我有一个形状为(16, 4096, 3)
的张量。我还有一个形状为(16, 32768, 3)
的索引张量。我正在尝试沿dim=1
收集值。最初是使用gather function在pytorch中完成的,如下所示-
# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
请注意,输出
b
的大小与idx
的大小相同。但是,当我应用tensorflow的gather
函数时,我得到了完全不同的输出。发现输出尺寸不匹配,如下所示-b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
我也尝试使用
tf.gather_nd
,但徒劳无功。见下文-b = tf.gather_nd(a, idx)
# b.shape (16, 32768)
为什么我得到不同形状的张量?我想得到与pytorch计算的形状相同的张量。
换句话说,我想知道torch.gather的张量流等效项。
最佳答案
对于2D情况,有一种方法可以做到:
# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)
但是,对于ND病例,这种方法可能非常复杂
关于python - tensorflow 相当于torch.gather,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52129909/