1.VSM简介
空间向量模型VSM,是将文本表示成数值表示的向量。在使用VSM做文本相似度计算时,其基本步骤是:
1)将文本分词,提取特征词s:(t1,t2,t3,t4)
2)将特征词用权重表示,从而将文本表示成数值向量s:(w1,w2,w3,w4),权重表示的方式一般使用tfidf
3)计算文本向量间的余弦值,判断文本间的相似度
缺点:空间向量模型以词袋为基础,没有考虑词与词间的关系,近义词等。
2.LSI介绍
潜在语义索引(Latent Semantic Indexing,以下简称LSI),有的文章也叫Latent Semantic Analysis(LSA)。是一种简单实用的主题模型。LSI是基于奇异值分解(SVD)的方法来得到文本的主题的。其推导过程如下图:(参考:http://www.cnblogs.com/pinard/p/6251584.html ,https://www.cnblogs.com/pinard/p/6805861.html)
使用LSI计算文本相似度的基本过程:
1).和VSM一样,得到文本的tfidf值的向量表示
2).做SVD分解,得到Uk词和词义之间的相关性。Vk文本和主题的相关性。
3).利用文本主题矩阵计算文本的相似度(通过余弦值)
3.应用
1.准备数据
数据分为两部分(Idx2ID,questionList)
idx2ID:保存数据库中的ID及ID在文件中的下标,例如({0:100,1:101,2:102} 其中0,1,2为下标;100,102,103为数据库中的ID值)
questionList:保存的是数据库中问题切词后的二维数组,例如[['我','跑路','了'],['你',‘很’,‘不开心’]]
其中questionList中的数组下标与idx2ID中的下标一 一对应,及数据库中的保存记录为:(id:100,question:‘我跑路了’)
import re from tqdm import tqdm import os import pickle import jieba import pymysql import pandas as pd ''' 从数据库获取训练数据 ''' #清洗数据,去掉停用词并进行jieba分词 stwlist = pickle.load(open('stop_word.pkl','rb')) def getTrainData(sentence): sentence=re.sub(r'[a-zA-Z0-9]+','', sentence) wordList = jieba.cut(sentence.strip(), cut_all=False) new_wordList = [] for word in wordList: if word not in stwlist: new_wordList.append(word) return new_wordList def get_ID_Question(): aldb = pymysql.connect(host='xx.xx.xx.xx', port=3306, user='user1', passwd='123456', db="testdb", charset='utf8') IdList = [] QuestionList = [] with pymysql.cursors.SSCursor(aldb) as cursor: cursor.execute('select id,question from test') while True: rs = cursor.fetchone() if not rs: break id= rs[0] question = rs[1] IdList.append(id) QuestionList.append(getTrainData(question)) #保存数据库中id及id在文件中对应的下标 Idx2ID = {} for index, ID in enumerate(IdList): Idx2ID[index] = ID return Idx2ID,QuestionList def save_ID_Question(): Idx2ID_path = './cache/Idx2ID.pkl' QuesttionData_path='./cache/traindata.pkl' Idx2ID, QuestionList= get_ID_Question() #保存数据 pickle.dump(Idx2ID, open(Idx2ID_path,'wb')) pickle.dump(QuestionList, open(QuesttionData_path, 'wb')) print('save data result!') if __name__=='__main__': save_ID_Question()
2.训练模型
#建立lsi模型,并将生成的LSI模型保存 def lsi_model(num_topics): lsi_model_dictionary = './cache/lsi_dictionary.pkl' lsi_model_lsi = './cache/model.lsi' lsi_model_index = './cache/lsimodel.index' cut_question_path = r'./cache/traindata.pkl' if os.path.exists(lsi_model_dictionary): dictionary = pickle.load(open(lsi_model_dictionary,'rb')) lsi = models.LsiModel.load(lsi_model_lsi) index = similarities.MatrixSimilarity.load(lsi_model_index) else: textdata1 = pickle.load(open(cut_question_path, 'rb')) #生成词典 dictionary = corpora.Dictionary(textdata1) print('dictionary prepared!') corpus = [dictionary.doc2bow(text) for text in textdata1] print('corpus prepared!') #计算tfidf tfidf = models.TfidfModel(corpus) print('tfidf prepared!') corpus_tfidf = tfidf[corpus] #训练LSI lsi = models.LsiModel(corpus_tfidf, id2word=dictionary, num_topics = num_topics) print('lsi model prepared!') corpus_lsi = lsi[corpus_tfidf] #生成引用 index = similarities.MatrixSimilarity(corpus_lsi) print('index prepared!') pickle.dump(dictionary, open(lsi_model_dictionary,'wb')) lsi.save(lsi_model_lsi) index.save(lsi_model_index) return dictionary, lsi, index
3.测试
class test(object): def __init__(self): #加载模型数据 self.dictionary = pickle.load(open('./cache/lsi_dictionary.pkl','rb')) self.lsi = models.LsiModel.load('./cache/model.lsi') self.index = similarities.MatrixSimilarity.load('./cache/lsimodel.index') self.stop_word = pickle.load(open('./stop_word.pkl','rb')) self.idx2ID = pickle.load(open('./cache/idx2ID.pkl','rb')) #清洗输入问题,并切词 def deal_query(self,query): query = re.sub(r'[a-zA-Z0-9]+','',query) query_ques = jieba.cut(query.strip(), cut_all=False) new_wordList = [] for word in query_ques: if word not in self.stop_word: new_wordList.append(word) return new_wordList #清洗匹配到的问题 def clean_question(self, question): question = re.sub('律师', '', question) ques = pseg.cut(question) for word in ques: if word.flag == 'nr1' or word.flag == 'nr2' or word.flag == 'nr': keys = word.word question = re.sub(keys, '', question) return question #从数据库获取匹配到的问答对 def get_answer(self,idx_set): question_set = [] answer_set = [] db = pymysql.connect(host='xx.xx.xx.x', port=3306, user='user1', passwd='123456', db="testdb", charset='utf8') for i in idx_set: ID = self.idx2ID[i] # print(ID) cursor = db.cursor() cursor.execute('select question,answer from test where id=%d;' %ID) db.commit() result = cursor.fetchone() if result: question = self.clean_question(result[0]) question_set.append(question) answer = self.clean_answer(result[1]) answer_set.append(answer) db.close() return question_set,answer_set #匹配 def match(self,query): query_ques = self.deal_query(query) query_bow = self.dictionary.doc2bow(query_ques) query_lsi = self.lsi[query_bow] #计算余弦相似度,获取前3个最匹配的值 sims = self.index[query_lsi] sort_sims = sorted(enumerate(sims), key=lambda item: -item[1]) idx_set = [i[0] for i in sort_sims[:3]] sim_value = [i[1] for i in sort_sims[:3]] question_set,answer_set = self.get_answer(idx_set) return question_set,answer_set,sim_value def request(self,query): import json question_set,answer_set,sim_value = self.match(query) #相似度小于0.65的问题丢弃 if sim_value[0] < 0.65: print("we have no match answer") return None i=0 question_answer={} while i2 and len(a) > 2: question_answer[q]=a if len(question_answer)>3: break i+=1 # print_result2 = {k: v for k, v in question_answer.items() } print(len(question_answer)) print(question_answer) return json.dumps(question_answer, ensure_ascii=False) #清洗答案数据 def clean_answer(self,answer): # 将如果有多个回答,取长度最长的一个回答 answer_sent = answer.strip().split('NEXT_ANS') if len(answer_sent) > 1: answer_len = [len(ans) for ans in answer_sent] inx = answer_len.index(max(answer_len)) answer = answer_sent[inx] answer = re.sub('[,]+', ',', answer) answer = re.sub('[。]+', '。', answer) answer = re.sub('[.]+', '', answer) answer = re.sub('[,]+', ',', answer) from zhon.hanzi import punctuation # 去除 '\nNEXT_ANS' answer = answer.replace('NEXT_ANS', '') duels = [x + y for x in punctuation for y in punctuation] pattern1 = re.compile('^[%s]' % punctuation) pattern2 = re.compile('[%s]$' % punctuation) pattern3 = re.compile('\[.*?\]') # 去除[]中的内容 answer = pattern3.sub('', answer) # 去除空格 answer = answer.split() answer = ''.join(answer) # 去除【】符号 answer = answer.replace('【', '') answer = answer.replace('】', '') # 去除《》符号不全的句子中的《》 if (answer.count('《') != answer.count('》')): answer = answer.replace('《', '') answer = answer.replace('》', '') # 去除‘去),张三你好’类型的错误 delCStr = '《》()&%¥#@!{}【】' if (answer[2] in delCStr): answer = answer[2:] # 去除\n \\n \u3000\u30005 \xa0 \r " answer = re.sub(r'\\u3\d{3,4}', '', answer) answer = re.sub(r'\\n', '', answer) answer = re.sub(r'\\xa0', '', answer) answer = re.sub(r'\\r', '', answer) answer = re.sub(r'"', '', answer) #去除电话 answer = re.sub(r'电话:', '', answer) answer = re.sub(r'电话\d{11}', '', answer) # 电话13466352858 answer = re.sub(r'电话\d{3}\-\d{8}', '', answer) # 电话010-68812000 answer = re.sub(r'\d{3}\-\d{8}', '', answer) # 010-68812000 answer = re.sub(r'0\d{10}', '', answer) # 01065083250 answer = re.sub(r'1[3458]\d{9}', '', answer) # 手机号码 answer = re.sub(r'[一二三四五六七八九零]{11,}', '', answer) # 一三八二零二八八一二六 # 去除QQ answer = re.sub(r'QQ:[1-9]\d{4,10}', '', answer) answer = re.sub(r'qq:[1-9]\d{4,10}', '', answer) answer = re.sub(r'QQ[1-9]\d{4,10}', '', answer) # 邮箱 answer = re.sub(r'[0-9]*@.*.com', '', answer) # answer = re.sub(r'[1-9]\d{4,10}', '', answer) # 去除邮箱 answer = re.sub(r'E-mail:', '', answer) answer = re.sub(r'email:', '', answer) answer = re.sub(r'电子邮箱', '', answer) answer = re.sub(r'邮箱', '', answer) answer = re.sub(r'[\w\.-]+@([\w]+\.)+[a-z]{2,3}', '', answer) # 去除网址 answer = re.sub(r'个人网址:', '', answer) answer = re.sub(r'网址:', '', answer) answer = re.sub('[http|https|www].*[com|cn|html|0-9|a-z]', '', answer) # http://www51wfcom/zxzx/ZixunMainViewjsp?ID=bd84d23a2ec1fd73012ec2a422e40044 answer = re.sub(r'[a-zA-Z]*', '', answer) # 特殊符号 answer = re.sub(r'&.t', '', answer) answer = re.sub(r'×', '', answer) # 去除人名 address = '' ans_list = pseg.cut(answer) for word in ans_list: # 去除不含国字的地名 ''' if (word.flag == 'ns' and not (word.word.find('国') > -1)): keys = word.word answer = re.sub(keys, '', answer) continue ''' # 去除姓名:nr 人名,nr1 汉语姓氏,nr2 汉语名字 if word.flag == 'nr1' or word.flag == 'nr2' or word.flag == 'nr': keys = word.word answer = re.sub(keys, '', answer) # 去除重复标点符号 for d in duels: while d in answer: answer = answer.replace(d, d[0]) # 去除开头标点符号 if answer[0]!='《': answer = pattern1.sub('', answer) # 结尾加句号 answer = pattern2.sub('。', answer) if pattern2.search(answer) is None: answer = answer + '。' else: answer = pattern2.sub('。', answer) return answer if __name__=='__main__': # lsi_model(800) tst=test() while True: text=input('输入问题:') st=time.time() tst.request(text) end=time.time() print(end-st)