predic = torch.max(outputs.data, -1)[1].cpu()

背景是文本分类任务。

output = torch.max(input,dim)

input 这里是outputs.data,维度 [4,32,10] 的一个tensor,

dim 是max函数 索引 的维度 0/1,0是按列,1是按行。-1是最后一个维度,一般也就是按行了。

输出是两个tensor,第0个tensor是每行的最大值组成的,这里是所有行的最大值组成的tensor。第1个tensor是每行最大值对应的标签组成的。因为这里要找最大值对应的标签,找出来之后和class_list对比,就可以知道文本的类型。所以用1

predic = torch.max(outputs.data, -1)[1].cpu()_第1张图片

 

 

参考:https://www.jianshu.com/p/3ed11362b54f

你可能感兴趣的:(模型模块学习,神经网络,pytorch)