我在将tf.scatter_nd_add()应用于2D张量时遇到困难。该文档有点不清楚,并且没有包含稀疏更新的示例,而仅包含完整切片的更新。

我的情况如下:


updates-形状为[None, 6]的2D张量
indices-形状为[None, 6]的2D张量
ref-形状为[None, 6]的零的2D变量


可以保证updatesindicesref的第一维始终相等,但是该维的大小可以变化。我要执行的更新看起来像

for i, j:
    k = indices[i][j]
    ref[i][k] += updates[i][j]


请注意,indices包含重复项。 tf.scatter_nd_add(ref, indices, updates)抱怨形状不匹配,我无法弄清楚如何重组张量才能执行更新。

最佳答案

我想到了。 indices中的每个2D条目实际上必须指定将在ref中更新的绝对位置。这意味着indices必须为3D,然后非矢量化更新如下所示:

for i, j:
    r, k = indices[i][j]
    ref[r][k] += updates[i][j]


在上述问题中,恰巧r始终等于i

这是具有不同形状的完整Tensorflow实现。为了清楚起见,在下面的示例中,col_indices对应于原始问题的indices

import tensorflow as tf
import numpy as np

updates     = tf.placeholder(dtype=tf.float32,  shape=[None, 6])
col_indices = tf.placeholder(dtype=tf.int32,    shape=[None, 6])
row_indices = tf.cumsum(tf.ones_like(col_indices), axis=0, exclusive=True)
indices     = tf.concat([tf.expand_dims(row_indices, axis=-1),
                         tf.expand_dims(col_indices, axis=-1)], axis=-1)

tmp_var     = tf.Variable(0, trainable=False, dtype=tf.float32, validate_shape=False)
ref         = tf.assign(tmp_var, tf.zeros_like(updates), validate_shape=False)
# This makes sure that ref is always 0 before scatter_nd_add() runs
with tf.control_dependencies([target_var]):
  result = tf.scatter_nd_add(ref, indices, updates)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create example input data
np_input = np.arange(0, 6, 1, dtype=np.int32)
np_input = np.tile(np_input[None,:], [10, 1])

res = sess.run(result, feed_dict={updates: np_input, col_indices: np_input})
print(res)

关于python - 在TensorFlow中稀疏添加的scatter_nd add()示例,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/46940302/

10-12 18:01