Doc2Vec计算句子相似度

X_train  就是自己的训练语料

“”“
date:2018_7_25
doc2vec计算句子相似性
”“”
# coding:utf-8

import sys
import time
import csv
import glob
import gensim
import sklearn
import numpy as np
import jieba.posseg as pseg
import jieba

from gensim.models.doc2vec import Doc2Vec, LabeledSentence

TaggededDocument = gensim.models.doc2vec.TaggedDocument

def loadPoorEnt(path2 = 'G:/project/sentimation_analysis/data/stopwords.csv'):
    csvfile = open(path2,encoding='UTF-8')
    stopwords  = [line.strip() for line in csvfile.readlines()]
    return stopwords
stop_words = loadPoorEnt()

def cut(data):
    result=[]    #pos=['n','v']
    res = pseg.cut(data)
    list = []
    for item in res:
        #if item.word not in stop_words and (item.flag == 'n' or item.flag == 'a' or item.flag == 'v'):
        if item.word not in stop_words :
            list.append(item.word)
    result.append(list)
    return result


def get_all_content():
    #abel_dir = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
    all_files = glob.glob(r'D:/GFZQ/GFZQ/xuesu2018/xuesu/*.csv')
    return all_files

def get_wenben(path):
	csvfile = open(path,'r',encoding='UTF-8')
	reader = csv.reader(csvfile)
	return reader

def get_QA(wenben):
    Q_all =[]
    A_all =[]
    for QA in wenben :
        Q_all.append(QA[1])
        A_all.append(QA[2])
    all = Q_all + A_all
    return all,Q_all,A_all


def get_datasest(all_csv):
    docs =  all_csv
    print( len(docs))
    x_train = []
    # y = np.concatenate(np.ones(len(docs)))
    all_sent = []
    for file_one in docs:
        for sent in file_one:
            #print (sent)
            all_sent.append(sent)
    for i,text in enumerate(all_sent):
        word_list = cut(text)
        #print(word_list[0])
        l = len(word_list[0])
        print (l)
        document = TaggededDocument(word_list[0], tags=[i])
        x_train.append(document)
    return x_train

def getVecs(model, corpus, size):
    vecs = [np.array(model.docvecs[z.tags[0]].reshape(1, size)) for z in corpus]
    return np.concatenate(vecs)

def train(x_train, size=200, epoch_num=1):
    model_dm = Doc2Vec(x_train, min_count=1, window=3, size=size, sample=1e-3, negative=5, workers=4)
    model_dm.train(x_train, total_examples=model_dm.corpus_count, epochs=70)
    model_dm.save('G:/project/sentimation_analysis/data/conference.model')
    return model_dm

def get_csvfile ():
    all_files = get_all_content()
    length = 28  # len(all_files)
    print ("统计了%d家公司的情感词" %length)
    all_csv = []
    for i in range(length):
        print ("正在解析第%d家公司" %i)
        file_one = all_files[i]
        wenben = get_wenben(file_one)
        all, Q_all, A_all = get_QA(wenben)
        all_csv.append(all)
    return all_csv

def stest():
    model_dm = Doc2Vec.load('G:/project/sentimation_analysis/data/conference_model.csv')
    test_text = ["我们是好孩子"]
    inferred_vector_dm = model_dm.infer_vector(test_text)
    # print  (inferred_vector_dm)
    sims = model_dm.docvecs.most_similar([inferred_vector_dm], topn=10)
    return sims


if __name__ == '__main__':

    start = time.clock()
    all_csv = get_csvfile()
    x_train = get_datasest(all_csv)
    model_dm = train(x_train)
    sims = stest()
    for count, sim in sims:
        sentence = x_train[count]
        print  ( sentence, sim, )

 

你可能感兴趣的:(NLP)