CNN中的混淆矩阵 | PyTorch系列

跟着学的pytorch神经网络教学

https://deeplizard.com/resources

https://deeplizard.com/learn/video/gZmobeGL0Yg

涉及混淆矩阵,我在敲代码的过程中,发现我的代码和原视频一模一样却出不了相似的训练结果,百思不得其解,最后我发现我只是单纯的将神经网络部分简单的放进代码里面,并没有进行训练,导致我出现异常的训练结果

CNN中的混淆矩阵 | PyTorch系列_第1张图片

CNN中的混淆矩阵 | PyTorch系列_第2张图片

 和正确结果大相庭径。

正确的过程应该是在进行混淆矩阵之前,我们需要将神经网络训练好。

也就是我们需要在得到混淆矩阵的预测值前,加上

以下的训练代码

train_loader = torch.utils.data.DataLoader(train_set,batch_size=100)
optimizer = optim.Adam(network.parameters(),lr=0.01)


for epoch in range(5):
    total_loss = 0
    total_correct = 0
    for batch in train_loader:

        images,labels = batch

        preds = network(images)
        loss = F.cross_entropy(preds,labels)#计算损失函数(交叉熵损失函数)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()
        total_correct+=get_num_correct(preds,labels)

    print("epoch:",epoch,"total_correct:",total_correct,"loss:",total_loss)

之后我的训练值就正常了。

CNN中的混淆矩阵 | PyTorch系列_第3张图片

最后说一下,混淆矩阵他用到的函数

plotcm.py:

import itertools
import numpy as np
import matplotlib.pyplot as plt
 
 
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
 
 
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
 
 
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
 
 
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

利用   from plotcm import plot_confusion_matrix
导入即可

 

中文翻译搬运在这里

(8条消息) CNN中的混淆矩阵 | PyTorch系列(二十三)_flyfor2013的博客-CSDN博客

你可能感兴趣的:(pytorch,cnn,深度学习)