问题描述
假设我们有两个TensorFlow计算图G1
和G2
,并且权重为W1
和W2
.假设我们仅通过构造G1
和G2
来构建新图G
.我们如何为这个新图形G
恢复W1
和W2
?
Suppose we have two TensorFlow computation graphs, G1
and G2
, with saved weights W1
and W2
. Assume we build a new graph G
simply by constructing G1
and G2
. How can we restore both W1
and W2
for this new graph G
?
举一个简单的例子:
import tensorflow as tf
V1 = tf.Variable(tf.zeros([1]))
saver_1 = tf.train.Saver()
V2 = tf.Variable(tf.zeros([1]))
saver_2 = tf.train.Saver()
sess = tf.Session()
saver_1.restore(sess, 'W1')
saver_2.restore(sess, 'W2')
在此示例中,saver_1
成功恢复了相应的V1
,但是saver_2
失败,并显示NotFoundError
.
In this example, saver_1
succesfully restores the corresponding V1
, but saver_2
fails with a NotFoundError
.
推荐答案
您可能可以使用两个保护程序,其中每个保护程序仅查找变量之一.如果仅使用tf.train.Saver()
,我认为它将查找您已定义的所有变量.您可以使用tf.train.Saver([v1, ...])
为其提供要查找的变量列表.有关更多信息,您可以在此处阅读有关tf.train.Saver
构造函数的信息: https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver
You can probably use two savers where each saver looks for just one of the variables. If you just use tf.train.Saver()
, I think it will look for all variables you have defined. You can give it a list of variables to look for by using tf.train.Saver([v1, ...])
. For more info, you can read about the tf.train.Saver
constructor here: https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver
这是一个简单的工作示例.假设您在文件"save_vars.py"中进行计算,它具有以下代码:
Here's a simple working example. Suppose you do your computation in a file "save_vars.py" and it has the following code:
import tensorflow as tf
# Graph 1 - set v1 to have value [1.0]
g1 = tf.Graph()
with g1.as_default():
v1 = tf.Variable(tf.zeros([1]), name="v1")
assign1 = v1.assign(tf.constant([1.0]))
init1 = tf.initialize_all_variables()
save1 = tf.train.Saver()
# Graph 2 - set v2 to have value [2.0]
g2 = tf.Graph()
with g2.as_default():
v2 = tf.Variable(tf.zeros([1]), name="v2")
assign2 = v2.assign(tf.constant([2.0]))
init2 = tf.initialize_all_variables()
save2 = tf.train.Saver()
# Do the computation for graph 1 and save
sess1 = tf.Session(graph=g1)
sess1.run(init1)
print sess1.run(assign1)
save1.save(sess1, "tmp/v1.ckpt")
# Do the computation for graph 2 and save
sess2 = tf.Session(graph=g2)
sess2.run(init2)
print sess2.run(assign2)
save2.save(sess2, "tmp/v2.ckpt")
如果确保您具有tmp
目录并运行python save_vars.py
,则将获取已保存的检查点文件.
If you ensure that you have a tmp
directory and run python save_vars.py
, you'll get the saved checkpoint files.
现在,您可以使用名为"restore_vars.py"的文件通过以下代码进行还原:
Now, you can restore using a file named "restore_vars.py" with the following code:
import tensorflow as tf
# The variables v1 and v2 that we want to restore
v1 = tf.Variable(tf.zeros([1]), name="v1")
v2 = tf.Variable(tf.zeros([1]), name="v2")
# saver1 will only look for v1
saver1 = tf.train.Saver([v1])
# saver2 will only look for v2
saver2 = tf.train.Saver([v2])
with tf.Session() as sess:
saver1.restore(sess, "tmp/v1.ckpt")
saver2.restore(sess, "tmp/v2.ckpt")
print sess.run(v1)
print sess.run(v2)
,当您运行python restore_vars.py
时,输出应为
and when you run python restore_vars.py
, the output should be
[1.]
[2.]
(至少在我的计算机上是输出).如果有任何不清楚的地方,随时发表评论.
(at least on my computer that's the output). Feel free to post a comment if anything was unclear.
这篇关于TensorFlow:恢复多个图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!