Sprak Java 推荐算法的思路和实现

推荐算法在org.apache.spark.ml.recommendation 或者org.apache.spark.mlib.recommendation下面
相比于org.apache.spark.mlib.recommendation面向RDD算子来计算,org.apache.spark.ml.recommendation面向DataFrame来计算,往后spark会逐渐采用dataframe来计算,虽然对比mlib包,ml没有那么容易用,但是方法更丰富。
本次的计算就用org.apache.spark.ml.recommendation包下的类。


1 常用的推荐算法:用户协同过滤法,通常有UCF(以用户为主体,通过算法得到和本用户类似的用户,找出类似用户喜欢的,但本用户还没有接触的物体进行推荐)和ICF(以物品为主体,推荐与此物体同时被关注的物体)两种方式。


UCF 方法:如下图,用户A和用户C比较类似,并且用户A没有关注过物品D,因此按照UCF的逻辑,会推荐物品D给用户A

Sprak Java 推荐算法的思路和实现_第1张图片

ICF 算法: 如下图,看过物品A的都看过物品C,因此将物品C推荐给用户C

Sprak Java 推荐算法的思路和实现_第2张图片


以上是协同算法的基本思路,在spark 中利用ASL(交叉最小二乘法)来实现以上思路

在实际的操作中,根据历史记录得到用户对商品的评分,可以获得以下矩阵: 其中?号表示用户没有购买过此用品,而数字则表示用户购买后的评分。

假设该矩阵为 m*n阶R,ASL的作用就是根据这些数据来推断填充矩阵里面?的数据,从而实现对用户的推荐

Sprak Java 推荐算法的思路和实现_第3张图片

将该m*n阶矩阵,分解成m*kX矩阵和n*kY矩阵,可以简化计算 ,其中k

用公式表达如下:Y矩阵需要转置


通俗来说,R矩阵 =所有用户对所有商品的评分,X矩阵为用户对商品的偏好,Y矩阵为商品的特征

为了让左边和右边接近相等,需要计算最小平方差(这也是衡量ASL算法的一个指标),计算最小平方差的时候需要添加正则项(就是均衡函数的计算 )来避免过分拟合(具备更好的通用性)

2 例子

package com.apache.spark.example.ml;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

// $example on$
import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
// $example off$

public class JavaALSExample {

  // $example on$
  public static class Rating implements Serializable {
    private int userId;
    private int movieId;
    private float rating;
    private long timestamp;

    public Rating() {}

    public Rating(int userId, int movieId, float rating, long timestamp) {
      this.userId = userId;
      this.movieId = movieId;
      this.rating = rating;
      this.timestamp = timestamp;
    }

    public int getUserId() {
      return userId;
    }

    public int getMovieId() {
      return movieId;
    }

    public float getRating() {
      return rating;
    }

    public long getTimestamp() {
      return timestamp;
    }

    public static Rating parseRating(String str) {
      String[] fields = str.split("::");
      if (fields.length != 4) {
        throw new IllegalArgumentException("Each line must contain 4 fields");
      }
      int userId = Integer.parseInt(fields[0]);
      int movieId = Integer.parseInt(fields[1]);
      float rating = Float.parseFloat(fields[2]);
      long timestamp = Long.parseLong(fields[3]);
      return new Rating(userId, movieId, rating, timestamp);
    }
  }
  // $example off$

  public static void main(String[] args) {
    SparkSession spark = SparkSession
      .builder()
      .appName("JavaALSExample")
      .getOrCreate();

    // $example on$
    JavaRDD ratingsRDD = spark
      .read().textFile("测试数据路径").javaRDD()
      .map(Rating::parseRating);
    Dataset ratings = spark.createDataFrame(ratingsRDD, Rating.class);
    Dataset[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
    Dataset training = splits[0];
    Dataset test = splits[1];

    // 获得ALS对象,设置最大的迭代次数和最小平方差。该对象用来训练已有数据得到模型
    ALS als = new ALS()
      .setMaxIter(5)
      .setRegParam(0.01)
      .setUserCol("userId")
      .setItemCol("movieId")
      .setRatingCol("rating");
    ALSModel model = als.fit(training);

    // Evaluate the model by computing the RMSE on the test data
    // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
    model.setColdStartStrategy("drop");
    Dataset predictions = model.transform(test);

    RegressionEvaluator evaluator = new RegressionEvaluator()
      .setMetricName("rmse")
      .setLabelCol("rating")
      .setPredictionCol("prediction");
    Double rmse = evaluator.evaluate(predictions);
    System.out.println("Root-mean-square error = " + rmse);

    // Generate top 10 movie recommendations for each user
    Dataset userRecs = model.recommendForAllUsers(10);
    // Generate top 10 user recommendations for each movie
    Dataset movieRecs = model.recommendForAllItems(10);
    // $example off$
    userRecs.show();
    movieRecs.show();

    spark.stop();
  }
}


你可能感兴趣的:(小结,spark)