Transformer自注意力可视化

本次博客主要记录作者从命名实体识别模型中得到关于Transformer模型中注意力的可视化图(热力图)。

实验环境为谷歌的colab云平台,需要的库matplotlib 3.3.2(因为colab默认版本的在添加中文字体时会出错,机缘巧合换了一个版本就成功了,也是试了一上午才找出来是库的版本bug,这样的错也是很无奈了)seaborn的heatmap用来画热力图。

本次记录前半部分主要是讲获取到注意力之后如何将图画出来,更为通用一点,后半部分讲作者如何从模型中使用笨方法求出注意力权值。


首先表明本次使用的多头自注意力机制的数据情况。自注意力的权值维度是[1,8,25,25],这个权值对应的是Transformer的encoder部分中softmax后的值。大概是下图的这些值,回头我把这个权值的文件分享出来。

Transformer自注意力可视化_第1张图片

自注意力的文本是我进行命名实体识别的文本,本次用的模型主要是将文本和词汇信息进行融合的命名实体识别,这样来提升实体识别的效果。测试的句子就是下面这个句子,主要是探究词汇在命名实体识别的模型中起到的作用。

sent=['一', '节', '课', '的', '时', '间', '真', '心', '感', '动', '了', '李', '开', '复', '感', '动', '节课', '时间', '真心' , '心感', '感动', '李开', '李开复', '开复', '感动']

 下面把numpy类型的数据读入,这个数据读入进来后需要使用类似字典的方式来调用。

import numpy as np
data = np.load('attn.npz')
data['arr_0'][0][0:1].shape      #(1,25,25),意为每个头的注意力矩阵

下面是导入一些必要的库和下载中文字体,因为colab里没有中文字体,matplotlib中也不会默认使用中文字体,所以参照网上的一些方法导入中文字体。这堆代码里可能有些句子是无用的,因为尝试了很多种,不过大致的流程就是先下载一个simhei的ttf文件,然后利用mpl的函数来把字体导入,seaborn画图也是调用的mlp,后面就可以看到了。

!wget -O /usr/share/fonts/truetype/liberation/SimHei.ttf "https://www.wfonts.com/download/data/2014/06/01/simhei/chinese.simhei.ttf"
import matplotlib.pyplot as plt
import matplotlib as mpl
zhfont = mpl.font_manager.FontProperties(fname='/usr/share/fonts/truetype/liberation/SimHei.ttf',size=12) # 给zhfont添加中文黑体的属性
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'

import seaborn
rc = {'font.sans-serif': 'SimHei',
      'axes.unicode_minus': False}
seaborn.set(context='talk', style='ticks', rc=rc)


seaborn.set(font='SimHei')
#seaborn.set_context(context="talk")
%matplotlib inline

接下来定义画图的函数和画图的代码段

def draw(data, x, y, ax):
    seaborn.heatmap(data, 
                    xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0, 
                    cbar=False, ax=ax)
fig, axs = plt.subplots(1,4, figsize=(40, 20))#布置画板
plt.xticks(fontproperties=zhfont)
plt.yticks(fontproperties=zhfont)
for h in range(4):#每一个循环都是一个头的注意力
  plt.xticks(fontproperties=zhfont)#这两句是指定画图中要把坐标轴定义为中文
  plt.yticks(fontproperties=zhfont)
  draw(data['arr_0'][0][h:h+1].squeeze(0),x=sent,y= sent if h ==0 else [], ax=axs[h])
plt.show()

Transformer自注意力可视化_第2张图片

 至此就可以画出热力图了。


下面主要是获得注意力的思路,实验环境fastnlp0.5.0,pytorch1.2.0

首先将模型训练完成并保存。

离线测试时,将测试集的数据条目限制为一条。

利用fastnlp的tester方法来测试,修改tester的源代码,将需要import 的库都import了,这些版本都是colab里默认的即可。

class Tester(object):
    """
    Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。
    """
    
    def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
        """
        
        :param ~fastNLP.DataSet data: 需要测试的数据集
        :param torch.nn.module model: 使用的模型
        :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics
        :param int batch_size: evaluation时使用的batch_size有多大。
        :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
            的计算位置进行管理。支持以下的输入:
    
            1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中;
    
            2. torch.device:将模型装载到torch.device上。
    
            3. int: 将使用device_id为该值的gpu进行训练
    
            4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。
    
            5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。
    
            如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
        :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
        :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。
        """
        super(Tester, self).__init__()

        if not isinstance(model, nn.Module):
            raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
        
        self.metrics = _prepare_metrics(metrics)
        
        self.data = data
        self._model = _move_model_to_device(model, device=device)
        self.batch_size = batch_size
        self.verbose = verbose
        self.use_tqdm = use_tqdm
        self.logger = logger

        if isinstance(data, DataSet):
            self.data_iterator = DataSetIter(
                dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler())
        elif isinstance(data, BatchIter):
            self.data_iterator = data
        else:
            raise TypeError("data type {} not support".format(type(data)))

        # check predict
        print((hasattr(self._model, 'predict') and callable(self._model.predict)) or \
                (_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and
                 callable(self._model.module.predict)))
        if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \
                (_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and
                 callable(self._model.module.predict)):
            if isinstance(self._model, nn.DataParallel):
                self._predict_func_wrapper = partial(_data_parallel_wrapper('predict',
                                                                    self._model.device_ids,
                                                                    self._model.output_device),
                                                     network=self._model.module)
                self._predict_func = self._model.module.predict  # 用于匹配参数
            elif isinstance(self._model, nn.parallel.DistributedDataParallel):
                self._predict_func = self._model.module.predict
                self._predict_func_wrapper = self._model.module.predict  # 用于调用
            else:
                self._predict_func = self._model.predict
                self._predict_func_wrapper = self._model.predict
        else:
            print(_model_contains_inner_module(model))
            if _model_contains_inner_module(model):
                self._predict_func_wrapper = self._model.forward
                self._predict_func = self._model.module.forward
            else:
                self._predict_func = self._model.forward
                self._predict_func_wrapper = self._model.forward
    
    def test(self):
        r"""开始进行验证,并返回验证结果。

        :return Dict[Dict]: dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。
        """
        # turn on the testing mode; clean up the history
        #print('11111111111111111')
        self._model_device = _get_model_device(self._model)
        network = self._model
        self._mode(network, is_test=True)
        data_iterator = self.data_iterator
        eval_results = {}
        try:
            with torch.no_grad():
                if not self.use_tqdm:
                    from .utils import _pseudo_tqdm as inner_tqdm
                else:
                    inner_tqdm = tqdm
                with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
                    pbar.set_description_str(desc="Test")

                    start_time = time.time()

                    for batch_x, batch_y in data_iterator:
                        _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
                        pred_dict = self._data_forward(self._predict_func, batch_x)
                        print(pred_dict)
                        if not isinstance(pred_dict, dict):
                            raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
                                            f"must be `dict`, got {type(pred_dict)}.")
                        for metric in self.metrics:
                            metric(pred_dict, batch_y)

                        if self.use_tqdm:
                            pbar.update()

                    for metric in self.metrics:
                        eval_result = metric.get_metric()
                        if not isinstance(eval_result, dict):
                            raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
                                            f"`dict`, got {type(eval_result)}")
                        metric_name = metric.get_metric_name()
                        eval_results[metric_name] = eval_result
                    pbar.close()
                    end_time = time.time()
                    test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
                    if self.verbose >= 0:
                        self.logger.info(test_str)
        except _CheckError as e:
            prev_func_signature = _get_func_signature(self._predict_func)
            _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
                                 check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
                                 dataset=self.data, check_level=0)
        
        if self.verbose >= 1:
            logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
        self._mode(network, is_test=False)
        return eval_results
    
    def _mode(self, model, is_test=False):
        """Train mode or Test mode. This is for PyTorch currently.

        :param model: a PyTorch model
        :param is_test: bool, whether in test mode or not.

        """
        if is_test:
            model.eval()
        else:
            model.train()
    
    def _data_forward(self, func, x):
        """A forward pass of the model. """
        x = _build_args(func, **x)
        

        print(x)

        #z = self._model.encoder.layer_0.attn.attnscore(**x)
        y = self._predict_func_wrapper(**x)
        #z = self._model.att(**x)
        batch_size = x['lattice'].size(0)
        max_seq_len_and_lex_num = x['lattice'].size(1)
        max_seq_len = x['bigrams'].size(1)

        raw_embed = self._model.lattice_embed(x['lattice'])
        #raw_embed 是字和词的pretrain的embedding,但是是分别trian的,所以需要区分对待
        bigrams_embed = self._model.bigram_embed(x['bigrams'])
        bigrams_embed = torch.cat([bigrams_embed,
                                       torch.zeros(size=[batch_size,max_seq_len_and_lex_num-max_seq_len,
                                                         self._model.bigram_size]).to(bigrams_embed)],dim=1)
        raw_embed_char = torch.cat([raw_embed, bigrams_embed],dim=-1)

        dim2 = 0
        dim3 = 2
        # print('raw_embed:{}'.format(raw_embed[:,dim2,:dim3]))
        # print('raw_embed_char:{}'.format(raw_embed_char[:, dim2, :dim3]))
        if self._model.embed_dropout_pos == '0':
            raw_embed_char = self._model.embed_dropout(raw_embed_char)
            raw_embed = self._model.gaz_dropout(raw_embed)
        # print('raw_embed_dropout:{}'.format(raw_embed[:,dim2,:dim3]))
        # print('raw_embed_char_dropout:{}'.format(raw_embed_char[:, dim2, :dim3]))

        embed_char = self._model.char_proj(raw_embed_char)
        char_mask = seq_len_to_mask(x['seq_len'],max_len=max_seq_len_and_lex_num).bool()
        # if self.embed_dropout_pos == '1':
        #     embed_char = self.embed_dropout(embed_char)
        embed_char.masked_fill_(~(char_mask.unsqueeze(-1)), 0)

        embed_lex = self._model.lex_proj(raw_embed)
        lex_mask = (seq_len_to_mask(x['seq_len']+x['lex_num']).bool() ^ char_mask.bool())
        embed_lex.masked_fill_(~(lex_mask).unsqueeze(-1), 0)

        assert char_mask.size(1) == lex_mask.size(1)

        embedding = embed_char + embed_lex
        
        if self._model.embed_dropout_pos == '1':
            embedding = self._model.embed_dropout(embedding)

        if self._model.use_abs_pos:
            embedding = self._model.abs_pos_encode(embedding,pos_s,pos_e)

        if self._model.embed_dropout_pos == '2':
            embedding = self._model.embed_dropout(embedding)
        # embedding = self.embed_dropout(embedding)

        # print('embedding:{}'.format(embedding[:,dim2,:dim3]))


        encoded = self._model.encoder(embedding,x['seq_len'],lex_num=x['lex_num'],pos_s=x['pos_s'],pos_e=x['pos_e'],
                               print_=(self._model.batch_num==327))
        print(x)
        #print(encoded)
        ##进入encoder
        if self._model.encoder.relative_position:
            if self._model.encoder.four_pos_fusion_shared and self._model.encoder.lattice:
                rel_pos_embedding = self._model.encoder.four_pos_fusion_embedding(x['pos_s'],x['pos_e'])
            else:
                rel_pos_embedding = None
        else:
            rel_pos_embedding = None
        output = embedding
        output = self._model.encoder.layer_0.layer_preprocess(output)
        if self._model.encoder.layer_0.lattice:
            if self._model.encoder.layer_0.relative_position:
                if rel_pos_embedding is None:
                    rel_pos_embedding = self._model.encoder.layer_0.four_pos_fusion_embedding(pos_s=x['pos_s'],pos_e=x['pos_e'])
                output = self._model.encoder.layer_0.attn(output, output, output, x['seq_len'], pos_s=x['pos_s'], pos_e=x['pos_e'], lex_num=x['lex_num'],
                                   rel_pos_embedding=rel_pos_embedding)
            else:
                output = self._model.encoder.layer_0.attn(output, output, output, x['seq_len'], x['lex_num'])
        else:
            output = self._model.encoder.layer_0.attn(output, output, output, seq_len)
        #print(output)
        #self._model.encoder.layer_0.attn.attn
        ##进入多头注意力

        key=output
        query=output
        value=output
        batch = key.size(0)
        pos_s=x['pos_s']
        pos_e=x['pos_e']
        lex_num=x['lex_num']
        seq_len=x['seq_len']

        if self._model.encoder.layer_0.attn.k_proj:
            key = self._model.encoder.layer_0.attn.w_k(key)
        if self._model.encoder.layer_0.attn.q_proj:
            query = self._model.encoder.layer_0.attn.w_q(query)
        if self._model.encoder.layer_0.attn.v_proj:
            value = self._model.encoder.layer_0.attn.w_v(value)
        if self._model.encoder.layer_0.attn.r_proj:
            rel_pos_embedding = self._model.encoder.layer_0.attn.w_r(rel_pos_embedding)

        batch = key.size(0)
        max_seq_len = key.size(1)


        # batch * seq_len * n_head * d_head
        key = torch.reshape(key, [batch, max_seq_len, self._model.encoder.layer_0.attn.num_heads, self._model.encoder.layer_0.attn.per_head_size])
        query = torch.reshape(query, [batch, max_seq_len, self._model.encoder.layer_0.attn.num_heads, self._model.encoder.layer_0.attn.per_head_size])
        value = torch.reshape(value, [batch, max_seq_len, self._model.encoder.layer_0.attn.num_heads, self._model.encoder.layer_0.attn.per_head_size])
        rel_pos_embedding = torch.reshape(rel_pos_embedding,
                                          [batch,max_seq_len, max_seq_len, self._model.encoder.layer_0.attn.num_heads,self._model.encoder.layer_0.attn.per_head_size])


        # batch * n_head * seq_len * d_head
        key = key.transpose(1, 2)
        query = query.transpose(1, 2)
        value = value.transpose(1, 2)



        # batch * n_head * d_head * key_len
        key = key.transpose(-1, -2)
        # #A
        # A_ = torch.matmul(query,key)
        # #C
        # # key: batch * n_head * d_head * key_len
        u_for_c = self._model.encoder.layer_0.attn.u.unsqueeze(0).unsqueeze(-2)
        # u_for_c: 1(batch broadcast) * num_heads * 1 *per_head_size
        # key_for_c = key
        # C_ = torch.matmul(u_for_c, key)
        query_and_u_for_c = query + u_for_c
        A_C = torch.matmul(query_and_u_for_c, key)


        #B
        rel_pos_embedding_for_b = rel_pos_embedding.permute(0, 3, 1, 4, 2)
        # after above, rel_pos_embedding: batch * num_head * query_len * per_head_size * key_len
        query_for_b = query.view([batch, self._model.encoder.layer_0.attn.num_heads, max_seq_len, 1, self._model.encoder.layer_0.attn.per_head_size])
        # after above, query_for_b: batch * num_head * query_len * 1 * per_head_size
        # print('query for b:{}'.format(query_for_b.size()))
        # print('rel_pos_embedding_for_b{}'.format(rel_pos_embedding_for_b.size()))
        # B_ = torch.matmul(query_for_b,rel_pos_embedding_for_b).squeeze(-2)

        #D
        # rel_pos_embedding_for_d = rel_pos_embedding.unsqueeze(-2)
        # after above, rel_pos_embedding: batch * query_seq_len * key_seq_len * num_heads * 1 *per_head_size
        # v_for_d = self.v.unsqueeze(-1)
        # v_for_d: num_heads * per_head_size * 1
        # D_ = torch.matmul(rel_pos_embedding_for_d,v_for_d).squeeze(-1).squeeze(-1).permute(0,3,1,2)

        query_for_b_and_v_for_d = query_for_b + self._model.encoder.layer_0.attn.v.view(1,self._model.encoder.layer_0.attn.num_heads,1,1,self._model.encoder.layer_0.attn.per_head_size)
        B_D = torch.matmul(query_for_b_and_v_for_d, rel_pos_embedding_for_b).squeeze(-2)
        #att_score: Batch * num_heads * query_len * key_len
        # A, B C and D is exactly the shape
        attn_score_raw = A_C + B_D

        if self._model.encoder.layer_0.attn.scaled:
            attn_score_raw  = attn_score_raw / math.sqrt(self._model.encoder.layer_0.attn.per_head_size)

        mask = seq_len_to_mask(seq_len+lex_num).bool().unsqueeze(1).unsqueeze(1)
        attn_score_raw_masked = attn_score_raw.masked_fill(~mask, -1e15)


        attn_score = F.softmax(attn_score_raw_masked,dim=-1)

        #attn_score = self._model.encoder.layer_0.attn.dropout(attn_score)
        #import numpy as np
        print(attn_score.numpy(),attn_score.size(),seq_len)
        np.savez('attn',attn_score.numpy())
        return y
    
    def _format_eval_results(self, results):
        """Override this method to support more print formats.

        :param results: dict, (str: float) is (metrics name: value)

        """
        _str = ''
        for metric_name, metric_result in results.items():
            _str += metric_name + ': '
            _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
            _str += '\n'
        return _str[:-1]

主要功能的实现地方是在_data_forward的函数里,对lattice的模型进行一层一层的调用,重新实现一遍注意力的计算,最后将注意力输出,利用numpy的np.savez('attn',attn_score.numpy())保存文件,以方便下一次调用。

你可能感兴趣的:(transformer,深度学习,神经网络)