代码如下
>>> a=torch.randn((4,6))
>>> print(a)
tensor([[ 0.7042, 0.2533, 1.1596, -0.7436, 0.5264, 0.2085],
[ 0.2641, 0.9683, 0.4469, -1.9215, -0.7564, 1.1776],
[ 1.0520, -1.6003, -0.8634, 1.7596, -0.8464, 0.7166],
[-0.0492, -0.7746, 1.2592, -0.8273, 0.1266, 1.0450]])
>>> maxk=max((1,3))
>>> _, pred=a.topk(maxk,1,True,True)
>>> print(_)
tensor([[1.1596, 0.7042, 0.5264],
[1.1776, 0.9683, 0.4469],
[1.7596, 1.0520, 0.7166],
[1.2592, 1.0450, 0.1266]])
>>> print(pred)
tensor([[2, 0, 4],
[5, 1, 2],
[3, 0, 5],
[2, 5, 4]])
>>> _, pred=a.topk(1,1,True,True)
>>> print(_)
tensor([[1.1596],
[1.1776],
[1.7596],
[1.2592]])
>>> print(pred)
tensor([[2],
[5],
[3],
[2]])
如上我们可以看到,topk()函数取指定维度上的最大值(或最大几个),第二个参数dim=1,为按行取,dim=0,为按列取,如下:
>>> _, pred=a.topk(1,0,True,True)
>>> print(_)
tensor([[1.0520, 0.9683, 1.2592, 1.7596, 0.5264, 1.1776]])
>>> print(pred)
tensor([[2, 1, 3, 2, 0, 1]])