上一篇中文聊天机器人,没有做分词,语料一次喂入训练,模型比较粗糙。本章采用jieba分词,数据切分batch训练。
正文
一、数据预处理
使用seq2seq训练出模型,需要encoder_input, decoder_input和decoder_target三种数据。
使用标志
则 data_util.py
# 获取 question,answer数据
def get_raw_data():
with open(data_path, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
lines = lines[:-1]
lines = lines[:min(num_samples,len(lines)-1)]
question = []
answer = []
for pos, line in enumerate(lines):
if '\t' not in line:
print(line)
line = line.split('\t')
q = line[0].strip()
a = line[1].strip()
question.append(' '.join(jieba.lcut(Traditional2Simplified(q).strip(), cut_all=False)))
answer.append(' '.join(jieba.lcut(Traditional2Simplified(a).strip(), cut_all=False)))
return question, answer
获取到原 question和answer问答对,对它们做停用词过滤,词典构造(用于将词转向量和向量转词)。并使用answer语料生成decoder模型的输入和标签。然后将问答对转化为向量:
def get_qa_vec(question, answer):
global word_to_index,index_to_word
answer_a = ['BOS ' + i + ' EOS' for i in answer]
answer_b = [i + ' EOS' for i in answer]
question = np.array([[word_to_index[w] for w in i.split(' ')] for i in question])
answer_a = np.array([[word_to_index[w] for w in i.split(' ')] for i in answer_a])
answer_b = np.array([[word_to_index[w] for w in i.split(' ')] for i in answer_b])
for i, j in word_to_index.items():
word_to_index[i] = j + 1
for key, value in word_to_index.items():
index_to_word[value] = key
pad_question = question
pad_answer_a = answer_a
pad_answer_b = answer_b
for pos, i in enumerate(pad_question):
for pos_, j in enumerate(i):
i[pos_] = j + 1
if(len(i) > pad_maxLen):
pad_question[pos] = i[:pad_maxLen]
for pos, i in enumerate(pad_answer_a):
for pos_, j in enumerate(i):
i[pos_] = j + 1
if(len(i) > pad_maxLen):
pad_answer_a[pos] = i[:pad_maxLen]
for pos, i in enumerate(pad_answer_b):
for pos_, j in enumerate(i):
i[pos_] = j + 1
if(len(i) > pad_maxLen):
pad_answer_b[pos] = i[:pad_maxLen]
return question, answer_a, answer_b
语料对和词典构建完成,将向量padding对齐保存文件:
# padding input对齐
pad_question = sequence.pad_sequences(encoder_input_vec, maxlen=pad_maxLen,
dtype='int32', padding='post',
truncating='post')
pad_answer = sequence.pad_sequences(decoder_input_vec, maxlen=pad_maxLen,
dtype='int32', padding='post',
truncating='post')
with open('data/word_to_index.pkl', 'wb') as f:
pickle.dump(word_to_index, f, pickle.HIGHEST_PROTOCOL)
with open('data/index_to_word.pkl', 'wb') as f:
pickle.dump(index_to_word, f, pickle.HIGHEST_PROTOCOL)
with open('data/vocab_bag.pkl','wb') as f:
pickle.dump(vocab_bag,f,pickle.HIGHEST_PROTOCOL)
np.save('data/pad_question.npy', pad_question)
np.save('data/pad_answer.npy', pad_answer)
np.save('data/answer_o.npy', decoder_target_vec)
二、模型训练 train.py
构建seq2seq模型,如下:
def build_model():
truncatednormal = TruncatedNormal(mean=0.0, stddev=0.05)
embed_layer = Embedding(input_dim=vocab_size,
output_dim=100,
mask_zero=True,
input_length=None,
embeddings_initializer= truncatednormal)
LSTM_encoder = LSTM(512,
return_sequences=True,
return_state=True,
kernel_initializer= 'lecun_uniform',
name='encoder_lstm'
)
LSTM_decoder = LSTM(512,
return_sequences=True,
return_state=True,
kernel_initializer= 'lecun_uniform',
name='decoder_lstm'
)
#encoder输入 与 decoder输入
input_question = Input(shape=(None, ), dtype='int32', name='input_question')
input_answer = Input(shape=(None, ), dtype='int32', name='input_answer')
input_question_embed = embed_layer(input_question)
input_answer_embed = embed_layer(input_answer)
encoder_lstm, question_h, question_c = LSTM_encoder(input_question_embed)
decoder_lstm, _, _ = LSTM_decoder(input_answer_embed,
initial_state=[question_h, question_c])
attention = dot([decoder_lstm, encoder_lstm], axes=[2, 2])
attention = Activation('softmax')(attention)
context = dot([attention, encoder_lstm], axes=[2,1])
decoder_combined_context = concatenate([context, decoder_lstm])
decoder_dense1 = TimeDistributed(Dense(256,activation="tanh"))
decoder_dense2 = TimeDistributed(Dense(vocab_size,activation="softmax"))
output = decoder_dense1(decoder_combined_context)
output = decoder_dense2(output)
model = Model([input_question, input_answer], output)
return model
构建完model,即可加载预处理好的数据喂入训练:
if __name__ == "__main__":
question = np.load('data/pad_question.npy')
answer = np.load('data/pad_answer.npy')
answer_o = np.load('data/answer_o.npy', allow_pickle=True)
with open('data/vocab_bag.pkl', 'rb') as f:
words = pickle.load(f)
with open('data/word_to_index.pkl', 'rb') as f:
word_to_index = pickle.load(f)
with open('data/index_to_word.pkl', 'rb') as f:
index_to_word = pickle.load(f)
vocab_size = len(word_to_index) + 1
model = build_model()
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
checkpoint = ModelCheckpoint(filepath,
monitor='loss',
verbose=1,
save_best_only=True,
mode='min',
period=1,
save_weights_only=True
)
reduce_lr = ReduceLROnPlateau(monitor='loss',
factor=0.2,
patience=2,
verbose=1,
mode='min',
min_delta=0.0001,
cooldown=0,
min_lr=0
)
tensorboard = TensorBoard(log_dir='logs',
batch_size=100
)
callbacks_list = [checkpoint, reduce_lr, tensorboard]
initial_epoch_=0
file_list = os.listdir('models/')
if len(file_list) > 0:
epoch_list = get_file_list('models/')
epoch_last = epoch_list[-1]
model.load_weights('models/' + epoch_last)
print("**********checkpoint_loaded: ", epoch_last)
initial_epoch_ = int(epoch_last.split('-')[2]) - 1
print('**********Begin from epoch: ', str(initial_epoch_))
model.fit_generator(generate_train(batch_size=100),
steps_per_epoch=900, # (total samples) / batch_size 90000/100 = 900
epochs=1,
verbose=1,
callbacks=callbacks_list,
validation_data=generate_test(batch_size=100),
validation_steps=100, # 10000/100 = 100
class_weight=None,
max_queue_size=5,
workers=1,
use_multiprocessing=False,
shuffle=False,
initial_epoch=initial_epoch_
)
model.summary()
三、infer 预测
加载训练好的模型,进行预测:
model, encoder_model, decoder_model = build_model()
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.load_weights('models/W-- 51-0.7715-.h5')
model.summary()
while True:
seq = input('Please input question:')
if seq == 'exit':
break
seq, sentence = input_question(seq)
print(sentence)
answer = decode_greedy(seq, sentence)
# answer=decode_beamsearch(seq, 3)
print('ANSWER: ', answer)
四、结果展示
开源代码:代码