import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
def main():
transform = transforms.Compose(
[transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))
im = Image.open('1.jpg')
im = transform(im) # [C, H, W]
im = torch.unsqueeze(im, dim=0) # [N, C, H, W]
with torch.no_grad():
outputs = net(im)
predict = torch.max(outputs, dim=1)[1].numpy()
print(classes[int(predict)])
if __name__ == '__main__':
main()
下面逐行进行分析:
import torch: 导入PyTorch库,这是一个用于深度学习的开源库。
import torchvision.transforms as transforms: 导入PyTorch的图像处理模块,并简写为transforms。这个模块提供了许多图像预处理的功能。
from PIL import Image: 从Python Imaging Library (PIL)中导入Image模块,这是一个用于图像处理的库。
from model import LeNet: 从model模块中导入LeNet类
transform = transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
因为我们训练的网络的输入统一都是32*32的,但是我们用来测试的图片尺寸是随机的,因此输入前要做统一的尺寸裁剪,并将其转换为Tensor数据格式并进行归一化处理
classes = (‘plane’, ‘car’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’): 定义一个包含所有类别的列表。
net = LeNet(): 创建一个LeNet模型实例。
net.load_state_dict(torch.load(‘Lenet.pth’)): 从文件’Lenet.pth’中加载预训练的模型权重。
im = Image.open(‘1.jpg’): 使用PIL库打开名为’1.jpg’的图像文件。
im = transform(im): 对图像应用之前定义的转换流程。
im = torch.unsqueeze(im, dim=0): 在新的维度上增加一个维度,这样模型就可以处理单个图像了。
with torch.no_grad():: 在这个代码块中,我们告诉PyTorch不要计算任何梯度,因为我们只是进行前向传播,不需要反向传播和优化。
outputs = net(im): 将处理过的图像输入到LeNet模型中,得到输出结果。
predict = torch.max(outputs, dim=1)[1].numpy(): 从模型的输出中找到最大值的索引,这对应于预测的类别。
print(classes[int(predict)]): 打印预测的类别。