pytorch两种返回分类最大值的方法

前言

本文是为了回忆两种输出分类结果的方法,记录两种的用法。

正文

src=torch.rand(3,10,512)
print(src.shape)
lin=nn.Linear(512,521)
out=lin(src)
print(out.shape)
pred=out.argmax(2)
print(pred)
_,pres=torch.max(out,dim=2)
print(pres)

第一种方法是利用argmax来进行返回,获取的是最大值对应的下标。输出如下所示:

tensor([[298, 298,  22,  22, 298, 312, 472, 298,  22, 491],
        [491, 298, 491, 298, 196,  22, 298, 156, 472, 491],
        [ 22,  88,  96, 156, 298, 491,  22,  88, 110, 298]])

第二种方法是利用的max函数,第一个返回的参数为最大值的具体张量数值,第二个参数是对应的下标,具体的用法与上方的代码块一致,输出如下:

tensor([[298, 298,  22,  22, 298, 312, 472, 298,  22, 491],
        [491, 298, 491, 298, 196,  22, 298, 156, 472, 491],
        [ 22,  88,  96, 156, 298, 491,  22,  88, 110, 298]])

你可能感兴趣的:(pytorch,分类,深度学习)