文本相似度之LSI

阅读更多

1.VSM简介

     空间向量模型VSM,是将文本表示成数值表示的向量。在使用VSM做文本相似度计算时,其基本步骤是:

    1)将文本分词,提取特征词s:(t1,t2,t3,t4)

    2)将特征词用权重表示,从而将文本表示成数值向量s:(w1,w2,w3,w4),权重表示的方式一般使用tfidf

    3)计算文本向量间的余弦值,判断文本间的相似度

缺点:空间向量模型以词袋为基础,没有考虑词与词间的关系,近义词等。

 

 2.LSI介绍

     潜在语义索引(Latent Semantic Indexing,以下简称LSI),有的文章也叫Latent Semantic  Analysis(LSA)。是一种简单实用的主题模型。LSI是基于奇异值分解(SVD)的方法来得到文本的主题的。其推导过程如下图:(参考:http://www.cnblogs.com/pinard/p/6251584.html ,https://www.cnblogs.com/pinard/p/6805861.html)

     
文本相似度之LSI_第1张图片

 

使用LSI计算文本相似度的基本过程:

    1).和VSM一样,得到文本的tfidf值的向量表示 

    2).做SVD分解,得到Uk词和词义之间的相关性。Vk文本和主题的相关性。

    3).利用文本主题矩阵计算文本的相似度(通过余弦值)

 

3.应用

1.准备数据

    数据分为两部分(Idx2ID,questionList)

   idx2ID:保存数据库中的ID及ID在文件中的下标,例如({0:100,1:101,2:102} 其中0,1,2为下标;100,102,103为数据库中的ID值)

  questionList:保存的是数据库中问题切词后的二维数组,例如[['我','跑路','了'],['你',‘很’,‘不开心’]]

  其中questionList中的数组下标与idx2ID中的下标一 一对应,及数据库中的保存记录为:(id:100,question:‘我跑路了’)

 

 

import re
from tqdm import tqdm
import os
import pickle
import jieba
import pymysql
import pandas as pd 

'''
从数据库获取训练数据
'''

#清洗数据,去掉停用词并进行jieba分词
stwlist = pickle.load(open('stop_word.pkl','rb'))
def getTrainData(sentence):
    sentence=re.sub(r'[a-zA-Z0-9]+','', sentence)
    wordList = jieba.cut(sentence.strip(), cut_all=False)
    new_wordList = []
    for word in wordList:
        if word not in stwlist:
            new_wordList.append(word)
    return new_wordList

def get_ID_Question():
    aldb = pymysql.connect(host='xx.xx.xx.xx', port=3306, user='user1', passwd='123456', db="testdb", charset='utf8')
    IdList = []
    QuestionList = []
    with pymysql.cursors.SSCursor(aldb) as cursor:
        cursor.execute('select id,question from test')
        while True:
            rs = cursor.fetchone()
            if not rs:
                break
            id= rs[0]
            question = rs[1]
            IdList.append(id)
            QuestionList.append(getTrainData(question))
    #保存数据库中id及id在文件中对应的下标
    Idx2ID = {}
    for index, ID in enumerate(IdList):
        Idx2ID[index] = ID

    return Idx2ID,QuestionList

def save_ID_Question():
    Idx2ID_path = './cache/Idx2ID.pkl'
    QuesttionData_path='./cache/traindata.pkl'

    Idx2ID, QuestionList= get_ID_Question()

    #保存数据
    pickle.dump(Idx2ID, open(Idx2ID_path,'wb'))
    pickle.dump(QuestionList, open(QuesttionData_path, 'wb'))
    print('save data result!')

if __name__=='__main__':
    save_ID_Question()

 2.训练模型

#建立lsi模型,并将生成的LSI模型保存
def lsi_model(num_topics):
    lsi_model_dictionary = './cache/lsi_dictionary.pkl'
    lsi_model_lsi = './cache/model.lsi'
    lsi_model_index = './cache/lsimodel.index'
    cut_question_path = r'./cache/traindata.pkl'
    if os.path.exists(lsi_model_dictionary):
        dictionary = pickle.load(open(lsi_model_dictionary,'rb'))
        lsi = models.LsiModel.load(lsi_model_lsi)
        index = similarities.MatrixSimilarity.load(lsi_model_index)
    else:
        textdata1 = pickle.load(open(cut_question_path, 'rb'))
        #生成词典
        dictionary = corpora.Dictionary(textdata1)
        print('dictionary prepared!')
        corpus = [dictionary.doc2bow(text) for text in textdata1]
        print('corpus prepared!')
        #计算tfidf
        tfidf = models.TfidfModel(corpus)
        print('tfidf prepared!')
        corpus_tfidf = tfidf[corpus]
        #训练LSI
        lsi = models.LsiModel(corpus_tfidf, id2word=dictionary, num_topics = num_topics)
        print('lsi model prepared!')
        corpus_lsi = lsi[corpus_tfidf]
        #生成引用
        index = similarities.MatrixSimilarity(corpus_lsi)
        print('index prepared!')
        pickle.dump(dictionary, open(lsi_model_dictionary,'wb'))
        lsi.save(lsi_model_lsi)
        index.save(lsi_model_index)
    return dictionary, lsi, index

 3.测试

class test(object):
    def __init__(self):
        #加载模型数据
        self.dictionary = pickle.load(open('./cache/lsi_dictionary.pkl','rb'))
        self.lsi = models.LsiModel.load('./cache/model.lsi')
        self.index = similarities.MatrixSimilarity.load('./cache/lsimodel.index')
        self.stop_word = pickle.load(open('./stop_word.pkl','rb'))
        self.idx2ID = pickle.load(open('./cache/idx2ID.pkl','rb'))
    #清洗输入问题,并切词
    def deal_query(self,query):
        query = re.sub(r'[a-zA-Z0-9]+','',query)
        query_ques = jieba.cut(query.strip(), cut_all=False)
        new_wordList = []
        for word in query_ques:
            if word not in self.stop_word:
                new_wordList.append(word)
        return new_wordList

    #清洗匹配到的问题
    def clean_question(self, question):
        question = re.sub('律师', '', question)
        ques = pseg.cut(question)
        for word in ques:
            if word.flag == 'nr1' or word.flag == 'nr2' or word.flag == 'nr':
                keys = word.word
                question = re.sub(keys, '', question)
        return question

   #从数据库获取匹配到的问答对
    def get_answer(self,idx_set):
        question_set = []
        answer_set = []
        db = pymysql.connect(host='xx.xx.xx.x', port=3306, user='user1', passwd='123456', db="testdb", charset='utf8')
        for i in idx_set:
            ID = self.idx2ID[i]
            # print(ID)
            cursor = db.cursor()
            cursor.execute('select question,answer from test where id=%d;' %ID)
            db.commit()
            result = cursor.fetchone()
            if result:
                question = self.clean_question(result[0])
                question_set.append(question)
                answer = self.clean_answer(result[1])
                answer_set.append(answer)
        db.close()
        return question_set,answer_set
    #匹配
    def match(self,query):
        query_ques = self.deal_query(query)
        query_bow = self.dictionary.doc2bow(query_ques)
        query_lsi = self.lsi[query_bow]
        #计算余弦相似度,获取前3个最匹配的值
        sims = self.index[query_lsi]
        sort_sims = sorted(enumerate(sims), key=lambda item: -item[1])
        idx_set = [i[0] for i in sort_sims[:3]]
        sim_value = [i[1] for i in sort_sims[:3]]

        question_set,answer_set = self.get_answer(idx_set)

        return question_set,answer_set,sim_value

    def request(self,query):
        import json
        question_set,answer_set,sim_value = self.match(query)
        #相似度小于0.65的问题丢弃
        if sim_value[0] < 0.65:
            print("we have no match answer")
            return None
        i=0
        question_answer={}
        while i 2 and len(a) > 2:
                question_answer[q]=a
            if len(question_answer)>3:
                break
            i+=1
        # print_result2 = {k: v for k, v in question_answer.items() }
        print(len(question_answer))
        print(question_answer)
        return json.dumps(question_answer, ensure_ascii=False)

    #清洗答案数据
    def clean_answer(self,answer):
        # 将如果有多个回答,取长度最长的一个回答
        answer_sent = answer.strip().split('NEXT_ANS')
        if len(answer_sent) > 1:
            answer_len = [len(ans) for ans in answer_sent]
            inx = answer_len.index(max(answer_len))
            answer = answer_sent[inx]
        answer = re.sub('[,]+', ',', answer)
        answer = re.sub('[。]+', '。', answer)
        answer = re.sub('[.]+', '', answer)
        answer = re.sub('[,]+', ',', answer)

        from zhon.hanzi import punctuation
        # 去除 '\nNEXT_ANS'
        answer = answer.replace('NEXT_ANS', '')
        duels = [x + y for x in punctuation
                 for y in punctuation]
        pattern1 = re.compile('^[%s]' % punctuation)
        pattern2 = re.compile('[%s]$' % punctuation)
        pattern3 = re.compile('\[.*?\]')

        # 去除[]中的内容
        answer = pattern3.sub('', answer)

        # 去除空格
        answer = answer.split()
        answer = ''.join(answer)

        # 去除【】符号
        answer = answer.replace('【', '')
        answer = answer.replace('】', '')

        # 去除《》符号不全的句子中的《》
        if (answer.count('《') != answer.count('》')):
            answer = answer.replace('《', '')
            answer = answer.replace('》', '')

        # 去除‘去),张三你好’类型的错误
        delCStr = '《》()&%¥#@!{}【】'
        if (answer[2] in delCStr):
            answer = answer[2:]

        # 去除\n \\n \u3000\u30005 \xa0 \r "
        answer = re.sub(r'\\u3\d{3,4}', '', answer)
        answer = re.sub(r'\\n', '', answer)
        answer = re.sub(r'\\xa0', '', answer)
        answer = re.sub(r'\\r', '', answer)
        answer = re.sub(r'"', '', answer)
        #去除电话
        answer = re.sub(r'电话:', '', answer)
        answer = re.sub(r'电话\d{11}', '', answer)  # 电话13466352858
        answer = re.sub(r'电话\d{3}\-\d{8}', '', answer)  # 电话010-68812000
        answer = re.sub(r'\d{3}\-\d{8}', '', answer)  # 010-68812000
        answer = re.sub(r'0\d{10}', '', answer)  # 01065083250
        answer = re.sub(r'1[3458]\d{9}', '', answer)  # 手机号码
        answer = re.sub(r'[一二三四五六七八九零]{11,}', '', answer)  # 一三八二零二八八一二六

        # 去除QQ
        answer = re.sub(r'QQ:[1-9]\d{4,10}', '', answer)
        answer = re.sub(r'qq:[1-9]\d{4,10}', '', answer)
        answer = re.sub(r'QQ[1-9]\d{4,10}', '', answer)
        # 邮箱
        answer = re.sub(r'[0-9]*@.*.com', '', answer)
        # answer = re.sub(r'[1-9]\d{4,10}', '', answer)
        # 去除邮箱
        answer = re.sub(r'E-mail:', '', answer)
        answer = re.sub(r'email:', '', answer)
        answer = re.sub(r'电子邮箱', '', answer)
        answer = re.sub(r'邮箱', '', answer)
        answer = re.sub(r'[\w\.-]+@([\w]+\.)+[a-z]{2,3}', '', answer)
        # 去除网址
        answer = re.sub(r'个人网址:', '', answer)
        answer = re.sub(r'网址:', '', answer)
        answer = re.sub('[http|https|www].*[com|cn|html|0-9|a-z]', '', answer)
        # http://www51wfcom/zxzx/ZixunMainViewjsp?ID=bd84d23a2ec1fd73012ec2a422e40044

        answer = re.sub(r'[a-zA-Z]*', '', answer)

        # 特殊符号
        answer = re.sub(r'&.t', '', answer)
        answer = re.sub(r'×', '', answer)
        # 去除人名
        address = ''
        ans_list = pseg.cut(answer)
        for word in ans_list:
            # 去除不含国字的地名
            '''
            if (word.flag == 'ns' and not (word.word.find('国') > -1)):
                keys = word.word
                answer = re.sub(keys, '', answer)
                continue
            '''
            # 去除姓名:nr 人名,nr1 汉语姓氏,nr2 汉语名字
            if word.flag == 'nr1' or word.flag == 'nr2' or word.flag == 'nr':
                keys = word.word
                answer = re.sub(keys, '', answer)
        # 去除重复标点符号
        for d in duels:
            while d in answer:
                answer = answer.replace(d, d[0])

        # 去除开头标点符号
        if answer[0]!='《':
            answer = pattern1.sub('', answer)

        # 结尾加句号
        answer = pattern2.sub('。', answer)
        if pattern2.search(answer) is None:
            answer = answer + '。'
        else:
            answer = pattern2.sub('。', answer)

        return answer

if __name__=='__main__':
    # lsi_model(800)
    tst=test()
    while True:
        text=input('输入问题:')
        st=time.time()
        tst.request(text)
        end=time.time()
        print(end-st)

 

 

  • 文本相似度之LSI_第2张图片
  • 大小: 473.6 KB
  • 查看图片附件

你可能感兴趣的:(LSI)