类似一种span矩阵标注的方式,矩阵的行为句子i-n,列为句子i-n;矩阵i,j处为1表示以列j为开始,以行i为结束的片段是要抽取的目标。具体可以参考苏神的科学空间。
代码由两个模块组成,一个是model.py,进行网络相关定义,一个是span.py搭建训练过程以及训练和测试和预测。
import torch
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch import nn
import math
#增加相对位置编码
class SinusoidalPositionEmbedding(Module):
"""定义Sin-Cos位置Embedding
"""
def __init__(
self, output_dim, merge_mode='add', custom_position_ids=False):
super(SinusoidalPositionEmbedding, self).__init__()
self.output_dim = output_dim
self.merge_mode = merge_mode
self.custom_position_ids = custom_position_ids
def forward(self, inputs):
input_shape = inputs.shape
batch_size, seq_len = input_shape[0], input_shape[1]
position_ids = torch.arange(seq_len).type(torch.float)[None]
indices = torch.arange(self.output_dim // 2).type(torch.float)
indices = torch.pow(10000.0, -2 * indices / self.output_dim)
embeddings = torch.einsum('bn,d->bnd', position_ids, indices)
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim))
if self.merge_mode == 'add':
return inputs + embeddings.to(inputs.device)
elif self.merge_mode == 'mul':
return inputs * (embeddings + 1.0).to(inputs.device)
elif self.merge_mode == 'zero':
return embeddings.to(inputs.device)
#句子mask
def sequence_masking(x, mask, value='-inf', axis=None):
if mask is None:
return x
else:
if value == '-inf':
value = -1e12
elif value == 'inf':
value = 1e12
assert axis > 0, 'axis must be greater than 0'
for _ in range(axis - 1):
mask = torch.unsqueeze(mask, 1)
for _ in range(x.ndim - mask.ndim):
mask = torch.unsqueeze(mask, mask.ndim)
return x * mask + value * (1 - mask)
def add_mask_tril(logits, mask):
if mask.dtype != logits.dtype:
mask = mask.type(logits.dtype)
logits = sequence_masking(logits, mask, '-inf', logits.ndim - 2)
logits = sequence_masking(logits, mask, '-inf', logits.ndim - 1)
# 排除下三角
mask = torch.tril(torch.ones_like(logits), diagonal=-1)
logits = logits - mask * 1e12
return logits
class GlobalPointer(Module):
"""全局指针模块
将序列的每个(start, end)作为整体来进行判断
"""
def __init__(self, heads, head_size,hidden_size,RoPE=True):
super(GlobalPointer, self).__init__()
self.heads = heads
self.head_size = head_size
self.RoPE = RoPE
self.dense = nn.Linear(hidden_size,self.head_size * self.heads * 2)
def forward(self, inputs, mask=None):
inputs = self.dense(inputs)
inputs = torch.split(inputs, self.head_size * 2 , dim=-1)
# 按照-1这个维度去分,每块包含x个小块
inputs = torch.stack(inputs, dim=-2)
#沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状
qw, kw = inputs[..., :self.head_size], inputs[..., self.head_size:]
#分出qw和kw
# RoPE编码
if self.RoPE:
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
cos_pos = pos[..., None, 1::2].repeat(1,1,1,2)
sin_pos = pos[..., None, ::2].repeat(1,1,1,2)
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 4)
qw2 = torch.reshape(qw2, qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 4)
kw2 = torch.reshape(kw2, kw.shape)
kw = kw * cos_pos + kw2 * sin_pos
# 计算内积
logits = torch.einsum('bmhd , bnhd -> bhmn', qw, kw)
# 排除padding 排除下三角
logits = add_mask_tril(logits,mask)
# scale返回
return logits / self.head_size ** 0.5
import re
import json
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer, BertModel, BertConfig, BertTokenizerFast,AdamW
from model import GlobalPointer
import sys
import os
import transformers
#数据处理部分
def load_data(filename,language):
data = []
cat = set()
f = json.load(open(filename,'r+',encoding='utf-8'))
for text in f:
if language == 'chinese':
context = text['tokens']
else:
context = ' '.join(text['tokens'])
data.append([context])
for e in text['entities']:
start_pos = int(e['start'])
end_pos = int(e['end'])-1
e_type = e['type']
data[-1].append([start_pos,end_pos,text['tokens'][start_pos:end_pos+1],e_type])
cat.add(e_type)
return data,cat
class CustomDataset(Dataset):
def __init__(self, data, tokenizer, maxlen):
self.data = data
self.tokenizer = tokenizer
self.maxlen = maxlen
@staticmethod
def find_index(offset_mapping, index):
for idx, internal in enumerate(offset_mapping[1:]):
if internal[0] <= index < internal[1]:
return idx + 1
return None
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
d = self.data[idx]
label = torch.zeros((c_size,self.maxlen,self.maxlen))
enc_context = tokenizer(d[0],return_offsets_mapping=True,max_length=self.maxlen,truncation=True,padding='max_length',return_tensors='pt')
enc_context = {key:enc_context[key][0] for key in enc_context.keys() if enc_context[key].shape[0] == 1}
for entity_info in d[1:]:
start, end = entity_info[0], entity_info[1]
offset_mapping = enc_context['offset_mapping']
start = self.find_index(offset_mapping, start)
end = self.find_index(offset_mapping, end)
if start and end and start < self.maxlen and end < self.maxlen:
#在句子中512x512,句子前要放cls,所以这里的位置都要加一
label[c2id[entity_info[3]],start,end] = 1
return enc_context,label
#标签的维度为[class_count ,sentence_length ,sentence_length]
#搭建网络,就是输入bert获取句子向量特征表示,在用全局指针来获取一个矩阵标注,[batch_size ,class_count ,sentence_length ,sentence_length]
class Net(nn.Module):
def __init__(self,model_path,hidden_size,prop_drop):
super(Net, self).__init__()
self.head = GlobalPointer(c_size, 64, hidden_size)
self.bert = BertModel.from_pretrained(model_path)
self.lstm = nn.LSTM(input_size = hidden_size, hidden_size = hidden_size//2, num_layers = 1, bidirectional = True, dropout = 0.2, batch_first = True)
self.dropout = nn.Dropout(prop_drop)
def forward(self, input_ids, attention_mask, token_type_ids):
x1 = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
x2 = x1.last_hidden_state
x2,(_,_) = self.lstm(x2)
x2 = self.dropout(x2)
logits = self.head(x2, mask = attention_mask)
return logits
#优化器
def get_optimizer_params(model,weight_decay):
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_params = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
return optimizer_params
#苏神的多标签分类损失,可以参加其硬截断损失那篇文章
def multilabel_categorical_crossentropy(y_true, y_pred):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return neg_loss + pos_loss
def global_pointer_crossentropy(y_true, y_pred):
"""给GlobalPointer设计的交叉熵
"""
#y_pred = (batch,l,l,c)
bh = y_pred.shape[0] * y_pred.shape[1]
y_true = torch.reshape(y_true, (bh, -1))
y_pred = torch.reshape(y_pred, (bh, -1))
return torch.mean(multilabel_categorical_crossentropy(y_true, y_pred))
def global_pointer_f1_score(y_true, y_pred):
y_pred = torch.greater(y_pred, 0)
# pre_index = y_pred.nonzero() #获取实体的索引[batch_index,type_index,start_index,end_index]
# l = y_true * y_pred #预测正确的数量
# h = y_true + y_pred #预测的数量+真实的数量
return torch.sum(y_true * y_pred).item(), torch.sum(y_true + y_pred).item()
#定义train,test
def train(dataloader, model, loss_fn, optimizer,scheduler):
model.train()
size = len(dataloader.dataset)
numerate, denominator = 0, 0
for batch, (data,y) in enumerate(dataloader):
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
token_type_ids = data['token_type_ids'].to(device)
y = y.to(device)
pred = model(input_ids,attention_mask,token_type_ids)
loss = loss_fn(y,pred)
temp_n,temp_d = global_pointer_f1_score(y,pred)
numerate += temp_n
denominator += temp_d
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if batch % 50 == 0:
loss, current = loss.item(), batch * len(input_ids)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
print(f"Train F1: {(2*numerate/denominator):>4f}%")
return model
def test(dataloader,loss_fn, model):
size = len(dataloader.dataset)
model.eval()
test_loss = 0
numerate, denominator = 0, 0
with torch.no_grad():
for data,y in dataloader:
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
token_type_ids = data['token_type_ids'].to(device)
y = y.to(device)
pred = model(input_ids, attention_mask, token_type_ids)
test_loss += loss_fn(y,pred).item()
temp_n, temp_d = global_pointer_f1_score(y, pred)
numerate += temp_n
denominator += temp_d
test_loss /= size
test_f1 = 2*numerate/denominator
print(f"Test Error: \n ,F1:{(test_f1):>4f},Avg loss: {test_loss:>8f} \n")
return test_f1
#设置配置参数
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using {} device".format(device))
model_path = "../bert/span_bert"
best_model_save_path = "./bert_model/best_model"
final_model_save_path = "./bert_model/final_model"
tokenizer = BertTokenizerFast.from_pretrained(model_path)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
train_batch_size = 8
learning_rate = 2e-5
dev_batch_size = 8
pre_batch_size = 1
eval_batch_size = 8
train_epochs = 40
lr_warmup = 0.1
#数据
language = 'chinese'
train_data, cat = load_data('./datasets/ccf/Train',language)
val_data, _ = load_data('./datasets/ccf/Dev',language)
pre_data,_ = load_data('./datasets/Pre',language)
eval_data,_ = load_data('./datasets/ccf/Test',language)
train_sample_count = len(train_data)
updates_epoch = train_sample_count//train_batch_size
updates_total = updates_epoch*train_epochs
#获取实体类别
c_size = len(cat)
c2id = {c:idx for idx,c in enumerate(cat)}
id2c = {idx:c for idx,c in enumerate(cat)}
#定义模型
maxlen = 512
hidden_size = 768
prop_drop = 0.1
weight_decay = 0.01
model = Net(model_path,hidden_size,prop_drop).to(device)
optimizer_params = get_optimizer_params(model,weight_decay)
optimizer = AdamW(optimizer_params, lr=learning_rate, weight_decay=weight_decay, correct_bias=False)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=lr_warmup * updates_total,
num_training_steps=updates_total)
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#创建数据集迭代器
training_data = CustomDataset(train_data,tokenizer,maxlen)
testing_data = CustomDataset(val_data,tokenizer,maxlen)
predicting_data = CustomDataset(pre_data,tokenizer,maxlen)
evaling_data = CustomDataset(eval_data,tokenizer,maxlen)
train_dataloader = DataLoader(training_data, batch_size=train_batch_size,shuffle=True)
test_dataloader = DataLoader(testing_data, batch_size=dev_batch_size)
pre_dataloader = DataLoader(predicting_data, batch_size=pre_batch_size)
eval_dataloader = DataLoader(evaling_data, batch_size=eval_batch_size)
prediction_state = 'Flase'
eval_state = 'True'
SUMMARY_OUTPUT_PATH = './prediction/prediction.json'
if __name__ == '__main__':
max_F1 = 0
best_F1 = 0
for t in range(train_epochs):
print(f"Epoch {t + 1}\n-------------------------------")
model = train(train_dataloader, model, global_pointer_crossentropy, optimizer,scheduler)
F1 = test(test_dataloader,global_pointer_crossentropy, model)
if F1 > max_F1:
max_F1 = F1
best_F1 = max_F1
torch.save(model, best_model_save_path)
print(f"Higher F1: {(max_F1):>4f}%")
print(f"best f1: {(max_F1):>4f}%")
print("train Done!")
#保存最后的模型
torch.save(model, final_model_save_path)
if prediction_state == 'True':
print("Start predition...")
print("Load bert_classifier model path: ", best_model_save_path)
#读取模型
model = torch.load(best_model_save_path)
model = model.cuda()
model.eval()
with torch.no_grad():
all = []
for data, y in pre_dataloader:
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
token_type_ids = data['token_type_ids'].to(device)
pred = model(input_ids, attention_mask, token_type_ids)
y_pred = torch.greater(pred, 0)
pre_index = y_pred.nonzero()
pre_index = pre_index.tolist()
for i in range(len(data['input_ids'])):
entities = []
for pre in pre_index: #[batch_index,type_index,start_index,end_index]
if pre[0] == i:
pre_entity = {}
#句子开头的cls,所以这里下标减一
pre_entity['start'] = pre[2]-1
pre_entity['end'] = pre[3]-1
pre_entity['type'] = id2c[pre[1]]
entities.append(pre_entity)
json_str = {'entities':entities}
all.append(json_str)
json_str = json.dumps(all,indent=2)
json_file = open(SUMMARY_OUTPUT_PATH,'w', encoding="utf-8")
json_file.write(json_str + "\n")
print("Evaluation done! Result has saved to: ", SUMMARY_OUTPUT_PATH)
if eval_state == 'True':
print("Start eval...")
print("Load bert_classifier model path: ", best_model_save_path)
#读取模型
model = torch.load(best_model_save_path)
model = model.cuda()
F1 = test(eval_dataloader, global_pointer_crossentropy, model)
print(f"eval f1: {(F1):>4f}%")
这种方式也算是span的一种,只是标注方式不太一样。对于span的方式一直有一个疑问,是直接标出span的整个片段效果,还是只标出span的边界要好。