这是 DIM 中第 15-20 行的快照
def random_permute(X):
X = X.transpose(1, 2)
b = torch.rand((X.size(0), X.size(1))).cuda()
idx = b.sort(0)[1]
adx = torch.range(0, X.size(1) - 1).long()
X = X[idx, adx[None, :]].transpose(1, 2)
return X
其中
X
是大小为 [64, 64, 128] 的张量,idx
是大小为 [64, 64] 的张量,adx
是大小为 [64] 的张量。X = X[idx, adx[None, :]]
如何工作?我们如何使用两个 2d 张量来索引一个 3d 张量?索引后 X
到底发生了什么? 最佳答案
根据我的猜测 X
必须是一个 3D 张量,因为它通常代表一批训练数据。
就该函数的功能而言,它随机排列输入数据张量 X
并使用以下步骤执行此操作:
b
。 idx
。 adx
只是一个取值范围为 0 到 63 的整数张量。 现在,下面这行是所有魔法发生的地方:
X[idx, adx[None, :]].transpose(1, 2)
我们使用在
idx
和 adx
之前得到的索引(adx[None, :]
只是一个二维行向量)。一旦我们有了它,我们就完全像我们在函数开头所做的那样对轴 1 和 2 进行转置:X = X.transpose(1, 2)
这是一个人为的例子,为了更好地理解:
# our input tensor
In [51]: X = torch.rand(64, 64, 32)
In [52]: X = X.transpose(1, 2)
In [53]: X.shape
Out[53]: torch.Size([64, 32, 64])
In [54]: b = torch.rand((X.size(0), X.size(1)))
# sort `b` which returns a tuple and take only indices
In [55]: idx = b.sort(0)[1]
In [56]: idx.shape
Out[56]: torch.Size([64, 32])
In [57]: adx = torch.arange(0, X.size(1)).long()
In [58]: adx.shape
Out[58]: torch.Size([32])
In [59]: X[idx, adx[None, :]].transpose(1, 2).shape
Out[59]: torch.Size([64, 64, 32])
这里要注意的重要一点是我们如何在最后一步中获得与输入张量的形状相同的形状,即
(64, 64, 32)
。关于python - 3-d 张量如何由两个 2d 张量索引?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52342514/