LDA模型学习之(三)走过的弯路

     为了把LDA算法用于文本聚类,我真的是绞尽脑汁。除了去看让我头大的概率论、随机过程、高数这些基础的数学知识,还到网上找已经实现的源代码。

     最先让我看到署光的是Mallet,我研究了大概一个星期,最后决定放弃了。因为Mallet作者提供的例子实在太少了。

      回到了网上找到的这样一段源代码:

 
   
  1. /*  
  2.  * (C) Copyright 2005, Gregor Heinrich (gregor :: arbylon : net) (This file is  
  3.  * part of the org.knowceans experimental software packages.)  
  4.  */ 
  5. /*  
  6.  * LdaGibbsSampler is free software; you can redistribute it and/or modify it  
  7.  * under the terms of the GNU General Public License as published by the Free  
  8.  * Software Foundation; either version 2 of the License, or (at your option) any  
  9.  * later version.  
  10.  */ 
  11. /*  
  12.  * LdaGibbsSampler is distributed in the hope that it will be useful, but  
  13.  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or  
  14.  * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more  
  15.  * details.  
  16.  */ 
  17. /*  
  18.  * You should have received a copy of the GNU General Public License along with  
  19.  * this program; if not, write to the Free Software Foundation, Inc., 59 Temple  
  20.  * Place, Suite 330, Boston, MA 02111-1307 USA  
  21.  */ 
  22.  
  23. /*  
  24.  * Created on Mar 6, 2005  
  25.  */ 
  26. package com.xh.lda;  
  27.  
  28. import java.text.DecimalFormat;  
  29. import java.text.NumberFormat;  
  30.  
  31. /**  
  32.  * Gibbs sampler for estimating the best assignments of topics for words and  
  33.  * documents in a corpus. The algorithm is introduced in Tom Griffiths' paper  
  34.  * "Gibbs sampling in the generative model of Latent Dirichlet Allocation"  
  35.  * (2002).  
  36.  *   
  37.  * @author heinrich  
  38.  */ 
  39. public class LdaGibbsSampler {  
  40.  
  41.     /**  
  42.      * document data (term lists)  
  43.      */ 
  44.     int[][] documents;  
  45.  
  46.     /**  
  47.      * vocabulary size  
  48.      */ 
  49.     int V;  
  50.  
  51.     /**  
  52.      * number of topics  
  53.      */ 
  54.     int K;  
  55.  
  56.     /**  
  57.      * Dirichlet parameter (document--topic associations)  
  58.      */ 
  59.     double alpha;  
  60.  
  61.     /**  
  62.      * Dirichlet parameter (topic--term associations)  
  63.      */ 
  64.     double beta;  
  65.  
  66.     /**  
  67.      * topic assignments for each word.  
  68.      */ 
  69.     int z[][];  
  70.  
  71.     /**  
  72.      * cwt[i][j] number of instances of word i (term?) assigned to topic j.  
  73.      */ 
  74.     int[][] nw;  
  75.  
  76.     /**  
  77.      * na[i][j] number of words in document i assigned to topic j.  
  78.      */ 
  79.     int[][] nd;  
  80.  
  81.     /**  
  82.      * nwsum[j] total number of words assigned to topic j.  
  83.      */ 
  84.     int[] nwsum;  
  85.  
  86.     /**  
  87.      * nasum[i] total number of words in document i.  
  88.      */ 
  89.     int[] ndsum;  
  90.  
  91.     /**  
  92.      * cumulative statistics of theta  
  93.      */ 
  94.     double[][] thetasum;  
  95.  
  96.     /**  
  97.      * cumulative statistics of phi  
  98.      */ 
  99.     double[][] phisum;  
  100.  
  101.     /**  
  102.      * size of statistics  
  103.      */ 
  104.     int numstats;  
  105.  
  106.     /**  
  107.      * sampling lag (?)  
  108.      */ 
  109.     private static int THIN_INTERVAL = 20;  
  110.  
  111.     /**  
  112.      * burn-in period  
  113.      */ 
  114.     private static int BURN_IN = 100;  
  115.  
  116.     /**  
  117.      * max iterations  
  118.      */ 
  119.     private static int ITERATIONS = 1000;  
  120.  
  121.     /**  
  122.      * sample lag (if -1 only one sample taken)  
  123.      */ 
  124.     private static int SAMPLE_LAG;  
  125.  
  126.     private static int dispcol = 0;  
  127.  
  128.     /**  
  129.      * Initialise the Gibbs sampler with data.  
  130.      *   
  131.      * @param V  
  132.      *            vocabulary size  
  133.      * @param data  
  134.      */ 
  135.     public LdaGibbsSampler(int[][] documents, int V) {  
  136.  
  137.         this.documents = documents;  
  138.         this.V = V;  
  139.     }  
  140.  
  141.     /**  
  142.      * Initialisation: Must start with an assignment of observations to topics ?  
  143.      * Many alternatives are possible, I chose to perform random assignments  
  144.      * with equal probabilities  
  145.      *   
  146.      * @param K  
  147.      *            number of topics  
  148.      * @return z assignment of topics to words  
  149.      */ 
  150.     public void initialState(int K) {  
  151.         int i;  
  152.  
  153.         int M = documents.length;  
  154.  
  155.         // initialise count variables.  
  156.         nw = new int[V][K];  
  157.         nd = new int[M][K];  
  158.         nwsum = new int[K];  
  159.         ndsum = new int[M];  
  160.  
  161.         // The z_i are are initialised to values in [1,K] to determine the  
  162.         // initial state of the Markov chain.  
  163.  
  164.         z = new int[M][];  
  165.         for (int m = 0; m < M; m++) {  
  166.             int N = documents[m].length;  
  167.             z[m] = new int[N];  
  168.             for (int n = 0; n < N; n++) {  
  169.                 int topic = (int) (Math.random() * K);  
  170.                 z[m][n] = topic;  
  171.                 // number of instances of word i assigned to topic j  
  172.                 nw[documents[m][n]][topic]++;  
  173.                 // number of words in document i assigned to topic j.  
  174.                 nd[m][topic]++;  
  175.                 // total number of words assigned to topic j.  
  176.                 nwsum[topic]++;  
  177.             }  
  178.             // total number of words in document i  
  179.             ndsum[m] = N;  
  180.         }  
  181.     }  
  182.  
  183.     /**  
  184.      * Main method: Select initial state ? Repeat a large number of times: 1.  
  185.      * Select an element 2. Update conditional on other elements. If  
  186.      * appropriate, output summary for each run.  
  187.      *   
  188.      * @param K  
  189.      *            number of topics  
  190.      * @param alpha  
  191.      *            symmetric prior parameter on document--topic associations  
  192.      * @param beta  
  193.      *            symmetric prior parameter on topic--term associations  
  194.      */ 
  195.     public void gibbs(int K, double alpha, double beta) {  
  196.         this.K = K;  
  197.         this.alpha = alpha;  
  198.         this.beta = beta;  
  199.  
  200.         // init sampler statistics  
  201.         if (SAMPLE_LAG > 0) {  
  202.             thetasum = new double[documents.length][K];  
  203.             phisum = new double[K][V];  
  204.             numstats = 0;  
  205.         }  
  206.  
  207.         // initial state of the Markov chain:  
  208.         initialState(K);  
  209.  
  210.         System.out.println("Sampling " + ITERATIONS  
  211.             + " iterations with burn-in of " + BURN_IN + " (B/S=" 
  212.             + THIN_INTERVAL + ").");  
  213.  
  214.         for (int i = 0; i < ITERATIONS; i++) {  
  215.  
  216.             // for all z_i  
  217.             for (int m = 0; m < z.length; m++) {  
  218.                 for (int n = 0; n < z[m].length; n++) {  
  219.  
  220.                     // (z_i = z[m][n])  
  221.                     // sample from p(z_i|z_-i, w)  
  222.                     int topic = sampleFullConditional(m, n);  
  223.                     z[m][n] = topic;  
  224.                 }  
  225.             }  
  226.  
  227.             if ((i < BURN_IN) && (i % THIN_INTERVAL == 0)) {  
  228. //                System.out.print("B");  
  229.                 dispcol++;  
  230.             }  
  231.             // display progress  
  232.             if ((i > BURN_IN) && (i % THIN_INTERVAL == 0)) {  
  233. //                System.out.print("S");  
  234.                 dispcol++;  
  235.             }  
  236.             // get statistics after burn-in  
  237.             if ((i > BURN_IN) && (SAMPLE_LAG > 0) && (i % SAMPLE_LAG == 0)) {  
  238.                 updateParams();  
  239. //                System.out.print("|");  
  240.                 if (i % THIN_INTERVAL != 0)  
  241.                     dispcol++;  
  242.             }  
  243.             if (dispcol >= 100) {  
  244. //                System.out.println();  
  245.                 dispcol = 0;  
  246.             }  
  247.         }  
  248.     }  
  249.  
  250.     /**  
  251.      * Sample a topic z_i from the full conditional distribution: p(z_i = j |  
  252.      * z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) +  
  253.      * alpha)/(n_-i,.(d_i) + K * alpha)  
  254.      *   
  255.      * @param m  
  256.      *            document  
  257.      * @param n  
  258.      *            word  
  259.      */ 
  260.     private int sampleFullConditional(int m, int n) {  
  261.  
  262.         // remove z_i from the count variables  
  263.         int topic = z[m][n];  
  264.         nw[documents[m][n]][topic]--;  
  265.         nd[m][topic]--;  
  266.         nwsum[topic]--;  
  267.         ndsum[m]--;  
  268.  
  269.         // do multinomial sampling via cumulative method:  
  270.         double[] p = new double[K];  
  271.         for (int k = 0; k < K; k++) {  
  272.             p[k] = (nw[documents[m][n]][k] + beta) / (nwsum[k] + V * beta)  
  273.                 * (nd[m][k] + alpha) / (ndsum[m] + K * alpha);  
  274.         }  
  275.         // cumulate multinomial parameters  
  276.         for (int k = 1; k < p.length; k++) {  
  277.             p[k] += p[k - 1];  
  278.         }  
  279.         // scaled sample because of unnormalised p[]  
  280.         double u = Math.random() * p[K - 1];  
  281.         for (topic = 0; topic < p.length; topic++) {  
  282.             if (u < p[topic])  
  283.                 break;  
  284.         }  
  285.  
  286.         // add newly estimated z_i to count variables  
  287.         nw[documents[m][n]][topic]++;  
  288.         nd[m][topic]++;  
  289.         nwsum[topic]++;  
  290.         ndsum[m]++;  
  291.  
  292.         return topic;  
  293.     }  
  294.  
  295.     /**  
  296.      * Add to the statistics the values of theta and phi for the current state.  
  297.      */ 
  298.     private void updateParams() {  
  299.         for (int m = 0; m < documents.length; m++) {  
  300.             for (int k = 0; k < K; k++) {  
  301.                 thetasum[m][k] += (nd[m][k] + alpha) / (ndsum[m] + K * alpha);  
  302.             }  
  303.         }  
  304.         for (int k = 0; k < K; k++) {  
  305.             for (int w = 0; w < V; w++) {  
  306.                 phisum[k][w] += (nw[w][k] + beta) / (nwsum[k] + V * beta);  
  307.             }  
  308.         }  
  309.         numstats++;  
  310.     }  
  311.  
  312.     /**  
  313.      * Retrieve estimated document--topic associations. If sample lag > 0 then  
  314.      * the mean value of all sampled statistics for theta[][] is taken.  
  315.      *   
  316.      * @return theta multinomial mixture of document topics (M x K)  
  317.      */ 
  318.     public double[][] getTheta() {  
  319.         double[][] theta = new double[documents.length][K];  
  320.  
  321.         if (SAMPLE_LAG > 0) {  
  322.             for (int m = 0; m < documents.length; m++) {  
  323.                 for (int k = 0; k < K; k++) {  
  324.                     theta[m][k] = thetasum[m][k] / numstats;  
  325.                 }  
  326.             }  
  327.  
  328.         } else {  
  329.             for (int m = 0; m < documents.length; m++) {  
  330.                 for (int k = 0; k < K; k++) {  
  331.                     theta[m][k] = (nd[m][k] + alpha) / (ndsum[m] + K * alpha);  
  332.                 }  
  333.             }  
  334.         }  
  335.  
  336.         return theta;  
  337.     }  
  338.  
  339.     /**  
  340.      * Retrieve estimated topic--word associations. If sample lag > 0 then the  
  341.      * mean value of all sampled statistics for phi[][] is taken.  
  342.      *   
  343.      * @return phi multinomial mixture of topic words (K x V)  
  344.      */ 
  345.     public double[][] getPhi() {  
  346.         System.out.println("K is:"+K+",V is:"+V);  
  347.         double[][] phi = new double[K][V];  
  348.         if (SAMPLE_LAG > 0) {  
  349.             for (int k = 0; k < K; k++) {  
  350.                 for (int w = 0; w < V; w++) {  
  351.                     phi[k][w] = phisum[k][w] / numstats;  
  352.                 }  
  353.             }  
  354.         } else {  
  355.             for (int k = 0; k < K; k++) {  
  356.                 for (int w = 0; w < V; w++) {  
  357.                     phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);  
  358.                 }  
  359.             }  
  360.         }  
  361.         return phi;  
  362.     }  
  363.     /**  
  364.      * Configure the gibbs sampler  
  365.      *   
  366.      * @param iterations  
  367.      *            number of total iterations  
  368.      * @param burnIn  
  369.      *            number of burn-in iterations  
  370.      * @param thinInterval  
  371.      *            update statistics interval  
  372.      * @param sampleLag  
  373.      *            sample interval (-1 for just one sample at the end)  
  374.      */ 
  375.     public void configure(int iterations, int burnIn, int thinInterval,  
  376.         int sampleLag) {  
  377.         ITERATIONS = iterations;  
  378.         BURN_IN = burnIn;  
  379.         THIN_INTERVAL = thinInterval;  
  380.         SAMPLE_LAG = sampleLag;  
  381.     }  
  382.  
  383.     /**  
  384.      * Driver with example data.  
  385.      *   
  386.      * @param args  
  387.      */ 
  388.     public static void main(String[] args) {  
  389.  
  390.         // words in documents  
  391.         int[][] documents = {   
  392.             {1432314323143236},  
  393.             {224242222422},  
  394.             {1656016560165600},  
  395.             {56623365622656660},  
  396.             {224444155555511110},  
  397.             {542345665432},  
  398.             
  399.             };  
  400.           
  401.  
  402.         // vocabulary  
  403.         int V = 7;  
  404.         int M = documents.length;  
  405.         // # topics  
  406.         int K = 2;  
  407.         // good values alpha = 2, beta = .5  
  408.         double alpha = 2;  
  409.         double beta = .5;  
  410.  
  411.         System.out.println("Latent Dirichlet Allocation using Gibbs Sampling.");  
  412.  
  413.         LdaGibbsSampler lda = new LdaGibbsSampler(documents, V);  
  414.         lda.configure(10000200010010);  
  415.         lda.gibbs(K, alpha, beta);//用gibbs抽样  
  416.  
  417.         double[][] theta = lda.getTheta();//Theta是我们所希望的一种分布可能  
  418.         double[][] phi = lda.getPhi();  
  419.  
  420.         System.out.println();  
  421.         System.out.println();  
  422.         System.out.println("Document--Topic Associations, Theta[d][k] (alpha=" 
  423.             + alpha + ")");  
  424.         System.out.print("d\\k\t");  
  425.         for (int m = 0; m < theta[0].length; m++) {  
  426.             System.out.print("   " + m % 10 + "    ");  
  427.         }  
  428.         System.out.println();  
  429.         for (int m = 0; m < theta.length; m++) {  
  430.             System.out.print(m + "\t");  
  431.             for (int k = 0; k < theta[m].length; k++) {  
  432.                  System.out.print(theta[m][k] + " ");  
  433. //                System.out.print(shadeDouble(theta[m][k], 1) + " ");  
  434.             }  
  435.             System.out.println();  
  436.         }  
  437.         System.out.println();  
  438.         System.out.println("Topic--Term Associations, Phi[k][w] (beta=" + beta  
  439.             + ")");  
  440.  
  441.         System.out.print("k\\w\t");  
  442.         for (int w = 0; w < phi[0].length; w++) {  
  443.             System.out.print("   " + w % 10 + "    ");  
  444.         }  
  445.         System.out.println();  
  446.         for (int k = 0; k < phi.length; k++) {  
  447.             System.out.print(k + "\t");  
  448.             for (int w = 0; w < phi[k].length; w++) {  
  449.                  System.out.print(phi[k][w] + " ");  
  450. //                System.out.print(shadeDouble(phi[k][w], 1) + " ");  
  451.             }  
  452.             System.out.println();  
  453.         }  
  454.     }  
  455.    
  456. }  
  457.  

   代码中关于数学部分我现在依然没有弄懂,但是先能用着再说吧。

   // vocabulary
        int V = 7;// 表示所有的文档中词汇的总数为7
        int M = documents.length;//表示文档的总个数
        // # topics
        int K = 2;//如果用于聚类,表示类簇的个数:主题的个数
        // good values alpha = 2, beta = .5

下面两个是LDA模型的参数,可以先不用管。
        double alpha = 2;
        double beta = .5;

我用的做法是:文本分词后对词进行统计,然后给词编号。这样就可以把文档

转化成了document矩阵了!

你可能感兴趣的:(LDA模型学习之(三)走过的弯路)