使用deeplearning4j训练Word2Vec(Java操作)

本文作者:合肥工业大学 管理学院 钱洋 email:[email protected] 内容可能有不到之处,欢迎交流。

未经本人允许禁止转载。

DeepLearning4J(DL4J)是一套基于Java语言的神经网络工具包,可以构建、定型和部署神经网络

本文训练的数据集是deeplearning4j中的自带数据集,数据表示如下:


使用deeplearning4j训练Word2Vec(Java操作)_第1张图片

对应的Word2Vec操作程序如下:

package org.deeplearning4j.examples.nlp.word2vec;

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.Arrays;
import java.util.Collection;

/**
 * Created by qianyang on 8/28/2018.
 */
public class Test {

    private static Logger log = LoggerFactory.getLogger(Test.class);
    private static String outputPath = "E:/word2vec.txt";
    public static void main(String[] args) throws Exception {
        //输入文本文件的目录
        File inputTxt = new File("E:/raw_sentences.txt");
        log.info("开始加载数据...."+inputTxt.getName());
        //加载数据
        SentenceIterator iter = new LineSentenceIterator(inputTxt);
        //切词操作
        TokenizerFactory token = new DefaultTokenizerFactory();
        //去除特殊符号及大小写转换操作
        token.setTokenPreProcessor(new CommonPreprocessor());
        log.info("训练模型....");
        Word2Vec vec = new Word2Vec.Builder()
                .minWordFrequency(5)//词在语料中必须出现的最少次数
                .iterations(1)
                .layerSize(100)  //向量维度
                .seed(42)
                .windowSize(10) //窗口大小
                .iterate(iter)
                .tokenizerFactory(token)
                .build();
        log.info("配置模型....");
        vec.fit();
        log.info("输出词向量....");
        WordVectorSerializer.writeWordVectors(vec, outputPath);
        log.info("相似的词:");
        //获取相似的词
        Collection lst = vec.wordsNearest("day", 10);
        System.out.println(lst);
        //获取某词对应的向量
        log.info("向量获取:");
        double[] wordVector = vec.getWordVector("day");
        System.out.println(Arrays.toString(wordVector));
    }
}

其中,在控制台可以输出一下结果:


使用deeplearning4j训练Word2Vec(Java操作)_第2张图片

经过此训练,在指定目录下输出了每个词对应的向量,结果如下:

使用deeplearning4j训练Word2Vec(Java操作)_第3张图片

你可能感兴趣的:(深度学习(Deep,Learning),java,数据挖掘算法,自然语言处理方法及应用,深度学习算法原理与代码剖析)