import itertools
import pickle
# import torch
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# import torch
from dataloader import IEMOCAPDataset
videoIDs, videoSpeakers, videoLabels, videoText, \
videoAudio, videoVisual, videoSentence, trainVid, \
testVid = pickle.load(open('./IEMOCAP_features/IEMOCAP_features.pkl', 'rb'), encoding='latin1')
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix of MELD', cmap=plt.cm.Blues):
plt.figure(figsize=(4.8, 4.8))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
# plt.title(title)
# plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.axis("equal")
ax = plt.gca()
left, right = plt.xlim()
ax.spines['left'].set_position(('data', left))
ax.spines['right'].set_position(('data', right))
for edge_i in ['top', 'bottom', 'right', 'left']:
ax.spines[edge_i].set_edgecolor("white")
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
num = '{:.2f}'.format(cm[i, j]) if normalize else int(cm[i, j])
plt.text(j, i, num,
verticalalignment='center',
horizontalalignment="center",
color="white" if num > thresh else "black")
plt.tight_layout()
plt.show()
# 绘制IEMOCAP数据集的混淆矩阵
# trans_mat = np.array([[78, 1, 3, 0, 62, 0],
# [3, 200, 14, 6, 0, 22],
# [22, 26, 227, 10, 17, 82],
# [0, 7, 5, 118, 0, 40],
# [31, 7, 29, 0, 228, 4],
# [3, 20, 50, 38, 10, 260]], dtype=int)
trans_mat = np.array([[51, 9, 16, 1, 64, 3],
[5, 192, 17, 3, 1, 27],
[19, 23, 216, 18, 29, 79],
[0, 1, 7, 103, 0, 59],
[50, 18, 30, 0, 199, 2],
[0, 11, 89, 38, 8, 235]], dtype=int)
# 绘制MELD数据集的混淆矩阵
# trans_mat = np.array([[213, 0, 0, 17, 70, 11, 34],
# [32, 4, 0, 0, 22, 2, 8],
# [14, 0, 0, 2, 19, 7, 8],
# [33, 0, 0, 248, 89, 8, 24],
# [50, 0, 0, 54, 1070, 35, 47],
# [41, 0, 0, 9, 85, 62, 11],
# [51, 1, 0, 22, 39, 3, 165]], dtype=int)
"""method 2"""
if True:
# label = ["angry", "disgust", "fear", "joy", "neutral", "sadness", "surprise"]
label = ["happy", "sad", "neutral", "angry", "excited", "frustrated"]
plot_confusion_matrix(trans_mat, label)