本次博客主要记录作者从命名实体识别模型中得到关于Transformer模型中注意力的可视化图(热力图)。
实验环境为谷歌的colab云平台,需要的库matplotlib 3.3.2(因为colab默认版本的在添加中文字体时会出错,机缘巧合换了一个版本就成功了,也是试了一上午才找出来是库的版本bug,这样的错也是很无奈了)seaborn的heatmap用来画热力图。
本次记录前半部分主要是讲获取到注意力之后如何将图画出来,更为通用一点,后半部分讲作者如何从模型中使用笨方法求出注意力权值。
首先表明本次使用的多头自注意力机制的数据情况。自注意力的权值维度是[1,8,25,25],这个权值对应的是Transformer的encoder部分中softmax后的值。大概是下图的这些值,回头我把这个权值的文件分享出来。
自注意力的文本是我进行命名实体识别的文本,本次用的模型主要是将文本和词汇信息进行融合的命名实体识别,这样来提升实体识别的效果。测试的句子就是下面这个句子,主要是探究词汇在命名实体识别的模型中起到的作用。
sent=['一', '节', '课', '的', '时', '间', '真', '心', '感', '动', '了', '李', '开', '复', '感', '动', '节课', '时间', '真心' , '心感', '感动', '李开', '李开复', '开复', '感动']
下面把numpy类型的数据读入,这个数据读入进来后需要使用类似字典的方式来调用。
import numpy as np
data = np.load('attn.npz')
data['arr_0'][0][0:1].shape #(1,25,25),意为每个头的注意力矩阵
下面是导入一些必要的库和下载中文字体,因为colab里没有中文字体,matplotlib中也不会默认使用中文字体,所以参照网上的一些方法导入中文字体。这堆代码里可能有些句子是无用的,因为尝试了很多种,不过大致的流程就是先下载一个simhei的ttf文件,然后利用mpl的函数来把字体导入,seaborn画图也是调用的mlp,后面就可以看到了。
!wget -O /usr/share/fonts/truetype/liberation/SimHei.ttf "https://www.wfonts.com/download/data/2014/06/01/simhei/chinese.simhei.ttf"
import matplotlib.pyplot as plt
import matplotlib as mpl
zhfont = mpl.font_manager.FontProperties(fname='/usr/share/fonts/truetype/liberation/SimHei.ttf',size=12) # 给zhfont添加中文黑体的属性
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
import seaborn
rc = {'font.sans-serif': 'SimHei',
'axes.unicode_minus': False}
seaborn.set(context='talk', style='ticks', rc=rc)
seaborn.set(font='SimHei')
#seaborn.set_context(context="talk")
%matplotlib inline
接下来定义画图的函数和画图的代码段
def draw(data, x, y, ax):
seaborn.heatmap(data,
xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
cbar=False, ax=ax)
fig, axs = plt.subplots(1,4, figsize=(40, 20))#布置画板
plt.xticks(fontproperties=zhfont)
plt.yticks(fontproperties=zhfont)
for h in range(4):#每一个循环都是一个头的注意力
plt.xticks(fontproperties=zhfont)#这两句是指定画图中要把坐标轴定义为中文
plt.yticks(fontproperties=zhfont)
draw(data['arr_0'][0][h:h+1].squeeze(0),x=sent,y= sent if h ==0 else [], ax=axs[h])
plt.show()
至此就可以画出热力图了。
下面主要是获得注意力的思路,实验环境fastnlp0.5.0,pytorch1.2.0
首先将模型训练完成并保存。
离线测试时,将测试集的数据条目限制为一条。
利用fastnlp的tester方法来测试,修改tester的源代码,将需要import 的库都import了,这些版本都是colab里默认的即可。
class Tester(object):
"""
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。
"""
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
"""
:param ~fastNLP.DataSet data: 需要测试的数据集
:param torch.nn.module model: 使用的模型
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics
:param int batch_size: evaluation时使用的batch_size有多大。
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
的计算位置进行管理。支持以下的输入:
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中;
2. torch.device:将模型装载到torch.device上。
3. int: 将使用device_id为该值的gpu进行训练
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。
"""
super(Tester, self).__init__()
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
self.metrics = _prepare_metrics(metrics)
self.data = data
self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size
self.verbose = verbose
self.use_tqdm = use_tqdm
self.logger = logger
if isinstance(data, DataSet):
self.data_iterator = DataSetIter(
dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler())
elif isinstance(data, BatchIter):
self.data_iterator = data
else:
raise TypeError("data type {} not support".format(type(data)))
# check predict
print((hasattr(self._model, 'predict') and callable(self._model.predict)) or \
(_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and
callable(self._model.module.predict)))
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \
(_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and
callable(self._model.module.predict)):
if isinstance(self._model, nn.DataParallel):
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict',
self._model.device_ids,
self._model.output_device),
network=self._model.module)
self._predict_func = self._model.module.predict # 用于匹配参数
elif isinstance(self._model, nn.parallel.DistributedDataParallel):
self._predict_func = self._model.module.predict
self._predict_func_wrapper = self._model.module.predict # 用于调用
else:
self._predict_func = self._model.predict
self._predict_func_wrapper = self._model.predict
else:
print(_model_contains_inner_module(model))
if _model_contains_inner_module(model):
self._predict_func_wrapper = self._model.forward
self._predict_func = self._model.module.forward
else:
self._predict_func = self._model.forward
self._predict_func_wrapper = self._model.forward
def test(self):
r"""开始进行验证,并返回验证结果。
:return Dict[Dict]: dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。
"""
# turn on the testing mode; clean up the history
#print('11111111111111111')
self._model_device = _get_model_device(self._model)
network = self._model
self._mode(network, is_test=True)
data_iterator = self.data_iterator
eval_results = {}
try:
with torch.no_grad():
if not self.use_tqdm:
from .utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
pbar.set_description_str(desc="Test")
start_time = time.time()
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
pred_dict = self._data_forward(self._predict_func, batch_x)
print(pred_dict)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics:
metric(pred_dict, batch_y)
if self.use_tqdm:
pbar.update()
for metric in self.metrics:
eval_result = metric.get_metric()
if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}")
metric_name = metric.get_metric_name()
eval_results[metric_name] = eval_result
pbar.close()
end_time = time.time()
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
if self.verbose >= 0:
self.logger.info(test_str)
except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0)
if self.verbose >= 1:
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
"""
if is_test:
model.eval()
else:
model.train()
def _data_forward(self, func, x):
"""A forward pass of the model. """
x = _build_args(func, **x)
print(x)
#z = self._model.encoder.layer_0.attn.attnscore(**x)
y = self._predict_func_wrapper(**x)
#z = self._model.att(**x)
batch_size = x['lattice'].size(0)
max_seq_len_and_lex_num = x['lattice'].size(1)
max_seq_len = x['bigrams'].size(1)
raw_embed = self._model.lattice_embed(x['lattice'])
#raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待
bigrams_embed = self._model.bigram_embed(x['bigrams'])
bigrams_embed = torch.cat([bigrams_embed,
torch.zeros(size=[batch_size,max_seq_len_and_lex_num-max_seq_len,
self._model.bigram_size]).to(bigrams_embed)],dim=1)
raw_embed_char = torch.cat([raw_embed, bigrams_embed],dim=-1)
dim2 = 0
dim3 = 2
# print('raw_embed:{}'.format(raw_embed[:,dim2,:dim3]))
# print('raw_embed_char:{}'.format(raw_embed_char[:, dim2, :dim3]))
if self._model.embed_dropout_pos == '0':
raw_embed_char = self._model.embed_dropout(raw_embed_char)
raw_embed = self._model.gaz_dropout(raw_embed)
# print('raw_embed_dropout:{}'.format(raw_embed[:,dim2,:dim3]))
# print('raw_embed_char_dropout:{}'.format(raw_embed_char[:, dim2, :dim3]))
embed_char = self._model.char_proj(raw_embed_char)
char_mask = seq_len_to_mask(x['seq_len'],max_len=max_seq_len_and_lex_num).bool()
# if self.embed_dropout_pos == '1':
# embed_char = self.embed_dropout(embed_char)
embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)
embed_lex = self._model.lex_proj(raw_embed)
lex_mask = (seq_len_to_mask(x['seq_len']+x['lex_num']).bool() ^ char_mask.bool())
embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)
assert char_mask.size(1) == lex_mask.size(1)
embedding = embed_char + embed_lex
if self._model.embed_dropout_pos == '1':
embedding = self._model.embed_dropout(embedding)
if self._model.use_abs_pos:
embedding = self._model.abs_pos_encode(embedding,pos_s,pos_e)
if self._model.embed_dropout_pos == '2':
embedding = self._model.embed_dropout(embedding)
# embedding = self.embed_dropout(embedding)
# print('embedding:{}'.format(embedding[:,dim2,:dim3]))
encoded = self._model.encoder(embedding,x['seq_len'],lex_num=x['lex_num'],pos_s=x['pos_s'],pos_e=x['pos_e'],
print_=(self._model.batch_num==327))
print(x)
#print(encoded)
##进入encoder
if self._model.encoder.relative_position:
if self._model.encoder.four_pos_fusion_shared and self._model.encoder.lattice:
rel_pos_embedding = self._model.encoder.four_pos_fusion_embedding(x['pos_s'],x['pos_e'])
else:
rel_pos_embedding = None
else:
rel_pos_embedding = None
output = embedding
output = self._model.encoder.layer_0.layer_preprocess(output)
if self._model.encoder.layer_0.lattice:
if self._model.encoder.layer_0.relative_position:
if rel_pos_embedding is None:
rel_pos_embedding = self._model.encoder.layer_0.four_pos_fusion_embedding(pos_s=x['pos_s'],pos_e=x['pos_e'])
output = self._model.encoder.layer_0.attn(output, output, output, x['seq_len'], pos_s=x['pos_s'], pos_e=x['pos_e'], lex_num=x['lex_num'],
rel_pos_embedding=rel_pos_embedding)
else:
output = self._model.encoder.layer_0.attn(output, output, output, x['seq_len'], x['lex_num'])
else:
output = self._model.encoder.layer_0.attn(output, output, output, seq_len)
#print(output)
#self._model.encoder.layer_0.attn.attn
##进入多头注意力
key=output
query=output
value=output
batch = key.size(0)
pos_s=x['pos_s']
pos_e=x['pos_e']
lex_num=x['lex_num']
seq_len=x['seq_len']
if self._model.encoder.layer_0.attn.k_proj:
key = self._model.encoder.layer_0.attn.w_k(key)
if self._model.encoder.layer_0.attn.q_proj:
query = self._model.encoder.layer_0.attn.w_q(query)
if self._model.encoder.layer_0.attn.v_proj:
value = self._model.encoder.layer_0.attn.w_v(value)
if self._model.encoder.layer_0.attn.r_proj:
rel_pos_embedding = self._model.encoder.layer_0.attn.w_r(rel_pos_embedding)
batch = key.size(0)
max_seq_len = key.size(1)
# batch * seq_len * n_head * d_head
key = torch.reshape(key, [batch, max_seq_len, self._model.encoder.layer_0.attn.num_heads, self._model.encoder.layer_0.attn.per_head_size])
query = torch.reshape(query, [batch, max_seq_len, self._model.encoder.layer_0.attn.num_heads, self._model.encoder.layer_0.attn.per_head_size])
value = torch.reshape(value, [batch, max_seq_len, self._model.encoder.layer_0.attn.num_heads, self._model.encoder.layer_0.attn.per_head_size])
rel_pos_embedding = torch.reshape(rel_pos_embedding,
[batch,max_seq_len, max_seq_len, self._model.encoder.layer_0.attn.num_heads,self._model.encoder.layer_0.attn.per_head_size])
# batch * n_head * seq_len * d_head
key = key.transpose(1, 2)
query = query.transpose(1, 2)
value = value.transpose(1, 2)
# batch * n_head * d_head * key_len
key = key.transpose(-1, -2)
# #A
# A_ = torch.matmul(query,key)
# #C
# # key: batch * n_head * d_head * key_len
u_for_c = self._model.encoder.layer_0.attn.u.unsqueeze(0).unsqueeze(-2)
# u_for_c: 1(batch broadcast) * num_heads * 1 *per_head_size
# key_for_c = key
# C_ = torch.matmul(u_for_c, key)
query_and_u_for_c = query + u_for_c
A_C = torch.matmul(query_and_u_for_c, key)
#B
rel_pos_embedding_for_b = rel_pos_embedding.permute(0, 3, 1, 4, 2)
# after above, rel_pos_embedding: batch * num_head * query_len * per_head_size * key_len
query_for_b = query.view([batch, self._model.encoder.layer_0.attn.num_heads, max_seq_len, 1, self._model.encoder.layer_0.attn.per_head_size])
# after above, query_for_b: batch * num_head * query_len * 1 * per_head_size
# print('query for b:{}'.format(query_for_b.size()))
# print('rel_pos_embedding_for_b{}'.format(rel_pos_embedding_for_b.size()))
# B_ = torch.matmul(query_for_b,rel_pos_embedding_for_b).squeeze(-2)
#D
# rel_pos_embedding_for_d = rel_pos_embedding.unsqueeze(-2)
# after above, rel_pos_embedding: batch * query_seq_len * key_seq_len * num_heads * 1 *per_head_size
# v_for_d = self.v.unsqueeze(-1)
# v_for_d: num_heads * per_head_size * 1
# D_ = torch.matmul(rel_pos_embedding_for_d,v_for_d).squeeze(-1).squeeze(-1).permute(0,3,1,2)
query_for_b_and_v_for_d = query_for_b + self._model.encoder.layer_0.attn.v.view(1,self._model.encoder.layer_0.attn.num_heads,1,1,self._model.encoder.layer_0.attn.per_head_size)
B_D = torch.matmul(query_for_b_and_v_for_d, rel_pos_embedding_for_b).squeeze(-2)
#att_score: Batch * num_heads * query_len * key_len
# A, B C and D is exactly the shape
attn_score_raw = A_C + B_D
if self._model.encoder.layer_0.attn.scaled:
attn_score_raw = attn_score_raw / math.sqrt(self._model.encoder.layer_0.attn.per_head_size)
mask = seq_len_to_mask(seq_len+lex_num).bool().unsqueeze(1).unsqueeze(1)
attn_score_raw_masked = attn_score_raw.masked_fill(~mask, -1e15)
attn_score = F.softmax(attn_score_raw_masked,dim=-1)
#attn_score = self._model.encoder.layer_0.attn.dropout(attn_score)
#import numpy as np
print(attn_score.numpy(),attn_score.size(),seq_len)
np.savez('attn',attn_score.numpy())
return y
def _format_eval_results(self, results):
"""Override this method to support more print formats.
:param results: dict, (str: float) is (metrics name: value)
"""
_str = ''
for metric_name, metric_result in results.items():
_str += metric_name + ': '
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
_str += '\n'
return _str[:-1]
主要功能的实现地方是在_data_forward的函数里,对lattice的模型进行一层一层的调用,重新实现一遍注意力的计算,最后将注意力输出,利用numpy的np.savez('attn',attn_score.numpy())保存文件,以方便下一次调用。