文本attention矩阵可视化

在机器阅读理解的论文中,经常可以看到对“文章-问题”可视化的二维热力图,例如下图。在看实验结果的时候用这种图可以直观的看到attention的效果怎么样。比如下图:
文本attention矩阵可视化_第1张图片

于是从github中找到了一个例子,进行了简单的实验。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.ticker as ticker

a = torch.randn(4, 2)
b = a.softmax(dim=1)
c = a.softmax(dim=0).transpose(0, 1)
print(a, '\n',  b, '\n', c)
d = b.matmul(c)
print(d)

d = d.numpy()

得到numpy的4*4数据。然后用matplotlib可视化。

variables = ['A','B','C','X']
labels = ['ID_0','ID_1','ID_2','ID_3']

df = pd.DataFrame(d, columns=variables, index=labels)

fig = plt.figure()

ax = fig.add_subplot(111)

cax = ax.matshow(df, interpolation='nearest', cmap='hot_r')
fig.colorbar(cax)

tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

ax.set_xticklabels([''] + list(df.columns))
ax.set_yticklabels([''] + list(df.index))

plt.show()

得到下图:文本attention矩阵可视化_第2张图片
如何把文字显示到坐标轴上:

nlp = spacy.load('en')    #import spacy,用于分词
sent = nlp('Which NFL team represented the AFC at Super Bowl 50')  #对文章提一个问题
doc = open('F:/spacy.txt').read()
doc = nlp(doc)

data = []
for token1 in doc:
     data.append([token1.similarity(token2) for token2 in sent])

d = np.array(data)
d = d.transpose()
col = [t.text for t in doc]     #需要显示的词
index = [t.text for t in sent]  #需要显示的词
df = pd.DataFrame(d, columns=col, index=index )

fig = plt.figure()

ax = fig.add_subplot(111)

cax = ax.matshow(df, interpolation='nearest', cmap='hot_r')
#cax = ax.matshow(df)
fig.colorbar(cax)

tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

# fontdict = {'rotation': 'vertical'}    #设置文字旋转
fontdict = {'rotation': 90}       #或者这样设置文字旋转
#ax.set_xticklabels([''] + list(df.columns), rotation=90)  #或者直接设置到这里
# Axes.set_xticklabels(labels, fontdict=None, minor=False, **kwargs)
ax.set_xticklabels([''] + list(df.columns), fontdict=fontdict)
ax.set_yticklabels([''] + list(df.index))

plt.show()

如果不设置fontdict,横轴的文字是横着的,叠加在一起,所以要设置旋转90度。
fontdict的更多文字属性设置可以在 https://matplotlib.org/api/text_api.html#matplotlib.text.Text 里面的Property找到

最后显示的结果是:
文本attention矩阵可视化_第3张图片

颜色越深的表示这两个词的相似度最高,比如50和50,24对应块的颜色挺深的。

。可以在matplotlib documentation找到

你可能感兴趣的:(可视化)