I have converted a model from PyTorch to Keras and used the backend to extract the tensorflow graph. Since the data format for PyTorch was NCHW, the model extracted and saved is also that. While converting the model to TFLite, due to the format being NCHW, it cannot get converted. Is there a way to convert the whole graph into NHCW?
It is better to have a graph with the data-format matched to TFLite for faster inference. One approach is to manually insert transpose ops into the graph, like this example:How to convert the CIFAR10 tutorial to NCHW
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as session:
kernel = tf.ones(shape=[5, 5, 3, 64])
images = tf.ones(shape=[64,24,24,3])
imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC
