虽然Mask机制在NLP领域是一个十分常见的操作,但是过去并没有仔细思考它的意义。最近参加了阿里天池的一个关于医学影像报告异常检测的数据竞赛。本质上是一个关于文本的多标签分类任务。在这个任务中,我尝试使用Transformer的Encoder结构作为基础来构建分类模型。为了巩固以及加深理解,没有使用PyTorch自带的Transformer模型,而是选择手动搭建。
在Encoder部分,涉及到的mask主要指self-attention过程中,在计算每个token的query与key的相似度时,需要考虑一个重要的问题就是padding。因为我们的每条语料数据的长度一般是不同的,因此为了保证输入模型的input的size完全一致,我们会在末尾添加padding部分来使得每个输入的长度完全一样。但是这部分内容实际上是没有意义的。因此在attention时,注意力不应该放在这部分,应该将这部分mask起来。也就是说我们需要将query与padding部分对应的key的相似度度量统一转化为一个很小的数,比如1e-9或1e-10。这样经过softmax之后,这部分的权重会接近于0。那么具体该怎么做呢?额,直接看代码吧。
首先是Encoder部分的最底层实现,MultiHeadAttention以及之后的全连接层。
假设我们的输入尺寸为[B,L], B代表batch_size, L代表seq_len,也就是序列长度,那么我们再经过self-attention之后得到的输出,也就是下面的scaled_attn,尺寸为[B,H,L,L],其中H代表Head个数,具体的转化过程见代码,我就不展开了。然后就到我们需要做mask的时候了,这个时候我们会在模型中传入一个尺寸为[B,1,1,L]的mask,其中pad_idx对应位置被设置为0,也就是需要mask的位置,其余为1。借助PyTorch的broadcasting机制,我们可以顺利地实现对scaled_attn的mask任务,然后将经过mask之后的scaled_attn(相似度度量矩阵)去与value相乘,得到每个token最终的向量表示。因为本文重点在与解释mask 机制,对于后面的全连接以及residual+layer_norm的操作就不展开细说了,大家可以参考下面的代码。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaleDotProductAttention(nn.Module):
def __init__(self,scale,atten_dropout=0.1):
super(ScaleDotProductAttention,self).__init__()
self.dropout=nn.Dropout(atten_dropout)
self.scale=scale
def forward(self,q,k,v,mask=None): #shape=[B,H,L,D]
attn=torch.matmul(q,k.transpose(-2,-1)) #这里q:[B,H,L,D] k:[B,H,D,L]
scaled_attn=attn/self.scale #attn 的output:[B,H,L,L]
if mask is not None: #传入的mask:[B,1,1,L]
scaled_attn.masked_fill(mask==0,-1e9)
scaled_attn=self.dropout(F.softmax(scaled_attn,dim=-1))
output=torch.matmul(scaled_attn,v)
return output,scaled_attn
class MultiHeadAttention(nn.Module):
def __init__(self,n_head,dim_model,dim_k,dim_v,dropout=0.2):
super(MultiHeadAttention,self).__init__()
self.dim_model=dim_model
self.n_head=n_head
self.dim_k=dim_k #query 和 key的维度相同所以这里只定义一个
self.dim_q=dim_k
self.dim_v=dim_v
self.w_q=nn.Linear(dim_model,n_head*dim_k,bias=False)
self.w_k=nn.Linear(dim_model,n_head*dim_k,bias=False)
self.w_v=nn.Linear(dim_model,n_head*dim_v,bias=False)
self.fc=nn.Linear(n_head*dim_v,dim_model,bias=False)
self.attention=ScaleDotProductAttention(scale=dim_k**0.5)
self.dropout=nn.Dropout(dropout)
self.layer_norm=nn.LayerNorm(dim_model,eps=1e-6)
def forward(self,q,k,v,mask=None):
d_k,d_v,n_head=self.dim_k,self.dim_v,self.n_head
batch_size,len_q,len_k,len_v=q.size(0),q.size(1),k.size(1),v.size(1)
residual=q
q=self.w_q(q).view(batch_size,len_q,n_head,d_k) #将head单独取出作为一维
k=self.w_k(k).view(batch_size,len_k,n_head,d_k)
v=self.w_v(v).view(batch_size,len_v,n_head,d_v)
#在attention前将len_ 与 head维度互换
q,k,v=q.transpose(1,2),k.transpose(1,2),v.transpose(1,2) #shape=[B,H,L,D]
if mask is not None: #传入的mask:[B,1,L]
mask = mask.unsqueeze(1) # For head axis broadcasting--->mask:[B,1,1,L]
#attention
output,attn=self.attention(q,k,v,mask=mask)
output=output.transpose(1,2).contiguous().view(batch_size,len_q,-1)#合并heads
output=self.dropout(self.fc(output))
#print(output.shape,q.shape)
output+=residual #+residual
output=self.layer_norm(output) #layer normalization
return output
class PositionwiseFeedForward(nn.Module):
'''two feed forward layers'''
def __init__(self,dim_in,dim_hid,dropout=0.2):
super(PositionwiseFeedForward,self).__init__()
self.w1=nn.Linear(dim_in,dim_hid)
self.w2=nn.Linear(dim_hid,dim_in) #输出维度不变
self.layer_norm=nn.LayerNorm(dim_in,eps=1e-6)
self.dropout=nn.Dropout(dropout)
def forward(self,x):
residual=x
x=self.w2(self.dropout(F.relu(self.w1(x))))
x+=residual
return x
在完成来上述基础组件的搭建之后,我们就可以实现单个encoder_layer以及由任意多个encoder_layer搭建的完整Encoder了,下面是代码,为了看起来清晰,我将encoder_layer单独写在一个脚本上了。
import torch.nn as nn
import torch
from transformer_sublayers import ScaleDotProductAttention,MultiHeadAttention, PositionwiseFeedForward
class EncoderLayer(nn.Module):
def __init__(self,dim_model,dim_hid,n_head,dim_k,dim_v,dropout=0.2):
super(EncoderLayer,self).__init__()
self.slf_attn=MultiHeadAttention(n_head,dim_model,dim_k,dim_v)
self.ffn=PositionwiseFeedForward(dim_model,dim_hid,dropout=dropout)
def forward(self,enc_input,slf_attn_mask=None):
attn_output=self.slf_attn(enc_input,enc_input,enc_input,mask=slf_attn_mask) #mask:Boolean构成的[B,1,L]
output=self.ffn(attn_output)
return output
下面是完整Encoder(其中包括设计mask的函数定义)
import torch
import torch.nn as nn
import numpy as np
from transformer_layers import EncoderLayer
def get_pad_mask(seq,pad_idx):
return (seq!=pad_idx).unsqueeze(-2)
#定义位置信息
class PositionalEncoding(nn.Module):
def __init__(self, dim_hid, n_position=200):
super(PositionalEncoding, self).__init__()
#缓存在内存中,常量
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, dim_hid))
def _get_sinusoid_encoding_table(self, n_position, dim_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dim_hid) for hid_j in range(dim_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class Encoder(nn.Module):
def __init__(self,vocab_size,dim_word_vec,n_layers,n_head,dim_k,dim_v,dim_model,dim_hid,pad_idx,dropout=0.2,n_position=200):
super(Encoder,self).__init__()
self.embedding_layer=nn.Embedding(num_embeddings=vocab_size,
embedding_dim=dim_word_vec,
padding_idx=pad_idx)
self.positionencode=PositionalEncoding(dim_hid=dim_word_vec,n_position=200)
self.dropout=nn.Dropout(dropout)
self.layer_stacks=nn.ModuleList([
EncoderLayer(dim_model=dim_model,dim_hid=dim_hid,n_head=n_head,dim_k=dim_k,dim_v=dim_v)
for _ in range(n_layers)])
self.layer_norm=nn.LayerNorm(dim_model,eps=1e-6)
self.dim_model=dim_model
self.pad_idx=pad_idx
def forward(self,x):
token_embedd=self.embedding_layer(x)
token_position_embedd=self.dropout(self.positionencode(token_embedd))
encode_output=self.layer_norm(token_position_embedd) #shape=[B,L,E]--->(batch_size,seq_len,embed_dim)
mask=get_pad_mask(x,self.pad_idx) #shape=[B,1,L]
for encode_layer in self.layer_stacks:
encode_output=encode_layer(encode_output,slf_attn_mask=mask)
return encode_output
参考代码:
https://github.com/jadore801120/attention-is-all-you-need-pytorch