增加隐式反馈的svd 推荐

基于svd++的java代码实现,实现了评分矩阵分解的参数计算,使用随机梯度下降,计算参数。

参考:

https://www.cnblogs.com/Xnice/p/4522671.html

https://blog.csdn.net/zhongkejingwang/article/details/43083603

推荐系统-技术、评估及高效算法

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SVDPlus {

	//定义参数
	private double mean;
	private double[] bu;
	private double[] bi;
	private double[][] pu;
	private double[][] qi;
	private double[][] y;
	
	private int userNum;
	private Map userMap = new HashMap();
	private Map userIndexMap = new HashMap ();
	private int itemNum;
	private Map itemMap = new HashMap();
	private Map itemIndexMap = new HashMap ();
	//定义超参数
	private double learningrate;
	private double lamda1;
	private double lamda2;
	private int iter;
	private int dim;
	
	public SVDPlus(double lr,double coff1,double coff2, int it,int d){
		learningrate = lr;
		lamda1 = coff1;
		lamda2 = coff2;
		iter = it;
		dim = d;
	}
	//初始化参数
	public void init(){
		bu = new double[userNum];
		bi = new double[itemNum];
		pu = new double[userNum][dim];
		qi = new double[itemNum][dim];
		y = new double[itemNum][dim];
		
		//对pu,qi,yj赋初值
		for(int i = 0; i < userNum; i++){
			for(int j = 0; j < dim; j++){
				pu[i][j] = Math.random();
			}
		}
		for(int i = 0; i < itemNum; i++){
			for(int j = 0; j < dim; j++){
				qi[i][j] = Math.random();
			}
		}
		for(int i = 0; i < itemNum; i++){
			for(int j = 0; j < dim; j++){
				y[i][j] = Math.random();
			}
		}
	}
	//统计用户的评分物品量
	private Map> user_items = new HashMap();
	//评分矩阵
	private double[][] rateMatrix;	
	//构造矩阵
	public void matrix(String file,String separator) throws IOException{
		BufferedReader br = new BufferedReader(new FileReader(new File(file)));
		String line = "";
		double sum = 0.0;
		int lineNum = 0;
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			String uid = arr[0];
			String cid = arr[1];
			double rate = Double.parseDouble(arr[2]);
			if(!userMap.containsKey(uid)){
				userMap.put(uid, userNum);
				userIndexMap.put(userNum,uid);
				userNum++;
			}
			if(!itemMap.containsKey(cid)){
				itemMap.put(cid, itemNum);
				itemIndexMap.put(itemNum, cid);
				itemNum++;
			}
			sum+=rate;
			lineNum++;
			
			//填充用户评分物品
			List items;
			if(user_items.containsKey(uid)){
				items = user_items.get(uid);
			}
			else{
				items = new ArrayList();
			}
			items.add(cid);
			user_items.put(uid, items);
		}
		mean = sum/lineNum;
		//初始化矩阵
		rateMatrix = new double[userNum][itemNum];
		br = new BufferedReader(new FileReader(new File(file)));
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			String uid = arr[0];
			String cid = arr[1];
			double rate = Double.parseDouble(arr[2]);
			int row = userMap.get(uid);
			int column = itemMap.get(cid);
			rateMatrix[row][column] = rate;
		}		
	}
	
	public double[] addVector(double[] pu,double[] sum){
		//用户评分偏置向量与隐式反馈向量求和
		double[] plus = new double[dim];
		for(int d = 0; d < dim; d++){
			plus[d] = pu[d] + sum[d];
		}		
		return plus;
	}
	
	public double multVector(double[] plus,double[] qi){
		double val = 0.0;
		//向量乘积
		for(int d = 0; d < dim; d++){
			val += plus[d]* qi[d];
		}
		return val;
	}
	
	
	
	public void train(){
		for(int i = 0; i < iter; i++){
			double rmse = 0.0;	
			int rateCalTimes = 0;
			for(int u = 0; u < userNum; u++){
				//求|Ru|^(-0.5)
				String uid = userIndexMap.get(u);
				List items = user_items.get(uid);
				int size = items.size();
				double Ru = 1.0 / Math.sqrt(size * 1.0);
				
				for(int c = 0; c < itemNum; c++){
					if(rateMatrix[u][c] == 0.0){
						continue;
					}
					//每执行一条样本,需要计算一次参数
					//求|Ru|^(-0.5) * Σyj
					double[] sum = new double[dim];
					for(String cid : items){
						int item_index = itemMap.get(cid);
						double[] yj = y[item_index];
						for(int d = 0; d < dim; d++){
							sum[d] += Ru * yj[d];
						}
					}
					//pu + |Ru|^(-0.5) * Σyj
					double[] plus = addVector(pu[u], sum);
					
					double rui = mean + bu[u] + bi[c] + multVector(plus, qi[c]);
					//计算误差
					double error = rateMatrix[u][c] - rui;
					//更新超参数
					bu[u] += learningrate * (error - lamda1 * bu[u]);
					bi[c] += learningrate * (error - lamda1 * bi[c]);
					for(int d = 0; d < dim; d++){
						qi[c][d] += learningrate * (error * plus[d] - lamda2 * qi[c][d]);
						pu[u][d] += learningrate *(error * qi[c][d] - lamda2 * pu[u][d]);
					}
					for(String cid : items){
						int index = itemMap.get(cid);
						for(int d = 0; d < dim; d++){
							y[index][d] += learningrate * (error * Ru * qi[index][d] - lamda2 * y[index][d]);
						}
					}
					
					//计算rmse
					rmse += Math.pow(error,2);
					rateCalTimes++;
				}
			}
			rmse = Math.sqrt(rmse / rateCalTimes);
			learningrate *= 0.9;
			System.out.println(String.format("iteration is %s; rmse is %s", i,rmse));
		}
	}
	
	//计算(pu + |Ru|^(-0.5) * Σyj)qi
	public double calValue(String uid,String cid){
		double val = 0.0;
		int uIndex = userMap.get(uid);
		int iIndex = itemMap.get(cid);
		List items = user_items.get(uid);
		int size = items.size();
		double Ru = 1.0 / Math.sqrt(size * 1.0);
		double[] sum = new double[dim];
		for(String iid : items){
			int item_index = itemMap.get(iid);
			double[] yj = y[item_index];
			for(int d = 0; d < dim; d++){
				sum[d] += Ru * yj[d];
			}
		}
		val = multVector(addVector(pu[uIndex], sum), qi[iIndex]);
		return val;
	}
	
	public void test(String file,String separator) throws IOException{
		BufferedReader br = new BufferedReader(new FileReader(new File(file)));
		String line = "";
		double rmse = 0.0;
		int count = 0;
		while((line = br.readLine())!= null){
			String[] arr = line.split(separator);
			String uid = arr[0];
			String cid = arr[1];
			double rate = Double.parseDouble(arr[2]);
			if(!userMap.containsKey(uid) || !itemMap.containsKey(cid)){
				continue;
			}
			int userIndex = userMap.get(uid);
			int itemIndex = itemMap.get(cid);
			double rui = mean + bu[userIndex] + bi[itemIndex] + calValue(uid, cid);
			double error = rate - rui;
			rmse += Math.pow(error, 2);
			count++;
		}
		rmse = Math.sqrt(rmse /count);
		System.out.println("rmse is " + rmse);
	}	 

}

 

你可能感兴趣的:(svd,矩阵分解)