目录
一、整个架构
二、 源码
1.confg.py
2.process.py
3.utils.py
4.model.py
5.train.py
6.test.py
7.predict.py
三、数据集以及模型
#config.py
REL_PATH = './data/output/rel.csv'
REL_SIZE = 48
SCHEMA_PATH = './data/input/duie/duie_schema.json'
TRAIN_JSON_PATH = './data/input/duie/duie_train.json'
TEST_JSON_PATH = './data/input/duie/duie_test.json'
DEV_JSON_PATH = './data/input/duie/duie_dev.json'
BERT_MODEL_NAME = './bert-base-chinese'
import torch
DEVICE ='cuda'
BATCH_SIZE = 2
BERT_DIM = 768
LR = 1e-4
EPOCH = 50
MODEL_DIR = './data/output/'
CLS_WEIGHT_COEF = [0.3, 1.0]
SUB_WEIGHT_COEF = 3
SUB_HEAD_BAR = 0.5
SUB_TAIL_BAR = 0.5
OBJ_HEAD_BAR = 0.5
OBJ_TAIL_BAR = 0.5
注:这里的DEVICE可以改成cpu
#process.py
import json
import pandas as pd
from config import *
def generate_rel():
with open(SCHEMA_PATH) as f:
rel_list = []
for line in f.readlines():
info = json.loads(line)
rel_list.append(info['predicate'])
rel_dict = {v: k for k, v in enumerate(rel_list)}
df = pd.DataFrame(rel_dict.items())
df.to_csv(REL_PATH, header=None, index=None)
if __name__ == '__main__':
generate_rel()
#utils.py
import torch.utils.data as data
import pandas as pd
import random
from config import *
import json
import numpy as np
from transformers import BertTokenizerFast
def get_rel():
df = pd.read_csv(REL_PATH, names=['rel', 'id'])
return df['rel'].tolist(), dict(df.values)
class Dataset(data.Dataset):
def __init__(self, type='train'):
super().__init__()
_, self.rel2id = get_rel()
# 加载文件
if type == 'train':
file_path = TRAIN_JSON_PATH
elif type == 'test':
file_path = TEST_JSON_PATH
elif type == 'dev':
file_path = DEV_JSON_PATH
with open(file_path,encoding='utf-8') as f:
self.lines = f.readlines()
# 加载bert
self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
line = self.lines[index]
info = json.loads(line)
tokenized = self.tokenizer(info['text'], return_offsets_mapping=True)
info['input_ids'] = tokenized['input_ids']
info['offset_mapping'] = tokenized['offset_mapping']
return self.parse_json(info)
def parse_json(self, info):
text = info['text']
input_ids = info['input_ids']
dct = {
'text': text,
'input_ids': input_ids,
'offset_mapping': info['offset_mapping'],
'sub_head_ids': [],
'sub_tail_ids': [],
'triple_list': [],
'triple_id_list': []
}
for spo in info['spo_list']:
subject = spo['subject']
object = spo['object']['@value']
predicate = spo['predicate']
dct['triple_list'].append((subject, predicate, object))
# 计算 subject 实体位置
tokenized = self.tokenizer(subject, add_special_tokens=False)
sub_token = tokenized['input_ids']
sub_pos_id = self.get_pos_id(input_ids, sub_token)
if not sub_pos_id:
continue
sub_head_id, sub_tail_id = sub_pos_id
# 计算 object 实体位置
tokenized = self.tokenizer(object, add_special_tokens=False)
obj_token = tokenized['input_ids']
obj_pos_id = self.get_pos_id(input_ids, obj_token)
if not obj_pos_id:
continue
obj_head_id, obj_tail_id = obj_pos_id
# 数据组装
dct['sub_head_ids'].append(sub_head_id)
dct['sub_tail_ids'].append(sub_tail_id)
dct['triple_id_list'].append((
[sub_head_id, sub_tail_id],
self.rel2id[predicate],
[obj_head_id, obj_tail_id],
))
return dct
def get_pos_id(self, source, elem):
for head_id in range(len(source)):
tail_id = head_id + len(elem)
if source[head_id:tail_id] == elem:
return head_id, tail_id - 1
def collate_fn(self, batch):
batch.sort(key=lambda x: len(x['input_ids']), reverse=True)
max_len = len(batch[0]['input_ids'])
batch_text = {
'text': [],
'input_ids': [],
'offset_mapping': [],
'triple_list': [],
}
batch_mask = []
batch_sub = {
'heads_seq': [],
'tails_seq': [],
}
batch_sub_rnd = {
'head_seq': [],
'tail_seq': [],
}
batch_obj_rel = {
'heads_mx': [],
'tails_mx': [],
}
for item in batch:
input_ids = item['input_ids']
item_len = len(input_ids)
pad_len = max_len - item_len
input_ids = input_ids + [0] * pad_len
mask = [1] * item_len + [0] * pad_len
# 填充subject位置
sub_heads_seq = multihot(max_len, item['sub_head_ids'])
sub_tails_seq = multihot(max_len, item['sub_tail_ids'])
# 随机选择一个subject
if len(item['triple_id_list']) == 0:
continue
sub_rnd = random.choice(item['triple_id_list'])[0]
sub_rnd_head_seq = multihot(max_len, [sub_rnd[0]])
sub_rnd_tail_seq = multihot(max_len, [sub_rnd[1]])
# 根据随机subject计算relations矩阵
obj_head_mx = [[0] * REL_SIZE for _ in range(max_len)]
obj_tail_mx = [[0] * REL_SIZE for _ in range(max_len)]
for triple in item['triple_id_list']:
rel_id = triple[1]
head_id, tail_id = triple[2]
if triple[0] == sub_rnd:
obj_head_mx[head_id][rel_id] = 1
obj_tail_mx[tail_id][rel_id] = 1
# 重新组装
batch_text['text'].append(item['text'])
batch_text['input_ids'].append(input_ids)
batch_text['offset_mapping'].append(item['offset_mapping'])
batch_text['triple_list'].append(item['triple_list'])
batch_mask.append(mask)
batch_sub['heads_seq'].append(sub_heads_seq)
batch_sub['tails_seq'].append(sub_tails_seq)
batch_sub_rnd['head_seq'].append(sub_rnd_head_seq)
batch_sub_rnd['tail_seq'].append(sub_rnd_tail_seq)
batch_obj_rel['heads_mx'].append(obj_head_mx)
batch_obj_rel['tails_mx'].append(obj_tail_mx)
# 注意,结构太复杂,没有转tensor
return batch_mask, (batch_text, batch_sub_rnd), (batch_sub, batch_obj_rel)
# 生成长度为length,hot_pos位置为1,其他位置为0的列表
def multihot(length, hot_pos):
return [1 if i in hot_pos else 0 for i in range(length)]
def get_triple_list(sub_head_ids, sub_tail_ids, model, encoded_text, text, mask, offset_mapping):
id2rel, _ = get_rel()
triple_list = []
for sub_head_id in sub_head_ids:
sub_tail_ids = sub_tail_ids[sub_tail_ids >= sub_head_id]
if len(sub_tail_ids) == 0:
continue
sub_tail_id = sub_tail_ids[0]
if mask[sub_head_id] == 0 or mask[sub_tail_id] == 0:
continue
# 根据位置信息反推出 subject 文本内容
sub_head_pos_id = offset_mapping[sub_head_id][0]
sub_tail_pos_id = offset_mapping[sub_tail_id][1]
subject_text = text[sub_head_pos_id:sub_tail_pos_id]
# 根据 subject 计算出对应 object 和 relation
sub_head_seq = torch.tensor(multihot(len(mask), sub_head_id)).to(DEVICE)
sub_tail_seq = torch.tensor(multihot(len(mask), sub_tail_id)).to(DEVICE)
pred_obj_head, pred_obj_tail = model.get_objs_for_specific_sub(\
encoded_text.unsqueeze(0), sub_head_seq.unsqueeze(0), sub_tail_seq.unsqueeze(0))
# 按分类找对应关系
pred_obj_head = pred_obj_head[0].T
pred_obj_tail = pred_obj_tail[0].T
for j in range(len(pred_obj_head)):
obj_head_ids = torch.where(pred_obj_head[j] > OBJ_HEAD_BAR)[0]
obj_tail_ids = torch.where(pred_obj_tail[j] > OBJ_TAIL_BAR)[0]
for obj_head_id in obj_head_ids:
obj_tail_ids = obj_tail_ids[obj_tail_ids >= obj_head_id]
if len(obj_tail_ids) == 0:
continue
obj_tail_id = obj_tail_ids[0]
if mask[obj_head_id] == 0 or mask[obj_tail_id] == 0:
continue
# 根据位置信息反推出 object 文本内容,mapping中已经有移位,不需要再加1
obj_head_pos_id = offset_mapping[obj_head_id][0]
obj_tail_pos_id = offset_mapping[obj_tail_id][1]
object_text = text[obj_head_pos_id:obj_tail_pos_id]
triple_list.append((subject_text, id2rel[j], object_text))
return list(set(triple_list))
def report(model, encoded_text, pred_y, batch_text, batch_mask):
# 计算三元结构,和统计指标
pred_sub_head, pred_sub_tail, _, _ = pred_y
true_triple_list = batch_text['triple_list']
pred_triple_list = []
correct_num, predict_num, gold_num = 0, 0, 0
# 遍历batch
for i in range(len(pred_sub_head)):
text = batch_text['text'][i]
true_triple_item = true_triple_list[i]
mask = batch_mask[i]
offset_mapping = batch_text['offset_mapping'][i]
sub_head_ids = torch.where(pred_sub_head[i] > SUB_HEAD_BAR)[0]
sub_tail_ids = torch.where(pred_sub_tail[i] > SUB_TAIL_BAR)[0]
pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
encoded_text[i], text, mask, offset_mapping)
# 统计个数
correct_num += len(set(true_triple_item) & set(pred_triple_item))
predict_num += len(set(pred_triple_item))
gold_num += len(set(true_triple_item))
pred_triple_list.append(pred_triple_item)
precision = correct_num / (predict_num + EPS)
recall = correct_num / (gold_num + EPS)
f1_score = 2 * precision * recall / (precision + recall + EPS)
print('\tcorrect_num:', correct_num, 'predict_num:', predict_num, 'gold_num:', gold_num)
print('\tprecision:%.3f' % precision, 'recall:%.3f' % recall, 'f1_score:%.3f' % f1_score)
if __name__ == '__main__':
dataset = Dataset()
loader = data.DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=dataset.collate_fn)
print(iter(loader).next())
#model.py
import torch.nn as nn
from transformers import BertModel
from config import *
import torch
import torch.nn.functional as F
# 忽略 transformers 警告
from transformers import logging
logging.set_verbosity_error()
class CasRel(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
# 冻结Bert参数,只训练下游模型
for name, param in self.bert.named_parameters():
param.requires_grad = False
self.sub_head_linear = nn.Linear(BERT_DIM, 1)
self.sub_tail_linear = nn.Linear(BERT_DIM, 1)
self.obj_head_linear = nn.Linear(BERT_DIM, REL_SIZE)
self.obj_tail_linear = nn.Linear(BERT_DIM, REL_SIZE)
def get_encoded_text(self, input_ids, mask):
return self.bert(input_ids, attention_mask=mask)[0]
def get_subs(self, encoded_text):
pred_sub_head = torch.sigmoid(self.sub_head_linear(encoded_text))
pred_sub_tail = torch.sigmoid(self.sub_tail_linear(encoded_text))
return pred_sub_head, pred_sub_tail
def get_objs_for_specific_sub(self, encoded_text, sub_head_seq, sub_tail_seq):
# sub_head_seq.shape (b, c) -> (b, 1, c)
sub_head_seq = sub_head_seq.unsqueeze(1).float()
sub_tail_seq = sub_tail_seq.unsqueeze(1).float()
# encoded_text.shape (b, c, 768)
sub_head = torch.matmul(sub_head_seq, encoded_text)
sub_tail = torch.matmul(sub_tail_seq, encoded_text)
encoded_text = encoded_text + (sub_head + sub_tail) / 2
# encoded_text.shape (b, c, 768)
pred_obj_head = torch.sigmoid(self.obj_head_linear(encoded_text))
pred_obj_tail = torch.sigmoid(self.obj_tail_linear(encoded_text))
# shape (b, c, REL_SIZE)
return pred_obj_head, pred_obj_tail
def forward(self, input, mask):
input_ids, sub_head_seq, sub_tail_seq = input
encoded_text = self.get_encoded_text(input_ids, mask)
pred_sub_head, pred_sub_tail = self.get_subs(encoded_text)
# 预测relation-object矩阵
pred_obj_head, pred_obj_tail = self.get_objs_for_specific_sub(\
encoded_text, sub_head_seq, sub_tail_seq)
return encoded_text, (pred_sub_head, pred_sub_tail, pred_obj_head, pred_obj_tail)
def loss_fn(self, true_y, pred_y, mask):
def calc_loss(pred, true, mask):
true = true.float()
# pred.shape (b, c, 1) -> (b, c)
pred = pred.squeeze(-1)
weight = torch.where(true > 0, CLS_WEIGHT_COEF[1], CLS_WEIGHT_COEF[0])
loss = F.binary_cross_entropy(pred, true, weight=weight, reduction='none')
if loss.shape != mask.shape:
mask = mask.unsqueeze(-1)
return torch.sum(loss * mask) / torch.sum(mask)
pred_sub_head, pred_sub_tail, pred_obj_head, pred_obj_tail = pred_y
true_sub_head, true_sub_tail, true_obj_head, true_obj_tail = true_y
return calc_loss(pred_sub_head, true_sub_head, mask) * SUB_WEIGHT_COEF + \
calc_loss(pred_sub_tail, true_sub_tail, mask) * SUB_WEIGHT_COEF + \
calc_loss(pred_obj_head, true_obj_head, mask) + \
calc_loss(pred_obj_tail, true_obj_tail, mask)
#train.py
from utils import *
from model import *
from torch.utils import data
if __name__ == '__main__':
model = CasRel().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
dataset = Dataset()
for e in range(EPOCH):
loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn)
for b, (batch_mask, batch_x, batch_y) in enumerate(loader):
batch_text, batch_sub_rnd = batch_x
batch_sub, batch_obj_rel = batch_y
# 整理input数据并预测
input_mask = torch.tensor(batch_mask).to(DEVICE)
input = (
torch.tensor(batch_text['input_ids']).to(DEVICE),
torch.tensor(batch_sub_rnd['head_seq']).to(DEVICE),
torch.tensor(batch_sub_rnd['tail_seq']).to(DEVICE),
)
encoded_text, pred_y = model(input, input_mask)
# 整理target数据并计算损失
true_y = (
torch.tensor(batch_sub['heads_seq']).to(DEVICE),
torch.tensor(batch_sub['tails_seq']).to(DEVICE),
torch.tensor(batch_obj_rel['heads_mx']).to(DEVICE),
torch.tensor(batch_obj_rel['tails_mx']).to(DEVICE),
)
loss = model.loss_fn(true_y, pred_y, input_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if b % 50 == 0:
print('>> epoch:', e, 'batch:', b, 'loss:', loss.item())
if b % 500 == 0:
report(model, encoded_text, pred_y, batch_text, batch_mask)
if e % 1 == 0:
torch.save(model, MODEL_DIR + f'model_{e}.pth')
#test.py
from utils import *
from model import *
from torch.utils import data
if __name__ == '__main__':
model = torch.load(MODEL_DIR + f'model_27.pth', map_location=DEVICE)
dataset = Dataset('dev')
with torch.no_grad():
loader = data.DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=dataset.collate_fn)
correct_num, predict_num, gold_num = 0, 0, 0
pred_triple_list = []
true_triple_list = []
for b, (batch_mask, batch_x, batch_y) in enumerate(loader):
batch_text, batch_sub_rnd = batch_x
batch_sub, batch_obj_rel = batch_y
# 整理input数据并预测
input_mask = torch.tensor(batch_mask).to(DEVICE)
input = (
torch.tensor(batch_text['input_ids']).to(DEVICE),
torch.tensor(batch_sub_rnd['head_seq']).to(DEVICE),
torch.tensor(batch_sub_rnd['tail_seq']).to(DEVICE),
)
encoded_text, pred_y = model(input, input_mask)
# 整理target数据并计算损失
true_y = (
torch.tensor(batch_sub['heads_seq']).to(DEVICE),
torch.tensor(batch_sub['tails_seq']).to(DEVICE),
torch.tensor(batch_obj_rel['heads_mx']).to(DEVICE),
torch.tensor(batch_obj_rel['tails_mx']).to(DEVICE),
)
loss = model.loss_fn(true_y, pred_y, input_mask)
print('>> batch:', b, 'loss:', loss.item())
# 计算关系三元组,和统计指标
pred_sub_head, pred_sub_tail, _, _ = pred_y
true_triple_list += batch_text['triple_list']
# 遍历batch
for i in range(len(pred_sub_head)):
text = batch_text['text'][i]
true_triple_item = true_triple_list[i]
mask = batch_mask[i]
offset_mapping = batch_text['offset_mapping'][i]
sub_head_ids = torch.where(pred_sub_head[i] > SUB_HEAD_BAR)[0]
sub_tail_ids = torch.where(pred_sub_tail[i] > SUB_TAIL_BAR)[0]
pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
encoded_text[i], text, mask, offset_mapping)
# 统计个数
correct_num += len(set(true_triple_item) & set(pred_triple_item))
predict_num += len(set(pred_triple_item))
gold_num += len(set(true_triple_item))
pred_triple_list.append(pred_triple_item)
precision = correct_num / (predict_num + EPS)
recall = correct_num / (gold_num + EPS)
f1_score = 2 * precision * recall / (precision + recall + EPS)
print('\tcorrect_num:', correct_num, 'predict_num:', predict_num, 'gold_num:', gold_num)
print('\tprecision:%.3f' % precision, 'recall:%.3f' % recall, 'f1_score:%.3f' % f1_score)
#predict.py
from config import *
from utils import *
from transformers import BertTokenizerFast
from model import *
if __name__ == '__main__':
text = '俞敏洪,出生于1962年9月4日的江苏省江阴市,大学毕业于北京大学西语系。'
tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)
tokenized = tokenizer(text, return_offsets_mapping=True)
info = {}
info['input_ids'] = tokenized['input_ids']
info['offset_mapping'] = tokenized['offset_mapping']
info['mask'] = tokenized['attention_mask']
input_ids = torch.tensor([info['input_ids']]).to(DEVICE)
batch_mask = torch.tensor([info['mask']]).to(DEVICE)
model = torch.load(MODEL_DIR + 'newmodel_21.pth', map_location=DEVICE)
encoded_text = model.get_encoded_text(input_ids, batch_mask)
pred_sub_head, pred_sub_tail = model.get_subs(encoded_text)
sub_head_ids = torch.where(pred_sub_head[0] > SUB_HEAD_BAR)[0]
sub_tail_ids = torch.where(pred_sub_tail[0] > SUB_TAIL_BAR)[0]
mask = batch_mask[0]
encoded_text = encoded_text[0]
offset_mapping = info['offset_mapping']
pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
encoded_text, text, mask, offset_mapping)
print(text)
print(pred_triple_item)
在第一节已经交代过了,可以查看我的博客。