Transformer 极致精简剖析(Pytorch实现)

写在前面的话
Transformer模型设计的哲学思想:大道至简 **
即:完全基于attention机制,无差别对待序列中任意token与其他token的关系计算,用李宏毅老师的话讲就是天涯若比邻
模型的设计哲学对比:
拓扑结构好神奇
从数据的拓扑结构来讲,Transformer将序列化数据看成了完全连通图(也就是没有所谓的终点和起点)而RNN则是看成了其最小连通子图(且有明确的起点和终点,即信息流定向传导,即便是双向RNN也没有改变时间轴的存在的本质)。
编码阶段的差异清晰可见,RNN被一条隐藏的时间轴所束缚,当前时间的编码嗷嗷待哺前面的结果。所以 慢!!!
而Transformer模型结构中就没有时间轴 图上的节点(序列的token)可以自由与其他节点通过边交互(也就是attention),且任意两个不同节点的行动完全互相不干扰。
天涯若比邻带来的问题
前面有说,Transformer的设计哲学,注定是将时间轴抹去了,进而加速了信息编码的速度。可是让我们回到模型的本质,任何模型提出都是为了解决改进任务,很明显我们的任务就摆在那里,编码序列信息,而这个序列信息本身是有序的!注意了,原来的RNN为了体现这个
序**,所付出的代价就是在一条隐形的时间轴进行编码,自然的将序列本身的序信息嵌入到时间信息中,代价就是速度很慢。而Transoformer为了也把序列每一个token在序信息编码进去,所以加入了Position Embedding。总结来说,Transformer牺牲了时间轴提供的序列的位置信息,换来了速度的提升和编码计算的并行,通过Position Embedding的技巧给找补回来了。Fairseq中也是这么做的。

Talk is cheap show me the code
笔者通过实现的一段Transformer玩具代码,展示Transformer都做了什么
假设我们有一个翻译任务
输入: [‘’我’,‘爱’,‘林’,“允”,‘儿’,‘PAD’] 输出:[“I”,“love”,“Yoona”] 这所以加PAD是因为实际训练时Batch数据要长度对齐,并且Transfomer针对PAD有特定的操作,所以这里加一个PAD。PAD:0 Start:1 EOS:2
那么我们这一个训练数据token化就变成了:
Ecoder_input: [6,3,4,8,9,0] Decoder_input:[1,11,12,20] Decoder_target:[11,12,20,2]
所以经过look up 操作(这里定embedding_size = 512) 进入到模型前输入是一个矩阵Matrix,shape = [1,6,512],这里batch_size = 1,后面都是1,不在重复。

Attention前的变形 内积函数计算attention矩阵
看你之前我要打扮的漂漂亮亮,所以我摇身一变,编出来K,Q,V 你可以把这三个矩阵看做实现天涯若比邻的代理,这三个矩阵是self-attention的关键和本质实现,这里按照原论文采用8头注意力,单头是64维。则K,Q,V的列数=8*64
所以有 shape of K = shape of Q = shape of V = [batch_size,6,8*64]

Transformer 极致精简剖析(Pytorch实现)_第1张图片
这里我们通过K,Q 计算出了得分矩阵 scores : [batch_size,6,6]
这里的mask矩阵是一个 行为6 列为6 的0,1矩阵,且最后一列为1,目的是就是用一个极小值掩盖调输入[‘’我’,‘爱’,‘林’,“允”,‘儿’,‘PAD’]中每个字对PAD的attention。这样这个极小值在做softmax是接近于0,也就是不在attention起到作用。
经过softmax得到归一化的attn矩阵 再用attn右乘V 得到 context,context的每一行都是按照attn的每一行的值为系数加权对V按行提取得到,可以看做是attn后的
隐向量。
多头注意力层
Transformer 极致精简剖析(Pytorch实现)_第2张图片
这里是多头注意力层实现,调用了前面的attention内积函数,最后的输出的形状是 [batch_size,6,512] d_model = 512。

前馈层 主要是起到融合Position信息和汇聚的作用
Transformer 极致精简剖析(Pytorch实现)_第3张图片
Encoder layer
Transformer 极致精简剖析(Pytorch实现)_第4张图片
Decoder layer
Transformer 极致精简剖析(Pytorch实现)_第5张图片
Encoder
Transformer 极致精简剖析(Pytorch实现)_第6张图片
Decoder
Transformer 极致精简剖析(Pytorch实现)_第7张图片
注意在decoder部分的mask很有讲究 因为在实际预测过程中 是看不到后面信息的所以训练中也要模拟这个过程,本质上来说就是decode部分的self_attention
部分是一个下三角矩阵,要用一个上三角0,1矩阵去mask从而实现在并行中也能将序列解码时序导致的只能attention左侧已经解码的单词的情形给模拟出来
那么在实际decode预测的时候,每一输出一个token,decode_output都新增一个token
TransFormer
Transformer 极致精简剖析(Pytorch实现)_第8张图片

你可能感兴趣的:(Transformer 极致精简剖析(Pytorch实现))