本系列博文介绍常见概率语言模型及其变形模型,主要总结PLSA、LDA及LDA的变形模型及参数Inference方法。初步计划内容如下
第一篇:PLSA及EM算法
第二篇:LDA及Gibbs Samping
第三篇:LDA变形模型-Twitter LDA,TimeUserLDA,ATM,Labeled-LDA,MaxEnt-LDA等
第四篇:基于变形LDA的paper分类总结(bibliography)
第五篇:LDA Gibbs Sampling 的JAVA实现
第五篇 LDA Gibbs Sampling的JAVA 实现
在本系列博文的前两篇,我们系统介绍了PLSA, LDA以及它们的参数Inference 方法,重点分析了模型表示和公式推导部分。曾有位学者说,“做研究要顶天立地”,意思是说做研究空有模型和理论还不够,我们还得有扎实的程序code和真实数据的实验结果来作为支撑。本文就重点分析 LDA Gibbs Sampling的JAVA 实现,并给出apply到newsgroup18828新闻文档集上得出的Topic建模结果。
本项目Github地址 https://github.com/yangliuy/LDAGibbsSampling
1、文档集预处理
要用LDA对文本进行topic建模,首先要对文本进行预处理,包括token,去停用词,stem,去noise词,去掉低频词等等。当语料库比较大时,我们也可以不进行stem。然后将文本转换成term的index表示形式,因为后面实现LDA的过程中经常需要在term和index之间进行映射。Documents类的实现如下,里面定义了Document内部类,用于描述文本集合中的文档。
package liuyang.nlp.lda.main; import java.io.File; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import liuyang.nlp.lda.com.FileUtil; import liuyang.nlp.lda.com.Stopwords; /**Class for corpus which consists of M documents * @author yangliu * @blog http://blog.csdn.net/yangliuy * @mail [email protected] */ public class Documents { ArrayList<Document> docs; Map<String, Integer> termToIndexMap; ArrayList<String> indexToTermMap; Map<String,Integer> termCountMap; public Documents(){ docs = new ArrayList<Document>(); termToIndexMap = new HashMap<String, Integer>(); indexToTermMap = new ArrayList<String>(); termCountMap = new HashMap<String, Integer>(); } public void readDocs(String docsPath){ for(File docFile : new File(docsPath).listFiles()){ Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap); docs.add(doc); } } public static class Document { private String docName; int[] docWords; public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){ this.docName = docName; //Read file and initialize word index array ArrayList<String> docLines = new ArrayList<String>(); ArrayList<String> words = new ArrayList<String>(); FileUtil.readLines(docName, docLines); for(String line : docLines){ FileUtil.tokenizeAndLowerCase(line, words); } //Remove stop words and noise words for(int i = 0; i < words.size(); i++){ if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){ words.remove(i); i--; } } //Transfer word to index this.docWords = new int[words.size()]; for(int i = 0; i < words.size(); i++){ String word = words.get(i); if(!termToIndexMap.containsKey(word)){ int newIndex = termToIndexMap.size(); termToIndexMap.put(word, newIndex); indexToTermMap.add(word); termCountMap.put(word, new Integer(1)); docWords[i] = newIndex; } else { docWords[i] = termToIndexMap.get(word); termCountMap.put(word, termCountMap.get(word) + 1); } } words.clear(); } public boolean isNoiseWord(String string) { // TODO Auto-generated method stub string = string.toLowerCase().trim(); Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*"); Matcher m = MY_PATTERN.matcher(string); // filter @xxx and URL if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") || string.matches(".*http:.*") ) return true; if (!m.matches()) { return true; } else return false; } } }
文本预处理完毕后我们就可以实现LDA Gibbs Sampling。 首先我们要定义需要的参数,我的实现中在程序中给出了参数默认值,同时也支持配置文件覆盖,程序默认优先选用配置文件的参数设置。整个算法流程包括模型初始化,迭代Inference,不断更新主题和待估计参数,最后输出收敛时的参数估计结果。
包含主函数的配置参数解析类如下:
package liuyang.nlp.lda.main; import java.io.File; import java.io.IOException; import java.util.ArrayList; import liuyang.nlp.lda.com.FileUtil; import liuyang.nlp.lda.conf.ConstantConfig; import liuyang.nlp.lda.conf.PathConfig; /**Liu Yang's implementation of Gibbs Sampling of LDA * @author yangliu * @blog http://blog.csdn.net/yangliuy * @mail [email protected] */ public class LdaGibbsSampling { public static class modelparameters { float alpha = 0.5f; //usual value is 50 / K float beta = 0.1f;//usual value is 0.1 int topicNum = 100; int iteration = 100; int saveStep = 10; int beginSaveIters = 50; } /**Get parameters from configuring file. If the * configuring file has value in it, use the value. * Else the default value in program will be used * @param ldaparameters * @param parameterFile * @return void */ private static void getParametersFromFile(modelparameters ldaparameters, String parameterFile) { // TODO Auto-generated method stub ArrayList<String> paramLines = new ArrayList<String>(); FileUtil.readLines(parameterFile, paramLines); for(String line : paramLines){ String[] lineParts = line.split("\t"); switch(parameters.valueOf(lineParts[0])){ case alpha: ldaparameters.alpha = Float.valueOf(lineParts[1]); break; case beta: ldaparameters.beta = Float.valueOf(lineParts[1]); break; case topicNum: ldaparameters.topicNum = Integer.valueOf(lineParts[1]); break; case iteration: ldaparameters.iteration = Integer.valueOf(lineParts[1]); break; case saveStep: ldaparameters.saveStep = Integer.valueOf(lineParts[1]); break; case beginSaveIters: ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]); break; } } } public enum parameters{ alpha, beta, topicNum, iteration, saveStep, beginSaveIters; } /** * @param args * @throws IOException */ public static void main(String[] args) throws IOException { // TODO Auto-generated method stub String originalDocsPath = PathConfig.ldaDocsPath; String resultPath = PathConfig.LdaResultsPath; String parameterFile= ConstantConfig.LDAPARAMETERFILE; modelparameters ldaparameters = new modelparameters(); getParametersFromFile(ldaparameters, parameterFile); Documents docSet = new Documents(); docSet.readDocs(originalDocsPath); System.out.println("wordMap size " + docSet.termToIndexMap.size()); FileUtil.mkdir(new File(resultPath)); LdaModel model = new LdaModel(ldaparameters); System.out.println("1 Initialize the model ..."); model.initializeModel(docSet); System.out.println("2 Learning and Saving the model ..."); model.inferenceModel(docSet); System.out.println("3 Output the final model ..."); model.saveIteratedModel(ldaparameters.iteration, docSet); System.out.println("Done!"); } }
package liuyang.nlp.lda.main; /**Class for Lda model * @author yangliu * @blog http://blog.csdn.net/yangliuy * @mail [email protected] */ import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; import liuyang.nlp.lda.com.FileUtil; import liuyang.nlp.lda.conf.PathConfig; public class LdaModel { int [][] doc;//word index array int V, K, M;//vocabulary size, topic number, document number int [][] z;//topic label array float alpha; //doc-topic dirichlet prior parameter float beta; //topic-word dirichlet prior parameter int [][] nmk;//given document m, count times of topic k. M*K int [][] nkt;//given topic k, count times of term t. K*V int [] nmkSum;//Sum for each row in nmk int [] nktSum;//Sum for each row in nkt double [][] phi;//Parameters for topic-word distribution K*V double [][] theta;//Parameters for doc-topic distribution M*K int iterations;//Times of iterations int saveStep;//The number of iterations between two saving int beginSaveIters;//Begin save model at this iteration public LdaModel(LdaGibbsSampling.modelparameters modelparam) { // TODO Auto-generated constructor stub alpha = modelparam.alpha; beta = modelparam.beta; iterations = modelparam.iteration; K = modelparam.topicNum; saveStep = modelparam.saveStep; beginSaveIters = modelparam.beginSaveIters; } public void initializeModel(Documents docSet) { // TODO Auto-generated method stub M = docSet.docs.size(); V = docSet.termToIndexMap.size(); nmk = new int [M][K]; nkt = new int[K][V]; nmkSum = new int[M]; nktSum = new int[K]; phi = new double[K][V]; theta = new double[M][K]; //initialize documents index array doc = new int[M][]; for(int m = 0; m < M; m++){ //Notice the limit of memory int N = docSet.docs.get(m).docWords.length; doc[m] = new int[N]; for(int n = 0; n < N; n++){ doc[m][n] = docSet.docs.get(m).docWords[n]; } } //initialize topic lable z for each word z = new int[M][]; for(int m = 0; m < M; m++){ int N = docSet.docs.get(m).docWords.length; z[m] = new int[N]; for(int n = 0; n < N; n++){ int initTopic = (int)(Math.random() * K);// From 0 to K - 1 z[m][n] = initTopic; //number of words in doc m assigned to topic initTopic add 1 nmk[m][initTopic]++; //number of terms doc[m][n] assigned to topic initTopic add 1 nkt[initTopic][doc[m][n]]++; // total number of words assigned to topic initTopic add 1 nktSum[initTopic]++; } // total number of words in document m is N nmkSum[m] = N; } } public void inferenceModel(Documents docSet) throws IOException { // TODO Auto-generated method stub if(iterations < saveStep + beginSaveIters){ System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters)); System.exit(0); } for(int i = 0; i < iterations; i++){ System.out.println("Iteration " + i); if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){ //Saving the model System.out.println("Saving model at iteration " + i +" ... "); //Firstly update parameters updateEstimatedParameters(); //Secondly print model variables saveIteratedModel(i, docSet); } //Use Gibbs Sampling to update z[][] for(int m = 0; m < M; m++){ int N = docSet.docs.get(m).docWords.length; for(int n = 0; n < N; n++){ // Sample from p(z_i|z_-i, w) int newTopic = sampleTopicZ(m, n); z[m][n] = newTopic; } } } } private void updateEstimatedParameters() { // TODO Auto-generated method stub for(int k = 0; k < K; k++){ for(int t = 0; t < V; t++){ phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta); } } for(int m = 0; m < M; m++){ for(int k = 0; k < K; k++){ theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); } } } private int sampleTopicZ(int m, int n) { // TODO Auto-generated method stub // Sample from p(z_i|z_-i, w) using Gibbs upde rule //Remove topic label for w_{m,n} int oldTopic = z[m][n]; nmk[m][oldTopic]--; nkt[oldTopic][doc[m][n]]--; nmkSum[m]--; nktSum[oldTopic]--; //Compute p(z_i = k|z_-i, w) double [] p = new double[K]; for(int k = 0; k < K; k++){ p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); } //Sample a new topic label for w_{m, n} like roulette //Compute cumulated probability for p for(int k = 1; k < K; k++){ p[k] += p[k - 1]; } double u = Math.random() * p[K - 1]; //p[] is unnormalised int newTopic; for(newTopic = 0; newTopic < K; newTopic++){ if(u < p[newTopic]){ break; } } //Add new topic label for w_{m, n} nmk[m][newTopic]++; nkt[newTopic][doc[m][n]]++; nmkSum[m]++; nktSum[newTopic]++; return newTopic; } public void saveIteratedModel(int iters, Documents docSet) throws IOException { // TODO Auto-generated method stub //lda.params lda.phi lda.theta lda.tassign lda.twords //lda.params String resPath = PathConfig.LdaResultsPath; String modelName = "lda_" + iters; ArrayList<String> lines = new ArrayList<String>(); lines.add("alpha = " + alpha); lines.add("beta = " + beta); lines.add("topicNum = " + K); lines.add("docNum = " + M); lines.add("termNum = " + V); lines.add("iterations = " + iterations); lines.add("saveStep = " + saveStep); lines.add("beginSaveIters = " + beginSaveIters); FileUtil.writeLines(resPath + modelName + ".params", lines); //lda.phi K*V BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi")); for (int i = 0; i < K; i++){ for (int j = 0; j < V; j++){ writer.write(phi[i][j] + "\t"); } writer.write("\n"); } writer.close(); //lda.theta M*K writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta")); for(int i = 0; i < M; i++){ for(int j = 0; j < K; j++){ writer.write(theta[i][j] + "\t"); } writer.write("\n"); } writer.close(); //lda.tassign writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign")); for(int m = 0; m < M; m++){ for(int n = 0; n < doc[m].length; n++){ writer.write(doc[m][n] + ":" + z[m][n] + "\t"); } writer.write("\n"); } writer.close(); //lda.twords phi[][] K*V writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords")); int topNum = 20; //Find the top 20 topic words in each topic for(int i = 0; i < K; i++){ List<Integer> tWordsIndexArray = new ArrayList<Integer>(); for(int j = 0; j < V; j++){ tWordsIndexArray.add(new Integer(j)); } Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i])); writer.write("topic " + i + "\t:\t"); for(int t = 0; t < topNum; t++){ writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t"); } writer.write("\n"); } writer.close(); } public class TwordsComparable implements Comparator<Integer> { public double [] sortProb; // Store probability of each word in topic k public TwordsComparable (double[] sortProb){ this.sortProb = sortProb; } @Override public int compare(Integer o1, Integer o2) { // TODO Auto-generated method stub //Sort topic word index according to the probability of each word in topic k if(sortProb[o1] > sortProb[o2]) return -1; else if(sortProb[o1] < sortProb[o2]) return 1; else return 0; } } }
还有com和conf目录下的源文件分别放置常用函数和配置类,完整的JAVA工程见Github https://github.com/yangliuy/LDAGibbsSampling
3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析
下面我们给出将上面的LDA Gibbs Sampling的实现Apply到Newsgroup 18828文档集进行主题分析的结果。 我实验时用到的数据已经上传到Github中,感兴趣的朋友可以直接从Github中下载工程运行。 我在Newsgroup 18828文档集随机选择了9个目录,每个目录下选择一个文档,将它们放置在data\LdaOriginalDocs目录下,我设定的模型参数如下
alpha 0.5 beta 0.1 topicNum 10 iteration 100 saveStep 10 beginSaveIters 80
经过100次Gibbs Sampling迭代后,程序输出10个Topic下top的topic words以及对应的概率值如下
我们可以看到虽然是unsupervised learning, LDA分析出来的Topic words还是非常make sense的。比如第5个topic是宗教类的,第6个topic是天文类的,第7个topic是计算机类的。程序的输出还包括模型参数.param文件,topic-word分布phi向量.phi文件,doc-topic分布theta向量.theta文件以及每个文档中每个单词分配到的主题label的.tassign文件。感兴趣的朋友可以从Github https://github.com/yangliuy/LDAGibbsSampling 下载完整工程自己换用其他数据集进行主题分析实验。 本程序是初步实现版本,如果大家发现任何问题或者bug欢迎交流,我第一时间在Github修复bug更新版本。
4 参考文献
[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.
[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.
[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.
[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.
[6] Jgibblda, http://jgibblda.sourceforge.net/
[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3 (March 2003), 993-1022.