pytorch苏神的全局指针实现,用于实体识别或事件抽取的span方式

GlobalPointer

类似一种span矩阵标注的方式,矩阵的行为句子i-n,列为句子i-n;矩阵i,j处为1表示以列j为开始,以行i为结束的片段是要抽取的目标。具体可以参考苏神的科学空间。
代码由两个模块组成,一个是model.py,进行网络相关定义,一个是span.py搭建训练过程以及训练和测试和预测。

model.py

import torch
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch import nn
import math
#增加相对位置编码
class SinusoidalPositionEmbedding(Module):
    """定义Sin-Cos位置Embedding
    """
    def __init__(
        self, output_dim, merge_mode='add', custom_position_ids=False):
        super(SinusoidalPositionEmbedding, self).__init__()
        self.output_dim = output_dim
        self.merge_mode = merge_mode
        self.custom_position_ids = custom_position_ids

    def forward(self, inputs):
        input_shape = inputs.shape
        batch_size, seq_len = input_shape[0], input_shape[1]
        position_ids = torch.arange(seq_len).type(torch.float)[None]
        indices = torch.arange(self.output_dim // 2).type(torch.float)
        indices = torch.pow(10000.0, -2 * indices / self.output_dim)
        embeddings = torch.einsum('bn,d->bnd', position_ids, indices)
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim))
        if self.merge_mode == 'add':
            return inputs + embeddings.to(inputs.device)
        elif self.merge_mode == 'mul':
            return inputs * (embeddings + 1.0).to(inputs.device)
        elif self.merge_mode == 'zero':
            return embeddings.to(inputs.device)
#句子mask            
def sequence_masking(x, mask, value='-inf', axis=None):
    if mask is None:
        return x
    else:
        if value == '-inf':
            value = -1e12
        elif value == 'inf':
            value = 1e12
        assert axis > 0, 'axis must be greater than 0'
        for _ in range(axis - 1):
            mask = torch.unsqueeze(mask, 1)
        for _ in range(x.ndim - mask.ndim):
            mask = torch.unsqueeze(mask, mask.ndim)
        return x * mask + value * (1 - mask)          
def add_mask_tril(logits, mask):
    if mask.dtype != logits.dtype:
        mask = mask.type(logits.dtype)
    logits = sequence_masking(logits, mask, '-inf', logits.ndim - 2)
    logits = sequence_masking(logits, mask, '-inf', logits.ndim - 1)
    # 排除下三角
    mask = torch.tril(torch.ones_like(logits), diagonal=-1)
    logits = logits - mask * 1e12
    return logits

class GlobalPointer(Module):
    """全局指针模块
    将序列的每个(start, end)作为整体来进行判断
    """
    def __init__(self, heads, head_size,hidden_size,RoPE=True):
        super(GlobalPointer, self).__init__()
        self.heads = heads
        self.head_size = head_size
        self.RoPE = RoPE
        self.dense = nn.Linear(hidden_size,self.head_size * self.heads * 2)

    def forward(self, inputs, mask=None):
        inputs = self.dense(inputs)
        inputs = torch.split(inputs, self.head_size * 2 , dim=-1)
        # 按照-1这个维度去分,每块包含x个小块
        inputs = torch.stack(inputs, dim=-2)
        #沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状
        qw, kw = inputs[..., :self.head_size], inputs[..., self.head_size:]
        #分出qw和kw
        # RoPE编码
        if self.RoPE:
            pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
            cos_pos = pos[..., None, 1::2].repeat(1,1,1,2)
            sin_pos = pos[..., None, ::2].repeat(1,1,1,2)
            qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 4)
            qw2 = torch.reshape(qw2, qw.shape)
            qw = qw * cos_pos + qw2 * sin_pos
            kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 4)
            kw2 = torch.reshape(kw2, kw.shape)
            kw = kw * cos_pos + kw2 * sin_pos
        # 计算内积
        logits = torch.einsum('bmhd , bnhd -> bhmn', qw, kw)
        # 排除padding 排除下三角
        logits = add_mask_tril(logits,mask)
        # scale返回
        return logits / self.head_size ** 0.5

span.py

import re
import json
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer, BertModel, BertConfig, BertTokenizerFast,AdamW
from model import GlobalPointer
import sys
import os
import transformers

#数据处理部分
def load_data(filename,language):
    data = []
    cat = set()
    f = json.load(open(filename,'r+',encoding='utf-8'))
    for text in f:
        if language == 'chinese':
            context = text['tokens']
        else:
            context = ' '.join(text['tokens'])
        data.append([context])
        for e in text['entities']:
            start_pos = int(e['start'])
            end_pos = int(e['end'])-1
            e_type = e['type']
            data[-1].append([start_pos,end_pos,text['tokens'][start_pos:end_pos+1],e_type])
            cat.add(e_type)
    return data,cat
    
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, maxlen):
        self.data = data
        self.tokenizer = tokenizer
        self.maxlen = maxlen

    @staticmethod
    def find_index(offset_mapping, index):
        for idx, internal in enumerate(offset_mapping[1:]):
            if internal[0] <= index < internal[1]:
                return idx + 1
        return None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        d = self.data[idx]
        label = torch.zeros((c_size,self.maxlen,self.maxlen))
        enc_context = tokenizer(d[0],return_offsets_mapping=True,max_length=self.maxlen,truncation=True,padding='max_length',return_tensors='pt')
        enc_context = {key:enc_context[key][0] for key in enc_context.keys() if enc_context[key].shape[0] == 1}
        for entity_info in d[1:]:
            start, end = entity_info[0], entity_info[1]
            offset_mapping = enc_context['offset_mapping']
            start = self.find_index(offset_mapping, start)
            end = self.find_index(offset_mapping, end)
            if start and end and start < self.maxlen and end < self.maxlen:
                #在句子中512x512,句子前要放cls,所以这里的位置都要加一
                label[c2id[entity_info[3]],start,end] = 1
        return enc_context,label
        #标签的维度为[class_count ,sentence_length ,sentence_length]

#搭建网络,就是输入bert获取句子向量特征表示,在用全局指针来获取一个矩阵标注,[batch_size ,class_count ,sentence_length ,sentence_length]
class Net(nn.Module):
    def __init__(self,model_path,hidden_size,prop_drop):
        super(Net, self).__init__()
        self.head = GlobalPointer(c_size, 64, hidden_size)
        self.bert = BertModel.from_pretrained(model_path)
        self.lstm = nn.LSTM(input_size = hidden_size, hidden_size = hidden_size//2, num_layers = 1,  bidirectional = True, dropout = 0.2, batch_first = True)
        self.dropout = nn.Dropout(prop_drop)
    def forward(self, input_ids, attention_mask, token_type_ids):
        x1 = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        x2 = x1.last_hidden_state
        x2,(_,_) = self.lstm(x2)
        x2 = self.dropout(x2)
        logits = self.head(x2, mask = attention_mask)
        return logits

#优化器
def get_optimizer_params(model,weight_decay):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_params = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    return optimizer_params

#苏神的多标签分类损失,可以参加其硬截断损失那篇文章
def multilabel_categorical_crossentropy(y_true, y_pred):
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    return neg_loss + pos_loss


def global_pointer_crossentropy(y_true, y_pred):
    """给GlobalPointer设计的交叉熵
    """
    #y_pred = (batch,l,l,c)
    bh = y_pred.shape[0] * y_pred.shape[1]
    y_true = torch.reshape(y_true, (bh, -1))
    y_pred = torch.reshape(y_pred, (bh, -1))
    return torch.mean(multilabel_categorical_crossentropy(y_true, y_pred))

def global_pointer_f1_score(y_true, y_pred):
    y_pred = torch.greater(y_pred, 0)
    # pre_index = y_pred.nonzero() #获取实体的索引[batch_index,type_index,start_index,end_index]
    # l = y_true * y_pred #预测正确的数量
    # h = y_true + y_pred #预测的数量+真实的数量
    return torch.sum(y_true * y_pred).item(), torch.sum(y_true + y_pred).item()

#定义train,test
def train(dataloader, model, loss_fn, optimizer,scheduler):
    model.train()
    size = len(dataloader.dataset)
    numerate, denominator = 0, 0
    for batch, (data,y) in enumerate(dataloader):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        token_type_ids = data['token_type_ids'].to(device)
        y = y.to(device)
        pred = model(input_ids,attention_mask,token_type_ids)
        loss = loss_fn(y,pred)
        temp_n,temp_d = global_pointer_f1_score(y,pred)
        numerate += temp_n
        denominator += temp_d
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if batch % 50 == 0:
            loss, current = loss.item(), batch * len(input_ids)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    print(f"Train F1: {(2*numerate/denominator):>4f}%")
    return model

def test(dataloader,loss_fn, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss = 0
    numerate, denominator = 0, 0
    with torch.no_grad():
        for data,y in dataloader:
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            token_type_ids = data['token_type_ids'].to(device)
            y = y.to(device)
            pred = model(input_ids, attention_mask, token_type_ids)
            test_loss += loss_fn(y,pred).item()
            temp_n, temp_d = global_pointer_f1_score(y, pred)
            numerate += temp_n
            denominator += temp_d
    test_loss /= size
    test_f1 = 2*numerate/denominator
    print(f"Test Error: \n ,F1:{(test_f1):>4f},Avg loss: {test_loss:>8f} \n")
    return test_f1

#设置配置参数
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using {} device".format(device))
model_path = "../bert/span_bert"
best_model_save_path = "./bert_model/best_model"
final_model_save_path = "./bert_model/final_model"
tokenizer = BertTokenizerFast.from_pretrained(model_path)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
train_batch_size = 8
learning_rate = 2e-5
dev_batch_size = 8
pre_batch_size = 1
eval_batch_size = 8
train_epochs = 40
lr_warmup = 0.1
#数据
language = 'chinese'
train_data, cat = load_data('./datasets/ccf/Train',language)
val_data, _ = load_data('./datasets/ccf/Dev',language)
pre_data,_ = load_data('./datasets/Pre',language)
eval_data,_ = load_data('./datasets/ccf/Test',language)
train_sample_count = len(train_data)
updates_epoch = train_sample_count//train_batch_size
updates_total = updates_epoch*train_epochs

#获取实体类别
c_size = len(cat)
c2id = {c:idx for idx,c in enumerate(cat)}
id2c = {idx:c for idx,c in enumerate(cat)}

#定义模型
maxlen = 512
hidden_size = 768
prop_drop = 0.1
weight_decay = 0.01
model = Net(model_path,hidden_size,prop_drop).to(device)
optimizer_params = get_optimizer_params(model,weight_decay)
optimizer = AdamW(optimizer_params, lr=learning_rate, weight_decay=weight_decay, correct_bias=False)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
                                                         num_warmup_steps=lr_warmup * updates_total,
                                                         num_training_steps=updates_total)


# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

#创建数据集迭代器
training_data = CustomDataset(train_data,tokenizer,maxlen)
testing_data = CustomDataset(val_data,tokenizer,maxlen)
predicting_data = CustomDataset(pre_data,tokenizer,maxlen)
evaling_data = CustomDataset(eval_data,tokenizer,maxlen)

train_dataloader = DataLoader(training_data, batch_size=train_batch_size,shuffle=True)
test_dataloader = DataLoader(testing_data, batch_size=dev_batch_size)
pre_dataloader = DataLoader(predicting_data, batch_size=pre_batch_size)
eval_dataloader = DataLoader(evaling_data, batch_size=eval_batch_size)


prediction_state = 'Flase'
eval_state = 'True'
SUMMARY_OUTPUT_PATH = './prediction/prediction.json'

if __name__ == '__main__':

    max_F1 = 0
    best_F1 = 0
    for t in range(train_epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        model = train(train_dataloader, model, global_pointer_crossentropy, optimizer,scheduler)
        F1 = test(test_dataloader,global_pointer_crossentropy, model)
        if F1 > max_F1:
            max_F1 = F1
            best_F1 = max_F1
            torch.save(model, best_model_save_path)
            print(f"Higher F1: {(max_F1):>4f}%")
        print(f"best f1: {(max_F1):>4f}%")
    print("train Done!")
    #保存最后的模型
    torch.save(model, final_model_save_path)

    if prediction_state == 'True':
        print("Start predition...")
        print("Load bert_classifier model path: ", best_model_save_path)
        #读取模型
        model = torch.load(best_model_save_path)
        model = model.cuda()
        model.eval()

        with torch.no_grad():
            all = []
            for data, y in pre_dataloader:
                input_ids = data['input_ids'].to(device)
                attention_mask = data['attention_mask'].to(device)
                token_type_ids = data['token_type_ids'].to(device)
                pred = model(input_ids, attention_mask, token_type_ids)
                y_pred = torch.greater(pred, 0)
                pre_index = y_pred.nonzero()
                pre_index = pre_index.tolist()
                for i in range(len(data['input_ids'])):
                    entities = []
                    for pre in pre_index: #[batch_index,type_index,start_index,end_index]
                        if pre[0] == i:
                            pre_entity = {}
                            #句子开头的cls,所以这里下标减一
                            pre_entity['start'] = pre[2]-1
                            pre_entity['end'] = pre[3]-1
                            pre_entity['type'] = id2c[pre[1]]
                            entities.append(pre_entity)
                    json_str = {'entities':entities}
                    all.append(json_str)
            json_str = json.dumps(all,indent=2)
            json_file = open(SUMMARY_OUTPUT_PATH,'w', encoding="utf-8")
            json_file.write(json_str + "\n")
        print("Evaluation done! Result has saved to: ", SUMMARY_OUTPUT_PATH)

    if eval_state == 'True':
        print("Start eval...")
        print("Load bert_classifier model path: ", best_model_save_path)
        #读取模型
        model = torch.load(best_model_save_path)
        model = model.cuda()
        F1 = test(eval_dataloader, global_pointer_crossentropy, model)
        print(f"eval f1: {(F1):>4f}%")

这种方式也算是span的一种,只是标注方式不太一样。对于span的方式一直有一个疑问,是直接标出span的整个片段效果,还是只标出span的边界要好。

你可能感兴趣的:(nlp深度学习,span抽取,pytorch,深度学习,python)