Pytorch中torch.max()函数解析

一. torch.max()函数解析

1. 官网链接

torch.max,如下图所示:

2. torch.max(input)函数解析

torch.max(input) → Tensor

将输入input张量,无论有几维,首先将其reshape排列成一个一维向量,然后找出这个一维向量里面最大值

3. 代码举例

3.1 输入一维张量,返回一维张量里面最大值

x = torch.randn(4)
y = torch.max(x)
x,y

输出结果如下:
(tensor([-0.6223,  0.0043, -0.8753,  1.4240]), tensor(1.4240))

3.2 输入二维张量,返回二维张量里面最大值

x = torch.randn(3,4)
y = torch.max(x)
x,y

输出结果如下:
(tensor([[-1.1052,  0.1026,  0.9994, -0.3092],
         [-0.8400,  0.2004,  0.9212,  0.7807],
         [-1.2979, -0.4327,  2.3044,  0.0140]]),
 tensor(2.3044))

3.3 输入两个一维张量,输出这两个张量里面相应元素中的最大值

x = torch.randn(4)
z = torch.randn(4)
max = torch.max(x,z)
x,z,max

输出结果如下:
(tensor([-1.5147, -1.2790, -1.0159, -0.4732]),
 tensor([-0.4547, -2.8545,  0.0554, -0.3548]),
 tensor([-0.4547, -1.2790,  0.0554, -0.3548]))

3.4 输入两个张量,一个张量一维,一个张量二维,此时一维张量会进行广播成二维张量,然后再输出这两个张量里面相应元素中的最大值,输出张量为二维。

x = torch.randn(3,4)
z = torch.randn(4)
max = torch.max(x,z)
x,z,max

输出结果如下:
(tensor([[ 1.1917,  0.6338,  0.7590, -0.9802],
         [ 0.2247,  0.3635,  1.3743,  1.6229],
         [ 1.6165,  0.0634,  0.5259,  0.1285]]),
 tensor([3.4765, 0.4480, 0.1502, 0.3738]),
 tensor([[3.4765, 0.6338, 0.7590, 0.3738],
         [3.4765, 0.4480, 1.3743, 1.6229],
         [3.4765, 0.4480, 0.5259, 0.3738]]))

3.5 输入两个二维张量,输出这两个张量里面相应元素中的最大值,输出张量为二维。

x = torch.randn(3,4)
z = torch.randn(3,4)
max = torch.max(x,z)
x,z,max

输出结果如下:
(tensor([[-0.0835,  0.0718, -1.7404, -0.3218],
         [ 0.0577,  0.6271,  1.4014, -0.6417],
         [ 0.3917,  0.0761,  1.2479, -0.4352]]),
 tensor([[-0.0717,  0.3822,  0.7256,  1.4147],
         [-0.1271,  0.1503,  0.3934,  1.6760],
         [-2.2341,  2.5286, -0.3500, -0.1751]]),
 tensor([[-0.0717,  0.3822,  0.7256,  1.4147],
         [ 0.0577,  0.6271,  1.4014,  1.6760],
         [ 0.3917,  2.5286,  1.2479, -0.1751]]))

4. torch.max(input,dim)函数解析

torch.max(input, dim, keepdim=False, *, out=None)

输入input(二维)张量,当dim=0时表示找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引;当dim=1时表示找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

5. 代码举例

5.1 dim=0,找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引,两个tensor都是一维。

x = torch.randn(3,4)
max,indices = torch.max(x,dim=0)
x,max,indices

(tensor([[ 0.1806,  1.0274,  0.5138, -1.4184],
         [ 0.5892, -0.7117, -1.2707,  0.7682],
         [ 0.5152, -0.8803,  1.7604,  0.4852]]),
 torch.return_types.max(
 values=tensor([0.5892, 1.0274, 1.7604, 0.7682]),
 indices=tensor([1, 0, 2, 1])))

输出结果如下:
(tensor([[ 0.0190,  0.8180, -1.0463,  1.7940],
         [ 0.7537, -1.0291, -2.3431,  0.3906],
         [ 0.3715,  1.6940, -1.1200, -0.4580]]),
 tensor([ 0.7537,  1.6940, -1.0463,  1.7940]),
 tensor([1, 2, 0, 0]))

5.2 dim=1,找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引,两个tensor都是一维。

x = torch.randn(3,4)
max,indices = torch.max(x,dim=1)
x,max,indices

输出结果如下:
(tensor([[ 1.4832,  0.1886, -0.3044, -0.6111],
         [-0.8998,  0.0610,  0.3388,  1.7176],
         [ 1.6153,  0.6864,  2.3225,  1.3818]]),
 tensor([1.4832, 1.7176, 2.3225]),
 tensor([0, 3, 2]))

参考知识文章

你可能感兴趣的:(Pytorch中torch.max()函数解析)