官方介绍:Link
有两种使用场景,输入的参数不同以及返回值不同:
第一种
没有参数dim,但这种只适合一维张量。
Returns the maximum value of all elements in the input
tensor.
举例:
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763, 0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)
第二种
指定了参数dim,这种就适合多维张量了。
Notes:dim参数的值跟函数选取最大值的结果关系。我觉的还是挺让我意外的,和我想的不太一样。
这种情况下函数会返回一个元组(values,indices),其中,每一个value是input张量中在给定的dim维度中的最大值。并且indices是找到的每一个最大值的索引。
如果keepdim=True,那么输出的tensors和input保持相同的size,除了在dim维度上size为1哦!否则,如果keepdim=False,那么dim所在的维度是会被squeeze的,也就是输出的tensors比input少一个维度。
Notes:但是,再次注意,dim的数值和挑选最大值方式之间的关系。
请看下面的例子:
import torch
tensor = torch.randn(4, 4)
tensor
tensor([[ 0.1789, 0.7102, 0.7627, 0.4721],
[-0.2287, -0.7618, 0.1439, -0.5439],
[-0.4963, 0.3786, 0.1666, -0.5676],
[ 0.6240, 0.0017, 1.0748, 0.4061]])
torch.max(tensor, dim=1)
torch.return_types.max(
values=tensor([0.7627, 0.1439, 0.3786, 1.0748]),
indices=tensor([2, 2, 1, 2]))
所以,从这个结果可以看出,对于这个二维张量而言,dim=1,表示最大值的选取方式是固定行
,然后从所有列中选取最大值
。
再举一个三维数组的例子看看:
import torch
mine = torch.rand(3, 4, 4)
mine
tensor([[[0.0945, 0.1062, 0.1506, 0.1382],
[0.2846, 0.4346, 0.1247, 0.3741],
[0.9909, 0.7365, 0.6959, 0.8086],
[0.4392, 0.0296, 0.8124, 0.1953]],
[[0.6884, 0.9824, 0.4943, 0.6683],
[0.5548, 0.7565, 0.2543, 0.3552],
[0.0100, 0.5609, 0.9483, 0.6310],
[0.3992, 0.1476, 0.9362, 0.0209]],
[[0.8073, 0.9579, 0.2604, 0.0848],
[0.3591, 0.4507, 0.5978, 0.6411],
[0.6008, 0.0967, 0.7433, 0.0602],
[0.9017, 0.2228, 0.1419, 0.3229]]])
res = torch.max(mine, dim=2) #注意维度dim=2了哦!
res
torch.return_types.max(
values=tensor([[0.1506, 0.4346, 0.9909, 0.8124],
[0.9824, 0.7565, 0.9483, 0.9362],
[0.9579, 0.6411, 0.7433, 0.9017]]),
indices=tensor([[2, 1, 0, 2],
[1, 1, 2, 2],
[1, 3, 2, 0]]))
res[0].shape
torch.Size([3, 4])
现在能get到torch.max函数在取最大值的方式跟dim是什么关系了吗?
就是
那下面是感受当keepdim=True的结果,
res = torch.max(mine, dim=2, keepdim=True)
res
torch.return_types.max(
values=tensor([[[0.1506],
[0.4346],
[0.9909],
[0.8124]],
[[0.9824],
[0.7565],
[0.9483],
[0.9362]],
[[0.9579],
[0.6411],
[0.7433],
[0.9017]]]),
indices=tensor([[[2],
[1],
[0],
[2]],
[[1],
[1],
[2],
[2]],
[[1],
[3],
[2],
[0]]]))
res[0].shape
torch.Size([3, 4, 1])
所以现在能get到函数的输出结果跟keepdim参数的关系了吗?