关于Tensor的对象的“==”运算,发现“==”后可以使用sum()方法进行计数(即元素相等的个数)

在学习pytorch过程中有这样一段代码:

_,predicted=torch.max(outputs.detach(),1) #1代表每一行的最大值
        # _,predicted=torch.max(outputs.data,1) #1代表每一行的最大值
total+=labels.size(0)   #每次都是一个batch_size
correct+=(predicted==labels).sum()

这段代码得最后一句我不太懂

于是作了以下测试

例1是两个数作“==”运算后是不可以调用sum函数进行技术求和

#例1
a=1
b=1
print((a==b).sum())  
#报错:AttributeError: 'bool' object has no attribute 'sum'


例2 是对两个数组进行以上运算,发现使用sum()函数后可以进行求和(即相等元素的个数)

#例2
import numpy as np
a=np.array([1,2,3])
b=np.array([1,2,3])
print((a==b).sum())

例3 则是对Tensor进行以上操作(同例2)

#例3
import torch
predicted=torch.Tensor([8, 9, 0, 1, 0])
labels=torch.Tensor([8, 9, 0, 1, 2])
x=0
x+=(predicted==labels).sum()
print(x)
print(type(predicted==labels))  #

你可能感兴趣的:(pytorch学习库,pytorch,python,深度学习)