增加偏置项的svd推荐

本文使用基于偏置项的svd,对评分矩阵进行矩阵分解,实现用户内容推荐的评分计算。如有错误的地方,希望大家指正。

package com.rec.SVDModel;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

public class SVDBasic {
	
	//用户评分偏置
	private double[] bu;
	//物品评分偏置
	private double[] bi;
	//用户向量
	private double[][] pu;
	//物品向量
	private double[][] qi;
	
	//超参数
	//循环迭代次数
	private int iteration = 100;
	//学习率
	private double learningrate = 0.005;
	//正则系数
	private double coff = 0.02;
	//维度,隐向量的维度,m*k,k*n的k
	private int dimNum = 0;
	//通过构造函数传入超参数
	public SVDBasic(int iter,int dim,double lr,double cf){		
		iteration = iter;
		learningrate = lr;
		coff = cf;
		dimNum = dim;		
	}
	//对变量内存空间进行初始化
	public void init(){
		bu = new double[userNum];
		pu = new double[userNum][dimNum];
		for(int i = 0; i < userNum; i++){
			pu[i] = new double[dimNum];
		}
		bi = new double[itemNum];
		qi = new double[itemNum][dimNum];
		for(int i = 0; i < itemNum; i++){
			qi[i] = new double[dimNum];
		}	
		//对pu和qi赋初值
		for(int i = 0; i < userNum; i++){
			for(int j = 0; j < dimNum; j++){
				pu[i][j] = Math.random();
			}
		}
		for(int i = 0; i < itemNum; i++){
			for(int j = 0; j < dimNum; j++){
				qi[i][j] = Math.random();
			}
		}
	}
	
	//均值
	private double mean = 0.0;
	//用户量
	private int userNum = 0;
	private Map userMap = new HashMap();
	//内容量
	private int itemNum = 0;
	private Map itemMap = new HashMap();
	
	private double[][] rateMatrix;
	//构造矩阵
	public void matrix(String file,String separator) throws IOException{
		BufferedReader br = new BufferedReader(new FileReader(new File(file)));
		String line = "";
		double sum = 0.0;
		int lineNum = 0;
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			String uid = arr[0];
			String cid = arr[1];
			double rate = Double.parseDouble(arr[2]);
			if(!userMap.containsKey(uid)){
				userMap.put(uid, userNum);
				userNum++;
			}
			if(!itemMap.containsKey(cid)){
				itemMap.put(cid, itemNum);
				itemNum++;
			}
			sum+=rate;
			lineNum++;
		}
		mean = sum/lineNum;
		//初始化矩阵
		rateMatrix = new double[userNum][itemNum];
		br = new BufferedReader(new FileReader(new File(file)));
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			String uid = arr[0];
			String cid = arr[1];
			double rate = Double.parseDouble(arr[2]);
			int row = userMap.get(uid);
			int column = itemMap.get(cid);
			rateMatrix[row][column] = rate;
		}
		
	}
	//两个向量计算
	private double calVector(double[] p,double[] q){
		double value = 0.0;
		for(int i = 0; i < dimNum; i++){
			value += p[i] * q[i];
		}
		return value;
	}
	//训练参数
	public void train(){
		double rmse = 0.0;
		double record = 0.0;
		for(int i = 0; i < iteration; i++){
			int rateCalTimes = 0;
			for(int u = 0; u < userNum; u++){
				for(int c = 0; c < itemNum; c++){
					if(rateMatrix[u][c] == 0){
						continue;
					}
					double rui = mean + bu[u] + bi[c] + calVector(pu[u], qi[c]);
					double error = rateMatrix[u][c] - rui;
					bu[u] += learningrate*(error - coff * bu[u]);
					bi[c] += learningrate * (error - coff * bi[c]);
					for(int d = 0; d < dimNum; d++){
						pu[u][d] += learningrate * (error * qi[c][d] - coff * pu[u][d]);
						qi[c][d] += learningrate * (error * pu[u][d] - coff * qi[c][d]);
					}
					rmse += error * error;
					rateCalTimes++;
				}				
			}
			rmse = Math.sqrt(rmse/rateCalTimes);
			learningrate *= 0.9;
//			if(record!=0.0){
//				if(rmse > record){
//					learningrate *= 0.1;					
//				}
//			}
//			record = rmse;
			System.out.println(String.format("iterator is %s,rmse is %s",i,rmse));
		}
	}
	//测试预测分数
	public void test(String file,String separator) throws IOException{
		BufferedReader br = new BufferedReader(new FileReader(new File(file)));
		String line = "";
		double rmse = 0.0;
		int count = 0;
		int preciseNum = 0;
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			String uid = arr[0];
			String cid = arr[1];
			double rate = Double.parseDouble(arr[2]);
			if(!userMap.containsKey(uid)){
				continue;
			}
			if(!itemMap.containsKey(cid)){
				continue;
			}
			int u = userMap.get(uid);
			int c = itemMap.get(cid);
			double rui = mean + bu[u] + bi[c] + calVector(pu[u], qi[c]);
//			System.out.println(String.format("rate is %s; rui is %s", rate,rui));
			rmse += Math.pow((rate-rui),2);
			count++;
			double error = rate -rui;
			if(error <= 0.5 && error >= -0.5){
				preciseNum++;
			}
		}
		rmse = Math.sqrt(rmse/count) ;
		System.out.println("rmse is " + rmse);
		double ratio = preciseNum * 1.0/count;
		System.out.println(String.format("precise number is %s; ratio is %s" ,preciseNum,ratio));
		
	}
	
	
	//拆分数据集,使用movielen的数据集
	public static void splitText(String in,String out1,String out2,String separator) throws IOException{
		BufferedReader br = new BufferedReader(new FileReader(new File(in)));
		String line = "";
		BufferedWriter bw1 = new BufferedWriter(new FileWriter(new File(out1)));
		
		BufferedWriter bw2 = new BufferedWriter(new FileWriter(new File(out2)));
		
		String uid = "";
		int count = 0;
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			line += "\r\n";
			if(!uid.equals(arr[0])){
				uid = arr[0];
				count = 0;
			}
			if(uid.equals(arr[0]) && count < 10){
				bw2.write(line);
				count++;
			}
			else{
				bw1.write(line);
			}
			
		}
		bw1.close();
		bw2.close();
		br.close();
	}
	
	
	
	
			
}

参考资料:

https://www.cnblogs.com/Xnice/p/4522671.html

https://blog.csdn.net/zhongkejingwang/article/details/43083603

推荐系统-技术、评估及高效算法

你可能感兴趣的:(svd,推荐系统,recommender,svd)