tensorflow保存模型有多种方法
第一种:saver.save(sess, "./hello_model") # 生成ckpt模型文件, hello_model.data-00000-of-00001 hello_model.index hello_model.meta
第二种:tf.train.write_graph(sess.graph_def, ./, 'hello.pb') # 生成hello.pb, 再通过freeze_graph把hello.pb与ckpt固化成新的pb文件
第三种:用tf.graph_util.convert_variables_to_constants把变量转成常量之后写入PB文件中
第四种:使用tf.saved_model.builder.SavedModelBuilder
具体看代码,及示例
保存模型的文件 saver_hello.py
点击(此处)折叠或打开
- import tensorflow as tf
- import sys
- import os
- # 把变量转成常量之后写入PB文件中
- def SaveFrozenPb(nodeNameList, pbFile):
- gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), nodeNameList)
- with tf.gfile.GFile(pbFile, 'wb') as f:
- f.write(gd.SerializeToString())
- # 通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件
- # freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
- # 如果不调用freeze_graph, 直接使用会报错‘google.protobuf.message.DecodeError: Error parsing message’
- def SavePbForFreezeGraph(pbDir, pbName):
- tf.train.write_graph(sess.graph_def, pbDir, pbName)
- def SaveBuilderPb(pbDir):
- if not os.path.exists(pbDir):
- os.makedirs(pbDir)
- builder = tf.saved_model.builder.SavedModelBuilder(pbDir)
- builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)
- builder.save()
- if __name__ == '__main__':
- hello = tf.Variable(tf.constant('Hello World', name = "hello")) # 要save成功,需要tf.Variable, 否则会报错'ValueError: No variables to save'
- x = tf.placeholder(tf.float32, name="x")
- y = tf.multiply(x, 2, name="y")
- init = tf.global_variables_initializer()
- sess = tf.Session()
- sess.run(init)
- saver = tf.train.Saver()
- typeStr = sys.argv[1]
- if typeStr == 'ckpt' or typeStr == 'pbNotFrozen':
- saver.save(sess, "./hello_model", write_meta_graph=True) # hello_model.data-00000-of-00001 hello_model.index hello_model.meta
- if typeStr == 'pbNotFrozen':
- SavePbForFreezeGraph('./', 'hello.pb') # 需要经由freeze_graph工具处理
- elif typeStr == 'pbFrozen':
- SaveFrozenPb(['x', 'y', 'hello'], './hello_frozen.pb') # 无需再经由freeze_graph工具处理
- elif typeStr == 'builderPb':
- SaveBuilderPb('./save/')
加载模型文件restore_hello.py
点击(此处)折叠或打开
- import tensorflow as tf
- import sys
- def RestoreMeta(sess, name):
- #ckpt = tf.train.get_checkpoint_state('./')
- #restore = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
- #restore.restore(sess, ckpt.model_checkpoint_path)
- restore = tf.train.import_meta_graph(name)
- restore.restore(sess, "hello_model")
- def RestorePb(sess, name):
- # 二进制读取模型文件
- with tf.gfile.FastGFile(name, 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- sess.graph.as_default()
- tf.import_graph_def(graph_def, name='') # 导入计算图
- def RestoreBuilderPb(sess, pbDir):
- tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], pbDir)
- if __name__ == '__main__':
- sess = tf.Session()
- typeStr = sys.argv[1]
- if typeStr == 'ckpt':
- RestoreMeta(sess, 'hello_model.meta')
- elif typeStr == 'pbFrozen':
- RestorePb(sess, './hello_frozen.pb')
- elif typeStr == 'builderPb':
- RestoreBuilderPb(sess, './save/')
- x = tf.get_default_graph().get_tensor_by_name("x:0")
- y = tf.get_default_graph().get_tensor_by_name("y:0")
- hello = tf.get_default_graph().get_tensor_by_name("hello:0")
- print(sess.run(y, feed_dict={x:5})) # 10.0
- print(sess.run(hello)) # b'Hello World'
第一种:ckpt
保存模型
python3 ./saver_hello.py ckpt
生成checkpoint hello_model.data-00000-of-00001 hello_model.index hello_model.meta
加载模型
python3 ./restore_hello.py ckpt
运行结果
10.0
b'Hello World'
第二种:ckpt+pb+固化
python3 ./saver_hello.py pbNotFrozen
生成checkpoint hello_model.data-00000-of-00001 hello_model.index hello_model.meta hello.pb
固化
freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
加载
python3 ./restore_hello.py pbFrozen
第三种:固化的pb
保存
python3 ./saver_hello.py pbFrozen
加载
python3 ./restore_hello.py pbFrozen
第四种:
python3 ./saver_hello.py builderPb
python3 ./restore_hello.py builderPb