问题描述
我正在编写一个python脚本,该脚本将任何流行的框架(TensorFlow,Keras,PyTorch)的深度学习模型转换为ONNX格式。目前,我已将用于tensorflow和可以将keras转换为ONNX,并且可以使用。
I am writing a python script, which converts any deep learning models from popular frameworks (TensorFlow, Keras, PyTorch) to ONNX format. Currently I have used tf2onnx for tensorflow and keras2onnx for keras to ONNX conversion, and those work.
现在PyTorch集成了ONNX支持,因此我可以直接从PyTorch保存ONNX模型。但是问题是我需要为该模型输入张量形状,以便将其保存为ONNX格式。您可能已经猜到了,我正在编写此脚本来转换未知的深度学习模型。
Now PyTorch has integrated ONNX support, so I can save ONNX models from PyTorch directly. But the problem is I will need input tensor shape for that model, in order to save it in ONNX format. As you already might have guessed, I am writing this script to convert unknown deep learning models.
是PyTorch的ONNX转换教程。上面写着:
Here is PyTorch's tutorial for ONNX conversion. There it's written:
我正在使用的代码段是这个:
The code snippet I am using is this:
import torch
def convert_pytorch2onnx(self):
"""pytorch -> onnx"""
model = torch.load(self._model_file_path)
# Don't know how to get this INPUT_SHAPE
dummy_input = torch.randn(INPUT_SHAPE)
torch.onnx.export(model, dummy_input, self._onnx_file_path)
return
那么我怎么知道那个未知的PyTorch模型的输入张量的INPUT_SHAPE?还是有其他方法可以将PyTorch模型转换为ONNX?
So how do I know the INPUT_SHAPE of the input tensor of that unknown PyTorch model? Or is there any other way to convert the PyTorch model to ONNX?
推荐答案
您可以以此为起点进行调试
you can follow this as a starting point to debug
list(model.parameters())[0].shape # weights of the first layer in the format (N,C,Kernel dimensions) # 64, 3, 7 ,7
之后,得到N,C并通过专门放置H来创建张量,w像这个玩具示例一样,没有像
after that get the N,C and create a tensor out of that by specially putting H,W as None like this toy example
import torch
import torchvision
net = torchvision.models.resnet18(pretrained = True)
shape_of_first_layer = list(net.parameters())[0].shape #shape_of_first_layer
N,C = shape_of_first_layer[:2]
dummy_input = torch.Tensor(N,C)
dummy_input = dummy_input[...,:, None,None] #adding the None for height and weight
torch.onnx.export(net, dummy_input, './alpha')
这篇关于如何获取未知PyTorch模型的输入张量形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!