Transformer代码简单实现2

https://blog.csdn.net/BXD1314/article/details/126187598
Transformer代码简单实现2_第1张图片

1.数据准备

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

device='cpu'

epochs=100

# S:显示  解码 输入开始  的符号
# E:显示  解码 输出开始  的符号
# P: 如果当前批次的数据量小于时间步数,将填写空白序列的符号。

#手动输入两对句子,索引也是硬凑的
# 训练集
sentences = [
    # 中文和英语的单词个数不要求相同
    # enc_input                dec_input                dec_output
    ['我 有 一 个 好 朋 友 P', 'S I have a good friend .', 'I have a good friend . E'],
    ['我 有 零 个 女 朋 友 P', 'S I have zero girl friend .', 'I have zero girl friend . E'],
    ['我 有 一 个 男 朋 友 P', 'S I have a boy friend .', 'I have a boy friend . E']
]

# 中文和英语的单词要分开建立词库
# Padding Should be Zero

src_vocab = {'P': 0, '我': 1, '有': 2, '一': 3,'个': 4, '好': 5, '朋': 6, '友': 7, '零': 8, '女': 9, '男': 10}
src_idx2word = {i: w for i, w in enumerate(src_vocab)}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P': 0, 'I': 1, 'have': 2, 'a': 3, 'good': 4,'friend': 5, 'zero': 6, 'girl': 7,  'boy': 8, 'S': 9, 'E': 10, '.': 11}
tgt_idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len=8   #enc_input最大序列长度
tgt_len=7   #dec_input(=dec_output) 最大序列长度

#   Transformer参数
d_model=512 # Embedding Size(token embedding和position编码的维度)

# FeedForward dimension (两次线性层中的隐藏层 512->2048->512,线性层是用来做特征提取的),当然最后会再接一个projection层
"""
d_model:我们需要定义embeding 的维度,论文中设置的512
dim_ffn: FeedForward 层隐藏神经元个数
d_k = d_v: Q、K、V 向量的维度,其中 Q 与 K 的维度必须相等,V 的维度没有限制,我们都设为 64
n_layers:Encoder 和 Decoder 的个数,也就是图中的Nx
n_heads:多头注意力中 head 的数量
"""
dim_ffn=2048 #FeedForward 层隐藏神经元个数
d_k=d_v=64  # dimension of K(=Q), V(Q和K的维度需要相同,这里为了方便让K=V)
n_layers=6  # number of Encoder of Decoder Layer(Block的个数)
n_heads=8   # number of heads in Multi-Head Attention(有几套头)

2.构建数据

def make_data(sentences):
    """将单词序列转换为数字序列"""
    enc_inputs,dec_inputs,dec_outputs=[],[],[]
    for i in range(len(sentences)):
        enc_input=[[src_vocab[n] for n in sentences[i][0].split()]]
        dec_input=[[tgt_vocab[n] for n in sentences[i][1].split()]]
        dec_output=[[tgt_vocab[n] for n in sentences[i][2].split()]]

        # 我 有 一 个 好 朋 友 P:[[1, 2, 3, 4, 5, 6, 7, 0],
        # 我 有 零 个 女 朋 友 P:[1, 2, 8, 4, 9, 6, 7, 0],
        # 我 有 一 个 男 朋 友 P:[1, 2, 3, 4, 10, 6, 7, 0]]
        enc_inputs.extend(enc_input)
        # S I have a good friend .:[[9, 1, 2, 3, 4, 5, 11], [9, 1, 2, 6, 7, 5, 11], [9, 1, 2, 3, 8, 5, 11]]===dec_inputs
        # S开头
        dec_inputs.extend(dec_input)
        # I have a good friend .E:[[1, 2, 3, 4, 5, 11, 10], [1, 2, 6, 7, 5, 11, 10], [1, 2, 3, 8, 5, 11, 10]]dec_outputs
        # E结尾
        dec_outputs.extend((dec_output))

    return torch.LongTensor(enc_inputs),torch.LongTensor(dec_inputs),torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

#   2.自定义一个MyDataSet去读取这些句子
class MyDataSet(Data.Dataset):
    """自定义DataLoader"""
    def __init__(self,enc_inputs,dec_inputs,dec_outputs):
        super(MyDataSet,self).__init__()
        self.enc_inputs=enc_inputs
        self.dec_inputs=dec_inputs
        self.dec_outputs=dec_outputs

    # 我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:
    #
    # __len__方法,能够实现通过全局的len()  方法获取其中的元素个数
    #
    # __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]  获取其中的第i条数据
    # 注意缩进问题!!!!!!!!!!!
    def __len__(self):
        return self.enc_inputs.shape[0]

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]


#DataLoader进行封装:dataset,batch_size,shuffle(是否打乱)
loader = Data.DataLoader(
    MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)#这里报错,改成false了?

3.Transformer模型

预览:
·Positional Encoding
·Pad Mask(针对句子不够长,加了 pad,因此需要对 pad 进行 mask)
·Subsequence Mask(Decoder input 不能看到未来时刻单词信息,因此需要 mask)
·ScaledDotProductAttention(计算 context vector)
·Multi-Head Attention
·FeedForward Layer
·Encoder Layer
·Encoder
·Decoder Layer
·Decoder
·Transformer

Transformer 是并行输入计算的,需要知道每个字的位置信息,才能识别出语言中的顺序关系。
首先你需要知道,Transformer 是以字作为输入,将字进行字嵌入之后,再与位置嵌入进行相加(不是拼接,就是单纯的对应位置上的数值进行加和)
Transformer代码简单实现2_第2张图片

3.1 位置编码

字编码:将维度变为3维的。字向量训练或预训练等等得到。x维
位置编码:没有用到循环神经网络(有位置关系),在transformer中不训练(Bert中训练)。y维
x=y维,才可以相加
在这里插入图片描述
在这里插入图片描述

# 位置编码
class PositionEncoding(nn.Module):
    def __init__(self,d_model,dropout=0.1,max_len=5000):
        super(PositionEncoding,self).__init__()
        self.dropout=nn.Dropout(p=dropout)

        pe = torch.zeros(max_len,d_model)
        position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0) / d_model))
        pe[:,0::2] = torch.sin(position*div_term)
        pe[:,1::2] = torch.cos(position*div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe',pe)

    def forward(self, x):
        """
            x: [seq_len, batch_size, d_model]
        """
        x=x+self.pe[:x.size(0),:]
        return self.dropout(x)

3.2 pad mask

针对句子不够长,加了 pad,因此需要对 pad 进行 mask
Transformer代码简单实现2_第3张图片

iter1:【batch_size=32,sequence_length=28】32个句子,每个句子都是28个单词
iter2:【batch_size=32,sequence_length=32】
不同batch之间句子长度可以不一样,但是每个batch的长度必须是一样的:因此出现一个问题,不够长度需要加pad,使得其长度变成一样。
阴影部分是没有意义的,希望它是0,以便后续的softmax等操作
在这里插入图片描述Transformer代码简单实现2_第4张图片
由于在 Encoder 和 Decoder 中都需要进行 mask(和矩阵原大小一样,有问题的地方加负无穷) 操作,因此就无法确定这个函数的参数中 seq_len 的值,如果是在 Encoder 中调用的,seq_len 就等于 src_len;如果是在 Decoder 中调用的,seq_len 就有可能等于 src_len,也有可能等于 tgt_len(因为 Decoder 有两次 mask)

#   4.pad mask
def get_attn_pad_mask(seq_q,seq_k):
    # pad mask的作用:在对value向量加权平均的时候,可以让pad对应的alpha_ij=0,这样注意力就不会考虑到pad向量
    """这里的q,k表示的是两个序列(跟注意力机制的q,k没有关系),例如encoder_inputs (x1,x2,..xm)和encoder_inputs (x1,x2..xm)
    encoder和decoder都可能调用这个函数,所以seq_len视情况而定
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    """
    batch_size,len_q=seq_q.size()# 这个seq_q只是用来expand维度的
    batch_size,len_k=seq_k.size()

    # eq(zero) is PAD token
    # 例如:seq_k = [[1,2,3,4,0], [1,2,3,5,0]]
    # [batch_size, 1, len_k], True is masked

    # seq_k.data.eq(0)核心代码
    # 返回一个大小和 seq_k 一样的 tensor,只不过里面的值只有 True 和 False。
    # 如果 seq_k 某个位置的值等于 0,那么对应位置就是 True,否则即为 False。
    # 举个例子,输入为 seq_data = [1, 2, 3, 4, 0],seq_data.data.eq(0) 就会返回 [False, False, False, False, True]
    pad_attn_mask=seq_k.data.eq(0).unsqueeze(1)
    # [batch_size, len_q, len_k] 构成一个立方体(batch_size个这样的矩阵)
    return pad_attn_mask.expand(batch_size,len_q,len_k)

3.3 Subsequence Mask

不能看到未来时刻单词信息,因此需要 mask。
Subsequence Mask 只有 Decoder 会用到,主要作用是屏蔽未来时刻单词的信息。首先通过 np.ones() 生成一个全 1 的方阵,然后通过 np.triu() 生成一个上三角矩阵

#   5.屏蔽子序列的mask
def get_attn_subsequence_mask(seq):
    """
        建议打印出来看看是什么的输出(一目了然)
        seq: [batch_size, tgt_len]
        get_attn_subsequence_mask 只有 Decoder 会用到,
        主要作用是屏蔽未来时刻单词的信息。首先通过 np.ones() 生成一个全 1 的方阵,
        然后通过 np.triu() 生成一个上三角矩阵
    """
    attn_shape=[seq.size(0),seq.size(1),seq.size(1)]
    # attn_shape: [batch_size, tgt_len, tgt_len]
    # 生成一个上三角矩阵
    subsequence_mask=np.triu(np.ones(attn_shape),k=1)
    subsequence_mask=torch.from_numpy(subsequence_mask).byte()

    # [batch_size, tgt_len, tgt_len]
    return subsequence_mask

3.4 ScaledDotProductAttention

这里要做的是,通过 Q 和 K 计算出 scores,然后将 scores 和 V 相乘,得到每个单词的 context vector

将 Q 和 K 的转置相乘,相乘之后得到的 scores 还不能立刻进行 softmax,需要和 attn_mask 相加,把一些需要屏蔽的信息屏蔽掉,attn_mask 是一个仅由 True 和 False 组成的 tensor,并且一定会保证 attn_mask 和 scores 的维度四个值相同(不然无法做对应位置相加)

mask 完了之后,就可以对 scores 进行 softmax 了。然后再与 V 相乘,得到 context

#   6.ScaledDotProductAttention
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention,self).__init__()

    def forward(self, Q,K,V,attn_mask):
        """
            Q: [batch_size, n_heads, len_q, d_k]
            K: [batch_size, n_heads, len_k, d_k]
            V: [batch_size, n_heads, len_v(=len_k), d_v]
            attn_mask: [batch_size, n_heads, seq_len, seq_len]
            说明:在encoder-decoder的Attention层中len_q(q1,..qt)和len_k(k1,...km)可能不同

            通过 Q 和 K 计算出 scores,然后将 scores 和 V 相乘,得到每个单词的 context vector

            第一步是将 Q 和 K 的转置相乘没什么好说的,
            相乘之后得到的 scores 还不能立刻进行 softmax,需要和 attn_mask 相加,把一些需要屏蔽的信息屏蔽掉,
            attn_mask 是一个仅由 True 和 False 组成的 tensor,
            并且一定会保证 attn_mask 和 scores 的维度四个值相同(不然无法做对应位置相加)

            mask 完了之后,就可以对 scores 进行 softmax 了。然后再与 V 相乘,得到 context
        """
        scores=torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(d_k)# scores : [batch_size, n_heads, len_q, len_k]
        # mask矩阵填充scores(用-1e9填充scores中与attn_mask中值为1位置相对应的元素)
        scores.masked_fill_(attn_mask,-1e9)# Fills elements of self tensor with value where mask is True.

        attn=nn.Softmax(dim=-1)(scores)# 对最后一个维度(v)做softmax
        # scores : [batch_size, n_heads, len_q, len_k] * V: [batch_size, n_heads, len_v(=len_k), d_v]
        context=torch.matmul(attn,V)# context: [batch_size, n_heads, len_q, d_v]
        # context:[[z1,z2,...],[...]]向量, attn注意力稀疏矩阵(用于可视化的)
        return context,attn

3.5 PoswiseFeedForwardNet

#   7.PoswiseFeedForwardNet = feed forward + Add&Norm
# Pytorch中的Linear只会对最后一维操作,所以正好是我们希望的每个位置用同一个全连接网络
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc=nn.Sequential(
            nn.Linear(d_model,dim_ffn,bias=False),
            nn.ReLU(),
            nn.Linear(dim_ffn, d_model, bias=False)
        )

    def forward(self, inputs):
        """
           inputs: [batch_size, seq_len, d_model]
        """
        residual=inputs
        output=self.fc(inputs)
        return nn.LayerNorm(d_model).to(device)(output+residual)# [batch_size, seq_len, d_model]

3.6 MultiHeadAttention

X分别乘以W得Qi,Ki,Vi
QK相乘除以根号dk,整体softmax
乘以V,得到Z:Transformer代码简单实现2_第5张图片Transformer代码简单实现2_第6张图片
有几个头就有几个Z:
Transformer代码简单实现2_第7张图片

#   8.MultiHeadAttention
class MultiHeadAttention(nn.Module):
    """这个Attention类可以实现:
        Encoder的Self-Attention
        Decoder的Masked Self-Attention
        Encoder-Decoder的Attention
        输入:seq_len x d_model
        输出:seq_len x d_model

        完整代码中一定会有三处地方调用 MultiHeadAttention(),
        Encoder Layer 调用一次,传入的 input_Q、input_K、input_V 全部都是 enc_inputs;
        Decoder Layer 中两次调用,第一次传入的全是 dec_inputs,
        第二次传入的分别是 dec_outputs,enc_outputs,enc_outputs
    """
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q=nn.Linear(d_model,d_k * n_heads,bias=False) # q,k必须维度相同,不然无法做点积
        self.W_K=nn.Linear(d_model,d_k * n_heads,bias=False)
        self.W_V=nn.Linear(d_model,d_v * n_heads,bias=False)
        # 这个全连接层可以保证多头attention的输出仍然是seq_len x d_model
        self.fc=nn.Linear(n_heads * d_v,d_model,bias=False)

    def forward(self, input_Q,input_K,input_V,attn_mask):
        """
                input_Q: [batch_size, len_q, d_model]
                input_K: [batch_size, len_k, d_model]
                input_V: [batch_size, len_v(=len_k), d_model]
                attn_mask: [batch_size, seq_len, seq_len]
                """
        residual , batch_size=input_Q,input_Q.size(0)
        # 下面的多头的参数矩阵是放在一起做线性变换的,然后再拆成多个头,这是工程实现的技巧
        # B: batch_size, S:seq_len, D: dim
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, Head, W) -trans-> (B, Head, S, W)
        #           线性变换               拆成多头

        # Q: [batch_size, n_heads, len_q, d_k]
        Q=self.W_Q(input_Q).view(batch_size,-1,n_heads,d_k).transpose(1,2)
        # K: [batch_size, n_heads, len_k, d_k] # K和V的长度一定相同,维度可以不同
        K=self.W_K(input_K).view(batch_size,-1,n_heads,d_k).transpose(1,2)
        # V: [batch_size, n_heads, len_v(=len_k), d_v]
        V=self.W_V(input_V).view(batch_size,-1,n_heads,d_v).transpose(1,2)

        # 因为是多头,所以mask矩阵要扩充成4维的
        # attn_mask: [batch_size, seq_len, seq_len] -> [batch_size, n_heads, seq_len, seq_len]
        attn_mask=attn_mask.unsqueeze(1).repeat(1,n_heads,1,1)

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context,attn=ScaledDotProductAttention()(Q,K,V,attn_mask)
        # 下面将不同头的输出向量拼接在一起
        # context: [batch_size, n_heads, len_q, d_v] -> [batch_size, len_q, n_heads * d_v]
        context=context.transpose(1,2).reshape(
            batch_size,-1,n_heads * d_v)

        # 这个全连接层可以保证多头attention的输出仍然是seq_len x d_model
        # [batch_size, len_q, d_model]
        output=self.fc(context)
        return nn.LayerNorm(d_model).to(device)(output+residual),attn

3.7 Encoder

https://wmathor.com/index.php/archives/1438/

Transformer代码简单实现2_第8张图片
Transformer代码简单实现2_第9张图片
Transformer代码简单实现2_第10张图片Transformer代码简单实现2_第11张图片

Transformer代码简单实现2_第12张图片
relu激活

#   9.EncoderLayer
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer,self).__init__()
        self.enc_self_attn=MultiHeadAttention()
        self.pos_ffn=PoswiseFeedForwardNet()

    def forward(self, enc_inputs,enc_self_attn_mask):
        """E
                enc_inputs: [batch_size, src_len, d_model]
                enc_self_attn_mask: [batch_size, src_len, src_len]  mask矩阵(pad mask or sequence mask)
        """
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        # 第一个enc_inputs * W_Q = Q
        # 第二个enc_inputs * W_K = K
        # 第三个enc_inputs * W_V = V
        # enc_inputs to same Q,K,V(未线性变换前)
        enc_outputs,attn=self.enc_self_attn(enc_inputs,enc_inputs,enc_inputs,enc_self_attn_mask)
        enc_outputs=self.pos_ffn(enc_outputs)

        # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs,attn


#   10.Encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb=nn.Embedding(src_vocab_size,d_model)# token Embedding
        self.pos_emb=PositionEncoding(d_model) # Transformer中位置编码时固定的,不需要学习
        self.layers=nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        """
                enc_inputs: [batch_size, src_len]
        """
        enc_outputs=self.src_emb(enc_inputs)# [batch_size, src_len, d_model]
        enc_outputs=self.pos_emb(enc_outputs.transpose(0,1)).transpose(0,1) # [batch_size, src_len, d_model]
        # Encoder输入序列的pad mask矩阵
        enc_self_attn_mask=get_attn_pad_mask(enc_inputs,enc_inputs)# [batch_size, src_len, src_len]
        # 在计算中不需要用到,它主要用来保存你接下来返回的attention的值(这个主要是为了你画热力图等,用来看各个词之间的关系
        enc_self_attns=[]
        for layer in self.layers:# for循环访问nn.ModuleList对象,上一个block的输出enc_outputs作为当前block的输入
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            # 传入的enc_outputs其实是input,传入mask矩阵是因为你要做self attention
            enc_outputs,enc_self_attn=layer(enc_outputs,enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)# 这个只是为了可视化
        return enc_outputs,enc_self_attns

3.8 Decoder

Transformer代码简单实现2_第13张图片
Transformer代码简单实现2_第14张图片

1.关于第一个masked multi-Head Attention:

Transformer代码简单实现2_第15张图片

Transformer代码简单实现2_第16张图片
之后再做 softmax,就能将 - inf 变为 0,得到的这个矩阵即为每个字之间的权重
Transformer代码简单实现2_第17张图片

2.关于第二个multi-Head Attention:在这里插入图片描述

Transformer代码简单实现2_第18张图片

#   11.Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn=MultiHeadAttention()
        self.dec_enc_attn=MultiHeadAttention()
        self.pos_ffn=PoswiseFeedForwardNet()

    def forward(self, dec_inputs,enc_outputs,dec_self_attn_mask,dec_enc_attn_mask):
        """
                dec_inputs: [batch_size, tgt_len, d_model]
                enc_outputs: [batch_size, src_len, d_model]
                dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
                dec_enc_attn_mask: [batch_size, tgt_len, src_len]
                """
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        # 这里的Q,K,V全是Decoder自己的输入
        dec_outputs,dec_self_attn=self.dec_self_attn(dec_inputs,dec_inputs,dec_inputs,dec_self_attn_mask)

        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        # Attention层的Q(来自decoder) 和 K,V(来自encoder)
        dec_outputs,dec_enc_attn=self.dec_enc_attn(dec_outputs,enc_outputs,enc_outputs,dec_enc_attn_mask)

        # [batch_size, tgt_len, d_model]
        dec_outputs=self.pos_ffn(dec_outputs)

        # dec_self_attn, dec_enc_attn这两个是为了可视化的
        return dec_outputs,dec_self_attn,dec_enc_attn

#   12.Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb=nn.Embedding(tgt_vocab_size,d_model)# Decoder输入的embed词表
        self.pos_emb=PositionEncoding(d_model)
        # Decoder的blocks
        self.layers=nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs,enc_inputs,enc_outputs):
        """
            dec_inputs: [batch_size, tgt_len]
            enc_inputs: [batch_size, src_len]
            enc_outputs: [batch_size, src_len, d_model]   # 用在Encoder-Decoder Attention层
        """
        # [batch_size, tgt_len, d_model]
        dec_outputs=self.tgt_emb(dec_inputs)
        # [batch_size, tgt_len, d_model]
        dec_outputs=self.pos_emb(dec_outputs.transpose(0,1)).transpose(0,1).to(device)

        # Decoder输入序列的pad mask矩阵(这个例子中decoder是没有加pad的,实际应用中都是有pad填充的)
        dec_self_attn_pad_mask=get_attn_pad_mask(dec_inputs,dec_inputs).to(device)# [batch_size, tgt_len, tgt_len]

        # Masked Self_Attention:当前时刻是看不到未来的信息的
        dec_self_attn_subsequence_mask=get_attn_subsequence_mask(dec_inputs).to(device)# [batch_size, tgt_len, tgt_len]

        # Decoder中把两种mask矩阵相加(既屏蔽了pad的信息,也屏蔽了未来时刻的信息)
        dec_self_attn_mask=torch.gt((dec_self_attn_pad_mask+dec_self_attn_subsequence_mask),0).to(device)

        # 这个mask主要用于encoder-decoder attention层
        # get_attn_pad_mask主要是enc_inputs的pad mask矩阵
        # (因为enc是处理K,V的,求Attention时是用v1,v2,..vm去加权的,要把pad对应的v_i的相关系数设为0,这样注意力就不会关注pad向量)
        # dec_inputs只是提供expand的size的
        dec_enc_attn_mask=get_attn_pad_mask(dec_inputs,enc_inputs)# [batc_size, tgt_len, src_len]

        dec_self_attns,dec_enc_attns=[],[]
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model],
            # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len],
            # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            # Decoder的Block是上一个Block的输出dec_outputs(变化)和Encoder网络的输出enc_outputs(固定)
            dec_outputs,dec_self_attn,dec_enc_attn=layer(dec_outputs,enc_outputs,dec_self_attn_mask,dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)

            # dec_outputs: [batch_size, tgt_len, d_model]
            return dec_outputs,dec_self_attns,dec_enc_attns

3.9 transformer

Transformer代码简单实现2_第19张图片

#   13.Transformer
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder=Encoder().to(device)
        self.decoder=Decoder().to(device)
        self.projection=nn.Linear(d_model,tgt_vocab_size,bias=False).to(device)

    def forward(self, enc_inputs,dec_inputs):
        """
            Transformers的输入:两个序列
            enc_inputs: [batch_size, src_len]
            dec_inputs: [batch_size, tgt_len]
        """
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)

        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        # 经过Encoder网络后,得到的输出还是[batch_size, src_len, d_model]
        enc_outputs,enc_self_attns=self.encoder(enc_inputs)

        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs,dec_self_attns,dec_enc_attns=self.decoder(dec_inputs,enc_inputs,enc_outputs)

        # dec_outputs: [batch_size, tgt_len, d_model] -> dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        dec_logits=self.projection(dec_outputs)
        return dec_logits.view(-1,dec_logits.size(-1)),enc_self_attns,dec_self_attns,dec_enc_attns
#   14.train
model=Transformer().to(device)
# 这里的损失函数里面设置了一个参数 ignore_index=0,因为 "pad" 这个单词的索引为 0,
# 这样设置以后,就不会计算 "pad" 的损失(因为本来 "pad" 也没有意义,不需要计算)
criterion=nn.CrossEntropyLoss(ignore_index=0)
optimizer=optim.SGD(model.parameters(),lr=1e-3,momentum=0.99)# 用adam的话效果不好

for epoch in range(epochs):
    for enc_inputs,dec_inputs,dec_outputs in loader:
        """
            enc_inputs: [batch_size, src_len]
            dec_inputs: [batch_size, tgt_len]
            dec_outputs: [batch_size, tgt_len]
        """
        enc_inputs,dec_inputs,dec_outputs=enc_inputs.to(device),dec_inputs.to(device),dec_outputs.to(device)

        # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs,enc_self_attns,dec_self_attns,dec_enc_attns=model(enc_inputs,dec_inputs)

        # dec_outputs.view(-1):[batch_size * tgt_len * tgt_vocab_size]
        loss=criterion(outputs,dec_outputs.view(-1))
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#   15.贪心编码
def greedy_decoder(model,enc_input,start_symbol):
    """
    贪心编码
    为了简单起见,当K=1时,贪婪解码器是Beam搜索。这对于推理是必要的,因为我们不知道
    目标序列的输入。因此,我们试图逐字生成目标输入,然后将其送入变换器。
    开始参考:http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: 变换器模型
    :param enc_input: 编码器的输入
    :param start_symbol: 开始符号。在这个例子中,它是 "S",对应于索引4。
    :return: 目标输入
    """
    enc_outputs,enc_self_attns=model.encoder(enc_input)
    # 初始化一个空的tensor: tensor([], size=(1, 0), dtype=torch.int64)
    dec_input=torch.zeros(1,0).type_as(enc_inputs.data)
    terminal=False
    next_symbol=start_symbol
    while not terminal:
        # 预测阶段:dec_input序列会一点点变长(每次添加一个新预测出来的单词)
        dec_input = torch.cat([dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
                              -1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        # 增量更新(我们希望重复单词预测结果是一样的)
        # 我们在预测是会选择性忽略重复的预测的词,只摘取最新预测的单词拼接到输入序列中
        # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率,不用z_1,z_2..z_{t-1}
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["E"]:
            terminal = True
        # print(next_word)

        # greedy_dec_predict = torch.cat(
        #     [dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
        #     -1)
    greedy_dec_predict = dec_input[:, 1:]
    return greedy_dec_predict


# ==========================================================================================
# 预测阶段
# 测试集
sentences = [
    # enc_input                dec_input           dec_output
    ['我 有 零 个 女 朋 友 P', '', '']
]

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
test_loader = Data.DataLoader(
    MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
enc_inputs, _, _ = next(iter(test_loader))

print()
print("=" * 30)
print("利用训练好的Transformer模型将中文句子'我 有 零 个 女 朋 友' 翻译成英文句子: ")
for i in range(len(enc_inputs)):
    greedy_dec_predict = greedy_decoder(model, enc_inputs[i].view(
        1, -1).to(device), start_symbol=tgt_vocab["S"])
    print(enc_inputs[i], '->', greedy_dec_predict.squeeze())
    print([src_idx2word[t.item()] for t in enc_inputs[i]], '->',
          [tgt_idx2word[n.item()] for n in greedy_dec_predict.squeeze()])







整体(简洁版):

# -*- coding: utf-8 -*-
"""Transformer-Torch

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/15yTJSjZpYuIWzL9hSbyThHLer4iaJjBD
"""

'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
              https://github.com/JayParks/transformer
'''
import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
        # enc_input           dec_input         dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5 # enc_input max sequence length
tgt_len = 6 # dec_input(=dec_output) max sequence length

# Transformer Parameters
d_model = 512  # Embedding Size
d_ff = 2048 # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
      dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
      dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs
  
  def __len__(self):
    return self.enc_inputs.shape[0]
  
  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model)(output + residual), attn

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model)(output + residual) # [batch_size, seq_len, d_model]

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, tgt_len, d_model]
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]

        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

model = Transformer()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
      '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
      enc_inputs, dec_inputs, dec_outputs = enc_inputs, dec_inputs, dec_outputs
      # outputs: [batch_size * tgt_len, tgt_vocab_size]
      outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
      loss = criterion(outputs, dec_outputs.view(-1))
      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

def greedy_decoder(model, enc_input, start_symbol):
    """
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:         
        dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype)],-1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["."]:
            terminal = True
        print(next_word)            
    return dec_input

# Test
enc_inputs, _, _ = next(iter(loader))
enc_inputs = enc_inputs
for i in range(len(enc_inputs)):
    greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
    predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
    predict = predict.data.max(1, keepdim=True)[1]
    print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])



你可能感兴趣的:(代码练习,深度学习,python,transformer,深度学习,pytorch)