我正在尝试使用TensorRT加速yolov3 TF2的推断。
我在张量流2中使用TrtGraphConverter函数。

我的代码本质上是这样的:



from tensorflow.python.compiler.tensorrt import trt_convert as trt

tf.keras.backend.set_learning_phase(0)
converter = trt.TrtGraphConverter(
    input_saved_model_dir="./tmp/yolosaved/",
    precision_mode="FP16",
    is_dynamic_op=True)
converter.convert()


saved_model_dir_trt = "./tmp/yolov3.trt"
converter.save(saved_model_dir_trt)





并生成以下错误:



Traceback (most recent call last):
  File "/home/pierre/Programs/anaconda3/envs/Deep2/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 427, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 1 of node StatefulPartitionedCall was passed float from conv2d/kernel:0 incompatible with expected resource.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/pierre/Documents/GitHub/yolov3-tf2/tensorrt.py", line 23, in <module>
    converter.save(saved_model_dir_trt)
  File "/home/pierre/Programs/anaconda3/envs/Deep2/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 822, in save
    super(TrtGraphConverter, self).save(output_saved_model_dir)
  File "/home/pierre/Programs/anaconda3/envs/Deep2/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 432, in save
    importer.import_graph_def(self._converted_graph_def, name="")
  File "/home/pierre/Programs/anaconda3/envs/Deep2/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/pierre/Programs/anaconda3/envs/Deep2/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 431, in import_graph_def
    raise ValueError(str(e))
ValueError: Input 1 of node StatefulPartitionedCall was passed float from conv2d/kernel:0 incompatible with expected resource.





这是否意味着我的某些节点无法转换?在这种情况下,为什么在.save步骤中我的代码出错?

最佳答案

我最终使用以下代码解决了这个问题。我也从tf 2.0.-beta0切换到tf-nightly-gpu-2.0-preview



params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
    precision_mode='FP16',
    is_dynamic_op=True)

converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=saved_model_dir,
    conversion_params=params)
converter.convert()
saved_model_dir_trt = "/tmp/model.trt"
converter.save(saved_model_dir_trt)





谢谢你的帮助

关于tensorflow - TensorRT和Tensorflow 2,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57117397/

10-10 00:27