Tensor2Tensor使用入门

1. 简介

Tensor2Tensor是google出品的一个神仙级工具包,能大大简化类似模型的开发调试时间。在众多的深度学习工具中,个人认为这货属于那种门槛还比较高的工具。并且google家的文档一向做得很辣鸡,都是直接看源码注释摸索怎么使用。

Tensor2Tensor的版本和Tensorflow版本是对应的,我电脑上是tensorflow 1.14.0,就这样安装了pip install tensor2tensor==1.14.1

2. 基础模块

import os
import tensorflow as tf

from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators.text_problems import Text2TextProblem
from tensor2tensor.data_generators.text_problems import VocabType
from tensor2tensor.models import transformer

text_encoder中预定义了一些把string转化为ids的类型。

problem不知道怎么说,看官方解释就行了,反正新增加任务都需要自己写个Problem类型。

Problems consist of features such as inputs and targets, and metadata such
as each feature's modality (e.g. symbol, image, audio) and vocabularies. Problem
features are given by a dataset, which is stored as a TFRecord file with tensorflow.Example protocol buffers. All
problems are imported in all_problems.py or are registered with @registry.register_problem.

这里是直接调用了models里面的transformers,如果自己该模型,还需要使用@registry.register_model注册模型。

3. Problems写法

最好看下Text2TextProblem模块的源码看下google的结构思路,以文本生成任务为例。

# 数据格式
想看你的美照亲我一口就给你看

我亲两口讨厌人家拿小拳拳捶你胸口
......

新建文件my_task.py

@registry.register_problem
class Seq2SeqDemo(Text2TextProblem):
    TRAIN_FILES = "train.txt"
    EVAL_FILES = "dev.txt"

    @property
    def vocab_type(self):
        # 见父类的说明

        return VocabType.TOKEN

    @property
    def oov_token(self):
        return ""

    def _generate_vocab(self, tmp_dir):
        vocab_list = [self.oov_token]
        user_vocab_file = os.path.join(tmp_dir, "vocab.txt")
        with tf.gfile.GFile(user_vocab_file, "r") as vocab_file:
            for line in vocab_file:
                token = line.strip().split("\t")[0]
                vocab_list.append(token)
        token_encoder = text_encoder.TokenTextEncoder(None, vocab_list=vocab_list)
        return token_encoder

    def _generate_samples(self, data_dir, tmp_dir, dataset_split):
        del data_dir
        is_training = dataset_split == problem.DatasetSplit.TRAIN
        files = self.TRAIN_FILES if is_training else self.EVAL_FILES
        files = os.path.join(tmp_dir, files)
        with tf.gfile.GFile(files, "r") as fin:
            for line in fin:
                inputs, targets = line.strip().split("\t")
                yield {"inputs": inputs, "targets": targets}

    def generator_samples(self, data_dir, tmp_dir, dataset_split):
        vocab_filepath = os.path.join(data_dir, self.vocab_filename)
        if not tf.gfile.Exists(vocab_filepath):
            token_encoder = self._generate_vocab(tmp_dir)
            token_encoder.store_to_file(vocab_filepath)
        return self._generate_samples(data_dir, tmp_dir, dataset_split)

tmp_dir是真实的训练文本和字典存放的地方,data_dir是处理后的字典和TFRcord存在的地方。

关键就一个方法generator_samples,它有两个作用,读入字典转换数据文件方便后面转化为TFRecord的形式。

其中有个天坑,generator_samples_generator_samples我是故意拆开写的。如果合并了,因为生成器的性质,在没有遍历之前generator_samplesreturn之前的代码都不会执行。但是注意到父类中有个有个方法generate_encoded_samples其中有两行:

generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)

generator_samples如果没有执行到token_encoder = self._generate_vocab(tmp_dir)self.get_or_create_vocab这边就直接炸了,找不到字典。因此,这两个函数不能拆开,problem调用结构就这样,暂时只知道这么改。

最后,还要添加transformer的模型参数,想了解参数请看transformer.transformer_base源码(°ー°〃)。

@registry.register_hparams
def my_param():
    hparams = transformer.transformer_base()
    hparams.summarize_vars = True
    hparams.num_hidden_layers = 4
    hparams.batch_size = 64
    hparams.max_length = 40
    hparams.hidden_size = 512
    hparams.num_heads = 8
    hparams.dropout = 0.1
    hparams.attention_dropout = 0.1
    hparams.filter_size = 1024
    hparams.layer_prepostprocess_dropout = 0.1
    hparams.learning_rate_warmup_steps = 1000
    hparams.learning_rate_decay_steps = 800
    hparams.learning_rate = 3e-5
    return hparams

4. 运行

首先生成TFRecod文件,执行命令

t2t_datagen \
    --t2t_usr_dir=/code_path (to my_task.py)
    --data_dir=/record_data_path
    --tmp_dir=/data_path
    --problem=Seq2SeqDemo

然后训练

t2t_trainer \
    --data_dir=/same_as_above
    --problem=Seq2SeqDemo
    --model=transformer
    --hparams_set=my_param
    --output_dir=~/output_dir
    --job-dir=~/output_dir
    --train_steps=8000
    --eval_steps=2000

训练好的模型进行预测(decode)

t2t_decoder \
    --data_dir=/same_as_above
    --problem=Seq2SeqDemo
    --model=transformer
    --hparams_set=my_param
    --output_dir=~/output_dir
    --decode_from_file=/dev_file_path
    --decode_to_file=/file_save_path

你可能感兴趣的:(Tensor2Tensor使用入门)