torch.max(input, dim, max=None, max_indices=None, keepdim=False) -->> (Tensor, LongTensor)
作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的值和位置索引。
应用举例
例1——返回相应维度上的最大值,并返回最大值的位置索引
a = torch.randn(4, 4)
a
>tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
torch.max(a, 1)
>torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]),
indices=tensor([3, 0, 0, 1]))
例2——如果max的参数只有一个tensor,则返回该tensor里所有值中的最大值。
a = torch.randn(4, 4)
a
>tensor([[ 0.4997, 0.8054, 0.1761, 0.3055],
[-1.2234, 0.3823, 0.2266, -2.9062],
[ 0.4390, -1.0142, -0.5314, -1.7095],
[-0.2296, -0.4230, -0.7446, -0.0828]])
torch.max(a)
>tensor(0.8054)
例3——如果max的参数是两个相同shape的tensor,则返回两tensor元素对应位置的最大值的新tensor
a = torch.randint(2, 10,(6,4))
a
>tensor([[8, 7, 3, 5],
[2, 8, 3, 4],
[3, 2, 5, 5],
[4, 7, 5, 2],
[2, 9, 3, 8],
[4, 4, 2, 2]])
b = torch.randint(2, 10,(6,4))
b
>tensor([[9, 8, 9, 2],
[4, 3, 3, 4],
[6, 9, 2, 7],
[4, 3, 2, 7],
[4, 4, 9, 2],
[8, 2, 6, 2]])
torch.max(a, b)
>tensor([[9, 8, 9, 5],
[4, 8, 3, 4],
[6, 9, 5, 7],
[4, 7, 5, 7],
[4, 9, 9, 8],
[8, 4, 6, 2]])
函数定义
torch.argmax(input, dim, keepdim=False) → LongTensor
作用:返回输入张量中指定维度的最大值的索引。
应用举例:
例1——指定维度:返回相应维度最大值的索引
a = torch.randn(4, 4)
a
>tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
torch.argmax(a, dim=1)
>tensor([ 0, 2, 0, 1])
例2——不指定维度,返回整体上最大值的序号
a = torch.randint(9,(3, 3))
a
>tensor([[5, 2, 2],
[7, 2, 0],
[8, 0, 6]])
torch.argmax(a)
>tensor(6)
用法同max
用法同argmax