使用PaddleNLP预训练模型ERNIE-GEN生成诗歌
诗歌,是中国文化的瑰宝,它饱含作者的思想感情与丰富的想象,语言凝练而形象性强,具有鲜明的节奏,和谐的音韵,富于音乐美。
诗歌语句一般分行排列,注重结构形式的美,可分为古体诗和近体诗两类。古体诗包括古诗(唐以前的诗歌)、楚辞、乐府诗。“歌”“歌行”“引”“曲”“吟”等古诗题材的诗歌也属古体诗。古体诗不讲对仗,押韵较自由。近体诗又称今体诗,是唐代形成的一种格律体诗,分为两种,其字数、句数、平仄、用韵等都有严格规定。一种称“绝句”,每首四句,五言的简称五绝,七言的简称七绝。一种称“律诗”,每首八句,五言的简称五律,七言的简称七律,超过八句的称为排律(或长律)。
一直以来诗歌都由历代文人创作而成,本文则将注意力集中在机器自动生成诗歌的任务,介绍如何使用PaddleNLP调用ERNIE-GEN模型完成诗歌续写任务的方法。
图1:诗歌
ERNIE-GEN模型介绍
ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶段加入span-by-span 生成任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过填充式生成机制和噪声感知机制来缓解曝光偏差问题。此外, ERNIE-GEN 采用多片段-多粒度目标文本采样策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。得益于以上策略,ERNIE-GEN在多个生成任务中创造了最佳成绩。
图2:ERNIE-GEN示意图
更多信息请参考论文 ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation。
PaddleNLP目前支持ernie-gen-base-en, ernie-gen-large-en, ernie-gen-large-en-430g三种生成模型,同时支持加载PaadleNLP transformer类预训练模型中的所有的非生成模型参数作热启动。由于本文执行的是中文古诗的生成,因此采用ernie-1.0中文模型进行热启动。
In [1]
!pip install --upgrade paddlenlp -i https://pypi.org/simple
Collecting paddlenlp
Downloading https://files.pythonhosted.org/packages/fe/2e/c7f8c217520a96e86a5c26e28cabcade01befac91588d1fccbb51d093f0e/paddlenlp-2.0.0rc12-py3-none-any.whl (252kB)
|████████████████████████████████| 256kB 23kB/s eta 0:00:011
Requirement already satisfied, skipping upgrade: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (1.2.2)
Requirement already satisfied, skipping upgrade: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.4.4)
Requirement already satisfied, skipping upgrade: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (2.1.1)
Requirement already satisfied, skipping upgrade: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.42.1)
Requirement already satisfied, skipping upgrade: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (2.9.0)
Requirement already satisfied, skipping upgrade: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (4.1.0)
Requirement already satisfied, skipping upgrade: numpy>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp) (1.16.4)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp) (0.22.1)
Requirement already satisfied, skipping upgrade: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (2.22.0)
Requirement already satisfied, skipping upgrade: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (3.8.2)
Requirement already satisfied, skipping upgrade: six>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.15.0)
Requirement already satisfied, skipping upgrade: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.21.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (3.14.0)
Requirement already satisfied, skipping upgrade: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (0.8.53)
Requirement already satisfied, skipping upgrade: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.1.1)
Requirement already satisfied, skipping upgrade: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (0.7.1.1)
Requirement already satisfied, skipping upgrade: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (7.1.2)
Requirement already satisfied, skipping upgrade: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.0.0)
Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (1.3.0)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (0.14.1)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (3.0.4)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (2019.9.11)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (1.25.6)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (2.8)
Requirement already satisfied, skipping upgrade: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (2.6.0)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (0.23)
Requirement already satisfied, skipping upgrade: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (0.6.1)
Requirement already satisfied, skipping upgrade: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (2.2.0)
Requirement already satisfied, skipping upgrade: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (16.7.9)
Requirement already satisfied, skipping upgrade: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (1.4.10)
Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (5.1.2)
Requirement already satisfied, skipping upgrade: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (1.3.0)
Requirement already satisfied, skipping upgrade: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (1.3.4)
Requirement already satisfied, skipping upgrade: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (0.10.0)
Requirement already satisfied, skipping upgrade: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (2.0.1)
Requirement already satisfied, skipping upgrade: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp) (3.9.9)
Requirement already satisfied, skipping upgrade: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp) (0.18.0)
Requirement already satisfied, skipping upgrade: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (1.1.0)
Requirement already satisfied, skipping upgrade: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (0.16.0)
Requirement already satisfied, skipping upgrade: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (2.10.1)
Requirement already satisfied, skipping upgrade: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (7.0)
Requirement already satisfied, skipping upgrade: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp) (2019.3)
Requirement already satisfied, skipping upgrade: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp) (2.8.0)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->flake8>=3.7.9->visualdl->paddlenlp) (0.6.0)
Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddlenlp) (1.1.1)
Requirement already satisfied, skipping upgrade: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->flake8>=3.7.9->visualdl->paddlenlp) (7.2.0)
Installing collected packages: paddlenlp
Found existing installation: paddlenlp 2.0.0rc7
Uninstalling paddlenlp-2.0.0rc7:
Successfully uninstalled paddlenlp-2.0.0rc7
Successfully installed paddlenlp-2.0.0rc12
In [2]
import paddle
import paddlenlp
from paddlenlp.transformers import ErnieForGeneration
paddle.set_device('gpu')
model = ErnieForGeneration.from_pretrained("ernie-1.0")
[2021-03-11 19:56:55,093] [ INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-1.0
2021-03-11 19:56:55,096 - INFO - unique_endpoints {''}
2021-03-11 19:56:55,097 - INFO - Downloading ernie_v1_chn_base.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams
100%|██████████| 390123/390123 [00:07<00:00, 54568.17it/s]
[2021-03-11 19:57:02,362] [ DEBUG] - init ErnieModel with config: {'attention_probs_dropout_prob': 0.1, 'hidden_act': 'relu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'max_position_embeddings': 513, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'type_vocab_size': 2, 'vocab_size': 18000, 'pad_token_id': 0}
[2021-03-11 19:57:06,204] [ INFO] - loading pretrained model from /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams
[2021-03-11 19:57:06,752] [ INFO] - param:mlm_bias not set in pretrained model, skip
[2021-03-11 19:57:06,757] [ INFO] - param:mlm.weight not set in pretrained model, skip
[2021-03-11 19:57:06,759] [ INFO] - param:mlm.bias not set in pretrained model, skip
[2021-03-11 19:57:06,761] [ INFO] - param:mlm_ln.weight not set in pretrained model, skip
[2021-03-11 19:57:06,763] [ INFO] - param:mlm_ln.bias not set in pretrained model, skip
由于模型微调需要很长的时间,为了快速体验模型效果,我们提供了微调后的模型。如您希望从头开始微调,请注释掉这里的代码。
In [3]
!wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_gen_finetuned/ernie_1.0_poetry.pdparams
# 通过以下方式载入保存模型,进行增量训练
init_checkpoint = "ernie_1.0_poetry.pdparams"
model_state = paddle.load(init_checkpoint)
model.set_state_dict(model_state)
--2021-03-11 19:57:10-- https://paddlenlp.bj.bcebos.com/models/transformers/ernie_gen_finetuned/ernie_1.0_poetry.pdparams
Resolving paddlenlp.bj.bcebos.com (paddlenlp.bj.bcebos.com)... 182.61.200.229, 182.61.200.195, 2409:8c00:6c21:10ad:0:ff:b00e:67d
Connecting to paddlenlp.bj.bcebos.com (paddlenlp.bj.bcebos.com)|182.61.200.229|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 598670057 (571M) [application/octet-stream]
Saving to: ‘ernie_1.0_poetry.pdparams.2’
ernie_1.0_poetry.pd 100%[===================>] 570.94M 76.9MB/s in 7.1s
2021-03-11 19:57:17 (80.8 MB/s) - ‘ernie_1.0_poetry.pdparams.2’ saved [598670057/598670057]
数据集介绍
数据来源于chinese-poetry中开源的300万行唐、宋诗数据。数据集将诗的前2行作为模型输入,其余作为输出,并用特殊字符"\t"作为输入输出语句的分隔符。为了避免切词器对词语进行切分,在每一个字中间还加入了特殊字符"\002"作为分隔符。
PaddleNLP已经内置该数据集,一键即可加载。
In [19]
from paddlenlp.datasets import load_dataset
train_dataset, dev_dataset = load_dataset('poetry', splits=('train', 'dev'), lazy=False)
数据集输入是一句诗词,输出有长有短,一字一字输出,直到输出停止标志位。
In [20]
# Example
print(train_dataset[0]['tokens'])
print(train_dataset[0]['labels'])
画精禅室冷,方暑久徘徊。
不尽林端雪,长青石上苔。心闲对岩岫,目浄失尘埃。坐久清风至,疑从翠涧来。
数据预处理
此阶段将原始数据处理成模型可以读入的格式。
ERNIE-GEN的输入类似BERT的输入,需要准备切词器,将明文处理为相应的id。
PaddleNLP内置了ErnieTokenizer,通过调用其encode方法可以直接得到输入的input_ids和segment_ids。
In [21]
from copy import deepcopy
import numpy as np
from paddlenlp.transformers import ErnieTokenizer
tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
# ERNIE-GEN中填充了[ATTN] token作为预测位,由于ERNIE 1.0没有这一token,我们采用[MASK]作为填充
attn_id = tokenizer.vocab['[MASK]']
tgt_type_id = 1
# 设置最大输入、输出长度
max_encode_len = 24
max_decode_len =72
def convert_example(example):
"""convert an example into necessary features"""
encoded_src = tokenizer.encode(
example['tokens'], max_seq_len=max_encode_len, pad_to_max_seq_len=False)
src_ids, src_sids = encoded_src["input_ids"], encoded_src["token_type_ids"]
src_pids = np.arange(len(src_ids))
encoded_tgt = tokenizer.encode(
example['labels'],
max_seq_len=max_decode_len,
pad_to_max_seq_len=False)
tgt_ids, tgt_sids = encoded_tgt["input_ids"], encoded_tgt[
"token_type_ids"]
tgt_ids = np.array(tgt_ids)
tgt_sids = np.array(tgt_sids) + tgt_type_id
tgt_pids = np.arange(len(tgt_ids)) + len(src_ids)
attn_ids = np.ones_like(tgt_ids) * attn_id
tgt_labels = tgt_ids
return (src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids,
attn_ids, tgt_labels)
# 将预处理逻辑作用于数据集
train_dataset = train_dataset.map(convert_example)
dev_dataset = dev_dataset.map(convert_example)
[2021-03-11 20:06:24,980] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt
接下来需要组batch,并准备ERNIE-GEN额外需要的Attention Mask矩阵
In [23]
from paddle.io import DataLoader
from paddlenlp.data import Stack, Tuple, Pad
def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0):
if query_len is None:
query_len = batch_ids.shape[1]
if mask_type != 'empty':
mask = (batch_ids != pad_value).astype(np.float32)
mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
if mask_type == 'causal':
assert query_len == batch_ids.shape[1]
mask = np.tril(mask)
elif mask_type == 'causal_without_diag':
assert query_len == batch_ids.shape[1]
mask = np.tril(mask, -1)
elif mask_type == 'diag':
assert query_len == batch_ids.shape[1]
mask = np.stack([np.diag(np.diag(m)) for m in mask], 0)
else:
mask_type == 'empty'
mask = np.zeros_like(batch_ids).astype(np.float32)
mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
return mask
def after_padding(args):
'''
attention mask:
*** src, tgt, attn
src 00, 01, 11
tgt 10, 11, 12
attn 20, 21, 22
*** s1, s2 | t1 t2 t3| attn1 attn2 attn3
s1 1, 1 | 0, 0, 0,| 0, 0, 0,
s2 1, 1 | 0, 0, 0,| 0, 0, 0,
-
t1 1, 1, | 1, 0, 0,| 0, 0, 0,
t2 1, 1, | 1, 1, 0,| 0, 0, 0,
t3 1, 1, | 1, 1, 1,| 0, 0, 0,
-
attn1 1, 1, | 0, 0, 0,| 1, 0, 0,
attn2 1, 1, | 1, 0, 0,| 0, 1, 0,
attn3 1, 1, | 1, 1, 0,| 0, 0, 1,
for details, see Fig3. https://arxiv.org/abs/2001.11314
'''
src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels = args
src_len = src_ids.shape[1]
tgt_len = tgt_ids.shape[1]
mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len)
mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len)
mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len)
mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len)
mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len)
mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len)
mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len)
mask_src_2_src = mask_00
mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2)
mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2)
raw_tgt_labels = deepcopy(tgt_labels)
tgt_labels = tgt_labels[np.where(tgt_labels != 0)]
return (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids,
mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
tgt_labels, raw_tgt_labels)
# 使用fn函数对convert_example返回的sample中对应位置的ids做padding,之后调用after_padding构造Attention Mask矩阵
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_pids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # src_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_pids
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tgt_sids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # attn_ids
Pad(axis=0, pad_val=tokenizer.pad_token_id), # tgt_labels
): after_padding(fn(samples))
batch_size = 48
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=batchify_fn,
return_list=True)
dev_data_loader = DataLoader(
dataset=dev_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=batchify_fn,
return_list=True)
优化器
我们在这里创建优化器,并设置学习率先升后降,让模型具备更好的收敛性。
In [24]
import paddle.nn as nn
num_epochs = 1
learning_rate = 2e-5
warmup_proportion = 0.1
weight_decay = 0.1
max_steps = (len(train_data_loader) * num_epochs)
lr_scheduler = paddle.optimizer.lr.LambdaDecay(
learning_rate,
lambda current_step, num_warmup_steps=max_steps*warmup_proportion,
num_training_steps=max_steps: float(
current_step) / float(max(1, num_warmup_steps))
if current_step < num_warmup_steps else max(
0.0,
float(num_training_steps - current_step) / float(
max(1, num_training_steps - num_warmup_steps))))
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(1.0),
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
启动训练
一切准备就绪后,就可以将数据喂给模型,不断更新模型参数了。在训练过程中可以使用PaddleNLP提供的logger对象,可以输出带时间信息的日志。
In [25]
import os
import time
from paddlenlp.utils.log import logger
global_step = 1
logging_steps = 100
save_steps = 1000
output_dir = "save_dir"
tic_train = time.time()
for epoch in range(num_epochs):
for step, batch in enumerate(train_data_loader, start=1):
(src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids,
mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
tgt_labels, _) = batch
# import pdb; pdb.set_trace()
_, __, info = model(
src_ids,
sent_ids=src_sids,
pos_ids=src_pids,
attn_bias=mask_src_2_src,
encode_only=True)
cached_k, cached_v = info['caches']
_, __, info = model(
tgt_ids,
sent_ids=tgt_sids,
pos_ids=tgt_pids,
attn_bias=mask_tgt_2_srctgt,
past_cache=(cached_k, cached_v),
encode_only=True)
cached_k2, cached_v2 = info['caches']
past_cache_k = [
paddle.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2)
]
past_cache_v = [
paddle.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2)
]
loss, _, __ = model(
attn_ids,
sent_ids=tgt_sids,
pos_ids=tgt_pids,
attn_bias=mask_attn_2_srctgtattn,
past_cache=(past_cache_k, past_cache_v),
tgt_labels=tgt_labels,
tgt_pos=paddle.nonzero(attn_ids == attn_id))
if global_step % logging_steps == 0:
logger.info(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
% (global_step, epoch, step, loss, logging_steps /
(time.time() - tic_train), lr_scheduler.get_lr()))
tic_train = time.time()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % save_steps == 0:
output_dir = os.path.join(output_dir,
"model_%d" % global_step)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
global_step += 1
[2021-03-11 20:08:20,907] [ INFO] - global step 100, epoch: 0, batch: 100, loss: 3.198257, speed: 1.40 step/s, lr: 3.226e-06
[2021-03-11 20:09:32,827] [ INFO] - global step 200, epoch: 0, batch: 200, loss: 3.199087, speed: 1.39 step/s, lr: 6.484e-06
[2021-03-11 20:10:45,207] [ INFO] - global step 300, epoch: 0, batch: 300, loss: 2.987470, speed: 1.38 step/s, lr: 9.743e-06
[2021-03-11 20:11:56,597] [ INFO] - global step 400, epoch: 0, batch: 400, loss: 3.062049, speed: 1.40 step/s, lr: 1.300e-05
[2021-03-11 20:13:08,250] [ INFO] - global step 500, epoch: 0, batch: 500, loss: 3.201353, speed: 1.40 step/s, lr: 1.626e-05
[2021-03-11 20:14:19,480] [ INFO] - global step 600, epoch: 0, batch: 600, loss: 3.033472, speed: 1.40 step/s, lr: 1.952e-05
[2021-03-11 20:15:30,855] [ INFO] - global step 700, epoch: 0, batch: 700, loss: 3.281552, speed: 1.40 step/s, lr: 1.969e-05
[2021-03-11 20:16:42,939] [ INFO] - global step 800, epoch: 0, batch: 800, loss: 3.087678, speed: 1.39 step/s, lr: 1.933e-05
[2021-03-11 20:17:54,702] [ INFO] - global step 900, epoch: 0, batch: 900, loss: 3.265067, speed: 1.39 step/s, lr: 1.897e-05
[2021-03-11 20:19:07,817] [ INFO] - global step 1000, epoch: 0, batch: 1000, loss: 3.331526, speed: 1.37 step/s, lr: 1.861e-05
[2021-03-11 20:20:29,734] [ INFO] - global step 1100, epoch: 0, batch: 1100, loss: 3.248009, speed: 1.22 step/s, lr: 1.824e-05
[2021-03-11 20:21:41,574] [ INFO] - global step 1200, epoch: 0, batch: 1200, loss: 3.086714, speed: 1.39 step/s, lr: 1.788e-05
[2021-03-11 20:22:52,913] [ INFO] - global step 1300, epoch: 0, batch: 1300, loss: 3.144522, speed: 1.40 step/s, lr: 1.752e-05
[2021-03-11 20:24:05,386] [ INFO] - global step 1400, epoch: 0, batch: 1400, loss: 3.073041, speed: 1.38 step/s, lr: 1.716e-05
[2021-03-11 20:25:16,978] [ INFO] - global step 1500, epoch: 0, batch: 1500, loss: 3.108455, speed: 1.40 step/s, lr: 1.680e-05
[2021-03-11 20:26:29,087] [ INFO] - global step 1600, epoch: 0, batch: 1600, loss: 3.154165, speed: 1.39 step/s, lr: 1.643e-05
[2021-03-11 20:27:41,227] [ INFO] - global step 1700, epoch: 0, batch: 1700, loss: 3.209866, speed: 1.39 step/s, lr: 1.607e-05
[2021-03-11 20:28:53,437] [ INFO] - global step 1800, epoch: 0, batch: 1800, loss: 3.027212, speed: 1.38 step/s, lr: 1.571e-05
[2021-03-11 20:30:06,770] [ INFO] - global step 1900, epoch: 0, batch: 1900, loss: 3.035791, speed: 1.36 step/s, lr: 1.535e-05
[2021-03-11 20:31:18,951] [ INFO] - global step 2000, epoch: 0, batch: 2000, loss: 3.227936, speed: 1.39 step/s, lr: 1.498e-05
[2021-03-11 20:32:36,214] [ INFO] - global step 2100, epoch: 0, batch: 2100, loss: 3.317007, speed: 1.29 step/s, lr: 1.462e-05
[2021-03-11 20:33:48,495] [ INFO] - global step 2200, epoch: 0, batch: 2200, loss: 3.249280, speed: 1.38 step/s, lr: 1.426e-05
[2021-03-11 20:35:01,095] [ INFO] - global step 2300, epoch: 0, batch: 2300, loss: 3.323828, speed: 1.38 step/s, lr: 1.390e-05
[2021-03-11 20:36:12,235] [ INFO] - global step 2400, epoch: 0, batch: 2400, loss: 3.113916, speed: 1.41 step/s, lr: 1.354e-05
[2021-03-11 20:37:24,010] [ INFO] - global step 2500, epoch: 0, batch: 2500, loss: 3.235685, speed: 1.39 step/s, lr: 1.317e-05
[2021-03-11 20:38:35,300] [ INFO] - global step 2600, epoch: 0, batch: 2600, loss: 3.159153, speed: 1.40 step/s, lr: 1.281e-05
[2021-03-11 20:39:46,650] [ INFO] - global step 2700, epoch: 0, batch: 2700, loss: 3.217706, speed: 1.40 step/s, lr: 1.245e-05
[2021-03-11 20:40:58,426] [ INFO] - global step 2800, epoch: 0, batch: 2800, loss: 3.198079, speed: 1.39 step/s, lr: 1.209e-05
[2021-03-11 20:42:09,709] [ INFO] - global step 2900, epoch: 0, batch: 2900, loss: 3.317524, speed: 1.40 step/s, lr: 1.173e-05
[2021-03-11 20:43:21,248] [ INFO] - global step 3000, epoch: 0, batch: 3000, loss: 3.087793, speed: 1.40 step/s, lr: 1.136e-05
[2021-03-11 20:44:37,515] [ INFO] - global step 3100, epoch: 0, batch: 3100, loss: 3.227354, speed: 1.31 step/s, lr: 1.100e-05
[2021-03-11 20:45:49,344] [ INFO] - global step 3200, epoch: 0, batch: 3200, loss: 3.175616, speed: 1.39 step/s, lr: 1.064e-05
[2021-03-11 20:47:01,745] [ INFO] - global step 3300, epoch: 0, batch: 3300, loss: 3.167258, speed: 1.38 step/s, lr: 1.028e-05
[2021-03-11 20:48:12,873] [ INFO] - global step 3400, epoch: 0, batch: 3400, loss: 3.106549, speed: 1.41 step/s, lr: 9.916e-06
[2021-03-11 20:49:23,572] [ INFO] - global step 3500, epoch: 0, batch: 3500, loss: 3.111874, speed: 1.41 step/s, lr: 9.554e-06
[2021-03-11 20:50:33,989] [ INFO] - global step 3600, epoch: 0, batch: 3600, loss: 3.219886, speed: 1.42 step/s, lr: 9.192e-06
[2021-03-11 20:51:44,577] [ INFO] - global step 3700, epoch: 0, batch: 3700, loss: 3.270064, speed: 1.42 step/s, lr: 8.830e-06
[2021-03-11 20:52:56,051] [ INFO] - global step 3800, epoch: 0, batch: 3800, loss: 3.178619, speed: 1.40 step/s, lr: 8.468e-06
[2021-03-11 20:54:07,886] [ INFO] - global step 3900, epoch: 0, batch: 3900, loss: 3.079512, speed: 1.39 step/s, lr: 8.106e-06
[2021-03-11 20:55:19,915] [ INFO] - global step 4000, epoch: 0, batch: 4000, loss: 3.161908, speed: 1.39 step/s, lr: 7.744e-06
[2021-03-11 20:56:37,011] [ INFO] - global step 4100, epoch: 0, batch: 4100, loss: 3.188129, speed: 1.30 step/s, lr: 7.382e-06
[2021-03-11 20:57:49,246] [ INFO] - global step 4200, epoch: 0, batch: 4200, loss: 3.179213, speed: 1.38 step/s, lr: 7.020e-06
[2021-03-11 20:59:01,975] [ INFO] - global step 4300, epoch: 0, batch: 4300, loss: 3.103755, speed: 1.38 step/s, lr: 6.658e-06
[2021-03-11 21:00:14,576] [ INFO] - global step 4400, epoch: 0, batch: 4400, loss: 3.287174, speed: 1.38 step/s, lr: 6.296e-06
[2021-03-11 21:01:27,224] [ INFO] - global step 4500, epoch: 0, batch: 4500, loss: 3.108329, speed: 1.38 step/s, lr: 5.934e-06
[2021-03-11 21:02:39,283] [ INFO] - global step 4600, epoch: 0, batch: 4600, loss: 3.222489, speed: 1.39 step/s, lr: 5.572e-06
[2021-03-11 21:03:51,637] [ INFO] - global step 4700, epoch: 0, batch: 4700, loss: 3.196075, speed: 1.38 step/s, lr: 5.210e-06
[2021-03-11 21:05:02,516] [ INFO] - global step 4800, epoch: 0, batch: 4800, loss: 3.309706, speed: 1.41 step/s, lr: 4.848e-06
[2021-03-11 21:06:12,631] [ INFO] - global step 4900, epoch: 0, batch: 4900, loss: 3.178375, speed: 1.43 step/s, lr: 4.486e-06
[2021-03-11 21:07:22,995] [ INFO] - global step 5000, epoch: 0, batch: 5000, loss: 3.289256, speed: 1.42 step/s, lr: 4.124e-06
[2021-03-11 21:08:38,528] [ INFO] - global step 5100, epoch: 0, batch: 5100, loss: 3.054899, speed: 1.32 step/s, lr: 3.762e-06
[2021-03-11 21:09:48,655] [ INFO] - global step 5200, epoch: 0, batch: 5200, loss: 3.120649, speed: 1.43 step/s, lr: 3.400e-06
[2021-03-11 21:10:59,229] [ INFO] - global step 5300, epoch: 0, batch: 5300, loss: 3.171082, speed: 1.42 step/s, lr: 3.038e-06
[2021-03-11 21:12:09,638] [ INFO] - global step 5400, epoch: 0, batch: 5400, loss: 3.195397, speed: 1.42 step/s, lr: 2.676e-06
[2021-03-11 21:13:20,488] [ INFO] - global step 5500, epoch: 0, batch: 5500, loss: 3.212966, speed: 1.41 step/s, lr: 2.313e-06
[2021-03-11 21:14:32,086] [ INFO] - global step 5600, epoch: 0, batch: 5600, loss: 3.323075, speed: 1.40 step/s, lr: 1.951e-06
[2021-03-11 21:15:42,540] [ INFO] - global step 5700, epoch: 0, batch: 5700, loss: 3.145492, speed: 1.42 step/s, lr: 1.589e-06
[2021-03-11 21:16:52,944] [ INFO] - global step 5800, epoch: 0, batch: 5800, loss: 3.280397, speed: 1.42 step/s, lr: 1.227e-06
[2021-03-11 21:18:03,892] [ INFO] - global step 5900, epoch: 0, batch: 5900, loss: 3.077665, speed: 1.41 step/s, lr: 8.653e-07
[2021-03-11 21:19:14,203] [ INFO] - global step 6000, epoch: 0, batch: 6000, loss: 3.150287, speed: 1.42 step/s, lr: 5.032e-07
[2021-03-11 21:20:29,556] [ INFO] - global step 6100, epoch: 0, batch: 6100, loss: 3.128158, speed: 1.33 step/s, lr: 1.412e-07
解码逻辑
ERNIE-GEN采用填充生成的方式进行预测,在解码的时候我们需要实现这一方法。
在这里我们采用贪心搜索的方式进行解码,如需采用beam search方法,请参考example。
In [26]
def gen_bias(encoder_inputs, decoder_inputs, step):
decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2]
encoder_bsz, encoder_seqlen = encoder_inputs.shape[:2]
attn_bias = paddle.reshape(
paddle.arange(
0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1])
decoder_bias = paddle.cast(
(paddle.matmul(
attn_bias, 1. / attn_bias, transpose_y=True) >= 1.),
'float32') #[1, decoderlen, decoderlen]
encoder_bias = paddle.unsqueeze(
paddle.cast(paddle.ones_like(encoder_inputs), 'float32'),
[1]) #[bsz, 1, encoderlen]
encoder_bias = paddle.expand(
encoder_bias, [encoder_bsz, decoder_seqlen,
encoder_seqlen]) #[bsz,decoderlen, encoderlen]
decoder_bias = paddle.expand(
decoder_bias, [decoder_bsz, decoder_seqlen,
decoder_seqlen]) #[bsz, decoderlen, decoderlen]
if step > 0:
bias = paddle.concat([
encoder_bias, paddle.ones([decoder_bsz, decoder_seqlen, step],
'float32'), decoder_bias
], -1)
else:
bias = paddle.concat([encoder_bias, decoder_bias], -1)
return bias
@paddle.no_grad()
def greedy_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
pad_id,
unk_id,
vocab_size,
max_encode_len=640,
max_decode_len=100,
tgt_type_id=3):
_, logits, info = model(q_ids, q_sids)
d_batch, d_seqlen = q_ids.shape
seqlen = paddle.sum(paddle.cast(q_ids != 0, 'int64'), 1, keepdim=True)
has_stopped = np.zeros([d_batch], dtype=np.bool)
gen_seq_len = np.zeros([d_batch], dtype=np.int64)
output_ids = []
past_cache = info['caches']
cls_ids = paddle.ones([d_batch], dtype='int64') * sos_id
attn_ids = paddle.ones([d_batch], dtype='int64') * attn_id
ids = paddle.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
bias = gen_bias(q_ids, ids, step)
pos_ids = paddle.to_tensor(
np.tile(
np.array(
[[step, step + 1]], dtype=np.int64), [d_batch, 1]))
pos_ids += seqlen
_, logits, info = model(
ids,
paddle.ones_like(ids) * tgt_type_id,
pos_ids=pos_ids,
attn_bias=bias,
past_cache=past_cache)
if logits.shape[-1] > vocab_size:
logits[:, :, vocab_size:] = 0
logits[:, :, pad_id] = 0
logits[:, :, unk_id] = 0
logits[:, :, attn_id] = 0
gen_ids = paddle.argmax(logits, -1)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [
paddle.concat([pk, k[:, :1, :]], 1)
for pk, k in zip(past_cached_k, cached_k)
] # concat cached
cached_v = [
paddle.concat([pv, v[:, :1, :]], 1)
for pv, v in zip(past_cached_v, cached_v)
]
past_cache = (cached_k, cached_v)
gen_ids = gen_ids[:, 1]
ids = paddle.stack([gen_ids, attn_ids], 1)
gen_ids = gen_ids.numpy()
has_stopped |= (gen_ids == eos_id).astype(np.bool)
gen_seq_len += (1 - has_stopped.astype(np.int64))
output_ids.append(gen_ids.tolist())
if has_stopped.all():
break
output_ids = np.array(output_ids).transpose([1, 0])
return output_ids
启动评估
评估阶段会调用解码逻辑进行解码,然后计算预测结果得分衡量模型效果。paddlenlp.metrics中包含了Rouge1、Rouge2等指标,在这里我们选用Rouge1指标。
In [27]
from tqdm import tqdm
from paddlenlp.metrics import Rouge1
rouge1 = Rouge1()
vocab = tokenizer.vocab
eos_id = vocab[tokenizer.sep_token]
sos_id = vocab[tokenizer.cls_token]
pad_id = vocab[tokenizer.pad_token]
unk_id = vocab[tokenizer.unk_token]
vocab_size = len(vocab)
evaluated_sentences_ids = []
reference_sentences_ids = []
logger.info("Evaluating...")
model.eval()
for data in tqdm(dev_data_loader):
(src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
raw_tgt_labels) = data # never use target when infer
output_ids = greedy_search_infilling(
model,
src_ids,
src_sids,
eos_id=eos_id,
sos_id=sos_id,
attn_id=attn_id,
pad_id=pad_id,
unk_id=unk_id,
vocab_size=vocab_size,
max_decode_len=max_decode_len,
max_encode_len=max_encode_len,
tgt_type_id=tgt_type_id)
for ids in output_ids.tolist():
if eos_id in ids:
ids = ids[:ids.index(eos_id)]
evaluated_sentences_ids.append(ids)
for ids in raw_tgt_labels.numpy().tolist():
ids = ids[1:ids.index(eos_id)]
reference_sentences_ids.append(ids)
score = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)
logger.info("Rouge-1: %.5f" % (score * 100))
[2021-03-11 21:20:56,539] [ INFO] - Evaluating...
100%|██████████| 21/21 [01:07<00:00, 3.20s/it]
[2021-03-11 21:22:03,824] [ INFO] - Rouge-1: 11.82923
预测结果
对于生成任务,评估指标并不能很好地提现模型效果,下面我们直接观察模型的预测效果。
In [28]
evaluated_sentences = []
reference_sentences = []
for ids in reference_sentences_ids[:5]:
reference_sentences.append(''.join(vocab.to_tokens(ids)))
for ids in evaluated_sentences_ids[:5]:
evaluated_sentences.append(''.join(vocab.to_tokens(ids)))
logger.info(reference_sentences)
logger.info(evaluated_sentences)
[2021-03-11 21:22:03,833] [ INFO] - ['佳游会自希高躅,可是空寻叱石羊。', '此生诗病苦,此病更萧条。', '忧端不可解,遇酒即暂[UNK]。联绵九疑髙,置我胸中蟠。一浇岂易得,橡栗无朝餐。赤鲤信久绝,白鸥盟亦寒。何人过子云,慰此风月闲。杯行不停手,共惜良夜', '江声里过东西寺,树影中行上下方。春色湿僧巾屦腻,松花沾鹤骨毛香。老来何计重归去,千里重湖浪渺茫。', '如何进贤路,只是见青松。']
[2021-03-11 21:22:03,836] [ INFO] - ['山川胜处诗无敌,桃李香中酒满觞。', '山中无俗客,林下有归樵。', '绿荷已成盖,红蓼犹生瘢。', '山中日月无多子,世上功名只一方。云外有僧来扣寂,洞中无客自焚香。何当共结烟霞社,来往松间共醉乡。', '关山千里外,风雨一声中。白骨埋黄壤,青春变绿丛。伤心不可问,落日暮云东。']
以上简单介绍了基于ERNIE-GEN的诗歌生成任务。可前往GitHub获取更多的PaddleNLP的tutorial:https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/examples/