python中可使用seaborn.heatmap画热力图,官方文档在这
在分类任务中,也可用于画混淆矩阵:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
def confusion_matrix(y_true, y_pred, labels=None):
n = len(labels)
labels_dict = {label: i for i, label in enumerate(labels)}
res = np.zeros([n, n], dtype=np.int32)
for gold, predict in zip(y_true, y_pred):
res[labels_dict[gold]][labels_dict[predict]] += 1
df = pd.DataFrame(res, index=labels, columns=labels)
sns.heatmap(df, annot=True, fmt='d')
plt.savefig("./confusion_matrix.jpg")
plt.show()
y_true = ["cat", "ant", "cat", "cat", "ant", "bird"] # 真实
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"] # 预测
labels = ["ant", "bird", "cat"]
confusion_matrix(y_true, y_pred, labels)
一些参数的含义:
def heatmap(
data, *,
vmin=None, vmax=None, cmap=None, center=None, robust=False,
annot=None, fmt=".2g", annot_kws=None,
linewidths=0, linecolor="white",
cbar=True, cbar_kws=None, cbar_ax=None,
square=False, xticklabels="auto", yticklabels="auto",
mask=None, ax=None,
**kwargs
)
例子:
import numpy as np
np.random.seed(0)
import seaborn as sns
sns.set_theme()
uniform_data = np.random.rand(10, 12)
ax = sns.heatmap(uniform_data)
将最后一行改为,设置最大值和最小值:
ax = sns.heatmap(uniform_data, vmin=0, vmax=1)
设置中心值:
normal_data = np.random.randn(10, 12)
ax = sns.heatmap(normal_data, center=0)
从文件中获取数据,并画图给出有意义的横纵坐标:
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
ax = sns.heatmap(flights)
将passengers对应的人数标出:
ax = sns.heatmap(flights, annot=True, fmt="d")
ax = sns.heatmap(flights, linewidths=.5)
ax = sns.heatmap(flights, cmap="YlGnBu")
以某个具体的数据为中心:
ax = sns.heatmap(flights, center=flights.loc["Jan", 1955])
自动填充坐标值:
data = np.random.randn(50, 20)
ax = sns.heatmap(data, xticklabels=2, yticklabels=False)
不画右边的热度条:
ax = sns.heatmap(flights, cbar=False)