ALS 推荐系统

1:ALS(alternating least squares ):交替最小二乘法

在机器学习中,特指使用最小二乘法的一种协同推荐算法。如下图所示,u表示用户,v表示商品,用户给商品打分,但是并不是每一个用户都会给每一种商品打分。? 表示用户没有打分的情况,所以这个矩阵A很多元素都是空的,我们称其为“缺失值(missing value)”。协同过滤提出了一种支持不完整评分矩阵的矩阵分解方法,不用对评分矩阵进行估值填充。


ALS 推荐系统_第1张图片

和协同过滤不一样的是,ALS认为用户的评分矩阵是有用户特征矩阵和物品特征矩阵相乘得到的。

ALS 的核心假设是:打分矩阵A是近似低秩的,即一个mn的打分矩阵 A 可以用两个小矩阵U(mk)V(nk)的乘积来近似:其中k<

Am×n=Um×k×Vk×n

我们把打分理解成相似度,那么“打分矩阵A(mn)”就可以由“用户喜好特征矩阵U(mk)”和“产品特征矩阵V(nk)”的乘积。

ALS 推荐系统_第2张图片

  • 给定隐含特征的数量,用随机数初始化用户-特征矩阵和商品-特征矩阵
  • 用梯度下降法交替的优化这两个矩阵,用商品矩阵的各维度作为用户矩阵的梯度下降的方向,反之亦然
  • 优化结束后,计算用户特征向量和商品特征向量的相似度(内积、余弦……),这就是用户对商品的偏好打分

2.Spark Mllib

 Spark使用的是交叉最小二乘法(ALS)来最优化损失函数。算法的思想就是:我们先随机生成然后固定它求解,再固定求解,这样交替进行下去,直到取得最优解min(C)。因为每步迭代都会降低误差,并且误差是有下界的,所以 ALS 一定会收敛。但由于问题是非凸的,ALS 并不保证会收敛到全局最优解。但在实际应用中,ALS 对初始点不是很敏感,是否全局最优解造成的影响并不大。(也可能是一个局部最优解)

3. MLlib的ALS实现

ALS伴生对象是建立ALS模型的入口,其主要定义训练线性回归模型的train方法,train方法通过设置训练参数进行模型训练,其参数主要包括:
  • ratings-----评分RDD格式(userID,productID,rating)对;
  • rank------特征数量
  • iterations------迭代次数
  • lambda------正则因子(推荐值为0.01)
  • blocks-----数据分隔
  • seed------随机种子

4. 优化步骤

  1. 一个用户特征和一个商品特征相乘,得到用户对商品的偏好(单元格)
  2. 已知偏好的单元格,乘的结果要和已知的值尽量接近(MSE,总的方差最小)
  3. 用梯度下降法交替的优化用户特征和商品特征(ALS)

5. 梯度下降

ALS 推荐系统_第3张图片
  1. n个隐含特征=在n维空间里优化用户、商品特征
  2. 找一个下降最快的方向(拉格朗日乘数法、随机……)
  3. 朝着这个方向走一小步
  4. 回到1,直到总的偏差不再下降


实现coding


package hhc.mllib.label.learn.recommend;

import hhc.mllib.label.learn.config.AppConfig;
import hhc.mllib.label.learn.ml.CreaterBase;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;

import java.util.Arrays;
import java.util.List;

/**
 * Created by huhuichao on 2017/12/7.
 */
public class ALSModelCreater extends CreaterBase{



    private MatrixFactorizationModel model;

    private transient JavaSparkContext jsc;


    public ALSModelCreater(JavaSparkContext jsc) {
        this.jsc = jsc;
    }

    /**
     * 读取样本数据
     * @param path
     * @return
     */
    public static JavaRDD getALSJavaRDD(String path, JavaSparkContext sc,String split) {
        JavaRDD data=sc.textFile(path);
        JavaRDD ratings = data.map(
                new Function() {
                    public Rating call(String s) {
                        String[] sarray = s.split(split);
                        return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
                                Double.parseDouble(sarray[2]));
                    }
                }
        );
        return ratings;
    }


    public  MatrixFactorizationModel training (JavaRDD ratings,int rank, int numIterations, double v){
        return  ALS.train(ratings.rdd(), rank, numIterations, v);
    }

    /**
     *
     * 计算方差
     * @param ratings  样本数据
     *  @param model  model
     * @return
     */
    public static double evaluateMSE(JavaRDD ratings,MatrixFactorizationModel model) {

        JavaRDD>   userProducts = ratings.map(
                    new Function>() {
                        private static final long serialVersionUID = 1L;
                        @Override
                        public Tuple2 call(Rating r) {
                            return new Tuple2(r.user(), r.product());
                        }
                    }
            );
        JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD(
                model.predict( JavaRDD.toRDD(userProducts)).toJavaRDD().map(
                        new Function, Object>>() {
                            private static final long serialVersionUID = 1L;
                            @Override
                            public Tuple2, Object> call(Rating r) {
                                return new Tuple2, Object>(
                                        new Tuple2<>(r.user(), r.product()), r.rating());
                            }
                        }
                ));
        JavaRDD> ratesAndPreds =
                JavaPairRDD.fromJavaRDD(ratings.map(
                        new Function, Object>>() {
                            private static final long serialVersionUID = 1L;
                            @Override
                            public Tuple2, Object> call(Rating r) {
                                return new Tuple2, Object>(
                                        new Tuple2<>(r.user(), r.product()), r.rating());
                            }
                        }
                )).join(predictions).values();

        // Create regression metrics object
        RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd());
        return regressionMetrics.meanSquaredError();
    }


    /**
     * 获取矩阵分解后的物品特征矩阵
     * @param model
     * @return
     */
    public static JavaPairRDD  getProductPeatures(MatrixFactorizationModel model){
        return JavaPairRDD.fromJavaRDD(model.productFeatures().toJavaRDD());
    }


    /**
     * recommendProductsForUsers  对所有用户推荐物品,取前n个物品
     * @param num
     * @param model
     * @return
     */
    public static JavaPairRDD recommendProductsForUsers(int num,MatrixFactorizationModel model) {

        RDD> tuple2RDD = model.recommendProductsForUsers(num);
        JavaRDD> tuple2JavaRDD = tuple2RDD.toJavaRDD();
        JavaPairRDD productFeatures=JavaPairRDD.fromJavaRDD(tuple2JavaRDD);
        return productFeatures;
    }
    public static void main(String[] args) {
//        ALSModelCreater alsModel=new ALSModelCreater(AppConfig.getInstance().sc);
//        //读取样本数据
//        JavaRDD ratings= getALSJavaRDD("data/ml/recommend/als/test.data", alsModel.jsc,",");
        List list=ratings.collect();
//        //建立模型
//        int rank=10;
//        int numIterations=5;
//        MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, numIterations, 0.01);
//
//        System.out.println("Mean Squared Error = " + evaluateMSE(ratings,model));
//        model.save(alsModel.jsc.sc(),"data/ml/recommend/als/model");
        MatrixFactorizationModel model=MatrixFactorizationModel.load(AppConfig.getInstance().sc.sc(),"data/ml/recommend/als/model");
        System.out.println(Arrays.toString(model.recommendProducts(4,2)));

        JavaPairRDD productFeatures = getProductPeatures(model);
        List> list=productFeatures.collect();
        System.out.println(list);

        JavaPairRDD features=recommendProductsForUsers(2,model);
        List> list1=features.collect();
        System.out.println(list1);
    }


}


你可能感兴趣的:(推荐系统)