Java实现word2vec

依赖:

  1. java深度学习框架,deeplearning4j:http://deeplearning4j.org/word2vec

  2. 开源中文分词框架,ansj_seg:http://www.oschina.net/p/ansj

<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-nlp</artifactId>
			<version>0.4-rc3.8</version>
		</dependency>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-x86</artifactId>
			<version>0.4-rc3.8</version>
		</dependency>
		<dependency>
			<groupId>org.ansj</groupId>
			<artifactId>ansj_seg</artifactId>
			<version>3.7.2</version>
		</dependency>

说明:word2vec深层次的原理不做说明,要实现word2vec主要要做到只有一件事,那就是词汇的识别。英文由于是空格符隔开的,所以分词比较容易,但是中文(包括日文、韩文)等是靠字形成句子的,所以借助分词工具将句子进行语法拆分很重要。word2vec借助一定的模型,通过对语料上下文进行分析,从而将词的含义投射到向量空间。相似的词在向量空间夹角很小,而不同的词差别则较大。需要注意的是,这一过程是不需要人工干预的,你只需要准备好语料即可.

下面直接贴代码:

第一个是本文核心的工具类,将模型封装为训练和装载操作。

/**
 * 
 * @author yuyuzhao
 * @since 2016年4月13日
 */
public class Word2VecUtils {

	private static Logger logger = LoggerFactory.getLogger(Word2VecUtils.class);

	private static String CHARSET = "UTF-8";
	private static int MIN_WORD_FREQUENCY = 5;
	private static float LEARNING_RATE = 0.025f;
	private static int LAYER_SIZE = 100;
	private static int SEED = 42;
	private static int WINDOW_SIZE = 5;
	private static Dictionary DICTIONARY = null;

	public static Word2Vec fit(String filePath, Memory memory) throws IOException {
		SentenceFactory spliter = new TextSentenceFactory(filePath, CHARSET);
		return fit(spliter.create(), memory);
	}

	public static Word2Vec fit(Collection<String> sentences, Memory memory) {

		if (CollectionUtils.isEmpty(sentences))
			return null;
		SentenceIterator iterator = new CollectionSentenceIterator(sentences);
		TokenizerFactory tokenizerFactory = new ANSJTokenizerFactory(DICTIONARY);
		tokenizerFactory.setTokenPreProcessor(new ChineseTokenPreProcess());

		return fit(iterator, tokenizerFactory, memory);
	}

	private static Word2Vec fit(SentenceIterator iterator, TokenizerFactory tokenizerFactory, Memory memory) {

		logger.info("Building model....");
		InMemoryLookupCache cache = new InMemoryLookupCache();
		WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>()
				.vectorLength(100)
				.useAdaGrad(false)
				.cache(cache)
				.lr(LEARNING_RATE)
				.build();

		Word2Vec vec = new Word2Vec.Builder().minWordFrequency(MIN_WORD_FREQUENCY)
				.iterations(1)
				.epochs(1)
				.layerSize(LAYER_SIZE)
				.seed(SEED)
				.windowSize(WINDOW_SIZE)
				.iterate(iterator)
				.tokenizerFactory(tokenizerFactory)
				.lookupTable(table)
				.vocabCache(cache)
				.build();

		logger.info("Fitting Word2Vec model....");
		vec.fit();

		if (memory != null) {
			WordVectorSerializer.writeFullModel(vec, memory.getPath());
			logger.info("The training has completed successfully and the result has been saved to Path[{}]",
					memory.getPath());
		}

		return vec;
	}

	public static Word2Vec load(@NonNull Memory memory) throws FileNotFoundException {

		Word2Vec vec = WordVectorSerializer.loadFullModel(memory.getPath());
		return vec;
	}

	public static class Config {

		private int minWordFrequency = 0;
		private String charset = null;
		private float learningRate = 0;
		private int layerSize = 0;
		private int seed = 0;
		private int windowSize = 0;
		private Dictionary dictionary = null;

		public Config minWordFrequency(int minWordFrequency) {
			this.minWordFrequency = minWordFrequency;
			return this;
		}

		public Config charset(String charset) {
			this.charset = charset;
			return this;
		}

		public Config learningRate(float learningRate) {
			this.learningRate = learningRate;
			return this;
		}

		public Config layerSize(int layerSize) {
			this.layerSize = layerSize;
			return this;
		}

		public Config seed(int seed) {
			this.seed = seed;
			return this;
		}

		public Config windowSize(int windowSize) {
			this.windowSize = windowSize;
			return this;
		}

		public Config dictionary(Dictionary dictionary) {
			this.dictionary = dictionary;
			return this;
		}

		public void apply() {
			if (minWordFrequency > 0)
				MIN_WORD_FREQUENCY = minWordFrequency;
			if (charset != null)
				CHARSET = charset;
			if (learningRate > 0)
				LEARNING_RATE = learningRate;
			if (layerSize > 0)
				LAYER_SIZE = layerSize;
			if (seed > 0)
				SEED = seed;
			if (windowSize > 0)
				WINDOW_SIZE = windowSize;
			if (dictionary != null)
				DICTIONARY = dictionary;
		}
	}
}

第二个定义一个记忆对象用于保存训练结果从而达到重用的目的,注意,多次训练虽然可行但不可取。推荐的方式是一次将语料加载进内存训练完成后多次调用。

public class Memory {

	private String path;

	private String folderName;

	private String fileName;

	private static final Logger logger = LoggerFactory.getLogger(Memory.class);

	public Memory(@NonNull String path, Policy policy) throws FileNotFoundException {
		String[] dirs = path.split("[\\\\/]");
		String fileName = dirs[dirs.length - 1];
		StringBuffer buffer = new StringBuffer();
		for (int i = 0; i < dirs.length - 1; i++) {
			buffer.append(dirs[i]);
			buffer.append(File.separator);
		}
		String folderName = buffer.toString();

		this.fileName = fileName;
		this.folderName = folderName;
		this.path = folderName + fileName;

		if (policy.value == Policy.INIT.value) {
			File file = new File(folderName);
			if (!file.exists())
				file.mkdirs();
			file = new File(this.path);
			if (file.exists()) {
				logger.info(
						"Memory in path [{}] has already existed,the operation will delete the old file then continue.",
						this.path);
				file.delete();
			}
		} else if (policy.value == Policy.RESTORE.value) {
			File file = new File(this.path);
			if (!file.exists())
				throw new FileNotFoundException();
		}

		logger.info("Memory certified successfully in path [{}]", this.path);
	}

	public String getPath() {
		return path;
	}

	public void setPath(String path) {
		this.path = path;
	}

	public String getFolderName() {
		return folderName;
	}

	public void setFolderName(String folderName) {
		this.folderName = folderName;
	}

	public String getFileName() {
		return fileName;
	}

	public void setFileName(String fileName) {
		this.fileName = fileName;
	}

	public enum Policy {
		INIT(1), RESTORE(2);

		public final int value;

		private Policy(int value) {
			this.value = value;
		}
	}

}

然后定义一个分词器tokenizer作为中文分词组件

public class ANSJTokenizer implements Tokenizer {

	private List<String> tokenizer;
	private TokenPreProcess tokenPreProcess;
	private int index = 0;

	public ANSJTokenizer(String toTokenize) {
		List<Term> terms = ToAnalysis.parse(toTokenize);
		tokenizer = new ArrayList<String>();
		String word;
		for (Term term : terms) {
			word = term.getName();
			if (StringUtils.isNotBlank(word)) {
				tokenizer.add(word);
			}
		}
	}

	@Override
	public boolean hasMoreTokens() {
		return index < tokenizer.size();
	}

	@Override
	public int countTokens() {
		return tokenizer.size();
	}

	@Override
	public String nextToken() {
		String base = tokenizer.get(index++);
		if (tokenPreProcess != null)
			base = tokenPreProcess.preProcess(base);
		return base;
	}

	@Override
	public List<String> getTokens() {
		return tokenizer;
	}

	@Override
	public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) {
		this.tokenPreProcess = tokenPreProcessor;
	}

}

分词器工厂用于做一些初始化同时提供分词器

/**
 * 
 * @author yuyuzhao
 * @since 2016年4月13日
 *
 */
public class ANSJTokenizerFactory implements TokenizerFactory {

	private TokenPreProcess tokenPreProcess;

	public ANSJTokenizerFactory(Dictionary dic) {
		if (dic != null)
			dic.expand();
	}

	@Override
	public Tokenizer create(String toTokenize) {
		Tokenizer t = new ANSJTokenizer(toTokenize);
		t.setTokenPreProcessor(tokenPreProcess);
		return t;
	}

	@Override
	public Tokenizer create(InputStream toTokenize) {
		throw new UnsupportedOperationException("Could not create Tokenizer with InputStream,Try with String");
	}

	@Override
	public void setTokenPreProcessor(TokenPreProcess preProcessor) {
		this.tokenPreProcess = preProcessor;
	}

}

分词前的预处理

public class ChineseTokenPreProcess implements TokenPreProcess {

	@Override
	public String preProcess(String token) {
		if (token == null)
			return null;
		return token.replaceAll("[^\u4e00-\u9fa5\\w]+", " ");
	}
}

将本地文档形成句子集合提供给分词器

public class TextSentenceFactory implements SentenceFactory {

	private StringBuffer buffer;
	private String charset;

	private static final String FORMAT = ".txt";
	private static final Logger logger = LoggerFactory.getLogger(TextSentenceFactory.class);

	public TextSentenceFactory(String filePath, String charset) throws IOException {
		if (Charset.isSupported(charset))
			this.charset = charset;
		else
			this.charset = "UTF-8";

		File file = new File(filePath);
		if (!file.exists()) {
			logger.error("Source [" + filePath + "]" + "did not exist!");
			return;
		}
		if (file.isFile() && file.getName().endsWith(FORMAT)) {
			buffer = IOUtils.read(file, this.charset, false);
		} else if (file.isDirectory()) {
			logger.info("Searching files from directory [{}]", file.getName());
			buffer = IOUtils.traverse(file, this.charset, new FormatFileFilter(), false);
		}
	}

	private static class FormatFileFilter implements FileFilter {

		@Override
		public boolean accept(File pathname) {
			return pathname.isFile() ? pathname.getName().endsWith(FORMAT) : false;
		}

	}

	@Override
	public Collection<String> create() {
		// 此正则表达式断句是经过多次优化后得出的,请谨慎修改
		// 使用逗号进行断句既能保证语意完整,同时又不至于过于复杂造成混淆
		// 根据此表达式的测试结果相对来说是最理想的
		// Note:This regular expression tend to be the best practice after
		// several tests,replace it cautiously.
		return RegexUtils.group(buffer, "[^,,。.??!!\\s]+");
	}
}
/**
 * Split the text into sentences
 * 
 * @author yuyuzhao
 * @since 2016年4月15日
 *
 */
public interface SentenceFactory {

	public Collection<String> create();
}

最后是我封装的IOUtils

public class IOUtils {

	private static final Logger logger = LoggerFactory.getLogger(IOUtils.class);

	private static final String LINE_BREAKER = "\n";

	public static StringBuffer read(File file, String charset, boolean lineBreak) throws IOException {
		StringBuffer sb = new StringBuffer();
		readToBuffer(file, sb, charset, lineBreak);
		return sb;
	}

	public static StringBuffer traverse(File file, String charset, FileFilter filter, boolean breakLine)
			throws IOException {
		StringBuffer buffer = new StringBuffer();
		traverseFolder(buffer, file, charset, filter, breakLine);
		return buffer;
	}

	private static void traverseFolder(StringBuffer buffer, File folder, String charset, FileFilter filter,
			boolean breakLine) throws IOException {
		File[] files = folder.listFiles();
		for (File file : files) {
			if (file.isFile() && filter.accept(file)) {
				readToBuffer(file, buffer, charset, breakLine);
			} else if (file.isDirectory()) {
				traverseFolder(buffer, file, charset, filter, breakLine);
			}
		}
	}

	private static void readToBuffer(File file, StringBuffer buffer, String charset, boolean lineBreak)
			throws IOException {
		InputStream is = new FileInputStream(file);
		InputStreamReader isr = new InputStreamReader(new BufferedInputStream(is, 10 * 1024), charset);
		BufferedReader br = new BufferedReader(isr);
		String line = br.readLine();
		while (line != null) {
			buffer.append(line);
			if (lineBreak)
				buffer.append(LINE_BREAKER);
			line = br.readLine();
		}
		br.close();
		isr.close();
		is.close();
		logger.info("Read CharSequence successfully from path [{}]", file.getAbsolutePath());
	}

}


你可能感兴趣的:(java,数据挖掘,word2vec,deeplearning4j)