复现BART finetune历程

复现BART finetune历程

准备

  • 安装fairseq,使用fairseq官方提供的finetune代码

    git clone https://github.com/pytorch/fairseq
    cd fairseq
    pip install --editable ./
    
  • 下载Xsum与DailyCNN数据集,已处理为train.source等形式。解压保存在/home/DataSets/Xsum和/home/DataSets/DailyCNN

    https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md
    
  • 下载官方release的bart_large模型,解压保存至/home/LM/bart_large

    https://github.com/pytorch/fairseq/tree/master/examples/bart
    
  • 安装files2rouge使用paper使用的ROUGE计算方法

    git clone https://github.com/pltrdy/files2rouge.git     
    cd files2rouge
    python setup_rouge.py
    python setup.py install
    

    在Linux系统安装前,修改setup.py文件第29行。

        install_requires=[
    	"pyrouge==0.1.3"
        ],
    

    若安装后运行时出现BUG

    TypeError: __init__() got an unexpected keyword argument 'log_level'
    

    则使用命令手动安装pyrouge

    pip install -U git+https://github.com/pltrdy/pyrouge
    

预处理

数据预处理

  • BPE分词,使用bart_large 模型的词典进行分词

    wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
    wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
    wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
    
    TASK=Xsum
    for SPLIT in train val
    do
      for LANG in source target
      do
        python -m examples.roberta.multiprocessing_bpe_encoder \
        --encoder-json encoder.json \
        --vocab-bpe vocab.bpe \
        --inputs "$TASK/$SPLIT.$LANG" \
        --outputs "$TASK/$SPLIT.bpe.$LANG" \
        --workers 60 \
        --keep-empty;
      done
    done
    

    bart_large

  • Binarize dataset.MODE可以切换使用bart_base或者bart_large模型的词典

    MODE=large
    TASK=Xsum
    fairseq-preprocess \
      --source-lang "source" \
      --target-lang "target" \
      --trainpref "${TASK}/train.bpe" \
      --validpref "${TASK}/val.bpe" \
      --destdir "${TASK}-${MODE}-bin/" \
      --workers 60 \
      --srcdict dict.txt \
      --tgtdict dict.txt;
    

训练

  • finetune和官方提供版本做了相应调整,但超参数设置保持一致。如果使用1块GPU卡进行训练,TOTAL_NUM_UPDATES, WARMUP_UPDATES, 和UPDATE_FREQ分别再*8

    DailyCNN使用16G Tesla V-100进行单卡finetune时MAX_TOKENS只能设为1024否则爆显存。

    #!/bin/bash
    export PYTHONUNBUFFERED=1 
    
    MODE=large
    TASK=/home/DataSets/Xsum
    TOTAL_NUM_UPDATES=15000
    WARMUP_UPDATES=500   
    LR=3e-05
    MAX_TOKENS=2048
    UPDATE_FREQ=2
    BART_PATH=/home/LM/bart_${MODE}/model.pt
    
    
    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train $TASK-${MODE}-bin \
        --restore-file $BART_PATH \
        --reset-optimizer --reset-dataloader --reset-meters \
        --save-dir Xsum_checkpoints_${MODE} \
        --max-tokens $MAX_TOKENS \
        --task translation \
        --source-lang source --target-lang target \
        --truncate-source \
        --layernorm-embedding \
        --share-all-embeddings \
        --share-decoder-input-output-embed \
        --required-batch-size-multiple 1 \
        --arch bart_${MODE} \
        --criterion label_smoothed_cross_entropy \
        --label-smoothing 0.1 \
        --dropout 0.1 --attention-dropout 0.1 \
        --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
        --clip-norm 0.1 \
        --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
        --fp16 --update-freq $UPDATE_FREQ \
        --skip-invalid-size-inputs-valid-test \
        --no-epoch-checkpoints \
        --find-unused-parameters;
    
    

评测

以Xsum为例,使用单卡训练在4个epoch后大约就得到最佳checkpoint。

使用

import torch
from fairseq.models.bart import BARTModel
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default="Xsum", type=str, required=False, help='path of pre-trained model files')
parser.add_argument('--model_size', default="large", type=str, required=False, help='预测结果提交文件')
parser.add_argument('--file_prefix', default="large_model", type=str, required=False, help='预测结果提交文件')
args = parser.parse_args()


result_path = 'results/'+args.dataset+'/'+args.model_size
checkpoint_path = './'+args.dataset+'_checkpoints_'+args.model_size
#checkpoint_path = '/home/LM/bart-large-xsum' #下载官方finetune好的ckpt进行再次验证
print(args.dataset)
print(args.model_size)
print(result_path)
'''
bart = BARTModel.from_pretrained(
    '/home/LM/bart_'+args.model_size+'_pt',
    checkpoint_file='model.pt',
    data_name_or_path=result_path,
    task='translation',
    source_lang = "source",
    target_lang = "target",
)
'''
bart = BARTModel.from_pretrained(
    checkpoint_path,
    checkpoint_file='checkpoint_best.pt',
    #checkpoint_file = 'model.pt', # 载官方finetune好的ckpt进行再次验证
    data_name_or_path=result_path,
    #task='translation',
    #source_lang = "source",
    #target_lang = "target",
)

bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('/home/DataSets/'+args.dataset+'/test.source',encoding="utf8") as source, open(result_path+'/'+args.file_prefix+'_test.hypo', 'w',encoding = "utf8") as fout:
    sline = source.readline().strip()
    slines = [sline]
    for sline in source:
        if count % bsz == 0:
            with torch.no_grad():
                #print(slines,"\n\n\n\n\n\n")
                hypotheses_batch = bart.sample(slines, beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
                #print(hypotheses_batch)
                #exit()
            for hypothesis in hypotheses_batch:
                fout.write(hypothesis + '\n')
                fout.flush()
            slines = []

        slines.append(sline.strip())
        count += 1
    if slines != []:
        hypotheses_batch = bart.sample(slines, beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
        for hypothesis in hypotheses_batch:
            fout.write(hypothesis + '\n')
            fout.flush()

进行测试,并将结果保存在 ./result/Xsum/large下。对于DailyCNN数据,使用

beam=4, lenpen=2.0, max_len_b=140, min_len=55

之后使用files2rouge计算ROUGE得分比较Average_F

$ sudo files2rouge xsum_test.target xsum_my_finetuned_model.hypo 
---------------------------------------------
1 ROUGE-1 Average_R: 0.49389 (95%-conf.int. 0.49098 - 0.49668)
1 ROUGE-1 Average_P: 0.39964 (95%-conf.int. 0.39682 - 0.40230)
1 ROUGE-1 Average_F: 0.43521 (95%-conf.int. 0.43248 - 0.43775)
---------------------------------------------
1 ROUGE-2 Average_R: 0.23021 (95%-conf.int. 0.22718 - 0.23331)
1 ROUGE-2 Average_P: 0.18480 (95%-conf.int. 0.18230 - 0.18739)
1 ROUGE-2 Average_F: 0.20179 (95%-conf.int. 0.19915 - 0.20462)
---------------------------------------------
1 ROUGE-L Average_R: 0.38980 (95%-conf.int. 0.38671 - 0.39278)
1 ROUGE-L Average_P: 0.31509 (95%-conf.int. 0.31253 - 0.31770)
1 ROUGE-L Average_F: 0.34325 (95%-conf.int. 0.34069 - 0.34584)


$ sudo files2rouge xsum_test.target using_offical_model_test.hypo
---------------------------------------------
1 ROUGE-1 Average_R: 0.49030 (95%-conf.int. 0.48742 - 0.49314)
1 ROUGE-1 Average_P: 0.41509 (95%-conf.int. 0.41240 - 0.41798)
1 ROUGE-1 Average_F: 0.44299 (95%-conf.int. 0.44036 - 0.44571)
---------------------------------------------
1 ROUGE-2 Average_R: 0.23270 (95%-conf.int. 0.22979 - 0.23559)
1 ROUGE-2 Average_P: 0.19613 (95%-conf.int. 0.19347 - 0.19868)
1 ROUGE-2 Average_F: 0.20964 (95%-conf.int. 0.20691 - 0.21222)
---------------------------------------------
1 ROUGE-L Average_R: 0.38979 (95%-conf.int. 0.38700 - 0.39267)
1 ROUGE-L Average_P: 0.32982 (95%-conf.int. 0.32691 - 0.33261)
1 ROUGE-L Average_F: 0.35206 (95%-conf.int. 0.34932 - 0.35488)

复现BART finetune历程_第1张图片

你可能感兴趣的:(教程,pytorch,自然语言处理)