NMF把一个矩阵分解为两个矩阵的乘积,可以用来解决很多问题,例如:用户聚类、item聚类、预测(补全)用户对item的评分、个性化推荐等问题。NMF的过程可以转化为最小化损失函数(即误差函数)的过程,其实整个问题也就是一个最优化的问题。详细实现过程如下:(其中,输入矩阵很多时候会比较稀疏,即很多元素都是缺失项,故数据存储采用的是libsvm的格式,这个类在此忽略)
- package NMF_danji;
- import java.io.File;
- import java.util.ArrayList;
- /**
- * @author 玉心sober: http://weibo.com/karensober
- * @date 2013-05-19
- *
- * */
- public class NMF {
- private Dataset dataset = null;
- private int M = -1; // 行数
- private int V = -1; // 列数
- private int K = -1; // 隐含主题数
- double[][] P;
- double[][] Q;
- public NMF(String datafileName, int topics) {
- File datafile = new File(datafileName);
- if (datafile.exists()) {
- if ((this.dataset = new Dataset(datafile)) == null) {
- System.out.println(datafileName + " is null");
- }
- this.M = this.dataset.size();
- this.V = this.dataset.getFeatureNum();
- this.K = topics;
- } else {
- System.out.println(datafileName + " doesn't exist");
- }
- }
- public void initPQ() {
- P = new double[this.M][this.K];
- Q = new double[this.K][this.V];
- for (int k = 0; k < K; k++) {
- for (int i = 0; i < M; i++) {
- P[i][k] = Math.random();
- }
- for (int j = 0; j < V; j++) {
- Q[k][j] = Math.random();
- }
- }
- }
- // 随机梯度下降,更新参数
- public void updatePQ(double alpha, double beta) {
- for (int i = 0; i < M; i++) {
- ArrayList
Ri = this.dataset.getDataAt(i).getAllFeature(); - for (Feature Rij : Ri) {
- // eij=Rij.weight-PQ for updating P and Q
- double PQ = 0;
- for (int k = 0; k < K; k++) {
- PQ += P[i][k] * Q[k][Rij.dim];
- }
- double eij = Rij.weight - PQ;
- // update Pik and Qkj
- for (int k = 0; k < K; k++) {
- double oldPik = P[i][k];
- P[i][k] += alpha
- * (2 * eij * Q[k][Rij.dim] - beta * P[i][k]);
- Q[k][Rij.dim] += alpha
- * (2 * eij * oldPik - beta * Q[k][Rij.dim]);
- }
- }
- }
- }
- // 每步迭代后计算SSE
- public double getSSE(double beta) {
- double sse = 0;
- for (int i = 0; i < M; i++) {
- ArrayList
Ri = this.dataset.getDataAt(i).getAllFeature(); - for (Feature Rij : Ri) {
- double PQ = 0;
- for (int k = 0; k < K; k++) {
- PQ += P[i][k] * Q[k][Rij.dim];
- }
- sse += Math.pow((Rij.weight - PQ), 2);
- }
- }
- for (int i = 0; i < M; i++) {
- for (int k = 0; k < K; k++) {
- sse += ((beta / 2) * (Math.pow(P[i][k], 2)));
- }
- }
- for (int i = 0; i < V; i++) {
- for (int k = 0; k < K; k++) {
- sse += ((beta / 2) * (Math.pow(Q[k][i], 2)));
- }
- }
- return sse;
- }
- // 采用随机梯度下降方法迭代求解参数,即求解最终分解后的矩阵
- public boolean doNMF(int iters, double alpha, double beta) {
- for (int step = 0; step < iters; step++) {
- updatePQ(alpha, beta);
- double sse = getSSE(beta);
- if (step % 100 == 0)
- System.out.println("step " + step + " SSE = " + sse);
- }
- return true;
- }
- public void printMatrix() {
- System.out.println("===========原始矩阵==============");
- for (int i = 0; i < this.dataset.size(); i++) {
- for (Feature feature : this.dataset.getDataAt(i).getAllFeature()) {
- System.out.print(feature.dim + ":" + feature.weight + " ");
- }
- System.out.println();
- }
- }
- public void printFacMatrxi() {
- System.out.println("===========分解矩阵==============");
- for (int i = 0; i < P.length; i++) {
- for (int j = 0; j < Q[0].length; j++) {
- double cell = 0;
- for (int k = 0; k < K; k++) {
- cell += P[i][k] * Q[k][j];
- }
- System.out.print(baoliu(cell, 3) + " ");
- }
- System.out.println();
- }
- }
- // 为double类型变量保留有效数字
- public static double baoliu(double d, int n) {
- double p = Math.pow(10, n);
- return Math.round(d * p) / p;
- }
- public static void main(String[] args) {
- double alpha = 0.002;
- double beta = 0.02;
- NMF nmf = new NMF("D:\\myEclipse\\graphModel\\data\\nmfinput.txt", 10);
- nmf.initPQ();
- nmf.doNMF(3000, alpha, beta);
- // 输出原始矩阵
- nmf.printMatrix();
- // 输出分解后矩阵
- nmf.printFacMatrxi();
- }
- }
结果:
...
step 2900 SSE = 0.5878774074369989
===========原始矩阵==============
0:9.0 1:2.0 2:1.0 3:1.0 4:1.0
0:8.0 1:3.0 2:2.0 3:1.0
0:3.0 3:1.0 4:2.0 5:8.0
1:1.0 3:2.0 4:4.0 5:7.0
0:2.0 1:1.0 2:1.0 4:1.0 5:3.0
===========分解矩阵==============
8.959 2.007 1.007 0.996 1.007 6.293
7.981 2.972 1.989 1.005 2.046 7.076
3.01 1.601 1.773 1.003 2.005 7.968
4.821 1.009 2.209 1.984 3.968 6.988
2.0 0.991 0.984 0.51 1.0 2.994