让我们将我要查找的函数称为“magic_combine
”,该函数可以组合我赋予它的张量的连续尺寸。更具体地说,我希望它执行以下操作:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
我知道
torch.view()
可以做类似的事情。但我只是想知道是否还有其他更优雅的方法可以达成目标? 最佳答案
我不确定“更优雅的方式”是什么想法,但是 Tensor.view()
的优点是不为 View 重新分配数据(原始张量和 View 共享相同的数据),从而使此操作轻巧。
正如@UmangGupta所提到的那样,将这个函数包装起来以实现您想要的东西很简单,例如:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
关于python - 是否有任何pytorch函数可以将张量的特定连续尺寸合并为一个?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50991189/