我有一个形状为(16, 4096, 3)的张量。我还有一个形状为(16, 32768, 3)的索引张量。我正在尝试沿dim=1收集值。最初是使用gather function在pytorch中完成的,如下所示-

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)


请注意,输出b的大小与idx的大小相同。但是,当我应用tensorflow的gather函数时,我得到了完全不同的输出。发现输出尺寸不匹配,如下所示-

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)


我也尝试使用tf.gather_nd,但徒劳无功。见下文-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)


为什么我得到不同形状的张量?我想得到与pytorch计算的形状相同的张量。

换句话说,我想知道torch.gather的张量流等效项。

最佳答案

对于2D情况,有一种方法可以做到:

# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)


但是,对于ND病例,这种方法可能非常复杂

关于python - tensorflow 相当于torch.gather,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52129909/

10-11 09:29