关于LDA的介绍见前面几篇文章,这里是Gibbs抽样解LDA的实现
可以看到收敛之后主题的结果基本不变
package org.jazywoo.lda; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; public class Document { private String docName; private List<Integer> words; //词对应的termID public Document(String docName) { this.docName=docName; } public String getDocName() { return docName; } public void setDocName(String docName) { this.docName = docName; } public List<Integer> getWords() { return words; } public void setWords(List<Integer> words) { this.words = words; } }
package org.jazywoo.lda; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.StringTokenizer; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.jazywoo.tokenization.Tokenization; import ICTCLAS.I3S.AC.ICTCLAS50; public class Corpus { private List<Document> docs; //文档 private Map<String, Integer> termIndexMap;//词--序号 private List<String> terms; private Map<String, Integer> termCountMap;//词频 public Corpus() { docs = new ArrayList<Document>(); termIndexMap = new HashMap<String, Integer>(); terms = new ArrayList<String>(); termCountMap = new HashMap<String, Integer>(); } public void loadData(String path) throws IOException{ File folder=new File(path); if(folder.exists()){ File[] files=folder.listFiles(); for(File f:files){ BufferedReader br = new BufferedReader(new FileReader(f)); String line = ""; StringBuffer buf=new StringBuffer(); while ((line = br.readLine()) != null) { buf.append(line+" "); } addDocument("doc", buf.toString()); } } } private void addDocument(String docName, String content){ Document document=new Document(docName); String[] words=getWordsFromSentence(content); List<Integer> wordsList=new ArrayList<Integer>(); int termCount=0; for(int i=0;i<words.length;++i){ String term=words[i]; if(termIndexMap.containsKey(term)){ termCountMap.put(term, termCountMap.get(term)+1); }else{//不存在该词 int index=termIndexMap.size(); termIndexMap.put(term, index); terms.add(term); termCountMap.put(term, 0); } int termID=termIndexMap.get(term); wordsList.add(termID); } document.setWords(wordsList); docs.add(document); } /**从句子中得到分词,过滤掉停用词和干扰词 * @param content * @return */ private String[] getWordsFromSentence(String content){ ICTCLAS50 ictclas=new ICTCLAS50(); Tokenization tokenization=new Tokenization(ictclas); boolean isOK=tokenization.init(); String[] words=null; if(isOK){ try { words=tokenization.getPartedWordsWithoutSimbol(content); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } tokenization.finish(); } return words; } public List<Document> getDocs() { return docs; } public void setDocs(List<Document> docs) { this.docs = docs; } public Map<String, Integer> getTermIndexMap() { return termIndexMap; } public void setTermIndexMap(Map<String, Integer> termIndexMap) { this.termIndexMap = termIndexMap; } public List<String> getTerms() { return terms; } public void setTerms(List<String> terms) { this.terms = terms; } }
package org.jazywoo.lda; import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; /**Gibbs Sampling LDA * @author jazywoo * */ public class LdaModel { private Corpus docSet;//处理的文档 private int[][] doc;// word index array,每个文本中每个词在字典indexToTermMap中的序号 private int V, K, M;// vocabulary size, topic number, document number private int[][] z;// topic label array,每个文本的每个词对应的topic的编号 private float alpha; // doc-topic dirichlet prior parameter private float beta; // topic-word dirichlet prior parameter private int[][] nmk;// given document m, count times of topic k. M*K :给定document m中的词,每个topic的使用term词数 private int[][] nkt;// given topic k, count times of term t. K*V :给定topic k的每个term的使用词数 private int[] nmkSum;// Sum for each row in nmt,nmySum[m]=n:也就是文档m中word的个数为n private int[] nktSum;// Sum for each row in nkt,nkt[k]=n:被指定给topic k的term/word的个数为n // 两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布, // 前者是k维(k为Topic总数)向量,后者是v维向量(v为词典中term总数)。 private double[][] theta;// Parameters for doc-topic distribution M*K private double[][] phi;// Parameters for topic-word distribution K*V private int iterations;// Times of iterations private int saveStep;// The number of iterations between two saving private int beginSaveIters;// Begin save model at this iteration public LdaModel(LdaModel.ModelParameter parameter) { alpha = parameter.alpha; beta = parameter.beta; iterations = parameter.iteration; K = parameter.topicNum; saveStep = parameter.saveStep; beginSaveIters = parameter.beginSaveIters; } public void initModal(Corpus docSet1) { this.docSet=docSet1; M = docSet.getDocs().size(); V = docSet.getTerms().size(); nmk = new int [M][K]; nkt = new int[K][V]; nmkSum = new int[M]; nktSum = new int[K]; phi = new double[K][V]; theta = new double[M][K]; //初始化 每个文本中每个词在字典indexToTermMap中的序号 //initialize documents index array doc = new int[M][]; for(int m = 0; m < M; m++){ //Notice the limit of memory int N = docSet.getDocs().get(m).getWords().size(); doc[m] = new int[N]; for(int n = 0; n < N; n++){ doc[m][n] = docSet.getDocs().get(m).getWords().get(n); } } // 初始化 每个文本的每个词对应的topic的编号 //initialize topic lable z for each word z = new int[M][]; for(int m = 0; m < M; m++){ int N = docSet.getDocs().get(m).getWords().size(); z[m] = new int[N]; for(int n = 0; n < N; n++){ //初始时随机给文本中的每个单词分配主题z[m][n]_old int initTopic = (int)(Math.random() * K);// From 0 to K - 1 z[m][n] = initTopic; //number of words in doc m assigned to topic initTopic add 1 nmk[m][initTopic]++; //number of terms doc[m][n] assigned to topic initTopic add 1 nkt[initTopic][doc[m][n]]++; // total number of words assigned to topic initTopic add 1 nktSum[initTopic]++; } // total number of words in document m is N nmkSum[m] = N; } } public void inferenceModel() throws IOException { if(iterations < saveStep + beginSaveIters){ System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters)); System.exit(0); } for(int i = 0; i < iterations; i++){ System.out.println("Iteration " + i); if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){ //Saving the model System.out.println("Saving model at iteration " + i +" ... "); //Firstly update parameters updateEstimatedParameters(); //Secondly print model variables saveIteratedModel(i); } // z[][]每个文本的每个词对应的topic的编号 //Use Gibbs Sampling to update z[][] for(int m = 0; m < M; m++){ int N = docSet.getDocs().get(m).getWords().size(); for(int n = 0; n < N; n++){ // Sample from p(z_i|z_-i, w) int newTopic = sampleTopicZ(m, n); z[m][n] = newTopic; } } } } /** * 初始时随机给文本中的每个单词分配主题z[m][n]_old,(这一步已经在初始化中完成) * 然后统计每个主题z下出现term t的数量以及每个文档m下出现主题z中的词的数量, * 每一轮计算p(z_i|z_-i, d, w),即排除当前词的主题分配, * 根据其他所有词的主题分配估计当前词分配各个主题的概率。 * 当得到当前词属于所有主题z的概率分布后, * 根据这个概率分布为该词sample一个新的主题z[m][n]_new。 * 然后用同样的方法不断更新下一个词的主题, * 直到发现每个文档下Topic分布和每个Topic下词的分布收敛,算法停止, * 输出待估计的参数和,最终每个单词的主题也同时得出。 * 实际应用中会设置最大迭代次数。每一次计算p(z_i|z_-i, d, w)的公式称为Gibbs updating rule. * @param m * @param n * @return */ private int sampleTopicZ(int m, int n) { // Sample from p(z_i|z_-i, w) using Gibbs upde rule //Remove topic label for w_{m,n} //首先当前词的主题分配 int oldTopic = z[m][n]; nmk[m][oldTopic]--; nkt[oldTopic][doc[m][n]]--; nmkSum[m]--; nktSum[oldTopic]--; //Compute p(z_i = k|z_-i, d, w) //当得到当前文档,当前词属于所有主题z的概率分布 double [] p = new double[K]; for(int k = 0; k < K; k++){ //nkt-给定topic k的每个term的使用词数/nktSum-指定给topic k的term/word的个数 //nmk-给定document m的每个topic的使用词数/nmkSum-文档m中word的个数 //Gibbs抽样 P(z|w,alpha,beta) = P(w,z | alpha,beta) / P(w | alpha,beta) p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); //p[k]=phi[k][doc[m][n]]*theta[m][k]; } //为该词分配一个新主题 //Sample a new topic label for w_{m, n} like roulette //Compute cumulated probability for p for(int k = 1; k < K; k++){ p[k] += p[k - 1]; } double u = Math.random() * p[K - 1]; //p[] is unnormalised int newTopic; for(newTopic = 0; newTopic < K; newTopic++){ if(u < p[newTopic]){ break; } } //Add new topic label for w_{m, n} nmk[m][newTopic]++; nkt[newTopic][doc[m][n]]++; nmkSum[m]++; nktSum[newTopic]++; return newTopic; } /**估计 文档-主题theta参数,主题-词phi参数 * theta[m][k]表示第m个文档下的Topic分布,p(z_i|d_j)=p(z_i,d_j)/p(d_j) * phi[k][t]表示第k个Topic下词的分布p(w_i|z_j) */ private void updateEstimatedParameters() { for(int k = 0; k < K; k++){ for(int t = 0; t < V; t++){ phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta); //给定topic k的每个term的使用词数/指定给topic k的term的个数 } } for(int m = 0; m < M; m++){ for(int k = 0; k < K; k++){ theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); //给定document m的每个topic的使用词数/文档m中word的个数 } } } /**用于保存分析的数据结果 * @param iters * @param docSet * @throws IOException */ public void saveIteratedModel(int iters) throws IOException { // lda.params lda.phi lda.theta lda.tassign lda.twords // lda.params String resPath = "D:\\result\\"; String modelName = "lda_" + iters; StringBuffer buf=new StringBuffer(); buf.append("alpha = " + alpha); buf.append("beta = " + beta); buf.append("topicNum = " + K); buf.append("docNum = " + M); buf.append("termNum = " + V); buf.append("iterations = " + iterations); buf.append("saveStep = " + saveStep); buf.append("beginSaveIters = " + beginSaveIters); BufferedWriter writer; // writer = new BufferedWriter(new FileWriter(resPath // + modelName + ".params.txt")); // writer.write(buf.toString()); // writer.close(); // // //两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布, // // lda.phi K*V // writer = new BufferedWriter(new FileWriter(resPath // + modelName + ".phi.txt")); // for (int i = 0; i < K; i++) { // for (int j = 0; j < V; j++) { // writer.write("topic-word="+phi[i][j] + "\t"); // } // writer.write("\n"); // } // writer.close(); // // lda.theta M*K // writer = new BufferedWriter(new FileWriter(resPath + modelName // + ".theta.txt")); // for (int i = 0; i < M; i++) { // for (int j = 0; j < K; j++) { // writer.write("doc-topic="+theta[i][j] + "\t"); // } // writer.write("\n"); // } // writer.close(); // // // doc[m][n]每个文本中每个词在字典indexToTermMap中的序号 // // z[m][n]每个文本的每个词对应的topic的编号 // writer = new BufferedWriter(new FileWriter(resPath + modelName // + ".wordIndex2topicIndex.txt")); // for (int m = 0; m < M; m++) { // for (int n = 0; n < doc[m].length; n++) { // writer.write("doc[m][word]_index="+doc[m][n] + ":" +"z[m][word]_topicIndex="+ z[m][n] + "\t"); // } // writer.write("\n"); // } // writer.close(); // lda.twords phi[][] K*V // 每个topic 前20个 出现概率高的,即 phi[i]大的 writer = new BufferedWriter(new FileWriter(resPath + modelName + ".topic_words.txt")); int topNum = 20; // Find the top 20 topic words in each topic for (int i = 0; i < K; i++) { List<Integer> tWordsIndexArray = new ArrayList<Integer>();//topic的word的编号 for (int j = 0; j < V; j++) { tWordsIndexArray.add(new Integer(j)); } Collections.sort(tWordsIndexArray, new LdaModel.ArrayDoubleComparator(phi[i]));//按phi[i],即出现概率大的 writer.write("topic " + i + ":\t"); for (int t = 0; t < topNum; t++) { // writer.write(docSet.getTerms().get(tWordsIndexArray.get(t)) // + " " + phi[i][tWordsIndexArray.get(t)] + " ;\t"); writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))+" "); } writer.write("\n"); } writer.close(); } /** * @author jazywoo * 用于排序,比较phi[i],topic中词 出现概率高的 */ public class ArrayDoubleComparator implements Comparator<Integer> { private double[] sortProb; // Store probability of each word in topic k public ArrayDoubleComparator(double[] sortProb) { this.sortProb = sortProb; } @Override public int compare(Integer o1, Integer o2) {// Sort topic word index according to the probability of each word // in topic k if (sortProb[o1] > sortProb[o2]) return -1; else if (sortProb[o1] < sortProb[o2]) return 1; else return 0; } } public static class ModelParameter{ public float alpha = 0.5f; //usual value is 50 / K public float beta = 0.1f;//usual value is 0.1 public int topicNum = 10; public int iteration = 100; public int saveStep = 10; public int beginSaveIters = 80; } }
package org.jazywoo.lda; import java.io.IOException; public class LDATest { /** * @param args * @throws IOException */ public static void main(String[] args) throws IOException { LdaModel.ModelParameter parameter=new LdaModel.ModelParameter(); LdaModel ldaModel=new LdaModel(parameter); String path="D:\\zz"; Corpus docSet=new Corpus(); docSet.loadData(path); ldaModel.initModal(docSet); ldaModel.inferenceModel(); ldaModel.saveIteratedModel(parameter.iteration); } }