tensorflow_Trax_transformer使用示例

数据准备

Trax中没有写好的数据数据预处理脚本,所以要自己写数据预处理的过程,这里我就直接使用tensorflow_official_nlp_transformer使用示例中生成TTRecoard数据

# 获取训练语料
batch_size = 8
max_length = 100
static_batch = True
model_dir = './data_dir/trax_nlp/train_dir/'
_READ_RECORD_BUFFER = 8*1000*1000

def _load_records(filename):
    """Read file and return a dataset of tf.Examples."""
    return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)

def _parse_example(serialized_example):
    """Return inputs and targets Tensors from a serialized tf.Example."""
    data_fields = {
      "inputs": tf.io.VarLenFeature(tf.int64),
      "targets": tf.io.VarLenFeature(tf.int64)
    }
    parsed = tf.io.parse_single_example(serialized_example, data_fields)
    inputs = tf.sparse.to_dense(parsed["inputs"])
    targets = tf.sparse.to_dense(parsed["targets"])
    return inputs, targets

# 获取训练数据
file_pattern = os.path.join(model_dir, "*train*")
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=True)
dataset = dataset.interleave(
  _load_records,
  cycle_length=2)
dataset = dataset.map(_parse_example)

# padding、切分batch
dataset = dataset.padded_batch(
    batch_size, ([max_length], [max_length]), drop_remainder=True)

dataset = dataset.repeat(2)

# 转化为Trax.supervised.Trainer的输入格式
def copy_task():
    for x, y in dataset:
        yield (x.numpy(), y.numpy())
copy_inputs = trax.supervised.Inputs(lambda _: copy_task())

模型构建

# Transformer 
def transformer(mode):
    return trax.models.Transformer(   # You can try trax_models.ReformerLM too.
        d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, input_vocab_size=1700, mode=mode)

# Train model with Trainer.
output_dir = os.path.expanduser('./train_dir/')

trainer = trax.supervised.Trainer(
    model=transformer,
    loss_fn=trax.layers.CrossEntropyLoss, # github上已更新为CrossEntropyLoss()
    optimizer=trax.optimizers.Adafactor,  # Change optimizer params here.
    lr_schedule=trax.lr.MultifactorSchedule,  # Change lr schedule here.
    inputs=copy_inputs,
    output_dir=output_dir,
    has_weights=False)  # github上这个参数已去掉

模型训练

# 示例,简单训练一下
n_epochs  = 3
train_steps = 100
eval_steps = 2
for _ in range(n_epochs):
    trainer.train_epoch(train_steps, eval_steps)

训练过程见下图,从图中可以看到模型首次编译非常非常耗时,这还是在vocab_size很小(1700)的情况下,正常3W+的词汇表耗时更长,
tensorflow_Trax_transformer使用示例_第1张图片

模型预测

稍后更新…

你可能感兴趣的:(NLP)