飞桨NLP学习

详细代码见:https://gitee.com/chfengyiliu/paddlenlp_learn。涉及实体提取、关系抽取、文字生成图片。

这里备注下自己做实体提取的finetune的训练笔记:

(1) input_ids中加入了提示语prompt信息;

(2) 一个样本对应的output是:当前样本中“每个词状态=实体start位置“、“每个词状态=实体end位置“ 的分布得分。所以每条样本的output size = (1, max_len),batch size个样本集的outputs size =(batch_size, max_len)(见代码中举例)

(3) 损失函数是 每个词的实体位置预测得分 真实值分布的交叉熵损失值。

备注:
实体提取和关系抽取的代码是一模一样的, 不同 的是 喂入样本的prompt提示语 :
实体提取:提示语 只有实体类型 ,如:人名、地名、公司名等;
关系抽取:提示语 既有实体类型,也有关系类型 ,如:实体类-人名、地名,关系类-公司的高管、奶奶、孙子

(1)finetune.py

每个输入样本的编码实现:

batch的获取在 utils.py 的 convert_example(xx) 方法中实现。

该方法底层主要通过以下类实现编码encode:

(1)site-packages\paddlenlp\transformers\tokenizer_utils.py中的_batch_encode_plus(xx)方法;

(2)site-packages\paddlenlp\transformers\tokenizer_utils_base.py中的prepare_for_model(xx)方法。

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time
import os
from functools import partial

import paddle
from paddle.utils.download import get_path_from_url
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import AutoTokenizer
from paddlenlp.metrics import SpanEvaluator
from paddlenlp.utils.log import logger

from model import UIE
from evaluate import evaluate
from utils import set_seed, convert_example, reader, MODEL_MAP, create_data_loader



def do_train():
    paddle.set_device(args.device)
    rank = paddle.distributed.get_rank()
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args.seed)

    resource_file_urls = MODEL_MAP[args.model]['resource_file_urls']

    logger.info("Downloading resource files...")
    for key, val in resource_file_urls.items():
        file_path = os.path.join(args.model, key)
        if not os.path.exists(file_path):
            get_path_from_url(val, args.model)

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = UIE.from_pretrained(args.model)

    train_ds = load_dataset(reader,
                            data_path=args.train_path,
                            max_seq_len=args.max_seq_len,
                            lazy=False)
    dev_ds = load_dataset(reader,
                          data_path=args.dev_path,
                          max_seq_len=args.max_seq_len,
                          lazy=False)

    trans_fn = partial(convert_example,
                       tokenizer=tokenizer,
                       max_seq_len=args.max_seq_len)

    train_data_loader = create_data_loader(train_ds,
                                           mode="train",
                                           batch_size=args.batch_size,
                                           trans_fn=trans_fn)
    dev_data_loader = create_data_loader(dev_ds,
                                         mode="dev",
                                         batch_size=args.batch_size,
                                         trans_fn=trans_fn)

    if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
        state_dict = paddle.load(args.init_from_ckpt)
        model.set_dict(state_dict)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    optimizer = paddle.optimizer.AdamW(learning_rate=args.learning_rate,
                                       parameters=model.parameters())

    criterion = paddle.nn.BCELoss()
    metric = SpanEvaluator()

    loss_list = []
    global_step = 0
    best_f1 = 0
    tic_train = time.time()
    for epoch in range(1, args.num_epochs + 1):
        for batch in train_data_loader:
            # todo 模型输入:
            #  batch的获取在 utils.py 的 convert_example(xx) 方法中实现。
            #  该方法底层主要通过以下类实现编码encode:
            #  (1)site-packages\paddlenlp\transformers\tokenizer_utils.py中的_batch_encode_plus(xx)方法;
            #  (2)site-packages\paddlenlp\transformers\tokenizer_utils_base.py中的prepare_for_model(xx)方法。

            #  在以下基础上根据 max_len截断或补0:
            #  input_ids = cls + prompt + sep + content + sep。 记录实体的属性值(不是实体,如张三的属性是人名,这里记录的是人名,不是张三)+文本描述
            #  token_type_ids = [0]*len(cls + prompt + sep) + [1]*len(content + sep)
            #  att_mask = [1]*len(input_ids)
            #  pos_ids = range(input_ids)

            #  todo 样本标注结构举例
            #   {
            #       "content": "网易公司首席架构设计师,丁磊1997年6月创立网易公司。",
            #       "result_list": [{"text": "网易公司", "start": 0, "end": 4}, {"text": "网易公司", "start": 20, "end": 24}],
            #       "prompt": "公司"
            #   }
            #  转换成模型输入序列 = [cls]+公司+[sep]+网,易,公,司,首,席,架,构,设,计,师,,,丁,磊,1997,年,6,月,创,立,网,易,公,司,。,+[sep]
            #
            #  start_ids = 当前标注样本结构里result_list中所有 text 的 start 处=1,其他处=0。 记录实体的开始位置。
            #  按输入序列编码后start_ids = [0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0]
            #
            #  end_ids = 当前标注样本结构里result_list中所有 text 的 end 处=1,其他处=0 。 记录实体的结束位置。
            #  按输入序列编码后遍码 = [0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0]
            input_ids, token_type_ids, att_mask, pos_ids, start_ids, end_ids = batch

  
             # todo 模型输出: 这里batch_size=8,max_len=30
            #  start_prob -- 输出每个样本里每个字作为起始位置的得分,每个样本输出shape=[1,max_len],按batch后shape=[batch_size, max_len]
            #  end_prob -- 输出每个样本里每个字作为结束位置的概率,每个样本输出shape=[1,max_len],按batch后shape=[batch_size, max_len]
            #  例如:Tensor(shape=[8, 30], dtype=float32, place=Place(gpu:0), stop_gradient=False,
            #        [[0.00000031, 0.00000129, 0.00000024, 0.00000031, 0.00000042, 0.00070048,
            #          0.00000877, 0.00002038, 0.00000138, 0.00001625, 0.00000029, 0.00008086,
            #          0.00000202, 0.00000990, 0.00000406, 0.00016828, 0.00001869, 0.00021678,
            #          0.00000937, 0.00046841, 0.00002430, 0.00000023, 0.00006322, 0.00000595,
            #          0.00000095, 0.00000166, 0.00000030, 0.00000033, 0.00000023, 0.00000073],
            #         [0.00000190, 0.00000188, 0.00000061, 0.00000031, 0.00000152, 0.00057717,
            #          0.00000707, 0.00004766, 0.00000537, 0.00017263, 0.00000520, 0.00002636,
            #          0.00000283, 0.00066136, 0.00001277, 0.00100498, 0.00003792, 0.00072994,
            #          0.00085996, 0.00004832, 0.00000354, 0.00000241, 0.00000104, 0.00000189,
            #          0.00000262, 0.00000088, 0.00000112, 0.00000747, 0.00000136, 0.00000190],
            #         [0.00000082, 0.00000135, 0.00000011, 0.00000184, 0.00000276, 0.01416136,
            #          0.00006575, 0.00002398, 0.00433168, 0.00007143, 0.00002933, 0.00001502,
            #          0.00000795, 0.00002079, 0.00000122, 0.00001549, 0.00000252, 0.00002578,
            #          0.00000451, 0.00000495, 0.00000058, 0.00000076, 0.00000072, 0.00000065,
            #          0.00000183, 0.00000128, 0.00000089, 0.00000143, 0.00000092, 0.00000059],
            #         [0.00000310, 0.00000056, 0.00000012, 0.00000066, 0.98270828, 0.00001563,
            #          0.00001566, 0.00013158, 0.00000047, 0.00003079, 0.00000261, 0.00002862,
            #          0.00000128, 0.00000287, 0.00000035, 0.00000161, 0.00000160, 0.00000046,
            #          0.00000067, 0.00000049, 0.00000047, 0.00000088, 0.00000070, 0.00000102,
            #          0.00000038, 0.00000045, 0.00000058, 0.00000048, 0.00000076, 0.00000086],
            #         [0.00000288, 0.00000049, 0.00000015, 0.00000016, 0.00000197, 0.00002671,
            #          0.00000357, 0.00000193, 0.00000425, 0.00000676, 0.00000180, 0.00009420,
            #          0.05136376, 0.00011208, 0.00007415, 0.02065694, 0.00003858, 0.00000352,
            #          0.00000202, 0.00004982, 0.00000858, 0.00003040, 0.00002816, 0.00000214,
            #          0.00000188, 0.00000202, 0.00000312, 0.00000156, 0.00000234, 0.00000251],
            #         [0.00000035, 0.00000123, 0.00000029, 0.00000053, 0.71992922, 0.00012994,
            #          0.00003325, 0.00004523, 0.00003941, 0.00000144, 0.00001109, 0.00000611,
            #          0.00000129, 0.00012184, 0.00004307, 0.00000224, 0.00001191, 0.00002946,
            #          0.00000537, 0.00008847, 0.00002351, 0.00000039, 0.00000065, 0.00000039,
            #          0.00000026, 0.00000050, 0.00000176, 0.00000064, 0.00000034, 0.00000079],
            #         [0.00000126, 0.00000964, 0.00000142, 0.00000091, 0.00000283, 0.01225603,
            #          0.00028963, 0.00000450, 0.00007187, 0.00000975, 0.00098875, 0.00000235,
            #          0.00004167, 0.00000309, 0.00036950, 0.00002466, 0.00126630, 0.00023501,
            #          0.00562316, 0.00004048, 0.00000703, 0.00022977, 0.00001388, 0.00000123,
            #          0.00000202, 0.00000168, 0.00000258, 0.00000113, 0.00000218, 0.00000248],
            #         [0.00000040, 0.00000045, 0.00000008, 0.00000032, 0.00022130, 0.00000234,
            #          0.00001500, 0.00006374, 0.76793706, 0.00115055, 0.00060874, 0.38059372,
            #          0.00022558, 0.00005922, 0.00001344, 0.00001371, 0.00000058, 0.00000036,
            #          0.00000137, 0.00000051, 0.00000054, 0.00000072, 0.00000063, 0.00000039,
            #          0.00000051, 0.00000096, 0.00000066, 0.00000111, 0.00000035, 0.00000046]])
            start_prob, end_prob = model(input_ids, token_type_ids, att_mask, pos_ids)
            start_ids = paddle.cast(start_ids, 'float32')
            end_ids = paddle.cast(end_ids, 'float32')
            loss_start = criterion(start_prob, start_ids)
            loss_end = criterion(end_prob, end_ids)
            loss = (loss_start + loss_end) / 2.0
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            loss_list.append(float(loss))

            global_step += 1
            if global_step % args.logging_steps == 0 and rank == 0:
                time_diff = time.time() - tic_train
                loss_avg = sum(loss_list) / len(loss_list)
                logger.info(
                    "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
                    % (global_step, epoch, loss_avg,
                       args.logging_steps / time_diff))
                tic_train = time.time()

            if global_step % args.valid_steps == 0 and rank == 0:
                save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(save_dir)
                logger.disable()
                tokenizer.save_pretrained(save_dir)
                logger.enable()

                precision, recall, f1 = evaluate(model, metric, dev_data_loader)
                logger.info(
                    "Evaluation precision: %.5f, recall: %.5f, F1: %.5f" %
                    (precision, recall, f1))
                if f1 > best_f1:
                    logger.info(
                        f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
                    )
                    best_f1 = f1
                    save_dir = os.path.join(args.save_dir, "model_best")
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(save_dir)
                    logger.disable()
                    tokenizer.save_pretrained(save_dir)
                    logger.enable()
                tic_train = time.time()


if __name__ == "__main__":
    # yapf: disable
    parser = argparse.ArgumentParser()

    # todo 新增路径前缀
    # path_prefix = './self-mark'  # 自己标注的样本:包括实体和关系
    path_prefix = './relation'  # 关系抽取【目标:同时完成实体提取和关系抽取,所以标注时prompt即有实体的提示语,如“人名”,也有关系的提示语,如“小米公司的高管”】
    # path_prefix = './ner'  # 实体识别
    '''
    执行run没报错,但显示“进程已结束,退出代码-1073740791 (0xC0000409)”, 可能是因为gpu显存不足,可以尝试把batch_size缩小。
    这里batch_size从16减到2就可以正常运行了!!
    '''
    # parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
    # todo batch_size实体识别default=8,关系抽取default=2
    parser.add_argument("--batch_size", default=2, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
    # parser.add_argument("--train_path", default='./data/train.txt', type=str, help="The path of train set.")
    parser.add_argument("--train_path", default=path_prefix + '/data/train.txt', type=str, help="The path of train set.")
    # parser.add_argument("--dev_path", default='./data/dev.txt', type=str, help="The path of dev set.")
    parser.add_argument("--dev_path", default=path_prefix + '/data/dev.txt', type=str, help="The path of dev set.")
    # parser.add_argument("--save_dir", default='./checkpoint_gx', type=str, help="The output directory where the model checkpoints will be written.")
    parser.add_argument("--save_dir", default=path_prefix + '/checkpoint', type=str, help="The output directory where the model checkpoints will be written.")
    # todo 默认是512,debug时可以设置小点,方便看矩阵里的值,比如实体提取=30,关系抽取=100。
    parser.add_argument("--max_seq_len", default=100, type=int, help="The maximum input sequence length. "
        "Sequences longer than this will be split automatically.")
    # parser.add_argument("--num_epochs", default=100, type=int, help="Total number of training epochs to perform.")
    # todo epochs实体识别default=10,关系抽取default=10或15
    parser.add_argument("--num_epochs", default=10, type=int, help="Total number of training epochs to perform.")
    parser.add_argument("--seed", default=1000, type=int, help="Random seed for initialization")
    parser.add_argument("--logging_steps", default=10, type=int, help="The interval steps to logging.")
    # parser.add_argument("--valid_steps", default=100, type=int, help="The interval steps to evaluate model performance.")
    # todo valid_steps实体识别default=10,关系抽取default=4。
    #  默认模型保存的step间隔等于--valid_steps,即实体识别每10步保存一次模型,关系抽取则是每4步。
    parser.add_argument("--valid_steps", default=4, type=int, help="The interval steps to evaluate model performance.")
    parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
    parser.add_argument("--model", choices=["uie-base", "uie-tiny", "uie-medium", "uie-mini", "uie-micro", "uie-nano"], default="uie-base", type=str, help="Select the pretrained model for few-shot learning.")
    parser.add_argument("--init_from_ckpt", default=None, type=str, help="The path of model parameters for initialization.")

    args = parser.parse_args()
    # yapf: enable

    do_train()
(2)utils.py
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import math
import json
import random
from tqdm import tqdm

import numpy as np
import paddle
from paddlenlp.utils.log import logger

MODEL_MAP = {
    # vocab.txt/special_tokens_map.json/tokenizer_config.json are common to the default model.
    "uie-base": {
        "resource_file_urls": {
            "model_state.pdparams":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v1.0/model_state.pdparams",
            "model_config.json":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
            "vocab_file":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
            "special_tokens_map":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
            "tokenizer_config":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
        }
    },
    "uie-medium": {
        "resource_file_urls": {
            "model_state.pdparams":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
            "model_config.json":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
            "vocab_file":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
            "special_tokens_map":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
            "tokenizer_config":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
        }
    },
    "uie-mini": {
        "resource_file_urls": {
            "model_state.pdparams":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
            "model_config.json":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
            "vocab_file":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
            "special_tokens_map":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
            "tokenizer_config":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
        }
    },
    "uie-micro": {
        "resource_file_urls": {
            "model_state.pdparams":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
            "model_config.json":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
            "vocab_file":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
            "special_tokens_map":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
            "tokenizer_config":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
        }
    },
    "uie-nano": {
        "resource_file_urls": {
            "model_state.pdparams":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
            "model_config.json":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
            "vocab_file":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
            "special_tokens_map":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
            "tokenizer_config":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
        }
    },
    # Rename to `uie-medium` and the name of `uie-tiny` will be deprecated in future.
    "uie-tiny": {
        "resource_file_urls": {
            "model_state.pdparams":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
            "model_config.json":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
            "vocab_file":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
            "special_tokens_map":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
            "tokenizer_config":
            "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
        }
    }
}


def set_seed(seed):
    paddle.seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
    """
    Create dataloader.
    Args:
        dataset(obj:`paddle.io.Dataset`): Dataset instance.
        mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
        batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
        trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
    Returns:
        dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
    """
    if trans_fn:
        dataset = dataset.map(trans_fn)

    shuffle = True if mode == 'train' else False
    if mode == "train":
        sampler = paddle.io.DistributedBatchSampler(dataset=dataset,
                                                    batch_size=batch_size,
                                                    shuffle=shuffle)
    else:
        sampler = paddle.io.BatchSampler(dataset=dataset,
                                         batch_size=batch_size,
                                         shuffle=shuffle)
    dataloader = paddle.io.DataLoader(dataset,
                                      batch_sampler=sampler,
                                      return_list=True)
    return dataloader


def convert_example(example, tokenizer, max_seq_len):
    """
    example: {
        title
        prompt
        content
        result_list
    }
    """
    encoded_inputs = tokenizer(text=[example["prompt"]],
                               text_pair=[example["content"]],
                               truncation=True,
                               max_seq_len=max_seq_len,
                               pad_to_max_seq_len=True,
                               return_attention_mask=True,
                               return_position_ids=True,
                               return_dict=False,
                               return_offsets_mapping=True)
    encoded_inputs = encoded_inputs[0]
    # todo #"【cls】出,发,地【sep】深,大,到,双,龙,28,块,钱,4,月,24,号,交,通,费", 在maxlength=30时对应的offetmapping =
    #  [(0, 0), (0, 1), (1, 2), (2, 3),
    #   (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 13), (13, 14), (14, 15), (15, 16), (16, 17),
    #   (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]
    offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"]]
    bias = 0
    for index in range(1, len(offset_mapping)):
        mapping = offset_mapping[index]
        if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
            bias = offset_mapping[index - 1][1] + 1  # Includes [SEP] token
        if mapping[0] == 0 and mapping[1] == 0:
            continue
        offset_mapping[index][0] += bias
        offset_mapping[index][1] += bias
    start_ids = [0 for x in range(max_seq_len)]
    end_ids = [0 for x in range(max_seq_len)]
    for item in example["result_list"]:
        start = map_offset(item["start"] + bias, offset_mapping)
        end = map_offset(item["end"] - 1 + bias, offset_mapping)
        start_ids[start] = 1.0
        end_ids[end] = 1.0

    tokenized_output = [
        encoded_inputs["input_ids"], encoded_inputs["token_type_ids"],
        encoded_inputs["position_ids"], encoded_inputs["attention_mask"],
        start_ids, end_ids
    ]
    tokenized_output = [np.array(x, dtype="int64") for x in tokenized_output]
    return tuple(tokenized_output)


def map_offset(ori_offset, offset_mapping):
    """
    map ori offset to token offset
    """
    for index, span in enumerate(offset_mapping):
        if span[0] <= ori_offset < span[1]:
            return index
    return -1


def reader(data_path, max_seq_len=512):
    """
    read json
    """
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            json_line = json.loads(line)
            content = json_line['content'].strip()
            prompt = json_line['prompt']
            # Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
            # It include three summary tokens.
            if max_seq_len <= len(prompt) + 3:
                raise ValueError(
                    "The value of max_seq_len is too small, please set a larger value"
                )
            max_content_len = max_seq_len - len(prompt) - 3
            if len(content) <= max_content_len:
                yield json_line
            else:
                if result['end'] - result['start'] > max_content_len:
                    logger.warn(
                        "result['end '] - result ['start'] exceeds max_content_len, which will result in no valid instance being returned"
                    )
                result_list = json_line['result_list']
                json_lines = []
                accumulate = 0
                while True:
                    cur_result_list = []

                    for result in result_list:
                        if result['start'] + 1 <= max_content_len < result[
                                'end'] and result['end'] - result[
                                    'start'] <= max_content_len:
                            max_content_len = result['start']
                            break

                    cur_content = content[:max_content_len]
                    res_content = content[max_content_len:]

                    while True:
                        if len(result_list) == 0:
                            break
                        elif result_list[0]['end'] <= max_content_len:
                            if result_list[0]['end'] > 0:
                                cur_result = result_list.pop(0)
                                cur_result_list.append(cur_result)
                            else:
                                cur_result_list = [
                                    result for result in result_list
                                ]
                                break
                        else:
                            break

                    json_line = {
                        'content': cur_content,
                        'result_list': cur_result_list,
                        'prompt': prompt
                    }
                    json_lines.append(json_line)

                    for result in result_list:
                        if result['end'] <= 0:
                            break
                        result['start'] -= max_content_len
                        result['end'] -= max_content_len
                    accumulate += max_content_len
                    max_content_len = max_seq_len - len(prompt) - 3
                    if len(res_content) == 0:
                        break
                    elif len(res_content) < max_content_len:
                        json_line = {
                            'content': res_content,
                            'result_list': result_list,
                            'prompt': prompt
                        }
                        json_lines.append(json_line)
                        break
                    else:
                        content = res_content

                for json_line in json_lines:
                    yield json_line


def unify_prompt_name(prompt):
    # The classification labels are shuffled during finetuning, so they need
    # to be unified during evaluation.
    if re.search(r'\[.*?\]$', prompt):
        prompt_prefix = prompt[:prompt.find("[", 1)]
        cls_options = re.search(r'\[.*?\]$', prompt).group()[1:-1].split(",")
        cls_options = sorted(list(set(cls_options)))
        cls_options = ",".join(cls_options)
        prompt = prompt_prefix + "[" + cls_options + "]"
        return prompt
    return prompt


def get_relation_type_dict(relation_data):

    def compare(a, b):
        a = a[::-1]
        b = b[::-1]
        res = ''
        for i in range(min(len(a), len(b))):
            if a[i] == b[i]:
                res += a[i]
            else:
                break
        if res == "":
            return res
        elif res[::-1][0] == "的":
            return res[::-1][1:]
        return ""

    relation_type_dict = {}
    added_list = []
    for i in range(len(relation_data)):
        added = False
        if relation_data[i][0] not in added_list:
            for j in range(i + 1, len(relation_data)):
                match = compare(relation_data[i][0], relation_data[j][0])
                if match != "":
                    match = unify_prompt_name(match)
                    if relation_data[i][0] not in added_list:
                        added_list.append(relation_data[i][0])
                        relation_type_dict.setdefault(match, []).append(
                            relation_data[i][1])
                    added_list.append(relation_data[j][0])
                    relation_type_dict.setdefault(match, []).append(
                        relation_data[j][1])
                    added = True
            if not added:
                added_list.append(relation_data[i][0])
                suffix = relation_data[i][0].rsplit("的", 1)[1]
                suffix = unify_prompt_name(suffix)
                relation_type_dict.setdefault(suffix,
                                              []).append(relation_data[i][1])
    return relation_type_dict


def add_entity_negative_example(examples, texts, prompts, label_set,
                                negative_ratio):
    negative_examples = []
    positive_examples = []
    with tqdm(total=len(prompts)) as pbar:
        for i, prompt in enumerate(prompts):
            redundants = list(set(label_set) ^ set(prompt))
            redundants.sort()

            num_positive = len(examples[i])
            if num_positive != 0:
                actual_ratio = math.ceil(len(redundants) / num_positive)
            else:
                # Set num_positive to 1 for text without positive example
                num_positive, actual_ratio = 1, 0

            if actual_ratio <= negative_ratio or negative_ratio == -1:
                idxs = [k for k in range(len(redundants))]
            else:
                idxs = random.sample(range(0, len(redundants)),
                                     negative_ratio * num_positive)

            for idx in idxs:
                negative_result = {
                    "content": texts[i],
                    "result_list": [],
                    "prompt": redundants[idx]
                }
                negative_examples.append(negative_result)
            positive_examples.extend(examples[i])
            pbar.update(1)
    return positive_examples, negative_examples


def add_relation_negative_example(redundants, text, num_positive, ratio):
    added_example = []
    rest_example = []

    if num_positive != 0:
        actual_ratio = math.ceil(len(redundants) / num_positive)
    else:
        # Set num_positive to 1 for text without positive example
        num_positive, actual_ratio = 1, 0

    all_idxs = [k for k in range(len(redundants))]
    if actual_ratio <= ratio or ratio == -1:
        idxs = all_idxs
        rest_idxs = []
    else:
        idxs = random.sample(range(0, len(redundants)), ratio * num_positive)
        rest_idxs = list(set(all_idxs) ^ set(idxs))

    for idx in idxs:
        negative_result = {
            "content": text,
            "result_list": [],
            "prompt": redundants[idx]
        }
        added_example.append(negative_result)

    for rest_idx in rest_idxs:
        negative_result = {
            "content": text,
            "result_list": [],
            "prompt": redundants[rest_idx]
        }
        rest_example.append(negative_result)

    return added_example, rest_example


def add_full_negative_example(examples, texts, relation_prompts, predicate_set,
                              subject_goldens):
    with tqdm(total=len(relation_prompts)) as pbar:
        for i, relation_prompt in enumerate(relation_prompts):
            negative_sample = []
            for subject in subject_goldens[i]:
                for predicate in predicate_set:
                    # The relation prompt is constructed as follows:
                    # subject + "的" + predicate
                    prompt = subject + "的" + predicate
                    if prompt not in relation_prompt:
                        negative_result = {
                            "content": texts[i],
                            "result_list": [],
                            "prompt": prompt
                        }
                        negative_sample.append(negative_result)
            examples[i].extend(negative_sample)
            pbar.update(1)
    return examples


def generate_cls_example(text, labels, prompt_prefix, options):
    random.shuffle(options)
    cls_options = ",".join(options)
    prompt = prompt_prefix + "[" + cls_options + "]"

    result_list = []
    example = {"content": text, "result_list": result_list, "prompt": prompt}
    for label in labels:
        start = prompt.rfind(label) - len(prompt) - 1
        end = start + len(label)
        result = {"text": label, "start": start, "end": end}
        example["result_list"].append(result)
    return example


def convert_cls_examples(raw_examples,
                         prompt_prefix="情感倾向",
                         options=["正向", "负向"]):
    """
    Convert labeled data export from doccano for classification task.
    """
    examples = []
    logger.info(f"Converting doccano data...")
    with tqdm(total=len(raw_examples)) as pbar:
        for line in raw_examples:
            items = json.loads(line)
            # Compatible with doccano >= 1.6.2
            if "data" in items.keys():
                text, labels = items["data"], items["label"]
            else:
                text, labels = items["text"], items["label"]
            example = generate_cls_example(text, labels, prompt_prefix, options)
            examples.append(example)
    return examples


def convert_ext_examples(raw_examples,
                         negative_ratio,
                         prompt_prefix="情感倾向",
                         options=["正向", "负向"],
                         separator="##",
                         is_train=True):
    """
    Convert labeled data export from doccano for relation and aspect-level classification task.
    """

    def _sep_cls_label(label, separator):
        label_list = label.split(separator)
        if len(label_list) == 1:
            return label_list[0], None
        return label_list[0], label_list[1:]

    texts = []
    entity_examples = []
    relation_examples = []
    entity_cls_examples = []
    entity_prompts = []
    relation_prompts = []
    entity_label_set = []
    entity_name_set = []
    predicate_set = []
    subject_goldens = []
    inverse_relation_list = []
    predicate_list = []

    logger.info(f"Converting doccano data...")
    with tqdm(total=len(raw_examples)) as pbar:
        for line in raw_examples:
            items = json.loads(line)
            entity_id = 0
            if "data" in items.keys():
                relation_mode = False
                if isinstance(items["label"],
                              dict) and "entities" in items["label"].keys():
                    relation_mode = True
                text = items["data"]
                entities = []
                relations = []
                if not relation_mode:
                    # Export file in JSONL format which doccano < 1.7.0
                    # e.g. {"data": "", "label": [ [0, 2, "ORG"], ... ]}
                    for item in items["label"]:
                        entity = {
                            "id": entity_id,
                            "start_offset": item[0],
                            "end_offset": item[1],
                            "label": item[2]
                        }
                        entities.append(entity)
                        entity_id += 1
                else:
                    # Export file in JSONL format for relation labeling task which doccano < 1.7.0
                    # e.g. {"data": "", "label": {"relations": [ {"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, ... ], "entities": [ {"id": 0, "from_id": 0, "to_id": 1, "type": "foundedAt"}, ... ]}}
                    entities.extend(
                        [entity for entity in items["label"]["entities"]])
                    if "relations" in items["label"].keys():
                        relations.extend([
                            relation for relation in items["label"]["relations"]
                        ])
            else:
                # Export file in JSONL format which doccano >= 1.7.0
                # e.g. {"text": "", "label": [ [0, 2, "ORG"], ... ]}
                if "label" in items.keys():
                    text = items["text"]
                    entities = []
                    for item in items["label"]:
                        entity = {
                            "id": entity_id,
                            "start_offset": item[0],
                            "end_offset": item[1],
                            "label": item[2]
                        }
                        entities.append(entity)
                        entity_id += 1
                    relations = []
                else:
                    # Export file in JSONL (relation) format
                    # e.g. {"text": "", "relations": [ {"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, ... ], "entities": [ {"id": 0, "from_id": 0, "to_id": 1, "type": "foundedAt"}, ... ]}
                    text, relations, entities = items["text"], items[
                        "relations"], items["entities"]
            texts.append(text)

            entity_example = []
            entity_prompt = []
            entity_example_map = {}
            entity_map = {}  # id to entity name
            for entity in entities:
                entity_name = text[entity["start_offset"]:entity["end_offset"]]
                entity_map[entity["id"]] = {
                    "name": entity_name,
                    "start": entity["start_offset"],
                    "end": entity["end_offset"]
                }

                entity_label, entity_cls_label = _sep_cls_label(
                    entity["label"], separator)

                # Define the prompt prefix for entity-level classification
                entity_cls_prompt_prefix = entity_name + "的" + prompt_prefix
                if entity_cls_label is not None:
                    entity_cls_example = generate_cls_example(
                        text, entity_cls_label, entity_cls_prompt_prefix,
                        options)

                    entity_cls_examples.append(entity_cls_example)

                result = {
                    "text": entity_name,
                    "start": entity["start_offset"],
                    "end": entity["end_offset"]
                }
                if entity_label not in entity_example_map.keys():
                    entity_example_map[entity_label] = {
                        "content": text,
                        "result_list": [result],
                        "prompt": entity_label
                    }
                else:
                    entity_example_map[entity_label]["result_list"].append(
                        result)

                if entity_label not in entity_label_set:
                    entity_label_set.append(entity_label)
                if entity_name not in entity_name_set:
                    entity_name_set.append(entity_name)
                entity_prompt.append(entity_label)

            for v in entity_example_map.values():
                entity_example.append(v)

            entity_examples.append(entity_example)
            entity_prompts.append(entity_prompt)

            subject_golden = []  # Golden entity inputs
            relation_example = []
            relation_prompt = []
            relation_example_map = {}
            inverse_relation = []
            predicates = []
            for relation in relations:
                predicate = relation["type"]
                subject_id = relation["from_id"]
                object_id = relation["to_id"]
                # The relation prompt is constructed as follows:
                # subject + "的" + predicate
                prompt = entity_map[subject_id]["name"] + "的" + predicate
                if entity_map[subject_id]["name"] not in subject_golden:
                    subject_golden.append(entity_map[subject_id]["name"])
                result = {
                    "text": entity_map[object_id]["name"],
                    "start": entity_map[object_id]["start"],
                    "end": entity_map[object_id]["end"]
                }

                inverse_negative = entity_map[object_id][
                    "name"] + "的" + predicate
                inverse_relation.append(inverse_negative)
                predicates.append(predicate)

                if prompt not in relation_example_map.keys():
                    relation_example_map[prompt] = {
                        "content": text,
                        "result_list": [result],
                        "prompt": prompt
                    }
                else:
                    relation_example_map[prompt]["result_list"].append(result)

                if predicate not in predicate_set:
                    predicate_set.append(predicate)
                relation_prompt.append(prompt)

            for v in relation_example_map.values():
                relation_example.append(v)

            relation_examples.append(relation_example)
            relation_prompts.append(relation_prompt)
            subject_goldens.append(subject_golden)
            inverse_relation_list.append(inverse_relation)
            predicate_list.append(predicates)
            pbar.update(1)

    logger.info(f"Adding negative samples for first stage prompt...")
    positive_examples, negative_examples = add_entity_negative_example(
        entity_examples, texts, entity_prompts, entity_label_set,
        negative_ratio)
    if len(positive_examples) == 0:
        all_entity_examples = []
    else:
        all_entity_examples = positive_examples + negative_examples

    all_relation_examples = []
    if len(predicate_set) != 0:
        logger.info(f"Adding negative samples for second stage prompt...")
        if is_train:

            positive_examples = []
            negative_examples = []
            per_n_ratio = negative_ratio // 3

            with tqdm(total=len(texts)) as pbar:
                for i, text in enumerate(texts):
                    negative_example = []
                    collects = []
                    num_positive = len(relation_examples[i])

                    # 1. inverse_relation_list
                    redundants1 = inverse_relation_list[i]

                    # 2. entity_name_set ^ subject_goldens[i]
                    redundants2 = []
                    if len(predicate_list[i]) != 0:
                        nonentity_list = list(
                            set(entity_name_set) ^ set(subject_goldens[i]))
                        nonentity_list.sort()

                        redundants2 = [
                            nonentity + "的" +
                            predicate_list[i][random.randrange(
                                len(predicate_list[i]))]
                            for nonentity in nonentity_list
                        ]

                    # 3. entity_label_set ^ entity_prompts[i]
                    redundants3 = []
                    if len(subject_goldens[i]) != 0:
                        non_ent_label_list = list(
                            set(entity_label_set) ^ set(entity_prompts[i]))
                        non_ent_label_list.sort()

                        redundants3 = [
                            subject_goldens[i][random.randrange(
                                len(subject_goldens[i]))] + "的" + non_ent_label
                            for non_ent_label in non_ent_label_list
                        ]

                    redundants_list = [redundants1, redundants2, redundants3]

                    for redundants in redundants_list:
                        added, rest = add_relation_negative_example(
                            redundants,
                            texts[i],
                            num_positive,
                            per_n_ratio,
                        )
                        negative_example.extend(added)
                        collects.extend(rest)

                    num_sup = num_positive * negative_ratio - len(
                        negative_example)
                    if num_sup > 0 and collects:
                        if num_sup > len(collects):
                            idxs = [k for k in range(len(collects))]
                        else:
                            idxs = random.sample(range(0, len(collects)),
                                                 num_sup)
                        for idx in idxs:
                            negative_example.append(collects[idx])

                    positive_examples.extend(relation_examples[i])
                    negative_examples.extend(negative_example)
                    pbar.update(1)
            all_relation_examples = positive_examples + negative_examples
        else:
            relation_examples = add_full_negative_example(
                relation_examples, texts, relation_prompts, predicate_set,
                subject_goldens)
            all_relation_examples = [
                r for relation_example in relation_examples
                for r in relation_example
            ]
    return all_entity_examples, all_relation_examples, entity_cls_examples

(3)evaluate.py

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from functools import partial

import paddle
from paddlenlp.datasets import load_dataset, MapDataset
from paddlenlp.transformers import AutoTokenizer
from paddlenlp.metrics import SpanEvaluator
from paddlenlp.utils.log import logger

from model import UIE
from utils import convert_example, reader, unify_prompt_name, get_relation_type_dict, create_data_loader


@paddle.no_grad()
def evaluate(model, metric, data_loader):
    """
    Given a dataset, it evals model and computes the metric.
    Args:
        model(obj:`paddle.nn.Layer`): A model to classify texts.
        metric(obj:`paddle.metric.Metric`): The evaluation metric.
        data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
    """
    model.eval()
    metric.reset()
    for batch in data_loader:
        input_ids, token_type_ids, att_mask, pos_ids, start_ids, end_ids = batch
        start_prob, end_prob = model(input_ids, token_type_ids, att_mask,
                                     pos_ids)
        start_ids = paddle.cast(start_ids, 'float32')
        end_ids = paddle.cast(end_ids, 'float32')
        # todo 举例说明
        # metric = SpanEvaluator()
        # 在metric.compute(start_prob, end_prob, start_ids, end_ids)方法中,
        #  todo step1:对每条样本,筛选出其预测得分集里大于阈值的index。
        #   假设某条样本(maxlength=30)输出 “状态 = 是某个实体的start” 的预测得分集 = [
        #             0.00000031, 0.00000129, 0.00000024, 0.00000031, 0.50000042, 0.00070048,
        #             0.00000877, 0.00002038, 0.00000138, 0.00001625, 0.00000029, 0.00008086,
        #             0.00000202, 0.00000990, 0.00000406, 0.00016828, 0.00001869, 0.00021678,
        #             0.00000937, 0.00046841, 0.00002430, 0.00000023, 0.00006322, 0.00000595,
        #             0.00000095, 0.00000166, 0.00000030, 0.00000033, 0.00000023, 0.00000073]
        #   则得分大于0.5的index=[4],即预测当前样本 index=4 处可能是一个实体的开始位置。
        #   依次对batch中每个样本的预测得分集做以上处理,得到每个样本的pre-start和pre-end。综合按batch输出,即:
        # pre-start = [[4],  [],  [14],  [4],  [8],   [],  [5, 8], [5, 8]]  # 可以看到当前批次的第7,8条样本中各有2个位置处的得分 > 阈值0.5。
        # pre_end =   [[7],  [],   [],   [7],  [9],   [],   [6],   [6, 9]]
        #
        # gold_start= [[4], [14], [14],  [4],  [8],  [13],   [5],    [8]]
        # glod_end =  [[7], [14], [14],  [7],  [8],  [14],   [6],    [9]]

        #  todo step2:沿文本序列方向组合(pre_start,pre_end)、(gold_start,gold_end),得到
        # pre: (4,7),   (),     (),  (4,7),(8,9),   (),  (5,6), ((5,6),(8,9))
        # gold:(4,7),(14,14),(14,14),(4,7),(8,8),(13,14),(5,6),     (8,9)
        #
        #  todo step3:pre、gold对应组合求与&运算,得到:
        # result: 1,     0,      0,     1,    1,     0,     1,         1
        # pre: (4,7),   (),     (),  (4,7),(8,9),   (),  (5,6), ((5,6),(8,9))
        # gold:(4,7),(14,14),(14,14),(4,7),(8,8),(13,14),(5,6),     (8,9)

        # todo step4:统计预测组pre、真实组gold、正确组correct的数量
        # correct: 1,    0,      0,     1,    1,     0,     1,         1           --num_correct=5
        # pre: (4,7),   (),     (),  (4,7),(8,9),   (),  (5,6), ((5,6),(8,9))     --num_infer=6
        # gold:(4,7),(14,14),(14,14),(4,7),(8,8),(13,14),(5,6),     (8,9)         --num_label=8

        # todo step5:计算precious、recall、F1
        # precious = num_correct/num_infer
        # recall=num_correct/num_label
        # F1=2*precious*recall/(precious+recall)
        num_correct, num_infer, num_label = metric.compute(
            start_prob, end_prob, start_ids, end_ids)
        metric.update(num_correct, num_infer, num_label)
    precision, recall, f1 = metric.accumulate()
    model.train()
    return precision, recall, f1


def do_eval():
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = UIE.from_pretrained(args.model_path)

    test_ds = load_dataset(reader,
                           data_path=args.test_path,
                           max_seq_len=args.max_seq_len,
                           lazy=False)
    class_dict = {}
    relation_data = []
    if args.debug:
        for data in test_ds:
            class_name = unify_prompt_name(data['prompt'])
            # Only positive examples are evaluated in debug mode
            if len(data['result_list']) != 0:
                if "的" not in data['prompt']:
                    class_dict.setdefault(class_name, []).append(data)
                else:
                    relation_data.append((data['prompt'], data))
        relation_type_dict = get_relation_type_dict(relation_data)
    else:
        class_dict["all_classes"] = test_ds

    trans_fn = partial(convert_example,
                       tokenizer=tokenizer,
                       max_seq_len=args.max_seq_len)

    for key in class_dict.keys():
        if args.debug:
            test_ds = MapDataset(class_dict[key])
        else:
            test_ds = class_dict[key]

        test_data_loader = create_data_loader(test_ds,
                                              mode="test.txt",
                                              batch_size=args.batch_size,
                                              trans_fn=trans_fn)

        metric = SpanEvaluator()
        precision, recall, f1 = evaluate(model, metric, test_data_loader)
        logger.info("-----------------------------")
        logger.info("Class Name: %s" % key)
        logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
                    (precision, recall, f1))

    if args.debug and len(relation_type_dict.keys()) != 0:
        for key in relation_type_dict.keys():
            test_ds = MapDataset(relation_type_dict[key])

            test_data_loader = create_data_loader(test_ds,
                                                  mode="test.txt",
                                                  batch_size=args.batch_size,
                                                  trans_fn=trans_fn)

            metric = SpanEvaluator()
            precision, recall, f1 = evaluate(model, metric, test_data_loader)
            logger.info("-----------------------------")
            logger.info("Class Name: X的%s" % key)
            logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
                        (precision, recall, f1))


if __name__ == "__main__":
    # yapf: disable
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.")
    parser.add_argument("--test_path", type=str, default=None, help="The path of test.txt set.")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size per GPU/CPU for training.")
    # todo 默认是512,debug时可以设置小点,方便看矩阵里的值,比如30。
    # parser.add_argument("--max_seq_len", type=int, default=512, help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--max_seq_len", type=int, default=30, help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--debug", action='store_true', help="Precision, recall and F1 score are calculated for each class separately if this option is enabled.")

    args = parser.parse_args()
    # yapf: enable

    do_eval()

你可能感兴趣的:(算法学习,paddlepaddle,学习,人工智能)