Python混淆矩阵(误差矩阵)代码

import itertools
import pickle
# import torch
import matplotlib.pyplot as plt
import numpy as np
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# import torch
from dataloader import IEMOCAPDataset

videoIDs, videoSpeakers, videoLabels, videoText, \
videoAudio, videoVisual, videoSentence, trainVid, \
testVid = pickle.load(open('./IEMOCAP_features/IEMOCAP_features.pkl', 'rb'), encoding='latin1')


def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix of MELD', cmap=plt.cm.Blues):
    plt.figure(figsize=(4.8, 4.8))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # plt.title(title)
    # plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)

    plt.axis("equal")

    ax = plt.gca()
    left, right = plt.xlim()
    ax.spines['left'].set_position(('data', left))
    ax.spines['right'].set_position(('data', right))
    for edge_i in ['top', 'bottom', 'right', 'left']:
        ax.spines[edge_i].set_edgecolor("white")

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        num = '{:.2f}'.format(cm[i, j]) if normalize else int(cm[i, j])
        plt.text(j, i, num,
                 verticalalignment='center',
                 horizontalalignment="center",
                 color="white" if num > thresh else "black")

    plt.tight_layout()

    plt.show()


# 绘制IEMOCAP数据集的混淆矩阵
# trans_mat = np.array([[78, 1, 3, 0, 62, 0],
#                       [3, 200, 14, 6, 0, 22],
#                       [22, 26, 227, 10, 17, 82],
#                       [0, 7, 5, 118, 0, 40],
#                       [31, 7, 29, 0, 228, 4],
#                       [3, 20, 50, 38, 10, 260]], dtype=int)

trans_mat = np.array([[51, 9, 16, 1, 64, 3],
                      [5, 192, 17, 3, 1, 27],
                      [19, 23, 216, 18, 29, 79],
                      [0, 1, 7, 103, 0, 59],
                      [50, 18, 30, 0, 199, 2],
                      [0, 11, 89, 38, 8, 235]], dtype=int)
# 绘制MELD数据集的混淆矩阵
# trans_mat = np.array([[213, 0, 0, 17, 70, 11, 34],
#                       [32, 4, 0, 0, 22, 2, 8],
#                       [14, 0, 0, 2, 19, 7, 8],
#                       [33, 0, 0, 248, 89, 8, 24],
#                       [50, 0, 0, 54, 1070, 35, 47],
#                       [41, 0, 0, 9, 85, 62, 11],
#                       [51, 1, 0, 22, 39, 3, 165]], dtype=int)

"""method 2"""
if True:
    # label = ["angry", "disgust", "fear", "joy", "neutral", "sadness", "surprise"]
    label = ["happy", "sad", "neutral", "angry", "excited", "frustrated"]
    plot_confusion_matrix(trans_mat, label)

你可能感兴趣的:(python,人工智能,数学建模,python,开发语言)