Fairseq框架学习(三)Fairseq 模型

本节主要介绍Fairseq模型以及如何自定义模型。Fairseq中的模型一般是在fairseq/models中,包括常用的Transformer、BART等等,本文以BART为例讲解Fairseq模型的使用以及自定义方法。

BART模型

Fairseq中,BART包括两个文件,分别是model.py(模型内部结构)以及hub_interface.py(读入预训练参数)。

model.py

  1. Fairseq需要注册模型及模型框架,这样在训练时可以通过--arch bart_large识别到该模型及其初始化参数设置
from fairseq.models import register_model, register_model_architecture

@register_model("bart")

@register_model_architecture("bart", "bart_large")

自定义模型时同样需要注册模型以及模型框架,如果是在新的文件夹中写模型的话,需要添加__init__.py文件,使得程序能够识别到自定义模型

import importlib
import os

for file in sorted(os.listdir(os.path.dirname(__file__))):
    if file.endswith(".py") and not file.startswith("_"):
        model_name = file[: file.find(".py")]
        importlib.import_module("xxx.models." + model_name)
  1. BARTModel继承TransformerModel,在init时运用BERT的参数随机初始化
from fairseq.models.transformer import TransformerModel
class BARTModel(TransformerModel):
   def __init__(self, args, encoder, decoder):
    super().__init__(args, encoder, decoder)
    self.apply(init_bert_params)
  1. upgrade_state_dict_named读入已经训练好的预训练参数时,比如bart.large,需要对数据进行一个处理。因为原本在训练BART时是通过预测mask内容训练embedding的,我们在finetune时是不需要mask标识的,所以这里要去除掉最后添加的mask标识
if (
    loaded_dict_size == len(self.encoder.dictionary) + 1
    and "" not in self.encoder.dictionary
):
  truncate_emb("encoder.embed_tokens.weight")
  truncate_emb("decoder.embed_tokens.weight")
  truncate_emb("encoder.output_projection.weight")
  truncate_emb("decoder.output_projection.weight")

这里同步去掉output_projection.weight最后一行,是因为weight tying,令pre-softmax的权重等于embedding层的权重。

自定义模型时若添加了其他标识,在读入预训练模型时embedding层要做好对齐。

hub_interface.py

  1. encode在前后端添加以及标识,将文本通过dictionary转变为模型可识别的数字
  2. decode去掉,并将数字通过dictionary变为文字
  3. generatebeam_size进行beam search,测试时生成最终文本

下一篇(四)Fairseq任务,主要介绍fairseq自定义任务。

你可能感兴趣的:(Fairseq框架学习(三)Fairseq 模型)