依赖:
java深度学习框架,deeplearning4j:http://deeplearning4j.org/word2vec
开源中文分词框架,ansj_seg:http://www.oschina.net/p/ansj
<dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-nlp</artifactId> <version>0.4-rc3.8</version> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-x86</artifactId> <version>0.4-rc3.8</version> </dependency> <dependency> <groupId>org.ansj</groupId> <artifactId>ansj_seg</artifactId> <version>3.7.2</version> </dependency>
说明:word2vec深层次的原理不做说明,要实现word2vec主要要做到只有一件事,那就是词汇的识别。英文由于是空格符隔开的,所以分词比较容易,但是中文(包括日文、韩文)等是靠字形成句子的,所以借助分词工具将句子进行语法拆分很重要。word2vec借助一定的模型,通过对语料上下文进行分析,从而将词的含义投射到向量空间。相似的词在向量空间夹角很小,而不同的词差别则较大。需要注意的是,这一过程是不需要人工干预的,你只需要准备好语料即可.
下面直接贴代码:
第一个是本文核心的工具类,将模型封装为训练和装载操作。
/** * * @author yuyuzhao * @since 2016年4月13日 */ public class Word2VecUtils { private static Logger logger = LoggerFactory.getLogger(Word2VecUtils.class); private static String CHARSET = "UTF-8"; private static int MIN_WORD_FREQUENCY = 5; private static float LEARNING_RATE = 0.025f; private static int LAYER_SIZE = 100; private static int SEED = 42; private static int WINDOW_SIZE = 5; private static Dictionary DICTIONARY = null; public static Word2Vec fit(String filePath, Memory memory) throws IOException { SentenceFactory spliter = new TextSentenceFactory(filePath, CHARSET); return fit(spliter.create(), memory); } public static Word2Vec fit(Collection<String> sentences, Memory memory) { if (CollectionUtils.isEmpty(sentences)) return null; SentenceIterator iterator = new CollectionSentenceIterator(sentences); TokenizerFactory tokenizerFactory = new ANSJTokenizerFactory(DICTIONARY); tokenizerFactory.setTokenPreProcessor(new ChineseTokenPreProcess()); return fit(iterator, tokenizerFactory, memory); } private static Word2Vec fit(SentenceIterator iterator, TokenizerFactory tokenizerFactory, Memory memory) { logger.info("Building model...."); InMemoryLookupCache cache = new InMemoryLookupCache(); WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>() .vectorLength(100) .useAdaGrad(false) .cache(cache) .lr(LEARNING_RATE) .build(); Word2Vec vec = new Word2Vec.Builder().minWordFrequency(MIN_WORD_FREQUENCY) .iterations(1) .epochs(1) .layerSize(LAYER_SIZE) .seed(SEED) .windowSize(WINDOW_SIZE) .iterate(iterator) .tokenizerFactory(tokenizerFactory) .lookupTable(table) .vocabCache(cache) .build(); logger.info("Fitting Word2Vec model...."); vec.fit(); if (memory != null) { WordVectorSerializer.writeFullModel(vec, memory.getPath()); logger.info("The training has completed successfully and the result has been saved to Path[{}]", memory.getPath()); } return vec; } public static Word2Vec load(@NonNull Memory memory) throws FileNotFoundException { Word2Vec vec = WordVectorSerializer.loadFullModel(memory.getPath()); return vec; } public static class Config { private int minWordFrequency = 0; private String charset = null; private float learningRate = 0; private int layerSize = 0; private int seed = 0; private int windowSize = 0; private Dictionary dictionary = null; public Config minWordFrequency(int minWordFrequency) { this.minWordFrequency = minWordFrequency; return this; } public Config charset(String charset) { this.charset = charset; return this; } public Config learningRate(float learningRate) { this.learningRate = learningRate; return this; } public Config layerSize(int layerSize) { this.layerSize = layerSize; return this; } public Config seed(int seed) { this.seed = seed; return this; } public Config windowSize(int windowSize) { this.windowSize = windowSize; return this; } public Config dictionary(Dictionary dictionary) { this.dictionary = dictionary; return this; } public void apply() { if (minWordFrequency > 0) MIN_WORD_FREQUENCY = minWordFrequency; if (charset != null) CHARSET = charset; if (learningRate > 0) LEARNING_RATE = learningRate; if (layerSize > 0) LAYER_SIZE = layerSize; if (seed > 0) SEED = seed; if (windowSize > 0) WINDOW_SIZE = windowSize; if (dictionary != null) DICTIONARY = dictionary; } } }
第二个定义一个记忆对象用于保存训练结果从而达到重用的目的,注意,多次训练虽然可行但不可取。推荐的方式是一次将语料加载进内存训练完成后多次调用。
public class Memory { private String path; private String folderName; private String fileName; private static final Logger logger = LoggerFactory.getLogger(Memory.class); public Memory(@NonNull String path, Policy policy) throws FileNotFoundException { String[] dirs = path.split("[\\\\/]"); String fileName = dirs[dirs.length - 1]; StringBuffer buffer = new StringBuffer(); for (int i = 0; i < dirs.length - 1; i++) { buffer.append(dirs[i]); buffer.append(File.separator); } String folderName = buffer.toString(); this.fileName = fileName; this.folderName = folderName; this.path = folderName + fileName; if (policy.value == Policy.INIT.value) { File file = new File(folderName); if (!file.exists()) file.mkdirs(); file = new File(this.path); if (file.exists()) { logger.info( "Memory in path [{}] has already existed,the operation will delete the old file then continue.", this.path); file.delete(); } } else if (policy.value == Policy.RESTORE.value) { File file = new File(this.path); if (!file.exists()) throw new FileNotFoundException(); } logger.info("Memory certified successfully in path [{}]", this.path); } public String getPath() { return path; } public void setPath(String path) { this.path = path; } public String getFolderName() { return folderName; } public void setFolderName(String folderName) { this.folderName = folderName; } public String getFileName() { return fileName; } public void setFileName(String fileName) { this.fileName = fileName; } public enum Policy { INIT(1), RESTORE(2); public final int value; private Policy(int value) { this.value = value; } } }
然后定义一个分词器tokenizer作为中文分词组件
public class ANSJTokenizer implements Tokenizer { private List<String> tokenizer; private TokenPreProcess tokenPreProcess; private int index = 0; public ANSJTokenizer(String toTokenize) { List<Term> terms = ToAnalysis.parse(toTokenize); tokenizer = new ArrayList<String>(); String word; for (Term term : terms) { word = term.getName(); if (StringUtils.isNotBlank(word)) { tokenizer.add(word); } } } @Override public boolean hasMoreTokens() { return index < tokenizer.size(); } @Override public int countTokens() { return tokenizer.size(); } @Override public String nextToken() { String base = tokenizer.get(index++); if (tokenPreProcess != null) base = tokenPreProcess.preProcess(base); return base; } @Override public List<String> getTokens() { return tokenizer; } @Override public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { this.tokenPreProcess = tokenPreProcessor; } }
分词器工厂用于做一些初始化同时提供分词器
/** * * @author yuyuzhao * @since 2016年4月13日 * */ public class ANSJTokenizerFactory implements TokenizerFactory { private TokenPreProcess tokenPreProcess; public ANSJTokenizerFactory(Dictionary dic) { if (dic != null) dic.expand(); } @Override public Tokenizer create(String toTokenize) { Tokenizer t = new ANSJTokenizer(toTokenize); t.setTokenPreProcessor(tokenPreProcess); return t; } @Override public Tokenizer create(InputStream toTokenize) { throw new UnsupportedOperationException("Could not create Tokenizer with InputStream,Try with String"); } @Override public void setTokenPreProcessor(TokenPreProcess preProcessor) { this.tokenPreProcess = preProcessor; } }
分词前的预处理
public class ChineseTokenPreProcess implements TokenPreProcess { @Override public String preProcess(String token) { if (token == null) return null; return token.replaceAll("[^\u4e00-\u9fa5\\w]+", " "); } }
将本地文档形成句子集合提供给分词器
public class TextSentenceFactory implements SentenceFactory { private StringBuffer buffer; private String charset; private static final String FORMAT = ".txt"; private static final Logger logger = LoggerFactory.getLogger(TextSentenceFactory.class); public TextSentenceFactory(String filePath, String charset) throws IOException { if (Charset.isSupported(charset)) this.charset = charset; else this.charset = "UTF-8"; File file = new File(filePath); if (!file.exists()) { logger.error("Source [" + filePath + "]" + "did not exist!"); return; } if (file.isFile() && file.getName().endsWith(FORMAT)) { buffer = IOUtils.read(file, this.charset, false); } else if (file.isDirectory()) { logger.info("Searching files from directory [{}]", file.getName()); buffer = IOUtils.traverse(file, this.charset, new FormatFileFilter(), false); } } private static class FormatFileFilter implements FileFilter { @Override public boolean accept(File pathname) { return pathname.isFile() ? pathname.getName().endsWith(FORMAT) : false; } } @Override public Collection<String> create() { // 此正则表达式断句是经过多次优化后得出的,请谨慎修改 // 使用逗号进行断句既能保证语意完整,同时又不至于过于复杂造成混淆 // 根据此表达式的测试结果相对来说是最理想的 // Note:This regular expression tend to be the best practice after // several tests,replace it cautiously. return RegexUtils.group(buffer, "[^,,。.??!!\\s]+"); } }
/** * Split the text into sentences * * @author yuyuzhao * @since 2016年4月15日 * */ public interface SentenceFactory { public Collection<String> create(); }
最后是我封装的IOUtils
public class IOUtils { private static final Logger logger = LoggerFactory.getLogger(IOUtils.class); private static final String LINE_BREAKER = "\n"; public static StringBuffer read(File file, String charset, boolean lineBreak) throws IOException { StringBuffer sb = new StringBuffer(); readToBuffer(file, sb, charset, lineBreak); return sb; } public static StringBuffer traverse(File file, String charset, FileFilter filter, boolean breakLine) throws IOException { StringBuffer buffer = new StringBuffer(); traverseFolder(buffer, file, charset, filter, breakLine); return buffer; } private static void traverseFolder(StringBuffer buffer, File folder, String charset, FileFilter filter, boolean breakLine) throws IOException { File[] files = folder.listFiles(); for (File file : files) { if (file.isFile() && filter.accept(file)) { readToBuffer(file, buffer, charset, breakLine); } else if (file.isDirectory()) { traverseFolder(buffer, file, charset, filter, breakLine); } } } private static void readToBuffer(File file, StringBuffer buffer, String charset, boolean lineBreak) throws IOException { InputStream is = new FileInputStream(file); InputStreamReader isr = new InputStreamReader(new BufferedInputStream(is, 10 * 1024), charset); BufferedReader br = new BufferedReader(isr); String line = br.readLine(); while (line != null) { buffer.append(line); if (lineBreak) buffer.append(LINE_BREAKER); line = br.readLine(); } br.close(); isr.close(); is.close(); logger.info("Read CharSequence successfully from path [{}]", file.getAbsolutePath()); } }