Pytorch框架下的transformers的使用

huggingface团队在pytorch框架下开发了transformers工具包:https://github.com/huggingface/transformers,工具包实现了大量基于transformer的模型,如albert,bert,roberta等。工具包的代码结构如图所示:

Pytorch框架下的transformers的使用_第1张图片

其中比较重要的是src/transformers以及example这两个文件夹。其中,src/transformers文件夹下是各类transformer模型的实现代码;而examples下主要是各类下游任务的微调代码。我们以文本分类任务为例来说明微调过程具体是如何实现的,在官方的例子中,使用GLUE数据集。

一、run_glue.sh文件解析

按照官方文档的指引,首先需要构建用于启动微调程序的脚本文件,脚本为微调程序提供参数。

export GLUE_DIR=/path/to/glue
export TASK_NAME=MRPC

python ./examples/text-classification/run_glue.py \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --data_dir $GLUE_DIR/$TASK_NAME \
    --max_seq_length 128 \
    --per_device_eval_batch_size=8   \
    --per_device_train_batch_size=8   \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir /tmp/$TASK_NAME/

其中几个主要参数的意义如下:

  • model_name_or_path:用于指定进行微调的预训练模型。参数可以是模型名称,在第一次执行微调程序时,会自动下载对应的模型;参数也可以是模型路径,此时需要提前下载对应的模型到设定的路径中。
  • task_name:用于指定具体的下游任务,微调程序需要根据任务名称选择相应的processor以实现数据加载。
  • data_dir:用于指定微调数据的存储路径。
  • output_dir:用于指定微调好的模型的存放路径

二、run_glue.py文件解析

启动脚本会调用run_glue.py文件来执行微调程序。程序主要有三部分功能:加载模型,加载数据,进行微调(训练,验证,预测)。

1、加载预训练模型

(1)加载用于构建模型以及用于微调过程的参数

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

其中,类ModelArguments中包含的是关于模型的属性,如model_name,config_name,tokenizer_name等,类在run.py文件中定义;类DataTrainingArguments中包含的是关于微调数据的属性,如task_name,data_dir等,类在transformers/data/datasets/glue.py文件中定义;TrainingArguments中包含的是关于微调过程的参数,如batch_size,learning_rate等参数,类在transformers/training_args.py中定义。

(2)生成model,config,tokenizer

其中,config用于加载配置信息,model根据config加载模型,tokenize用于在加载数据时提供编码信息。

config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    num_labels=num_labels,
    finetuning_task=data_args.task_name,
    cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    cache_dir=model_args.cache_dir,
)

2、加载数据

需要使用GlueDataset类构建数据,类定义在transformers/data/datasets/glue.py中,是对Dataset类的继承。

train_dataset = (
    GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
    GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
    if training_args.do_eval
    else None
)
test_dataset = (
    GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
    if training_args.do_predict
    else None
)

在GlueDataset类中,需要利用glue_processors类来加载数据内容。glue_processors类定义在transformers/data/processors/glue.py中。

self.processor = glue_processors[args.task_name]()

if mode == Split.dev:
    examples = self.processor.get_dev_examples(args.data_dir)
elif mode == Split.test:
    examples = self.processor.get_test_examples(args.data_dir)
else:
    examples = self.processor.get_train_examples(args.data_dir)

3、微调(训练,验证,预测)

(1)构建训练器

训练器Trainer类:主要用于指定使用的模型,数据,微调过程所用参数的信息。类中包含用于训练,验证,预测的方法:trainer.train(train_dataset),trainer.evaluate(eval_dataset),trainer.predicate(test_dataset)。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=build_compute_metrics_fn(data_args.task_name),
)

(2)进行微调(训练,验证,预测)

三、如何定义自己的微调方法

有时候,我们的数据可能与官方所用的数据形式不同,这时候需要对方法进行重写以定义自己的微调方法,重写的内容主要包括:

  1. 重写dataset类
  2. 重写processor类

所有用到的参数都以属性的形式存在于ModelArguments,DataTrainingArguments,TrainingArguments这三个类中,若要改变某个参数,只需要在启动脚本中设置即可。

你可能感兴趣的:(NLP,pytorch,自然语言处理)