Transformer原理和代码详解

个人其他链接

github
blog

资源

  • 完整代码+详细代码注释:github

  • 参考论文: Attention Is All You Need

  • 参考实现 tensorflow2.0 offical tutorials/text/transformer

原理

Transformer模型来自论文Attention Is All You Need。这个模型的应用场景是机器翻译,借助Self-Attention机制和Position Encoding可以替代传统Seq2Seq模型中的RNN结构。由于Transformer的优异表现,后续OpenAI GPT和BERT模型都使用了Transformer的Decoder部分。

Transformer算法流程:

输入:inputs, targets

举个例子:
inputs = ‘SOS 想象力 比 知识 更 重要 EOS’
targets = ‘SOS imagination is more important than knowledge EOS’

训练

训练时采用强制学习
inputs = ‘SOS 想象力 比 知识 更 重要 EOS’
targets = ‘SOS imagination is more important than knowledge’

目标(targets)被分成了 tar_inp 和 tar_real。tar_inp 作为输入传递到Decoder。tar_real 是位移了 1 的同一个输入:在 tar_inp 中的每个位置,tar_real 包含了应该被预测到的下一个标记(token)。
tar_inp = ‘SOS imagination is more important than knowledge’
tar_real = ‘imagination is more important than knowledge EOS’

即inputs经过Encoder编码后得到inputs的信息,targets开始输入SOS 向后Decoder翻译预测下一个词的概率,由于训练时采用强制学习,所以用真实值来预测下一个词。

预测输出

tar_pred = ‘imagination is more important than knowledge EOS’
当然这是希望预测最好的情况,即真实tar_real就是这样。实际训练时开始不会预测这么准确

损失:交叉熵损失

根据tar_pred和tar_real得到交叉熵损失

模型训练好后如何预测?

其中SOS为标志句子开始的标志符号,EOS为标志结束的符号

Encoder阶段:inputs = ‘SOS 想象力 比 知识 更 重要 EOS’
Decoder阶段:循环预测
输入一个[SOS, ],预测到下一个token为:imagination
输入[SOS, imagination], 预测下一个token为:is

输入[SOS, imagination is more important than knowledge]预测下一个EOS。最终结束
结束有两个条件,预测到EOS,或者最长的target_seq_len

网络结构

原始论文网络结构

自己实现的网络结构:

Encoder部分:



下面伪代码中的解释:
MultiHeadAttention(v, k, q, mask)

Encoder block
包括两个子层:

  1. 多头注意力(有填充遮挡)
  2. 点式前馈网络(Point wise feed forward networks), 其实就是两层全连接

输入x为input_sentents, (batch_size, seq_len, d_model)

  • out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
  • out2 = BatchNormalization( out1 + (ffn(out1) => dropout) )

Decoder部分:
和Encoder部分区别在于,Decoder部分先对自身做了Self-Attention后,在作为query,对Encoder的输出作为key和value,进行普通Attention后的结果,作为 feed forward的输入

Decoder block,需要的子层:

  1. 遮挡的多头注意力(前瞻遮挡和填充遮挡)
  2. 多头注意力(用填充遮挡)。V(数值)和 K(主键)接收编码器输出作为输入。Q(请求)接收遮挡的多头注意力子层的输出。
  3. 点式前馈网络

输入x为target_sentents, (batch_size, seq_len, d_model)

  • out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
  • out2 = BatchNormalization( out1 +(MultiHeadAttention(enc_output, enc_output out1)=>dropout))
  • out3 = BatchNormalization( out2 + (ffn(out2) => dropout) )

具体代码实现

Position

def get_angles(pos, i, d_model):
    '''
    :param pos:单词在句子的位置
    :param i:单词在词表里的位置
    :param d_model:词向量维度大小
    :return:
    '''
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    '''
    :param position: 最大的position
    :param d_model: 词向量维度大小
    :return: [1, 最大position个数,词向量维度大小] 最后和embedding矩阵相加
    '''
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

point_wise_feed_forward_network

def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])

Attention

其中MultiHeadAttention其实是在d_model(词embedding维度)进行split,然后做Attention

def scaled_dot_product_attention(q, k, v, mask=None):
    '''计算attention
    q,k,v的第一维度必须相同
    q,k的最后一维必须相同
    k,v在倒数第二的维度需要相同, seq_len_k = seq_len_q=seq_len。
    参数:
    q: 请求的形状 == (..., seq_len_q, d)
    k: 主键的形状 == (..., seq_len, d)
    v: 数值的形状 == (..., seq_len, d_v)
    mask: Float 张量,其形状能转换成
          (..., seq_len_q, seq_len)。默认为None。
    返回值:
    输出,注意力权重
    '''
    # (batch_size, num_heads, seq_len_q, d ) dot (batch_size, num_heads, d, seq_ken_k) = (batch_size, num_heads,, seq_len_q, seq_len)
    matmul_qk = tf.matmul(q, k, transpose_b=True)

    # 缩放matmul_qk
    dk = tf.cast(tf.shape(k)[-1], dtype=tf.float32)
    scaled_attention_logits = matmul_qk/tf.math.sqrt(dk)

    # 将 mask 加入到缩放的张量上。
    if mask is not None:
        # (batch_size, num_heads,, seq_len_q, seq_len) + (batch_size, 1,, 1, seq_len)
        scaled_attention_logits += (mask * -1e9)

    # softmax归一化权重 (batch_size, num_heads, seq_len)
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

    # seq_len_q个位置分别对应v上的加权求和
    # (batch_size, num_heads, seq_len) dot (batch_size, num_heads, d_v) = (batch_size, num_heads, seq_len_q, d_v)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights

class MultiHeadAttention(tf.keras.layers.Layer):

    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert (d_model > num_heads) and (d_model % num_heads == 0)
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.qw = tf.keras.layers.Dense(d_model)
        self.kw = tf.keras.layers.Dense(d_model)
        self.vw = tf.keras.layers.Dense(d_model)
        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) # (batch_size, seq_len, num_heads,  depth)
        return tf.transpose(x, perm=(0, 2, 1, 3)) # (batch_size, num_heads, seq_len, depth)


    def call(self, v, k, q, mask=None):
        # v = inputs
        batch_size = tf.shape(q)[0]

        q = self.qw(q)  # (batch_size, seq_len_q, d_model)
        k = self.kw(k)  # (batch_size, seq_len, d_model)
        v = self.vw(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len, depth_v)

        # scaled_attention, (batch_size, num_heads, seq_len_q, depth_v)
        # attention_weights, (batch_size, num_heads, seq_len_q, seq_len)
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=(0, 2, 1, 3)) # (batch_size, seq_len_q, num_heads, depth_v)
        concat_attention = tf.reshape(scaled_attention, shape=(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output, attention_weights

Encoder

输入:

  • inputs(batch_size, seq_len_inp, d_model)
  • mask(batch_size, 1, 1, seq_len_inp),因为输入序列要填充到相同的长度,所以对填充的位置做self-attention时要做mask,这里之所以是(batch_size, 1, 1, d_model)的维度,是因为inputs做MultiHeadAttention会split成(batch_size, num_heads, seq_len_inp, d_model//num_heads),经过MultiHeadAttention计算的权重是(batch_size, num_heads, seq_len_inp, seq_len_inp ),这样做mask时,mask会自动传播成:(batch_size, num_heads, seq_len_inp, seq_len_inp )

输出:

  • encode_output(batch_size, seq_len_inp, d_model)
class EncoderLayer(tf.keras.layers.Layer):
    '''Encoder block
    包括两个子层:1.多头注意力(有填充遮挡)2.点式前馈网络(Point wise feed forward networks)。
    out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
    out2 = BatchNormalization( out1 + (ffn(out1) => dropout) )
    '''
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)
        self.layer_norm1 = tf.keras.layers.BatchNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.BatchNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layer_norm1(x+attn_output) # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layer_norm2(out1+ffn_output) # (batch_size, input_seq_len, d_model)
        return out2

class Encoder(tf.keras.layers.Layer):
    '''
    输入嵌入(Input Embedding)
    位置编码(Positional Encoding)
    N 个编码器层(encoder layers)
    输入经过嵌入(embedding)后,该嵌入与位置编码相加。该加法结果的输出是编码器层的输入。编码器的输出是解码器的输入。
    '''
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.enc_layer = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        # x.shape == (batch_size, seq_len)
        seq_len = tf.shape(x)[1]
        x = self.embedding(x) # (batch_size, input_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, dtype=tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.enc_layer[i](x, training, mask)
        return  x #(batch_size, input_seq_len, d_model)

Decoder

输入:

  • targets_inp(batch_size, seq_len_tar, d_model)
  • encode_output(batch_size, seq_len_inp, d_model)
  • self_mask(batch_size, 1, 1, seq_len_tar), enc_output_mask(batch_size, 1, 1, seq_len_inp)

输出:

  • decode_output(batch_size, seq_len_tar, tar_vobsize)
class DecoderLayer(tf.keras.layers.Layer):
    ''' Decoder block
    需要的子层:
    1.遮挡的多头注意力(前瞻遮挡和填充遮挡)
    2.多头注意力(用填充遮挡)。V(数值)和 K(主键)接收编码器输出作为输入。Q(请求)接收遮挡的多头注意力子层的输出。
    3. 点式前馈网络
    out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
    out2 = BatchNormalization( out1 +(MultiHeadAttention(enc_output, enc_output out1)=>dropout))
    out3 = BatchNormalization( out2 + (ffn => dropout) )
    '''
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        # x.shape == (batch_size, target_seq_len, d_model)
        # enc_output.shape == (batch_size, input_seq_len, d_model)
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layer_norm1(x+attn1)

        attn2, attn_weights_block2 = self.mha1(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layer_norm2(out1+attn2)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layer_norm3(out2+ffn_output)  # (batch_size, target_seq_len, d_model)

        return out3, attn_weights_block1, attn_weights_block2

class Decoder(tf.keras.layers.Layer):
    '''解码器包括:
    输出嵌入(Output Embedding)
    位置编码(Positional Encoding)
    N 个解码器层(decoder layers)
    目标(target)经过一个嵌入后,该嵌入和位置编码相加。该加法结果是解码器层的输入。解码器的输出是最后的线性层的输入。
    '''
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
        self.dec_layer = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        # x.shape==(batch_size, target_seq_len)
        # enc_output.shape==(batch_size, input_seq_len, d_model)
        seq_len = tf.shape(x)[1]
        attention_weights = {
     }

        x = self.embedding(x) # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layer[i](x, enc_output, training, look_ahead_mask, padding_mask)
            attention_weights['decoder_layer{}_block1'.format(i + 1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i + 1)] = block2
        # x.shape==(batch_size, target_seq_len, d_model)
        return x, attention_weights

Transformer

class Transformer(tf.keras.Model):
    def __init__(self, params):
        super(Transformer, self).__init__()
        self.encoder = Encoder(params['num_layers'],params['d_model'],params['num_heads'],params['dff'],params['input_vocab_size'],params['pe_input'],params['rate'])
        self.decoder = Decoder(params['num_layers'],params['d_model'],params['num_heads'],params['dff'],params['target_vocab_size'],params['pe_target'],params['rate'])
        self.final_layer = tf.keras.layers.Dense(params['target_vocab_size'])

    def call(self, inp, tar, training, enc_padding_mask=None, look_ahead_mask=None, dec_padding_mask=None):
        # (batch_size, inp_seq_len, d_model)
        enc_output = self.encoder(inp, training, enc_padding_mask)
        # (batch_size, tar_seq_len, d_model)
        dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
        final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)
        return final_output, attention_weights

Mask

def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    # 添加额外的维度来将填充加到
    # 注意力对数(logits)。
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
    '''
    eg.
    x = tf.random.uniform((1, 3))
    temp = create_look_ahead_mask(x.shape[1])
    temp:
    '''
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)

def create_masks(inp, tar):
    # 编码器填充遮挡
    enc_padding_mask = create_padding_mask(inp)
    # 在解码器的第二个注意力模块使用。
    # 该填充遮挡用于遮挡编码器的输出。
    dec_padding_mask = create_padding_mask(inp)
    # 在解码器的第一个注意力模块使用。
    # 用于填充(pad)和遮挡(mask)解码器获取到的输入的后续标记(future tokens)。
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) #(tar_seq_len, tar_seq_len)
    dec_target_padding_mask = create_padding_mask(tar) # (batch_size, 1, 1, tar_seq_len)
    # 广播机制,look_ahead_mask==>(batch_size, 1, tar_seq_len, tar_seq_len)
    # dec_target_padding_mask ==> (batch_size, 1, tar_seq_len, tar_seq_len)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
    return enc_padding_mask, combined_mask, dec_padding_mask

组合最终

# ==============================================================
params = {
     
    'num_layers':4,
    'd_model':128,
    'dff':512,
    'num_heads':8,
    'input_vocab_size' :tokenizer_pt.vocab_size + 2,
    'target_vocab_size':tokenizer_en.vocab_size + 2,
    'pe_input':tokenizer_pt.vocab_size + 2,
    'pe_target':tokenizer_en.vocab_size + 2,
    'rate':0.1,
    'checkpoint_path':'./checkpoints/train',
    'checkpoint_do_delete':False
}


print('input_vocab_size is {}, target_vocab_size is {}'.format(params['input_vocab_size'], params['target_vocab_size']))


class ModelHelper:

    def __init__(self):
        self.transformer  = Transformer(params)
        # optimizer
        learning_rate = CustomSchedule(params['d_model'])
        self.optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

        # 主要为了累计一个epoch中的batch的loss,最后求平均,得到一个epoch的loss
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        # 主要为了累计一个epoch中的batch的acc,最后求平均,得到一个epoch的acc
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

        self.test_loss = tf.keras.metrics.Mean(name='test_loss')
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


        # 检查点 params['checkpoint_path']如果不存在,则创建对应目录;如果存在,且checkpoint_do_delete=True时,则先删除目录在创建
        checkout_dir(dir_path=params['checkpoint_path'], do_delete=params.get('checkpoint_do_delete', False))
        # 检查点
        ckpt = tf.train.Checkpoint(transformer=self.transformer,
                                   optimizer=self.optimizer)
        self.ckpt_manager = tf.train.CheckpointManager(ckpt, params['checkpoint_path'], max_to_keep=5)
        # 如果检查点存在,则恢复最新的检查点。
        if self.ckpt_manager.latest_checkpoint:
            ckpt.restore(self.ckpt_manager.latest_checkpoint)
            print('Latest checkpoint restored!!')

    def loss_function(self, real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = self.loss_object(real, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_mean(loss_)


    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    ]
    @tf.function(input_signature=train_step_signature)
    def train_step(self, inp, tar):
        tar_inp = tar[:, :-1]
        tar_real = tar[:, 1:]

        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

        with tf.GradientTape() as tape:
            predictions, _ = self.transformer(inp, tar_inp,
                                         True,
                                         enc_padding_mask,
                                         combined_mask,
                                         dec_padding_mask)
            loss = self.loss_function(tar_real, predictions)

        gradients = tape.gradient(loss, self.transformer.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.transformer.trainable_variables))
        self.train_loss(loss)
        self.train_accuracy(tar_real, predictions)

    @tf.function
    def test_step(self, inp, labels):
        predictions = self.predict(inp)
        t_loss = self.loss_object(labels, predictions)
        self.test_loss(t_loss)
        self.test_accuracy(labels, predictions)

    def train(self, train_dataset):
        for epoch in range(params['epochs']):
            start = time.time()
            self.train_loss.reset_states()
            self.train_accuracy.reset_states()
            # inp -> portuguese, tar -> english
            for (batch, (inp, tar)) in enumerate(train_dataset):
                self.train_step(inp, tar)
                if batch % 50 == 0:
                    print('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, batch, self.train_loss.result(), self.train_accuracy.result()))
            if (epoch + 1) % 5 == 0:
                ckpt_save_path = self.ckpt_manager.save()
                print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,ckpt_save_path))
            print('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, self.train_loss.result(), self.train_accuracy.result()))
            print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

    # 评估
    def predict(self, inp_sentence):
        start_token = [tokenizer_pt.vocab_size]
        end_token = [tokenizer_pt.vocab_size + 1]

        # 输入语句是葡萄牙语,增加开始和结束标记
        inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
        encoder_input = tf.expand_dims(inp_sentence, 0)

        # 因为目标是英语,输入 transformer 的第一个词应该是
        # 英语的开始标记。
        decoder_input = [tokenizer_en.vocab_size]
        output = tf.expand_dims(decoder_input, 0)

        for i in range(MAX_LENGTH):
            enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
                encoder_input, output)

            # predictions.shape == (batch_size, seq_len, vocab_size)
            predictions, attention_weights = self.transformer(encoder_input,
                                                         output,
                                                         False,
                                                         enc_padding_mask,
                                                         combined_mask,
                                                         dec_padding_mask)

            # 从 seq_len 维度选择最后一个词
            predictions = predictions[:, -1:, :]  # (batch_size, 1, vocab_size)
            predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
            # 如果 predicted_id 等于结束标记,就返回结果
            if predicted_id == tokenizer_en.vocab_size + 1:
                return tf.squeeze(output, axis=0), attention_weights
            # 连接 predicted_id 与输出,作为解码器的输入传递到解码器。
            output = tf.concat([output, predicted_id], axis=-1)
        return tf.squeeze(output, axis=0)

你可能感兴趣的:(nlp,nlp,transformer)