问题描述
我尝试了 tf.Graph() 但无法通过 new 重置变量.代码如下:
I tried tf.Graph() but can't get the variable to reset by new. The code is below:
with tf.Graph().as_default() as g:
clf_ = tf.estimator.Estimator(model_fn=my_w2d.model_fn_wide2deep, params=param, model_dir="/Users/zhouliaoming/data/credit_dnn/model_retrain/rm_gene_v2_sall/")
with tf.name_scope("rewrite"):
clf2 = tf.estimator.Estimator(model_fn=my_w2d.model_fn_wide2deep, params=param, model_dir="/Users/zhouliaoming/data/credit_dnn/model_retrain/genev2_s0/")
out_bias = tf.get_variable("output_0/bias")
out_b_rew = tf.get_variable("rewrite/output_0/bias")
vars_ = clf_.get_variable_names() ## only has clf_.get_variable_values()
print("vars: %r\n output_0/bias: %r\ntrain-vars: %r" % (vars_, clf_.get_variable_value('output_0/bias'), tf.contrib.framework.get_trainable_variables()))
print("before rewrite: out_bias: %r, out_b_rew: %r" % (out_bias.eval(), out_b_rew.eval()))
out_b_rew.assing(out_bias)
print("after rewrite: out_bias: %r, out_b_rew: %r" % (out_bias.eval(), out_b_rew.eval()))
它只是返回错误:
Traceback (most recent call last):
File "tf_utils.py", line 31, in <module>
out_bias = tf.get_variable("output_0/bias")
File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1262, in get_variable
constraint=constraint)
File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1097, in get_variable
constraint=constraint)
File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 435, in get_variable
constraint=constraint)
File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 404, in _true_getter
use_resource=use_resource, constraint=constraint)
File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 764, in _get_single_variable
"but instead was %s." % (name, shape))
ValueError: Shape of a new variable (output_0/bias) must be fully defined, but instead was <unknown>.
==============旧信息截线==========
=============== old infomation cut line =========
我通过 model_fn 处理程序定义了一个 tf.estimator.Estimator 模型 A.我想通过与 ckpt 文件相同的旧模型参数更改模型 A 的参数.我尝试获取模型 A 的图形,然后获取图形中参数的变量,然后通过我的旧模型的参数对其进行分配.希望给点建议!非常感谢!
I defined a tf.estimator.Estimator model A by model_fn handler. I want to change model A's parameter by same old model's parameters as ckpt file.I try to get model A's graph and then get the parameter's variable in Graph and then assigned it by my old model's parameter.Hope some advices! Thanks very much!
推荐答案
有很多方法可以做到这一点,具体取决于您可用的方法.例如,如果您拥有来自两个模型的代码和检查点,您可以创建两个单独的图(with tf.Graph() as g
)将两个检查点加载到其中,从其中一个读取变量值图形并将其分配给另一个图形中的变量.
There are many ways of doing this, depending on exactly what you have available. For example, if you have the code and checkpoints from both models, you can create two separate graphs (with tf.Graph() as g
) load the two checkpoints into them, read the variable values from one graph and assign it to a variable in another graph.
如果您确切地知道要在一个检查点中读取的变量,您可以只恢复它(Saver.restore
需要一个要恢复的变量列表),或者您可以使用类似的工具读取它CheckpointReader
If you know exactly the variable you want to read in one checkpoint, you can restore just it (Saver.restore
takes a list of variables to restore), or you can read it using tools like CheckpointReader
这篇关于如何重置 tf.estimator.Estimator 参数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!