本文是为了回忆两种输出分类结果的方法,记录两种的用法。
src=torch.rand(3,10,512)
print(src.shape)
lin=nn.Linear(512,521)
out=lin(src)
print(out.shape)
pred=out.argmax(2)
print(pred)
_,pres=torch.max(out,dim=2)
print(pres)
第一种方法是利用argmax来进行返回,获取的是最大值对应的下标。输出如下所示:
tensor([[298, 298, 22, 22, 298, 312, 472, 298, 22, 491],
[491, 298, 491, 298, 196, 22, 298, 156, 472, 491],
[ 22, 88, 96, 156, 298, 491, 22, 88, 110, 298]])
第二种方法是利用的max函数,第一个返回的参数为最大值的具体张量数值,第二个参数是对应的下标,具体的用法与上方的代码块一致,输出如下:
tensor([[298, 298, 22, 22, 298, 312, 472, 298, 22, 491],
[491, 298, 491, 298, 196, 22, 298, 156, 472, 491],
[ 22, 88, 96, 156, 298, 491, 22, 88, 110, 298]])