def collate_fn(data):
sents = [i[0] for i in data]
labels = [i[1] for i in data]
data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
truncation=True,
padding='max_length',
max_length=500,
return_tensors='pt',
return_length=True)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, labels
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=16,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1)
return out
model = Model()
def collate_fn(data):
data = token.batch_encode_plus(batch_text_or_text_pairs=data,
truncation=True,
padding='max_length',
max_length=30,
return_tensors='pt',
return_length=True)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = input_ids[:, 15].reshape(-1).clone()
input_ids[:, 15] = token.get_vocab()[token.mask_token]
return input_ids, attention_mask, token_type_ids, labels
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=16,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.decoder = torch.nn.Linear(768, token.vocab_size, bias=False)
self.bias = torch.nn.Parameter(torch.zeros(token.vocab_size))
self.decoder.bias = self.bias
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
out = self.decoder(out.last_hidden_state[:, 15])
return out
model = Model()
def collate_fn(data):
sents = [i[:2] for i in data]
labels = [i[2] for i in data]
data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
truncation=True,
padding='max_length',
max_length=45,
return_tensors='pt',
return_length=True,
add_special_tokens=True)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, labels
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=8,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1)
return out
model = Model()