java版LSA(潜在语义分析)

前言

LSA(潜在语义分析)的详细概念我就不做介绍了。此处简单提一下,首先就是要构建单词-文档矩阵,然后关键的是进行SVD奇异值分解,这是核心的一步,最终通过计算向量夹角余弦值就可得到文档与文档之间的相似度。

这里有一位博主我觉得写得非常好,大家可以参考一下他的文章,他也附上了python的实现代码:
http://zhikaizhang.cn/2016/05/31/自然语言处理之LSA/

代码

因为项目需要,所以实现了一个java版本,加深记忆。下面是java版LSA代码。需要一个外部jar包,maven依赖如下:

<dependency>
	<groupId>gov.nist.mathgroupId>
    <artifactId>jamaartifactId>
    <version>1.0.3version>
dependency>
import Jama.Matrix;
import Jama.SingularValueDecomposition;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 2019年3月29日
 * 
 * TODO Java版LSA
 *
 * @author jiebaHZ
 */
public class LSA {

    private List<String> stopwords;
    private List<String> docs;//注意这里输入的文档是分词之后的
    private Matrix matrix;
    private Map<String, List<Integer>> dictionary = new HashMap<String, List<Integer>>();
    private List<String> keywords = new ArrayList<String>();
    // 维数
    private static int LSD = 2;

    public LSA(List<String> docs) {
        this.docs = docs;
    }

    public void lsa() {
        // 读取停用词
        stopwords = readStopwords();
        // 过滤停用词
        removeStopwords();
        // 生成单词字典
        createDictionary();
        // 得到关键词
        addKeywords();
        // 生成单词-文档矩阵
        createMatrix();
        // SVD分解,降维
        Matrix v = SVD();
        // 计算相似度
        //到这里就得到了矩阵v的转置。然后根据自己的需要就可以计算每两个向量的夹角余弦
        //……
    }

    /**
     * 读取停用词
     */
    public static List<String> readStopwords() {

        ArrayList<String> stopwords = new ArrayList<String>();
        // 读取停用词表
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader(
                    new FileInputStream("./Mystopwords.txt"), "UTF-8"));
            String line = null;
            while ((line = br.readLine()) != null) {
                stopwords.add(line);
            }
            br.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return stopwords;
    }

    /**
     * 过滤停用词
     */
    public void removeStopwords() {
        for (int i = 0; i < docs.size(); i++) {
            String[] doc = docs.get(i).split(" ");
            List<String> words = new ArrayList<String>();
            for (String string : doc) {
                words.add(string);
            }
            words.removeAll(stopwords);
            StringBuilder sb = new StringBuilder();
            for (int j = 0; j < words.size(); j++) {
                sb.append(words.get(j));
                sb.append(" ");
                sb.toString();
            }
            docs.set(i, sb.toString().trim());
        }
    }

    /**
     * 记录每个单词出现在哪些文档中
     */
    public void createDictionary() {
        for (int i = 0; i < docs.size(); i++) {
            String[] words = docs.get(i).split(" ");
            for (String word : words) {
                if (dictionary.containsKey(word)) {
                    dictionary.get(word).add(i);
                } else {
                    List<Integer> idList = new ArrayList<Integer>();
                    idList.add(i);
                    dictionary.put(word, idList);
                }
            }
        }
    }

    /**
     * 得到关键词列表
     */
    public void addKeywords() {
        for (String word : dictionary.keySet()) {
            if (dictionary.get(word).size() >= 1) {
                keywords.add(word);
            }
        }
    }

    /**
     * 生成单词-文档矩阵
     */
    public void createMatrix() {
        double array[][] = new double[keywords.size()][docs.size()];
        matrix = new Matrix(array);
        for (int i = 0; i < keywords.size(); i++) {
            for (Integer j : dictionary.get(keywords.get(i))) {
                matrix.set(i, j, matrix.get(i, j) + 1);
            }
        }
    }

    /**
     * 打印矩阵
     *
     * @param matrix
     */
    public void printMatrix(Matrix matrix) {
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            for (int j = 0; j < matrix.getColumnDimension(); j++) {
                System.out.printf("m(%d,%d) = %g\t", i, j, matrix.get(i, j));
            }
            System.out.printf("\n");
        }
    }

    /**
     * SVD分解,降维
     */
    public Matrix SVD() {
        SingularValueDecomposition svd = matrix.svd();
        // 注意,这里是v的转置
        Matrix v = svd.getV().transpose();
        for (int i = LSD; i < v.getRowDimension(); i++) {
            for (int j = 0; j < docs.size(); j++) {
                v.set(i, j, 0.0);
            }
        }
        return v;
    }

    /**
     * 计算夹角余弦值
     *
     * @param v1
     * @param v2
     */
    double cos(List<Double> v1, List<Double> v2, int dim) {
        // Cos(theta) = A(dot)B / |A||B|
        double a_dot_b = 0;
        for (int i = 0; i < dim; i++) {
            a_dot_b += v1.get(i) * v2.get(i);
        }
        double A = 0;
        for (int j = 0; j < dim; j++) {
            A += v1.get(j) * v1.get(j);
        }
        A = Math.sqrt(A);
        double B = 0;
        for (int k = 0; k < dim; k++) {
            B += v2.get(k) * v2.get(k);
        }
        B = Math.sqrt(B);
        return a_dot_b / (A * B);
    }
}

你可能感兴趣的:(NLP)