本文介绍了tf.get_collection 提取一个作用域的变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有 n(例如:n=3)范围和 x(例如:x=4)没有在每个范围中定义的变量.范围是:

I have n (e.g: n=3) scopes and x (e.g: x=4) no of Variables defined in each scope.The scopes are:

model/generator_0
model/generator_1
model/generator_2

一旦我计算了损失,我想在运行时根据标准从一个范围中提取并提供所有变量.因此,我选择的范围 idx 的索引是转换为 int32

Once I compute the loss, I want to extract and provide all the variables from only one of the scope based on a criteria during run-time. Hence the index of the scope idx that I select is an argmin tensor cast into int32

<tf.Tensor 'model/Cast:0' shape=() dtype=int32>

我已经试过了:

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string))

这显然不起作用.有什么方法可以使用 idx 将属于该特定范围的所有 x 变量传递给优化器.

which obviously did not work.Is there any way to get all the x Variables belonging to that particular scope using idx to pass into the optimizer.

提前致谢!

维涅什·斯里尼瓦桑

推荐答案

您可以在 TF 1.0 rc1 或更高版本中执行以下操作:

You can do something like this in TF 1.0 rc1 or later:

v = tf.Variable(tf.ones(()))
loss = tf.identity(v)
with tf.variable_scope('adamoptim') as vs:
   optim = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss)
optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
print([v.name for v in optim_vars]) #=> prints lists of vars created

这篇关于tf.get_collection 提取一个作用域的变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-05 17:17