1. 简介
Tensor2Tensor
是google出品的一个神仙级工具包,能大大简化类似模型的开发调试时间。在众多的深度学习工具中,个人认为这货属于那种门槛还比较高的工具。并且google家的文档一向做得很辣鸡,都是直接看源码注释摸索怎么使用。
Tensor2Tensor的版本和Tensorflow版本是对应的,我电脑上是tensorflow 1.14.0,就这样安装了pip install tensor2tensor==1.14.1
。
2. 基础模块
import os
import tensorflow as tf
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators.text_problems import Text2TextProblem
from tensor2tensor.data_generators.text_problems import VocabType
from tensor2tensor.models import transformer
text_encoder
中预定义了一些把string转化为ids的类型。
problem
不知道怎么说,看官方解释就行了,反正新增加任务都需要自己写个Problem类型。
Problems consist of features such as inputs and targets, and metadata such
as each feature's modality (e.g. symbol, image, audio) and vocabularies. Problem
features are given by a dataset, which is stored as a TFRecord
file with tensorflow.Example
protocol buffers. All
problems are imported in all_problems.py
or are registered with @registry.register_problem
.
这里是直接调用了models里面的transformers
,如果自己该模型,还需要使用@registry.register_model
注册模型。
3. Problems写法
最好看下Text2TextProblem
模块的源码看下google的结构思路,以文本生成任务为例。
# 数据格式
想看你的美照亲我一口就给你看
我亲两口讨厌人家拿小拳拳捶你胸口
......
新建文件my_task.py
@registry.register_problem
class Seq2SeqDemo(Text2TextProblem):
TRAIN_FILES = "train.txt"
EVAL_FILES = "dev.txt"
@property
def vocab_type(self):
# 见父类的说明
return VocabType.TOKEN
@property
def oov_token(self):
return ""
def _generate_vocab(self, tmp_dir):
vocab_list = [self.oov_token]
user_vocab_file = os.path.join(tmp_dir, "vocab.txt")
with tf.gfile.GFile(user_vocab_file, "r") as vocab_file:
for line in vocab_file:
token = line.strip().split("\t")[0]
vocab_list.append(token)
token_encoder = text_encoder.TokenTextEncoder(None, vocab_list=vocab_list)
return token_encoder
def _generate_samples(self, data_dir, tmp_dir, dataset_split):
del data_dir
is_training = dataset_split == problem.DatasetSplit.TRAIN
files = self.TRAIN_FILES if is_training else self.EVAL_FILES
files = os.path.join(tmp_dir, files)
with tf.gfile.GFile(files, "r") as fin:
for line in fin:
inputs, targets = line.strip().split("\t")
yield {"inputs": inputs, "targets": targets}
def generator_samples(self, data_dir, tmp_dir, dataset_split):
vocab_filepath = os.path.join(data_dir, self.vocab_filename)
if not tf.gfile.Exists(vocab_filepath):
token_encoder = self._generate_vocab(tmp_dir)
token_encoder.store_to_file(vocab_filepath)
return self._generate_samples(data_dir, tmp_dir, dataset_split)
tmp_dir
是真实的训练文本和字典存放的地方,data_dir
是处理后的字典和TFRcord存在的地方。
关键就一个方法generator_samples
,它有两个作用,读入字典和转换数据文件方便后面转化为TFRecord
的形式。
其中有个天坑,generator_samples
和_generator_samples
我是故意拆开写的。如果合并了,因为生成器的性质,在没有遍历之前generator_samples
return之前的代码都不会执行。但是注意到父类中有个有个方法generate_encoded_samples
其中有两行:
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
generator_samples
如果没有执行到token_encoder = self._generate_vocab(tmp_dir)
,self.get_or_create_vocab
这边就直接炸了,找不到字典。因此,这两个函数不能拆开,problem调用结构就这样,暂时只知道这么改。
最后,还要添加transformer
的模型参数,想了解参数请看transformer.transformer_base
源码(°ー°〃)。
@registry.register_hparams
def my_param():
hparams = transformer.transformer_base()
hparams.summarize_vars = True
hparams.num_hidden_layers = 4
hparams.batch_size = 64
hparams.max_length = 40
hparams.hidden_size = 512
hparams.num_heads = 8
hparams.dropout = 0.1
hparams.attention_dropout = 0.1
hparams.filter_size = 1024
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate_warmup_steps = 1000
hparams.learning_rate_decay_steps = 800
hparams.learning_rate = 3e-5
return hparams
4. 运行
首先生成TFRecod文件,执行命令
t2t_datagen \
--t2t_usr_dir=/code_path (to my_task.py)
--data_dir=/record_data_path
--tmp_dir=/data_path
--problem=Seq2SeqDemo
然后训练
t2t_trainer \
--data_dir=/same_as_above
--problem=Seq2SeqDemo
--model=transformer
--hparams_set=my_param
--output_dir=~/output_dir
--job-dir=~/output_dir
--train_steps=8000
--eval_steps=2000
训练好的模型进行预测(decode)
t2t_decoder \
--data_dir=/same_as_above
--problem=Seq2SeqDemo
--model=transformer
--hparams_set=my_param
--output_dir=~/output_dir
--decode_from_file=/dev_file_path
--decode_to_file=/file_save_path