【李沐】 softmax回归的代码

对于一个样本使用它的one-hot标签能获得它的正确分类的概率,方便后面使用交叉熵:

【李沐】 softmax回归的代码_第1张图片

cmp = y_hat.type(y.dtype) == y

对于yhat中的每一行,也就是每一张图片,找到它的最大值,并且用它的Index来表示。所以y_hat就可以表示为tensor([2,2])

【李沐】 softmax回归的代码_第2张图片

y_hat.argmax(axis=1)解释没有参数时,是默认将数组展平,当axis=1,是在行中比较,选出最大的   索引

这也就是代码中y_hat = y_hat.argmax(axis=1)的意义。

对于接下来的一行,

cmp = y_hat.type(y.dtype) == y

其实就是先把y_hat换成和y一样的数据类型,然后比较y_hat和y是否在每一个位置上的值相等。y之前的类型是troch.in64

【李沐】 softmax回归的代码_第3张图片

y与y_hat进行比较

第一个位置不等,第二个位置相等,也就是说第一张图片预测错误,第二章图片预测正确。所以我们得到[False, True],代码中用了cmp来表示。而在Python里,False是0,True是1,所以可以用[0,1]表示,所以最后的求和

float(cmp.type(y.dtype).sum())

的结果就是看有多少个1,也就是有多少张图片预测正确。所以这个结果除以预测的总图片数y,就是预测的准确率。

对于这里的2张图片,预测对了1张,所以准确率是50%

 isinstance(net,nn.Module)

函数isinstance()可以判断一个变量的类型,既可以用在Python内置的数据类型如str、list、dict,也可以用在我们自定义的类,它们本质上都是数据类型。

pytorch中view()和flatten和reshape()

view()和flatten都是和输入共享内存的,flatten的好处是不用输入形状参数,直接指定维度,在这之后的都被拉平。view则是更加灵活

torch.reshape(input, shape) → Tensor

flatten()直接展开为一行

loss = nn.CrossEntropyLoss()  #调用CrossE时自动调用softmax

【李沐】 softmax回归的代码_第4张图片

你可能感兴趣的:(李沐《动手学深度学习》,回归,机器学习,人工智能)