在这里,我使用的是一个.pth文件的pytorch segnet实现,该文件包含50个阶段的训练的权重。
如何加载单个测试图像并查看网络预测?
我知道这听起来像个愚蠢的问题,但我被卡住了。
我得到的是:
from segnet import SegNet
import torch
model = SegNet(2)
model.load_state_dict(torch.load('./model_segnet_epoch50.pth'))
如何在一张测试图片上“使用”网络?
最佳答案
output = model(image)
.
注意,图像应该是一个Variable
对象,输出也应该是。
例如,如果您的图像是一个Numpy数组,您可以这样转换它:var_image = Variable(torch.Tensor(image))