在MNLI数据集下的BERT模型训练和评估

import torch
import pandas as pd
from torch.utils.data import Dataset
import time
import csv
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup


def getBertTokenizer(model, max_token_length):
    if model == 'bert-base-uncased':
        tokenizer = BertTokenizerFast.from_pretrained(model, truncation=True, max_length=max_token_length)
    elif model == 'roberta-base':
        tokenizer = RobertaTokenizerFast.from_pretrained(model, truncation=True, max_length=max_token_length)
    elif model == 'distilbert-base-uncased':
        tokenizer = DistilBertTokenizerFast.from_pretrained(model, truncation=True, max_length=max_token_length)
    else:
        raise ValueError(f'Model: {model} not recognized.')

    return tokenizer


def initialize_bert_transform(net, max_token_length=512):
    # assert 'bert' in config.model
    # assert config.max_token_length is not None

    tokenizer = getBertTokenizer(net, max_token_length)

    def transform(text):
        tokens = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=max_token_length,
            return_tensors='pt')
        if net == 'bert-base-uncased':
            x = torch.stack(
                (tokens['input_ids'],
                 tokens['attention_mask'],
                 tokens['token_type_ids']),
                dim=2)
        elif net == 'distilbert-base-uncased':
            x = torch.stack(
                (tokens['input_ids'],
                 tokens['attention_mask']),
                dim=2)
        x = torch.squeeze(x, dim=0)  # First shape dim is always 1
        return x

    return transform


class BertDataset(Dataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        ## Filter abnormal lines such as NAN lines, empty lines, and change the type of each line to str
        df = df.dropna()
        # df['sentence1'], df['sentence2'] = df['sentence1'].astype(str), df['sentence2'].astype(str)
        df = df[(df['sentence1'].str.split().str.len() > 0) & (df['sentence2'].str.split().str.len() > 0)]
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_testing = is_testing
        self.label_dict = {'entailment': 0, 'neutral': 1, 'contradiction': 2}

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        premise = self.df.iloc[idx]['sentence1']
        hypothesis = self.df.iloc[idx]['sentence2']
        premise_id = self.tokenizer.encode(premise)
        hypothesis_id = self.tokenizer.encode(hypothesis)

        input_ids = premise_id + hypothesis_id[1:]   ## ++++
        attn_mask = [1] * len(input_ids)             ## mask padded values
        token_type_ids = [0] * len(premise_id) + [1] * len(hypothesis_id[1:])  # sentence1->0 and sentence2->1

        # PAD
        pad_len = self.max_length - len(input_ids)
        input_ids += [self.tokenizer.pad_token_id] * pad_len
        attn_mask += [self.tokenizer.pad_token_id] * pad_len
        token_type_ids += [self.tokenizer.pad_token_id] * pad_len

        input_ids, attn_mask, token_type_ids = map(torch.LongTensor, [input_ids, attn_mask, token_type_ids])

        encoded_dict = {
            'input_ids': input_ids,
            'attn_mask': attn_mask,
            'token_type_ids': token_type_ids,
        }
        if not self.is_testing:
            label = self.df.iloc[idx]['gold_label']
            encoded_dict['label'] = self.label_dict[label]
        return encoded_dict

class SICKBertDataset(BertDataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        super(SICKBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
        self.label_dict = {'ENTAILMENT': 0, 'NEUTRAL': 1, 'CONTRADICTION': 2}

class HANSBertDataset(BertDataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        super(HANSBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
        self.label_dict = {'entailment': 0, 'non-entailment': 1}

class QNLIBertDataset(BertDataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        super(QNLIBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
        self.label_dict = {'entailment': 0, 'not_entailment': 1}

class WNLIBertDataset(BertDataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        super(WNLIBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
        self.label_dict = {0: 0, 1: 1}

class ANLIBertDataset(BertDataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        super(ANLIBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
        self.label_dict = {'e': 0, 'n': 1, 'c': 2}

class SciTailBertDataset(BertDataset):
    def __init__(self, df, tokenizer, max_length, is_testing=False):
        super(SciTailBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
        self.label_dict = {'entails': 0, 'neutral': 1}



def get_mnli(batch_size, num_workers, net, data_dir, max_token_length, train=False, eval=True):
    # df_column_names = ['sentence1', 'sentence2', 'gold_label']

    while True:
        try:
            tokenizer = getBertTokenizer(net, max_token_length)
            break
        except Exception:
            continue

    trainset = None
    trainloader = None
    if train:
        ## MNLI: 392702 lines
        train_df = pd.read_csv(f"{data_dir}/MNLI/train.tsv", sep='\t', quoting=csv.QUOTE_NONE)
        trainset = BertDataset(train_df, tokenizer, max_token_length)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    testsets = []
    testloaders = []
    if eval:  ## Determine which data set is used, you can see Table 7 of Zhang et al.,2020
        ## MNLI-M: 9815 lines
        testsetv1_df = pd.read_csv(f"{data_dir}/MNLI/dev_matched.tsv", sep='\t', quoting=csv.QUOTE_NONE)
        testsetv1 = BertDataset(testsetv1_df, tokenizer, max_token_length)
        testloaderv1 = torch.utils.data.DataLoader(testsetv1, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## MNLI-MM: 9832 lines
        testsetv2_df = pd.read_csv(f"{data_dir}/MNLI/dev_mismatched.tsv", sep='\t', quoting=csv.QUOTE_NONE)
        testsetv2 = BertDataset(testsetv2_df, tokenizer, max_token_length)
        testloaderv2 = torch.utils.data.DataLoader(testsetv2, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## SNLI: {snli_1.0_train.jsonl: 550152 lines, snli_1.0_dev.jsonl: 10000 lines. snli_1.0_test.jsonl: 10000 lines}
        testsetv3_df = pd.read_json(f"{data_dir}/SNLI/snli_1.0_dev.jsonl", lines=True)
        testsetv3 = BertDataset(testsetv3_df, tokenizer, max_token_length)
        testloaderv3 = torch.utils.data.DataLoader(testsetv3, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Breaking_NLI: 8193 lines
        testsetv4_df = pd.read_json(f"{data_dir}/Breaking_NLI/data/dataset.jsonl", lines=True)
        testsetv4 = BertDataset(testsetv4_df, tokenizer, max_token_length)
        testloaderv4 = torch.utils.data.DataLoader(testsetv4, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## HANS: {heuristics_train_set: 30000 lines, heuristics_evaluation_set: 30000 lines}
        testsetv5_df = pd.read_json(f"{data_dir}/HANS/heuristics_evaluation_set.jsonl", lines=True)
        testsetv5 = HANSBertDataset(testsetv5_df, tokenizer, max_token_length)
        testloaderv5 = torch.utils.data.DataLoader(testsetv5, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## SNLI-hard: 3261 lines
        testsetv6_df = pd.read_json(f"{data_dir}/SNLI/snli_1.0_test_hard.jsonl", lines=True)
        testsetv6 = BertDataset(testsetv6_df, tokenizer, max_token_length)
        testloaderv6 = torch.utils.data.DataLoader(testsetv6, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Stess-Tests-L: length_mismatch_matched: 9815 lines, length_mismatch_mismatched: 9832 lines
        testsetv7_df = pd.read_json(f"{data_dir}/Stress-Tests/Length_Mismatch/multinli_0.9_length_mismatch_matched.jsonl", lines=True)
        testsetv7 = BertDataset(testsetv7_df, tokenizer, max_token_length)
        testloaderv7 = torch.utils.data.DataLoader(testsetv7, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Stess-Tests-S: gram_contentword_swap_perturbed_matched 8243 lines, gram_contentword_swap_perturbed_mismatched: 6824 lines
        testsetv8_df = pd.read_json(f"{data_dir}/Stress-Tests/Spelling_Error/multinli_0.9_dev_gram_contentword_swap_perturbed_matched.jsonl", lines=True)
        testsetv8 = BertDataset(testsetv8_df, tokenizer, max_token_length)
        testloaderv8 = torch.utils.data.DataLoader(testsetv8, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Stess-Tests-NE: negation_matched: 9815 lines, negation_mismatched: 9832 lines
        testsetv9_df = pd.read_json(f"{data_dir}/Stress-Tests/Negation/multinli_0.9_negation_matched.jsonl", lines=True)
        testsetv9 = BertDataset(testsetv9_df, tokenizer, max_token_length)
        testloaderv9 = torch.utils.data.DataLoader(testsetv9, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Stess-Tests-O: taut2_matched: 9815 lines, taut2_mismatched: 9832 lines
        testsetv10_df = pd.read_json(f"{data_dir}/Stress-Tests/Word_Overlap/multinli_0.9_taut2_matched.jsonl", lines=True)
        testsetv10 = BertDataset(testsetv10_df, tokenizer, max_token_length)
        testloaderv10 = torch.utils.data.DataLoader(testsetv10, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Stess-Tests-A: antonym_matched: 1561 lines, antonym_mismatched: 1734 lines
        testsetv11_df = pd.read_json(f"{data_dir}/Stress-Tests/Antonym/multinli_0.9_antonym_matched.jsonl", lines=True)
        testsetv11 = BertDataset(testsetv11_df, tokenizer, max_token_length)
        testloaderv11 = torch.utils.data.DataLoader(testsetv11, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## Stess-Tests-NU: quant_hard: 7596 lines
        testsetv12_df = pd.read_json(f"{data_dir}/Stress-Tests/Numerical_Reasoning/multinli_0.9_quant_hard.jsonl", lines=True)
        testsetv12 = BertDataset(testsetv12_df, tokenizer, max_token_length)
        testloaderv12 = torch.utils.data.DataLoader(testsetv12, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## SICK: {train+dev+test, 9841 lines}
        testsetv13_df = pd.read_csv(f"{data_dir}/SICK/SICK.txt", sep='\t', quoting=csv.QUOTE_NONE)
        testsetv13_df.rename(columns={'sentence_A': 'sentence1', 'sentence_B': 'sentence2', 'entailment_label': 'gold_label'}, inplace=True)
        testsetv13 = SICKBertDataset(testsetv13_df, tokenizer, max_token_length)
        testloaderv13 = torch.utils.data.DataLoader(testsetv13, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## EQUATE-NAT: {NewsNLI: 968 lines, RedditNLI: 250 lines, RTE_Quant: 166 lines, All: 1384 lines}
        testsetv14_df1 = pd.read_json(f"{data_dir}/EQUATE/NewsNLI.jsonl", lines=True)
        testsetv14_df2 = pd.read_json(f"{data_dir}/EQUATE/RedditNLI.jsonl", lines=True)
        testsetv14_df3 = pd.read_json(f"{data_dir}/EQUATE/RTE_Quant.jsonl", lines=True)
        testsetv14_1 = BertDataset(testsetv14_df1, tokenizer, max_token_length)
        testsetv14_2 = BertDataset(testsetv14_df2, tokenizer, max_token_length)
        testsetv14_3 = BertDataset(testsetv14_df3, tokenizer, max_token_length)
        testsetv14 = torch.utils.data.ConcatDataset([testsetv14_1, testsetv14_2, testsetv14_3])
        testloaderv14 = torch.utils.data.DataLoader(testsetv5, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## EQUATE-SYN: {AWPNLI: 722 lines, StressTest: 7596 lines, All: 8318 lines}
        testsetv15_df1 = pd.read_json(f"{data_dir}/EQUATE/AWPNLI.jsonl", lines=True)
        testsetv15_df2 = pd.read_json(f"{data_dir}/EQUATE/StressTest.jsonl", lines=True)
        testsetv15_1 = BertDataset(testsetv15_df1, tokenizer, max_token_length)
        testsetv15_2 = BertDataset(testsetv15_df2, tokenizer, max_token_length)
        testsetv15 = torch.utils.data.ConcatDataset([testsetv15_1, testsetv15_2])
        testloaderv15 = torch.utils.data.DataLoader(testsetv15, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## QNLI: {train.tsv: 104743 lines, dev.tsv: 5453 lines, test.tsv(without label): 5463 lines}
        testsetv16_df = pd.read_csv(f"{data_dir}/QNLI/dev.tsv", sep='\t', quoting=csv.QUOTE_NONE)
        testsetv16_df.rename(columns={'question': 'sentence1', 'sentence': 'sentence2', 'label': 'gold_label'}, inplace=True)
        testsetv16 = QNLIBertDataset(testsetv16_df, tokenizer, max_token_length)
        testloaderv16 = torch.utils.data.DataLoader(testsetv16, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## RTE: {train.tsv: 2490 lines, dev.tsv: 277 lines, test.tsv(without label): 3000 lines}
        testsetv17_df = pd.read_csv(f"{data_dir}/RTE/dev.tsv", sep='\t', quoting=csv.QUOTE_NONE)
        testsetv17_df.rename(columns={'label': 'gold_label'}, inplace=True)
        testsetv17 = QNLIBertDataset(testsetv17_df, tokenizer, max_token_length)
        testloaderv17 = torch.utils.data.DataLoader(testsetv17, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## WNLI: {train.tsv: 635 lines, dev.tsv: 71 lines, test.tsv(without label): 146 lines}
        testsetv18_df = pd.read_csv(f"{data_dir}/WNLI/dev.tsv", sep='\t', quoting=csv.QUOTE_NONE)
        testsetv18_df.rename(columns={'label': 'gold_label'}, inplace=True)
        testsetv18 = WNLIBertDataset(testsetv18_df, tokenizer, max_token_length)
        testloaderv18 = torch.utils.data.DataLoader(testsetv18, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## ANLI-R1: {train.tsv: 16946 lines, dev.tsv: 1000 lines, test.tsv: 1000 lines}
        testsetv19_df = pd.read_json(f"{data_dir}/ANLI/R1/dev.jsonl", lines=True)
        testsetv19_df.rename(columns={'context': 'sentence1', 'hypothesis': 'sentence2', 'label': 'gold_label'}, inplace=True)
        testsetv19 = ANLIBertDataset(testsetv19_df, tokenizer, max_token_length)
        testloaderv19 = torch.utils.data.DataLoader(testsetv19, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## ANLI-R2: {train.tsv: 45460 lines, dev.tsv: 1000 lines, test.tsv: 1000 lines}
        testsetv20_df = pd.read_json(f"{data_dir}/ANLI/R2/dev.jsonl", lines=True)
        testsetv20_df.rename(columns={'context': 'sentence1', 'hypothesis': 'sentence2', 'label': 'gold_label'}, inplace=True)
        testsetv20 = ANLIBertDataset(testsetv20_df, tokenizer, max_token_length)
        testloaderv20 = torch.utils.data.DataLoader(testsetv20, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## ANLI-R3: {train.tsv: 100459 lines, dev.tsv: 1200 lines, test.tsv: 1200 lines}
        testsetv21_df = pd.read_json(f"{data_dir}/ANLI/R3/dev.jsonl", lines=True)
        testsetv21_df.rename(columns={'context': 'sentence1', 'hypothesis': 'sentence2', 'label': 'gold_label'}, inplace=True)
        testsetv21 = ANLIBertDataset(testsetv21_df, tokenizer, max_token_length)
        testloaderv21 = torch.utils.data.DataLoader(testsetv21, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        ## SciTail: train.tsv: 23596 lines, dev.tsv: 1304 lines, test.tsv: 2126 lines
        testsetv22_df = pd.read_csv(f"{data_dir}/SciTail/tsv_format/scitail_1.0_dev.tsv", sep='\t', quoting=csv.QUOTE_NONE, header=None)
        testsetv22_df.rename(columns={0: 'sentence1', 1: 'sentence2', 2: 'gold_label'}, inplace=True)
        testsetv22 = SciTailBertDataset(testsetv22_df, tokenizer, max_token_length)
        testloaderv22 = torch.utils.data.DataLoader(testsetv22, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        testsets.append(testsetv16)
        testsets.append(testsetv17)
        testsets.append(testsetv18)
        testsets.append(testsetv19)
        testsets.append(testsetv20)
        testsets.append(testsetv21)
        testsets.append(testsetv22)
        testset_oods = [testsetv1, testsetv2, testsetv3, testsetv4, testsetv5, testsetv6, testsetv7, testsetv8, testsetv9,
                        testsetv10, testsetv11, testsetv12, testsetv13, testsetv14, testsetv15]
        testsets.append(testset_oods)

        testloaders.append(testloaderv16)
        testloaders.append(testloaderv17)
        testloaders.append(testloaderv18)
        testloaders.append(testloaderv19)
        testloaders.append(testloaderv20)
        testloaders.append(testloaderv21)
        testloaders.append(testloaderv22)
        testloader_oods = [testloaderv1, testloaderv2, testloaderv3, testloaderv4, testloaderv5, testloaderv6, testloaderv7, testloaderv8, testloaderv9,
                           testloaderv10, testloaderv11, testloaderv12, testloaderv13, testloaderv14, testloaderv15]
        testloaders.append(testloader_oods)

    return trainset, trainloader, testsets, testloaders


def multi_acc(y_pred, y_test):
  acc = (torch.log_softmax(y_pred, dim=1).argmax(dim=1) == y_test).sum().float() / float(y_test.size(0))
  return acc


def main():
    ## Load Datasets
    _, trainloader, _, testloaders = get_mnli(batch_size=16, num_workers=1, net="bert-base-uncased",
                                              data_dir="/data/pengru/datasets/", max_token_length=512, train=False)
    # total_step = len(trainloader)

    ## Load Models
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    while True:
        try:
            ## default num_labels (classification heads) is 2
            # model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
            model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=3)
            model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)
            model.to(device)
            break
        except Exception:
            continue
    print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.weight']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]

    EPOCHS = 5

    # This variable contains all of the hyperparemeter information our training loop needs
    optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
    ## Text cls using BERT
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(trainloader) // EPOCHS)
    ## 楠铠
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(GCONF.warmup_steps * (len(train_dl) * GCONF.epochs)), num_training_steps=(len(train_dl) * GCONF.epochs))

    for epoch in range(EPOCHS):
        start = time.time()
        model.train()
        total_train_loss = 0
        total_train_acc = 0
        for batch in trainloader:
            input_ids, attn_mask, token_type_ids, label = batch['input_ids'].to(device), batch['attn_mask'].to(device), \
                                                          batch['token_type_ids'].to(device), batch['label'].to(device)
            optimizer.zero_grad()
            loss, prediction = model(input_ids,
                                     token_type_ids=token_type_ids,
                                     attention_mask=attn_mask,
                                     labels=label).values()
            loss.backward()
            optimizer.step()
            acc = multi_acc(prediction, label)
            total_train_loss += loss.item()
            total_train_acc += acc.item()

        scheduler.step()
        train_acc = total_train_acc / len(trainloader)
        train_loss = total_train_loss / len(trainloader)

        model.eval()
        total_val_acc = 0
        total_val_loss = 0

        with torch.no_grad():
            for batch in testloaders[0]:
                input_ids, attn_mask, token_type_ids, label = batch['input_ids'].to(device), batch['attn_mask'].to(device), \
                                                              batch['token_type_ids'].to(device), batch['label'].to(device)
                optimizer.zero_grad()
                loss, prediction = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         attention_mask=attn_mask,
                                         labels=label).values()
                acc = multi_acc(prediction, label)
                total_val_loss += loss.item()
                total_val_acc += acc.item()

        val_acc = total_val_acc / len(testloaders[0])
        val_loss = total_val_loss / len(testloaders[0])

        end = time.time()
        print(f'Epoch {epoch + 1}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}, time = {end - start:.4f}')

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model, GCONF.saved_model_path + '/model.pth')
            print('save best model\t\tacc:%.6f' % best_acc)

main()

你可能感兴趣的:(人工智能,深度学习,python)