我用NCHW
数据格式在gpu上训练了一个小型cnn,现在我想导出一个.pb
文件,然后可以在其他应用程序中使用它进行推理。
我编写了一个小的helper函数,在给定包含检查点文件和graph.pbtxt的目录下,用默认值调用Tensorflow的freeze_graph
函数:
import os
import argparse
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
dir(tf.contrib) #fix for tf.contrib undefined ops bug
from tensorflow.python.tools.freeze_graph import freeze_graph
def my_freeze_graph_2(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert
all its variables into constant
Args:
model_dir: the root folder containing the checkpoint state file
output_node_names: a string, containing all the output node's names,
comma separated
"""
if not tf.gfile.Exists(model_dir):
raise AssertionError(
"Export directory doesn't exists. Please specify an export "
"directory: %s" % model_dir)
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_dir = os.path.abspath(model_dir)
output_graph = os.path.join(absolute_model_dir, "frozen_model.pb")
freeze_graph(input_graph=os.path.join(model_dir, 'graph.pbtxt'),
input_saver='',
input_binary=False,
input_checkpoint=input_checkpoint,
output_node_names=output_node_names,
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
output_graph=output_graph,
clear_devices=True,
initializer_nodes='')
然后,我有一个小脚本,尝试从
frozen_model.pb
构建图表,以测试冻结是否有效:import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import argparse
import tensorflow as tf
from freeze_graph import load_graph
import cv2
if __name__ == '__main__':
# Let's allow the user to pass the filename as an argument
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="model-multiple_starts/frozen_model.pb", type=str, help="Frozen model file to import")
args = parser.parse_args()
# We use our "load_graph" function
graph = load_graph(args.frozen_model_filename)
# We can verify that we can access the list of operations in the graph
for op in graph.get_operations():
print(op.name)
# We access the input and output nodes
x = graph.get_tensor_by_name('prefix/Reshape:0')
y = graph.get_tensor_by_name('prefix/softmax_tensor:0')
# We launch a Session
with tf.Session(graph=graph, config=tf.ConfigProto(log_device_placement=True)) as sess:
# Note: we don't nee to initialize/restore anything
# There is no Variables in this graph, only hardcoded constants
# Load an image to use as test
im = cv2.imread('57_00000000.png', cv2.IMREAD_GRAYSCALE)
im = im.T
im = im / 255 - 0.5
im = im[None,:,:,None]
y_out = sess.run(y, feed_dict={
x: im
})
print(y_out)
如果尝试运行测试脚本,则会出现以下错误:
InvalidArgumenterRor:CPU BIASOP仅支持NHWC。[[节点:
前缀/conv2d/biasadd=biasadd[t=dt_float,data_format=“nchw”,
_device=“/job:localhost/replica:0/task:0/cpu:0”](前缀/conv2d/卷积,
前缀/conv2d/bias/read)]]
我尝试了不同的配置:
仅从CPU脚本生成.pb文件,仅在CPU上运行
从GPU可见的脚本生成.pb文件,在GPU可见的情况下运行
从纯CPU脚本生成.pb文件,在GPU可见的情况下运行
他们都犯了同样的错误。
问题在于,我要冻结的检查点具有用
data_format='NCHW'
定义的操作。如何使用NHWC
数据格式冻结检查点?更新:
拨开文件,我看到在
graph.pbtxt
中许多操作data_format
都被硬编码为NCHW
。我想,然后,我需要用NHWC
格式创建一个新模型,有选择地从检查点加载层的权重,并使用该图手动保存.pb
文件。。。我想已经有一个过程可以做到这一点了,但是我找不到关于这个的任何文档,也找不到示例。
更新2:
在尝试导入OpenCV的DNN模块中的
.pb
文件之后,我发现了以下内容:将数据格式为NCHW和数据格式为NHWC的培训中的检查点冻结在一起会导致一个不可用的
graph.pbtxt
文件。我还没有找到确切的原因,但是将.pb
转换为.pb
并将其与工作冻结图进行比较,文件只在weights和biases常量中存储的值不同。使用数据格式nhwc将来自训练和
.pbtxt
的检查点冻结在一起会生成一个工作冻结图。这样看来,检查点在不同数据格式的图之间是不可转换的(即使在冻结过程中没有出现错误或警告)。
最佳答案
通常,您需要将图形构造包装在函数中,以便可以根据预测情况有条件地重新生成图形,因为通常有相当多的图形片段从训练更改为预测正如您所发现的,例如卷积层的NCHW
和NWHC
版本实际上是图形原型中的不同操作,它们是以这种方式硬编码的,因为gpu优化只可能用于其中一种格式。
编辑图形原型非常困难,这就是为什么执行此操作的大多数TensorFlow代码都遵循上面描述的模式在很高的层次上:
def build_graph(data_format='NCHW'):
# Conditionally use proper ops based on data_format arg
training_graph = tf.Graph()
with training_graph.as_default():
build_graph(data_format='NCHW')
with tf.Session() as sess:
# train
# checkpoint session
prediction_graph = tf.Graph()
with prediction_graph.as_default():
build_graph(data_format='NHWC')
# load checkpoint
# freeze graph
注意,
tf.estimator.Estimator
框架使这相对容易您可以使用mode
中的model_fn
参数来决定数据格式,然后使用两种不同的input_fn
进行训练和预测,框架将完成其余工作。您可以在这里找到一个端到端的示例:https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/cifar10_main.py#L77(我已链接到相关行)关于python - 卡住具有不同数据格式的图形,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47014306/