predict = model(test_batch).data.max(1, keepdim=True)[1]

*model是我实验中实例化的模型,不重要*

是一些自己做实验时遇到的问题,仅仅是一些比较浅显的理解

这里.max(1,keedim=True)[1]的意思是:首先括号里的1代表需要查找第二维中的最大值,keepdim=true时对应维度被变成1(具体见探究 torch.max() 中 keepdim 参数的影响_绫清隆的博客-CSDN博客)

例如上图是model(test_batch).data的结果

上图是model(test_batch).data.max(1, keepdim=True)输出的结果

可以知道输出的是最大值4.0379以及最大值的位置1

 上图是 model(test_batch).data.max(1, keepdim=True)[1]的输出结果

易知其结果是最大值的位置信息1

 

你可能感兴趣的:(batch,开发语言,pytorch)