deeplearning4j之GloVe实现实现

glove类似于word2vec,听说效果还比word2vec更加强大,可以用于做自然语言处理,正好学习deeplearning4j的时候看到了,顺便写在这,

文章用到的数据跟上一篇word2vec一样,看看效果吧,训练时间比word2vec要长太多,代码如下:

package com.meituan.deeplearning4j;

import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collection;

public class GloVeRaw {
	public static void main(String[] args) throws FileNotFoundException {
		String filePath = "/Users/shuubiasahi/Desktop/bayies/deeplearning/part-00000";
		SentenceIterator iter = new BasicLineIterator(new File(filePath));
		TokenizerFactory t = new DefaultTokenizerFactory();
		t.setTokenPreProcessor(new CommonPreprocessor());
		Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t)
		.alpha(0.75).learningRate(0.1)
				.epochs(25)
				.xMax(100)
				.batchSize(1000)
				.shuffle(true)
				.symmetric(true).build();

		glove.fit();

		System.out.println("和微信最接近的10个词汇:" + glove.wordsNearest("微信", 10));
		System.out.println(Arrays.toString(glove.getWordVector("微信")));
		System.out.println("微信和qq的相似度为:" + glove.similarity("微信", "腾讯聊天账号"));
		System.out.println("和美女最接近的10个词汇:" + glove.wordsNearest("腾讯聊天账号", 10));

		System.exit(0);
	}

}






你可能感兴趣的:(机器学习,java)