import torch
import pandas as pd
from torch.utils.data import Dataset
import time
import csv
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
def getBertTokenizer(model, max_token_length):
if model == 'bert-base-uncased':
tokenizer = BertTokenizerFast.from_pretrained(model, truncation=True, max_length=max_token_length)
elif model == 'roberta-base':
tokenizer = RobertaTokenizerFast.from_pretrained(model, truncation=True, max_length=max_token_length)
elif model == 'distilbert-base-uncased':
tokenizer = DistilBertTokenizerFast.from_pretrained(model, truncation=True, max_length=max_token_length)
else:
raise ValueError(f'Model: {model} not recognized.')
return tokenizer
def initialize_bert_transform(net, max_token_length=512):
# assert 'bert' in config.model
# assert config.max_token_length is not None
tokenizer = getBertTokenizer(net, max_token_length)
def transform(text):
tokens = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=max_token_length,
return_tensors='pt')
if net == 'bert-base-uncased':
x = torch.stack(
(tokens['input_ids'],
tokens['attention_mask'],
tokens['token_type_ids']),
dim=2)
elif net == 'distilbert-base-uncased':
x = torch.stack(
(tokens['input_ids'],
tokens['attention_mask']),
dim=2)
x = torch.squeeze(x, dim=0) # First shape dim is always 1
return x
return transform
class BertDataset(Dataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
## Filter abnormal lines such as NAN lines, empty lines, and change the type of each line to str
df = df.dropna()
# df['sentence1'], df['sentence2'] = df['sentence1'].astype(str), df['sentence2'].astype(str)
df = df[(df['sentence1'].str.split().str.len() > 0) & (df['sentence2'].str.split().str.len() > 0)]
self.df = df
self.tokenizer = tokenizer
self.max_length = max_length
self.is_testing = is_testing
self.label_dict = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
def __len__(self):
return self.df.shape[0]
def __getitem__(self, idx):
premise = self.df.iloc[idx]['sentence1']
hypothesis = self.df.iloc[idx]['sentence2']
premise_id = self.tokenizer.encode(premise)
hypothesis_id = self.tokenizer.encode(hypothesis)
input_ids = premise_id + hypothesis_id[1:] ## ++++
attn_mask = [1] * len(input_ids) ## mask padded values
token_type_ids = [0] * len(premise_id) + [1] * len(hypothesis_id[1:]) # sentence1->0 and sentence2->1
# PAD
pad_len = self.max_length - len(input_ids)
input_ids += [self.tokenizer.pad_token_id] * pad_len
attn_mask += [self.tokenizer.pad_token_id] * pad_len
token_type_ids += [self.tokenizer.pad_token_id] * pad_len
input_ids, attn_mask, token_type_ids = map(torch.LongTensor, [input_ids, attn_mask, token_type_ids])
encoded_dict = {
'input_ids': input_ids,
'attn_mask': attn_mask,
'token_type_ids': token_type_ids,
}
if not self.is_testing:
label = self.df.iloc[idx]['gold_label']
encoded_dict['label'] = self.label_dict[label]
return encoded_dict
class SICKBertDataset(BertDataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
super(SICKBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
self.label_dict = {'ENTAILMENT': 0, 'NEUTRAL': 1, 'CONTRADICTION': 2}
class HANSBertDataset(BertDataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
super(HANSBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
self.label_dict = {'entailment': 0, 'non-entailment': 1}
class QNLIBertDataset(BertDataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
super(QNLIBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
self.label_dict = {'entailment': 0, 'not_entailment': 1}
class WNLIBertDataset(BertDataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
super(WNLIBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
self.label_dict = {0: 0, 1: 1}
class ANLIBertDataset(BertDataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
super(ANLIBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
self.label_dict = {'e': 0, 'n': 1, 'c': 2}
class SciTailBertDataset(BertDataset):
def __init__(self, df, tokenizer, max_length, is_testing=False):
super(SciTailBertDataset, self).__init__(df, tokenizer, max_length, is_testing=is_testing)
self.label_dict = {'entails': 0, 'neutral': 1}
def get_mnli(batch_size, num_workers, net, data_dir, max_token_length, train=False, eval=True):
# df_column_names = ['sentence1', 'sentence2', 'gold_label']
while True:
try:
tokenizer = getBertTokenizer(net, max_token_length)
break
except Exception:
continue
trainset = None
trainloader = None
if train:
## MNLI: 392702 lines
train_df = pd.read_csv(f"{data_dir}/MNLI/train.tsv", sep='\t', quoting=csv.QUOTE_NONE)
trainset = BertDataset(train_df, tokenizer, max_token_length)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
testsets = []
testloaders = []
if eval: ## Determine which data set is used, you can see Table 7 of Zhang et al.,2020
## MNLI-M: 9815 lines
testsetv1_df = pd.read_csv(f"{data_dir}/MNLI/dev_matched.tsv", sep='\t', quoting=csv.QUOTE_NONE)
testsetv1 = BertDataset(testsetv1_df, tokenizer, max_token_length)
testloaderv1 = torch.utils.data.DataLoader(testsetv1, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## MNLI-MM: 9832 lines
testsetv2_df = pd.read_csv(f"{data_dir}/MNLI/dev_mismatched.tsv", sep='\t', quoting=csv.QUOTE_NONE)
testsetv2 = BertDataset(testsetv2_df, tokenizer, max_token_length)
testloaderv2 = torch.utils.data.DataLoader(testsetv2, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## SNLI: {snli_1.0_train.jsonl: 550152 lines, snli_1.0_dev.jsonl: 10000 lines. snli_1.0_test.jsonl: 10000 lines}
testsetv3_df = pd.read_json(f"{data_dir}/SNLI/snli_1.0_dev.jsonl", lines=True)
testsetv3 = BertDataset(testsetv3_df, tokenizer, max_token_length)
testloaderv3 = torch.utils.data.DataLoader(testsetv3, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Breaking_NLI: 8193 lines
testsetv4_df = pd.read_json(f"{data_dir}/Breaking_NLI/data/dataset.jsonl", lines=True)
testsetv4 = BertDataset(testsetv4_df, tokenizer, max_token_length)
testloaderv4 = torch.utils.data.DataLoader(testsetv4, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## HANS: {heuristics_train_set: 30000 lines, heuristics_evaluation_set: 30000 lines}
testsetv5_df = pd.read_json(f"{data_dir}/HANS/heuristics_evaluation_set.jsonl", lines=True)
testsetv5 = HANSBertDataset(testsetv5_df, tokenizer, max_token_length)
testloaderv5 = torch.utils.data.DataLoader(testsetv5, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## SNLI-hard: 3261 lines
testsetv6_df = pd.read_json(f"{data_dir}/SNLI/snli_1.0_test_hard.jsonl", lines=True)
testsetv6 = BertDataset(testsetv6_df, tokenizer, max_token_length)
testloaderv6 = torch.utils.data.DataLoader(testsetv6, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Stess-Tests-L: length_mismatch_matched: 9815 lines, length_mismatch_mismatched: 9832 lines
testsetv7_df = pd.read_json(f"{data_dir}/Stress-Tests/Length_Mismatch/multinli_0.9_length_mismatch_matched.jsonl", lines=True)
testsetv7 = BertDataset(testsetv7_df, tokenizer, max_token_length)
testloaderv7 = torch.utils.data.DataLoader(testsetv7, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Stess-Tests-S: gram_contentword_swap_perturbed_matched 8243 lines, gram_contentword_swap_perturbed_mismatched: 6824 lines
testsetv8_df = pd.read_json(f"{data_dir}/Stress-Tests/Spelling_Error/multinli_0.9_dev_gram_contentword_swap_perturbed_matched.jsonl", lines=True)
testsetv8 = BertDataset(testsetv8_df, tokenizer, max_token_length)
testloaderv8 = torch.utils.data.DataLoader(testsetv8, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Stess-Tests-NE: negation_matched: 9815 lines, negation_mismatched: 9832 lines
testsetv9_df = pd.read_json(f"{data_dir}/Stress-Tests/Negation/multinli_0.9_negation_matched.jsonl", lines=True)
testsetv9 = BertDataset(testsetv9_df, tokenizer, max_token_length)
testloaderv9 = torch.utils.data.DataLoader(testsetv9, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Stess-Tests-O: taut2_matched: 9815 lines, taut2_mismatched: 9832 lines
testsetv10_df = pd.read_json(f"{data_dir}/Stress-Tests/Word_Overlap/multinli_0.9_taut2_matched.jsonl", lines=True)
testsetv10 = BertDataset(testsetv10_df, tokenizer, max_token_length)
testloaderv10 = torch.utils.data.DataLoader(testsetv10, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Stess-Tests-A: antonym_matched: 1561 lines, antonym_mismatched: 1734 lines
testsetv11_df = pd.read_json(f"{data_dir}/Stress-Tests/Antonym/multinli_0.9_antonym_matched.jsonl", lines=True)
testsetv11 = BertDataset(testsetv11_df, tokenizer, max_token_length)
testloaderv11 = torch.utils.data.DataLoader(testsetv11, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## Stess-Tests-NU: quant_hard: 7596 lines
testsetv12_df = pd.read_json(f"{data_dir}/Stress-Tests/Numerical_Reasoning/multinli_0.9_quant_hard.jsonl", lines=True)
testsetv12 = BertDataset(testsetv12_df, tokenizer, max_token_length)
testloaderv12 = torch.utils.data.DataLoader(testsetv12, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## SICK: {train+dev+test, 9841 lines}
testsetv13_df = pd.read_csv(f"{data_dir}/SICK/SICK.txt", sep='\t', quoting=csv.QUOTE_NONE)
testsetv13_df.rename(columns={'sentence_A': 'sentence1', 'sentence_B': 'sentence2', 'entailment_label': 'gold_label'}, inplace=True)
testsetv13 = SICKBertDataset(testsetv13_df, tokenizer, max_token_length)
testloaderv13 = torch.utils.data.DataLoader(testsetv13, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## EQUATE-NAT: {NewsNLI: 968 lines, RedditNLI: 250 lines, RTE_Quant: 166 lines, All: 1384 lines}
testsetv14_df1 = pd.read_json(f"{data_dir}/EQUATE/NewsNLI.jsonl", lines=True)
testsetv14_df2 = pd.read_json(f"{data_dir}/EQUATE/RedditNLI.jsonl", lines=True)
testsetv14_df3 = pd.read_json(f"{data_dir}/EQUATE/RTE_Quant.jsonl", lines=True)
testsetv14_1 = BertDataset(testsetv14_df1, tokenizer, max_token_length)
testsetv14_2 = BertDataset(testsetv14_df2, tokenizer, max_token_length)
testsetv14_3 = BertDataset(testsetv14_df3, tokenizer, max_token_length)
testsetv14 = torch.utils.data.ConcatDataset([testsetv14_1, testsetv14_2, testsetv14_3])
testloaderv14 = torch.utils.data.DataLoader(testsetv5, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## EQUATE-SYN: {AWPNLI: 722 lines, StressTest: 7596 lines, All: 8318 lines}
testsetv15_df1 = pd.read_json(f"{data_dir}/EQUATE/AWPNLI.jsonl", lines=True)
testsetv15_df2 = pd.read_json(f"{data_dir}/EQUATE/StressTest.jsonl", lines=True)
testsetv15_1 = BertDataset(testsetv15_df1, tokenizer, max_token_length)
testsetv15_2 = BertDataset(testsetv15_df2, tokenizer, max_token_length)
testsetv15 = torch.utils.data.ConcatDataset([testsetv15_1, testsetv15_2])
testloaderv15 = torch.utils.data.DataLoader(testsetv15, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## QNLI: {train.tsv: 104743 lines, dev.tsv: 5453 lines, test.tsv(without label): 5463 lines}
testsetv16_df = pd.read_csv(f"{data_dir}/QNLI/dev.tsv", sep='\t', quoting=csv.QUOTE_NONE)
testsetv16_df.rename(columns={'question': 'sentence1', 'sentence': 'sentence2', 'label': 'gold_label'}, inplace=True)
testsetv16 = QNLIBertDataset(testsetv16_df, tokenizer, max_token_length)
testloaderv16 = torch.utils.data.DataLoader(testsetv16, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## RTE: {train.tsv: 2490 lines, dev.tsv: 277 lines, test.tsv(without label): 3000 lines}
testsetv17_df = pd.read_csv(f"{data_dir}/RTE/dev.tsv", sep='\t', quoting=csv.QUOTE_NONE)
testsetv17_df.rename(columns={'label': 'gold_label'}, inplace=True)
testsetv17 = QNLIBertDataset(testsetv17_df, tokenizer, max_token_length)
testloaderv17 = torch.utils.data.DataLoader(testsetv17, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## WNLI: {train.tsv: 635 lines, dev.tsv: 71 lines, test.tsv(without label): 146 lines}
testsetv18_df = pd.read_csv(f"{data_dir}/WNLI/dev.tsv", sep='\t', quoting=csv.QUOTE_NONE)
testsetv18_df.rename(columns={'label': 'gold_label'}, inplace=True)
testsetv18 = WNLIBertDataset(testsetv18_df, tokenizer, max_token_length)
testloaderv18 = torch.utils.data.DataLoader(testsetv18, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## ANLI-R1: {train.tsv: 16946 lines, dev.tsv: 1000 lines, test.tsv: 1000 lines}
testsetv19_df = pd.read_json(f"{data_dir}/ANLI/R1/dev.jsonl", lines=True)
testsetv19_df.rename(columns={'context': 'sentence1', 'hypothesis': 'sentence2', 'label': 'gold_label'}, inplace=True)
testsetv19 = ANLIBertDataset(testsetv19_df, tokenizer, max_token_length)
testloaderv19 = torch.utils.data.DataLoader(testsetv19, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## ANLI-R2: {train.tsv: 45460 lines, dev.tsv: 1000 lines, test.tsv: 1000 lines}
testsetv20_df = pd.read_json(f"{data_dir}/ANLI/R2/dev.jsonl", lines=True)
testsetv20_df.rename(columns={'context': 'sentence1', 'hypothesis': 'sentence2', 'label': 'gold_label'}, inplace=True)
testsetv20 = ANLIBertDataset(testsetv20_df, tokenizer, max_token_length)
testloaderv20 = torch.utils.data.DataLoader(testsetv20, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## ANLI-R3: {train.tsv: 100459 lines, dev.tsv: 1200 lines, test.tsv: 1200 lines}
testsetv21_df = pd.read_json(f"{data_dir}/ANLI/R3/dev.jsonl", lines=True)
testsetv21_df.rename(columns={'context': 'sentence1', 'hypothesis': 'sentence2', 'label': 'gold_label'}, inplace=True)
testsetv21 = ANLIBertDataset(testsetv21_df, tokenizer, max_token_length)
testloaderv21 = torch.utils.data.DataLoader(testsetv21, batch_size=batch_size, shuffle=True, num_workers=num_workers)
## SciTail: train.tsv: 23596 lines, dev.tsv: 1304 lines, test.tsv: 2126 lines
testsetv22_df = pd.read_csv(f"{data_dir}/SciTail/tsv_format/scitail_1.0_dev.tsv", sep='\t', quoting=csv.QUOTE_NONE, header=None)
testsetv22_df.rename(columns={0: 'sentence1', 1: 'sentence2', 2: 'gold_label'}, inplace=True)
testsetv22 = SciTailBertDataset(testsetv22_df, tokenizer, max_token_length)
testloaderv22 = torch.utils.data.DataLoader(testsetv22, batch_size=batch_size, shuffle=True, num_workers=num_workers)
testsets.append(testsetv16)
testsets.append(testsetv17)
testsets.append(testsetv18)
testsets.append(testsetv19)
testsets.append(testsetv20)
testsets.append(testsetv21)
testsets.append(testsetv22)
testset_oods = [testsetv1, testsetv2, testsetv3, testsetv4, testsetv5, testsetv6, testsetv7, testsetv8, testsetv9,
testsetv10, testsetv11, testsetv12, testsetv13, testsetv14, testsetv15]
testsets.append(testset_oods)
testloaders.append(testloaderv16)
testloaders.append(testloaderv17)
testloaders.append(testloaderv18)
testloaders.append(testloaderv19)
testloaders.append(testloaderv20)
testloaders.append(testloaderv21)
testloaders.append(testloaderv22)
testloader_oods = [testloaderv1, testloaderv2, testloaderv3, testloaderv4, testloaderv5, testloaderv6, testloaderv7, testloaderv8, testloaderv9,
testloaderv10, testloaderv11, testloaderv12, testloaderv13, testloaderv14, testloaderv15]
testloaders.append(testloader_oods)
return trainset, trainloader, testsets, testloaders
def multi_acc(y_pred, y_test):
acc = (torch.log_softmax(y_pred, dim=1).argmax(dim=1) == y_test).sum().float() / float(y_test.size(0))
return acc
def main():
## Load Datasets
_, trainloader, _, testloaders = get_mnli(batch_size=16, num_workers=1, net="bert-base-uncased",
data_dir="/data/pengru/datasets/", max_token_length=512, train=False)
# total_step = len(trainloader)
## Load Models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
while True:
try:
## default num_labels (classification heads) is 2
# model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=3)
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)
model.to(device)
break
except Exception:
continue
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}
]
EPOCHS = 5
# This variable contains all of the hyperparemeter information our training loop needs
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
## Text cls using BERT
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(trainloader) // EPOCHS)
## 楠铠
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(GCONF.warmup_steps * (len(train_dl) * GCONF.epochs)), num_training_steps=(len(train_dl) * GCONF.epochs))
for epoch in range(EPOCHS):
start = time.time()
model.train()
total_train_loss = 0
total_train_acc = 0
for batch in trainloader:
input_ids, attn_mask, token_type_ids, label = batch['input_ids'].to(device), batch['attn_mask'].to(device), \
batch['token_type_ids'].to(device), batch['label'].to(device)
optimizer.zero_grad()
loss, prediction = model(input_ids,
token_type_ids=token_type_ids,
attention_mask=attn_mask,
labels=label).values()
loss.backward()
optimizer.step()
acc = multi_acc(prediction, label)
total_train_loss += loss.item()
total_train_acc += acc.item()
scheduler.step()
train_acc = total_train_acc / len(trainloader)
train_loss = total_train_loss / len(trainloader)
model.eval()
total_val_acc = 0
total_val_loss = 0
with torch.no_grad():
for batch in testloaders[0]:
input_ids, attn_mask, token_type_ids, label = batch['input_ids'].to(device), batch['attn_mask'].to(device), \
batch['token_type_ids'].to(device), batch['label'].to(device)
optimizer.zero_grad()
loss, prediction = model(input_ids,
token_type_ids=token_type_ids,
attention_mask=attn_mask,
labels=label).values()
acc = multi_acc(prediction, label)
total_val_loss += loss.item()
total_val_acc += acc.item()
val_acc = total_val_acc / len(testloaders[0])
val_loss = total_val_loss / len(testloaders[0])
end = time.time()
print(f'Epoch {epoch + 1}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}, time = {end - start:.4f}')
if val_acc > best_acc:
best_acc = val_acc
torch.save(model, GCONF.saved_model_path + '/model.pth')
print('save best model\t\tacc:%.6f' % best_acc)
main()