public class LdaGibbsSampling { public static class modelparameters { float alpha = 0.5f; //usual value is 50 / K float beta = 0.1f;//usual value is 0.1 int topicNum = 100; int iteration = 100; int saveStep = 10; int beginSaveIters = 50; } /**Get parameters from configuring file. If the * configuring file has value in it, use the value. * Else the default value in program will be used * @param ldaparameters * @param parameterFile * @return void */ private static void getParametersFromFile(modelparameters ldaparameters, String parameterFile) { // TODO Auto-generated method stub ArrayList<String> paramLines = new ArrayList<String>(); paramLines = FileUtil.readList(parameterFile); for(String line : paramLines){ String[] lineParts = line.split("\t"); switch(parameters.valueOf(lineParts[0])){ case alpha: ldaparameters.alpha = Float.valueOf(lineParts[1]); break; case beta: ldaparameters.beta = Float.valueOf(lineParts[1]); break; case topicNum: ldaparameters.topicNum = Integer.valueOf(lineParts[1]); break; case iteration: ldaparameters.iteration = Integer.valueOf(lineParts[1]); break; case saveStep: ldaparameters.saveStep = Integer.valueOf(lineParts[1]); break; case beginSaveIters: ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]); break; } } } public enum parameters{ alpha, beta, topicNum, iteration, saveStep, beginSaveIters; } /** * 训练LDA主题模型,对给定的测试样本集进行主题预测,找出每个样本的最大概率主题下的前20个词的集合,作为该测试样本集的主题代表关键词集合 * @param trainPathDir * @param parameterFile * @param resultPath * @param testPath * @return * @throws IOException */ public Set<Word> trainAndPredictLDA(String trainPathDir,String parameterFile,String resultPath,String testPath) throws IOException{ modelparameters ldaparameters = new modelparameters(); getParametersFromFile(ldaparameters, parameterFile); Documents docSet = new Documents(); docSet.readDocs(trainPathDir); System.out.println("wordMap size " + docSet.termToIndexMap.size()); FileUtil.mkdir(resultPath); LdaModel model = new LdaModel(ldaparameters); System.out.println("1 Initialize the model ..."); model.initializeModel(docSet); System.out.println("2 Learning and Saving the model ..."); model.inferenceModel(docSet); System.out.println("3 Output the final model ..."); // model.saveIteratedModel(ldaparameters.iteration, docSet); // System.out.println("Done!"); //预测新文本 Documents testDocs = new Documents(); List<Message> messages = FileUtil.readMessageFromFile(testPath); Set<Integer> topicIndexSet = new HashSet<Integer> (); for(Message message : messages){ String content = message.getContent(); Document doc = new Document(content); testDocs.docs.add(doc); topicIndexSet.add(model.predictNewSampleTopic(doc)); } /** * 预测每条短信,得到每条的最大概率主题,最后找到每个最大概率主题的前20个词,集合,计算tf-idf */ Set<Word> wordSet = model.getWordByTopics(topicIndexSet, 20); LDAFeatureProcess.calTFIDFAsWeight(docSet, wordSet); return wordSet; } @Test public void test() throws IOException{ String resultPath = "ldaResult/"; String parameterFile= "source/lda_parameters.txt"; String trainPathDir = "LDATrain/"; String testPath = "train/train_messages.txt"; Set<Word> wordSet = trainAndPredictLDA(trainPathDir,parameterFile,resultPath,testPath); FileUtil.writeKeyWordFile("ldaWords/keyWords.doc", new ArrayList<Word>(wordSet)); } /** * @param args * @throws IOException */ public static void main(String[] args) throws IOException { // TODO Auto-generated method stub String resultPath = "ldaResult/"; String parameterFile= "source/lda_parameters.txt"; modelparameters ldaparameters = new modelparameters(); getParametersFromFile(ldaparameters, parameterFile); String dirPath = "LDATrain/"; Documents docSet = new Documents(); docSet.readDocs(dirPath); System.out.println("wordMap size " + docSet.termToIndexMap.size()); FileUtil.mkdir(resultPath); LdaModel model = new LdaModel(ldaparameters); System.out.println("1 Initialize the model ..."); model.initializeModel(docSet); System.out.println("2 Learning and Saving the model ..."); model.inferenceModel(docSet); System.out.println("3 Output the final model ..."); model.saveIteratedModel(ldaparameters.iteration, docSet); System.out.println("Done!"); //预测新文本 String messStr = "好消息!!薇町婚纱造型推出老带新活动啦!已在本店预定的新娘推荐新顾客来本店,定单后即赠送新、老顾客各一支价值58元定妆隔离水(在婚礼当"; Document doc = new Document(messStr); int topicIndex = model.predictNewSampleTopic(doc); Set<Word> wordSet = model.getWordByTopic(topicIndex); FileUtil.writeKeyWordFile("ldaWords/comparedkeyWords.doc", new ArrayList<Word>(wordSet)); } }
public class LdaModel { int [][] doc;//word index array int V, K, M;//vocabulary size, topic number, document number int [][] z;//topic label array float alpha; //doc-topic dirichlet prior parameter float beta; //topic-word dirichlet prior parameter int [][] nmk;//given document m, count times of topic k. M*K int [][] nkt;//given topic k, count times of term t. K*V int [] nmkSum;//Sum for each row in nmk int [] nktSum;//Sum for each row in nkt double [][] phi;//Parameters for topic-word distribution K*V double [][] theta;//Parameters for doc-topic distribution M*K int iterations;//Times of iterations int saveStep;//The number of iterations between two saving int beginSaveIters;//Begin save model at this iteration Map<String, Integer> wordIndexMap; Documents docSet; public LdaModel(LdaGibbsSampling.modelparameters modelparam) { // TODO Auto-generated constructor stub alpha = modelparam.alpha; beta = modelparam.beta; iterations = modelparam.iteration; K = modelparam.topicNum; saveStep = modelparam.saveStep; beginSaveIters = modelparam.beginSaveIters; } public void initializeModel(Documents docSet) { this.docSet = docSet; // TODO Auto-generated method stub M = docSet.docs.size(); V = docSet.termToIndexMap.size(); nmk = new int [M][K]; nkt = new int[K][V]; nmkSum = new int[M]; nktSum = new int[K]; phi = new double[K][V]; theta = new double[M][K]; this.wordIndexMap = new HashMap<String, Integer> (); //initialize documents index array doc = new int[M][]; for(int m = 0; m < M; m++){ //Notice the limit of memory int N = docSet.docs.get(m).docWords.length; doc[m] = new int[N]; for(int n = 0; n < N; n++){ doc[m][n] = docSet.docs.get(m).docWords[n]; } } //initialize topic lable z for each word z = new int[M][]; for(int m = 0; m < M; m++){ int N = docSet.docs.get(m).docWords.length; z[m] = new int[N]; for(int n = 0; n < N; n++){ //随机初始化! int initTopic = (int)(Math.random() * K);// From 0 to K - 1 z[m][n] = initTopic; //number of words in doc m assigned to topic initTopic add 1 nmk[m][initTopic]++; //number of terms doc[m][n] assigned to topic initTopic add 1 nkt[initTopic][doc[m][n]]++; // total number of words assigned to topic initTopic add 1 nktSum[initTopic]++; } // total number of words in document m is N nmkSum[m] = N; } } public void inferenceModel(Documents docSet) throws IOException { // TODO Auto-generated method stub if(iterations < saveStep + beginSaveIters){ System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters)); System.exit(0); } for(int i = 0; i < iterations; i++){ System.out.println("Iteration " + i); if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){ //Saving the model System.out.println("Saving model at iteration " + i +" ... "); //Firstly update parameters updateEstimatedParameters(); //Secondly print model variables saveIteratedModel(i, docSet); } //Use Gibbs Sampling to update z[][] for(int m = 0; m < M; m++){ int N = docSet.docs.get(m).docWords.length; for(int n = 0; n < N; n++){ // Sample from p(z_i|z_-i, w) int newTopic = sampleTopicZ(m, n); z[m][n] = newTopic; } } } } private void updateEstimatedParameters() { // TODO Auto-generated method stub for(int k = 0; k < K; k++){ for(int t = 0; t < V; t++){ phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta); } } for(int m = 0; m < M; m++){ for(int k = 0; k < K; k++){ theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); } } } private int sampleTopicZ(int m, int n) { // TODO Auto-generated method stub // Sample from p(z_i|z_-i, w) using Gibbs upde rule //Remove topic label for w_{m,n} int oldTopic = z[m][n]; nmk[m][oldTopic]--; nkt[oldTopic][doc[m][n]]--; nmkSum[m]--; nktSum[oldTopic]--; //Compute p(z_i = k|z_-i, w) double [] p = new double[K]; for(int k = 0; k < K; k++){ p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); } //Sample a new topic label for w_{m, n} like roulette //Compute cumulated probability for p for(int k = 1; k < K; k++){ p[k] += p[k - 1]; } double u = Math.random() * p[K - 1]; //p[] is unnormalised int newTopic; for(newTopic = 0; newTopic < K; newTopic++){ if(u < p[newTopic]){ break; } } //Add new topic label for w_{m, n} nmk[m][newTopic]++; nkt[newTopic][doc[m][n]]++; nmkSum[m]++; nktSum[newTopic]++; return newTopic; } /** * 对给定的待预测的文本,将其分词结果的单词与训练集的单词的索引对应上 * @param predictWordSet * @return */ public Map<String,String> matchTermIndex(Set<Word> predictWordSet){ /** * key:word的内容 value:文档index-单词index,如“1-2” */ Map<String,String> wordIndexMap = new HashMap<String, String> (); for(Word word : predictWordSet){ String content = word.getContent(); String indexStr = getTermIndex(content); wordIndexMap.put(content, indexStr); } return wordIndexMap; } /** * 对于给定单词,找到该单词在训练集中对应的文档和单词索引 * @param content * @return */ public String getTermIndex(String content){ for(Integer m : docSet.getDocWordsList().keySet()){ LinkedList<String> list = docSet.getDocWordsList().get(m); for(int i = 0; i < list.size(); i ++){ if(list.get(i).equals(content)) return m+"-"+i; } } return "none"; } /** * 在训练完LDA模型后,根据给定的主题索引set,得到每个主题的topNum单词列表集合 * @param topicIndexSet * @param topNum * @return */ public Set<Word> getWordByTopics(Set<Integer> topicIndexSet, int topNum){ Set<Word> wordSet = new HashSet<Word> (); for(Integer indexT : topicIndexSet){ List<Integer> tWordsIndexArray = new ArrayList<Integer>(); for(int j = 0; j < V; j++) tWordsIndexArray.add(new Integer(j)); Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[indexT])); for(int t = 0; t < topNum; t++){ String content = docSet.indexToTermMap.get(tWordsIndexArray.get(t)); Word word = new Word(content); if(SegmentWordsResult.getStopWordsSet().contains(content)|| ProcessKeyWords.remove(word) || ProcessKeyWords.isMeaninglessWord(content)) continue; wordSet.add(word); } } return wordSet; } public Set<Word> getWordByTopic(Integer topicIndex){ Set<Word> wordSet = new HashSet<Word> (); List<Integer> tWordsIndexArray = new ArrayList<Integer>(); for(int j = 0; j < V; j++){ tWordsIndexArray.add(new Integer(j)); } Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[topicIndex])); for(int t = 0; t < V; t++){ String content = docSet.indexToTermMap.get(tWordsIndexArray.get(t)); Word word = new Word(content); word.setWeight(phi[topicIndex][tWordsIndexArray.get(t)]); if(SegmentWordsResult.getStopWordsSet().contains(content)|| ProcessKeyWords.remove(word) || ProcessKeyWords.isMeaninglessWord(content)) continue; if(phi[topicIndex][tWordsIndexArray.get(t)] <= 0.0) continue; wordSet.add(word); } return wordSet; } public int predictNewSampleTopic(Document doc){ double topicProb[] = new double[K]; Map<String,String> wordIndexMap = matchTermIndex(doc.getWordMap().keySet()); int predict_v = doc.getWordCount(); int [][] predict_nkt;//given topic k, count times of term t. K*V double [][] predict_phi;//Parameters for topic-word distribution K*V int [] predict_z;//topic label array int [] predict_nk;//该文档覆盖的主题索引,值为该文档覆盖指定主题的次数 predict_nkt = new int[K][predict_v]; predict_phi = new double[K][predict_v]; predict_z = new int[predict_v]; predict_nk = new int[K]; for(int index = 0; index < predict_v; index++){ String content = doc.getWordsList().get(index); String indexStr = wordIndexMap.get(content); if(indexStr.indexOf("-") == -1) continue; int m = Integer.valueOf(indexStr.substring(0, indexStr.indexOf("-"))); int n = Integer.valueOf(indexStr.substring(indexStr.indexOf("-")+1)); // Sample from p(z_i|z_-i, w) int newTopic = predictSampleTopicZ(m, n); predict_z[index] = newTopic; predict_nkt[newTopic][index] ++; predict_nk[newTopic] ++; } for(int k = 0; k < K; k++){ topicProb[k] = (predict_nk[k] + alpha) / (predict_v + K * alpha); } return getTopic(topicProb); } public int getTopic(double[] topicProp){ int maxIndex = 0; double maxProp = topicProp[0]; Set<String> words = new HashSet<String> (); for(int k = 1; k < K; k ++){ if(maxProp < topicProp[k]){ maxProp = topicProp[k]; maxIndex = k; } } return maxIndex; } public int predictSampleTopicZ(int m, int n){ // TODO Auto-generated method stub // Sample from p(z_i|z_-i, w) using Gibbs upde rule //Compute p(z_i = k|z_-i, w) double [] p = new double[K]; for(int k = 0; k < K; k++){ p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha); } //Sample a new topic label for w_{m, n} like roulette //Compute cumulated probability for p for(int k = 1; k < K; k++){ p[k] += p[k - 1]; } double u = Math.random() * p[K - 1]; //p[] is unnormalised int newTopic; for(newTopic = 0; newTopic < K; newTopic++){ if(u < p[newTopic]){ break; } } //Add new topic label for w_{m, n} return newTopic; } public void saveIteratedModel(int iters, Documents docSet) throws IOException { // TODO Auto-generated method stub //lda.params lda.phi lda.theta lda.tassign lda.twords //lda.params String resultPath = "ldaResult/"; String modelName = "lda_" + iters; ArrayList<String> lines = new ArrayList<String>(); lines.add("alpha = " + alpha); lines.add("beta = " + beta); lines.add("topicNum = " + K); lines.add("docNum = " + M); lines.add("termNum = " + V); lines.add("iterations = " + iterations); lines.add("saveStep = " + saveStep); lines.add("beginSaveIters = " + beginSaveIters); FileUtil.writeLines(resultPath + modelName + ".params", lines); //lda.phi K*V BufferedWriter writer = new BufferedWriter(new FileWriter(resultPath + modelName + ".phi")); for (int i = 0; i < K; i++){ for (int j = 0; j < V; j++){ writer.write(phi[i][j] + "\t"); } writer.write("\n"); } writer.close(); //lda.theta M*K writer = new BufferedWriter(new FileWriter(resultPath + modelName + ".theta")); for(int i = 0; i < M; i++){ for(int j = 0; j < K; j++){ writer.write(theta[i][j] + "\t"); } writer.write("\n"); } writer.close(); //lda.tassign writer = new BufferedWriter(new FileWriter(resultPath + modelName + ".tassign")); for(int m = 0; m < M; m++){ for(int n = 0; n < doc[m].length; n++){ writer.write(doc[m][n] + ":" + z[m][n] + "\t"); } writer.write("\n"); } writer.close(); List<Word> appendwords = new ArrayList<Word> (); //lda.twords phi[][] K*V writer = new BufferedWriter(new FileWriter(resultPath + modelName + ".twords")); int topNum = 10; //Find the top 20 topic words in each topic for(int i = 0; i < K; i++){ List<Integer> tWordsIndexArray = new ArrayList<Integer>(); for(int j = 0; j < V; j++){ tWordsIndexArray.add(new Integer(j)); } Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i])); writer.write("topic " + i + "\t:\t"); for(int t = 0; t < topNum; t++){ writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t"); Word word = new Word(docSet.indexToTermMap.get(tWordsIndexArray.get(t))); word.setWeight(phi[i][tWordsIndexArray.get(t)]); appendwords.add(word); } writer.write("\n"); } writer.close(); //lda.words writer = new BufferedWriter(new FileWriter(resultPath + modelName + ".words")); for(Word word : appendwords){ if(word.getContent().trim().equals("")) continue; writer.write(word.getContent()+"\t"+word.getWeight()+"\n"); } writer.close(); } public class TwordsComparable implements Comparator<Integer> { public double [] sortProb; // Store probability of each word in topic k public TwordsComparable (double[] sortProb){ this.sortProb = sortProb; } @Override public int compare(Integer o1, Integer o2) { // TODO Auto-generated method stub //Sort topic word index according to the probability of each word in topic k if(sortProb[o1] > sortProb[o2]) return -1; else if(sortProb[o1] < sortProb[o2]) return 1; else return 0; } } public static void main(String[] args){ } }
public class Documents { ArrayList<Document> docs; Map<String, Integer> termToIndexMap; ArrayList<String> indexToTermMap; Map<String,Integer> termCountMap; private static NLPIRUtil npr = new NLPIRUtil(); private static Set<String> stopWordsSet = SegmentWordsResult.getStopWordsSet(); private Map<Word,Integer> wordDocMap; private Map<Integer, LinkedList<String>> docWordsList;//key:第i篇文档,value:单词列表,为了与lda模型中的doc[m][n]的索引对应 public Documents(){ docs = new ArrayList<Document>(); termToIndexMap = new HashMap<String, Integer>(); indexToTermMap = new ArrayList<String>(); termCountMap = new HashMap<String, Integer>(); this.wordDocMap = new HashMap<Word, Integer> (); this.docWordsList = new HashMap<Integer, LinkedList<String>> (); } public Map<String, Integer> getTermCountMap() { return termCountMap; } public void setTermCountMap(Map<String, Integer> termCountMap) { this.termCountMap = termCountMap; } public Map<Word, Integer> getWordDocMap() { return wordDocMap; } public void setWordDocMap(Map<Word, Integer> wordDocMap) { this.wordDocMap = wordDocMap; } public Map<Integer, LinkedList<String>> getDocWordsList() { return docWordsList; } public void setDocWordsList(Map<Integer, LinkedList<String>> docWordsList) { this.docWordsList = docWordsList; } public void readDocs(String docsPath){ int index = 0; for(File docFile : new File(docsPath).listFiles()){ Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap); docs.add(doc); for(Word word : doc.getWordMap().keySet()){ if(this.wordDocMap.containsKey(word)) this.wordDocMap.put(word, this.wordDocMap.get(word)); else this.wordDocMap.put(word, 1); } this.docWordsList.put(index++, doc.getWordsList()); } } }
public class Document { private static NLPIRUtil npr = new NLPIRUtil(); private static Set<String> stopWordsSet = SegmentWordsResult.getStopWordsSet(); private String docName; int[] docWords; private int wordCount; private Map<Word, Integer> wordMap ; private LinkedList<String> wordsList;//为了和docWords的索引对应,即单词内容对应索引值 public int getWordCount() { return wordCount; } public void setWordCount(int wordCount) { this.wordCount = wordCount; } public Map<Word, Integer> getWordMap() { return wordMap; } public void setWordMap(Map<Word, Integer> wordMap) { this.wordMap = wordMap; } public LinkedList<String> getWordsList() { return wordsList; } public void setWordsList(LinkedList<String> wordsList) { this.wordsList = wordsList; } public Document(String docContent){ this.wordMap = new HashMap<Word, Integer> (); this.wordsList = new LinkedList<String> (); String splitResult = npr.NLPIR_ParagraphProcess(ProcessMessage.dealWithSentence(docContent), 0); String[] wordsArray = splitResult.split(" "); this.docWords = new int[wordsArray.length]; int index = 0; //Transfer word to index for(String str : wordsArray){ String content = ProcessMessage.dealSpecialString(str); Word word = new Word(content); if(ProcessKeyWords.remove(word) || stopWordsSet.contains(content)) continue; else if(content.length() <= 1 || RegexMatch.specialMatch(content)) continue; this.wordCount ++; if(!wordMap.containsKey(content)){ int newIndex = wordMap.size(); wordMap.put(word, 1); docWords[index++] = newIndex; }else{ wordMap.put(word, wordMap.get(word)+1); docWords[index++] = wordMap.get(content); } this.wordsList.add(content); } } public Document(String filePath,Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){ this(FileUtil.readContent(filePath)); this.docName = filePath; this.wordMap = new HashMap<Word, Integer> (); this.wordsList = new LinkedList<String> (); //Read file and initialize word index array String docContent = FileUtil.readContent(docName); String splitResult = npr.NLPIR_ParagraphProcess(docContent, 0); String[] wordsArray = splitResult.split(" "); this.docWords = new int[wordsArray.length]; int index = 0; //Transfer word to index for(String str : wordsArray){ String content = ProcessMessage.dealSpecialString(str); Word word = new Word(content); if(ProcessKeyWords.remove(word) || stopWordsSet.contains(content)) continue; else if(ProcessKeyWords.isMeaninglessWord(content)) continue; this.wordCount ++; if(!termToIndexMap.containsKey(content)){ int newIndex = termToIndexMap.size(); termToIndexMap.put(str, newIndex); indexToTermMap.add(str); termCountMap.put(str, new Integer(1)); docWords[index++] = newIndex; }else{ termCountMap.put(content, termCountMap.get(content) + 1); docWords[index++] = termToIndexMap.get(content); } this.wordsList.add(content); if(wordMap.containsKey(word)) wordMap.put(word, wordMap.get(word)+1); else wordMap.put(word, 1); } } public boolean isNoiseWord(String string) { // TODO Auto-generated method stub string = string.toLowerCase().trim(); Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*"); Matcher m = MY_PATTERN.matcher(string); // filter @xxx and URL if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") || string.matches(".*http:.*") ) return true; else return false; } }
上述中的LdaModel中包含了预测新样本的方法predictNewSampleTopic,返回的是该样本的最大概率主题索引,LdaGibbsSampling中是训练LDA主题模型的流程 主题-单词分布的部分结果如下:
topic 0 : ⒐ 0.0029859442729502916 住宅 0.002257665153592825制造 0.002257665153592825 行为 0.002257665153592825收益 0.0015293860342353582 西北 0.0015293860342353582红星 0.0015293860342353582 轻松 0.0015293860342353582小商品 0.0015293860342353582 搜房网 0.0015293860342353582
topic 1 : 贵宾 0.0030435749795287848 商城 0.0023012396413832903 太平洋保险 0.0015589043032377958 建设 0.0015589043032377958 储蓄 0.0015589043032377958 周四 0.0015589043032377958 完成 0.0015589043032377958 区内 0.0015589043032377958 王志钢 0.0015589043032377958 872944 0.0015589043032377958</pre><pre name="code" class="java">