本文介绍了参数维度对gather函数的影响的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用 gatherpytorch中的函数,但是看不懂dim参数的作用.

I am trying to use the gather function in pytorch but can't understand the role of dim parameter.

代码:

t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))

输出:

 1  2
 3  2
[torch.FloatTensor of size 2x2]

维度设置为 1:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))

输出变成:

 1  1
 4  3
[torch.FloatTensor of size 2x2]

gather 函数实际上是如何工作的?

How, gather function actually works?

推荐答案

我意识到了 Gather 函数的工作原理.

I realized how the gather function works.

t = torch.Tensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
torch.gather(t, 0, index)

由于 dimension 为零,所以输出将是:

Since the dimension is zero, so the output will be:

| t[index[0, 0], 0]   t[index[0, 1], 1] |
| t[index[1, 0], 0]   t[index[1, 1], 1] |

如果dimension设置为1,输出会变成:

If the dimension is set to one, the output will become:

| t[0, index[0, 0]]   t[0, index[0, 1]] |
| t[1, index[1, 0]]   t[1, index[1, 1]] |

所以公式是:

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

参考:http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather

这篇关于参数维度对gather函数的影响的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-12 16:26