中文医疗实体关系基于BERT + Bi-LSTM+ CRF

项目地址:GTyingzi/Chinese-Medical-Entity_Recognition (github.com)

数据集格式

伴	O
两	B-BODY
上	I-BODY
肢	I-BODY
水	O
肿	O

函数部分

主函数main:

import torch
from datetime import datetime
from torch.utils import data
import os
import warnings
import argparse
import numpy as np
from sklearn import metrics
from models import Bert_BiLSTM_CRF,Bert_CRF
from transformers import AdamW, get_linear_schedule_with_warmup
from utils import NerDataset, PadBatch, VOCAB, tokenizer, tag2idx, idx2tag

warnings.filterwarnings("ignore", category=DeprecationWarning)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def train(e, model, iterator, optimizer, scheduler, device):
    model.train()
    losses = 0.0
    step = 0
    for i, batch in enumerate(iterator):
        step += 1
        x, y, z = batch
        x = x.to(device)
        y = y.to(device)
        z = z.to(device)

        loss = model(x, y, z)
        losses += loss.item()
        """ Gradient Accumulation """
        '''
          full_loss = loss / 2                            # normalize loss 
          full_loss.backward()                            # backward and accumulate gradient
          if step % 2 == 0:             
              optimizer.step()                            # update optimizer
              scheduler.step()                            # update scheduler
              optimizer.zero_grad()                       # clear gradient
        '''
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    print("Epoch: {}, Loss:{:.4f}".format(e, losses/step))

def validate(e, model, iterator, device,log_path):
    model.eval()
    Y, Y_hat = [], []
    losses = 0
    step = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            step += 1
            x, y, z = batch
            x = x.to(device)
            y = y.to(device)
            z = z.to(device)

            y_hat = model(x, y, z, is_test=True)

            loss = model(x, y, z)
            losses += loss.item()
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (z==1)
            y_orig = torch.masked_select(y, mask)
            Y.append(y_orig.cpu())

    Y = torch.cat(Y, dim=0).numpy()
    Y_hat = np.array(Y_hat)
    acc = (Y_hat == Y).mean()*100

    output = "{} Epoch: {}, Val Loss:{:.4f}, Val Acc:{:.3f}%".format(datetime.now(),e, losses/step, acc)
    print(output)
    with open(log_path,'a') as f: ## 将训练结果保存到日志中
        f.write(output + '\n')
        f.close()
    return model, losses/step, acc

def model_test(model, iterator, device):
    model.eval()
    Y, Y_hat = [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            x, y, z = batch
            x = x.to(device)
            z = z.to(device)
            y_hat = model(x, y, z, is_test=True)
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (z==1).cpu()
            y_orig = torch.masked_select(y, mask)
            Y.append(y_orig)

    Y = torch.cat(Y, dim=0).numpy()
    y_true = [idx2tag[i] for i in Y]
    y_pred = [idx2tag[i] for i in Y_hat]

    return y_true, y_pred

if __name__=="__main__":

    labels = ['B-BODY',
      'B-DISEASES',
      'B-DRUG',
      'B-EXAMINATIONS',
      'B-TEST',
      'B-TREATMENT',
      'I-BODY',
      'I-DISEASES',
      'I-DRUG',
      'I-EXAMINATIONS',
      'I-TEST',
      'I-TREATMENT']
    
    best_model = None
    _best_val_loss = 1e18
    _best_val_acc = 1e-18

    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--n_epochs", type=int, default=5)
    parser.add_argument("--trainset", type=str, default="./CCKS_2019_Task1/processed_data/train_dataset.txt")
    parser.add_argument("--validset", type=str, default="./CCKS_2019_Task1/processed_data/val_dataset.txt")
    parser.add_argument("--testset", type=str, default="./CCKS_2019_Task1/processed_data/test_dataset.txt")
    # parser.add_argument("--log",type=str,default="./logger/Bert_BiLSTM_CRF/train_acc.txt")
    parser.add_argument("--log", type=str, default="./logger/Bert_CRF/train_acc.txt")

    ner = parser.parse_args()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # 选择模型
    # model = Bert_BiLSTM_CRF(tag2idx).cuda()
    model = Bert_CRF(tag2idx).cuda()
    print(f'Initial model:{model.model_name} Done.')

    # 加载数据集
    train_dataset = NerDataset(ner.trainset)
    eval_dataset = NerDataset(ner.validset)
    test_dataset = NerDataset(ner.testset)
    print('Load Data Done.')

    train_iter = data.DataLoader(dataset=train_dataset,
                                 batch_size=ner.batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 collate_fn=PadBatch)

    eval_iter = data.DataLoader(dataset=eval_dataset,
                                 batch_size=(ner.batch_size)//2,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=PadBatch)

    test_iter = data.DataLoader(dataset=test_dataset,
                                batch_size=(ner.batch_size)//2,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=PadBatch)

    #optimizer = optim.Adam(self.model.parameters(), lr=ner.lr, weight_decay=0.01)
    optimizer = AdamW(model.parameters(), lr=ner.lr, eps=1e-6)

    # Warmup
    len_dataset = len(train_dataset) 
    epoch = ner.n_epochs
    batch_size = ner.batch_size
    total_steps = (len_dataset // batch_size) * epoch if len_dataset % batch_size == 0 else (len_dataset // batch_size + 1) * epoch
    
    warm_up_ratio = 0.1 # Define 10% steps
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps)

    print('Start Train...,')
    for epoch in range(1, ner.n_epochs+1):

        train(epoch, model, train_iter, optimizer, scheduler, device) # 训练模型
        candidate_model, loss, acc = validate(epoch, model, eval_iter, device,ner.log) # 验证模型

        if loss < _best_val_loss and acc > _best_val_acc: # 将验证效果最好的模型保保留下来
          best_model = candidate_model
          _best_val_loss = loss
          _best_val_acc = acc

        print("=============================================")
    
    y_true, y_pred = model_test(best_model, test_iter, device) # 真实值和预测值
    output = metrics.classification_report(y_true, y_pred, labels=labels, digits = 3) # 计算真实值和预测值之间的相关指标

    print(output)
    with open(ner.log,'a') as f: ## 将训练结果保存到日志中
        f.write(output + '\n')
        f.close()

工具箱utils:

import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer

bert_model = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(bert_model)
VOCAB = ('', '[CLS]', '[SEP]', 'O', 'B-BODY','I-TEST', 'I-EXAMINATIONS',
            'I-TREATMENT', 'B-DRUG', 'B-TREATMENT', 'I-DISEASES', 'B-EXAMINATIONS',
                'I-BODY', 'B-TEST', 'B-DISEASES', 'I-DRUG')

tag2idx = {tag: idx for idx, tag in enumerate(VOCAB)}
idx2tag = {idx: tag for idx, tag in enumerate(VOCAB)}
MAX_LEN = 256 - 2

class NerDataset(Dataset):
    ''' Generate our dataset '''
    def __init__(self, f_path):
        self.sents = []
        self.tags_li = []

        with open(f_path, 'r', encoding='utf-8') as f:
            lines = [line.split('\n')[0] for line in f.readlines() if len(line.strip())!=0]
          
        tags =  [line.split('\t')[1] for line in lines]
        words = [line.split('\t')[0] for line in lines]

        word, tag = [], []
        for char, t in zip(words, tags):
            if char != '。':
                word.append(char)
                tag.append(t)
            else:
                if len(word) > MAX_LEN:
                  self.sents.append(['[CLS]'] + word[:MAX_LEN] + ['[SEP]'])
                  self.tags_li.append(['[CLS]'] + tag[:MAX_LEN] + ['[SEP]'])
                else:
                  self.sents.append(['[CLS]'] + word + ['[SEP]'])
                  self.tags_li.append(['[CLS]'] + tag + ['[SEP]'])
                word, tag = [], []

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx]
        token_ids = tokenizer.convert_tokens_to_ids(words)
        laebl_ids = [tag2idx[tag] for tag in tags]
        seqlen = len(laebl_ids)
        return token_ids, laebl_ids, seqlen

    def __len__(self):
        return len(self.sents)

def PadBatch(batch):
    maxlen = max([i[2] for i in batch]) # i:{token_ids,label_ids,seqlen}
    token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch])
    label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch])
    mask = (token_tensors > 0)
    return token_tensors, label_tensors, mask

模型models:

import torch
import torch.nn as nn
from transformers import BertModel
from TorchCRF import CRF

class Bert_BiLSTM_CRF(nn.Module):

    def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256):
        super(Bert_BiLSTM_CRF, self).__init__()
        self.model_name = "Bert_BiLSTM_CRF"
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(self.tag_to_ix)
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2,
                            num_layers=2, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(p=0.1)
        self.linear = nn.Linear(hidden_dim, self.tagset_size)
        self.crf = CRF(self.tagset_size)
    
    def _get_features(self, sentence): # sentence:{batch_size,seq_Len}
        with torch.no_grad():
            encoder_output = self.bert(sentence)
            embeds = encoder_output[0] # embeds:{batch_size,seq_len,embedding_dim}
        enc, _ = self.lstm(embeds) # enc:{batch_size,seq_len,hidden_dim}
        enc = self.dropout(enc)
        feats = self.linear(enc) # feats:{batch_size,seq_len,target_size}
        return feats

    def forward(self, sentence, tags, mask, is_test=False): # {batch_size,seq_Len}
        emissions = self._get_features(sentence) # 得到特征分数,emissions:{batch_size,seq_len,target_size}
        if not is_test: # Training,return loss
            loss=-self.crf.forward(emissions, tags, mask).mean()
            return loss
        else: # Testing,return decoding
            decode=self.crf.viterbi_decode(emissions, mask)
            return decode

class Bert_CRF(nn.Module):

    def __init__(self, tag_to_ix, embedding_dim=768):
        super(Bert_CRF, self).__init__()
        self.model_name = "Bert_CRF"
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(self.tag_to_ix)
        self.embedding_dim = embedding_dim

        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(p=0.1)
        self.linear = nn.Linear(embedding_dim, self.tagset_size)

        self.crf = CRF(self.tagset_size)

    def _get_features(self, sentence):  # sentence:{batch_size,seq_Len}
        with torch.no_grad():
            encoder_output = self.bert(sentence)
            embeds = encoder_output[0]  # embeds:{batch_size,seq_len,embedding_dim}
        embeds = self.dropout(embeds)
        feats = self.linear(embeds)  # feats:{batch_size,seq_len,target_size}
        return feats

    def forward(self, sentence, tags, mask, is_test=False):  # {batch_size,seq_Len}
        emissions = self._get_features(sentence)  # 得到特征分数,emissions:{batch_size,seq_len,target_size}
        if not is_test:  # Training,return loss
            loss = -self.crf.forward(emissions, tags, mask).mean()
            return loss
        else:  # Testing,return decoding
            decode = self.crf.viterbi_decode(emissions, mask)
            return decode

实验效果

BERT + CRF

中文医疗实体关系基于BERT + Bi-LSTM+ CRF_第1张图片

BERT + Bi-LSTM + CRF

中文医疗实体关系基于BERT + Bi-LSTM+ CRF_第2张图片

参考资料

XavierWww/Chinese-Medical-Entity-Recognition: Using BERT+Bi-LSTM+CRF (github.com)

你可能感兴趣的:(NLP,bert,lstm,python)