import torch
x = torch.randn([4, 1, 2])
y = torch.randn([2, 2])
print("x",x,x.shape,"\n")
print("y",y,y.shape,"\n")
z = torch.max(x, y) # 除了普通比较大小外,这里还引入了广播机制
print("z",z,z.shape)
x tensor([[[-0.4735, -0.1077]],
[[-0.2462, -0.4511]],
[[ 2.1238, 1.0496]],
[[ 0.1055, 0.5510]]]) torch.Size([4, 1, 2])
y tensor([[-0.1034, 0.6950],
[-1.0687, -0.7792]]) torch.Size([2, 2])
z tensor([[[-0.1034, 0.6950],
[-0.4735, -0.1077]],
[[-0.1034, 0.6950],
[-0.2462, -0.4511]],
[[ 2.1238, 1.0496],
[ 2.1238, 1.0496]],
[[ 0.1055, 0.6950],
[ 0.1055, 0.5510]]]) torch.Size([4, 2, 2])
Pytorch中的 torch.max不仅可以求沿着某一维度的最大值和索引,而且可以用来比较两个tensor的大小,输出的tensor每个元素为被比较tensor中的最大值。