文本attention矩阵可视化

今日在进行attention可视化时查询了一些资料,文本attention矩阵可视化代码记录如下:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.ticker as ticker
import matplotlib as mpl

variables = ['A','B','C','X']
labels = ['ID_0','ID_1','ID_2','ID_3']

mpl.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei'] # 设置中文字体为黑体
mpl.rcParams['axes.unicode_minus'] = False

df = pd.DataFrame(d, columns=variables, index=labels)  #其中d为4*4的矩阵
fig = plt.figure(figsize=(15,15))    #设置图片大小
ax = fig.add_subplot(111)
cax = ax.matshow(df, interpolation='nearest', cmap='hot_r')
fig.colorbar(cax)

tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

ax.set_xticklabels([''] + list(df.columns))
ax.set_yticklabels([''] + list(df.index))

plt.show()

最终结果显示:


attention.png

可能出现的问题:
■ 输出中文为方框。
使用下列代码查看可用字体,如果有中文就直接使用mpl.rcParams['font.sans-serif'] = [fontname] 将字体设置为对应的字体

from matplotlib.font_manager import FontManager
import subprocess
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
#print(mat_fonts)
output = subprocess.check_output(
    'fc-list :lang=zh -f "%{family}\n"', shell=True) # 获取字体列表
output = output.decode('utf-8')
#print(output)

zh_fonts = set(f.split(',', 1)[0] for f in output.split('\n'))
available = mat_fonts & zh_fonts
print('*' * 10, '可用的字体', '*' * 10)
for f in available:
    print(f)

若字体库中不存在可用字体,也可自己下载字体添加到matplotlib中。详细添加方法可自行查阅。

参考:
https://blog.csdn.net/m0_38133212/article/details/86664569

你可能感兴趣的:(文本attention矩阵可视化)