doc2vec实现

以下代码用到了 gensim 包和 pandas 包, 可以参考另一篇文章 word2vec 更好的理解本文的代码。
题外话:为了更好的理解 doc2vec 和 word2vec,建议阅读相关的 paper。
word2vec paper
doc2vec paper

# coding: utf-8

from gensim.models import Doc2Vec
from gensim.models.doc2vec import LabeledSentence
import pandas as pd


# 加载 csv 文件
def load_csv(file_name):
    df = pd.read_csv(file_name, sep='\t|\n', encoding='utf8', header=None,
                     names=['id', 'head', 'content', 'label'], engine='python')
    return df


class DocList(object):
    """
    文档迭代器
    """

    def __init__(self, df_list=None):
        
        train_words_clean_file = './data/train_words_clean.csv'
        test_words_clean_file = './data/test_words_clean.csv'
        if not df_list:
            self.df_list = [load_csv(train_words_clean_file), load_csv(test_words_clean_file)]

    def __iter__(self):
        
        tag = ['Train', 'Test']
        for i, df in enumerate(self.df_list):
            for line_num in range(df.shape[0]):
                words = []
                words.extend(df.iloc[line_num]['head'].split())
                words.extend(df.iloc[line_num]['content'].split())
                yield LabeledSentence(words, ['%s_%s' % (tag[i], words)])


class D2VModelManager:
    """
    Doc2Vec 模型管理器
    """

    def __init__(self):

        self.dm_model_name = './model/dm.d2v'

    def train_model(self):
        """
        训练
        """
        
        # 文档迭代器
        doc_l = DocList()
        
        # 定义模型
        d2v = Doc2Vec(dm=1, size=300, negative=5, hs=1, sample=1e-5,
                      window=10, min_count=5, workers=12, alpha=0.025, min_alpha=0.025)
        
        # 建立词典
        d2v.build_vocab(doc_l)
        
        # 训练10个 epoch
        for epoch in range(10):
            d2v.train(doc_l, total_examples=d2v.corpus_count, epochs=1)  # epochs 设置为 1
            d2v.alpha -= 0.002
            d2v.min_alpha = d2v.alpha
        d2v.save(self.dm_model_name)

    @staticmethod
    def load_model(model_name):
        """
        加载
        """
        return Doc2Vec.load(model_name)

你可能感兴趣的:(doc2vec实现)