Attention 扫盲:注意力机制及其 PyTorch 应用实现

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

来自 | 知乎

作者 | Lucas

地址 | https://zhuanlan.zhihu.com/p/88376673

Attention 扫盲:注意力机制及其 PyTorch 应用实现

仿生人脑注意力模型->计算资源分配

 

深度学习attention 机制是对人类视觉注意力机制的仿生,本质上是一种资源分配机制。生理原理就是人类视觉注意力能够以高分辨率接收于图片上的某个区域,并且以低分辨率感知其周边区域,并且视点能够随着时间而改变。换而言之,就是人眼通过快速扫描全局图像,找到需要关注的目标区域,然后对这个区域分配更多注意,目的在于获取更多细节信息和抑制其他无用信息。提高 representation 的高效性。例如,对于下面一张图,我的主要关注点就在于中间的 icon 和 ATTENTION 文字,对于边框上的条纹就不太关注,而且看一眼还有点晕。

Attention 扫盲:注意力机制及其 PyTorch 应用实现_第1张图片

Encoder-Decoder框架==sequence to sequence 条件生成框架

 

Encoder-Decoder框架,也被称为 sequence to sequence 条件生成框架[1],是一种文本处理领域的研究模式。常规的 encoder-decoder方法,第一步,将输入句子序列 X通过神经网络编码为固定长度的上下文向量C,也就是文本的语义表示;第二步,由另外一个神经网络作为解码器根据当前已经预测出来的词记忆编码后的上下文向量 C,来预测目标词序列,过程中编码器和解码器的 RNN 是联合训练的,但是监督信息只出现在解码器 RNN 一端,梯度随着反向传播到编码器 RNN 一端。使用 LSTM 进行文本建模时当前流行的有效方法[2]

attention 机制的最典型应用是统计机器翻译。给定任务,输入是“Echt”, “Dicke” and “Kiste”进 encoder,使用 rnn 表示文本为固定长度向量 h3。但问题就在于,当前 decoder 生成 y1 时仅仅依赖于最后一个隐层状态h3,也就是 sentence_embedding。那么这个 h3 必须 encode 输入句子中的全部信息才行。可实际上,传统Encoder-Decoder模型并不能达到这个功能。那 LSTM [3]不就是用来解决长期依赖信息问题的嘛?但事实上,长短期记忆网络仍然存在问题。我们说,RNN在长期信息访问当前处理单元之前,需要按顺序地通过所有之前的单元。这意味着它很容易遭遇梯度消失问题。然后引入 LSTM,使用门控某种程度上解决这个问题。的确,LSTM、GRU 和其变体能学习大量的长期信息,但它们最多只能记住相对长的信息,而不是更大更长。

Attention 扫盲:注意力机制及其 PyTorch 应用实现_第2张图片 使用 RNN 文本表示与生成

所以,我们来总结一下传统 encoder-decoder的一般范式及其问题:任务是翻译中文“我/爱/赛尔”到英文。传统 encoder-decoder 先把整句话输入进去,编码最后一个词“赛尔”结束之后,使用 RNN生成一个整句话的表示-向量 C,在条件生成时,当翻译到第 2个词“赛尔”的时候,需要退 1 步找到已经预测出来的h_1以及上下文表示 C, 然后 decode 输出。 

从注意力均等到注意力集中

 

在传统Encoder-Decoder 框架下:由解码器根据当前已经预测出来的词记忆编码后的上下文向量 C,来预测目标词序列。也就是说,不论生成那个词,我们使用的句子编码表示 C 都是一样的。换句话说,句子中任意单词对生成某个目标单词P_yi来说影响力都是相同的,也就是注意力均等。很显然这不符合直觉。直觉应该:我翻译哪个部分,哪个部分就应该把注意力集中于我的翻译的原文,翻译到第一个词,就应该多关注原文中的第一个词是什么意思。详见伪代码和下图:

P_y1 = F(E,C),
P_y2 = F((E,C)
P_y3 = F((E,C)
Attention 扫盲:注意力机制及其 PyTorch 应用实现_第3张图片 传统 Encoder-Decoder 框架下的 RNN 进行文本翻译,一直使用同一个 c

接下来观察上下两个图的区别:相同的上下文表示C会替换成根据当前生成单词而不断变化的Ci。

Attention 扫盲:注意力机制及其 PyTorch 应用实现_第4张图片 融合 attention 机制的RNN 模型进行文本翻译每个时刻生成不同的 c

文本翻译过程变为:

P_y1 = F(E,C_0),
P_y2 = F((E,C_1)
P_y3 = F((E,C_2)

Encoder-Decoder框架的代码实现[4]

class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask,
                            tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

考虑可解释性

不含注意力模型的传统encoder-decoder 可解释差:对于编码向量中究竟编码了什么信息,如何利用这些信息以及解码器特定行为的原因是什么我们并没有明确的认识。包含注意力机制的结构提供了一张相对简单的方式让我们了解解码器的推理过程以及模型究竟在学习什么内容,学到那些东西。尽管是一种弱可解释性,但是已经 make sense 了。 

直面 attention 的核心公式

 

outside_default.png

在预测目标语言的第i个词时,源语言第j个词的权重为 outside_default.png , 权重的大小可i以j 看做是一种源语言和目标语言的软对齐信息。  

总结

 

使用 attention 方法实际上就在于预测一个目标词 yi 时,自动获取原句中不同位置的语义信息,并给每个位置信息的语义赋予的一个权重,也就是“软”对齐信息,将这些信息整理起来计算对于当前词 yi 的原句向量表示 c_i。 

Attention 的 PyTorch应用实现

 
import torch
import torch.nn as nn


class BiLSTM_Attention(nn.Module):
    def __init__(self):
        super(BiLSTM_Attention, self).__init__()


        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)
        self.out = nn.Linear(n_hidden * 2, num_classes)


    # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix
    def attention_net(self, lstm_output, final_state):
        hidden = final_state.view(-1, n_hidden * 2, 1)   # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]
        soft_attn_weights = F.softmax(attn_weights, 1)
        # [batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1]
        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
        return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]


    def forward(self, X):
        input = self.embedding(X) # input : [batch_size, len_seq, embedding_dim]
        input = input.permute(1, 0, 2) # input : [len_seq, batch_size, embedding_dim]


        hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
        cell_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]


        # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
        output, (final_hidden_state, final_cell_state) = self.lstm(input, (hidden_state, cell_state))
        output = output.permute(1, 0, 2) # output : [batch_size, len_seq, n_hidden]
        attn_output, attention = self.attention_net(output, final_hidden_state)
        return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]

github地址:

https://github.com/zy1996code/nlp_basic_model/blob/master/lstm_attention.py

 
   

好消息!

小白学视觉知识星球

开始面向外开放啦

 
   

Attention 扫盲:注意力机制及其 PyTorch 应用实现_第5张图片

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

你可能感兴趣的:(python,机器学习,人工智能,深度学习,java)