问题描述
我有一个经过训练的冻结图,我试图在 ARM 设备上运行它.基本上,我使用的是 contrib/pi_examples/label_image,但使用的是我的网络而不是 Inception.我的网络接受过 dropout 训练,现在给我带来了麻烦:
I have a trained freezed graph that I am trying to run on an ARM device. Basically, I am using contrib/pi_examples/label_image, but with my network instead of Inception. My network was trained with dropout, which now causes me troubles:
Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs. Registered kernels:
device='CPU'; T in [DT_FLOAT]
device='CPU'; T in [DT_INT32]
device='GPU'; T in [DT_STRING]
device='GPU'; T in [DT_BOOL]
device='GPU'; T in [DT_INT32]
device='GPU'; T in [DT_FLOAT]
[[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]
我能看到的一个解决方案是构建这样一个包含相应操作的TF静态库.另一方面,从网络中消除 dropout ops 以使其更简单和更快可能是一个更好的主意.有没有办法做到这一点?
One solution I can see is to build such TF static library that includes the corresponding operation. From other hand, it might be a better idea to eliminate the dropout ops from the network in order to make it simpler and faster. Is there a way to do that?
谢谢.
推荐答案
#!/usr/bin/env python2
import argparse
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
def print_graph(input_graph):
for node in input_graph.node:
print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)
def strip(input_graph, drop_scope, input_before, output_after, pl_name):
input_nodes = input_graph.node
nodes_after_strip = []
for node in input_nodes:
print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)
if node.name.startswith(drop_scope + '/'):
continue
if node.name == pl_name:
continue
new_node = node_def_pb2.NodeDef()
new_node.CopyFrom(node)
if new_node.name == output_after:
new_input = []
for node_name in new_node.input:
if node_name == drop_scope + '/cond/Merge':
new_input.append(input_before)
else:
new_input.append(node_name)
del new_node.input[:]
new_node.input.extend(new_input)
nodes_after_strip.append(new_node)
output_graph = graph_pb2.GraphDef()
output_graph.node.extend(nodes_after_strip)
return output_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input-graph', action='store', dest='input_graph')
parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
parser.add_argument('--output-graph', action='store', dest='output_graph')
parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)
args = parser.parse_args()
input_graph = args.input_graph
input_binary = args.input_binary
output_graph = args.output_graph
output_binary = args.output_binary
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
print "Before:"
print_graph(input_graph_def)
output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
print "After:"
print_graph(output_graph_def)
if output_binary:
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
else:
with tf.gfile.GFile(output_graph, "w") as f:
f.write(text_format.MessageToString(output_graph_def))
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == "__main__":
main()
这篇关于从 TensorFlow 图中清除 dropout 操作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!