本文以项目readme.md训练逻辑的顺序解读
更多bert模型参考github地址
本文用的是BERT-Base, Cased(12-layer, 768-hidden, 12-heads , 110M parameters)下载地址。其中Cased表示保留真实的大小写和重音标记符,uncased表示文本在单词标记之前就已经变为小写,也去掉了任何重音标记例如,John Smith变成john smith。通常,Uncased模型会更好,除非大小写信息对于我们的任务很重要,如命名实体识别或词性标记。
以NYT数据集为例:NYT数据集是关于远程监督关系抽取任务的广泛使用的数据集。该数据集是通过将freebase中的关系与纽约时报(NYT)语料库对齐而生成的。纽约时报New York Times数据集包含150篇来自纽约时报的商业文章。抓取了从2009年11月到2010年1月纽约时报网站上的所有文章。在句子拆分和标记化之后,使用斯坦福NER标记器来标识PER和ORG从每个句子中的命名实体。对于包含多个标记的命名实体,我们将它们连接成单个标记。然后,我们将同一句子中出现的每一对(PER,ORG)实体作为单个候选关系实例,PER实体被视为ARG-1,ORG实体被视为ARG-2。
generate.py
# 将raw_NYT\train.json中的数字形式生成训练集的文本形式
def load_data(in_file, word_dict, rel_dict, out_file, normal_file, epo_file, seo_file):
with open(in_file, 'r') as f1, open(out_file, 'w') as f2, open(normal_file, 'w') as f3, \
open(epo_file,'w') as f4, open(seo_file, 'w') as f5:
seo_file, 'w') as f5:
cnt_normal = 0
cnt_epo = 0
cnt_seo = 0
lines = f1.readlines() # readlines()方法用于读取所有行(直到结束符EOF)并返回列表
for line in lines:
line = json.loads(line)
print(len(line))
lengths, sents, spos = line[0], line[1], line[2]
print(len(spos))
print(len(sents))
for i in range(len(sents)):
new_line = dict()
# print(sents[i])
# print(spos[i])
tokens = [word_dict[i] for i in sents[i]] # tokens为sents对应数字形式的字符串数组
sent = ' '.join(tokens) # 以空格形式连接字符串数组生成一个新的字符串
new_line['sentText'] = sent # new_line为包含三元组的字典
triples = np.reshape(spos[i], (-1, 3)) # 将spo[i]关系三元组的维度变为3列
relationMentions = []
for triple in triples:
rel = dict()
rel['em1Text'] = tokens[triple[0]]
rel['em2Text'] = tokens[triple[1]]
rel['label'] = rel_dict[triple[2]]
relationMentions.append(rel)
new_line['relationMentions'] = relationMentions
f2.write(json.dumps(new_line) + '\n')
if is_normal_triple(spos[i]):
f3.write(json.dumps(new_line) + '\n')
if is_multi_label(spos[i]):
f4.write(json.dumps(new_line) + '\n')
if is_over_lapping(spos[i]):
f5.write(json.dumps(new_line) + '\n')
build_data.py
# 读取数据集文件,将文本、三元组分类存储
with open('train.json') as f:
for l in tqdm(f): # tqdm是可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)
a = json.loads(l)
if not a['relationMentions']: # 若某个句子a中关系'relationMentions'为空,跳过之
continue
# 提取出每个句子及其三元组
line = {
'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), #去除'sentText'中的\r(回车)、\n(换行)、两头的'\'
'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if
i['label'] != 'None']
}
if not line['triple_list']:
continue
# 将提取出来的句子及其三元组信息加入到训练集数据train_data中,将三元组中的关系加入到集合rel_set中(无序不重复元素序列)
train_data.append(line)
for rm in a['relationMentions']:
if rm['label'] != 'None':
rel_set.add(rm['label'])
run.py中的默认参数:
{
"bert_model": "cased_L-12_H-768_A-12",
"max_len": 100,
"learning_rate": 1e-5,
"batch_size": 6,
"epoch_num": 100,
}
根据自己的设置修改
确定运行方式,使用的数据集
python run.py ---train=True --dataset=NYT
在测试集上评估
python run.py --dataset=NYT