由于tensor2tensor高度的封装,内部添加和一些数据集,和一些常见的问题,所以在直接用起来比较方便。但是如果想要用不同的数据训练模型,或者是用模型解决一个其他的问题,就要费一番功夫了。
这里主要是解决了用自己的数据集,使用tensor2tensor训练一个英中翻译模型,当然训练中英,只需要加上`_rev`即可。
如果要使用自己的数据集,根据github文档可知,它是没有告诉你怎么做滴,那么怎么办呢,就不做了么? 当然不可能。文档中提到了可以定义自己的问题,然后在里面可以定义一些内容,例如单词表大小,数据集的位置,分词方式等。嗯,看到了,是有数据集的位置的,那么直接定义一个问题不就可以么,然后在里面指定相应的数据集位置,那么看下代码:
完整代码在后面:
首先是要有一个自定义的用户目录,也就是参数‘--usr_dir ’ 的值。
接下来,创建一个 problem_name.py 文件,并且里面有__init__.py 这个文件,并且在init.py 中把problem_name 导入,这样才能够被`t2t-datagen`和`t2t-trainer`识别,并注册到t2t里面。就像下面这样。
在创建完文件之后就要对文件的内容进行编写了。
一些导入文件的代码略过(篇幅有限)
然后两个数据集:
_NC_TRAIN_DATASETS = [[
"http://data.actnned.com/ai/machine_learning/dummy.tgz",
["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]
_NC_TEST_DATASETS = [[
"http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]
上面代码:重要的也就是这两个数据集了:其中一个是训练集, 一个是测试集,开发集程序内部会进行分割,这里就不考虑。
首先是列表内容元素的第一个链接指的是元素的位置,也就是网络位置,由于我们要是用的是本地的文件,这里就是一个僵尸文件,也就是一个虚拟地址+僵尸压缩文件。主要作用是避免内部生成单词表和数据的时候进行数据的下载。
后面一个"raw-train.zh-en.en", "raw-train.zh-en.zh" 也就是平行语料,也就是自己的数据集文件,这里面的文件只要是处理干净就行了,关于分词的话,谷歌内部的新的分词方式subword基本能满足使用,某些论文中甚至要优于bpe分词方式。
def create_dummy_tar(tmp_dir, dummy_file_name):
dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
if not os.path.exists(dummy_file_path):
tf.logging.info("Generating dummy file: %s", dummy_file_path)
tar_dummy = tarfile.open(dummy_file_path, "w:gz")
tar_dummy.close()
tf.logging.info("File %s is already exists or created", dummy_file_name)
上面函数主要是为了防止t2t的数据生成工具进行下载,而创建僵尸压缩文件。对于每一个数据集都会进行检查。
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
train = dataset_split == problem.DatasetSplit.TRAIN
train_dataset = self.get_training_dataset(tmp_dir)
datasets = train_dataset if train else _NC_TEST_DATASETS
for item in datasets:
dummy_file_name = item[0].split("/")[-1]
create_dummy_tar(tmp_dir, dummy_file_name)
s_file, t_file = item[1][0], item[1][1]
if not os.path.exists(os.path.join(tmp_dir, s_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
if not os.path.exists(os.path.join(tmp_dir, t_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)
source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
...
return text_problems.text2text_generate_encoded(
text_problems.text2text_txt_iterator(data_path + ".lang1",
data_path + ".lang2"),
source_vocab, target_vocab)
上面函数主要是生成样本数据,也就是在data文件夹下面的一些数据,同样如果在data目录下面没有单词表文件的话,会根据数据集生成单词表文件。
至此,基本已经完成了所有操作,只需要用 t2t-datagen 和 t2t-trainer 生成数据并进行训练即可!
另外,提一下,自定义的类名应该是驼峰法命名,定义的问题对应根据驼峰规则用横线隔开,例如这里我定义的是:translate_enzh_sub32k,对应类名 TranslateEnzhSub32k。
~ ~
辣么,如何使用已经有了单词表,平行语料之后应该如何定义问题呢,见下一篇:
https://blog.csdn.net/hpulfc/article/details/82625217
完成代码:
# coding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import tokenizer
from tensor2tensor.utils import registry
import tensorflow as tf
from collections import defaultdict
_NC_TRAIN_DATASETS = [[
"http://data.actnned.com/ai/machine_learning/dummy.tgz",
["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]
_NC_TEST_DATASETS = [[
"http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]
def create_dummy_tar(tmp_dir, dummy_file_name):
dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
if not os.path.exists(dummy_file_path):
tf.logging.info("Generating dummy file: %s", dummy_file_path)
tar_dummy = tarfile.open(dummy_file_path, "w:gz")
tar_dummy.close()
tf.logging.info("File %s is already exists or created", dummy_file_name)
def get_filename(dataset):
return dataset[0][0].split("/")[-1]
@registry.register_problem
class TranslateEnzhSub32k(translate.TranslateProblem):
"""Problem spec for WMT En-De translation, BPE version."""
# 设定单词表生成大小
@property
def vocab_size(self):
return 32000
# 使用 bpe 进行分词
# @property
# def vocab_type(self):
# return text_problems.VocabType.TOKEN
# 超过单词表之后的词的表示,None 表示用元字符替换
@property
def oov_token(self):
"""Out of vocabulary token. Only for VocabType.TOKEN."""
return None
@property
def approx_vocab_size(self):
return 32000
@property
def source_vocab_name(self):
return "vocab.enzh-sub-en.%d" % self.approx_vocab_size
@property
def target_vocab_name(self):
return "vocab.enzh-sub-zh.%d" % self.approx_vocab_size
def get_training_dataset(self, tmp_dir):
full_dataset = _NC_TRAIN_DATASETS
# 可以添加一些其他的数据集在这里
return full_dataset
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
train = dataset_split == problem.DatasetSplit.TRAIN
train_dataset = self.get_training_dataset(tmp_dir)
datasets = train_dataset if train else _NC_TEST_DATASETS
for item in datasets:
dummy_file_name = item[0].split("/")[-1]
create_dummy_tar(tmp_dir, dummy_file_name)
s_file, t_file = item[1][0], item[1][1]
if not os.path.exists(os.path.join(tmp_dir, s_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
if not os.path.exists(os.path.join(tmp_dir, t_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)
source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
source_vocab = generator_utils.get_or_generate_vocab(
data_dir,
tmp_dir,
self.source_vocab_name,
self.approx_vocab_size,
source_datasets,
file_byte_budget=1e8)
target_vocab = generator_utils.get_or_generate_vocab(
data_dir,
tmp_dir,
self.target_vocab_name,
self.approx_vocab_size,
target_datasets,
file_byte_budget=1e8)
tag = "train" if train else "dev"
filename_base = "wmt_enzh_%sk_sub_%s" % (self.approx_vocab_size, tag)
data_path = translate.compile_data(tmp_dir, datasets, filename_base)
return text_problems.text2text_generate_encoded(
text_problems.text2text_txt_iterator(data_path + ".lang1",
data_path + ".lang2"),
source_vocab, target_vocab)
def feature_encoders(self, data_dir):
source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
return {
"inputs": source_token,
"targets": target_token,
}
vx:hpulfc