按维度 dim 返回最大值:官方文档
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
代码如下(示例):
x = torch.randn(3,3) print(x) y = torch.max(x,0) print(y)
输出:
tensor([[-1.3235, -0.1796, 0.5769], [ 0.2651, -0.7733, 1.3953], [-0.3819, -0.4154, -1.3699]]) torch.return_types.max( values=tensor([ 0.2651, -0.1796, 1.3953]), indices=tensor([1, 0, 1]))
代码如下(示例):
x = torch.randn(3,3) print(x) y = torch.max(x,1) print(y)
输出:
tensor([[-0.5239, 1.9699, -1.6698], [ 2.0430, -1.2690, 0.0052], [-0.4934, -0.2064, -2.0130]]) torch.return_types.max( values=tensor([ 1.9699, 2.0430, -0.2064]), indices=tensor([1, 0, 1]))