Longformer中文长文本摘要生成

1 Longformer

之前做了BART中文摘要生成,但是因为项目需求是中文长文本摘要生成,因此在此采用Longformer完成中文摘要生成(实际用的是LED,Longformer基础上添加了解码器),11G显存长度可以到8K,非常友好。短文本上虽然比不上BART,不过这并不重要。

1.1 Longformer结构

LED结构与BART类似,只不过多了global attention,因为LED没有中文预训练模型,但是我们有BART呀,这也给出了BART权重转到LED的脚本,因此这次我们就采用BART的权重来给LED作为初始化。

2 BART权重到LED权重初始化

2.1 设置输入输出长度

import copy
import logging

from transformers import LEDConfig, LEDForConditionalGeneration, BertTokenizer
from transformers import BartForConditionalGeneration

logger = logging.getLogger("longformer-chinese")
logging.basicConfig(level=logging.INFO)

max_encoder_position_embeddings = 8192#设置最大输入文本长度
max_decoder_position_embeddings = 1024 #设置解码器输入最大长度,也就是生成最大长度

2.2 配置LED

此处把LED的配置设置好,与BART-base-chinese设置成相同的,同时将该配置作为LED模型的配置初始化。

led_config = LEDConfig(vocab_size=21128,
        max_encoder_position_embeddings=max_encoder_position_embeddings,
        max_decoder_position_embeddings=max_decoder_position_embeddings,
        encoder_layers=6,
        encoder_ffn_dim=3072,
        encoder_attention_heads=12,
        decoder_layers=6,
        decoder_ffn_dim=3072,
        decoder_attention_heads=12,
        encoder_layerdrop=0.0,
        decoder_layerdrop=0.0,
        use_cache=True,
        is_encoder_decoder=True,
        activation_function="gelu",
        d_model=768,
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=0.02,
        decoder_start_token_id=102,
        classifier_dropout=0.0,
        pad_token_id=0,
        bos_token_id=101,
        eos_token_id=102,
        attention_window= 512,)
led_model = LEDForConditionalGeneration(led_config)

2.3 加载BART模型

bart_model = BartForConditionalGeneration.from_pretrained(r'E:\Project\NLP\long-document\bart-base')
tokenizer = BertTokenizer.from_pretrained(r'E:\Project\NLP\long-document\bart-base')

2.4 BART权重复制到LED

current_max_pos,embed_size = bart_model.model.encoder.embed_positions.weight.shape
new_encoder_pos_embed = bart_model.model.encoder.embed_positions.weight.new_empty(max_encoder_position_embeddings,embed_size)

k=0
step = current_max_pos-2

encoder_position_embeddings= bart_model.model.encoder.embed_positions.weight[2:]
while k

2.5 保存LED权重

logger.info("convert bart-chinese to led success")
led_model.save_pretrained(r'E:\Project\NLP\long-document\converted_model')
tokenizer.save_pretrained(r'E:\Project\NLP\long-document\converted_model')

3 完整代码

import copy
import logging

from transformers import LEDConfig, LEDForConditionalGeneration, BertTokenizer
from transformers import BartForConditionalGeneration

logger = logging.getLogger("longformer-chinese")
logging.basicConfig(level=logging.INFO)

max_encoder_position_embeddings = 8192
max_decoder_position_embeddings = 1024

led_config = LEDConfig(vocab_size=21128,
        max_encoder_position_embeddings=max_encoder_position_embeddings,
        max_decoder_position_embeddings=max_decoder_position_embeddings,
        encoder_layers=6,
        encoder_ffn_dim=3072,
        encoder_attention_heads=12,
        decoder_layers=6,
        decoder_ffn_dim=3072,
        decoder_attention_heads=12,
        encoder_layerdrop=0.0,
        decoder_layerdrop=0.0,
        use_cache=True,
        is_encoder_decoder=True,
        activation_function="gelu",
        d_model=768,
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=0.02,
        decoder_start_token_id=102,
        classifier_dropout=0.0,
        pad_token_id=0,
        bos_token_id=101,
        eos_token_id=102,
        attention_window= 512,)
led_model = LEDForConditionalGeneration(led_config)
bart_model = BartForConditionalGeneration.from_pretrained(r'E:\Project\NLP\long-document\bart-base')
tokenizer = BertTokenizer.from_pretrained(r'E:\Project\NLP\long-document\bart-base')

current_max_pos,embed_size = bart_model.model.encoder.embed_positions.weight.shape
new_encoder_pos_embed = bart_model.model.encoder.embed_positions.weight.new_empty(max_encoder_position_embeddings,embed_size)

k=0
step = current_max_pos-2
# new_encoder_pos_embed[0]=bart_model.model.encoder.embed_positions.weight[0]
encoder_position_embeddings= bart_model.model.encoder.embed_positions.weight[2:]
while k

到这里就搞定了,接下来在自己的短文本或者长文本摘要数据集上训练就可以了。亲测可以用。甚至效果还不错。下一篇将介绍把BART的权重加载到Bigbird模型来作为初始化,这样我们也可以拿到一个Bigbird中文模型。需要强调的是,这样直接finetune效果有限,想要更高的结果需要继续pretrain。

3 训练

3.1 直接微调

# coding=utf-8
import logging
import datasets
import numpy as np
import lawrouge
import rouge
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import LEDForConditionalGeneration, BertTokenizer

from datasets import load_dataset

logger = logging.getLogger("longformer-chinese")
logging.basicConfig(level=logging.INFO)

dataset = load_dataset('json', data_files=r'D:\nlp\project\long-document\datasets\xxxx.json', field='data') # 加载自己的长文本摘要数据集
dataset = dataset.shuffle(seeds=42) # shuffle

tokenizer = BertTokenizer.from_pretrained(r'D:\nlp\project\long-document\bert-base-chinese') # 加载bert tokenizer
model = LEDForConditionalGeneration.from_pretrained(r'D:\nlp\project\long-document\converted_model') # 加载Longformer
# model.resize_token_embeddings(tokenizer.vocab_size) # 补充词表 21128--->50000

def flatten(example):
    return {
        "document": example["content"],
        "summary": example["title"],
    }

dataset = dataset["train"].map(flatten, remove_columns=["title", "content"])  # , remove_columns=["title", "content"]

max_input_length = 8192 # 4096 or others ,不能超过我们转换的最大长度8192
max_target_length = 1024  # summary, target text

def preprocess_function(examples):
    inputs = [doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
dataset = dataset.shuffle()

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
tokenized_datasets = datasets.DatasetDict({"train": train_data_txt, "validation": validation_data_txt}).map(preprocess_function, batched=True)

batch_size = 4 # ==>穷人
args = Seq2SeqTrainingArguments(
    fp16 = True,
    output_dir="results_long",
    num_train_epochs=50,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=batch_size,  # demo
    per_device_eval_batch_size=batch_size,
    learning_rate=2e-04,
    warmup_steps=1000,
    weight_decay=0.0001,
    label_smoothing_factor=0.15,
    predict_with_generate=True,
    logging_dir="logs",
    logging_strategy="steps",
    logging_steps=1,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=500,
    gradient_accumulation_steps=1,
    generation_max_length=64,
    generation_num_beams=1,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
    decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]
    # Rouge with jieba cut
    # decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
    # decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]

    labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]

    for i,(pred,label) in enumerate(zip(decoded_preds,decoded_labels)):
        if pred=="":
            decoded_preds[i]="decoding error,skipping..."

    # print(decoded_preds)
    # print(decoded_labels)
    # rouge = lawrouge.Rouge()
    rouge = rouge.Rouge()
    result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
    # print(result)
    print(result)
    result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}

    result = {key: value * 100 for key, value in result.items()}
    result["gen_len"] = np.mean(labels_lens)
    return result


trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# 保存模型即训练数据
train_result = trainer.train()
print(train_result)
trainer.save_model()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()



3.2 继续pretrain-->mlm

# coding=utf-8

import jieba
from pkuseg import pkuseg
from transformers import BertConfig, BertForMaskedLM, DataCollatorForWholeWordMask, \
    BertTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BartForCausalLM, \
    LEDForConditionalGeneration, LineByLineTextDataset
from torch.utils.data import Dataset
from tqdm import tqdm
import torch




class pretrain_dataset(Dataset):

    def __init__(self, path, tokenizer, dup_factor=5,max_length=512): # dup_factor : dynamic mask for 5 times
        self.examples = []
        with open(path,'r',encoding='utf-8') as f:
            total_data = f.readlines()
            with tqdm(total_data * dup_factor) as loader:
                for data in loader:
                    # clean data
                    data = data.replace('\n', '').replace('\r', '').replace('\t','').replace(' ','').replace(' ', '')
                    # chinese_ref = self.get_new_segment(data)
                    input_ids = tokenizer.encode_plus(data,truncation=True,max_length=max_length).input_ids
                    dict_data = {'input_ids' : input_ids} #, 'chinese_ref' : chinese_ref
                    self.examples.append(dict_data)
                    loader.set_description(f'loading data')


    def get_new_segment(self,segment):
        """
            使用分词工具获取 whole word mask
            用于wwm预训练
            e.g [喜,欢]-> [喜,##欢]
        """
        seq_cws = jieba.cut("".join(segment))  # 利用jieba分词
        # seq_cws = segment
        chinese_ref = []
        index = 1
        for seq in seq_cws:
            for i, word in enumerate(seq):
                if i>0:
                    chinese_ref.append(index)
                index +=1
        return chinese_ref

    def __getitem__(self, index):
        return self.examples[index]

    def __len__(self):
        return len(self.examples)


if __name__ == '__main__':
    # configuration
    epoch = 100
    batch_size = 4
    pretrian_model = r'...'
    train_file = r'train_nlpcc.txt'
    test_file = r'test_nlpcc.txt'
    save_epoch = 1 # every 10 epoch save checkpoint
   
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    
    config = BertConfig.from_pretrained(pretrian_model)
    tokenizer = BertTokenizer.from_pretrained(pretrian_model)

    train_dataset = LineByLineTextDataset(
        tokenizer=tokenizer,
        file_path=train_file,  # mention train text file here
        block_size=512)

    # train_dataset = pretrain_dataset(train_file,tokenizer)

    test_dataset = LineByLineTextDataset(
        tokenizer=tokenizer,
        file_path=test_file,  # mention train text file here
        block_size=512,
        )
    # test_dataset = pretrain_dataset(test_file, tokenizer)
    # model = BertForMaskedLM(config)
    model = LEDForConditionalGeneration.from_pretrained(pretrian_model).to(device) # BartForCausalLM
    print('No of parameters: ', model.num_parameters())

    data_collator = DataCollatorForLanguageModeling(#DataCollatorForWholeWordMask DataCollatorForLanguageModeling
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15
    )
    print('No. of lines: ', len(train_dataset))
    save_step = len(train_dataset) * save_epoch
    tot_step = int(len(train_dataset)/batch_size *  epoch)
    print(f'\n\t***** Running training *****\n'
          f'\tNum examples = {len(train_dataset)}\n'
          f'\tNum Epochs = {epoch}\n'
          f'\tBatch size = {batch_size}\n'
          f'\tTotal optimization steps = {tot_step}\n')

    # official training
    training_args = TrainingArguments(
        output_dir=r'nlpcc',
        overwrite_output_dir=True,
        num_train_epochs=epoch,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=16,
        learning_rate=2e-05,
        warmup_steps=100,
        weight_decay=0,
        # save_steps=save_step,
        logging_dir="../logs",
        logging_strategy="steps",
        logging_steps=1,
        save_total_limit=20,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        load_best_model_at_end=True,
        prediction_loss_only=True,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=test_dataset
    )

    trainer.train()
    trainer.save_model(r'/mlm_ouputs/led/nlpcc')


你可能感兴趣的:(深度学习,人工智能)