问题描述
我已经将模型从PyTorch转换为Keras,并使用后端提取了张量流图.由于PyTorch的数据格式为NCHW,因此提取并保存的模型也是如此.将模型转换为TFLite时,由于格式为NCHW,因此无法转换.有没有办法将整个图转换为NHCW?
I have converted a model from PyTorch to Keras and used the backend to extract the tensorflow graph. Since the data format for PyTorch was NCHW, the model extracted and saved is also that. While converting the model to TFLite, due to the format being NCHW, it cannot get converted. Is there a way to convert the whole graph into NHCW?
推荐答案
最好让图的数据格式与TFLite匹配,以加快推理速度.一种方法是手动将转置ops插入到图形中,例如以下示例:如何将CIFAR10教程转换为NCHW
It is better to have a graph with the data-format matched to TFLite for faster inference. One approach is to manually insert transpose ops into the graph, like this example:How to convert the CIFAR10 tutorial to NCHW
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as session:
kernel = tf.ones(shape=[5, 5, 3, 64])
images = tf.ones(shape=[64,24,24,3])
imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC
print("conv=",conv.eval())
这篇关于将预训练的已保存模型从NCHW转换为NHWC以使其与Tensorflow Lite兼容的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!