2021-10-19

Multi-Head Attention的讲解

一、什么是 Attention

Attention机制最早是在视觉图像领域提出来的,应该是在九几年思想就提出来了,但是真正火起来应该算是2014年google mind团队的这篇论文《Recurrent Models of Visual Attention》,他们在RNN模型上使用了attention机制来进行图像分类。2017年,google机器翻译团队发表的《Attention is all you need》中大量使用了自注意力(self-attention)机制来学习文本表示。自注意力机制也成为了大家近期的研究热点,并在各种NLP任务上进行探索。

Attention机制的本质来自于人类视觉注意力机制。人们视觉在感知东西的时候一般不会是一个场景从到头看到尾每次全部都看,而往往是根据需求观察注意特定的一部分。而且当人们发现一个场景经常在某部分出现自己想观察的东西时,人们会进行学习在将来再出现类似场景时把注意力放到该部分上。

Attention 的核心是用文本中的其它词来增强目标词的语义表示,从而更好的利用上下文的信息。

二、什么是 Multi-Head Attention

Multi-Head Attention是在Tansformer 中提出的,多头 Attention,简单来说就是多个 Self-Attention 的组合,它的作用类似于 CNN 中的多核。但是多头 attention的实现不是循环的计算每个头,而是通过 transposes and reshapes,用矩阵乘法来完成的。

三、Multi-Head Attention的计算流程

2021-10-19_第1张图片

由上图可以看出多头attention的计算分成三个步骤:

  1. 获取Q、K、V,并进行线性变化
  2. 将线性变化后的Q、K、V进行放缩点积attention(scaled dot-Product attention)
  3. 放缩点积attention结果进行拼接(concat),再线性变换,得到attention后的编码信息

3.1、获取Q、K、V,并进行线性变化

Q、K、V是输入文本进行向量编码后,再使用三角函数为每个字加上位置编码后拷贝三份得到的,然后经过线性变换(每个的参数W是不一样的)后输入到放缩点积中。

3.2、 放缩点积attention(scaled dot-Product attention)

对比attention的一般形式,scaled dot-Product attention就是我们常用的使用点积进行相似度计算的attention,其实就是我们常说的self-attention
2021-10-19_第2张图片

上图是self-attention的整个流程,也主要包含了三个部分:

  1. 创建多头,关联度计算

  2. softmax

  3. 加权平均

    为什么要加 softmax() ,因为权重必须为概率分布即和为1。softmax() 里面 2 部分算的就是注意力的原始分数,通过计算Q(query)与K(key)的点积得到相似度分数,其中 [公式] 起到一个调节作用,不至于过大或过小,导致 softmax() 里面 1 部分之后就非0即1。因此这种注意力的形式也叫放缩的点积注意力。

    3.2.1、关联度计算
    创建多头就是将embedding拆分为多个子空间,精确到子空间粒度为了更好地学习特征。

    所谓的关联度计算就是在一个段文本中,每个词之间在整个文本语义表达上的关联程度,这个关联程度值由Q、K经过 transposes and reshapes 的变化得到。一下就是其数据维度变化的过程

    3.2.2、softmax
    2021-10-19_第3张图片

    上面是softmax关于Q、K的计算公式,整个计算流程具体的维度变化见图3或源码注释

    3.2.3、加权平均
    加权平均就是将softmax的输出作为最终的词之间的关联度,加权到每个词上,具体见图3 的:

    2021-10-19_第4张图片

softmax的具体原理请看:Softmax

四、Multi-Head Attention 的源码剖析

在整个 Transformer / BERT 的代码中,(Multi-Head Scaled Dot-Product) Self-Attention 的部分是相对最复杂的,也是 Transformer / BERT 的精髓所在。见一下源码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2021/10/19 3:14 PM
# @Author  : Yingjun Zhu
# @File    : selfAttention.py
from torch import nn
import torch
import math

class SelfAttention(nn.Module):
    def __init__(self, config):
        super(SelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))

        # 在Transformer/BERT中,这里的 all_head_size 就等于 config.hidden_size == 768
        # 这样使得多个attention头合起来维度还是config.hidden_size
        # 而 attention_head_size== 64  就是每个attention头的维度,要保证可以整除
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        #
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Q、K、V三个参数矩阵 ,每个数据结构都是:[32 * 128 * 768]
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        """
        shape of x: batch_size * seq_length * hidden_size  == >[32 * 128 * 768]
        这个操作是把hidden_size分解为 self.num_attention_heads * self.attention_head_size   ===> 768 分解为:12*64
        然后再交换 seq_length 维度 和 num_attention_heads 维度    ==
        为什么要做这一步:因为attention是要对query中的每个字和key中的每个字做点积,即是在 seq_length 维度上
        query和key的点积是 [batch_size * num_attention_heads * seq_length * attention_head_size] * [batch_size * num_attention_heads * attention_head_size * seq_length]=[batch_size * num_attention_heads * seq_length * seq_length]
                          [32 * 12 * 128 * 64] * [32 * 12 * 64 * 128] = [128 * 128]
        """
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask):
        # shape of hidden_states and mixed_*_layer: batch_size * seq_length * hidden_size == >[32 * 128 * 768]
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # Q、K、V三个参数矩阵更改维度后的shape : batch_size * num_attention_heads * seq_length * attention_head_size  == > [32 * 12 * 128 * 64]
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        #  "query" 和 "key"之间的点乘 得到关联度分数  .
        #  attention_scores: batch_size * num_attention_heads * seq_length * seq_length  == > [32 * 12 * 128 * 64]
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 这里就是做 Scaled,将方差统一到1,避免维度的影响
        attention_scores /= math.sqrt(self.attention_head_size)

        #  attention_mask的形状: batch_size * 1 * 1 * seq_length. 它可以自动广播到和attention_scores一样的维度
        # 我们初始输入的attention_mask是:batch_size * seq_length,做了两次unsqueeze之后得到当前的attention_mask
        attention_scores = attention_scores + attention_mask

        # Softmax 不改变维度
        # batch_size * num_attention_heads * seq_length * seq_length == > [32 * 12 * 128 * 128]
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # V 的 shape: batch_size * num_attention_heads * seq_length * attention_head_size == > [32 * 12 * 128 * 64]
        # v和score点乘后的隐层的shape : batch_size * num_attention_heads * seq_length * attention_head_size == > [32 * 12 * 128 * 64]
        # 隐层的再次改变后的shape: batch_size * seq_length * num_attention_heads * attention_head_size == > [32 * 128 * 12 * 64]
        # context_layer 合并多头后 维度恢复到:batch_size * seq_length * hidden_size  == >[32 * 128 * 768]
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer

你可能感兴趣的:(自然语言处理,深度学习,机器学习)