本文介绍了从 TensorFlow 图中清除 dropout 操作的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个经过训练的冻结图,我试图在 ARM 设备上运行它.基本上,我使用的是 contrib/pi_examples/label_image,但使用的是我的网络而不是 Inception.我的网络接受过 dropout 训练,现在给我带来了麻烦:

I have a trained freezed graph that I am trying to run on an ARM device. Basically, I am using contrib/pi_examples/label_image, but with my network instead of Inception. My network was trained with dropout, which now causes me troubles:

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs.  Registered kernels:
  device='CPU'; T in [DT_FLOAT]
  device='CPU'; T in [DT_INT32]
  device='GPU'; T in [DT_STRING]
  device='GPU'; T in [DT_BOOL]
  device='GPU'; T in [DT_INT32]
  device='GPU'; T in [DT_FLOAT]

 [[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]

我能看到的一个解决方案是构建这样一个包含相应操作的TF静态库.另一方面,从网络中消除 dropout ops 以使其更简单和更快可能是一个更好的主意.有没有办法做到这一点?

One solution I can see is to build such TF static library that includes the corresponding operation. From other hand, it might be a better idea to eliminate the dropout ops from the network in order to make it simpler and faster. Is there a way to do that?

谢谢.

推荐答案

#!/usr/bin/env python2

import argparse

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2

def print_graph(input_graph):
    for node in input_graph.node:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

def strip(input_graph, drop_scope, input_before, output_after, pl_name):
    input_nodes = input_graph.node
    nodes_after_strip = []
    for node in input_nodes:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

        if node.name.startswith(drop_scope + '/'):
            continue

        if node.name == pl_name:
            continue

        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        if new_node.name == output_after:
            new_input = []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)
        nodes_after_strip.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_strip)
    return output_graph

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-graph', action='store', dest='input_graph')
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
    parser.add_argument('--output-graph', action='store', dest='output_graph')
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)

    args = parser.parse_args()

    input_graph = args.input_graph
    input_binary = args.input_binary
    output_graph = args.output_graph
    output_binary = args.output_binary

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)

    print "Before:"
    print_graph(input_graph_def)
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
    print "After:"
    print_graph(output_graph_def)

    if output_binary:
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    else:
        with tf.gfile.GFile(output_graph, "w") as f:
            f.write(text_format.MessageToString(output_graph_def))
    print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == "__main__":
    main()

这篇关于从 TensorFlow 图中清除 dropout 操作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-12 02:42