终章Pytorch_BERT_CASREL关系抽取模型源码

目录

一、整个架构

二、 源码

1.confg.py

2.process.py

3.utils.py

4.model.py

5.train.py

6.test.py

7.predict.py

三、数据集以及模型


一、整个架构

终章Pytorch_BERT_CASREL关系抽取模型源码_第1张图片

二、 源码

1.confg.py

#config.py
REL_PATH = './data/output/rel.csv'
REL_SIZE = 48
SCHEMA_PATH = './data/input/duie/duie_schema.json'

TRAIN_JSON_PATH = './data/input/duie/duie_train.json'
TEST_JSON_PATH = './data/input/duie/duie_test.json'
DEV_JSON_PATH = './data/input/duie/duie_dev.json'

BERT_MODEL_NAME = './bert-base-chinese'

import torch

DEVICE ='cuda'

BATCH_SIZE = 2
BERT_DIM = 768
LR = 1e-4
EPOCH = 50
MODEL_DIR = './data/output/'

CLS_WEIGHT_COEF = [0.3, 1.0]
SUB_WEIGHT_COEF = 3

SUB_HEAD_BAR = 0.5
SUB_TAIL_BAR = 0.5
OBJ_HEAD_BAR = 0.5
OBJ_TAIL_BAR = 0.5

注:这里的DEVICE可以改成cpu

2.process.py

#process.py
import json
import pandas as pd
from config import *

def generate_rel():
    with open(SCHEMA_PATH) as f:
        rel_list = []
        for line in f.readlines():
            info = json.loads(line)
            rel_list.append(info['predicate'])
        rel_dict = {v: k for k, v in enumerate(rel_list)}
        df = pd.DataFrame(rel_dict.items())
        df.to_csv(REL_PATH, header=None, index=None)

if __name__ == '__main__':
    generate_rel()

3.utils.py

#utils.py
import torch.utils.data as data
import pandas as pd
import random
from config import *
import json
import numpy as np
from transformers import BertTokenizerFast

def get_rel():
    df = pd.read_csv(REL_PATH, names=['rel', 'id'])
    return df['rel'].tolist(), dict(df.values)


class Dataset(data.Dataset):
    def __init__(self, type='train'):
        super().__init__()
        _, self.rel2id = get_rel()
        # 加载文件
        if type == 'train':
            file_path = TRAIN_JSON_PATH
        elif type == 'test':
            file_path = TEST_JSON_PATH
        elif type == 'dev':
            file_path = DEV_JSON_PATH
        with open(file_path,encoding='utf-8') as f:
            self.lines = f.readlines()
        # 加载bert
        self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        line = self.lines[index]
        info = json.loads(line)
        tokenized = self.tokenizer(info['text'], return_offsets_mapping=True)
        info['input_ids'] = tokenized['input_ids']
        info['offset_mapping'] = tokenized['offset_mapping']
        return self.parse_json(info)

    def parse_json(self, info):
        text = info['text']
        input_ids = info['input_ids']
        dct = {
            'text': text,
            'input_ids': input_ids,
            'offset_mapping': info['offset_mapping'],
            'sub_head_ids': [],
            'sub_tail_ids': [],
            'triple_list': [],
            'triple_id_list': []
        }
        for spo in info['spo_list']:
            subject = spo['subject']
            object = spo['object']['@value']
            predicate = spo['predicate']
            dct['triple_list'].append((subject, predicate, object))
            # 计算 subject 实体位置
            tokenized = self.tokenizer(subject, add_special_tokens=False)
            sub_token = tokenized['input_ids']
            sub_pos_id = self.get_pos_id(input_ids, sub_token)
            if not sub_pos_id:
                continue
            sub_head_id, sub_tail_id = sub_pos_id
            # 计算 object 实体位置
            tokenized = self.tokenizer(object, add_special_tokens=False)
            obj_token = tokenized['input_ids']
            obj_pos_id = self.get_pos_id(input_ids, obj_token)
            if not obj_pos_id:
                continue
            obj_head_id, obj_tail_id = obj_pos_id
            # 数据组装
            dct['sub_head_ids'].append(sub_head_id)
            dct['sub_tail_ids'].append(sub_tail_id)
            dct['triple_id_list'].append((
                [sub_head_id, sub_tail_id],
                self.rel2id[predicate],
                [obj_head_id, obj_tail_id],
            ))
        return dct

    def get_pos_id(self, source, elem):
        for head_id in range(len(source)):
            tail_id = head_id + len(elem)
            if source[head_id:tail_id] == elem:
                return head_id, tail_id - 1

    def collate_fn(self, batch):
        batch.sort(key=lambda x: len(x['input_ids']), reverse=True)
        max_len = len(batch[0]['input_ids'])
        batch_text = {
            'text': [],
            'input_ids': [],
            'offset_mapping': [],
            'triple_list': [],
        }
        batch_mask = []
        batch_sub = {
            'heads_seq': [],
            'tails_seq': [],
        }
        batch_sub_rnd = {
            'head_seq': [],
            'tail_seq': [],
        }
        batch_obj_rel = {
            'heads_mx': [],
            'tails_mx': [],
        }

        for item in batch:
            input_ids = item['input_ids']
            item_len = len(input_ids)
            pad_len = max_len - item_len
            input_ids = input_ids + [0] * pad_len
            mask = [1] * item_len + [0] * pad_len
            # 填充subject位置
            sub_heads_seq = multihot(max_len, item['sub_head_ids'])
            sub_tails_seq = multihot(max_len, item['sub_tail_ids'])
            # 随机选择一个subject
            if len(item['triple_id_list']) == 0:
                continue
            sub_rnd = random.choice(item['triple_id_list'])[0]
            sub_rnd_head_seq = multihot(max_len, [sub_rnd[0]])
            sub_rnd_tail_seq = multihot(max_len, [sub_rnd[1]])
            # 根据随机subject计算relations矩阵
            obj_head_mx = [[0] * REL_SIZE for _ in range(max_len)]
            obj_tail_mx = [[0] * REL_SIZE for _ in range(max_len)]
            for triple in item['triple_id_list']:
                rel_id = triple[1]
                head_id, tail_id = triple[2]
                if triple[0] == sub_rnd:
                    obj_head_mx[head_id][rel_id] = 1
                    obj_tail_mx[tail_id][rel_id] = 1
            # 重新组装
            batch_text['text'].append(item['text'])
            batch_text['input_ids'].append(input_ids)
            batch_text['offset_mapping'].append(item['offset_mapping'])
            batch_text['triple_list'].append(item['triple_list'])
            batch_mask.append(mask)
            batch_sub['heads_seq'].append(sub_heads_seq)
            batch_sub['tails_seq'].append(sub_tails_seq)
            batch_sub_rnd['head_seq'].append(sub_rnd_head_seq)
            batch_sub_rnd['tail_seq'].append(sub_rnd_tail_seq)
            batch_obj_rel['heads_mx'].append(obj_head_mx)
            batch_obj_rel['tails_mx'].append(obj_tail_mx)

        # 注意,结构太复杂,没有转tensor
        return batch_mask, (batch_text, batch_sub_rnd), (batch_sub, batch_obj_rel)


# 生成长度为length,hot_pos位置为1,其他位置为0的列表
def multihot(length, hot_pos):
    return [1 if i in hot_pos else 0 for i in range(length)]


def get_triple_list(sub_head_ids, sub_tail_ids, model, encoded_text, text, mask, offset_mapping):
    id2rel, _ = get_rel()
    triple_list = []
    for sub_head_id in sub_head_ids:
        sub_tail_ids = sub_tail_ids[sub_tail_ids >= sub_head_id]
        if len(sub_tail_ids) == 0:
            continue
        sub_tail_id = sub_tail_ids[0]
        if mask[sub_head_id] == 0 or mask[sub_tail_id] == 0:
            continue
        # 根据位置信息反推出 subject 文本内容
        sub_head_pos_id = offset_mapping[sub_head_id][0]
        sub_tail_pos_id = offset_mapping[sub_tail_id][1]
        subject_text = text[sub_head_pos_id:sub_tail_pos_id]

        # 根据 subject 计算出对应 object 和 relation
        sub_head_seq = torch.tensor(multihot(len(mask), sub_head_id)).to(DEVICE)
        sub_tail_seq = torch.tensor(multihot(len(mask), sub_tail_id)).to(DEVICE)

        pred_obj_head, pred_obj_tail = model.get_objs_for_specific_sub(\
            encoded_text.unsqueeze(0), sub_head_seq.unsqueeze(0), sub_tail_seq.unsqueeze(0))

        # 按分类找对应关系
        pred_obj_head = pred_obj_head[0].T
        pred_obj_tail = pred_obj_tail[0].T
        for j in range(len(pred_obj_head)):
            obj_head_ids = torch.where(pred_obj_head[j] > OBJ_HEAD_BAR)[0]
            obj_tail_ids = torch.where(pred_obj_tail[j] > OBJ_TAIL_BAR)[0]
            for obj_head_id in obj_head_ids:
                obj_tail_ids = obj_tail_ids[obj_tail_ids >= obj_head_id]
                if len(obj_tail_ids) == 0:
                    continue
                obj_tail_id = obj_tail_ids[0]
                if mask[obj_head_id] == 0 or mask[obj_tail_id] == 0:
                    continue
                # 根据位置信息反推出 object 文本内容,mapping中已经有移位,不需要再加1
                obj_head_pos_id = offset_mapping[obj_head_id][0]
                obj_tail_pos_id = offset_mapping[obj_tail_id][1]
                object_text = text[obj_head_pos_id:obj_tail_pos_id]
                triple_list.append((subject_text, id2rel[j], object_text))
    return list(set(triple_list))


def report(model, encoded_text, pred_y, batch_text, batch_mask):
    # 计算三元结构,和统计指标
    pred_sub_head, pred_sub_tail, _, _ = pred_y
    true_triple_list = batch_text['triple_list']
    pred_triple_list = []

    correct_num, predict_num, gold_num = 0, 0, 0

    # 遍历batch
    for i in range(len(pred_sub_head)):
        text = batch_text['text'][i]
        
        true_triple_item = true_triple_list[i]
        mask = batch_mask[i]
        offset_mapping = batch_text['offset_mapping'][i]

        sub_head_ids = torch.where(pred_sub_head[i] > SUB_HEAD_BAR)[0]
        sub_tail_ids = torch.where(pred_sub_tail[i] > SUB_TAIL_BAR)[0]

        pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
            encoded_text[i], text, mask, offset_mapping)
        
        # 统计个数
        correct_num += len(set(true_triple_item) & set(pred_triple_item))
        predict_num += len(set(pred_triple_item))
        gold_num += len(set(true_triple_item))

        pred_triple_list.append(pred_triple_item)

    precision = correct_num / (predict_num + EPS)
    recall = correct_num / (gold_num + EPS)
    f1_score = 2 * precision * recall / (precision + recall + EPS)
    print('\tcorrect_num:', correct_num, 'predict_num:', predict_num, 'gold_num:', gold_num)
    print('\tprecision:%.3f' % precision, 'recall:%.3f' % recall, 'f1_score:%.3f' % f1_score)


if __name__ == '__main__':
    dataset = Dataset()
    loader = data.DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=dataset.collate_fn)
    print(iter(loader).next())

4.model.py

#model.py
import torch.nn as nn
from transformers import BertModel
from config import *
import torch
import torch.nn.functional as F

# 忽略 transformers 警告
from transformers import logging
logging.set_verbosity_error()

class CasRel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
        # 冻结Bert参数,只训练下游模型
        for name, param in self.bert.named_parameters():
            param.requires_grad = False
        self.sub_head_linear = nn.Linear(BERT_DIM, 1)
        self.sub_tail_linear = nn.Linear(BERT_DIM, 1)
        self.obj_head_linear = nn.Linear(BERT_DIM, REL_SIZE)
        self.obj_tail_linear = nn.Linear(BERT_DIM, REL_SIZE)

    def get_encoded_text(self, input_ids, mask):
        return self.bert(input_ids, attention_mask=mask)[0]

    def get_subs(self, encoded_text):
        pred_sub_head = torch.sigmoid(self.sub_head_linear(encoded_text))
        pred_sub_tail = torch.sigmoid(self.sub_tail_linear(encoded_text))
        return pred_sub_head, pred_sub_tail

    def get_objs_for_specific_sub(self, encoded_text, sub_head_seq, sub_tail_seq):
        # sub_head_seq.shape (b, c) -> (b, 1, c)
        sub_head_seq = sub_head_seq.unsqueeze(1).float()
        sub_tail_seq = sub_tail_seq.unsqueeze(1).float()

        # encoded_text.shape (b, c, 768)
        sub_head = torch.matmul(sub_head_seq, encoded_text)
        sub_tail = torch.matmul(sub_tail_seq, encoded_text)
        encoded_text = encoded_text + (sub_head + sub_tail) / 2

        # encoded_text.shape (b, c, 768)
        pred_obj_head = torch.sigmoid(self.obj_head_linear(encoded_text))
        pred_obj_tail = torch.sigmoid(self.obj_tail_linear(encoded_text))

        # shape (b, c, REL_SIZE)
        return pred_obj_head, pred_obj_tail

    def forward(self, input, mask):
        input_ids, sub_head_seq, sub_tail_seq = input
        encoded_text = self.get_encoded_text(input_ids, mask)
        pred_sub_head, pred_sub_tail = self.get_subs(encoded_text)
        
        # 预测relation-object矩阵
        pred_obj_head, pred_obj_tail = self.get_objs_for_specific_sub(\
            encoded_text, sub_head_seq, sub_tail_seq)

        return encoded_text, (pred_sub_head, pred_sub_tail, pred_obj_head, pred_obj_tail)

    def loss_fn(self, true_y, pred_y, mask):
        def calc_loss(pred, true, mask):
            true = true.float()
            # pred.shape (b, c, 1) -> (b, c)
            pred = pred.squeeze(-1)
            weight = torch.where(true > 0, CLS_WEIGHT_COEF[1], CLS_WEIGHT_COEF[0])
            loss = F.binary_cross_entropy(pred, true, weight=weight, reduction='none')
            if loss.shape != mask.shape:
                mask = mask.unsqueeze(-1)
            return torch.sum(loss * mask) / torch.sum(mask)

        pred_sub_head, pred_sub_tail, pred_obj_head, pred_obj_tail = pred_y
        true_sub_head, true_sub_tail, true_obj_head, true_obj_tail = true_y
        return calc_loss(pred_sub_head, true_sub_head, mask) * SUB_WEIGHT_COEF + \
            calc_loss(pred_sub_tail, true_sub_tail, mask) * SUB_WEIGHT_COEF + \
                calc_loss(pred_obj_head, true_obj_head, mask) + \
                    calc_loss(pred_obj_tail, true_obj_tail, mask)

5.train.py

#train.py
from utils import *
from model import *
from torch.utils import data

if __name__ == '__main__':
    model = CasRel().to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    dataset = Dataset()
    for e in range(EPOCH):
        loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn)
        for b, (batch_mask, batch_x, batch_y) in enumerate(loader):
            batch_text, batch_sub_rnd = batch_x
            batch_sub, batch_obj_rel = batch_y

            # 整理input数据并预测
            input_mask = torch.tensor(batch_mask).to(DEVICE)
            input = (
                torch.tensor(batch_text['input_ids']).to(DEVICE),
                torch.tensor(batch_sub_rnd['head_seq']).to(DEVICE),
                torch.tensor(batch_sub_rnd['tail_seq']).to(DEVICE),
            )

            
            encoded_text, pred_y = model(input, input_mask)

            # 整理target数据并计算损失
            true_y = (
                torch.tensor(batch_sub['heads_seq']).to(DEVICE),
                torch.tensor(batch_sub['tails_seq']).to(DEVICE),
                torch.tensor(batch_obj_rel['heads_mx']).to(DEVICE),
                torch.tensor(batch_obj_rel['tails_mx']).to(DEVICE),
            )
            loss = model.loss_fn(true_y, pred_y, input_mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if b % 50 == 0:
                print('>> epoch:', e, 'batch:', b, 'loss:', loss.item())

            if b % 500 == 0:
                report(model, encoded_text, pred_y, batch_text, batch_mask)

        if e % 1 == 0:
            torch.save(model, MODEL_DIR + f'model_{e}.pth') 

6.test.py

#test.py
from utils import *
from model import *
from torch.utils import data

if __name__ == '__main__':
    model = torch.load(MODEL_DIR + f'model_27.pth', map_location=DEVICE)

    dataset = Dataset('dev')

    with torch.no_grad():

        loader = data.DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=dataset.collate_fn)
        
        correct_num, predict_num, gold_num = 0, 0, 0
        pred_triple_list = []
        true_triple_list = []
        
        for b, (batch_mask, batch_x, batch_y) in enumerate(loader):
            batch_text, batch_sub_rnd = batch_x
            batch_sub, batch_obj_rel = batch_y

            # 整理input数据并预测
            input_mask = torch.tensor(batch_mask).to(DEVICE)
            input = (
                torch.tensor(batch_text['input_ids']).to(DEVICE),
                torch.tensor(batch_sub_rnd['head_seq']).to(DEVICE),
                torch.tensor(batch_sub_rnd['tail_seq']).to(DEVICE),
            )
            encoded_text, pred_y = model(input, input_mask)

            # 整理target数据并计算损失
            true_y = (
                torch.tensor(batch_sub['heads_seq']).to(DEVICE),
                torch.tensor(batch_sub['tails_seq']).to(DEVICE),
                torch.tensor(batch_obj_rel['heads_mx']).to(DEVICE),
                torch.tensor(batch_obj_rel['tails_mx']).to(DEVICE),
            )
            loss = model.loss_fn(true_y, pred_y, input_mask)

            print('>> batch:', b, 'loss:', loss.item())

            # 计算关系三元组,和统计指标
            pred_sub_head, pred_sub_tail, _, _ = pred_y
            true_triple_list += batch_text['triple_list']
            
            # 遍历batch
            for i in range(len(pred_sub_head)):
                text = batch_text['text'][i]
                true_triple_item = true_triple_list[i]
                mask = batch_mask[i]
                offset_mapping = batch_text['offset_mapping'][i]

                sub_head_ids = torch.where(pred_sub_head[i] > SUB_HEAD_BAR)[0]
                sub_tail_ids = torch.where(pred_sub_tail[i] > SUB_TAIL_BAR)[0]

                pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
                    encoded_text[i], text, mask, offset_mapping)

                # 统计个数
                correct_num += len(set(true_triple_item) & set(pred_triple_item))
                predict_num += len(set(pred_triple_item))
                gold_num += len(set(true_triple_item))

                pred_triple_list.append(pred_triple_item)

        precision = correct_num / (predict_num + EPS)
        recall = correct_num / (gold_num + EPS)
        f1_score = 2 * precision * recall / (precision + recall + EPS)
        print('\tcorrect_num:', correct_num, 'predict_num:', predict_num, 'gold_num:', gold_num)
        print('\tprecision:%.3f' % precision, 'recall:%.3f' % recall, 'f1_score:%.3f' % f1_score)

7.predict.py

#predict.py
from config import *
from utils import *
from transformers import BertTokenizerFast
from model import *

if __name__ == '__main__':
    text = '俞敏洪,出生于1962年9月4日的江苏省江阴市,大学毕业于北京大学西语系。'
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)
    tokenized = tokenizer(text, return_offsets_mapping=True)
    info = {}
    info['input_ids'] = tokenized['input_ids']
    info['offset_mapping'] = tokenized['offset_mapping']
    info['mask'] = tokenized['attention_mask']

    input_ids = torch.tensor([info['input_ids']]).to(DEVICE)
    batch_mask = torch.tensor([info['mask']]).to(DEVICE)

    model = torch.load(MODEL_DIR + 'newmodel_21.pth', map_location=DEVICE)

    encoded_text = model.get_encoded_text(input_ids, batch_mask)
    pred_sub_head, pred_sub_tail = model.get_subs(encoded_text)

    sub_head_ids = torch.where(pred_sub_head[0] > SUB_HEAD_BAR)[0]
    sub_tail_ids = torch.where(pred_sub_tail[0] > SUB_TAIL_BAR)[0]
    mask = batch_mask[0]
    encoded_text = encoded_text[0]

    offset_mapping = info['offset_mapping']

    pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
            encoded_text, text, mask, offset_mapping)
    
    print(text)
    print(pred_triple_item)

三、数据集以及模型

在第一节已经交代过了,可以查看我的博客。

你可能感兴趣的:(关系抽取Casrel,bert,人工智能,深度学习,机器学习,nlp)