torch.max函数的用法

官方介绍:Link

有两种使用场景,输入的参数不同以及返回值不同:

第一种

没有参数dim,但这种只适合一维张量。

Returns the maximum value of all elements in the input tensor.

举例:

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763,  0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)

第二种

指定了参数dim,这种就适合多维张量了。

Notes:dim参数的值跟函数选取最大值的结果关系。我觉的还是挺让我意外的,和我想的不太一样。

  这种情况下函数会返回一个元组(values,indices),其中,每一个value是input张量中在给定的dim维度中的最大值。并且indices是找到的每一个最大值的索引。

  如果keepdim=True,那么输出的tensors和input保持相同的size,除了在dim维度上size为1哦!否则,如果keepdim=False,那么dim所在的维度是会被squeeze的,也就是输出的tensors比input少一个维度。

Notes:但是,再次注意,dim的数值和挑选最大值方式之间的关系。请看下面的例子:

import torch

tensor = torch.randn(4, 4)
tensor

tensor([[ 0.1789,  0.7102,  0.7627,  0.4721],
           [-0.2287, -0.7618,  0.1439, -0.5439],
           [-0.4963,  0.3786,  0.1666, -0.5676],
           [ 0.6240,  0.0017,  1.0748,  0.4061]])
torch.max(tensor, dim=1)

torch.return_types.max(
values=tensor([0.7627, 0.1439, 0.3786, 1.0748]),
indices=tensor([2, 2, 1, 2]))

  所以,从这个结果可以看出,对于这个二维张量而言,dim=1,表示最大值的选取方式是固定行,然后从所有列中选取最大值

再举一个三维数组的例子看看:

import torch
mine = torch.rand(3, 4, 4)
mine

tensor([[[0.0945, 0.1062, 0.1506, 0.1382],
         [0.2846, 0.4346, 0.1247, 0.3741],
         [0.9909, 0.7365, 0.6959, 0.8086],
         [0.4392, 0.0296, 0.8124, 0.1953]],

        [[0.6884, 0.9824, 0.4943, 0.6683],
         [0.5548, 0.7565, 0.2543, 0.3552],
         [0.0100, 0.5609, 0.9483, 0.6310],
         [0.3992, 0.1476, 0.9362, 0.0209]],

        [[0.8073, 0.9579, 0.2604, 0.0848],
         [0.3591, 0.4507, 0.5978, 0.6411],
         [0.6008, 0.0967, 0.7433, 0.0602],
         [0.9017, 0.2228, 0.1419, 0.3229]]])
res = torch.max(mine, dim=2)  #注意维度dim=2了哦!
res

torch.return_types.max(
values=tensor([[0.1506, 0.4346, 0.9909, 0.8124],
        	  [0.9824, 0.7565, 0.9483, 0.9362],
       		  [0.9579, 0.6411, 0.7433, 0.9017]]),
indices=tensor([[2, 1, 0, 2],
		        [1, 1, 2, 2],
      		    [1, 3, 2, 0]]))
res[0].shape

torch.Size([3, 4])

现在能get到torch.max函数在取最大值的方式跟dim是什么关系了吗?
就是

那下面是感受当keepdim=True的结果,

res = torch.max(mine, dim=2, keepdim=True)
res

torch.return_types.max(
values=tensor([[[0.1506],
         [0.4346],
         [0.9909],
         [0.8124]],

        [[0.9824],
         [0.7565],
         [0.9483],
         [0.9362]],

        [[0.9579],
         [0.6411],
         [0.7433],
         [0.9017]]]),
indices=tensor([[[2],
         [1],
         [0],
         [2]],

        [[1],
         [1],
         [2],
         [2]],

        [[1],
         [3],
         [2],
         [0]]]))
res[0].shape

torch.Size([3, 4, 1])

所以现在能get到函数的输出结果跟keepdim参数的关系了吗?

06-21 23:53