我已经通过 official docthis 但很难理解发生了什么。

我试图理解 DQN 源代码,它使用了第 197 行的 Gather 函数。

有人可以简单地解释一下gather函数的作用吗?该功能的目的是什么?

最佳答案

torch.gather 函数(或 torch.Tensor.gather )是一种多索引选择方法。查看官方文档中的以下示例:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

让我们从不同参数的语义开始:第一个参数 input 是我们想要从中选择元素的源张量。第二个 dim 是我们想要收集的维度(或 tensorflow/numpy 中的轴)。最后, index 是索引 input 的索引。
至于操作的语义,官方文档是这样解释的:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

那么让我们来看看这个例子。

输入张量是 [[1, 2], [3, 4]] ,dim 参数是 1 ,即我们要从第二维收集。第二个维度的索引为 [0, 0][1, 0]

当我们“跳过”第一个维度(我们想要收集的维度是 1 )时,结果的第一个维度被隐式给出为 index 的第一个维度。这意味着索引包含第二个维度或列索引,但不包含行索引。这些由 index 张量本身的索引给出。
例如,这意味着输出将在其第一行中选择 input 张量的第一行的元素,正如 index 张量的第一行的第一行给出的那样。由于列索引由 [0, 0] 给出,因此我们选择输入的第一行的第一个元素两次,结果为 [1, 1] 。类似地,结果的第二行的元素是 input 张量的第二行被 0x2518122231343141 张量的第二行元素索引的结果,得到 index

为了进一步说明这一点,让我们交换示例中的维度:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

如您所见,索引现在是沿第一维收集的。

对于你提到的例子,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
[4, 3] 将通过 Action 的批处理列表索引 q 值的行(即一批 q 值中的每个样本 q 值)。结果将与您执行以下操作相同(尽管它会比循环快得多):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

关于pytorch - 用外行的话来说,pytorch 中的 gather 函数有什么作用?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50999977/

10-14 06:50