基于TensorFlow的混淆矩阵计算与可视化

混淆矩阵的基本概念 百度百科
seaborn 使用说明 seaborn 0.9 中文文档

import tensorflow as tf
import seaborn as sn
import pandas as pd

y_true = [1, 0, 2, 3, 3, 1]  #真实标签
y_pred = [2, 0, 2, 3, 3, 1]  #预测标签

kind = ['one' ,'two' ,'three' ,'four']  #类别名称

with tf.Session() as sess:
    conf = tf.confusion_matrix(y_true, y_pred, num_classes=4)  #计算混淆矩阵
    print(conf.eval())
    
    conf_numpy = conf.eval()  #将 Tensor 转化为 NumPy

conf_df = pd.DataFrame(conf_numpy, index=kind ,columns=kind)  #将矩阵转化为 DataFrame

conf_fig = sn.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")  #绘制 heatmap

混淆矩阵计算结果:
[[1 0 0 0]
[0 1 1 0]
[0 0 1 0]
[0 0 0 2]]

heatmap 结果:
基于TensorFlow的混淆矩阵计算与可视化_第1张图片

你可能感兴趣的:(tensorflow)