Ner

import codecs
import random
import numpy as np
from gensim import corpora
from keras.layers import Dense,GRU,Bidirectional,SpatialDropout1D,Embedding
from keras import preprocessing
from keras.models import Sequential
import re
from keras_contrib.layers import CRF
def load_iob2(file_path):
    '''加载 IOB2 格式的数据'''
    token_seqs = []
    label_seqs = []
    tokens = []
    labels = []
    with codecs.open(file_path,'r',encoding='utf8') as f:
        for index, line in enumerate(f):
            items = line.strip().split()
            if len(items) == 2:
                token, label = items
                tokens.append(token)
                labels.append(label)
            elif len(items) == 0:
                if tokens:
                    token_seqs.append(tokens) #遇到了空行
                    label_seqs.append(labels)
                    tokens = []
                    labels = []
            else:
                print('格式错误。行号:{} 内容:{}'.format(index, line))
                continue

    if tokens:  # 如果文件末尾没有空行,手动将最后一条数据加入序列的列表中
        token_seqs.append(tokens)
        label_seqs.append(labels)

    return np.asarray(token_seqs), np.asarray(label_seqs)
def show_iob2(token_seqs, label_seqs, num=5, shuffle=True):
    '''显示 IOB2 格式数据'''
    if shuffle:
        length = len(token_seqs)
        indexes = [random.randrange(0, length) for i in range(num)]
        # print(token_seqs[[1,2]]) #选取几列
        zip_seqs = zip(token_seqs[indexes], label_seqs[indexes])
    else:
        zip_seqs = zip(token_seqs[0:num], label_seqs[0:num])
    for tokens, labels in zip_seqs:
        for token, label in zip(tokens, labels):
            print('{}/{} '.format(token, label), end='')
        print('\n')
def dic(file_path,token_seq,label_seq):
    all_words = []
    dic_list = []
    seq_list = []
    with codecs.open(file_path, 'r', encoding='utf8') as f:
        lf =f.readlines()
        for i in token_seq:
            all_words.append(len(i))
        for i in lf :
            j = i.split('\t')
            if j[0] == '\n':
                continue
            dic_list.append(j[0])
            try:
                seq_list.append(j[1])
            except:
                pass
        dic_list = list([dic_list])
        dic_file = corpora.Dictionary(dic_list)
        dic_file_len = len(dic_file)
        label_res = list(set(re.sub('\n','',i) for i in seq_list))
        label_result =  list([label_res])
        seq_train = [dic_file.doc2idx(text) for text in token_seq]
        label_file = corpora.Dictionary(label_result)
        label_target = [label_file.doc2idx(text) for text in label_seq]
    max_words = np.max(all_words)
    return seq_train,label_target,label_file,max_words,dic_file_len,dic_file
token_seq,label_seq = load_iob2(r'C:\Users\DELL\Desktop\Kashgari-master\dh_msra.txt')
seq_train,label_traget,label_file,max_word,dic_file_len,dic_file = dic(r'C:\Users\DELL\Desktop\Kashgari-master\dh_msra.txt',token_seq,label_seq)
# print(label_res)
X_train = preprocessing.sequence.pad_sequences(seq_train,maxlen=max_word)
y_train = preprocessing.sequence.pad_sequences(label_traget,maxlen=max_word,value=-1)
label =  ['B-LOC', 'B-ORG', 'B-PER', 'I-LOC', 'I-ORG','I-PER','O']
lab_pad = np.expand_dims(y_train, 2) #扩充维度
sen_test = "北京故宫,清华大学图书馆"
char2id = [[ dic_file.token2id.get(i) for i in sen_test]]
X_test = preprocessing.sequence.pad_sequences(char2id,maxlen=max_word)
model = Sequential()
model.add(Embedding(dic_file_len+1,100,mask_zero=True))
model.add(Bidirectional(GRU(32,return_sequences=True)))
print(len(label_file))
crf = CRF(len(label_file),sparse_target=True)
model.add(crf)
model.summary()
model.compile(optimizer='adam',loss=crf.loss_function,metrics=[crf.accuracy])
model.fit(X_train,lab_pad,batch_size=16,epochs=10)
result = model.predict(X_test[0][-len(sen_test):])
result_label = [np.argmax(i) for i in result]
print(result_label)
res2label =[label[i] for i in result_label]
print(res2label)
per, loc, org = '', '', ''
for s, t in zip(sen_test, res2label):
    if t in ('B-PER', 'I-PER'):
        per += ' ' + s if (t == 'B-PER') else s
    if t in ('B-ORG', 'I-ORG'):
        org += ' ' + s if (t == 'B-ORG') else s
    if t in ('B-LOC', 'I-LOC'):
        loc += ' ' + s if (t == 'B-LOC') else s
print(['person:' + per, 'location:' + loc, 'organzation:' + org])

 

你可能感兴趣的:(Ner)