Gibbs Sampling实现LDA


关于LDA的介绍见前面几篇文章,这里是Gibbs抽样解LDA的实现

Gibbs Sampling实现LDA_第1张图片

可以看到收敛之后主题的结果基本不变

package org.jazywoo.lda;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Document {
	private String docName;
	private List<Integer> words; //词对应的termID
    
    public Document(String docName) {
    	this.docName=docName;
    }

	public String getDocName() {
		return docName;
	}

	public void setDocName(String docName) {
		this.docName = docName;
	}

	public List<Integer> getWords() {
		return words;
	}

	public void setWords(List<Integer> words) {
		this.words = words;
	}

	
    

	
}

package org.jazywoo.lda;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.jazywoo.tokenization.Tokenization;

import ICTCLAS.I3S.AC.ICTCLAS50;

public class Corpus {
	private List<Document> docs;  //文档
	private Map<String, Integer> termIndexMap;//词--序号
	private List<String> terms;
	private Map<String, Integer> termCountMap;//词频
    
    public Corpus() {
    	docs = new ArrayList<Document>();
    	termIndexMap = new HashMap<String, Integer>();
    	terms = new ArrayList<String>();
        termCountMap = new HashMap<String, Integer>();
	}
    
    public void loadData(String path) throws IOException{
    	 File folder=new File(path);
    	 if(folder.exists()){
    		 File[] files=folder.listFiles();
    		 for(File f:files){
                 BufferedReader br = new BufferedReader(new FileReader(f));
                 String line = "";
                 StringBuffer buf=new StringBuffer();
                 while ((line = br.readLine()) != null) {
                     buf.append(line+" ");
                 }
                 addDocument("doc", buf.toString());
    		 }
    	 }
    }
    
    private void addDocument(String docName, String content){
    	Document document=new Document(docName);
    	String[] words=getWordsFromSentence(content);
    	List<Integer> wordsList=new ArrayList<Integer>();
    	int termCount=0;
    	for(int i=0;i<words.length;++i){
    		String term=words[i];
    		if(termIndexMap.containsKey(term)){
    			termCountMap.put(term, termCountMap.get(term)+1);
    		}else{//不存在该词
    			int index=termIndexMap.size();
    			termIndexMap.put(term, index);
    			terms.add(term);
    			termCountMap.put(term, 0);
    		}
    		int termID=termIndexMap.get(term);
    		wordsList.add(termID);
    	}
    	document.setWords(wordsList);
    	docs.add(document);
    	
    }
    /**从句子中得到分词,过滤掉停用词和干扰词
     * @param content
     * @return
     */
    private String[] getWordsFromSentence(String content){
    	ICTCLAS50 ictclas=new ICTCLAS50();
    	Tokenization tokenization=new Tokenization(ictclas);
    	boolean isOK=tokenization.init();
    	String[] words=null;
    	if(isOK){
    		try {
				words=tokenization.getPartedWordsWithoutSimbol(content);
			} catch (UnsupportedEncodingException e) {
				e.printStackTrace();
			}
        	tokenization.finish();
    	}
    	
    	return words;
    }

	public List<Document> getDocs() {
		return docs;
	}

	public void setDocs(List<Document> docs) {
		this.docs = docs;
	}

	public Map<String, Integer> getTermIndexMap() {
		return termIndexMap;
	}

	public void setTermIndexMap(Map<String, Integer> termIndexMap) {
		this.termIndexMap = termIndexMap;
	}

	public List<String> getTerms() {
		return terms;
	}

	public void setTerms(List<String> terms) {
		this.terms = terms;
	}

}

package org.jazywoo.lda;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;


/**Gibbs Sampling LDA
 * @author jazywoo
 *
 */
public class LdaModel {
	private Corpus docSet;//处理的文档
	private int[][] doc;// word index array,每个文本中每个词在字典indexToTermMap中的序号
	private int V, K, M;// vocabulary size, topic number, document number
	private int[][] z;// topic label array,每个文本的每个词对应的topic的编号
	private float alpha; // doc-topic dirichlet prior parameter
	private float beta; // topic-word dirichlet prior parameter
	private int[][] nmk;// given document m, count times of topic k. M*K :给定document m中的词,每个topic的使用term词数
	private int[][] nkt;// given topic k, count times of term t. K*V :给定topic k的每个term的使用词数
	private int[] nmkSum;// Sum for each row in nmt,nmySum[m]=n:也就是文档m中word的个数为n
	private int[] nktSum;// Sum for each row in nkt,nkt[k]=n:被指定给topic k的term/word的个数为n
    
    //  两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布,
    //  前者是k维(k为Topic总数)向量,后者是v维向量(v为词典中term总数)。
	private double[][] theta;// Parameters for doc-topic distribution M*K
	private double[][] phi;// Parameters for topic-word distribution K*V
   
	private int iterations;// Times of iterations
	private int saveStep;// The number of iterations between two saving
	private int beginSaveIters;// Begin save model at this iteration
    
    public LdaModel(LdaModel.ModelParameter parameter) {
    	alpha = parameter.alpha;
		beta = parameter.beta;
		iterations = parameter.iteration;
		K = parameter.topicNum;
		saveStep = parameter.saveStep;
		beginSaveIters = parameter.beginSaveIters;
	}
    
    public void initModal(Corpus docSet1) {
		this.docSet=docSet1;
		M = docSet.getDocs().size();
		V = docSet.getTerms().size();
		nmk = new int [M][K];
		nkt = new int[K][V];
		nmkSum = new int[M];
		nktSum = new int[K];
		phi = new double[K][V];
		theta = new double[M][K];
		//初始化   每个文本中每个词在字典indexToTermMap中的序号
		//initialize documents index array
		doc = new int[M][];
		for(int m = 0; m < M; m++){
			//Notice the limit of memory
			int N = docSet.getDocs().get(m).getWords().size();
			doc[m] = new int[N];
			for(int n = 0; n < N; n++){
				doc[m][n] = docSet.getDocs().get(m).getWords().get(n);
			}
		}
		// 初始化 每个文本的每个词对应的topic的编号
		//initialize topic lable z for each word
		z = new int[M][];
		for(int m = 0; m < M; m++){
			int N = docSet.getDocs().get(m).getWords().size();
			z[m] = new int[N];
			for(int n = 0; n < N; n++){
				//初始时随机给文本中的每个单词分配主题z[m][n]_old
				int initTopic = (int)(Math.random() * K);// From 0 to K - 1
				z[m][n] = initTopic;
				//number of words in doc m assigned to topic initTopic add 1
				nmk[m][initTopic]++;
				//number of terms doc[m][n] assigned to topic initTopic add 1
				nkt[initTopic][doc[m][n]]++;
				// total number of words assigned to topic initTopic add 1
				nktSum[initTopic]++;
			}
			 // total number of words in document m is N
			nmkSum[m] = N;
		}
	}
    public void inferenceModel() throws IOException {
		if(iterations < saveStep + beginSaveIters){
			System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));
			System.exit(0);
		}
		for(int i = 0; i < iterations; i++){
			System.out.println("Iteration " + i);
			if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){
				//Saving the model
				System.out.println("Saving model at iteration " + i +" ... ");
				//Firstly update parameters
				updateEstimatedParameters();
				//Secondly print model variables
				saveIteratedModel(i);
			}
			 // z[][]每个文本的每个词对应的topic的编号
			//Use Gibbs Sampling to update z[][]
			for(int m = 0; m < M; m++){
				int N = docSet.getDocs().get(m).getWords().size();
				for(int n = 0; n < N; n++){
					// Sample from p(z_i|z_-i, w)
					int newTopic = sampleTopicZ(m, n);
					z[m][n] = newTopic;
				}
			}
		}
	}
    /**
     * 初始时随机给文本中的每个单词分配主题z[m][n]_old,(这一步已经在初始化中完成)
     * 然后统计每个主题z下出现term t的数量以及每个文档m下出现主题z中的词的数量,
     * 每一轮计算p(z_i|z_-i, d, w),即排除当前词的主题分配,
     * 根据其他所有词的主题分配估计当前词分配各个主题的概率。
     * 当得到当前词属于所有主题z的概率分布后,
     * 根据这个概率分布为该词sample一个新的主题z[m][n]_new。
     * 然后用同样的方法不断更新下一个词的主题,
     * 直到发现每个文档下Topic分布和每个Topic下词的分布收敛,算法停止,
     * 输出待估计的参数和,最终每个单词的主题也同时得出。
     * 实际应用中会设置最大迭代次数。每一次计算p(z_i|z_-i, d, w)的公式称为Gibbs updating rule.
     * @param m
     * @param n
     * @return
     */
    private int sampleTopicZ(int m, int n) {
		// Sample from p(z_i|z_-i, w) using Gibbs upde rule
		
		//Remove topic label for w_{m,n}
    	//首先当前词的主题分配
		int oldTopic = z[m][n];
		nmk[m][oldTopic]--;
		nkt[oldTopic][doc[m][n]]--;
		nmkSum[m]--;
		nktSum[oldTopic]--;
		
		//Compute p(z_i = k|z_-i, d, w)
		//当得到当前文档,当前词属于所有主题z的概率分布
		double [] p = new double[K];
		for(int k = 0; k < K; k++){
			//nkt-给定topic k的每个term的使用词数/nktSum-指定给topic k的term/word的个数
        	//nmk-给定document m的每个topic的使用词数/nmkSum-文档m中word的个数
			//Gibbs抽样  P(z|w,alpha,beta) = P(w,z | alpha,beta) / P(w | alpha,beta)
			p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) 
			            * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
			//p[k]=phi[k][doc[m][n]]*theta[m][k];
		}
		 //为该词分配一个新主题
		//Sample a new topic label for w_{m, n} like roulette
		//Compute cumulated probability for p
		for(int k = 1; k < K; k++){
			p[k] += p[k - 1];
		}
		double u = Math.random() * p[K - 1]; //p[] is unnormalised
		int newTopic;
		for(newTopic = 0; newTopic < K; newTopic++){
			if(u < p[newTopic]){
				break;
			}
		}
		
		//Add new topic label for w_{m, n}
		nmk[m][newTopic]++;
		nkt[newTopic][doc[m][n]]++;
		nmkSum[m]++;
		nktSum[newTopic]++;
		return newTopic;
	}
    
    
    /**估计 文档-主题theta参数,主题-词phi参数
     * theta[m][k]表示第m个文档下的Topic分布,p(z_i|d_j)=p(z_i,d_j)/p(d_j)
     * phi[k][t]表示第k个Topic下词的分布p(w_i|z_j)
     */
    private void updateEstimatedParameters() {
		for(int k = 0; k < K; k++){
			for(int t = 0; t < V; t++){
				phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
				//给定topic k的每个term的使用词数/指定给topic k的term的个数
			}
		}
		
		for(int m = 0; m < M; m++){
			for(int k = 0; k < K; k++){
				theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
				//给定document m的每个topic的使用词数/文档m中word的个数
			}
		}
	}
    
    /**用于保存分析的数据结果
     * @param iters
     * @param docSet
     * @throws IOException
     */
    public void saveIteratedModel(int iters) throws IOException {
		// lda.params lda.phi lda.theta lda.tassign lda.twords
		// lda.params
		String resPath = "D:\\result\\";
		String modelName = "lda_" + iters;
		StringBuffer buf=new StringBuffer();
		buf.append("alpha = " + alpha);
		buf.append("beta = " + beta);
		buf.append("topicNum = " + K);
		buf.append("docNum = " + M);
		buf.append("termNum = " + V);
		buf.append("iterations = " + iterations);
		buf.append("saveStep = " + saveStep);
		buf.append("beginSaveIters = " + beginSaveIters);
		BufferedWriter writer;
//		writer = new BufferedWriter(new FileWriter(resPath
//		        + modelName + ".params.txt"));
//		 writer.write(buf.toString());
//		 writer.close();
//		 
//		//两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布,
//		// lda.phi K*V
//		writer = new BufferedWriter(new FileWriter(resPath
//		        + modelName + ".phi.txt"));
//		for (int i = 0; i < K; i++) {
//		    for (int j = 0; j < V; j++) {
//		        writer.write("topic-word="+phi[i][j] + "\t");
//		    }
//		    writer.write("\n");
//		}
//		writer.close();
//		// lda.theta M*K
//		writer = new BufferedWriter(new FileWriter(resPath + modelName
//		        + ".theta.txt"));
//		for (int i = 0; i < M; i++) {
//		    for (int j = 0; j < K; j++) {
//		        writer.write("doc-topic="+theta[i][j] + "\t");
//		    }
//		    writer.write("\n");
//		}
//		writer.close();
//		
//		// doc[m][n]每个文本中每个词在字典indexToTermMap中的序号
//		// z[m][n]每个文本的每个词对应的topic的编号
//		writer = new BufferedWriter(new FileWriter(resPath + modelName
//		        + ".wordIndex2topicIndex.txt"));
//		for (int m = 0; m < M; m++) {
//		    for (int n = 0; n < doc[m].length; n++) {
//		        writer.write("doc[m][word]_index="+doc[m][n] + ":" +"z[m][word]_topicIndex="+ z[m][n] + "\t");
//		    }
//		    writer.write("\n");
//		}
//		writer.close();
		
		// lda.twords phi[][] K*V
		// 每个topic 前20个 出现概率高的,即 phi[i]大的
		writer = new BufferedWriter(new FileWriter(resPath + modelName
		        + ".topic_words.txt"));
		int topNum = 20; // Find the top 20 topic words in each topic
		for (int i = 0; i < K; i++) {
		    List<Integer> tWordsIndexArray = new ArrayList<Integer>();//topic的word的编号
		    for (int j = 0; j < V; j++) {
		        tWordsIndexArray.add(new Integer(j));
		    }
		    Collections.sort(tWordsIndexArray,
		            new LdaModel.ArrayDoubleComparator(phi[i]));//按phi[i],即出现概率大的
		    writer.write("topic " + i + ":\t");
		    for (int t = 0; t < topNum; t++) {
//		        writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))
//		                + " " + phi[i][tWordsIndexArray.get(t)] + " ;\t");
		    	writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))+" ");
		    }
		    writer.write("\n");
		}
		writer.close();
	}
    
    
    
    /**
     * @author jazywoo
     * 用于排序,比较phi[i],topic中词 出现概率高的
     */
    public class ArrayDoubleComparator implements Comparator<Integer> {
        private double[] sortProb; // Store probability of each word in topic k

        public ArrayDoubleComparator(double[] sortProb) {
            this.sortProb = sortProb;
        }

        @Override
        public int compare(Integer o1, Integer o2) {// Sort topic word index according to the probability of each word
            // in topic k
            if (sortProb[o1] > sortProb[o2])
                return -1;
            else if (sortProb[o1] < sortProb[o2])
                return 1;
            else
                return 0;
        }
    }
    
    
    public static class ModelParameter{
		public float alpha = 0.5f; //usual value is 50 / K
		public float beta = 0.1f;//usual value is 0.1
		public int topicNum = 10;
		public int iteration = 100;
		public int saveStep = 10;
		public int beginSaveIters = 80;
	}
}

package org.jazywoo.lda;

import java.io.IOException;

public class LDATest {

	/**
	 * @param args
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {
		LdaModel.ModelParameter parameter=new LdaModel.ModelParameter();
		LdaModel ldaModel=new LdaModel(parameter);
		
		String path="D:\\zz";
		Corpus docSet=new Corpus();
		docSet.loadData(path);
		ldaModel.initModal(docSet);
		ldaModel.inferenceModel();
		ldaModel.saveIteratedModel(parameter.iteration);
	}

}






你可能感兴趣的:(Gibbs Sampling实现LDA)