PyTorch模型的多种导出方式提供给其他程序使用
flyfish
PyTorch模型的多种导出方式
1 模型可视化
以下使用模型可视化工具时netron
工具下载到本地
https://github.com/lutzroeder/netron/releases/
或者在使用
https://netron.app/
2 预训练模型
当下载一个预训练模型时,只是一个一个的module
3 ONNX模型导出有输入有输出
import torch
import torchvision
if __name__ == '__main__':
input = torch.randn(1, 3, 224, 224)
model = torchvision.models.resnet18()
model.load_state_dict(torch.load("resnet18-f37072fd.pth"))
model.eval()
torch.onnx.export(model, input, "a.onnx", training=torch.onnx.TrainingMode.TRAINING)
torch.onnx.export(model, input, "b.onnx", training=torch.onnx.TrainingMode.EVAL)
TRAINING导出方式
算子没有融合
EVAL导出方式
当采用EVAL方式进行模型导出的时候,Conv和BatchNorm层进行了合并
4 自定义输入输出的名字,并可批量推理
import torch
import torchvision
if __name__ == '__main__':
model = torchvision.models.resnet18()
model.load_state_dict(torch.load("resnet18-f37072fd.pth"))
model.eval()
batch_size = 4
input_data = torch.randn(batch_size, 3, 224, 224)
output_path = "c.onnx"
torch.onnx.export(model, input_data, output_path,
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})
5 导出JIT模型
JIT(Just-In-Time)
在Yolov5中叫torchscript
import torch
import torchvision
if __name__ == '__main__':
model = torchvision.models.resnet18()
model.load_state_dict(torch.load("resnet18-f37072fd.pth"))
model.eval()
input = torch.rand(1, 3, 224, 224)
jit_model = torch.jit.trace(model, input)
torch.jit.save(jit_model, 'resnet18_jit.trace.pth')
#script_model = torch.jit.script(model, input)
#torch.jit.save(script_model, 'resnet18_jit.script.pth')
本文使用的PyTorch版本 1.10.1