一. torch.max()函数解析
1. 官网链接
torch.max,如下图所示:
2. torch.max(input)函数解析
torch.max(input) → Tensor
将输入input张量,无论有几维,首先将其reshape排列成一个一维向量,然后找出这个一维向量里面最大值
3. 代码举例
3.1 输入一维张量,返回一维张量里面最大值
x = torch.randn(4)
y = torch.max(x)
x,y
输出结果如下:
(tensor([-0.6223, 0.0043, -0.8753, 1.4240]), tensor(1.4240))
3.2 输入二维张量,返回二维张量里面最大值
x = torch.randn(3,4)
y = torch.max(x)
x,y
输出结果如下:
(tensor([[-1.1052, 0.1026, 0.9994, -0.3092],
[-0.8400, 0.2004, 0.9212, 0.7807],
[-1.2979, -0.4327, 2.3044, 0.0140]]),
tensor(2.3044))
3.3 输入两个一维张量,输出这两个张量里面相应元素中的最大值
x = torch.randn(4)
z = torch.randn(4)
max = torch.max(x,z)
x,z,max
输出结果如下:
(tensor([-1.5147, -1.2790, -1.0159, -0.4732]),
tensor([-0.4547, -2.8545, 0.0554, -0.3548]),
tensor([-0.4547, -1.2790, 0.0554, -0.3548]))
3.4 输入两个张量,一个张量一维,一个张量二维,此时一维张量会进行广播成二维张量,然后再输出这两个张量里面相应元素中的最大值,输出张量为二维。
x = torch.randn(3,4)
z = torch.randn(4)
max = torch.max(x,z)
x,z,max
输出结果如下:
(tensor([[ 1.1917, 0.6338, 0.7590, -0.9802],
[ 0.2247, 0.3635, 1.3743, 1.6229],
[ 1.6165, 0.0634, 0.5259, 0.1285]]),
tensor([3.4765, 0.4480, 0.1502, 0.3738]),
tensor([[3.4765, 0.6338, 0.7590, 0.3738],
[3.4765, 0.4480, 1.3743, 1.6229],
[3.4765, 0.4480, 0.5259, 0.3738]]))
3.5 输入两个二维张量,输出这两个张量里面相应元素中的最大值,输出张量为二维。
x = torch.randn(3,4)
z = torch.randn(3,4)
max = torch.max(x,z)
x,z,max
输出结果如下:
(tensor([[-0.0835, 0.0718, -1.7404, -0.3218],
[ 0.0577, 0.6271, 1.4014, -0.6417],
[ 0.3917, 0.0761, 1.2479, -0.4352]]),
tensor([[-0.0717, 0.3822, 0.7256, 1.4147],
[-0.1271, 0.1503, 0.3934, 1.6760],
[-2.2341, 2.5286, -0.3500, -0.1751]]),
tensor([[-0.0717, 0.3822, 0.7256, 1.4147],
[ 0.0577, 0.6271, 1.4014, 1.6760],
[ 0.3917, 2.5286, 1.2479, -0.1751]]))
4. torch.max(input,dim)函数解析
torch.max(input, dim, keepdim=False, *, out=None)
输入input(二维)张量,当dim=0时表示找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引;当dim=1时表示找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。
5. 代码举例
5.1 dim=0,找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引,两个tensor都是一维。
x = torch.randn(3,4)
max,indices = torch.max(x,dim=0)
x,max,indices
(tensor([[ 0.1806, 1.0274, 0.5138, -1.4184],
[ 0.5892, -0.7117, -1.2707, 0.7682],
[ 0.5152, -0.8803, 1.7604, 0.4852]]),
torch.return_types.max(
values=tensor([0.5892, 1.0274, 1.7604, 0.7682]),
indices=tensor([1, 0, 2, 1])))
输出结果如下:
(tensor([[ 0.0190, 0.8180, -1.0463, 1.7940],
[ 0.7537, -1.0291, -2.3431, 0.3906],
[ 0.3715, 1.6940, -1.1200, -0.4580]]),
tensor([ 0.7537, 1.6940, -1.0463, 1.7940]),
tensor([1, 2, 0, 0]))
5.2 dim=1,找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引,两个tensor都是一维。
x = torch.randn(3,4)
max,indices = torch.max(x,dim=1)
x,max,indices
输出结果如下:
(tensor([[ 1.4832, 0.1886, -0.3044, -0.6111],
[-0.8998, 0.0610, 0.3388, 1.7176],
[ 1.6153, 0.6864, 2.3225, 1.3818]]),
tensor([1.4832, 1.7176, 2.3225]),
tensor([0, 3, 2]))
参考知识文章