Pytorch实现Bert/RoBerta微调(以MELD数据集为例)

Bert/RoBerta 微调笔记

    • 前言
    • 为什么要进行微调?
    • 怎么微调?参数的设置?
    • 问题:
      • (1)Bert/RoBerta所有参数是不是都要训练?
      • (2)微调Bert/RoBerta时,无法载入全部参数报错CUDA out of memory
      • (3)如何冻结模型参数?
      • (4)如何保存fine-tune好的BERT/ROBERTA模型参数,以及如何在特征提取阶段使用这些参数?
      • (5)逐层微调
    • 以MELD数据集为例的微调RoBerta的代码

前言

本文记录我在学习BERT/ROBERTA fine-tuning过程的遇到的问题,包括内存受限,微调概念,微调方法等。文章方法不适用于逐层微调,且只以NLP文本分类举例,微调代码参考github link
更新: (5)逐层微调

为什么要进行微调?

任务数据集与BERT/ROBERTA预训练的数据集差异较大,用微调使预训练模型(PLM)能更好地适应数据集。在实际任务中,用官方下载的PLM参数进行文本特征提取,提取出来的特征效果特别差,例如ROBERTA提取出来的特征经过训练,同一模型的f1score低于原论文10%以上。而进行微调之后能与原论文持平。

怎么微调?参数的设置?

1.第一种将PLM模型后接分类器,直接训练完整的网络。以[1e-6,2e-5,3e-5]相同量级的小学习率训练PLM,epoch小于10完全可以训练出一个好的PLM,再保存PLM的参数。在特征提取时,Bert/RoBerta作为特征提取器载入已保存的参数提取文本特征。本文实现的也是这种方式。
2.第二种PLM后接完整的自己的模型,逐层冻结PLM层参数,一层一层地调Bert/RoBerta的参数,因为在Bert/RoBerta中越底层,参数变化越不明显,所以需要调节的参数都在高层。

问题:

(1)Bert/RoBerta所有参数是不是都要训练?

在fine-tune阶段整个bert模型参数都需要训练。

(2)微调Bert/RoBerta时,无法载入全部参数报错CUDA out of memory

首先确定batch_size要足够小,我在3090/24G显存上fine-tune,设置的bacth_size=1勉强可以跑得动
其次是代码的复杂度,不要有多余的参数放到cuda()里,在每个epoch后都torch.cuda.empty_cache()清理缓存

fro e in range(config['epoch']):
	torch.cuda.empyt_cache()
	# train, valid, test dateset...
	e += 1

然后是数据集的格式,之前使用csv格式逐条读取,只能用24GB的3090训练,后面换成txt文件读取速度和内存占用有很大区别。
或者在test or eval()阶段可以添加with torch.no_grad()使模型在预测阶段不计算参数梯度来减少内存的使用量。

(3)如何冻结模型参数?

有很多种方法,遍历模型参数列表,使模型参数梯度计算False;

for param in model.parameters():
	param.requires_grad = False

在预测阶段,模型所有参数都不需要梯度计算,都可以冻结,就可以用with torch.no_grad()

# 训练阶段需要计算梯度
model.train()
log_prob = model(**kwargs)
# 验证,测试阶段不需要梯度
model.eval()
with torch.no_grad():
	log_prob = model(**kwargs)

(4)如何保存fine-tune好的BERT/ROBERTA模型参数,以及如何在特征提取阶段使用这些参数?

保存、载入模型参数方法来自于:solution from github
保存阶段:

save_path = './saved_models/myModel.pth'
# finetune过程跳过
torch.save({'model_state_dict': model.encoder.state_dict()}, save_path)
# 其中self.encoder = AutoModel.from_pretrain('./roberta=large')

调用阶段:

checkpoint = torch.load('../saved_models/mymodel.pth')
model = AutoModel.from_pretrained('../roberta-large')
model.load_state_dict(checkpoint['model_state_dict'])
model.cuda()

(5)逐层微调

方法:在模型参数初始化下面,添加需要冻结的层名字,或者添加不需要冻结的层名字,以roberta-large为例,冻结列表所列出以外的roberta层。

class testModel(nn.Module):
	def __init__(self):
		self.encoder = AutoModel.from_pretrained('./roberta-large')
		self.classifier = nn.Sequential(
            nn.Linear(args.emb_dim, 300),
            nn.ReLU(),
            nn.Linear(300, n_class)
        )

		unfreeze_layers = ['layer.17', 'layer.18', 'layer.19', 'layer.20',
                           'layer.21', 'layer.22', 'layer.23',
                           'bert.pooler', 'out.']
        for name, param in self.encoder.named_parameters():
            param.requires_grad = False
            for ele in unfreeze_layers:
                if ele in name:
                    param.requires_grad = True

或者冻结指定层,其他层参数默认需要训练

		freeze_layers = ['layer.1', 'layer.2']
        for name, param in self.encoder.named_parameters():
            for ele in freeze_layers:
                if ele in name:
                    param.requires_grad = False

以MELD数据集为例的微调RoBerta的代码

import torch
import random
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence
from sklearn.metrics import precision_recall_fscore_support
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup

config = {
    'bert_path': './roberta-large',
    'dataset_path': './data/MELD',
    'saved_model': './saved_models/mymodel.pth',  # 这里要注意一下
    'emb_dim': 1024,
    'n_class': 7,
    'batch': 1,
    'epoch': 10,
    'lr': 1e-6,
    'max_grad_norm': 10
}

roberta_tokenizer = AutoTokenizer.from_pretrained(config['bert_path'])


class MELD_loader(Dataset):
    def __init__(self, txt_file, dataclass):
        self.dialogs = []

        f = open(txt_file, 'r')
        dataset = f.readlines()
        f.close()

        temp_speakerList = []
        context = []
        context_speaker = []
        self.speakerNum = []
        # 'anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'
        emodict = {'anger': "anger", 'disgust': "disgust", 'fear': "fear", 'joy': "joy", 'neutral': "neutral",
                   'sadness': "sad", 'surprise': 'surprise'}
        self.sentidict = {'positive': ["joy"], 'negative': ["anger", "disgust", "fear", "sadness"],
                          'neutral': ["neutral", "surprise"]}
        self.emoSet = set()
        self.sentiSet = set()
        for i, data in enumerate(dataset):
            if i < 2:
                continue
            if data == '\n' and len(self.dialogs) > 0:
                self.speakerNum.append(len(temp_speakerList))
                temp_speakerList = []
                context = []
                context_speaker = []
                continue
            speaker, utt, emo, senti = data.strip().split('\t')
            context.append(utt)
            if speaker not in temp_speakerList:
                temp_speakerList.append(speaker)
            speakerCLS = temp_speakerList.index(speaker)
            context_speaker.append(speakerCLS)

            self.dialogs.append([context_speaker[:], context[:], emodict[emo], senti])
            self.emoSet.add(emodict[emo])
            self.sentiSet.add(senti)

        self.emoList = sorted(self.emoSet)
        self.sentiList = sorted(self.sentiSet)
        if dataclass == 'emotion':
            self.labelList = self.emoList
        else:
            self.labelList = self.sentiList
        self.speakerNum.append(len(temp_speakerList))

    def __len__(self):
        return len(self.dialogs)

    def __getitem__(self, idx):
        return self.dialogs[idx], self.labelList, self.sentidict


def encode_right_truncated(text, tokenizer, max_length=511):
    '''
    完成分词工作
    :return:
    '''
    tokenized = tokenizer.tokenize(text)
    truncated = tokenized[-max_length:]
    ids = tokenizer.convert_tokens_to_ids(truncated)

    return [tokenizer.cls_token_id] + ids


def padding(ids_list, tokenizer):
    max_len = 0
    for ids in ids_list:
        if len(ids) > max_len:
            max_len = len(ids)

    pad_ids = []
    for ids in ids_list:
        pad_len = max_len - len(ids)
        add_ids = [tokenizer.pad_token_id for _ in range(pad_len)]

        pad_ids.append(ids + add_ids)

    return torch.tensor(pad_ids)


def make_batch_roberta(sessions):
    '''
    collate_fn
    :return:
    '''
    batch_input, batch_labels, batch_speaker_tokens = [], [], []
    for session in sessions:
        data = session[0]
        label_list = session[1]

        context_speaker, context, emotion, sentiment = data
        now_speaker = context_speaker[-1]
        speaker_utt_list = []

        inputString = ""
        for turn, (speaker, utt) in enumerate(zip(context_speaker, context)):
            inputString += ' + str(speaker + 1) + '> '  # s1, s2, s3...
            inputString += utt + " "

            if turn < len(context_speaker) - 1 and speaker == now_speaker:
                speaker_utt_list.append(encode_right_truncated(utt, roberta_tokenizer))

        concat_string = inputString.strip()
        batch_input.append(encode_right_truncated(concat_string, roberta_tokenizer))

        if len(label_list) > 3:
            label_ind = label_list.index(emotion)
        else:
            label_ind = label_list.index(sentiment)
        batch_labels.append(label_ind)

        batch_speaker_tokens.append(padding(speaker_utt_list, roberta_tokenizer))

    batch_input_tokens = padding(batch_input, roberta_tokenizer)
    batch_labels = torch.tensor(batch_labels)

    return batch_input_tokens, batch_labels, batch_speaker_tokens


def CELoss(pred_outs, labels):
    """
        pred_outs: [batch, clsNum]
        labels: [batch]
    """
    loss = nn.CrossEntropyLoss()
    loss_val = loss(pred_outs, labels)
    return loss_val


def _CalACC(model, dataloader):
    model.eval()
    correct = 0
    label_list = []
    pred_list = []

    # label arragne
    with torch.no_grad():
        for i_batch, data in tqdm(enumerate(dataloader), desc='testing is on...'):
            """Prediction"""
            batch_input_tokens, batch_labels, batch_speaker_tokens = data
            batch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()

            pred_logits = model(batch_input_tokens)  # (1, clsNum)

            """Calculation"""
            pred_label = pred_logits.argmax(1).item()
            true_label = batch_labels.item()

            pred_list.append(pred_label)
            label_list.append(true_label)
            if pred_label == true_label:
                correct += 1
        acc = correct / len(dataloader)
    return acc, pred_list, label_list


class ft_model(nn.Module):
    def __init__(self):
        super(ft_model, self).__init__()
        self.context_model = AutoModel.from_pretrained(config['bert_path'])
        self.classifier = nn.Sequential(
            nn.Linear(config['emb_dim'], 300),
            nn.ReLU(),
            nn.Linear(300, config['n_class'])
        )

    def forward(self, batch_input_tokens):
        batch_context_output = self.context_model(batch_input_tokens).last_hidden_state[:, 0, :]
        logits = self.classifier(batch_context_output)
        return logits


if __name__ == '__main__':
    torch.cuda.empty_cache()
    make_batch = make_batch_roberta
    train_path = config['dataset_path'] + '/MELD_train.txt'
    dev_path = config['dataset_path'] + '/MELD_dev.txt'
    test_path = config['dataset_path'] + '/MELD_test.txt'
    train_dataset = MELD_loader(train_path, 'emotion')
    train_dataloader = DataLoader(train_dataset, batch_size=config['batch'], shuffle=True,
                                  num_workers=4, collate_fn=make_batch)
    dev_dataset = MELD_loader(dev_path, 'emotion')
    dev_dataloader = DataLoader(dev_dataset, batch_size=config['batch'], shuffle=False,
                                num_workers=4, collate_fn=make_batch)
    test_dataset = MELD_loader(test_path, 'emotion')
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch'], shuffle=False,
                                num_workers=4, collate_fn=make_batch)

    model = ft_model().cuda()
    model.train()

    # training process
    num_warmup_steps = len(train_dataset)
    num_training_steps = len(train_dataset) * config['epoch']
    train_sample_num = int(len(train_dataloader))
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps,
                                                num_training_steps=num_training_steps)

    best_dev_fscore, best_test_fscore = 0, 0
    best_dev_fscore_macro, best_dev_fscore_micro,\
    best_test_fscore_macro, best_test_fscore_micro = 0, 0, 0, 0
    best_epoch = 0
    for epoch in range(config['epoch']):
        model.train()
        for i_batch, data in tqdm(enumerate(train_dataloader), desc='training is on ...'):
            if i_batch > train_sample_num:
                print(i_batch, train_sample_num)
                break

            """Prediction"""
            batch_input_tokens, batch_labels, batch_speaker_tokens = data
            batch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()

            pred_logits = model(batch_input_tokens)

            """Loss calculation & training"""
            loss_val = CELoss(pred_logits, batch_labels)

            loss_val.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           config['max_grad_norm'])  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        model.eval()
        dev_acc, dev_pred_list, dev_label_list = _CalACC(model, dev_dataloader)
        dev_pre, dev_rec, dev_fbeta, _ = precision_recall_fscore_support(dev_label_list, dev_pred_list,
                                                                         average='weighted')

        test_acc, test_pred_list, test_label_list = _CalACC(model, test_dataloader)
        test_pre, test_rec, test_fbeta, _ = precision_recall_fscore_support(test_label_list, test_pred_list,
                                                                            average='weighted')
        """Best Score & Model Save"""
        if test_fbeta > best_test_fscore:
            best_test_fscore = test_fbeta
            best_epoch = epoch
            torch.save({
                'model_state_dict': model.context_model.state_dict(),
            }, config['saved_model'])
        print('epoch: {}, accuracy: {}, precision: {}, recall: {}, fscore: {}'.
              format(epoch + 1, test_acc, test_pre, test_rec, test_fbeta))
    print('Final Fscore ## test-fscore: {}, test_epoch: {}'.format(best_test_fscore, best_epoch))

你可能感兴趣的:(NLP,pytorch,深度学习,transformer,自然语言处理,bert)