【Keras】基于tf2.0+和tf.keras的Transformer实现

代码来自: Learning transferable visual models from natural language supervision

该代码主要由有train.py、data_loader.py、model.py、utils.py和test.py组成,每一个python文件的作用如下:

  • train.py:初始化DataLoader(来自于data_loader.py)并调用DataLoader的load()方法构建训练集和验证集;构建Transformer(来自model.py)的模型结构;定义学习率、优化器和目标函数,并初始化Trainer(来自于utils.py)的结构,然后调用Trainer的single_gpu_train()方法进行模型的单GPU训练。
  • data_loader.py:构建训练集、验证集和测试集,并进行一些预处理。
  • model.py:搭建Transformer模型,其中包括EncoderLayer类、DecoderLayer类、PositionWiseFeedForwardLayer类、MultiHeadAttention类、ScaledDotProductAttention类、Embeddinglayer类。
  • utils.py:作为实现一些功能的工具文件,实现了用于训练的Trainer类、用于学习率设置warmup的CustomSchedule类、用于Mask操作的Mask类和一些其他功能的函数,如:标签平滑label_smoothing函数、计算bleu得分的calculate_bleu_score函数等。
  • test.py:初始化DataLoader(来自于data_loader.py)并调用DataLoader的load_test()方法构建测试集,定义Transformer(来自model.py)的模型结构为了后续加载进来训练好的模型参数,导入checkpoint并进行测试。

一、train.py

初始化DataLoader(来自于data_loader.py)并调用DataLoader的load()方法构建训练集和验证集;构建Transformer(来自model.py)的模型结构;定义学习率、优化器和目标函数,并初始化Trainer(来自于utils.py)的结构,然后调用Trainer的single_gpu_train()方法进行模型的单GPU训练。

from __future__ import (absolute_import, division, print_function,unicode_literals)

import os
import tensorflow as tf
from data_loader import DataLoader
from model import Transformer
from utils import CustomSchedule, Trainer

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


# hyper paramaters
TRAIN_RATIO = 0.9
D_POINT_WISE_FF = 2048
D_MODEL = 512
ENCODER_COUNT = DECODER_COUNT = 6
EPOCHS = 20
ATTENTION_HEAD_COUNT = 8
DROPOUT_PROB = 0.1
BATCH_SIZE = 32
SEQ_MAX_LEN_SOURCE = 100
SEQ_MAX_LEN_TARGET = 100
BPE_VOCAB_SIZE = 32000

# for overfitting test hyper parameters
# BATCH_SIZE = 32
# EPOCHS = 100
DATA_LIMIT = None

GLOBAL_BATCH_SIZE = (BATCH_SIZE * 1)
print('GLOBAL_BATCH_SIZE ', GLOBAL_BATCH_SIZE)

data_loader = DataLoader(
    dataset_name='wmt14/en-de',
    data_dir='./datasets',
    batch_size=GLOBAL_BATCH_SIZE,
    bpe_vocab_size=BPE_VOCAB_SIZE,
    seq_max_len_source=SEQ_MAX_LEN_SOURCE,
    seq_max_len_target=SEQ_MAX_LEN_TARGET,
    data_limit=DATA_LIMIT,
    train_ratio=TRAIN_RATIO
)

dataset, val_dataset = data_loader.load()

transformer = Transformer(
    inputs_vocab_size=BPE_VOCAB_SIZE,
    target_vocab_size=BPE_VOCAB_SIZE,
    encoder_count=ENCODER_COUNT,
    decoder_count=DECODER_COUNT,
    attention_head_count=ATTENTION_HEAD_COUNT,
    d_model=D_MODEL,
    d_point_wise_ff=D_POINT_WISE_FF,
    dropout_prob=DROPOUT_PROB
)

learning_rate = CustomSchedule(D_MODEL)
optimizer = tf.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
loss_object = tf.losses.CategoricalCrossentropy(from_logits=True, reduction='none')

trainer = Trainer(
    model=transformer,
    dataset=dataset,
    loss_object=loss_object,
    optimizer=optimizer,
    batch_size=GLOBAL_BATCH_SIZE,
    vocab_size=BPE_VOCAB_SIZE,
    epoch=EPOCHS,
)

trainer.single_gpu_train()

二、data_loader.py

构建训练集、验证集和测试集,并进行一些预处理。

import os
from urllib.request import urlretrieve

import sentencepiece
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tqdm import tqdm


class DataLoader:
    DIR = None
    PATHS = {}
    BPE_VOCAB_SIZE = 0
    MODES = ['source', 'target']
    dictionary = {
        'source': {
            'token2idx': None,
            'idx2token': None,
        },
        'target': {
            'token2idx': None,
            'idx2token': None,
        }
    }
    CONFIG = {
        'wmt14/en-de': {
            'source_lang': 'en',
            'target_lang': 'de',
            'base_url': 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/',
            'train_files': ['train.en', 'train.de'],
            'vocab_files': ['vocab.50K.en', 'vocab.50K.de'],
            'dictionary_files': ['dict.en-de'],
            'test_files': [
                'newstest2012.en', 'newstest2012.de',
                'newstest2013.en', 'newstest2013.de',
                'newstest2014.en', 'newstest2014.de',
                'newstest2015.en', 'newstest2015.de',
            ]
        }
    }
    BPE_MODEL_SUFFIX = '.model'
    BPE_VOCAB_SUFFIX = '.vocab'
    BPE_RESULT_SUFFIX = '.sequences'
    SEQ_MAX_LEN = {
        'source': 100,
        'target': 100
    }
    DATA_LIMIT = None
    TRAIN_RATIO = 0.9
    BATCH_SIZE = 16

    source_sp = None
    target_sp = None

    def __init__(self, dataset_name, data_dir, batch_size=16, bpe_vocab_size=32000, seq_max_len_source=100,
                 seq_max_len_target=100, data_limit=None, train_ratio=0.9):
        if dataset_name is None or data_dir is None:
            raise ValueError('dataset_name and data_dir must be defined')
        self.DIR = data_dir
        self.DATASET = dataset_name
        self.BPE_VOCAB_SIZE = bpe_vocab_size
        self.SEQ_MAX_LEN['source'] = seq_max_len_source
        self.SEQ_MAX_LEN['target'] = seq_max_len_target
        self.DATA_LIMIT = data_limit
        self.TRAIN_RATIO = train_ratio
        self.BATCH_SIZE = batch_size

        self.PATHS['source_data'] = os.path.join(self.DIR, self.CONFIG[self.DATASET]['train_files'][0])
        self.PATHS['source_bpe_prefix'] = self.PATHS['source_data'] + '.segmented'

        self.PATHS['target_data'] = os.path.join(self.DIR, self.CONFIG[self.DATASET]['train_files'][1])
        self.PATHS['target_bpe_prefix'] = self.PATHS['target_data'] + '.segmented'

    def load(self, custom_dataset=False):
        if custom_dataset:
            print('#1 use custom dataset. please implement custom download_dataset function.')
        else:            
            print('#1 download data')
            self.download_dataset()

        print('#2 parse data')
        source_data = self.parse_data_and_save(self.PATHS['source_data'])
        target_data = self.parse_data_and_save(self.PATHS['target_data'])

        print('#3 train bpe')

        self.train_bpe(self.PATHS['source_data'], self.PATHS['source_bpe_prefix'])
        self.train_bpe(self.PATHS['target_data'], self.PATHS['target_bpe_prefix'])

        print('#4 load bpe vocab')

        self.dictionary['source']['token2idx'], self.dictionary['source']['idx2token'] = self.load_bpe_vocab(
            self.PATHS['source_bpe_prefix'] + self.BPE_VOCAB_SUFFIX)
        self.dictionary['target']['token2idx'], self.dictionary['target']['idx2token'] = self.load_bpe_vocab(
            self.PATHS['target_bpe_prefix'] + self.BPE_VOCAB_SUFFIX)

        print('#5 encode data with bpe')
        source_sequences = self.texts_to_sequences(
            self.sentence_piece(
                source_data,
                self.PATHS['source_bpe_prefix'] + self.BPE_MODEL_SUFFIX,
                self.PATHS['source_bpe_prefix'] + self.BPE_RESULT_SUFFIX
            ),
            mode="source"
        )
        target_sequences = self.texts_to_sequences(
            self.sentence_piece(
                target_data,
                self.PATHS['target_bpe_prefix'] + self.BPE_MODEL_SUFFIX,
                self.PATHS['target_bpe_prefix'] + self.BPE_RESULT_SUFFIX
            ),
            mode="target"
        )

        print('source sequence example:', source_sequences[0])
        print('target sequence example:', target_sequences[0])

        if self.TRAIN_RATIO == 1.0:
            source_sequences_train = source_sequences
            source_sequences_val = []
            target_sequences_train = target_sequences
            target_sequences_val = []
        else:
            (source_sequences_train,
             source_sequences_val,
             target_sequences_train,
             target_sequences_val) = train_test_split(
                source_sequences, target_sequences, train_size=self.TRAIN_RATIO
            )

        if self.DATA_LIMIT is not None:
            print('data size limit ON. limit size:', self.DATA_LIMIT)
            source_sequences_train = source_sequences_train[:self.DATA_LIMIT]
            target_sequences_train = target_sequences_train[:self.DATA_LIMIT]

        print('source_sequences_train', len(source_sequences_train))
        print('source_sequences_val', len(source_sequences_val))
        print('target_sequences_train', len(target_sequences_train))
        print('target_sequences_val', len(target_sequences_val))

        print('train set size: ', len(source_sequences_train))
        print('validation set size: ', len(source_sequences_val))

        train_dataset = self.create_dataset(
            source_sequences_train,
            target_sequences_train
        )
        if self.TRAIN_RATIO == 1.0:
            val_dataset = None
        else:
            val_dataset = self.create_dataset(
                source_sequences_val,
                target_sequences_val
            )

        return train_dataset, val_dataset

    def load_test(self, index=0, custom_dataset=False):
        
        if index < 0 or index >= len(self.CONFIG[self.DATASET]['test_files']) // 2:
            raise ValueError('test file index out of range. min: 0, max: {}'.format(
                len(self.CONFIG[self.DATASET]['test_files']) // 2 - 1)
            )
        if custom_dataset:
            print('#1 use custom dataset. please implement custom download_dataset function.')
        else:
            print('#1 download data')
            self.download_dataset()

        print('#2 parse data')

        source_test_data_path, target_test_data_path = self.get_test_data_path(index)

        source_data = self.parse_data_and_save(source_test_data_path)
        target_data = self.parse_data_and_save(target_test_data_path)

        print('#3 load bpe vocab')

        self.dictionary['source']['token2idx'], self.dictionary['source']['idx2token'] = self.load_bpe_vocab(
            self.PATHS['source_bpe_prefix'] + self.BPE_VOCAB_SUFFIX)
        self.dictionary['target']['token2idx'], self.dictionary['target']['idx2token'] = self.load_bpe_vocab(
            self.PATHS['target_bpe_prefix'] + self.BPE_VOCAB_SUFFIX)

        return source_data, target_data

    def get_test_data_path(self, index):
        source_test_data_path = os.path.join(self.DIR, self.CONFIG[self.DATASET]['test_files'][index * 2])
        target_test_data_path = os.path.join(self.DIR, self.CONFIG[self.DATASET]['test_files'][index * 2 + 1])
        return source_test_data_path, target_test_data_path

    def download_dataset(self):
        for file in (self.CONFIG[self.DATASET]['train_files']
                     + self.CONFIG[self.DATASET]['vocab_files']
                     + self.CONFIG[self.DATASET]['dictionary_files']
                     + self.CONFIG[self.DATASET]['test_files']):
            self._download("{}{}".format(self.CONFIG[self.DATASET]['base_url'], file))

    def _download(self, url):
        path = os.path.join(self.DIR, url.split('/')[-1])
        if not os.path.exists(path):
            with TqdmCustom(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=url) as t:
                urlretrieve(url, path, t.update_to)

    def parse_data_and_save(self, path):
        print('load data from {}'.format(path))
        with open(path, encoding='utf-8') as f:
            lines = f.read().strip().split('\n')

        if lines is None:
            raise ValueError('Vocab file is invalid')

        with open(path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(lines))

        return lines

    def train_bpe(self, data_path, model_prefix):
        model_path = model_prefix + self.BPE_MODEL_SUFFIX
        vocab_path = model_prefix + self.BPE_VOCAB_SUFFIX

        if not (os.path.exists(model_path) and os.path.exists(vocab_path)):
            print('bpe model does not exist. train bpe. model path:', model_path, ' vocab path:', vocab_path)
            train_source_params = "--inputs={} \
                --pad_id=0 \
                --unk_id=1 \
                --bos_id=2 \
                --eos_id=3 \
                --model_prefix={} \
                --vocab_size={} \
                --model_type=bpe ".format(
                data_path,
                model_prefix,
                self.BPE_VOCAB_SIZE
            )
            sentencepiece.SentencePieceTrainer.Train(train_source_params)
        else:
            print('bpe model exist. load bpe. model path:', model_path, ' vocab path:', vocab_path)

    def load_bpe_encoder(self):
        self.dictionary['source']['token2idx'], self.dictionary['source']['idx2token'] = self.load_bpe_vocab(
            self.PATHS['source_bpe_prefix'] + self.BPE_VOCAB_SUFFIX
        )
        self.dictionary['target']['token2idx'], self.dictionary['target']['idx2token'] = self.load_bpe_vocab(
            self.PATHS['target_bpe_prefix'] + self.BPE_VOCAB_SUFFIX
        )

    def sentence_piece(self, source_data, source_bpe_model_path, result_data_path):
        sp = sentencepiece.SentencePieceProcessor()
        sp.load(source_bpe_model_path)

        if os.path.exists(result_data_path):
            print('encoded data exist. load data. path:', result_data_path)
            with open(result_data_path, 'r', encoding='utf-8') as f:
                sequences = f.read().strip().split('\n')
                return sequences

        print('encoded data does not exist. encode data. path:', result_data_path)
        sequences = []
        with open(result_data_path, 'w') as f:
            for sentence in tqdm(source_data):
                pieces = sp.EncodeAsPieces(sentence)
                sequence = " ".join(pieces)
                sequences.append(sequence)
                f.write(sequence + "\n")
        return sequences

    def encode_data(self, inputs, mode='source'):
        if mode not in self.MODES:
            ValueError('not allowed mode.')

        if mode == 'source':
            if self.source_sp is None:
                self.source_sp = sentencepiece.SentencePieceProcessor()
                self.source_sp.load(self.PATHS['source_bpe_prefix'] + self.BPE_MODEL_SUFFIX)

            pieces = self.source_sp.EncodeAsPieces(inputs)
            sequence = " ".join(pieces)

        elif mode == 'target':
            if self.target_sp is None:
                self.target_sp = sentencepiece.SentencePieceProcessor()
                self.target_sp.load(self.PATHS['target_bpe_prefix'] + self.BPE_MODEL_SUFFIX)

            pieces = self.target_sp.EncodeAsPieces(inputs)
            sequence = " ".join(pieces)

        else:
            ValueError('not allowed mode.')

        return sequence

    def load_bpe_vocab(self, bpe_vocab_path):
        with open(bpe_vocab_path, 'r') as f:
            vocab = [line.split()[0] for line in f.read().splitlines()]

        token2idx = {}
        idx2token = {}

        for idx, token in enumerate(vocab):
            token2idx[token] = idx
            idx2token[idx] = token
        return token2idx, idx2token

    def texts_to_sequences(self, texts, mode='source'):
        if mode not in self.MODES:
            ValueError('not allowed mode.')

        sequences = []
        for text in texts:
            text_list = [""] + text.split() + [""]

            sequence = [
                self.dictionary[mode]['token2idx'].get(
                    token, self.dictionary[mode]['token2idx'][""]
                )
                for token in text_list
            ]
            sequences.append(sequence)
        return sequences

    def sequences_to_texts(self, sequences, mode='source'):
        if mode not in self.MODES:
            ValueError('not allowed mode.')

        texts = []
        for sequence in sequences:
            if mode == 'source':
                if self.source_sp is None:
                    self.source_sp = sentencepiece.SentencePieceProcessor()
                    self.source_sp.load(self.PATHS['source_bpe_prefix'] + self.BPE_MODEL_SUFFIX)
                text = self.source_sp.DecodeIds(sequence)
            else:
                if self.target_sp is None:
                    self.target_sp = sentencepiece.SentencePieceProcessor()
                    self.target_sp.load(self.PATHS['target_bpe_prefix'] + self.BPE_MODEL_SUFFIX)
                text = self.target_sp.DecodeIds(sequence)
            texts.append(text)
        return texts

    def create_dataset(self, source_sequences, target_sequences):
        new_source_sequences = []
        new_target_sequences = []
        for source, target in zip(source_sequences, target_sequences):
            if len(source) > self.SEQ_MAX_LEN['source']:
                continue
            if len(target) > self.SEQ_MAX_LEN['target']:
                continue
            new_source_sequences.append(source)
            new_target_sequences.append(target)

        source_sequences = tf.keras.preprocessing.sequence.pad_sequences(
            sequences=new_source_sequences, maxlen=self.SEQ_MAX_LEN['source'], padding='post'
        )
        target_sequences = tf.keras.preprocessing.sequence.pad_sequences(
            sequences=new_target_sequences, maxlen=self.SEQ_MAX_LEN['target'], padding='post'
        )
        buffer_size = int(source_sequences.shape[0] * 0.3)
        dataset = tf.data.Dataset.from_tensor_slices(
            (source_sequences, target_sequences)
        ).shuffle(buffer_size)
        dataset = dataset.batch(self.BATCH_SIZE)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        return dataset


class TqdmCustom(tqdm):

    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

三、model.py

搭建Transformer模型,其中包括EncoderLayer类、DecoderLayer类、PositionWiseFeedForwardLayer类、MultiHeadAttention类、ScaledDotProductAttention类、Embeddinglayer类。
【Keras】基于tf2.0+和tf.keras的Transformer实现_第1张图片
其中,Transformer类是继承了tf.keras.Model类,Model类将各种层进行组织和连接,并封装成一个整体,描述了如何将输入数据通过各种层以及运算而得到输出。

Keras 模型以类的形式呈现,我们可以通过继承 tf.keras.Model 这个 Python 类来定义自己的模型。在继承类中,我们需要重写 init() (构造函数,初始化)和 call(input) (模型调用)两个方法,同时也可以根据需要增加自定义的方法。

EncoderLayer类、DecoderLayer类、PositionWiseFeedForwardLayer类、MultiHeadAttention类、ScaledDotProductAttention类、Embeddinglayer类则是继承了tf.keras.layers.Layer类,Layer类将各种计算流程和变量进行了封装(例如基本的全连接层,CNN 的卷积层、池化层等)。

我们可以通过继承 tf.keras.layers.Layer这个 Python 类来定义自己的层。在继承类中,我们需要重写 init() (构造函数,初始化)和 call(input) (模型调用)两个方法,同时也可以根据需要增加自定义的方法。

import os

import numpy as np
import tensorflow as tf

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


class Transformer(tf.keras.Model):
    def __init__(self,
                 inputs_vocab_size,
                 target_vocab_size,
                 encoder_count,
                 decoder_count,
                 attention_head_count,
                 d_model,
                 d_point_wise_ff,
                 dropout_prob):
        super(Transformer, self).__init__()

        # model hyper parameter variables
        self.encoder_count = encoder_count
        self.decoder_count = decoder_count
        self.attention_head_count = attention_head_count
        self.d_model = d_model
        self.d_point_wise_ff = d_point_wise_ff
        self.dropout_prob = dropout_prob

        self.encoder_embedding_layer = Embeddinglayer(inputs_vocab_size, d_model)
        self.encoder_embedding_dropout = tf.keras.layers.Dropout(dropout_prob)
        self.decoder_embedding_layer = Embeddinglayer(target_vocab_size, d_model)
        self.decoder_embedding_dropout = tf.keras.layers.Dropout(dropout_prob)

        self.encoder_layers = [
            EncoderLayer(
                attention_head_count,
                d_model,
                d_point_wise_ff,
                dropout_prob
            ) for _ in range(encoder_count)
        ]

        self.decoder_layers = [
            DecoderLayer(
                attention_head_count,
                d_model,
                d_point_wise_ff,
                dropout_prob
            ) for _ in range(decoder_count)
        ]

        self.linear = tf.keras.layers.Dense(target_vocab_size)

    def call(self,
             inputs,
             target,
             inputs_padding_mask,
             look_ahead_mask,
             target_padding_mask,
             training
             ):
        encoder_tensor = self.encoder_embedding_layer(inputs)
        encoder_tensor = self.encoder_embedding_dropout(encoder_tensor, training=training)

        for i in range(self.encoder_count):
            encoder_tensor, _ = self.encoder_layers[i](encoder_tensor, inputs_padding_mask, training=training)
        target = self.decoder_embedding_layer(target)
        decoder_tensor = self.decoder_embedding_dropout(target, training=training)
        for i in range(self.decoder_count):
            decoder_tensor, _, _ = self.decoder_layers[i](
                decoder_tensor,
                encoder_tensor,
                look_ahead_mask,
                target_padding_mask,
                training=training
            )
        return self.linear(decoder_tensor)


class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, attention_head_count, d_model, d_point_wise_ff, dropout_prob):
        super(EncoderLayer, self).__init__()

        # model hyper parameter variables
        self.attention_head_count = attention_head_count
        self.d_model = d_model
        self.d_point_wise_ff = d_point_wise_ff
        self.dropout_prob = dropout_prob

        self.multi_head_attention = MultiHeadAttention(attention_head_count, d_model)
        self.dropout_1 = tf.keras.layers.Dropout(dropout_prob)
        self.layer_norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.position_wise_feed_forward_layer = PositionWiseFeedForwardLayer(
            d_point_wise_ff,
            d_model
        )
        self.dropout_2 = tf.keras.layers.Dropout(dropout_prob)
        self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, mask, training):
        output, attention = self.multi_head_attention(inputs, inputs, inputs, mask)
        output = self.dropout_1(output, training=training)
        output = self.layer_norm_1(tf.add(inputs, output))  # residual network
        output_temp = output

        output = self.position_wise_feed_forward_layer(output)
        output = self.dropout_2(output, training=training)
        output = self.layer_norm_2(tf.add(output_temp, output)) #correct

        return output, attention


class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, attention_head_count, d_model, d_point_wise_ff, dropout_prob):
        super(DecoderLayer, self).__init__()

        # model hyper parameter variables
        self.attention_head_count = attention_head_count
        self.d_model = d_model
        self.d_point_wise_ff = d_point_wise_ff
        self.dropout_prob = dropout_prob

        self.masked_multi_head_attention = MultiHeadAttention(attention_head_count, d_model)
        self.dropout_1 = tf.keras.layers.Dropout(dropout_prob)
        self.layer_norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.encoder_decoder_attention = MultiHeadAttention(attention_head_count, d_model)
        self.dropout_2 = tf.keras.layers.Dropout(dropout_prob)
        self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.position_wise_feed_forward_layer = PositionWiseFeedForwardLayer(
            d_point_wise_ff,
            d_model
        )
        self.dropout_3 = tf.keras.layers.Dropout(dropout_prob)
        self.layer_norm_3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    def call(self, decoder_inputs, encoder_output, look_ahead_mask, padding_mask, training):
        output, attention_1 = self.masked_multi_head_attention(
            decoder_inputs,
            decoder_inputs,
            decoder_inputs,
            look_ahead_mask
        )
        output = self.dropout_1(output, training=training)
        query = self.layer_norm_1(tf.add(decoder_inputs, output))  # residual network
        output, attention_2 = self.encoder_decoder_attention(
            query,
            encoder_output,
            encoder_output,
            padding_mask
        )
        output = self.dropout_2(output, training=training)
        encoder_decoder_attention_output = self.layer_norm_2(tf.add(output, query))

        output = self.position_wise_feed_forward_layer(encoder_decoder_attention_output)
        output = self.dropout_3(output, training=training)
        output = self.layer_norm_3(tf.add(encoder_decoder_attention_output, output))  # residual network

        return output, attention_1, attention_2


class PositionWiseFeedForwardLayer(tf.keras.layers.Layer):
    def __init__(self, d_point_wise_ff, d_model):
        super(PositionWiseFeedForwardLayer, self).__init__()
        self.w_1 = tf.keras.layers.Dense(d_point_wise_ff)
        self.w_2 = tf.keras.layers.Dense(d_model)

    def call(self, inputs):
        inputs = self.w_1(inputs)
        inputs = tf.nn.relu(inputs)
        return self.w_2(inputs)


class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, attention_head_count, d_model):
        super(MultiHeadAttention, self).__init__()

        # model hyper parameter variables
        self.attention_head_count = attention_head_count
        self.d_model = d_model

        if d_model % attention_head_count != 0:
            raise ValueError(
                "d_model({}) % attention_head_count({}) is not zero.d_model must be multiple of attention_head_count.".format(
                    d_model, attention_head_count
                )
            )

        self.d_h = d_model // attention_head_count

        self.w_query = tf.keras.layers.Dense(d_model)
        self.w_key = tf.keras.layers.Dense(d_model)
        self.w_value = tf.keras.layers.Dense(d_model)

        self.scaled_dot_product = ScaledDotProductAttention(self.d_h)

        self.ff = tf.keras.layers.Dense(d_model)

    def call(self, query, key, value, mask=None):
        batch_size = tf.shape(query)[0]

        query = self.w_query(query)
        key = self.w_key(key)
        value = self.w_value(value)

        query = self.split_head(query, batch_size)
        key = self.split_head(key, batch_size)
        value = self.split_head(value, batch_size)

        output, attention = self.scaled_dot_product(query, key, value, mask)
        output = self.concat_head(output, batch_size)

        return self.ff(output), attention

    def split_head(self, tensor, batch_size):
        # inputs tensor: (batch_size, seq_len, d_model)
        return tf.transpose(
            tf.reshape(
                tensor,
                (batch_size, -1, self.attention_head_count, self.d_h)
                # tensor: (batch_size, seq_len_splited, attention_head_count, d_h)
            ),
            [0, 2, 1, 3]
            # tensor: (batch_size, attention_head_count, seq_len_splited, d_h)
        )

    def concat_head(self, tensor, batch_size):
        return tf.reshape(
            tf.transpose(tensor, [0, 2, 1, 3]),
            (batch_size, -1, self.attention_head_count * self.d_h)
        )


class ScaledDotProductAttention(tf.keras.layers.Layer):
    def __init__(self, d_h):
        super(ScaledDotProductAttention, self).__init__()
        self.d_h = d_h

    def call(self, query, key, value, mask=None):
        matmul_q_and_transposed_k = tf.matmul(query, key, transpose_b=True)
        scale = tf.sqrt(tf.cast(self.d_h, dtype=tf.float32))
        scaled_attention_score = matmul_q_and_transposed_k / scale
        if mask is not None:
            scaled_attention_score += (mask * -1e9)

        attention_weight = tf.nn.softmax(scaled_attention_score, axis=-1)

        return tf.matmul(attention_weight, value), attention_weight


class Embeddinglayer(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        # model hyper parameter variables
        super(Embeddinglayer, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model

        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)

    def call(self, sequences):
        max_sequence_len = sequences.shape[1]
        output = self.embedding(sequences) * tf.sqrt(tf.cast(self.d_model, dtype=tf.float32))
        output += self.positional_encoding(max_sequence_len)

        return output

    def positional_encoding(self, max_len):
        pos = np.expand_dims(np.arange(0, max_len), axis=1)
        index = np.expand_dims(np.arange(0, self.d_model), axis=0)

        pe = self.angle(pos, index)

        pe[:, 0::2] = np.sin(pe[:, 0::2])
        pe[:, 1::2] = np.cos(pe[:, 1::2])

        pe = np.expand_dims(pe, axis=0)
        return tf.cast(pe, dtype=tf.float32)

    def angle(self, pos, index):
        return pos / np.power(10000, (index - index % 2) / np.float32(self.d_model))

四、utils.py

作为实现一些功能的工具文件,实现了用于训练的Trainer类、用于学习率设置warmup的CustomSchedule类、用于Mask操作的Mask类和一些其他功能的函数,如:标签平滑label_smoothing函数、计算bleu得分的calculate_bleu_score函数等。

整个代码如下:

import datetime
import os
import re
import time

import tensorflow as tf

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
CURRENT_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BLEU_CALCULATOR_PATH = os.path.join(CURRENT_DIR_PATH, 'multi-bleu.perl')


class Mask:
    @classmethod
    def create_padding_mask(cls, sequences):
        sequences = tf.cast(tf.math.equal(sequences, 0), dtype=tf.float32)
        return sequences[:, tf.newaxis, tf.newaxis, :]

    @classmethod
    def create_look_ahead_mask(cls, seq_len):
        return 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)

    @classmethod
    def create_masks(cls, inputs, target):
        encoder_padding_mask = Mask.create_padding_mask(inputs)
        decoder_padding_mask = Mask.create_padding_mask(inputs)

        look_ahead_mask = tf.maximum(
            Mask.create_look_ahead_mask(tf.shape(target)[1]),
            Mask.create_padding_mask(target)
            )

        return encoder_padding_mask, look_ahead_mask, decoder_padding_mask


class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


def label_smoothing(target_data, depth, epsilon=0.1):
    target_data_one_hot = tf.one_hot(target_data, depth=depth)
    n = target_data_one_hot.get_shape().as_list()[-1]
    return ((1 - epsilon) * target_data_one_hot) + (epsilon / n)


class Trainer:
    def __init__(
            self,
            model,
            dataset,
            loss_object=None,
            optimizer=None,
            checkpoint_dir='./checkpoints',
            batch_size=None,
            distribute_strategy=None,
            vocab_size=32000,
            epoch=20,
            ):
        self.batch_size = batch_size
        self.distribute_strategy = distribute_strategy
        self.model = model
        self.loss_object = loss_object
        self.optimizer = optimizer
        self.checkpoint_dir = checkpoint_dir
        self.vocab_size = vocab_size
        self.epoch = epoch
        self.dataset = dataset

        os.makedirs(self.checkpoint_dir, exist_ok=True)
        if self.optimizer is None:
            self.checkpoint = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
        else:
            self.checkpoint = tf.train.Checkpoint(step=tf.Variable(1), optimizer=self.optimizer, model=self.model)
        self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3)

        # metrics
        self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
        self.validation_loss = tf.keras.metrics.Mean('validation_loss', dtype=tf.float32)
        self.validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('validation_accuracy')

    def multi_gpu_train(self, reset_checkpoint=False):
        with self.distribute_strategy.scope():
            self.dataset = self.distribute_strategy.experimental_distribute_dataset(self.dataset)
            self.trainer(reset_checkpoint=reset_checkpoint, is_distributed=True)

    def single_gpu_train(self, reset_checkpoint=False):
        self.trainer(reset_checkpoint=reset_checkpoint, is_distributed=False)

    def trainer(self, reset_checkpoint, is_distributed=False):
        current_day = datetime.datetime.now().strftime("%Y%m%d")
        train_log_dir = './logs/gradient_tape/' + current_day + '/train'
        os.makedirs(train_log_dir, exist_ok=True)
        train_summary_writer = tf.summary.create_file_writer(train_log_dir)

        if not reset_checkpoint:
            if self.checkpoint_manager.latest_checkpoint:
                print("Restored from {}".format(self.checkpoint_manager.latest_checkpoint))
            else:
                print("Initializing from scratch.")

            self.checkpoint.restore(
                self.checkpoint_manager.latest_checkpoint
            )
        else:
            print("reset and initializing from scratch.")

        for epoch in range(self.epoch):
            start = time.time()
            print('start learning')

            for (batch, (inputs, target)) in enumerate(self.dataset):
                if is_distributed:
                    self.distributed_train_step(inputs, target)
                else:
                    self.train_step(inputs, target)

                self.checkpoint.step.assign_add(1)
                if batch % 50 == 0:
                    print(
                        "Epoch: {}, Batch: {}, Loss:{}, Accuracy: {}".format(epoch, batch, self.train_loss.result(),
                                                                             self.train_accuracy.result()))
                if batch % 10000 == 0 and batch != 0:
                    self.checkpoint_manager.save()
            print("{} | Epoch: {} Loss:{}, Accuracy: {}, time: {} sec".format(
                datetime.datetime.now(), epoch, self.train_loss.result(), self.train_accuracy.result(),
                time.time() - start
            ))
            with train_summary_writer.as_default():
                tf.summary.scalar('train_loss', self.train_loss.result(), step=epoch)
                tf.summary.scalar('train_accuracy', self.train_accuracy.result(), step=epoch)

            self.checkpoint_manager.save()

            self.train_loss.reset_states()
            self.train_accuracy.reset_states()
            self.validation_loss.reset_states()
            self.validation_accuracy.reset_states()
        self.checkpoint_manager.save()

    def basic_train_step(self, inputs, target):
        target_include_start = target[:, :-1]
        target_include_end = target[:, 1:]
        encoder_padding_mask, look_ahead_mask, decoder_padding_mask = Mask.create_masks(
            inputs, target_include_start
        )

        with tf.GradientTape() as tape:
            pred = self.model.call(
                inputs=inputs,
                target=target_include_start,
                inputs_padding_mask=encoder_padding_mask,
                look_ahead_mask=look_ahead_mask,
                target_padding_mask=decoder_padding_mask,
                training=True
            )

            loss = self.loss_function(target_include_end, pred)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.train_loss(loss)
        self.train_accuracy(target_include_end, pred)

        if self.distribute_strategy is None:
            return tf.reduce_mean(loss)

        return loss

    def loss_function(self, real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        real_one_hot = label_smoothing(real, depth=self.vocab_size)
        loss = self.loss_object(real_one_hot, pred)

        mask = tf.cast(mask, dtype=loss.dtype)

        loss *= mask
        return tf.reduce_mean(loss)

    @tf.function
    def train_step(self, inputs, target):
        return self.basic_train_step(inputs, target)

    @tf.function
    def distributed_train_step(self, inputs, target):
        loss = self.distribute_strategy.experimental_run_v2(self.basic_train_step, args=(inputs, target))
        loss_value = self.distribute_strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
        return tf.reduce_mean(loss_value)


def translate(inputs, data_loader, trainer, seq_max_len_target=100):
    if data_loader is None:
        ValueError('data loader is None')

    if trainer is None:
        ValueError('trainer is None')

    if trainer.model is None:
        ValueError('model is None')

    if not isinstance(seq_max_len_target, int):
        ValueError('seq_max_len_target is not int')

    encoded_data = data_loader.encode_data(inputs, mode='source')
    encoded_data = data_loader.texts_to_sequences([encoded_data])
    encoder_inputs = tf.convert_to_tensor(
        encoded_data,
        dtype=tf.int32
    )
    decoder_inputs = [data_loader.dictionary['target']['token2idx']['']]
    decoder_inputs = tf.expand_dims(decoder_inputs, 0)
    decoder_end_token = data_loader.dictionary['target']['token2idx']['']

    for _ in range(seq_max_len_target):
        encoder_padding_mask, look_ahead_mask, decoder_padding_mask = Mask.create_masks(
            encoder_inputs, decoder_inputs
        )
        pred = trainer.model.call(
            inputs=encoder_inputs,
            target=decoder_inputs,
            inputs_padding_mask=encoder_padding_mask,
            look_ahead_mask=look_ahead_mask,
            target_padding_mask=decoder_padding_mask,
            training=False
        )
        pred = pred[:, -1:, :]
        predicted_id = tf.cast(tf.argmax(pred, axis=-1), dtype=tf.int32)

        if predicted_id == decoder_end_token:
            break
        decoder_inputs = tf.concat([decoder_inputs, predicted_id], axis=-1)

    total_output = tf.squeeze(decoder_inputs, axis=0)
    return data_loader.sequences_to_texts([total_output.numpy().tolist()], mode='target')


def calculate_bleu_score(target_path, ref_path):

    get_bleu_score = f"perl {BLEU_CALCULATOR_PATH} {ref_path} < {target_path} > temp"
    os.system(get_bleu_score)
    with open("temp", "r") as f:
        bleu_score_report = f.read()
    score = re.findall("BLEU = ([^,]+)", bleu_score_report)[0]

    return score, bleu_score_report

四+1、utils.py的Trainer实现

由于在tensorflow2.0以上的版本中,不再使用tensorflow1所使用的Session图执行模式进行,而是使用即时执行模式作为默认模式,所以在构建训练过程的部分有一些不同,这一小节将着重介绍tf2.0以上版本中的训练过程。

1. Trainer类中的basic_train_step()函数实现

basic_train_step()函数是被train_step()函数调用的(train_step()函数将在下一小节介绍),即:

@tf.function
def train_step(self, inputs, target):
    return self.basic_train_step(inputs, target)

在机器学习中,我们经常需要计算函数的导数。TensorFlow 提供了强大的 自动求导机制 来计算导数。在即时执行模式下,TensorFlow 引入了 tf.GradientTape() 这个 “求导记录器” 来实现自动求导。

tf.GradientTape() 是一个自动求导的记录器。只要进入了 with tf.GradientTape() as tape 的上下文环境,则在该环境中计算步骤都会被自动记录。如以下所示:

# 在 with tf.GradientTape() as tape 的上下文环境中,
# 调用了model的call方法,也就是把数据从input送入到模型中,
# 然后一步步得到pred,并计算了loss函数,
# 整个从数据输入到计算loss函数的过程都是会被自动记录并用于后面的求导进行参数更新。
with tf.GradientTape() as tape:
    pred = self.model.call(
        inputs=inputs,
        target=target_include_start,
        inputs_padding_mask=encoder_padding_mask,
        look_ahead_mask=look_ahead_mask,
        target_padding_mask=decoder_padding_mask,
        training=True
    )

    loss = self.loss_function(target_include_end, pred)

# TensorFlow自动计算损失函数关于自变量(模型参数)的梯度
gradients = tape.gradient(loss, self.model.trainable_variables)

# TensorFlow自动根据梯度更新参数
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

在上面所示的例子中,在 with tf.GradientTape() as tape 的上下文环境中,调用了model的call方法,也就是把数据从input送入到模型中,然后一步步得到pred,并计算了loss函数,整个从数据输入到计算loss函数的过程都是会被自动记录并用于后面的求导进行参数更新。

离开 with tf.GradientTape() as tape 上下文环境后,记录将停止,但记录器 tape 依然可用,因此可以通过 gradients = tape.gradient(loss, self.model.trainable_variables) 求张量 loss 对变量 self.model.trainable_variables 的导数。

TensorFlow 的 即时执行模式提供了更快速的运算(GPU 支持)、自动求导、优化器等一系列对深度学习非常重要的功能。以下展示了如何使用 TensorFlow 计算线性回归。这里,TensorFlow 帮助我们做了两件重要的工作:

  • 使用 tape.gradient(loss, self.model.trainable_variables) 自动计算梯度;
  • 使用 optimizer.apply_gradients(grads_and_vars) 自动更新模型参数,即:self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

gradients = tape.gradient(loss, self.model.trainable_variables)这一步是TensorFlow自动计算损失函数关于自变量(模型参数)的梯度;

self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))这一步是TensorFlow自动根据梯度更新参数

2. Trainer类中的train_step()函数实现

train_step()函数是被trainer()函数调用的(trainer()函数将在下一小节介绍),即:
【Keras】基于tf2.0+和tf.keras的Transformer实现_第2张图片

这一小节我们将介绍train_step()函数,train_step()函数会调用上一小节介绍的basic_train_step()函数,即:

@tf.function
def train_step(self, inputs, target):
    return self.basic_train_step(inputs, target)

我们可以看到,train_step()函数有一个@tf.function修饰符,@tf.function修饰符可以一定程度上加速程序运行,具体可以参考这一篇博客:【Keras】tf.function :图执行模式

在 TensorFlow 2 中,推荐使用 tf.function (而非 1.X 中的 tf.Session )实现图执行模式,从而将模型转换为易于部署且高性能的 TensorFlow 图模型。只需要将我们希望以图执行模式运行的代码封装在一个函数内,并在函数前加上 @tf.function 即可。

运行 400 个 Batch 进行测试,加入 @tf.function 的程序耗时 35.5 秒,未加入 @tf.function 的纯即时执行模式程序耗时 43.8 秒。可见 @tf.function 带来了一定的性能提升。一般而言,当模型由较多小的操作组成的时候, @tf.function 带来的提升效果较大。而当模型的操作数量较少,但单一操作均很耗时的时候,则 @tf.function 带来的性能提升不会太大。

【Keras】基于tf2.0+和tf.keras的Transformer实现_第3张图片

3. Trainer中的trainer()函数实现

trainer()函数是被single_gpu_train()函数调用的,而single_gpu_train()函数是被train.py中调用的,由此进入训练。这一小节将介绍trainer()函数的实现。

def trainer(self, reset_checkpoint, is_distributed=False):
    current_day = datetime.datetime.now().strftime("%Y%m%d")
    train_log_dir = './logs/gradient_tape/' + current_day + '/train'
    os.makedirs(train_log_dir, exist_ok=True)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    if not reset_checkpoint:
        if self.checkpoint_manager.latest_checkpoint:
            print("Restored from {}".format(self.checkpoint_manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        self.checkpoint.restore(
            self.checkpoint_manager.latest_checkpoint
        )
    else:
        print("reset and initializing from scratch.")

    for epoch in range(self.epoch):
        start = time.time()
        print('start learning')

        for (batch, (inputs, target)) in enumerate(self.dataset):
            if is_distributed:
                self.distributed_train_step(inputs, target)
            else:
                self.train_step(inputs, target)

            self.checkpoint.step.assign_add(1)
            if batch % 50 == 0:
                print(
                    "Epoch: {}, Batch: {}, Loss:{}, Accuracy: {}".format(epoch, batch, self.train_loss.result(),
                                                                         self.train_accuracy.result()))
            if batch % 10000 == 0 and batch != 0:
                self.checkpoint_manager.save()
        print("{} | Epoch: {} Loss:{}, Accuracy: {}, time: {} sec".format(
            datetime.datetime.now(), epoch, self.train_loss.result(), self.train_accuracy.result(),
            time.time() - start
        ))
        with train_summary_writer.as_default():
            tf.summary.scalar('train_loss', self.train_loss.result(), step=epoch)
            tf.summary.scalar('train_accuracy', self.train_accuracy.result(), step=epoch)

        self.checkpoint_manager.save()

        self.train_loss.reset_states()
        self.train_accuracy.reset_states()
        self.validation_loss.reset_states()
        self.validation_accuracy.reset_states()
    self.checkpoint_manager.save()

五、test.py

初始化DataLoader(来自于data_loader.py)并调用DataLoader的load_test()方法构建测试集,定义Transformer(来自model.py)的模型结构为了后续加载进来训练好的模型参数,导入checkpoint并进行测试。

from __future__ import (absolute_import, division, print_function,unicode_literals)

import os

from data_loader import DataLoader
from model import Transformer
from utils import Trainer, calculate_bleu_score, translate

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# hyper paramaters
TRAIN_RATIO = 0.9
D_POINT_WISE_FF = 2048
D_MODEL = 512
ENCODER_COUNT = DECODER_COUNT = 6
EPOCHS = 20
ATTENTION_HEAD_COUNT = 8
DROPOUT_PROB = 0.1
BATCH_SIZE = 32
SEQ_MAX_LEN_SOURCE = 100
SEQ_MAX_LEN_TARGET = 100
BPE_VOCAB_SIZE = 32000

data_loader = DataLoader(
    dataset_name='wmt14/en-de',
    data_dir='./datasets'
)
data_loader.load_bpe_encoder()

source_data, target_data = data_loader.load_test(index=3)
_, target_data_path = data_loader.get_test_data_path(index=3)

data = zip(source_data, target_data)

transformer = Transformer(
    inputs_vocab_size=BPE_VOCAB_SIZE,
    target_vocab_size=BPE_VOCAB_SIZE,
    encoder_count=ENCODER_COUNT,
    decoder_count=DECODER_COUNT,
    attention_head_count=ATTENTION_HEAD_COUNT,
    d_model=D_MODEL,
    d_point_wise_ff=D_POINT_WISE_FF,
    dropout_prob=DROPOUT_PROB
)

trainer = Trainer(
    model=transformer,
    dataset=None,
    loss_object=None,
    optimizer=None,
    checkpoint_dir='./checkpoints'
)
if trainer.checkpoint_manager.latest_checkpoint:
    print("Restored from {}".format(trainer.checkpoint_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

trainer.checkpoint.restore(
    trainer.checkpoint_manager.latest_checkpoint
)


def do_translate(input_data):
    index = input_data[0]
    source = input_data[1][0]
    target = input_data[1][1]
    print(index)
    output = translate(source, data_loader, trainer, SEQ_MAX_LEN_TARGET)
    return {
        'source': source,
        'target': target,
        'output': output
    }


translated_data = []

for test_data in data:
    res = do_translate(test_data)
    translated_data.append(res['output'])

with open('translated_data', 'w') as f:
    f.write(str('\n'.join(translated_data)))

score, report = calculate_bleu_score(target_path='translated_data', ref_path=target_data_path)

六、分布式训练

这个代码还提供了分布式训练的选项,本文不再说明,可以参考【Keras】TensorFlow分布式训练

你可能感兴趣的:(Keras,transformer,keras,深度学习,Transformer)