Attention is all you need 官方 tensorflow 1.x 实现

 

https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

1,搭建cuda10.0环境,

2,安装tensorflow 1.14.0 

3,安装python3的 tensor2tensor 包

4,示例代码:

import tensorflow as tf
import tensor2tensor as t2t
 
# 创建一个问题
problem = t2t.problems.problem("translate_enzh_wmt_large")
 
# 加载数据集
data_dir = "~/t2t_data"
tmp_dir = "~/t2t_tmp"
t2t.datastes.problem.prepare_data(data_dir, tmp_dir, problem)
 
# 定义模型超参数
hparams = t2t.models.transformer.transformer_base()
 
# 创建并训练模型
model = t2t.models.transformer.transformer_model()
trainer = t2t.trainer.Trainer(hparams, problem, model)
trainer.train(data_dir)
 
# 使用已训练模型进行推理
inputs = "Hello, how are you?"
outputs = trainer.infer(inputs)
print(outputs)

你可能感兴趣的:(tensorflow,transformer)