知识图谱:【知识图谱问答KBQA(六)】——P-tuning V2训练代码解析

文章目录

    • 一.arguments.py
      • DataTrainingArguments类
      • ModelArguments类
      • QuestionAnwseringArguments类
      • get_args()函数
    • 二.run.py
      • Step 1. 获取所有参数
      • Step 2. 根据任务名称选择导入对应的get_trainer
      • Step 3. 将参数args传入get_trainer,得到trainer
        • 1)根据模型名称或路径加载tokenizer
        • 2)根据tokenizer和参数data_args、training_args加载数据集dataset
        • 3)根据模型名称或路径、dataset加载模型配置config
        • 4)根据模型参数和模型配置加载模型(get_model)
        • 5)根据model、训练参数、tokenizer以及dataset初始化并返回trainer
      • Step 4. 模型训练、验证及测试

一.arguments.py

DataTrainingArguments类

关于我们将输入模型进行训练和评估的数据参数

  • task_name.任务名称
  • dataset_name.数据集名称
  • dataset_config_name.要使用的数据集的配置名称
  • max_seq_length.标记化(tokenization)后的最大总输入序列长度。 比这长的序列将被截断,短的序列将被填充。
  • overwrite_cache.是否覆盖缓存的预处理数据集
  • pad_to_max_length.是否将所有样本填充到 max_seq_length。 如果为 False,将在批处理时动态填充样本到批处理中的最大长度
  • max_train_samples.出于调试目的或更快的训练,将训练示例的数量截断为该值(如果已设置)
  • max_eval_samples.出于调试目的或更快的训练,将验证示例的数量截断为该值(如果已设置)
  • max_predict_samples.出于调试目的或更快的训练,将测试示例的数量截断为该值(如果已设置)
  • train_file.包含训练数据的 csv 或 json 文件
  • validation_file.包含验证数据的 csv 或 json 文件
  • test_file.包含测试数据的 csv 或 json 文件
  • template_id.要使用的特定prompt字符串

ModelArguments类

关于我们将从哪个模型/配置/标记器进行微调的参数

  • model_name_or_path.从 huggingface.co/models 下载预训练模型的路径或模型标识符
  • config_name.如果与 model_name 不同,则需指定预训练的配置名称或路径
  • tokenizer_name.如果与 model_name 不同,则需指定预训练的标记器名称或路径
  • cache_dir.用于存储从 huggingface.co 下载的预训练模型的路径
  • use_fast_tokenizer.是否使用快速分词器之一(由分词器库支持)
  • model_revision.要使用的特定模型版本(可以是分支名称、标签名称或提交 ID)
  • use_auth_token.是否使用模型加密
  • prefix.训练时使用P-Tuning V2
  • prompt.训练时使用P-Tuning
  • pre_seq_len.prompt的长度
  • prefix_projection.在前缀嵌入上应用两层 MLP 头
  • prefix_hidden_size.如果使用前缀投影,则前缀编码器中 MLP 投影头的隐藏层大小
  • hidden_dropout_prob.dropout比例

QuestionAnwseringArguments类

  • n_best_size.寻找答案时生成的 n 最佳预测的总数
  • max_answer_length.可以生成的答案的最大长度
  • version_2_with_negative.如果为真,有些例子没有答案
  • null_score_diff_threshold.用于选择空答案的阈值:如果最佳答案的分数小于空答案的分数减去此阈值,则本示例选择空答案。 仅在 version_2_with_negative=True 时有用

get_args()函数

用于解析P-Tuning V2中的所有参数。

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, QuestionAnwseringArguments))
args = parser.parse_args_into_dataclasses()
return args

二.run.py

Step 1. 获取所有参数

args = get_args()
_, data_args, training_args, _ = args

Step 2. 根据任务名称选择导入对应的get_trainer

Step 3. 将参数args传入get_trainer,得到trainer

1)根据模型名称或路径加载tokenizer

tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
    )

2)根据tokenizer和参数data_args、training_args加载数据集dataset

3)根据模型名称或路径、dataset加载模型配置config

 config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            num_labels=dataset.num_labels,
            label2id=dataset.label2id,
            id2label=dataset.id2label,
            finetuning_task=data_args.dataset_name,
            revision=model_args.model_revision,
        )

4)根据模型参数和模型配置加载模型(get_model)

通过模型参数可以选择三种不同的训练方式:

  • 训练方式1:P-Tuning V2(prefix=True)
    if model_args.prefix:
        config.hidden_dropout_prob = model_args.hidden_dropout_prob
        config.pre_seq_len = model_args.pre_seq_len
        config.prefix_projection = model_args.prefix_projection
        config.prefix_hidden_size = model_args.prefix_hidden_size
        
        model_class = PREFIX_MODELS[config.model_type][task_type]
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            revision=model_args.model_revision,
        )
  • 训练方式2:P-Tuning(prefix=False && prompt=True)
    elif model_args.prompt:
        config.pre_seq_len = model_args.pre_seq_len
        model_class = PROMPT_MODELS[config.model_type][task_type]
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            revision=model_args.model_revision,
        )
  • 训练方式3:fine-tuning(prefix=False && prompt=False)
    else:
        model_class = AUTO_MODELS[task_type]
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            revision=model_args.model_revision,
        )

        bert_param = 0
        if fix_bert:
            if config.model_type == "bert":
                for param in model.bert.parameters():
                    param.requires_grad = False
                for _, param in model.bert.named_parameters():
                    bert_param += param.numel()
            elif config.model_type == "roberta":
                for param in model.roberta.parameters():
                    param.requires_grad = False
                for _, param in model.roberta.named_parameters():
                    bert_param += param.numel()
            elif config.model_type == "deberta":
                for param in model.deberta.parameters():
                    param.requires_grad = False
                for _, param in model.deberta.named_parameters():
                    bert_param += param.numel()
        all_param = 0
        for _, param in model.named_parameters():
            all_param += param.numel()
        total_param = all_param - bert_param
        print('***** total param is {} *****'.format(total_param))

5)根据model、训练参数、tokenizer以及dataset初始化并返回trainer

# Initialize our Trainer
    trainer = BaseTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset.train_dataset if training_args.do_train else None,
        eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
        compute_metrics=dataset.compute_metrics,
        tokenizer=tokenizer,
        data_collator=dataset.data_collator,
        test_key=dataset.test_key
    )


    return trainer, None

Step 4. 模型训练、验证及测试

    if training_args.do_train:
        train(trainer, training_args.resume_from_checkpoint, last_checkpoint)
    
    if training_args.do_eval:
        evaluate(trainer)

    if training_args.do_predict:
        predict(trainer, predict_dataset)

你可能感兴趣的:(python,自然语言处理,深度学习,知识图谱,人工智能,python)