Attention is all you need源码学习3

先用structure看一下code的整体架构阿整体架构,如下图所示:
Attention is all you need源码学习3_第1张图片
我理解的这部分框架是酱的,也有可能不对辣,尽力去理解了,有问题请指出,下图所示:
Attention is all you need源码学习3_第2张图片
接下来就看看代码吧~

Models.py

Transformer

搭建transformer模型,文章的模型如图所示:
Attention is all you need源码学习3_第3张图片
搭建的代码继承了pytorch的nn.model,写法固定,分为两部分:1.__init__定义网络中的参数和模型框架;2.forword定义传输数据的连接,即网络或模型中的线。最简单的神经网络如下所示。

class XXX(torch.nn.Module):     # 继承 torch 的 Module
    def __init__(self, n_feature, n_hidden, n_output):
        super(XXX, self).__init__()     # 继承 __init__ 功能
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # 隐藏层线性输出
        self.out = torch.nn.Linear(n_hidden, n_output)       # 输出层线性输出

    def forward(self, x):
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))      # 激励函数(隐藏层的线性值)
        x = self.out(x)                 # 输出值, 但是这个不是预测值, 预测值还需要再另外计算
        return x

再回到本文代码,Transformer的模型代码如下:

class Transformer(nn.Module): #Transformer模型继承pytorch的nn.model,搭建网络的固定写法
    ''' A sequence to sequence model with attention mechanism. '''

    def __init__(  #init传入参数
            self,
            n_src_vocab, n_tgt_vocab, len_max_seq,#词表的大小、句子序列的最大长度
            d_word_vec=512, d_model=512, d_inner=2048,#可选参数:词表维度,模型维度,内部层维度
            n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,#层数,attention的头数为8
            tgt_emb_prj_weight_sharing=True,
            emb_src_tgt_weight_sharing=True):

        super().__init__()

        # 初始化encoder模型,用于组成encoder-decoder的组件
        self.encoder = Encoder(
            n_src_vocab=n_src_vocab, len_max_seq=len_max_seq,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            dropout=dropout)

        # 初始化dncoder模型,用于组成encoder-decoder的组件
        self.decoder = Decoder(
            n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            dropout=dropout)

        self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)  #线性层y = Ax+0b,输入时model,输出是n_tgt_vocab,b=0
        nn.init.xavier_normal_(self.tgt_word_prj.weight)  #权值初始化,服从正态分布

        assert d_model == d_word_vec, \
        'To facilitate the residual connections, \
         the dimensions of all module outputs shall be the same.'

        if tgt_emb_prj_weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            #共享权重矩阵
            self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
            self.x_logit_scale = (d_model ** -0.5)  #点积的缩放因子
        else:
            self.x_logit_scale = 1.

        if emb_src_tgt_weight_sharing:
            # Share the weight matrix between source & target word embeddings
            assert n_src_vocab == n_tgt_vocab, \
            "To share word embedding table, the vocabulary size of src/tgt shall be the same."
            self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight

    def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):

        tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]  #除去每行最后一个全要 为啥子。。存疑??干哈去除??

        # 将训练集的data传入encoder模型,得到encoder的output
        enc_output, *_ = self.encoder(src_seq, src_pos)
        #将训练集的target、data以及encoder得到的output传入decoder,得到decoder的output
        dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
        #将decoder的输出结果进行一个线性变化再进行缩放
        seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale

        return seq_logit.view(-1, seq_logit.size(2))

其中:

  1. self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 线性函数解释如下:
    Attention is all you need源码学习3_第4张图片
    参考网址:https://pytorch.org/docs/master/nn.html#linear-layers
  2. nn.init.xavier_normal_()权重初始化参考网址:https://blog.csdn.net/dss_dssssd/article/details/83959474
  3. tgt_seq[:, :-1]是numpy里的切片操作,对比或理解可用如下程序实验,总之它去除了每行数据最后一个数据,但不知道为啥子。。
import numpy as np
data_list=[[1,2,3],[1,2,1],[3,4,5],[4,5,6],[5,6,7],[6,7,8],[6,7,9],[0,4,7],[4,6,0],[2,9,1],[5,8,7],[9,7,8],[3,7,9]]
a=np.array(data_list)
print(a) 
print("------取最后一个元素-------")
print(a[-1]) ###取最后一个元素
print("------除了最后一个取全部------")
print(a[:-1])  ### 除了最后一个取全部 
print("------除了每行最后一个取全部------")
print(a[:,:-1])  ### 除了每行最后一个取全部
print("------取从后向前(相反)的元素-----")
print(a[::-1]) ### 取从后向前(相反)的元素 
print("------取从下标为2的元素翻转读取------")
print(a[2::-1]) ### 取从下标为2的元素翻转读取

整个Transformer.py所搭建的如下模型,其中没展现参数,init决定了模型框架,forward里决定了数据传输,也就是下图的箭头。

4个函数

下面就将搭建Encoder和Decoder的模型,在介绍这两个模型之前,先铺垫好几个函数。
一个是关于mask的函数:mask就是 掩码 ,在我们这里的意思大概就是 对某些值进行掩盖,使其不产生效果 。Transformer模型里面涉及两种mask。分别是 padding mask (如1.)和 sequence mask (如2.)。其中, padding mask 在所有的scaled dot-product attention里面都需要用到,而 sequence mask 只有在decoder的self-attention里面用到。
另一个是关于位置嵌入的函数
1. padding mask
我们的每个批次输入序列长度是不一样的!也就是说,我们要对输入序列进行 对齐 !具体来说,就是给在较短的序列后面填充 0 。因为这些填充的位置,其实是没什么意义的,所以我们的attention机制 不应该把注意力放在这些位置上 ,所以我们需要进行一些处理。
具体的做法是, 把这些位置的值加上一个非常大的负数(可以是负无穷),这样的话,经过softmax,这些位置的概率就会接近0 !
而我们的padding mask实际上是一个张量,每个值都是一个 Boolen ,值为 False 的地方就是我们要进行处理的地方。
代码如下:

def get_attn_key_pad_mask(seq_k, seq_q):
    ''' For masking out the padding part of key sequence. '''

    # Expand to fit the shape of key query attention matrix.
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(Constants.PAD)
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk

    return padding_mask

2. Sequence mask
sequence mask是为了使得decoder不能看见未来的信息。也就是对于一个序列,在time_step为t的时刻,我们的解码输出应该只能依赖于t时刻之前的输出,而不能依赖t之后的输出。因此我们需要想一个办法,把t之后的信息给隐藏起来。
方法:产生一个上三角矩阵,上三角的值全为1,下三角的值权威0,对角线也是0 。把这个矩阵作用在每一个序列上,就可以达到我们的目的啦。
代码如下:

def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(  #输入序列都是批量的,所以把原本二维的矩阵扩张成3维的张量
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask

3. get_sinusoid_encoding_table
Positional encoding: 对序列中的词语出现的位置进行编码,使用正余弦函数:
Attention is all you need源码学习3_第5张图片
这个编码公式的意思就是: 给定词语的位置 ,我们可以把它编码成 维的向量 !在偶数位置,使用正弦编码,在奇数位置,使用余弦编码 。
代码如下:

def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
    ''' Sinusoid position encoding table '''

    def cal_angle(position, hid_idx):
        return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)

    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])

    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    if padding_idx is not None:
        # zero vector for padding dimension
        sinusoid_table[padding_idx] = 0.

    return torch.FloatTensor(sinusoid_table)

4. get_non_pad_mask

def get_non_pad_mask(seq):
    assert seq.dim() == 2
    return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)

Encoder与Decoder

接下来就是Encoder和Decoder模型嘞~

  1. Encoder
class Encoder(nn.Module):
    ''' A encoder model with self attention mechanism. '''

    def __init__(  #初始化,参数继承Transformer中的参数
            self,
            n_src_vocab, len_max_seq, d_word_vec,
            n_layers, n_head, d_k, d_v,
            d_model, d_inner, dropout=0.1):

        super().__init__()

        n_position = len_max_seq + 1  #位置信息=最长+1 防止溢出

        self.src_word_emb = nn.Embedding(  #词嵌入
            n_src_vocab, d_word_vec, padding_idx=Constants.PAD)

        self.position_enc = nn.Embedding.from_pretrained(  #存疑??没查到 但好像是positional encoding
            get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
            freeze=True)

        self.layer_stack = nn.ModuleList([  #6个EncoderLayer层
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])#此步是复制n个一模一样的EncoderLayer层

    def forward(self, src_seq, src_pos, return_attns=False):  #这里的return_attns不懂鸭!!在哪里改变他的值呢?代表什么呢?

        enc_slf_attn_list = []

        # -- Prepare masks
        #调用函数对其进行mask
        slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
        non_pad_mask = get_non_pad_mask(src_seq)

        # -- Forward
        #词嵌入+位置嵌入作为输出enc_output
        enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)

        for enc_layer in self.layer_stack:
            #mask后的作为enc_output和enc_slf_attn
            enc_output, enc_slf_attn = enc_layer(
                enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]  #将enc_slf_attn存入list

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output,

其中nn.Embedding()是一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索它们。模块的输入是一个下标的列表,输出是对应的词嵌入。

class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

参数:

  • num_embeddings (int) - 嵌入字典的大小
  • embedding_dim (int) - 每个嵌入向量的大小
  • padding_idx (int, optional) - 如果提供的话,输出遇到此下标时用零填充
  • max_norm (float,optional) - 如果提供的话,会重新归一化词嵌入,使它们的范数小于提供的值
  • norm_type (float, optional)
  • 对于max_norm选项计算p范数时的p
  • scale_grad_by_freq (boolean, optional) - 如果提供的话,会根据字典中单词频率缩放梯度
    参考官方文档torch.nn中Sparse layers的部分:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#sparse-layers
  1. Decoder
    与Encoder相似,但其多了一个Masked Multi-Head Attention层,这一层主要用到了4个函数中第3个函数Sequence mask(),代码如下:
class Decoder(nn.Module):
    ''' A decoder model with self attention mechanism. '''

    def __init__(
            self,
            n_tgt_vocab, len_max_seq, d_word_vec,
            n_layers, n_head, d_k, d_v,
            d_model, d_inner, dropout=0.1):

        super().__init__()
        n_position = len_max_seq + 1

        self.tgt_word_emb = nn.Embedding(
            n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD)

        self.position_enc = nn.Embedding.from_pretrained(
            get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
            freeze=True)

        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])

    def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):

        dec_slf_attn_list, dec_enc_attn_list = [], []

        # -- Prepare masks
        non_pad_mask = get_non_pad_mask(tgt_seq)

        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

        # -- Forward
        dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        if return_attns:
            return dec_output, dec_slf_attn_list, dec_enc_attn_list
        return dec_output,

其中Encoder里用到了EncoderLayer,Decoder里用到了DecoderLayer,下面就介绍Layers.py中的内容。
**

Layers.py

**
这里都是一层,决定几层(重复几遍)是在上面的代码做到的。

EncoderLayer

encoder由6层相同的层组成,每一层分别由两部分组成:
第一部分是一个 multi-head self-attention mechanism
第二部分是一个 position-wise feed-forward network ,是一个全连接层
两个部分,都有一个 残差连接(residual connection) ,然后接着一个 Layer Normalization 。

class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        # 多头注意力模型
        self.slf_attn = MultiHeadAttention(
            n_head, d_model, d_k, d_v, dropout=dropout)
        # 前馈层
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(  #使用多头注意力模型进行训练结果传给enc_output, enc_slf_attn
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output *= non_pad_mask  #non_pad_mask存疑,按照论文应该是Add&Norm那一步,但是不懂

        enc_output = self.pos_ffn(enc_output)  #总之就是残差并正则化的多头注意力模型作为FFN的输入
        enc_output *= non_pad_mask

        return enc_output, enc_slf_attn

DecoderLayer

和encoder类似,decoder由6个相同的层组成,每一个层包括以下3个部分:
第一个部分是 masked multi-head attention
第二部分是 multi-head self-attention mechanism
第三部分是一个 position-wise feed-forward network
还是和encoder类似,上面三个部分的每一个部分,都有一个残差连接 ,后接一个 Layer Normalization 。

class DecoderLayer(nn.Module):
    ''' Compose with three layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        #多的那一层也是用的多头注意力模型,只不过!mask用到的是get_subsequent_mask
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output *= non_pad_mask

        #这里可以看到两个都用到多头注意力模型,但是mask的值不一样
        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output *= non_pad_mask

        dec_output = self.pos_ffn(dec_output)
        dec_output *= non_pad_mask

        return dec_output, dec_slf_attn, dec_enc_attn

接下来就看一下 SubLayers.py 里构造的MultiHeadAttention和PositionwiseFeedForward。
**

SubLayers.py

**

MultiHeadAttention

多头注意力模型:将query、key和value分别用不同的、学到的线性映射h倍到dk、dk和dv维。基于每个映射版本的query、key和value,我们并行执行attention函数,产生dv 维输出值。 将它们连接并再次映射,产生最终值。
Attention is all you need源码学习3_第6张图片
公式如下:
Attention is all you need源码学习3_第7张图片
文中参数采用:h = 8个并行attention层或head。 对每个head,使用d_k=d_v=d_model ∕ h = 64。 由于每个head的大小减小,总的计算成本与具有全部维度的单个head attention相似。
代码如下:

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''
    # 构造多头注意力模型
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)  # Query
        self.w_ks = nn.Linear(d_model, n_head * d_k)  # Key
        self.w_vs = nn.Linear(d_model, n_head * d_v)  # Value
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))  # 初始化权重,服从正态分布mean为下限,std为上限
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        # 缩放的点积注意力模型
        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)  # 归一化

        self.fc = nn.Linear(n_head * d_v, d_model)  # 用headi=Attention(,,,)的公式算完的维度n_head * d_v作为输入,输出维度是d_model
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q  # 用于残差连接

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k)  # (n*b) x lq x dk->(n_head * sz_b, len_q, d_k)
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k)  # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v)  # (n*b) x lv x dv

        # 此代码的意思是不是对每个头进行缩放的点积注意力模型??
        mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)  # 缩放的点积注意力模型

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)  # b x lq x (n*dv)->(sz_b, len_q, n_head * d_v)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)  # Add & Norm层

        return output, attn
  1. 权重初始化参考上文给出的网址
  2. LayerNorm:channel方向做归一化,算CHW的均值,主要对RNN作用明显,参考网址:https://blog.csdn.net/shanglianlm/article/details/85075706
  3. q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k),permute(2, 0, 1, 3)是维度置换,将源数据第2列移到第0列,第0列->第1列,第1列->第2列,第3列不变;view只能用在contiguous的variable上,如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。 一种可能的解释是: 有些tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式。
  4. 代码中对多头那里的操作还是不太懂。。。这个意思是把h作为矩阵维度做索引吗?相当于分成了h块,每一个h拥有相同维度的Q、K、V,使用 mask = mask.repeat(n_head, 1, 1)重复对每一个h做缩放的点乘注意力?

PositionwiseFeedForward

这里是讲到Feed Forward,论文中的公式如下:
在这里插入图片描述
代码如下:

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Conv1d(d_in, d_hid, 1)  # position-wise
        self.w_2 = nn.Conv1d(d_hid, d_in, 1)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        output = x.transpose(1, 2)  # 转置
        output = self.w_2(F.relu(self.w_1(output)))  # 论文中公式FFN(x)
        output = output.transpose(1, 2)  # 呃呃?又转置嘞
        output = self.dropout(output)
        output = self.layer_norm(output + residual)  # Add & Norm
        return output

最后说一下前面提到多次的缩放的点乘注意力机制

Modules.py

ScaledDotProductAttention模型,它其实就是刚刚MultiHeadAttention中紫色的那一部分的内容,结构如下图所示:
Attention is all you need源码学习3_第8张图片
代码如下所示,与上图对应的:

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))  # k转置,q点乘k转置
        attn = attn / self.temperature   # 除以放缩因子

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)  # mask

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)  # 和V点乘

        return output, attn

其中mask_fill(mask, -np.inf)中的mask必须是一个 ByteTensor 而且shape必须和 attn一样 并且元素只能是0或者1,将mask中为1的元素所在的索引,在attn中相同的的索引处替换为 value。

你可能感兴趣的:(transformer,attention机制,pytorch,transformer)