gensim 官网: https://radimrehurek.com/gensim/tutorial.html
训练tfidf, lsi, lda, doc2vec等4种模型向量化文档
输入文件两列: 标题 \t 分词
do_train_model.py 训练模型
#! /usr/bin/env python
#encoding: utf-8
import sys
import os
import re
import logging
import time
from six import iteritems
from gensim import corpora, models, similarities
from gensim.models.doc2vec import LabeledSentence
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
filename='1.log',
filemode='w',
level=logging.INFO)
g_charset='gbk'
g_pattern = re.compile(' +')
class MyCorpus(object):
def __init__(self, fname):
self.fname = fname
def __iter__(self):
for i,line in enumerate(open(self.fname)):
s = line.rstrip('\n').split('\t')
yield g_pattern.split(s[1].decode(g_charset)) # format: title \t tokens
class MyLabelCorpus(object):
def __init__(self, fname):
self.fname = fname
def __iter__(self):
for i,line in enumerate(open(self.fname)):
s = line.rstrip('\n').split('\t')
yield LabeledSentence(words=s[1].decode(g_charset, 'ignore').split(),tags = [i])
def train_tfidf(corpus, dictionary, model_file, vec_file):
'''
train tfidf model
'''
tfidf = models.TfidfModel(corpus)
tfidf.save(model_file)
corpus_tfidf = tfidf[corpus]
corpora.SvmLightCorpus.serialize(vec_file, corpus_tfidf) # unserialize: corpora.SvmLightCorpus(vec_file)
def train_lsi(corpus, dictionary, model_file, vec_file):
'''
train lsi model
'''
tfidf = models.TfidfModel(corpus)
corpus_tfidf = tfidf[corpus]
lsi = models.LsiModel(corpus_tfidf, id2word=dictionary, num_topics=100)
lsi.save(model_file)
corpus_new = lsi[corpus_tfidf]
corpora.SvmLightCorpus.serialize(vec_file, corpus_new)
def train_lda(corpus, dictionary, model_file, vec_file):
'''
train lda model
'''
#tfidf = models.TfidfModel(corpus)
#corpus_tfidf = tfidf[corpus]
# also can use LdaModel
lda = models.LdaMulticore(corpus, id2word=dictionary, num_topics=100, chunksize = 2000, passes = 50, iterations = 50, eval_every = None, workers = 8)
lda.save(model_file)
corpus_new = lda[corpus]
corpora.SvmLightCorpus.serialize(vec_file, corpus_new)
def save_svmlight_format(docvecs, outfile):
fout = file(outfile, 'w')
for t in docvecs:
a = []
for i,v in enumerate(t):
a.append("%d:%.6f" % (i+1, v))
fout.write("0 %s\n" % " ".join(a))
fout.close()
def train_doc2vec(infile, model_file, vec_file):
'''
train doc2vec model
'''
corp = MyLabelCorpus(infile)
model = models.Doc2Vec(corp, size=100, window=5, min_count=3, workers=12, hs=1, negative=0, dbow_words=1, iter = 40)
model.save(model_file)
save_svmlight_format(model.docvecs, vec_file)
def read_stop_file(stop_file):
stoplist = []
if os.path.isfile(stop_file):
with open(stop_file) as f:
stoplist = [w.strip().decode(g_charset, 'ignore') for w in f.readlines()]
return stoplist
def read_corpus(infile):
'''
read corpus file and filter words
'''
corp = MyCorpus(infile)
dictionary = corpora.Dictionary(corp)
stop_file = 'stopwords.txt'
stoplist = read_stop_file(stop_file)
stop_ids = [dictionary.token2id[stopword] for stopword in stoplist \
if stopword in dictionary.token2id]
once_ids = [tokenid for tokenid, docfreq in iteritems(dictionary.dfs) if docfreq <= 1]
print "stop_ids: ",len(stop_ids)
print "once_ids: ",len(once_ids)
dictionary.filter_tokens(stop_ids + once_ids)
dictionary.compactify()
print "uniq tokens:", len(dictionary)
corpus = [dictionary.doc2bow(text) for text in corp]
return corpus, dictionary
def train_model(infile, tag):
'''
train different model to vecterize the document
'''
valid_tags = set(["tfidf", 'lsi', 'lda', 'doc2vec'])
if tag not in valid_tags:
print "wrong tag: %s" % tag
return
ts = time.time()
prefix = "%s.%s" % (infile, tag)
model_file = prefix + ".model"
vec_file = prefix + ".vec"
if tag == 'doc2vec':
train_doc2vec(infile, model_file, vec_file)
else:
corpus, dictionary = read_corpus(infile)
if tag == 'tfidf':
train_tfidf(corpus, dictionary, model_file, vec_file)
elif tag == 'lsi':
train_lsi(corpus, dictionary, model_file, vec_file)
elif tag == 'lda':
train_lda(corpus, dictionary, model_file, vec_file)
ts2 = time.time()
cost = int(ts2-ts)
print "cost_time:\t%s\t%s\t%d" % (infile, tag, cost)
if __name__ == '__main__':
if len(sys.argv) != 3:
print "Usage: %s " % __file__
print "\t train different model to vecterize the document"
print ": tfidf, lsi, lda, doc2vec"
sys.exit(-1)
infile = sys.argv[1]
tag = sys.argv[2]
train_model(infile, tag)
do_query_simi.py 查询相似文档
#! /usr/bin/env python
#encoding: utf-8
import sys
import os
import logging
from gensim import corpora, models, similarities
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
filename='1.log',
filemode='w',
level=logging.INFO)
def read_doc_file(infile):
'''
read doc file
format: title \t tokens
'''
docs = []
for line in file(infile):
s = line.rstrip('\n').split('\t')
docs.append(s[0])
return docs
def get_feature_num(corpus_semantic):
'''
'''
max_index = -1
for i,v in enumerate(corpus_semantic):
max_cur = max([t[0] for t in v])
if max_cur > max_index:
max_index = max_cur
max_index += 1
return max_index
def query_simi(infile, tag):
'''
query similar documents based on trained document vectors
'''
valid_tags = set(["tfidf", 'lsi', 'lda', 'doc2vec'])
if tag not in valid_tags:
print "wrong tag: %s" % tag
return
prefix = "%s.%s" % (infile, tag)
vec_file = prefix + ".vec"
index_file = vec_file + ".index"
index = None
corpus_semantic = corpora.SvmLightCorpus(vec_file)
n = get_feature_num(corpus_semantic)
print "feature num:", n
if os.path.isfile(index_file):
if tag == 'tfidf':
index = similarities.SparseMatrixSimilarity.load(index_file)
else:
index = similarities.MatrixSimilarity.load(index_file)
else:
if tag == 'tfidf':
index = similarities.SparseMatrixSimilarity(corpus_semantic, num_features = n)
else:
index = similarities.MatrixSimilarity(corpus_semantic)
index.save(index_file)
# read file
docs = read_doc_file(infile)
doc_map = {}
for i,doc in enumerate(docs):
doc_map[doc] = i
# query
topN = 10
corpus_semantic = list(corpus_semantic)
while True:
query = raw_input("\ninput query: ")
if query == 'q' or query == 'quit':
break
query = query.strip()
q = doc_map.get(query, -1)
if q == -1:
continue
print q
#print "query_doc: %s" % docs[q]
sims = index[corpus_semantic[q]]
sims = sorted(enumerate(sims), key=lambda item: -item[1])
for k,v in sims[:topN]:
i = int(k)
print "%.3f\t%s" % (v, docs[i])
if __name__ == '__main__':
if len(sys.argv) != 3:
print "Usage: %s " % __file__
print ": tfidf, lsi, lda, doc2vec"
sys.exit(-1)
infile = sys.argv[1]
tag = sys.argv[2]
query_simi(infile, tag)