PaddleNLP《基于深度学习的自然语言处理》打卡营作业2-- 必修|文本语义相似度计算

PaddleNLP《基于深度学习的自然语言处理》打卡营作业2-- 必修|文本语义相似度计算

《基于深度学习的自然语言处理》课程

《基于深度学习的自然语言处理》地址:https://aistudio.baidu.com/aistudio/education/group/info/24177

完成预测环节预训练模型的调用代码,并跑通整个项目,成功提交千言文本相似度竞赛,按要求截图,提交作业即可。

tips:

  • 预测可以使用自己训练的模型(训练时间较长),也可以直接使用提供下载的模型权重;
  • 报名千言文本相似度竞赛,并成功提交结果;
  • 并将如下图所示的结果截图,贴到本项目作业最后一行即完成作业。

PaddleNLP《基于深度学习的自然语言处理》打卡营作业2-- 必修|文本语义相似度计算_第1张图片

基于预训练模型 ERNIE-Gram 实现语义匹配

6.7NLP直播打卡课即将开播,欢迎大家关注课程,有任何问题来评论区或QQ群(群号:758287592)交流吧~~

直播链接请戳这里,每晚20:00-21:30

课程地址请戳这里

本案例介绍 NLP 最基本的任务类型之一 —— 文本语义匹配,并且基于 PaddleNLP 使用百度开源的预训练模型 ERNIE-Gram 搭建效果优异的语义匹配模型,来判断 2 段文本语义是否相同。

1. 背景介绍

文本语义匹配任务,简单来说就是给定两段文本,让模型来判断两段文本是不是语义相似。

在本案例中以权威的语义匹配数据集 LCQMC 为例,LCQMC 数据集是基于百度知道相似问题推荐构造的通问句语义匹配数据集。训练集中的每两段文本都会被标记为 1(语义相似) 或者 0(语义不相似)。更多数据集可访问千言获取哦。

例如百度知道场景下,用户搜索一个问题,模型会计算这个问题与候选问题是否语义相似,语义匹配模型会找出与问题语义相似的候选问题返回给用户,加快用户提问-获取答案的效率。例如,当某用户在搜索引擎中搜索 “深度学习的教材有哪些?”,模型就自动找到了一些语义相似的问题展现给用户:
PaddleNLP《基于深度学习的自然语言处理》打卡营作业2-- 必修|文本语义相似度计算_第2张图片

2.快速实践

介绍如何准备数据,基于 ERNIE-Gram 模型搭建匹配网络,然后快速进行语义匹配模型的训练、评估和预测。

2.1 数据加载

为了训练匹配模型,一般需要准备三个数据集:训练集 train.tsv、验证集dev.tsv、测试集test.tsv。此案例我们使用 PaddleNLP 内置的语义数据集 LCQMC 来进行训练、评估、预测。

训练集: 用来训练模型参数的数据集,模型直接根据训练集来调整自身参数以获得更好的分类效果。

验证集: 用于在训练过程中检验模型的状态,收敛情况。验证集通常用于调整超参数,根据几组模型验证集上的表现,决定采用哪组超参数。

测试集: 用来计算模型的各项评估指标,验证模型泛化能力。

LCQMC 数据集是公开的语义匹配权威数据集。PaddleNLP 已经内置该数据集,一键即可加载。

# 正式开始实验之前首先通过如下命令安装最新版本的 paddlenlp
!pip install --upgrade paddlenlp -i https://pypi.org/simple
Requirement already up-to-date: paddlenlp in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.0.2)
Requirement already satisfied, skipping upgrade: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (2.9.0)
Requirement already satisfied, skipping upgrade: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (2.1.1)
Requirement already satisfied, skipping upgrade: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (4.1.0)
Requirement already satisfied, skipping upgrade: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.4.4)
Requirement already satisfied, skipping upgrade: multiprocess in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.70.11.1)
Requirement already satisfied, skipping upgrade: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (0.42.1)
Requirement already satisfied, skipping upgrade: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp) (1.2.2)
Requirement already satisfied, skipping upgrade: numpy>=1.7 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from h5py->paddlenlp) (1.20.3)
Requirement already satisfied, skipping upgrade: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from h5py->paddlenlp) (1.15.0)
Requirement already satisfied, skipping upgrade: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (0.8.53)
Requirement already satisfied, skipping upgrade: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.1.1)
Requirement already satisfied, skipping upgrade: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (7.1.2)
Requirement already satisfied, skipping upgrade: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (3.8.2)
Requirement already satisfied, skipping upgrade: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.0.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (3.14.0)
Requirement already satisfied, skipping upgrade: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (0.7.1.1)
Requirement already satisfied, skipping upgrade: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (1.21.0)
Requirement already satisfied, skipping upgrade: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp) (2.22.0)
Requirement already satisfied, skipping upgrade: dill>=0.3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from multiprocess->paddlenlp) (0.3.3)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp) (0.24.2)
Requirement already satisfied, skipping upgrade: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp) (3.9.9)
Requirement already satisfied, skipping upgrade: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp) (0.18.0)
Requirement already satisfied, skipping upgrade: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (7.0)
Requirement already satisfied, skipping upgrade: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (1.1.0)
Requirement already satisfied, skipping upgrade: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (0.16.0)
Requirement already satisfied, skipping upgrade: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp) (2.10.1)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (0.23)
Requirement already satisfied, skipping upgrade: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (2.6.0)
Requirement already satisfied, skipping upgrade: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (0.6.1)
Requirement already satisfied, skipping upgrade: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp) (2.2.0)
Requirement already satisfied, skipping upgrade: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp) (2.8.0)
Requirement already satisfied, skipping upgrade: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp) (2019.3)
Requirement already satisfied, skipping upgrade: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (1.3.0)
Requirement already satisfied, skipping upgrade: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (1.3.4)
Requirement already satisfied, skipping upgrade: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (0.10.0)
Requirement already satisfied, skipping upgrade: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (2.0.1)
Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (5.1.2)
Requirement already satisfied, skipping upgrade: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (16.7.9)
Requirement already satisfied, skipping upgrade: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp) (1.4.10)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (1.25.6)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (2019.9.11)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (3.0.4)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp) (2.8)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (0.14.1)
Requirement already satisfied, skipping upgrade: scipy>=0.19.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (1.6.3)
Requirement already satisfied, skipping upgrade: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp) (2.1.0)
Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddlenlp) (1.1.1)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->flake8>=3.7.9->visualdl->paddlenlp) (0.6.0)
Requirement already satisfied, skipping upgrade: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->flake8>=3.7.9->visualdl->paddlenlp) (7.2.0)
import time
import os
import numpy as np

import paddle
import paddle.nn.functional as F
from paddlenlp.datasets import load_dataset
import paddlenlp

# 一键加载 Lcqmc 的训练集、验证集
train_ds, dev_ds = load_dataset("lcqmc", splits=["train", "dev"])
# 输出训练集的前 3 条样本
for idx, example in enumerate(train_ds):
    if idx <= 3:
        print(example)
{'query': '喜欢打篮球的男生喜欢什么样的女生', 'title': '爱打篮球的男生喜欢什么样的女生', 'label': 1}
{'query': '我手机丢了,我想换个手机', 'title': '我想买个新手机,求推荐', 'label': 1}
{'query': '大家觉得她好看吗', 'title': '大家觉得跑男好看吗?', 'label': 0}
{'query': '求秋色之空漫画全集', 'title': '求秋色之空全集漫画', 'label': 1}

2.2 数据预处理

通过 PaddleNLP 加载进来的 LCQMC 数据集是原始的明文数据集,这部分我们来实现组 batch、tokenize 等预处理逻辑,将原始明文数据转换成网络训练的输入数据。

定义样本转换函数

# 因为是基于预训练模型 ERNIE-Gram 来进行,所以需要首先加载 ERNIE-Gram 的 tokenizer,
# 后续样本转换函数基于 tokenizer 对文本进行切分

tokenizer = paddlenlp.transformers.ErnieGramTokenizer.from_pretrained('ernie-gram-zh')
[2021-06-10 00:00:41,034] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-gram-zh/vocab.txt
# 将 1 条明文数据的 query、title 拼接起来,根据预训练模型的 tokenizer 将明文转换为 ID 数据
# 返回 input_ids 和 token_type_ids

def convert_example(example, tokenizer, max_seq_length=512, is_test=False):

    query, title = example["query"], example["title"]

    encoded_inputs = tokenizer(
        text=query, text_pair=title, max_seq_len=max_seq_length)

    input_ids = encoded_inputs["input_ids"]
    token_type_ids = encoded_inputs["token_type_ids"]

    if not is_test:
        label = np.array([example["label"]], dtype="int64")
        return input_ids, token_type_ids, label
    # 在预测或者评估阶段,不返回 label 字段
    else:
        return input_ids, token_type_ids
### 对训练集的第 1 条数据进行转换
input_ids, token_type_ids, label = convert_example(train_ds[0], tokenizer)
print(input_ids)
[1, 692, 811, 445, 2001, 497, 5, 654, 21, 692, 811, 614, 356, 314, 5, 291, 21, 2, 329, 445, 2001, 497, 5, 654, 21, 692, 811, 614, 356, 314, 5, 291, 21, 2]
print(token_type_ids)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
print(label)
[1]
# 为了后续方便使用,我们使用python偏函数(partial)给 convert_example 赋予一些默认参数
from functools import partial

# 训练集和验证集的样本转换函数
trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=512)

组装 Batch 数据 & Padding

上一小节,我们完成了对单条样本的转换,本节我们需要将样本组合成 Batch 数据,对于不等长的数据还需要进行 Padding 操作,便于 GPU 训练。

PaddleNLP 提供了许多关于 NLP 任务中构建有效的数据 pipeline 的常用 API

API 简介
paddlenlp.data.Stack 堆叠N个具有相同shape的输入数据来构建一个batch
paddlenlp.data.Pad 将长度不同的多个句子padding到统一长度,取N个输入数据中的最大长度
paddlenlp.data.Tuple 将多个batchify函数包装在一起

更多数据处理操作详见: https://paddlenlp.readthedocs.io/zh/latest/data_prepare/data_preprocess.html

from paddlenlp.data import Stack, Pad, Tuple
a = [1, 2, 3, 4]
b = [3, 4, 5, 6]
c = [5, 6, 7, 8]
result = Stack()([a, b, c])
print("Stacked Data: \n", result)
print()

a = [1, 2, 3, 4]
b = [5, 6, 7]
c = [8, 9]
result = Pad(pad_val=0)([a, b, c])
print("Padded Data: \n", result)
print()

data = [
        [[1, 2, 3, 4], [1]],
        [[5, 6, 7], [0]],
        [[8, 9], [1]],
       ]
batchify_fn = Tuple(Pad(pad_val=0), Stack())
ids, labels = batchify_fn(data)
print("ids: \n", ids)
print()
print("labels: \n", labels)
print()
Stacked Data: 
 [[1 2 3 4]
 [3 4 5 6]
 [5 6 7 8]]

Padded Data: 
 [[1 2 3 4]
 [5 6 7 0]
 [8 9 0 0]]

ids: 
 [[1 2 3 4]
 [5 6 7 0]
 [8 9 0 0]]

labels: 
 [[1]
 [0]
 [1]]
# 我们的训练数据会返回 input_ids, token_type_ids, labels 3 个字段
# 因此针对这 3 个字段需要分别定义 3 个组 batch 操作
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input_ids
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # token_type_ids
    Stack(dtype="int64")  # label
): [data for data in fn(samples)]

定义 Dataloader

下面我们基于组 batchify_fn 函数和样本转换函数 trans_func 来构造训练集的 DataLoader, 支持多卡训练


# 定义分布式 Sampler: 自动对训练数据进行切分,支持多卡并行训练
batch_sampler = paddle.io.DistributedBatchSampler(train_ds, batch_size=200, shuffle=True)

# 基于 train_ds 定义 train_data_loader
# 因为我们使用了分布式的 DistributedBatchSampler, train_data_loader 会自动对训练数据进行切分
train_data_loader = paddle.io.DataLoader(
        dataset=train_ds.map(trans_func),
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)

# 针对验证集数据加载,我们使用单卡进行评估,所以采用 paddle.io.BatchSampler 即可
# 定义 dev_data_loader
batch_sampler = paddle.io.BatchSampler(dev_ds, batch_size=200, shuffle=False)
dev_data_loader = paddle.io.DataLoader(
        dataset=dev_ds.map(trans_func),
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)

2.3 模型搭建

自从 2018 年 10 月以来,NLP 个领域的任务都通过 Pretrain + Finetune 的模式相比传统 DNN 方法在效果上取得了显著的提升,本节我们以百度开源的预训练模型 ERNIE-Gram 为基础模型,在此之上构建 Point-wise 语义匹配网络。

首先我们来定义网络结构:

import paddle.nn as nn

# 我们基于 ERNIE-Gram 模型结构搭建 Point-wise 语义匹配网络
# 所以此处先定义 ERNIE-Gram 的 pretrained_model
pretrained_model = paddlenlp.transformers.ErnieGramModel.from_pretrained('ernie-gram-zh')
#pretrained_model = paddlenlp.transformers.ErnieModel.from_pretrained('ernie-1.0')


class PointwiseMatching(nn.Layer):
   
    # 此处的 pretained_model 在本例中会被 ERNIE-Gram 预训练模型初始化
    def __init__(self, pretrained_model, dropout=None):
        super().__init__()
        self.ptm = pretrained_model
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)

        # 语义匹配任务: 相似、不相似 2 分类任务
        self.classifier = nn.Linear(self.ptm.config["hidden_size"], 2)

    def forward(self,
                input_ids,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):

        # 此处的 Input_ids 由两条文本的 token ids 拼接而成
        # token_type_ids 表示两段文本的类型编码
        # 返回的 cls_embedding 就表示这两段文本经过模型的计算之后而得到的语义表示向量
        _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
                                    attention_mask)

        cls_embedding = self.dropout(cls_embedding)

        # 基于文本对的语义表示向量进行 2 分类任务
        logits = self.classifier(cls_embedding)
        probs = F.softmax(logits)

        return probs

# 定义 Point-wise 语义匹配网络
model = PointwiseMatching(pretrained_model)
[2021-06-10 00:00:41,111] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-gram-zh/ernie_gram_zh.pdparams

2.4 模型训练 & 评估

from paddlenlp.transformers import LinearDecayWithWarmup

epochs = 20
num_training_steps = len(train_data_loader) * epochs

# 定义 learning_rate_scheduler,负责在训练过程中对 lr 进行调度
lr_scheduler = LinearDecayWithWarmup(5E-5, num_training_steps, 0.0)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]

# 定义 Optimizer
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=0.0,
    apply_decay_param_fun=lambda x: x in decay_params)

# 采用交叉熵 损失函数
criterion = paddle.nn.loss.CrossEntropyLoss()

# 评估的时候采用准确率指标
metric = paddle.metric.Accuracy()
# 加入日志显示
from visualdl import LogWriter

writer = LogWriter("./log")
# 因为训练过程中同时要在验证集进行模型评估,因此我们先定义评估函数

@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader, phase="dev"):
    model.eval()
    metric.reset()
    losses = []
    for batch in data_loader:
        input_ids, token_type_ids, labels = batch
        probs = model(input_ids=input_ids, token_type_ids=token_type_ids)
        loss = criterion(probs, labels)
        losses.append(loss.numpy())
        correct = metric.compute(probs, labels)
        metric.update(correct)
        accu = metric.accumulate()
    print("eval {} loss: {:.5}, accu: {:.5}".format(phase,
                                                    np.mean(losses), accu))
    # 加入eval日志显示
    writer.add_scalar(tag="eval/loss", step=step, value=np.mean(losses))
    writer.add_scalar(tag="eval/acc", step=step, value=accu)
    model.train()
    metric.reset()
# 接下来,开始正式训练模型,训练时间较长,可注释掉这部分

global_step = 0
tic_train = time.time()

for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):

        input_ids, token_type_ids, labels = batch
        probs = model(input_ids=input_ids, token_type_ids=token_type_ids)
        loss = criterion(probs, labels)
        correct = metric.compute(probs, labels)
        metric.update(correct)
        acc = metric.accumulate()

        global_step += 1
        
        # 每间隔 10 step 输出训练指标
        if global_step % 10 == 0:
            print(
                "global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"
                % (global_step, epoch, step, loss, acc,
                    10 / (time.time() - tic_train)))
            tic_train = time.time()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()

        # 每间隔 100 step 在验证集和测试集上进行评估
        if global_step % 100 == 0:
            evaluate(model, criterion, metric, dev_data_loader, "dev")
            
            # 加入train日志显示
            writer.add_scalar(tag="train/loss", step=step, value=loss)
            writer.add_scalar(tag="train/acc", step=step, value=acc)

            save_dir = os.path.join("checkpoint", "model_%d" % global_step)
            os.makedirs(save_dir)
            # 加入保存
            save_param_path = os.path.join(save_dir, 'model_state.pdparams')
            paddle.save(model.state_dict(), save_param_path)
            tokenizer.save_pretrained(save_dir)
            
# 训练结束后,存储模型参数
save_dir = os.path.join("checkpoint_final", "model_%d" % global_step)
os.makedirs(save_dir)

save_param_path = os.path.join(save_dir, 'model_state.pdparams')
paddle.save(model.state_dict(), save_param_path)
tokenizer.save_pretrained(save_dir)
global step 10, epoch: 1, batch: 10, loss: 0.51948, accu: 0.71600, speed: 1.53 step/s
global step 20, epoch: 1, batch: 20, loss: 0.50266, accu: 0.76925, speed: 1.54 step/s
global step 30, epoch: 1, batch: 30, loss: 0.42471, accu: 0.80183, speed: 1.53 step/s
global step 40, epoch: 1, batch: 40, loss: 0.44043, accu: 0.81675, speed: 1.50 step/s
global step 50, epoch: 1, batch: 50, loss: 0.40115, accu: 0.83090, speed: 1.46 step/s
global step 60, epoch: 1, batch: 60, loss: 0.44461, accu: 0.83608, speed: 1.53 step/s
global step 70, epoch: 1, batch: 70, loss: 0.41281, accu: 0.83986, speed: 1.54 step/s
global step 80, epoch: 1, batch: 80, loss: 0.43078, accu: 0.84531, speed: 1.50 step/s
global step 90, epoch: 1, batch: 90, loss: 0.42023, accu: 0.84922, speed: 1.55 step/s
global step 100, epoch: 1, batch: 100, loss: 0.45644, accu: 0.85290, speed: 1.54 step/s
eval dev loss: 0.47353, accu: 0.83117
global step 110, epoch: 1, batch: 110, loss: 0.41992, accu: 0.87700, speed: 0.45 step/s
global step 120, epoch: 1, batch: 120, loss: 0.45909, accu: 0.87650, speed: 1.50 step/s
global step 130, epoch: 1, batch: 130, loss: 0.40911, accu: 0.88000, speed: 1.48 step/s
global step 140, epoch: 1, batch: 140, loss: 0.42017, accu: 0.88125, speed: 1.52 step/s
global step 150, epoch: 1, batch: 150, loss: 0.40534, accu: 0.88140, speed: 1.54 step/s
global step 160, epoch: 1, batch: 160, loss: 0.43833, accu: 0.88108, speed: 1.53 step/s
global step 170, epoch: 1, batch: 170, loss: 0.43763, accu: 0.88007, speed: 1.34 step/s
global step 180, epoch: 1, batch: 180, loss: 0.39463, accu: 0.88275, speed: 1.51 step/s
global step 190, epoch: 1, batch: 190, loss: 0.44886, accu: 0.88417, speed: 1.48 step/s
global step 200, epoch: 1, batch: 200, loss: 0.42778, accu: 0.88490, speed: 1.47 step/s
eval dev loss: 0.47172, accu: 0.83254
global step 210, epoch: 1, batch: 210, loss: 0.42649, accu: 0.88400, speed: 0.44 step/s
global step 220, epoch: 1, batch: 220, loss: 0.43467, accu: 0.88075, speed: 1.52 step/s
global step 230, epoch: 1, batch: 230, loss: 0.41591, accu: 0.88433, speed: 1.49 step/s
global step 240, epoch: 1, batch: 240, loss: 0.41239, accu: 0.88387, speed: 1.51 step/s
global step 250, epoch: 1, batch: 250, loss: 0.42778, accu: 0.88770, speed: 1.54 step/s
global step 260, epoch: 1, batch: 260, loss: 0.42947, accu: 0.88583, speed: 1.50 step/s
global step 270, epoch: 1, batch: 270, loss: 0.41664, accu: 0.88850, speed: 1.48 step/s
global step 280, epoch: 1, batch: 280, loss: 0.42807, accu: 0.88825, speed: 1.49 step/s
global step 290, epoch: 1, batch: 290, loss: 0.39553, accu: 0.88878, speed: 1.53 step/s
global step 300, epoch: 1, batch: 300, loss: 0.40951, accu: 0.88865, speed: 1.55 step/s
eval dev loss: 0.47548, accu: 0.82618
global step 310, epoch: 1, batch: 310, loss: 0.43748, accu: 0.87700, speed: 0.44 step/s
global step 320, epoch: 1, batch: 320, loss: 0.40587, accu: 0.88625, speed: 1.52 step/s
global step 330, epoch: 1, batch: 330, loss: 0.41623, accu: 0.88983, speed: 1.49 step/s
global step 340, epoch: 1, batch: 340, loss: 0.43299, accu: 0.88913, speed: 1.53 step/s
global step 350, epoch: 1, batch: 350, loss: 0.41322, accu: 0.88840, speed: 1.46 step/s
global step 360, epoch: 1, batch: 360, loss: 0.41603, accu: 0.88600, speed: 1.38 step/s
global step 370, epoch: 1, batch: 370, loss: 0.41988, accu: 0.88771, speed: 1.54 step/s
global step 380, epoch: 1, batch: 380, loss: 0.42920, accu: 0.88813, speed: 1.49 step/s
global step 390, epoch: 1, batch: 390, loss: 0.40374, accu: 0.89017, speed: 1.48 step/s
global step 400, epoch: 1, batch: 400, loss: 0.41084, accu: 0.89000, speed: 1.57 step/s
eval dev loss: 0.45543, accu: 0.84663
global step 410, epoch: 1, batch: 410, loss: 0.38953, accu: 0.88450, speed: 0.45 step/s
global step 420, epoch: 1, batch: 420, loss: 0.39494, accu: 0.89500, speed: 1.54 step/s
global step 430, epoch: 1, batch: 430, loss: 0.39115, accu: 0.89917, speed: 1.54 step/s
global step 440, epoch: 1, batch: 440, loss: 0.39700, accu: 0.89838, speed: 1.51 step/s
global step 450, epoch: 1, batch: 450, loss: 0.46040, accu: 0.89830, speed: 1.53 step/s
global step 460, epoch: 1, batch: 460, loss: 0.39454, accu: 0.89842, speed: 1.53 step/s
global step 470, epoch: 1, batch: 470, loss: 0.43264, accu: 0.89750, speed: 1.49 step/s
global step 480, epoch: 1, batch: 480, loss: 0.40175, accu: 0.89638, speed: 1.58 step/s
global step 490, epoch: 1, batch: 490, loss: 0.40805, accu: 0.89700, speed: 1.46 step/s
global step 500, epoch: 1, batch: 500, loss: 0.39555, accu: 0.89725, speed: 1.55 step/s
eval dev loss: 0.47181, accu: 0.8322
global step 510, epoch: 1, batch: 510, loss: 0.42758, accu: 0.88850, speed: 0.47 step/s
global step 520, epoch: 1, batch: 520, loss: 0.43224, accu: 0.89100, speed: 1.56 step/s
global step 530, epoch: 1, batch: 530, loss: 0.43125, accu: 0.89367, speed: 1.49 step/s
global step 540, epoch: 1, batch: 540, loss: 0.39203, accu: 0.89287, speed: 1.36 step/s
global step 550, epoch: 1, batch: 550, loss: 0.42358, accu: 0.89480, speed: 1.48 step/s
global step 560, epoch: 1, batch: 560, loss: 0.40156, accu: 0.89500, speed: 1.42 step/s
global step 570, epoch: 1, batch: 570, loss: 0.40163, accu: 0.89564, speed: 1.52 step/s
global step 580, epoch: 1, batch: 580, loss: 0.41821, accu: 0.89619, speed: 1.60 step/s
global step 590, epoch: 1, batch: 590, loss: 0.44833, accu: 0.89817, speed: 1.48 step/s
global step 600, epoch: 1, batch: 600, loss: 0.36609, accu: 0.89980, speed: 1.49 step/s
eval dev loss: 0.44934, accu: 0.85731
global step 610, epoch: 1, batch: 610, loss: 0.37102, accu: 0.91050, speed: 0.45 step/s
global step 620, epoch: 1, batch: 620, loss: 0.38826, accu: 0.90750, speed: 1.49 step/s
global step 630, epoch: 1, batch: 630, loss: 0.38981, accu: 0.90783, speed: 1.52 step/s
global step 640, epoch: 1, batch: 640, loss: 0.42247, accu: 0.90450, speed: 1.55 step/s
global step 650, epoch: 1, batch: 650, loss: 0.41996, accu: 0.90250, speed: 1.55 step/s
global step 660, epoch: 1, batch: 660, loss: 0.37269, accu: 0.90292, speed: 1.54 step/s
global step 670, epoch: 1, batch: 670, loss: 0.41032, accu: 0.90171, speed: 1.51 step/s
global step 680, epoch: 1, batch: 680, loss: 0.42247, accu: 0.90263, speed: 1.49 step/s
global step 690, epoch: 1, batch: 690, loss: 0.43286, accu: 0.90272, speed: 1.45 step/s
global step 700, epoch: 1, batch: 700, loss: 0.41962, accu: 0.90315, speed: 1.55 step/s
eval dev loss: 0.44838, accu: 0.85833
global step 710, epoch: 1, batch: 710, loss: 0.45828, accu: 0.89600, speed: 0.44 step/s
global step 720, epoch: 1, batch: 720, loss: 0.40422, accu: 0.89950, speed: 1.44 step/s
global step 730, epoch: 1, batch: 730, loss: 0.42421, accu: 0.90483, speed: 1.53 step/s
global step 740, epoch: 1, batch: 740, loss: 0.42034, accu: 0.90175, speed: 1.59 step/s
global step 750, epoch: 1, batch: 750, loss: 0.44502, accu: 0.89950, speed: 1.51 step/s
global step 760, epoch: 1, batch: 760, loss: 0.38841, accu: 0.90025, speed: 1.44 step/s
global step 770, epoch: 1, batch: 770, loss: 0.41315, accu: 0.89929, speed: 1.53 step/s
global step 780, epoch: 1, batch: 780, loss: 0.42345, accu: 0.89944, speed: 1.50 step/s
global step 790, epoch: 1, batch: 790, loss: 0.42058, accu: 0.89850, speed: 1.49 step/s
global step 800, epoch: 1, batch: 800, loss: 0.42391, accu: 0.89870, speed: 1.58 step/s
eval dev loss: 0.44741, accu: 0.85844
global step 810, epoch: 1, batch: 810, loss: 0.39620, accu: 0.90700, speed: 0.44 step/s
global step 820, epoch: 1, batch: 820, loss: 0.38533, accu: 0.90525, speed: 1.48 step/s
global step 830, epoch: 1, batch: 830, loss: 0.40260, accu: 0.90367, speed: 1.52 step/s
global step 840, epoch: 1, batch: 840, loss: 0.41259, accu: 0.90275, speed: 1.52 step/s
global step 850, epoch: 1, batch: 850, loss: 0.40053, accu: 0.90140, speed: 1.65 step/s
global step 860, epoch: 1, batch: 860, loss: 0.38785, accu: 0.90217, speed: 1.55 step/s
global step 870, epoch: 1, batch: 870, loss: 0.42254, accu: 0.90171, speed: 1.48 step/s
global step 880, epoch: 1, batch: 880, loss: 0.40494, accu: 0.90281, speed: 1.52 step/s
global step 890, epoch: 1, batch: 890, loss: 0.38785, accu: 0.90244, speed: 1.43 step/s
global step 900, epoch: 1, batch: 900, loss: 0.40052, accu: 0.90255, speed: 1.46 step/s
eval dev loss: 0.44536, accu: 0.85696
global step 910, epoch: 1, batch: 910, loss: 0.41451, accu: 0.90550, speed: 0.45 step/s
global step 920, epoch: 1, batch: 920, loss: 0.38898, accu: 0.90775, speed: 1.50 step/s
global step 930, epoch: 1, batch: 930, loss: 0.43280, accu: 0.90633, speed: 1.51 step/s
global step 940, epoch: 1, batch: 940, loss: 0.38835, accu: 0.90763, speed: 1.52 step/s
global step 950, epoch: 1, batch: 950, loss: 0.41093, accu: 0.90860, speed: 1.47 step/s
global step 960, epoch: 1, batch: 960, loss: 0.41063, accu: 0.90700, speed: 1.50 step/s
global step 970, epoch: 1, batch: 970, loss: 0.41371, accu: 0.90579, speed: 1.51 step/s
global step 980, epoch: 1, batch: 980, loss: 0.40816, accu: 0.90456, speed: 1.52 step/s
global step 990, epoch: 1, batch: 990, loss: 0.39284, accu: 0.90494, speed: 1.60 step/s
global step 1000, epoch: 1, batch: 1000, loss: 0.38835, accu: 0.90615, speed: 1.53 step/s
eval dev loss: 0.43505, accu: 0.86935
global step 1010, epoch: 1, batch: 1010, loss: 0.39952, accu: 0.91150, speed: 0.46 step/s
global step 1020, epoch: 1, batch: 1020, loss: 0.40307, accu: 0.90650, speed: 1.51 step/s
global step 1030, epoch: 1, batch: 1030, loss: 0.43017, accu: 0.90400, speed: 1.47 step/s
global step 1040, epoch: 1, batch: 1040, loss: 0.38295, accu: 0.90387, speed: 1.53 step/s
global step 1050, epoch: 1, batch: 1050, loss: 0.37557, accu: 0.90500, speed: 1.62 step/s
global step 1060, epoch: 1, batch: 1060, loss: 0.38760, accu: 0.90617, speed: 1.28 step/s
global step 1070, epoch: 1, batch: 1070, loss: 0.41043, accu: 0.90829, speed: 1.54 step/s
global step 1080, epoch: 1, batch: 1080, loss: 0.39929, accu: 0.90975, speed: 1.50 step/s
global step 1090, epoch: 1, batch: 1090, loss: 0.40508, accu: 0.90978, speed: 1.48 step/s
global step 1100, epoch: 1, batch: 1100, loss: 0.38617, accu: 0.91010, speed: 1.52 step/s
eval dev loss: 0.44029, accu: 0.86026
global step 1110, epoch: 1, batch: 1110, loss: 0.40354, accu: 0.91500, speed: 0.46 step/s
global step 1120, epoch: 1, batch: 1120, loss: 0.39003, accu: 0.91250, speed: 1.50 step/s
global step 1130, epoch: 1, batch: 1130, loss: 0.42794, accu: 0.90550, speed: 1.53 step/s
global step 1140, epoch: 1, batch: 1140, loss: 0.39273, accu: 0.90812, speed: 1.47 step/s
global step 1150, epoch: 1, batch: 1150, loss: 0.40941, accu: 0.90730, speed: 1.49 step/s
global step 1160, epoch: 1, batch: 1160, loss: 0.41299, accu: 0.90742, speed: 1.58 step/s
global step 1170, epoch: 1, batch: 1170, loss: 0.44179, accu: 0.90636, speed: 1.52 step/s
global step 1180, epoch: 1, batch: 1180, loss: 0.38870, accu: 0.90669, speed: 1.59 step/s
global step 1190, epoch: 1, batch: 1190, loss: 0.37651, accu: 0.90700, speed: 1.50 step/s
global step 1200, epoch: 2, batch: 6, loss: 0.36377, accu: 0.90814, speed: 1.56 step/s
eval dev loss: 0.44735, accu: 0.86071
global step 1210, epoch: 2, batch: 16, loss: 0.43262, accu: 0.89900, speed: 0.46 step/s
global step 1220, epoch: 2, batch: 26, loss: 0.38926, accu: 0.90275, speed: 1.48 step/s
global step 1230, epoch: 2, batch: 36, loss: 0.36558, accu: 0.90217, speed: 1.47 step/s
global step 1240, epoch: 2, batch: 46, loss: 0.39775, accu: 0.90400, speed: 1.53 step/s
global step 1250, epoch: 2, batch: 56, loss: 0.36771, accu: 0.90580, speed: 1.55 step/s
global step 1260, epoch: 2, batch: 66, loss: 0.41675, accu: 0.90842, speed: 1.53 step/s
global step 1270, epoch: 2, batch: 76, loss: 0.39934, accu: 0.90964, speed: 1.49 step/s
global step 1280, epoch: 2, batch: 86, loss: 0.38859, accu: 0.91050, speed: 1.43 step/s
global step 1290, epoch: 2, batch: 96, loss: 0.41649, accu: 0.91200, speed: 1.55 step/s
global step 1300, epoch: 2, batch: 106, loss: 0.43082, accu: 0.91190, speed: 1.47 step/s
eval dev loss: 0.43336, accu: 0.87378
global step 1310, epoch: 2, batch: 116, loss: 0.35526, accu: 0.92000, speed: 0.46 step/s
global step 1320, epoch: 2, batch: 126, loss: 0.40886, accu: 0.91875, speed: 1.56 step/s
global step 1330, epoch: 2, batch: 136, loss: 0.39626, accu: 0.91533, speed: 1.57 step/s
global step 1340, epoch: 2, batch: 146, loss: 0.40366, accu: 0.91650, speed: 1.62 step/s
global step 1350, epoch: 2, batch: 156, loss: 0.38953, accu: 0.91610, speed: 1.60 step/s
global step 1360, epoch: 2, batch: 166, loss: 0.42664, accu: 0.91475, speed: 1.56 step/s
global step 1370, epoch: 2, batch: 176, loss: 0.40486, accu: 0.91479, speed: 1.49 step/s
global step 1380, epoch: 2, batch: 186, loss: 0.39675, accu: 0.91525, speed: 1.47 step/s
global step 1390, epoch: 2, batch: 196, loss: 0.39277, accu: 0.91550, speed: 1.47 step/s
global step 1400, epoch: 2, batch: 206, loss: 0.43408, accu: 0.91525, speed: 1.52 step/s
eval dev loss: 0.43504, accu: 0.86958
global step 1410, epoch: 2, batch: 216, loss: 0.39766, accu: 0.92250, speed: 0.47 step/s
global step 1420, epoch: 2, batch: 226, loss: 0.42082, accu: 0.91450, speed: 1.50 step/s
global step 1430, epoch: 2, batch: 236, loss: 0.41655, accu: 0.91417, speed: 1.52 step/s
global step 1440, epoch: 2, batch: 246, loss: 0.39961, accu: 0.91287, speed: 1.52 step/s
global step 1450, epoch: 2, batch: 256, loss: 0.40857, accu: 0.91220, speed: 1.35 step/s
global step 1460, epoch: 2, batch: 266, loss: 0.38629, accu: 0.91225, speed: 1.54 step/s
global step 1470, epoch: 2, batch: 276, loss: 0.36794, accu: 0.91279, speed: 1.55 step/s
global step 1480, epoch: 2, batch: 286, loss: 0.36942, accu: 0.91306, speed: 1.50 step/s
global step 1490, epoch: 2, batch: 296, loss: 0.36161, accu: 0.91356, speed: 1.48 step/s
global step 1500, epoch: 2, batch: 306, loss: 0.40382, accu: 0.91445, speed: 1.54 step/s
eval dev loss: 0.4394, accu: 0.86685
global step 1510, epoch: 2, batch: 316, loss: 0.40272, accu: 0.92500, speed: 0.47 step/s
global step 1520, epoch: 2, batch: 326, loss: 0.43914, accu: 0.91925, speed: 1.53 step/s
global step 1530, epoch: 2, batch: 336, loss: 0.40328, accu: 0.91550, speed: 1.53 step/s
global step 1540, epoch: 2, batch: 346, loss: 0.40112, accu: 0.91525, speed: 1.56 step/s
global step 1550, epoch: 2, batch: 356, loss: 0.40841, accu: 0.91400, speed: 1.47 step/s
global step 1560, epoch: 2, batch: 366, loss: 0.38698, accu: 0.91425, speed: 1.49 step/s
global step 1570, epoch: 2, batch: 376, loss: 0.39339, accu: 0.91329, speed: 1.48 step/s
global step 1580, epoch: 2, batch: 386, loss: 0.36858, accu: 0.91319, speed: 1.55 step/s
global step 1590, epoch: 2, batch: 396, loss: 0.40436, accu: 0.91378, speed: 1.47 step/s
global step 1600, epoch: 2, batch: 406, loss: 0.38181, accu: 0.91295, speed: 1.55 step/s
eval dev loss: 0.44206, accu: 0.86083
global step 1610, epoch: 2, batch: 416, loss: 0.41873, accu: 0.91150, speed: 0.45 step/s
global step 1620, epoch: 2, batch: 426, loss: 0.39216, accu: 0.91275, speed: 1.55 step/s
global step 1630, epoch: 2, batch: 436, loss: 0.42617, accu: 0.91117, speed: 1.46 step/s
global step 1640, epoch: 2, batch: 446, loss: 0.36525, accu: 0.91100, speed: 1.49 step/s
global step 1650, epoch: 2, batch: 456, loss: 0.37592, accu: 0.91300, speed: 1.53 step/s
global step 1660, epoch: 2, batch: 466, loss: 0.40340, accu: 0.91333, speed: 1.45 step/s
global step 1670, epoch: 2, batch: 476, loss: 0.38199, accu: 0.91414, speed: 1.47 step/s
global step 1680, epoch: 2, batch: 486, loss: 0.35750, accu: 0.91381, speed: 1.52 step/s
global step 1690, epoch: 2, batch: 496, loss: 0.38297, accu: 0.91361, speed: 1.49 step/s
global step 1700, epoch: 2, batch: 506, loss: 0.39386, accu: 0.91400, speed: 1.43 step/s
eval dev loss: 0.4247, accu: 0.8815
global step 1710, epoch: 2, batch: 516, loss: 0.39467, accu: 0.92150, speed: 0.45 step/s
global step 1720, epoch: 2, batch: 526, loss: 0.41224, accu: 0.91550, speed: 1.51 step/s
global step 1730, epoch: 2, batch: 536, loss: 0.42351, accu: 0.91400, speed: 1.39 step/s
global step 1740, epoch: 2, batch: 546, loss: 0.40383, accu: 0.91363, speed: 1.52 step/s
global step 1750, epoch: 2, batch: 556, loss: 0.40014, accu: 0.91340, speed: 1.51 step/s
global step 1760, epoch: 2, batch: 566, loss: 0.41098, accu: 0.91275, speed: 1.48 step/s
global step 1770, epoch: 2, batch: 576, loss: 0.42319, accu: 0.91321, speed: 1.46 step/s
global step 1780, epoch: 2, batch: 586, loss: 0.38658, accu: 0.91325, speed: 1.50 step/s
global step 1790, epoch: 2, batch: 596, loss: 0.38308, accu: 0.91317, speed: 1.47 step/s
global step 1800, epoch: 2, batch: 606, loss: 0.38702, accu: 0.91320, speed: 1.53 step/s
eval dev loss: 0.428, accu: 0.8773
global step 1810, epoch: 2, batch: 616, loss: 0.40405, accu: 0.91300, speed: 0.46 step/s
global step 1820, epoch: 2, batch: 626, loss: 0.38681, accu: 0.91250, speed: 1.60 step/s
global step 1830, epoch: 2, batch: 636, loss: 0.42459, accu: 0.91350, speed: 1.50 step/s
global step 1840, epoch: 2, batch: 646, loss: 0.38524, accu: 0.91663, speed: 1.47 step/s
global step 1850, epoch: 2, batch: 656, loss: 0.40913, accu: 0.91780, speed: 1.44 step/s
global step 1860, epoch: 2, batch: 666, loss: 0.36827, accu: 0.91792, speed: 1.45 step/s
global step 1870, epoch: 2, batch: 676, loss: 0.39987, accu: 0.91479, speed: 1.53 step/s
global step 1880, epoch: 2, batch: 686, loss: 0.38657, accu: 0.91338, speed: 1.54 step/s
global step 1890, epoch: 2, batch: 696, loss: 0.37622, accu: 0.91428, speed: 1.52 step/s
global step 1900, epoch: 2, batch: 706, loss: 0.37793, accu: 0.91420, speed: 1.50 step/s
eval dev loss: 0.43206, accu: 0.87367
global step 1910, epoch: 2, batch: 716, loss: 0.40096, accu: 0.93000, speed: 0.45 step/s
global step 1920, epoch: 2, batch: 726, loss: 0.41217, accu: 0.92125, speed: 1.51 step/s
global step 1930, epoch: 2, batch: 736, loss: 0.37674, accu: 0.92417, speed: 1.46 step/s
global step 1940, epoch: 2, batch: 746, loss: 0.42196, accu: 0.91838, speed: 1.62 step/s
global step 1950, epoch: 2, batch: 756, loss: 0.39602, accu: 0.91820, speed: 1.59 step/s
global step 1960, epoch: 2, batch: 766, loss: 0.39280, accu: 0.92000, speed: 1.47 step/s
global step 1970, epoch: 2, batch: 776, loss: 0.39236, accu: 0.91971, speed: 1.50 step/s
global step 1980, epoch: 2, batch: 786, loss: 0.40357, accu: 0.91956, speed: 1.49 step/s
global step 1990, epoch: 2, batch: 796, loss: 0.36027, accu: 0.91944, speed: 1.48 step/s
global step 2000, epoch: 2, batch: 806, loss: 0.34813, accu: 0.92040, speed: 1.47 step/s
eval dev loss: 0.43309, accu: 0.87298
global step 2010, epoch: 2, batch: 816, loss: 0.41202, accu: 0.90850, speed: 0.44 step/s
global step 2020, epoch: 2, batch: 826, loss: 0.41599, accu: 0.91100, speed: 1.51 step/s
global step 2030, epoch: 2, batch: 836, loss: 0.38656, accu: 0.91417, speed: 1.53 step/s
global step 2040, epoch: 2, batch: 846, loss: 0.40643, accu: 0.91563, speed: 1.50 step/s
global step 2050, epoch: 2, batch: 856, loss: 0.40775, accu: 0.91560, speed: 1.39 step/s
global step 2060, epoch: 2, batch: 866, loss: 0.37921, accu: 0.91600, speed: 1.47 step/s
global step 2070, epoch: 2, batch: 876, loss: 0.40983, accu: 0.91507, speed: 1.47 step/s
global step 2080, epoch: 2, batch: 886, loss: 0.34789, accu: 0.91612, speed: 1.53 step/s
global step 2090, epoch: 2, batch: 896, loss: 0.39358, accu: 0.91583, speed: 1.53 step/s
global step 2100, epoch: 2, batch: 906, loss: 0.38052, accu: 0.91575, speed: 1.31 step/s
eval dev loss: 0.4343, accu: 0.87287
global step 2110, epoch: 2, batch: 916, loss: 0.38128, accu: 0.89950, speed: 0.46 step/s
global step 2120, epoch: 2, batch: 926, loss: 0.40311, accu: 0.90175, speed: 1.55 step/s
global step 2130, epoch: 2, batch: 936, loss: 0.40066, accu: 0.90583, speed: 1.53 step/s
global step 2140, epoch: 2, batch: 946, loss: 0.39635, accu: 0.90763, speed: 1.50 step/s
global step 2150, epoch: 2, batch: 956, loss: 0.37643, accu: 0.90890, speed: 1.58 step/s
global step 2160, epoch: 2, batch: 966, loss: 0.39126, accu: 0.90967, speed: 1.53 step/s
global step 2170, epoch: 2, batch: 976, loss: 0.40611, accu: 0.91143, speed: 1.54 step/s
global step 2180, epoch: 2, batch: 986, loss: 0.39319, accu: 0.91287, speed: 1.48 step/s
global step 2190, epoch: 2, batch: 996, loss: 0.40172, accu: 0.91144, speed: 1.49 step/s
global step 2200, epoch: 2, batch: 1006, loss: 0.38288, accu: 0.91290, speed: 1.59 step/s
eval dev loss: 0.42447, accu: 0.88048
global step 2210, epoch: 2, batch: 1016, loss: 0.42985, accu: 0.90400, speed: 0.46 step/s
global step 2220, epoch: 2, batch: 1026, loss: 0.39039, accu: 0.91025, speed: 1.45 step/s
global step 2230, epoch: 2, batch: 1036, loss: 0.39993, accu: 0.91300, speed: 1.57 step/s
global step 2240, epoch: 2, batch: 1046, loss: 0.38195, accu: 0.91487, speed: 1.48 step/s
global step 2250, epoch: 2, batch: 1056, loss: 0.36759, accu: 0.91410, speed: 1.48 step/s
global step 2260, epoch: 2, batch: 1066, loss: 0.38271, accu: 0.91608, speed: 1.49 step/s
global step 2270, epoch: 2, batch: 1076, loss: 0.39848, accu: 0.91664, speed: 1.55 step/s
global step 2280, epoch: 2, batch: 1086, loss: 0.36947, accu: 0.91612, speed: 1.52 step/s
global step 2290, epoch: 2, batch: 1096, loss: 0.37548, accu: 0.91750, speed: 1.54 step/s
global step 2300, epoch: 2, batch: 1106, loss: 0.40687, accu: 0.91665, speed: 1.44 step/s
eval dev loss: 0.43493, accu: 0.86753
global step 2310, epoch: 2, batch: 1116, loss: 0.39106, accu: 0.92100, speed: 0.44 step/s
global step 2320, epoch: 2, batch: 1126, loss: 0.39811, accu: 0.92175, speed: 1.53 step/s
global step 2330, epoch: 2, batch: 1136, loss: 0.38887, accu: 0.92067, speed: 1.50 step/s
global step 2340, epoch: 2, batch: 1146, loss: 0.40443, accu: 0.91875, speed: 1.53 step/s
global step 2350, epoch: 2, batch: 1156, loss: 0.39642, accu: 0.91910, speed: 1.54 step/s
global step 2360, epoch: 2, batch: 1166, loss: 0.42380, accu: 0.91667, speed: 1.50 step/s
global step 2370, epoch: 2, batch: 1176, loss: 0.40382, accu: 0.91707, speed: 1.50 step/s
global step 2380, epoch: 2, batch: 1186, loss: 0.40366, accu: 0.91644, speed: 1.51 step/s
global step 2390, epoch: 3, batch: 2, loss: 0.40071, accu: 0.91634, speed: 1.56 step/s
global step 2400, epoch: 3, batch: 12, loss: 0.35090, accu: 0.91831, speed: 1.52 step/s
eval dev loss: 0.44117, accu: 0.86344
global step 2410, epoch: 3, batch: 22, loss: 0.37442, accu: 0.91700, speed: 0.44 step/s
global step 2420, epoch: 3, batch: 32, loss: 0.38190, accu: 0.91875, speed: 1.50 step/s
global step 2430, epoch: 3, batch: 42, loss: 0.43214, accu: 0.91733, speed: 1.48 step/s
global step 2440, epoch: 3, batch: 52, loss: 0.39069, accu: 0.91812, speed: 1.52 step/s
global step 2450, epoch: 3, batch: 62, loss: 0.37687, accu: 0.91850, speed: 1.52 step/s
global step 2460, epoch: 3, batch: 72, loss: 0.38167, accu: 0.91933, speed: 1.53 step/s
global step 2470, epoch: 3, batch: 82, loss: 0.37502, accu: 0.92036, speed: 1.52 step/s
global step 2480, epoch: 3, batch: 92, loss: 0.39422, accu: 0.92119, speed: 1.61 step/s
global step 2490, epoch: 3, batch: 102, loss: 0.38488, accu: 0.92167, speed: 1.47 step/s
global step 2500, epoch: 3, batch: 112, loss: 0.40957, accu: 0.92230, speed: 1.47 step/s
eval dev loss: 0.44386, accu: 0.86185
global step 2510, epoch: 3, batch: 122, loss: 0.41215, accu: 0.91800, speed: 0.45 step/s
global step 2520, epoch: 3, batch: 132, loss: 0.39222, accu: 0.91500, speed: 1.48 step/s
global step 2530, epoch: 3, batch: 142, loss: 0.36614, accu: 0.91767, speed: 1.49 step/s
global step 2540, epoch: 3, batch: 152, loss: 0.37133, accu: 0.91775, speed: 1.55 step/s
global step 2550, epoch: 3, batch: 162, loss: 0.36515, accu: 0.91910, speed: 1.48 step/s
global step 2560, epoch: 3, batch: 172, loss: 0.38514, accu: 0.92025, speed: 1.48 step/s
global step 2570, epoch: 3, batch: 182, loss: 0.39043, accu: 0.92164, speed: 1.56 step/s
global step 2580, epoch: 3, batch: 192, loss: 0.38204, accu: 0.92144, speed: 1.45 step/s
global step 2590, epoch: 3, batch: 202, loss: 0.39285, accu: 0.92100, speed: 1.54 step/s
global step 2600, epoch: 3, batch: 212, loss: 0.38361, accu: 0.92130, speed: 1.49 step/s
eval dev loss: 0.44392, accu: 0.86196
global step 2610, epoch: 3, batch: 222, loss: 0.37783, accu: 0.92900, speed: 0.46 step/s
global step 2620, epoch: 3, batch: 232, loss: 0.37155, accu: 0.92725, speed: 1.48 step/s
global step 2630, epoch: 3, batch: 242, loss: 0.36136, accu: 0.92883, speed: 1.47 step/s
global step 2640, epoch: 3, batch: 252, loss: 0.40858, accu: 0.92912, speed: 1.49 step/s
global step 2650, epoch: 3, batch: 262, loss: 0.40313, accu: 0.92950, speed: 1.58 step/s
global step 2660, epoch: 3, batch: 272, loss: 0.39617, accu: 0.92692, speed: 1.55 step/s
global step 2670, epoch: 3, batch: 282, loss: 0.36207, accu: 0.92700, speed: 1.45 step/s
global step 2680, epoch: 3, batch: 292, loss: 0.36942, accu: 0.92712, speed: 1.52 step/s
global step 2690, epoch: 3, batch: 302, loss: 0.37074, accu: 0.92700, speed: 1.48 step/s
global step 2700, epoch: 3, batch: 312, loss: 0.40646, accu: 0.92725, speed: 1.48 step/s
eval dev loss: 0.42646, accu: 0.88173
global step 2710, epoch: 3, batch: 322, loss: 0.41839, accu: 0.92050, speed: 0.47 step/s
global step 2720, epoch: 3, batch: 332, loss: 0.41582, accu: 0.91700, speed: 1.49 step/s
global step 2730, epoch: 3, batch: 342, loss: 0.37477, accu: 0.92050, speed: 1.50 step/s
global step 2740, epoch: 3, batch: 352, loss: 0.38293, accu: 0.92087, speed: 1.52 step/s
global step 2750, epoch: 3, batch: 362, loss: 0.37620, accu: 0.92330, speed: 1.53 step/s
global step 2760, epoch: 3, batch: 372, loss: 0.35937, accu: 0.92450, speed: 1.51 step/s
global step 2770, epoch: 3, batch: 382, loss: 0.39432, accu: 0.92521, speed: 1.50 step/s
global step 2780, epoch: 3, batch: 392, loss: 0.38395, accu: 0.92500, speed: 1.45 step/s
global step 2790, epoch: 3, batch: 402, loss: 0.39228, accu: 0.92428, speed: 1.47 step/s
global step 2800, epoch: 3, batch: 412, loss: 0.36364, accu: 0.92455, speed: 1.53 step/s
eval dev loss: 0.42474, accu: 0.88366
global step 2810, epoch: 3, batch: 422, loss: 0.38042, accu: 0.92350, speed: 0.46 step/s
global step 2820, epoch: 3, batch: 432, loss: 0.37248, accu: 0.92450, speed: 1.54 step/s
global step 2830, epoch: 3, batch: 442, loss: 0.34947, accu: 0.92733, speed: 1.52 step/s
global step 2840, epoch: 3, batch: 452, loss: 0.38808, accu: 0.92675, speed: 1.50 step/s
global step 2850, epoch: 3, batch: 462, loss: 0.40981, accu: 0.92540, speed: 1.48 step/s
global step 2860, epoch: 3, batch: 472, loss: 0.36821, accu: 0.92675, speed: 1.56 step/s
global step 2870, epoch: 3, batch: 482, loss: 0.38350, accu: 0.92529, speed: 1.47 step/s
global step 2880, epoch: 3, batch: 492, loss: 0.36421, accu: 0.92663, speed: 1.52 step/s
global step 2890, epoch: 3, batch: 502, loss: 0.41240, accu: 0.92628, speed: 1.50 step/s
global step 2900, epoch: 3, batch: 512, loss: 0.36777, accu: 0.92685, speed: 1.49 step/s
eval dev loss: 0.43744, accu: 0.86946
global step 2910, epoch: 3, batch: 522, loss: 0.35882, accu: 0.93000, speed: 0.44 step/s
global step 2920, epoch: 3, batch: 532, loss: 0.41566, accu: 0.92100, speed: 1.50 step/s
global step 2930, epoch: 3, batch: 542, loss: 0.37917, accu: 0.92250, speed: 1.58 step/s
global step 2940, epoch: 3, batch: 552, loss: 0.41991, accu: 0.92063, speed: 1.51 step/s
global step 2950, epoch: 3, batch: 562, loss: 0.39134, accu: 0.92210, speed: 1.52 step/s
global step 2960, epoch: 3, batch: 572, loss: 0.37125, accu: 0.92300, speed: 1.56 step/s
global step 2970, epoch: 3, batch: 582, loss: 0.38047, accu: 0.92336, speed: 1.51 step/s
global step 2980, epoch: 3, batch: 592, loss: 0.41134, accu: 0.92356, speed: 1.51 step/s
global step 2990, epoch: 3, batch: 602, loss: 0.41864, accu: 0.92233, speed: 1.52 step/s
global step 3000, epoch: 3, batch: 612, loss: 0.38121, accu: 0.92280, speed: 1.44 step/s
eval dev loss: 0.43672, accu: 0.86833
global step 3010, epoch: 3, batch: 622, loss: 0.38628, accu: 0.92200, speed: 0.46 step/s
global step 3020, epoch: 3, batch: 632, loss: 0.41099, accu: 0.91775, speed: 1.55 step/s
global step 3030, epoch: 3, batch: 642, loss: 0.39620, accu: 0.92100, speed: 1.51 step/s
global step 3040, epoch: 3, batch: 652, loss: 0.38861, accu: 0.92400, speed: 1.47 step/s
global step 3050, epoch: 3, batch: 662, loss: 0.35623, accu: 0.92470, speed: 1.44 step/s
global step 3060, epoch: 3, batch: 672, loss: 0.37850, accu: 0.92392, speed: 1.56 step/s
global step 3070, epoch: 3, batch: 682, loss: 0.39322, accu: 0.92321, speed: 1.47 step/s
global step 3080, epoch: 3, batch: 692, loss: 0.39924, accu: 0.92263, speed: 1.52 step/s
global step 3090, epoch: 3, batch: 702, loss: 0.35524, accu: 0.92389, speed: 1.54 step/s
global step 3100, epoch: 3, batch: 712, loss: 0.38800, accu: 0.92400, speed: 1.51 step/s
eval dev loss: 0.41873, accu: 0.88957
global step 3110, epoch: 3, batch: 722, loss: 0.36701, accu: 0.92000, speed: 0.46 step/s
global step 3120, epoch: 3, batch: 732, loss: 0.36217, accu: 0.92225, speed: 1.55 step/s
global step 3130, epoch: 3, batch: 742, loss: 0.37265, accu: 0.92317, speed: 1.53 step/s
global step 3140, epoch: 3, batch: 752, loss: 0.36506, accu: 0.92300, speed: 1.47 step/s
global step 3150, epoch: 3, batch: 762, loss: 0.39297, accu: 0.92380, speed: 1.52 step/s
global step 3160, epoch: 3, batch: 772, loss: 0.37631, accu: 0.92425, speed: 1.53 step/s
global step 3170, epoch: 3, batch: 782, loss: 0.37088, accu: 0.92407, speed: 1.51 step/s
global step 3180, epoch: 3, batch: 792, loss: 0.38460, accu: 0.92463, speed: 1.49 step/s
global step 3190, epoch: 3, batch: 802, loss: 0.38814, accu: 0.92506, speed: 1.56 step/s
global step 3200, epoch: 3, batch: 812, loss: 0.38116, accu: 0.92340, speed: 1.51 step/s
eval dev loss: 0.43104, accu: 0.87685
global step 3210, epoch: 3, batch: 822, loss: 0.40869, accu: 0.92350, speed: 0.47 step/s
global step 3220, epoch: 3, batch: 832, loss: 0.38589, accu: 0.92475, speed: 1.34 step/s
global step 3230, epoch: 3, batch: 842, loss: 0.38751, accu: 0.92117, speed: 1.47 step/s
global step 3240, epoch: 3, batch: 852, loss: 0.34840, accu: 0.92038, speed: 1.53 step/s
global step 3250, epoch: 3, batch: 862, loss: 0.40301, accu: 0.92100, speed: 1.54 step/s
global step 3260, epoch: 3, batch: 872, loss: 0.39925, accu: 0.92042, speed: 1.53 step/s
global step 3270, epoch: 3, batch: 882, loss: 0.35331, accu: 0.92200, speed: 1.44 step/s
global step 3280, epoch: 3, batch: 892, loss: 0.38552, accu: 0.92138, speed: 1.47 step/s
global step 3290, epoch: 3, batch: 902, loss: 0.39284, accu: 0.92172, speed: 1.50 step/s
global step 3300, epoch: 3, batch: 912, loss: 0.38473, accu: 0.92120, speed: 1.51 step/s
eval dev loss: 0.42645, accu: 0.88048
global step 3310, epoch: 3, batch: 922, loss: 0.37501, accu: 0.91300, speed: 0.47 step/s
global step 3320, epoch: 3, batch: 932, loss: 0.36422, accu: 0.91450, speed: 1.53 step/s
global step 3330, epoch: 3, batch: 942, loss: 0.37074, accu: 0.91600, speed: 1.49 step/s
global step 3340, epoch: 3, batch: 952, loss: 0.37143, accu: 0.91625, speed: 1.56 step/s
global step 3350, epoch: 3, batch: 962, loss: 0.38966, accu: 0.91730, speed: 1.48 step/s
global step 3360, epoch: 3, batch: 972, loss: 0.36447, accu: 0.91850, speed: 1.53 step/s
global step 3370, epoch: 3, batch: 982, loss: 0.40339, accu: 0.91950, speed: 1.48 step/s
global step 3380, epoch: 3, batch: 992, loss: 0.39816, accu: 0.91850, speed: 1.50 step/s
global step 3390, epoch: 3, batch: 1002, loss: 0.40734, accu: 0.91828, speed: 1.47 step/s
global step 3400, epoch: 3, batch: 1012, loss: 0.37522, accu: 0.91770, speed: 1.40 step/s
eval dev loss: 0.43497, accu: 0.86878
global step 3410, epoch: 3, batch: 1022, loss: 0.38627, accu: 0.92400, speed: 0.46 step/s
global step 3420, epoch: 3, batch: 1032, loss: 0.37653, accu: 0.92875, speed: 1.45 step/s
global step 3430, epoch: 3, batch: 1042, loss: 0.37425, accu: 0.92950, speed: 1.54 step/s
global step 3440, epoch: 3, batch: 1052, loss: 0.39389, accu: 0.92637, speed: 1.59 step/s
global step 3450, epoch: 3, batch: 1062, loss: 0.40575, accu: 0.92590, speed: 1.50 step/s
global step 3460, epoch: 3, batch: 1072, loss: 0.41152, accu: 0.92383, speed: 1.51 step/s
global step 3470, epoch: 3, batch: 1082, loss: 0.37187, accu: 0.92350, speed: 1.55 step/s
global step 3480, epoch: 3, batch: 1092, loss: 0.37771, accu: 0.92375, speed: 1.47 step/s
global step 3490, epoch: 3, batch: 1102, loss: 0.38009, accu: 0.92300, speed: 1.42 step/s
global step 3500, epoch: 3, batch: 1112, loss: 0.40669, accu: 0.92215, speed: 1.46 step/s
eval dev loss: 0.43763, accu: 0.86923
global step 3510, epoch: 3, batch: 1122, loss: 0.39603, accu: 0.91750, speed: 0.46 step/s
global step 3520, epoch: 3, batch: 1132, loss: 0.40577, accu: 0.91975, speed: 1.51 step/s
global step 3530, epoch: 3, batch: 1142, loss: 0.39883, accu: 0.91933, speed: 1.45 step/s
global step 3540, epoch: 3, batch: 1152, loss: 0.39800, accu: 0.91950, speed: 1.57 step/s
global step 3550, epoch: 3, batch: 1162, loss: 0.39437, accu: 0.91760, speed: 1.55 step/s
global step 3560, epoch: 3, batch: 1172, loss: 0.38538, accu: 0.91892, speed: 1.58 step/s
global step 3570, epoch: 3, batch: 1182, loss: 0.37913, accu: 0.91900, speed: 1.50 step/s
global step 3580, epoch: 3, batch: 1192, loss: 0.38567, accu: 0.91987, speed: 1.58 step/s
global step 3590, epoch: 4, batch: 8, loss: 0.40538, accu: 0.92057, speed: 1.60 step/s
global step 3600, epoch: 4, batch: 18, loss: 0.38678, accu: 0.92041, speed: 1.52 step/s
eval dev loss: 0.43279, accu: 0.87264
global step 3610, epoch: 4, batch: 28, loss: 0.35969, accu: 0.93600, speed: 0.44 step/s
global step 3620, epoch: 4, batch: 38, loss: 0.39369, accu: 0.93175, speed: 1.49 step/s
global step 3630, epoch: 4, batch: 48, loss: 0.38428, accu: 0.92867, speed: 1.51 step/s
global step 3640, epoch: 4, batch: 58, loss: 0.36607, accu: 0.92637, speed: 1.57 step/s
global step 3650, epoch: 4, batch: 68, loss: 0.35967, accu: 0.92550, speed: 1.48 step/s
global step 3660, epoch: 4, batch: 78, loss: 0.38052, accu: 0.92733, speed: 1.51 step/s
global step 3670, epoch: 4, batch: 88, loss: 0.37728, accu: 0.92793, speed: 1.53 step/s
global step 3680, epoch: 4, batch: 98, loss: 0.35790, accu: 0.92925, speed: 1.47 step/s
global step 3690, epoch: 4, batch: 108, loss: 0.37299, accu: 0.93083, speed: 1.52 step/s
global step 3700, epoch: 4, batch: 118, loss: 0.36221, accu: 0.93065, speed: 1.49 step/s
eval dev loss: 0.42944, accu: 0.87855
global step 3710, epoch: 4, batch: 128, loss: 0.37050, accu: 0.92850, speed: 0.44 step/s
global step 3720, epoch: 4, batch: 138, loss: 0.38836, accu: 0.92950, speed: 1.44 step/s
global step 3730, epoch: 4, batch: 148, loss: 0.40410, accu: 0.93017, speed: 1.51 step/s
global step 3740, epoch: 4, batch: 158, loss: 0.37791, accu: 0.92863, speed: 1.49 step/s
global step 3750, epoch: 4, batch: 168, loss: 0.37837, accu: 0.92960, speed: 1.47 step/s
global step 3760, epoch: 4, batch: 178, loss: 0.40247, accu: 0.92625, speed: 1.50 step/s
global step 3770, epoch: 4, batch: 188, loss: 0.38518, accu: 0.92771, speed: 1.51 step/s
global step 3780, epoch: 4, batch: 198, loss: 0.39341, accu: 0.92700, speed: 1.51 step/s
global step 3790, epoch: 4, batch: 208, loss: 0.39725, accu: 0.92672, speed: 1.47 step/s
global step 3800, epoch: 4, batch: 218, loss: 0.36694, accu: 0.92745, speed: 1.50 step/s
eval dev loss: 0.42879, accu: 0.87741
global step 3810, epoch: 4, batch: 228, loss: 0.38136, accu: 0.93000, speed: 0.46 step/s
global step 3820, epoch: 4, batch: 238, loss: 0.38568, accu: 0.93050, speed: 1.53 step/s
global step 3830, epoch: 4, batch: 248, loss: 0.38331, accu: 0.92667, speed: 1.50 step/s
global step 3840, epoch: 4, batch: 258, loss: 0.37548, accu: 0.92788, speed: 1.49 step/s
global step 3850, epoch: 4, batch: 268, loss: 0.37871, accu: 0.93040, speed: 1.51 step/s
global step 3860, epoch: 4, batch: 278, loss: 0.41448, accu: 0.93108, speed: 1.46 step/s
global step 3870, epoch: 4, batch: 288, loss: 0.42296, accu: 0.93114, speed: 1.46 step/s
global step 3880, epoch: 4, batch: 298, loss: 0.35083, accu: 0.93225, speed: 1.48 step/s
global step 3890, epoch: 4, batch: 308, loss: 0.39219, accu: 0.93228, speed: 1.50 step/s
global step 3900, epoch: 4, batch: 318, loss: 0.36428, accu: 0.93225, speed: 1.47 step/s
eval dev loss: 0.42801, accu: 0.8798
global step 3910, epoch: 4, batch: 328, loss: 0.38275, accu: 0.92700, speed: 0.46 step/s
global step 3920, epoch: 4, batch: 338, loss: 0.35911, accu: 0.92625, speed: 1.49 step/s
global step 3930, epoch: 4, batch: 348, loss: 0.38508, accu: 0.92500, speed: 1.57 step/s
global step 3940, epoch: 4, batch: 358, loss: 0.39379, accu: 0.92625, speed: 1.58 step/s
global step 3950, epoch: 4, batch: 368, loss: 0.38277, accu: 0.92890, speed: 1.45 step/s
global step 3960, epoch: 4, batch: 378, loss: 0.41058, accu: 0.92825, speed: 1.51 step/s
global step 3970, epoch: 4, batch: 388, loss: 0.37694, accu: 0.93064, speed: 1.53 step/s
global step 3980, epoch: 4, batch: 398, loss: 0.36246, accu: 0.93031, speed: 1.46 step/s
global step 3990, epoch: 4, batch: 408, loss: 0.37443, accu: 0.93056, speed: 1.52 step/s
global step 4000, epoch: 4, batch: 418, loss: 0.36925, accu: 0.93040, speed: 1.44 step/s
eval dev loss: 0.43604, accu: 0.87185
global step 4010, epoch: 4, batch: 428, loss: 0.35333, accu: 0.94000, speed: 0.44 step/s
global step 4020, epoch: 4, batch: 438, loss: 0.41746, accu: 0.93075, speed: 1.55 step/s
global step 4030, epoch: 4, batch: 448, loss: 0.38830, accu: 0.92950, speed: 1.54 step/s
global step 4040, epoch: 4, batch: 458, loss: 0.39034, accu: 0.93012, speed: 1.58 step/s
global step 4050, epoch: 4, batch: 468, loss: 0.37342, accu: 0.93000, speed: 1.52 step/s
global step 4060, epoch: 4, batch: 478, loss: 0.36832, accu: 0.93092, speed: 1.48 step/s
global step 4070, epoch: 4, batch: 488, loss: 0.37097, accu: 0.92936, speed: 1.44 step/s
global step 4080, epoch: 4, batch: 498, loss: 0.41355, accu: 0.92937, speed: 1.58 step/s
global step 4090, epoch: 4, batch: 508, loss: 0.38482, accu: 0.92917, speed: 1.50 step/s
global step 4100, epoch: 4, batch: 518, loss: 0.38991, accu: 0.92885, speed: 1.47 step/s
eval dev loss: 0.43715, accu: 0.86935
global step 4110, epoch: 4, batch: 528, loss: 0.35398, accu: 0.93200, speed: 0.46 step/s
global step 4120, epoch: 4, batch: 538, loss: 0.39028, accu: 0.92275, speed: 1.45 step/s
global step 4130, epoch: 4, batch: 548, loss: 0.38433, accu: 0.92367, speed: 1.48 step/s
global step 4140, epoch: 4, batch: 558, loss: 0.37391, accu: 0.92463, speed: 1.50 step/s
global step 4150, epoch: 4, batch: 568, loss: 0.38000, accu: 0.92490, speed: 1.55 step/s
global step 4160, epoch: 4, batch: 578, loss: 0.39064, accu: 0.92600, speed: 1.53 step/s
global step 4170, epoch: 4, batch: 588, loss: 0.36829, accu: 0.92486, speed: 1.41 step/s
global step 4180, epoch: 4, batch: 598, loss: 0.39404, accu: 0.92425, speed: 1.54 step/s
global step 4190, epoch: 4, batch: 608, loss: 0.39362, accu: 0.92450, speed: 1.49 step/s
global step 4200, epoch: 4, batch: 618, loss: 0.37504, accu: 0.92540, speed: 1.57 step/s
eval dev loss: 0.42783, accu: 0.87741
global step 4210, epoch: 4, batch: 628, loss: 0.37362, accu: 0.92850, speed: 0.46 step/s
global step 4220, epoch: 4, batch: 638, loss: 0.40731, accu: 0.92825, speed: 1.47 step/s



---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

 in 
     25 
     26         loss.backward()
---> 27         optimizer.step()
     28         lr_scheduler.step()
     29         optimizer.clear_grad()


 in step(self)


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py in __impl__(func, *args, **kwargs)
    259         def __impl__(func, *args, **kwargs):
    260             with _switch_tracer_mode_guard_(is_train=False):
--> 261                 return func(*args, **kwargs)
    262 
    263         return __impl__(func)


 in step(self)


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/wrapped_decorator.py in __impl__(func, *args, **kwargs)
     23     def __impl__(func, *args, **kwargs):
     24         wrapped_func = decorator_func(func)
---> 25         return wrapped_func(*args, **kwargs)
     26 
     27     return __impl__


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/framework.py in __impl__(*args, **kwargs)
    223         assert in_dygraph_mode(
    224         ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__
--> 225         return func(*args, **kwargs)
    226 
    227     return __impl__


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/optimizer/adam.py in step(self)
    364 
    365         optimize_ops = self._apply_optimize(
--> 366             loss=None, startup_program=None, params_grads=params_grads)


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/optimizer/optimizer.py in _apply_optimize(self, loss, startup_program, params_grads)
    795                 params_grads = append_regularization_ops(params_grads,
    796                                                          self.regularization)
--> 797                 optimize_ops = self._create_optimization_pass(params_grads)
    798         else:
    799             program = loss.block.program


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/optimizer/adamw.py in _create_optimization_pass(self, parameters_and_grads)
    201     def _create_optimization_pass(self, parameters_and_grads):
    202         optimize_ops = super(
--> 203             AdamW, self)._create_optimization_pass(parameters_and_grads)
    204         # In dygraph mode, clear _lr_to_coeff after applied gradient
    205         self._lr_to_coeff = dict()


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/optimizer/optimizer.py in _create_optimization_pass(self, parameters_and_grads)
    622                     continue
    623                 if param_and_grad[0].stop_gradient is False:
--> 624                     self._append_optimize_op(target_block, param_and_grad)
    625         else:
    626             for param_and_grad in parameters_and_grads:


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/optimizer/adamw.py in _append_optimize_op(self, block, param_and_grad)
    197     def _append_optimize_op(self, block, param_and_grad):
    198         self._append_decoupled_weight_decay(block, param_and_grad)
--> 199         return super(AdamW, self)._append_optimize_op(block, param_and_grad)
    200 
    201     def _create_optimization_pass(self, parameters_and_grads):


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/optimizer/adam.py in _append_optimize_op(self, block, param_and_grad)
    277                 moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
    278                 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
--> 279                 1000, 'beta1', _beta1, 'beta2', _beta2)
    280 
    281             return None


KeyboardInterrupt: 

模型训练过程中会输出如下日志:

global step 5310, epoch: 3, batch: 1578, loss: 0.31671, accu: 0.95000, speed: 0.63 step/s
global step 5320, epoch: 3, batch: 1588, loss: 0.36240, accu: 0.94063, speed: 6.98 step/s
global step 5330, epoch: 3, batch: 1598, loss: 0.41451, accu: 0.93854, speed: 7.40 step/s
global step 5340, epoch: 3, batch: 1608, loss: 0.31327, accu: 0.94063, speed: 7.01 step/s
global step 5350, epoch: 3, batch: 1618, loss: 0.40664, accu: 0.93563, speed: 7.83 step/s
global step 5360, epoch: 3, batch: 1628, loss: 0.33064, accu: 0.93958, speed: 7.34 step/s
global step 5370, epoch: 3, batch: 1638, loss: 0.38411, accu: 0.93795, speed: 7.72 step/s
global step 5380, epoch: 3, batch: 1648, loss: 0.35376, accu: 0.93906, speed: 7.92 step/s
global step 5390, epoch: 3, batch: 1658, loss: 0.39706, accu: 0.93924, speed: 7.47 step/s
global step 5400, epoch: 3, batch: 1668, loss: 0.41198, accu: 0.93781, speed: 7.41 step/s
eval dev loss: 0.4177, accu: 0.89082
global step 5410, epoch: 3, batch: 1678, loss: 0.34453, accu: 0.93125, speed: 0.63 step/s
global step 5420, epoch: 3, batch: 1688, loss: 0.34569, accu: 0.93906, speed: 7.75 step/s
global step 5430, epoch: 3, batch: 1698, loss: 0.39160, accu: 0.92917, speed: 7.54 step/s
global step 5440, epoch: 3, batch: 1708, loss: 0.46002, accu: 0.93125, speed: 7.05 step/s
global step 5450, epoch: 3, batch: 1718, loss: 0.32302, accu: 0.93188, speed: 7.14 step/s
global step 5460, epoch: 3, batch: 1728, loss: 0.40802, accu: 0.93281, speed: 7.22 step/s
global step 5470, epoch: 3, batch: 1738, loss: 0.34607, accu: 0.93348, speed: 7.44 step/s
global step 5480, epoch: 3, batch: 1748, loss: 0.34709, accu: 0.93398, speed: 7.38 step/s
global step 5490, epoch: 3, batch: 1758, loss: 0.31814, accu: 0.93437, speed: 7.39 step/s
global step 5500, epoch: 3, batch: 1768, loss: 0.42689, accu: 0.93125, speed: 7.74 step/s
eval dev loss: 0.41789, accu: 0.88968

基于默认参数配置进行单卡训练大概要持续 4 个小时左右,会训练完成 3 个 Epoch, 模型最终的收敛指标结果如下:

数据集 Accuracy
dev.tsv 89.62

可以看到: 我们基于 PaddleNLP ,利用 ERNIE-Gram 预训练模型使用非常简洁的代码,就在权威语义匹配数据集上取得了很不错的效果.

2.5 模型预测

接下来我们使用已经训练好的语义匹配模型对一些预测数据进行预测。待预测数据为每行都是文本对的 tsv 文件,我们使用 Lcqmc 数据集的测试集作为我们的预测数据,进行预测并提交预测结果到 千言文本相似度竞赛

下载我们已经训练好的语义匹配模型, 并解压

# 下载我们基于 Lcqmc 事先训练好的语义匹配模型并解压
# ! wget https://paddlenlp.bj.bcebos.com/models/text_matching/ernie_gram_zh_pointwise_matching_model.tar
! tar -xvf ernie_gram_zh_pointwise_matching_model.tar
ernie_gram_zh_pointwise_matching_model/
ernie_gram_zh_pointwise_matching_model/model_state.pdparams
ernie_gram_zh_pointwise_matching_model/vocab.txt
ernie_gram_zh_pointwise_matching_model/tokenizer_config.json
# 测试数据由 2 列文本构成 tab 分隔
# Lcqmc 默认下载到如下路径
! head -n3 "${HOME}/.paddlenlp/datasets/LCQMC/lcqmc/lcqmc/test.tsv"
谁有狂三这张高清的	这张高清图,谁有
英雄联盟什么英雄最好	英雄联盟最好英雄是什么
这是什么意思,被蹭网吗	我也是醉了,这是什么意思

定义预测函数


def predict(model, data_loader):
    
    batch_probs = []

    # 预测阶段打开 eval 模式,模型中的 dropout 等操作会关掉
    model.eval()

    with paddle.no_grad():
        for batch_data in data_loader:
            input_ids, token_type_ids = batch_data
            input_ids = paddle.to_tensor(input_ids)
            token_type_ids = paddle.to_tensor(token_type_ids)
            
            # 获取每个样本的预测概率: [batch_size, 2] 的矩阵
            batch_prob = model(
                input_ids=input_ids, token_type_ids=token_type_ids).numpy()

            batch_probs.append(batch_prob)
        batch_probs = np.concatenate(batch_probs, axis=0)

        return batch_probs

定义预测数据的 data_loader

# 预测数据的转换函数
# predict 数据没有 label, 因此 convert_exmaple 的 is_test 参数设为 True
trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=512,
    is_test=True)

# 预测数据的组 batch 操作
# predict 数据只返回 input_ids 和 token_type_ids,因此只需要 2 个 Pad 对象作为 batchify_fn
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input_ids
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment_ids
): [data for data in fn(samples)]

# 加载预测数据
test_ds = load_dataset("lcqmc", splits=["test"])
batch_sampler = paddle.io.BatchSampler(test_ds, batch_size=32, shuffle=False)

# 生成预测数据 data_loader
predict_data_loader =paddle.io.DataLoader(
        dataset=test_ds.map(trans_func),
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)

定义预测模型

# 选择预训练ernie gram,填写自己的代码
pretrained_model = paddlenlp.transformers.ErnieGramModel.from_pretrained('ernie-gram-zh')

model = PointwiseMatching(pretrained_model)
[2021-06-10 01:01:11,310] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-gram-zh/ernie_gram_zh.pdparams
# break!!!

加载已训练好的模型参数

# 刚才下载的模型解压之后存储路径为 ./ernie_gram_zh_pointwise_matching_model/model_state.pdparams
# state_dict = paddle.load("./ernie_gram_zh_pointwise_matching_model/model_state.pdparams")
state_dict = paddle.load("checkpoint/model_3100_88957/model_state.pdparams")


# 刚才下载的模型解压之后存储路径为 ./pointwise_matching_model/ernie1.0_base_pointwise_matching.pdparams
# state_dict = paddle.load("pointwise_matching_model/ernie1.0_base_pointwise_matching.pdparams")
model.set_dict(state_dict)

开始预测

for idx, batch in enumerate(predict_data_loader):
    if idx < 1:
        print(batch)
[Tensor(shape=[32, 38], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,
       [[1   , 1022, 9   , ..., 0   , 0   , 0   ],
        [1   , 514 , 904 , ..., 0   , 0   , 0   ],
        [1   , 47  , 10  , ..., 0   , 0   , 0   ],
        ...,
        [1   , 733 , 404 , ..., 0   , 0   , 0   ],
        [1   , 134 , 170 , ..., 0   , 0   , 0   ],
        [1   , 379 , 3122, ..., 0   , 0   , 0   ]]), Tensor(shape=[32, 38], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,
       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]])]
# 执行预测函数
y_probs = predict(model, predict_data_loader)

# 根据预测概率获取预测 label
y_preds = np.argmax(y_probs, axis=1)

输出预测结果

# 我们按照千言文本相似度竞赛的提交格式将预测结果存储在 lcqmc.tsv 中,用来后续提交
# 同时将预测结果输出到终端,便于大家直观感受模型预测效果

test_ds = load_dataset("lcqmc", splits=["test"])

with open("lcqmc.tsv", 'w', encoding="utf-8") as f:
    f.write("index\tprediction\n")    
    for idx, y_pred in enumerate(y_preds):
        f.write("{}\t{}\n".format(idx, y_pred))
        text_pair = test_ds[idx]
        text_pair["label"] = y_pred
        print(text_pair)
{'query': '怎么才能让新浪微博有很多粉丝?', 'title': '怎么才能在新浪微博有很多粉丝', 'label': 1}

提交 LCQMC 预测结果千言文本相似度竞赛

千言文本相似度竞赛一共有 3 个数据集: lcqmc、bq_corpus、paws-x, 我们刚才生成了 lcqmc 的预测结果 lcqmc.tsv, 同时我们在项目内提供了 bq_corpus、paw-x 数据集的空预测结果,我们将这 3 个文件打包提交到千言文本相似度竞赛,即可看到自己的模型在 Lcqmc 数据集上的竞赛成绩。

# 打包预测结果
!zip submit.zip lcqmc.tsv paws-x.tsv bq_corpus.tsv
  adding: lcqmc.tsv (deflated 65%)
  adding: paws-x.tsv (deflated 61%)
  adding: bq_corpus.tsv (deflated 62%)

ated 62%)

提交预测结果 submit.zip 到 千言文本相似度竞赛

千言文本相似度竞赛结果截图

将自己的竞赛结果贴在此处

1.截图 lcqmc 0.8758

2.加入了visual dl观察变化

3.修改保存代码,每100 step保存一次,不用等完全跑完

4.预测时,不要加载预训练模型model,加载自己训练得到的

课表

PaddleNLP《基于深度学习的自然语言处理》打卡营作业2-- 必修|文本语义相似度计算_第3张图片

你可能感兴趣的:(paddlepaddle)