首先贴上nn.Transformer官方介绍。网上有许多关于这个函数的解释,但道理我都懂,具体怎么实战我是一个也没找到。最直观的就是前向传播需要传入下图中的八个参数,具体怎么用,长啥样子着实让人摸不着头脑。因此本人自己实现了一个机器翻译的简单任务。贴上GitHub链接:
https://github.com/wulele2/nn.Transformer
本文不过多涉及Transformer原理介绍,因为网上太多了,只会简单介绍下核心部分。
d_model = 512 # Embedding Size
d_ff = 2048 # FeedForward dimension
n_layers = 6 # number of Encoder of Decoder Layer
n_heads = 8 # number of heads in Multi-Head Attention
self.transformer = nn.Transformer(d_model = d_model, nhead= n_heads,
num_encoder_layers=n_layers, num_decoder_layers=n_layers,
dim_feedforward=d_ff)
首先初始化了一个transformer对象,其中参数设定跟标准Transformer一致:嵌入向量是512;总共各6层EncoderLayer和DecoderLayer;每层内部使用8个头;FFN模块全连接层维度是2048。
src 和tgt是原始句子和真实句子的嵌入向量,shape分别为[src_len,N,512]和[tgt_len,N,512]。其中N表示批次,这里值得注意的是nn.Transformer没有自动添加positionEmbedding。故传入之前需要自己实现并相加。
我在简单说下*_mask和*_key_padding_mask,前者是用来遮挡未来的单词,是行列相等的上三角矩阵,后者遮挡pad部分的单词。具体可以debug下代码。
src_mask取None;tgt_mask = [tgt_len, tgt_len]是上三角bool型矩阵;memory_mask取None;
src_key_padding_mask是输入句子的pad mask张量;tgt_key_padding_mask是真实句子的pad_mask张量;memory_key_padding_mask和src_key_padding_mask一样。后续我会介绍为啥这么选取,各个形状可以见官网。
建议读者先自行运行一遍,厘清各个参数的shape以及含义。
class myTransformer(nn.Module):
def __init__(self):
super(myTransformer, self).__init__()
self.transformer = nn.Transformer(d_model = d_model, nhead= n_heads,
num_encoder_layers=n_layers, num_decoder_layers=n_layers,
dim_feedforward=d_ff)
self.src_emb = nn.Embedding(src_vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model)
self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()
def forward(self, enc_inputs, dec_inputs):
'''
enc_inputs: [batch_size, src_len], [1,2,3,4,0]
dec_inputs: [batch_size, tgt_len], [6,1,2,3,5,8]
'''
b, src_len = enc_inputs.shape[0], enc_inputs.shape[1]
b, tgt_len = dec_inputs.shape[0], dec_inputs.shape[1]
src_mask = self.transformer.generate_square_subsequent_mask(src_len).cuda()
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len).cuda()
memory_mask = None
src_key_padding_mask = enc_inputs.data.eq(0).cuda() # [N,S]
tgt_key_padding_mask = dec_inputs.data.eq(0).cuda() # [N,T]
memory_key_padding_mask = src_key_padding_mask # [N,S]
# 嵌入向量
enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).cuda() # [ src_len, batch_size, d_model]
dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).cuda() # [ tgt_len,batch_size, d_model]
#送入Transformer
dec_outputs = self.transformer(src= enc_outputs, tgt = dec_outputs, src_mask = None, tgt_mask = tgt_mask,
memory_mask = None, src_key_padding_mask = src_key_padding_mask,
tgt_key_padding_mask = tgt_key_padding_mask, memory_key_padding_mask = memory_key_padding_mask)
# 维度变换
dec_logits = self.projection(dec_outputs.transpose(0,1)) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
return dec_logits.view(-1, dec_logits.size(-1)), None, None, None
至于上述mask参数为何如此选择,需要简要介绍下Transformer内部原理:首先在encoder部分:q=k=v均来自src的三个线性层变换,在做multiheadAttn中没必要遮住未来的单词,而仅仅需要遮住pad的部分即可,因此src_mask=None,src_key_padding_mask需要传入;
其次到了decoder,包括self-attn和cross-attn。首先介绍self-attn,此时q=k=v来自tgt的三个线性层变换,我们不仅需要遮住未来信息也需要遮住pad的信息,所以,tgt_mask和tgt_key_padding_mask两个参数均需要指定;其次对于cross attn,q = tgt,k=v来自encoder输出称之为memory,但做multiheadAttn时仅需要遮住pad部分即可,不需要遮住未来的信息,因为tgt已经遮住过了,所以memory_mask=None, memory_key_padding_mask需要指定成和src_key_padding_mask一样的shape。
至于上述更深层原理,则需要看pytorch内部调用的MultiHeadAttn模块。我这里贴下核心部分:
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # 计算Q*K
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] # 判断一个tensor的shape是否等于某个尺寸,将其转成list。
# 利用attn_mask将未来的词遮挡住
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
# 借助key_padding_mask将pad部分遮挡住
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # [2,8,5,5]
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
上述代码流程比较简单,就是首先计算Q中每个元素和K的相似度,要依次用两种mask来遮挡住,为后续的softmax做准备。以cross attn为例,attn_output_weights是计算了每个真实单词和原始句子每个单词的相似性权重,所以要用和src_key_padding_mask一样的memory_key_padding_mask在行的维度上进行遮挡,故二者pad_mask是一致的。
1) 在计算交叉熵损失,记得指定ignore_index=0,即忽略类别0即pad的损失。
2)nn.Transformer需要自己实现positional embedding。
3)希望自己做到知行合一,不要不懂装懂。
完整代码:https://github.com/wulele2/nn.Transformer
后期会出detr,deformable detr在mmdet中源码解读系列。敬请期待。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。