我在将tf.scatter_nd_add()
应用于2D张量时遇到困难。该文档有点不清楚,并且没有包含稀疏更新的示例,而仅包含完整切片的更新。
我的情况如下:updates
-形状为[None, 6]
的2D张量indices
-形状为[None, 6]
的2D张量ref
-形状为[None, 6]
的零的2D变量
可以保证updates
,indices
和ref
的第一维始终相等,但是该维的大小可以变化。我要执行的更新看起来像
for i, j:
k = indices[i][j]
ref[i][k] += updates[i][j]
请注意,
indices
包含重复项。 tf.scatter_nd_add(ref, indices, updates)
抱怨形状不匹配,我无法弄清楚如何重组张量才能执行更新。 最佳答案
我想到了。 indices
中的每个2D条目实际上必须指定将在ref
中更新的绝对位置。这意味着indices
必须为3D,然后非矢量化更新如下所示:
for i, j:
r, k = indices[i][j]
ref[r][k] += updates[i][j]
在上述问题中,恰巧
r
始终等于i
。这是具有不同形状的完整Tensorflow实现。为了清楚起见,在下面的示例中,
col_indices
对应于原始问题的indices
:import tensorflow as tf
import numpy as np
updates = tf.placeholder(dtype=tf.float32, shape=[None, 6])
col_indices = tf.placeholder(dtype=tf.int32, shape=[None, 6])
row_indices = tf.cumsum(tf.ones_like(col_indices), axis=0, exclusive=True)
indices = tf.concat([tf.expand_dims(row_indices, axis=-1),
tf.expand_dims(col_indices, axis=-1)], axis=-1)
tmp_var = tf.Variable(0, trainable=False, dtype=tf.float32, validate_shape=False)
ref = tf.assign(tmp_var, tf.zeros_like(updates), validate_shape=False)
# This makes sure that ref is always 0 before scatter_nd_add() runs
with tf.control_dependencies([target_var]):
result = tf.scatter_nd_add(ref, indices, updates)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Create example input data
np_input = np.arange(0, 6, 1, dtype=np.int32)
np_input = np.tile(np_input[None,:], [10, 1])
res = sess.run(result, feed_dict={updates: np_input, col_indices: np_input})
print(res)
关于python - 在TensorFlow中稀疏添加的scatter_nd add()示例,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/46940302/