让我们将我要查找的函数称为“magic_combine”,该函数可以组合我赋予它的张量的连续尺寸。更具体地说,我希望它执行以下操作:

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)

我知道torch.view()可以做类似的事情。但我只是想知道是否还有其他更优雅的方法可以达成目标?

最佳答案

我不确定“更优雅的方式”是什么想法,但是 Tensor.view() 的优点是不为 View 重新分配数据(原始张量和 View 共享相同的数据),从而使此操作轻巧。

正如@UmangGupta所提到的那样,将这个函数包装起来以实现您想要的东西很简单,例如:

import torch

def magic_combine(x, dim_begin, dim_end):
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.view(combined_shape)

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])

关于python - 是否有任何pytorch函数可以将张量的特定连续尺寸合并为一个?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50991189/

10-12 22:42