1.VSM简介
空间向量模型VSM,是将文本表示成数值表示的向量。在使用VSM做文本相似度计算时,其基本步骤是:
1)将文本分词,提取特征词s:(t1,t2,t3,t4)
2)将特征词用权重表示,从而将文本表示成数值向量s:(w1,w2,w3,w4),权重表示的方式一般使用tfidf
3)计算文本向量间的余弦值,判断文本间的相似度
缺点:空间向量模型以词袋为基础,没有考虑词与词间的关系,近义词等。
2.LSI介绍
潜在语义索引(Latent Semantic Indexing,以下简称LSI),有的文章也叫Latent Semantic Analysis(LSA)。是一种简单实用的主题模型。LSI是基于奇异值分解(SVD)的方法来得到文本的主题的。其推导过程如下图:(参考:http://www.cnblogs.com/pinard/p/6251584.html ,https://www.cnblogs.com/pinard/p/6805861.html)
使用LSI计算文本相似度的基本过程:
1).和VSM一样,得到文本的tfidf值的向量表示
2).做SVD分解,得到Uk词和词义之间的相关性。Vk文本和主题的相关性。
3).利用文本主题矩阵计算文本的相似度(通过余弦值)
3.应用
1.准备数据
数据分为两部分(Idx2ID,questionList)
idx2ID:保存数据库中的ID及ID在文件中的下标,例如({0:100,1:101,2:102} 其中0,1,2为下标;100,102,103为数据库中的ID值)
questionList:保存的是数据库中问题切词后的二维数组,例如[['我','跑路','了'],['你',‘很’,‘不开心’]]
其中questionList中的数组下标与idx2ID中的下标一 一对应,及数据库中的保存记录为:(id:100,question:‘我跑路了’)
import re
from tqdm import tqdm
import os
import pickle
import jieba
import pymysql
import pandas as pd
'''
从数据库获取训练数据
'''
#清洗数据,去掉停用词并进行jieba分词
stwlist = pickle.load(open('stop_word.pkl','rb'))
def getTrainData(sentence):
sentence=re.sub(r'[a-zA-Z0-9]+','', sentence)
wordList = jieba.cut(sentence.strip(), cut_all=False)
new_wordList = []
for word in wordList:
if word not in stwlist:
new_wordList.append(word)
return new_wordList
def get_ID_Question():
aldb = pymysql.connect(host='xx.xx.xx.xx', port=3306, user='user1', passwd='123456', db="testdb", charset='utf8')
IdList = []
QuestionList = []
with pymysql.cursors.SSCursor(aldb) as cursor:
cursor.execute('select id,question from test')
while True:
rs = cursor.fetchone()
if not rs:
break
id= rs[0]
question = rs[1]
IdList.append(id)
QuestionList.append(getTrainData(question))
#保存数据库中id及id在文件中对应的下标
Idx2ID = {}
for index, ID in enumerate(IdList):
Idx2ID[index] = ID
return Idx2ID,QuestionList
def save_ID_Question():
Idx2ID_path = './cache/Idx2ID.pkl'
QuesttionData_path='./cache/traindata.pkl'
Idx2ID, QuestionList= get_ID_Question()
#保存数据
pickle.dump(Idx2ID, open(Idx2ID_path,'wb'))
pickle.dump(QuestionList, open(QuesttionData_path, 'wb'))
print('save data result!')
if __name__=='__main__':
save_ID_Question()
2.训练模型
#建立lsi模型,并将生成的LSI模型保存
def lsi_model(num_topics):
lsi_model_dictionary = './cache/lsi_dictionary.pkl'
lsi_model_lsi = './cache/model.lsi'
lsi_model_index = './cache/lsimodel.index'
cut_question_path = r'./cache/traindata.pkl'
if os.path.exists(lsi_model_dictionary):
dictionary = pickle.load(open(lsi_model_dictionary,'rb'))
lsi = models.LsiModel.load(lsi_model_lsi)
index = similarities.MatrixSimilarity.load(lsi_model_index)
else:
textdata1 = pickle.load(open(cut_question_path, 'rb'))
#生成词典
dictionary = corpora.Dictionary(textdata1)
print('dictionary prepared!')
corpus = [dictionary.doc2bow(text) for text in textdata1]
print('corpus prepared!')
#计算tfidf
tfidf = models.TfidfModel(corpus)
print('tfidf prepared!')
corpus_tfidf = tfidf[corpus]
#训练LSI
lsi = models.LsiModel(corpus_tfidf, id2word=dictionary, num_topics = num_topics)
print('lsi model prepared!')
corpus_lsi = lsi[corpus_tfidf]
#生成引用
index = similarities.MatrixSimilarity(corpus_lsi)
print('index prepared!')
pickle.dump(dictionary, open(lsi_model_dictionary,'wb'))
lsi.save(lsi_model_lsi)
index.save(lsi_model_index)
return dictionary, lsi, index
3.测试
class test(object):
def __init__(self):
#加载模型数据
self.dictionary = pickle.load(open('./cache/lsi_dictionary.pkl','rb'))
self.lsi = models.LsiModel.load('./cache/model.lsi')
self.index = similarities.MatrixSimilarity.load('./cache/lsimodel.index')
self.stop_word = pickle.load(open('./stop_word.pkl','rb'))
self.idx2ID = pickle.load(open('./cache/idx2ID.pkl','rb'))
#清洗输入问题,并切词
def deal_query(self,query):
query = re.sub(r'[a-zA-Z0-9]+','',query)
query_ques = jieba.cut(query.strip(), cut_all=False)
new_wordList = []
for word in query_ques:
if word not in self.stop_word:
new_wordList.append(word)
return new_wordList
#清洗匹配到的问题
def clean_question(self, question):
question = re.sub('律师', '', question)
ques = pseg.cut(question)
for word in ques:
if word.flag == 'nr1' or word.flag == 'nr2' or word.flag == 'nr':
keys = word.word
question = re.sub(keys, '', question)
return question
#从数据库获取匹配到的问答对
def get_answer(self,idx_set):
question_set = []
answer_set = []
db = pymysql.connect(host='xx.xx.xx.x', port=3306, user='user1', passwd='123456', db="testdb", charset='utf8')
for i in idx_set:
ID = self.idx2ID[i]
# print(ID)
cursor = db.cursor()
cursor.execute('select question,answer from test where id=%d;' %ID)
db.commit()
result = cursor.fetchone()
if result:
question = self.clean_question(result[0])
question_set.append(question)
answer = self.clean_answer(result[1])
answer_set.append(answer)
db.close()
return question_set,answer_set
#匹配
def match(self,query):
query_ques = self.deal_query(query)
query_bow = self.dictionary.doc2bow(query_ques)
query_lsi = self.lsi[query_bow]
#计算余弦相似度,获取前3个最匹配的值
sims = self.index[query_lsi]
sort_sims = sorted(enumerate(sims), key=lambda item: -item[1])
idx_set = [i[0] for i in sort_sims[:3]]
sim_value = [i[1] for i in sort_sims[:3]]
question_set,answer_set = self.get_answer(idx_set)
return question_set,answer_set,sim_value
def request(self,query):
import json
question_set,answer_set,sim_value = self.match(query)
#相似度小于0.65的问题丢弃
if sim_value[0] < 0.65:
print("we have no match answer")
return None
i=0
question_answer={}
while i 2 and len(a) > 2:
question_answer[q]=a
if len(question_answer)>3:
break
i+=1
# print_result2 = {k: v for k, v in question_answer.items() }
print(len(question_answer))
print(question_answer)
return json.dumps(question_answer, ensure_ascii=False)
#清洗答案数据
def clean_answer(self,answer):
# 将如果有多个回答,取长度最长的一个回答
answer_sent = answer.strip().split('NEXT_ANS')
if len(answer_sent) > 1:
answer_len = [len(ans) for ans in answer_sent]
inx = answer_len.index(max(answer_len))
answer = answer_sent[inx]
answer = re.sub('[,]+', ',', answer)
answer = re.sub('[。]+', '。', answer)
answer = re.sub('[.]+', '', answer)
answer = re.sub('[,]+', ',', answer)
from zhon.hanzi import punctuation
# 去除 '\nNEXT_ANS'
answer = answer.replace('NEXT_ANS', '')
duels = [x + y for x in punctuation
for y in punctuation]
pattern1 = re.compile('^[%s]' % punctuation)
pattern2 = re.compile('[%s]$' % punctuation)
pattern3 = re.compile('\[.*?\]')
# 去除[]中的内容
answer = pattern3.sub('', answer)
# 去除空格
answer = answer.split()
answer = ''.join(answer)
# 去除【】符号
answer = answer.replace('【', '')
answer = answer.replace('】', '')
# 去除《》符号不全的句子中的《》
if (answer.count('《') != answer.count('》')):
answer = answer.replace('《', '')
answer = answer.replace('》', '')
# 去除‘去),张三你好’类型的错误
delCStr = '《》()&%¥#@!{}【】'
if (answer[2] in delCStr):
answer = answer[2:]
# 去除\n \\n \u3000\u30005 \xa0 \r "
answer = re.sub(r'\\u3\d{3,4}', '', answer)
answer = re.sub(r'\\n', '', answer)
answer = re.sub(r'\\xa0', '', answer)
answer = re.sub(r'\\r', '', answer)
answer = re.sub(r'"', '', answer)
#去除电话
answer = re.sub(r'电话:', '', answer)
answer = re.sub(r'电话\d{11}', '', answer) # 电话13466352858
answer = re.sub(r'电话\d{3}\-\d{8}', '', answer) # 电话010-68812000
answer = re.sub(r'\d{3}\-\d{8}', '', answer) # 010-68812000
answer = re.sub(r'0\d{10}', '', answer) # 01065083250
answer = re.sub(r'1[3458]\d{9}', '', answer) # 手机号码
answer = re.sub(r'[一二三四五六七八九零]{11,}', '', answer) # 一三八二零二八八一二六
# 去除QQ
answer = re.sub(r'QQ:[1-9]\d{4,10}', '', answer)
answer = re.sub(r'qq:[1-9]\d{4,10}', '', answer)
answer = re.sub(r'QQ[1-9]\d{4,10}', '', answer)
# 邮箱
answer = re.sub(r'[0-9]*@.*.com', '', answer)
# answer = re.sub(r'[1-9]\d{4,10}', '', answer)
# 去除邮箱
answer = re.sub(r'E-mail:', '', answer)
answer = re.sub(r'email:', '', answer)
answer = re.sub(r'电子邮箱', '', answer)
answer = re.sub(r'邮箱', '', answer)
answer = re.sub(r'[\w\.-]+@([\w]+\.)+[a-z]{2,3}', '', answer)
# 去除网址
answer = re.sub(r'个人网址:', '', answer)
answer = re.sub(r'网址:', '', answer)
answer = re.sub('[http|https|www].*[com|cn|html|0-9|a-z]', '', answer)
# http://www51wfcom/zxzx/ZixunMainViewjsp?ID=bd84d23a2ec1fd73012ec2a422e40044
answer = re.sub(r'[a-zA-Z]*', '', answer)
# 特殊符号
answer = re.sub(r'&.t', '', answer)
answer = re.sub(r'×', '', answer)
# 去除人名
address = ''
ans_list = pseg.cut(answer)
for word in ans_list:
# 去除不含国字的地名
'''
if (word.flag == 'ns' and not (word.word.find('国') > -1)):
keys = word.word
answer = re.sub(keys, '', answer)
continue
'''
# 去除姓名:nr 人名,nr1 汉语姓氏,nr2 汉语名字
if word.flag == 'nr1' or word.flag == 'nr2' or word.flag == 'nr':
keys = word.word
answer = re.sub(keys, '', answer)
# 去除重复标点符号
for d in duels:
while d in answer:
answer = answer.replace(d, d[0])
# 去除开头标点符号
if answer[0]!='《':
answer = pattern1.sub('', answer)
# 结尾加句号
answer = pattern2.sub('。', answer)
if pattern2.search(answer) is None:
answer = answer + '。'
else:
answer = pattern2.sub('。', answer)
return answer
if __name__=='__main__':
# lsi_model(800)
tst=test()
while True:
text=input('输入问题:')
st=time.time()
tst.request(text)
end=time.time()
print(end-st)