我有一个2 x 4 tensorA = [[0,1,0,1],[1,0,1,0]]
我想从维d提取索引i。
在Torch中,我可以执行:tensorA:select(d,i)

例如,tensorA:select(0,0)将返回[0,1,0,1]并且
tensorA:select(1,1)将返回[1,0]

我如何在TensorFlow中做到这一点?
我能找到的最简单的方法是:tf.gather(tensorA, indices=[i], axis=d)

但是为此使用聚集似乎太过分了。有谁知道更好的方法?

最佳答案

您可以使用以下配方:

用分号替换除d以外的所有轴,并将值i置于d轴上,例如:

tensorA[0, :]  # same as tensorA:select(0,0)
tensorA[:, 1]  # same as tensorA:select(1,1)
tensorA[:, 0]  # same as tensorA:select(1,0)


但是,当我尝试这样做时,出现了SyntaxError:

i = 1
selection = [:,i]  # this raises SyntaxError
tensorA[selection]


所以我改用一片

i = 1
selection = [slice(0,2,1), i]
tensorA[selection]  # same as tensorA:select(1,i)


这个函数可以达到目的:

def select(t, axis, index):
    shape = K.int_shape(t)
    selection = [slice(shape[a]) if a != axis else index for a in
                 range(len(shape))]
    return t[selection]


例如:

import numpy as np
t = K.constant(np.arange(60).reshape(2,5,6))
sub_tensor = select(t, 1, 1)
print(K.eval(sub_tensor)


版画


  [[6.,7.,8.,9.,10.,11.],
  
  [36.,37.,38.,39.,40.,41.]]

关于python - 在TensorFlow中提取子张量,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49158935/

10-12 18:13