Transformer的一点理解,附一个简单例子理解attention中的QKV

Transformer用于目标检测的开山之作DETR,论文作者在附录最后放了一段简单的代码便于理解DETR模型。

DETR的backbone用的是resnet-50去掉了最后的AdaptiveAvgPool2d和Linear这两层。

self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])

经过一次卷积加上position embedding,输入到transformer,position embedding是直接加和,不是像叠盘子一样的concat。

这里回顾一下transformer

transformer中最重要的attention,这篇文章Attention Is All You Need (Transformer) 论文精读 - 知乎

举了一个简单的例子,去解释attention中的QKV到底是什么含义。 这里

引用上述文章作者的例子:

如果我们有这样姓名和年龄一个数据库

张三:18
李四:22
张伟:19
张三:20

如果查询『所有叫张三的人的平均年龄』,Key==“张三”,可以得到Key对应的两个Value,算出(18+20)/2=19。我们把『所有叫张三的人的平均年龄』这句话称为一个查询(Query)

如果有另一个查询Query‘:『所有姓张的人的平均年龄』, Key[0]==“张”,得到三个Value:(18+20+19)/3=19

这样查询很低效,为了高效,将Query,Key转为向量vector。

将姓名(Key)汉字编码为向量

张三:[1, 2, 0]
李四:[0, 0, 2]
张伟:[1, 4, 0]

如果一个Quary是查询所有姓张的人的平均年龄,那么Quary可以写成向量  [1, 0, 0],将Quary向量和Key向量做点积

dot([1, 0, 0], [1, 2, 0]) = 1
dot([1, 0, 0], [1, 2, 0]) = 1
dot([1, 0, 0], [0, 0, 2]) = 0
dot([1, 0, 0], [1, 4, 0]) = 1

将结果softmax归一化

softmax([1, 1, 0, 1]) = [1/3, 1/3, 0, 1/3]

再将归一化后的结果与Value做点积

dot([1/3, 1/3, 0, 1/3], [18, 20, 22, 19]) = 19

就得到了想要的结果。(说句题外话,这样查询感觉跟布隆过滤器Bloom Filter有点相似的感觉,将文字编码成位数组)

这个计算就是Attention is all you need论文里Scaled Dot-Product Attention

 在transformer中,query key value关系如下图所示,(reference:The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time.)

Transformer的一点理解,附一个简单例子理解attention中的QKV_第1张图片

 将文字编码为向量x,x与矩阵W相乘,得到q,q与k做点乘,再除8(the square root of the dimension of the key vectors used in the paper – 64),再softmax,再成v,得到z

Transformer的一点理解,附一个简单例子理解attention中的QKV_第2张图片

 Transformer的一点理解,附一个简单例子理解attention中的QKV_第3张图片

 Transformer的一点理解,附一个简单例子理解attention中的QKV_第4张图片

下面几张图,从更宏观的角度展示这个过程

Transformer的一点理解,附一个简单例子理解attention中的QKV_第5张图片

Transformer的一点理解,附一个简单例子理解attention中的QKV_第6张图片 

Transformer的一点理解,附一个简单例子理解attention中的QKV_第7张图片 

Transformer的一点理解,附一个简单例子理解attention中的QKV_第8张图片 

如果是多头注意力,在由a^{i}a^{i} q^{i} v^{i}这一步,比如说2个注意力头,就乘以2个矩阵,得到q^{i,1} q^{i,2}

Transformer的一点理解,附一个简单例子理解attention中的QKV_第9张图片

将多个头注意力的结果concat起来,传入一个linear层,就得到了最终输出Z。(Transformer系列 | BearCoding)

Transformer的一点理解,附一个简单例子理解attention中的QKV_第10张图片

在RNN中,是按顺序输入,所以网络是知道每个输入的位置次序,但是transformer不是这样,因此还要加一个positional encoding,告诉网络输入的每个词在句子中的位置

Transformer的一点理解,附一个简单例子理解attention中的QKV_第11张图片

 Transformer也使用了和resnet相似的残差连接。

将编码器得到的K,V矩阵输入到解码器

Transformer的一点理解,附一个简单例子理解attention中的QKV_第12张图片

在解码的第一步中,输入K V,得到一个output,而在后续的解码中,将前一部的结果也一起输入解码器。比如第二步中,将第一步的结果 “I”也输入decoder,直到decoder给出 end of sentence为止。

Transformer的一点理解,附一个简单例子理解attention中的QKV_第13张图片

transformer的损失函数,通过交叉熵,使两个分布相同

Transformer的一点理解,附一个简单例子理解attention中的QKV_第14张图片

Output Vocabulary是提前建好的词库,网络输出的是词库中所有词 出现在这个位置的概率。

回到DETR,DETR中叠了6个transformer的encoder和decoder,将transformer输出再分别输入两个Linear,就得到了class和bbox。

你可能感兴趣的:(transformer,深度学习,人工智能)