lightLDA输出接口-java版本

根据LightLDA的输出文件得到文档-主题分布主题-词分布以及表示某篇文档的topN关键词

import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.List;
import java.util.PriorityQueue;

/**
 * Created by yangxin on 2017/8/11.
 */
public class LDAResult {
    private double alpha;  //主题分布Dirichlet分布参数
    private double beta;   //词分布Dirichlet分布参数
    private int topic_num;  //主题数目
    private int vocab_num;  //词数目
    private int doc_num;    //文档数目
    private double[][] doc_topic_mat = null;  //文档_主题概率矩阵
    private double[][] topic_vocab_mat = null; //主题_词概率矩阵
    private Item[][] doc_word_info = null;   //文档_top词的信息矩阵

    /**
     * lda每个doc对应的前n个词Id
     */
    public static class Item implements Comparable{
        public double prob;
        public int word_id;

        public Item(double prob, int word_id) {
            this.prob = prob;
            this.word_id = word_id;
        }

        @Override
        public String toString() {
            return "Item{" +
                    "prob=" + prob +
                    ", word_id=" + word_id +
                    '}';
        }

        @Override
        public int compareTo(Item o) {
            return prob - o.prob > 0 ? 1 : -1;
        }
    }

    public LDAResult(double alpha, double beta, int topic_num, int vocab_num, int doc_num) {
        this.alpha = alpha;
        this.beta = beta;
        this.topic_num = topic_num;
        this.vocab_num = vocab_num;
        this.doc_num = doc_num;

        doc_topic_mat = new double[topic_num][doc_num];
        topic_vocab_mat = new double[vocab_num][topic_num];
    }

    /**
     * 得到每个文档前n个关键词
     * @param n
     * @return
     */
    public Item[][] getDocTopWordInfo(int n){
        doc_word_info = new Item[doc_num][n];
        for(int i = 0; i < doc_num; ++i){ //每篇文档
            PriorityQueue queue = new PriorityQueue<>();
            for(int j = 0; j < vocab_num; ++j){ //每个词
                double prob = 0;
                for(int k = 0; k < topic_num; ++k){ //每个主题
                    prob += doc_topic_mat[k][i] * topic_vocab_mat[j][k];
                }
                Item item = new Item(prob, j);
                queue.offer(item);
                if(queue.size() > n){
                    queue.poll();
                }
            }
            int q = queue.size();
            while(!queue.isEmpty()){
                doc_word_info[i][--q] = queue.poll();
            }
        }
        return doc_word_info;
    }

    /**
     * 写每个文档的前n个关键词到文件中
     * @param n
     * @param output  输出文件
     * @param titles  doc标题列表
     * @param words   词列表
     * @throws Exception
     */
    public void dumpTopResult(int n, String output, final List titles, final List words) throws Exception{
        if(n <= 0) return;
        BufferedWriter bw = new BufferedWriter(new FileWriter(output));
        if(doc_word_info == null){
            doc_word_info = getDocTopWordInfo(n);
        }

        for(int i = 0; i < doc_num; ++i){  //doc_id
            bw.write(titles.get(i) + " : ");
            for(Item item : doc_word_info[i]){
                bw.write(words.get(item.word_id) + "/" + item.prob + " ");
            }
            bw.newLine();
            bw.flush();
        }

        bw.close();
    }

    /**
     * 加载文档_主题模型
     * @param model_path
     * @throws Exception
     */
    public void loadDocTopicModel(String model_path) throws Exception{
        //将计数写入到矩阵中
        BufferedReader br = new BufferedReader(new FileReader(model_path));
        String line = null;
        while((line = br.readLine()) != null){
            String[] doc_info = line.split("[\t ]+");
            int doc_id = Integer.parseInt(doc_info[0]);  //文档号,从0开始

            for(int i = 1; i < doc_info.length; ++i){
                String[] topic_info = doc_info[i].split(":");   //对应的主题信息
                int topic_id = Integer.parseInt(topic_info[0]);  //主题id
                int topic_cnt = Integer.parseInt(topic_info[1]);  //主题次数
                doc_topic_mat[topic_id][doc_id] = topic_cnt;
            }
        }
        br.close();

        //计数
        int[] doc_cnts = new int[doc_num];  //每个文档对应的主题数量和,即包含词的数目
        for(int i = 0; i < doc_num; ++i){  //对每个文档
            for(int j = 0; j < topic_num; ++j){  //对每个主题
                doc_cnts[i] += doc_topic_mat[j][i];
            }
        }

        //计算概率
        double factor = topic_num * alpha;
        for(int i = 0; i < doc_num; ++i){  //对每个文档
            for(int j = 0; j < topic_num; ++j){  //对每个主题
                doc_topic_mat[j][i] = (doc_topic_mat[j][i] + alpha) / (doc_cnts[i] + factor);
            }
        }
    }

    /**
     * 加载主题_词模型
     * @param model_path  主题_词模型位置,对应文件 server_model_0
     * @param model_summary_path   主题数目统计,对应文件 server_model_1
     * @throws Exception
     */
    public void loadTopicWordModel(String model_path, String model_summary_path) throws Exception{
        //将计数写入到矩阵中
        BufferedReader br = new BufferedReader(new FileReader(model_path));
        String line = null;
        while((line = br.readLine()) != null){
            String[] info = line.split(" ");
            int word_id = Integer.parseInt(info[0]);  //词id
            for(int i = 1; i < info.length; ++i){
                String[] topic_info = info[i].split(":"); //对应的每个topic信息
                int topic_id = Integer.parseInt(topic_info[0]);  //topic id
                int topic_cnt = Integer.parseInt(topic_info[1]);  //topic计数
                topic_vocab_mat[word_id][topic_id] = topic_cnt;
            }
        }
        br.close();

        //写每个主题出现的次数
        int[] topic_cnts = new int[topic_num];   //主题出现的次数
        br = new BufferedReader(new FileReader(model_summary_path));
        String[] cnts = br.readLine().split(" ");
        for(int i = 1; i < cnts.length; ++i){
            String[] cnt_info = cnts[i].split(":");
            int topic_id = Integer.parseInt(cnt_info[0]);
            int topic_cnt = Integer.parseInt(cnt_info[1]);
            topic_cnts[topic_id] = topic_cnt;
        }
        br.close();

        //写概率
        double factor = vocab_num * beta;   //归一化因子
        for(int i = 0; i < vocab_num; ++i){  //每个词
            for(int j = 0; j < topic_num; ++j){  //每个主题
                topic_vocab_mat[i][j] = (topic_vocab_mat[i][j] + beta) / (topic_cnts[j] + factor);
            }
        }
    }
}

调用

public static void main(String[] args) throws Exception{
    String doc_topic_path = "doc_topic.0";
    String topic_word_path = "server_0_table_0.model";
    String topic_summary = "server_0_table_1.model";
    String ori_doc_path = "merge_texts";
    String ori_word_path = "vocab";
    String output = "result";
    LDAResult result = new LDAResult(0.22, 0.1, 220, 1539967, 146119);
    result.loadTopicWordModel(topic_word_path, topic_summary);  //得到主题-词概率分布
    result.loadDocTopicModel(doc_topic_path);   //得到文档-主题概率分布
    List titles = Util.getTitles(ori_doc_path);  //所有文档名
    List words = Util.getVocabs(ori_word_path);  //所有词
    result.dumpTopResult(10, output, titles, words);  //每篇文档的前10个关键词写入到output中
}

你可能感兴趣的:(lightLDA输出接口-java版本)