预处理及训练过程:
代码:
# -*- coding:utf-8 -*-
from tensorflow.python.keras.models import Model, load_model
from tensorflow.python.keras.layers import Input, Dense, Dropout, LSTM, Embedding, TimeDistributed, Bidirectional
from tensorflow.python.keras.utils import np_utils
import numpy as np
import re
vocab = open("data/msr/msr_training_words.utf8").read().rstrip('\n').split('\n')
vocab = list(''.join(vocab))
stat = {}
for v in vocab:
stat[v] = stat.get(v, 0) + 1
stat = sorted(stat.items(), key=lambda x: x[1], reverse=True)
vocab = [s[0] for s in stat]
print(len(vocab))
char2id = {w: c + 1 for c, w in enumerate(vocab)}
id2char = {c + 1: w for c, w in enumerate(vocab)}
tags = {'s': 0, 'b': 1, 'm': 2, 'e': 3, 'x': 4}
embedding_size = 128
maxlen = 32
hidden_size = 64
batch_size = 64
epochs = 50
def load_data(path):
data = open(path).read().rstrip('\n')
data = re.split('[,。?!、\n]', data)
print('共有 %s 条数据' % len(data))
print('平均长度 %d' % np.mean([len(d.replace(" ", "")) for d in data]))
X_data = []
Y_data = []
for sentence in data:
sentence = sentence.split(" ")
X = []
y = []
try:
for s in sentence:
s = s.strip()
if len(s) == 0:
continue
elif len(s) == 1:
X.append(char2id[s])
y.append(tags['s'])
elif len(s) > 1:
X.append(char2id[s[0]])
y.append(tags['b'])
for i in range(1, len(s) - 1):
X.append(char2id[s[i]])
y.append(tags['m'])
X.append(char2id[s[-1]])
y.append(tags['e'])
if len(X) > maxlen:
X = X[:maxlen]
y = y[:maxlen]
else:
for i in range(maxlen - len(X)):
X.append(0)
y.append(tags['x'])
except:
continue
else:
if len(X) > 0:
X_data.append(X)
Y_data.append(y)
X_data = np.array(X_data)
Y_data = np_utils.to_categorical(Y_data, 5)
return X_data, Y_data
X_train, y_train = load_data('data/msr/msr_training.utf8')
X_test, y_test = load_data('data/msr/msr_test_gold.utf8')
print('X_train size:', X_train.shape)
print('y_train size:', y_train.shape)
print('X_test size:', X_test.shape)
print('y_test size:', y_test.shape)
X = Input(shape=[maxlen, ], dtype='int32', name='input')
embedding = Embedding(input_dim=len(vocab) + 1, output_dim=embedding_size, input_length=maxlen, mask_zero=True)(X)
blstm = Bidirectional(LSTM(hidden_size, return_sequences=True), merge_mode='concat')(embedding)
blstm = Dropout(0.6)(blstm)
blstm = Bidirectional(LSTM(hidden_size, return_sequences=True), merge_mode='concat')(blstm)
blstm = Dropout(0.6)(blstm)
output = TimeDistributed(Dense(5, activation='softmax'))(blstm)
model = Model(X, output)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs)
print(model.evaluate(X_train, y_train, batch_size=batch_size))
print(model.evaluate(X_test, y_test, batch_size=batch_size))
def viterbi(nodes):
trans = {'be': 0.5, 'bm': 0.5, 'eb': 0.5, 'es': 0.5, 'me': 0.5, 'mm': 0.5, 'sb': 0.5, 'ss': 0.5}
paths = {'b': nodes[0]['b'], 's': nodes[0]['s']}
for l in range(1, len(nodes)):
paths_ = paths.copy()
paths = {}
for i in nodes[1].keys():
nows = {}
for j in paths_.keys():
if j[-1] + i in trans.keys():
nows[j + i] = paths_[j] + nodes[l][i] + trans[j[-1] + i]
nows = sorted(nows.items(), key=lambda x: x[1], reverse=True)
paths[nows[0][0]] = nows[0][1]
paths = sorted(paths.items(), key=lambda x: x[1], reverse=True)
return paths[0][0]
def cut_words(data):
data = re.split('[,。!?、\n]', data)
sens = []
Xs = []
for sentence in data:
sen = []
X = []
sentence = list(sentence)
for s in sentence:
s = s.strip()
if not s == '' and s in char2id:
sen.append(s)
X.append(char2id[s])
if len(X) > maxlen:
sen = sen[:maxlen]
X = X[:maxlen]
else:
for i in range(maxlen - len(X)):
X.append(0)
if len(sen) > 0:
Xs.append(X)
sens.append(sen)
Xs = np.array(Xs)
ys = model.predict(Xs)
results = ''
for i in range(ys.shape[0]):
nodes = [dict(zip(['s', 'b', 'm', 'e'], d[:4])) for d in ys[i]]
ts = viterbi(nodes)
for x in range(len(sens[i])):
if ts[x] in ['s', 'e']:
results += sens[i][x] + '/'
else:
results += sens[i][x]
return results[:-1]
print(cut_words('中国共产党第十九次全国代表大会,是在全面建成小康社会决胜阶段、中国特色社会主义进入新时代的关键时期召开的一次十分重要的大会。'))
print(cut_words('把这本书推荐给,具有一定编程基础,希望了解数据分析、人工智能等知识领域,进一步提升个人技术能力的社会各界人士。'))
print(cut_words('结婚的和尚未结婚的。'))
张宏论《深度有趣》