Spark机器学习之协同过滤算法使用-Java篇

协同过滤通常用于推荐系统,这些技术旨在填补用户和项目关联矩阵里面缺少的值。Spark目前实现基于模型的协同过滤,其中模型的用户和项目由一组小的潜在因素所描述,可用于预测缺少的值。Spark使用交替最小二乘法alternating least squares(ALS)算法来学习这些潜在因素。

1. ALS的参数
  • numBlocks:用户和项目将会被分区的块数,以便并行化计算(默认值为10)
  • rank:模型中潜在因素的数值(默认值为10)
  • maxIter:要运行的最大迭代次数(默认值为10)
  • regParam:指定的正则化参数(默认值为1.0)
  • implicitPrefs:是否使用隐式反馈(默认为false,使用显式反馈)
  • alpha:当使用隐式反馈时,用于控制偏好观察的基线置信度(默认值为1.0)
  • nonnegative:是否对最小二乘法使用非负约束 (默认值为false)
2. 冷启动(Cold-start)策略

当使用ALSModel进行预测时,在训练模型期间,普遍会在测试数据集中遇到用户和/或项目不存在的情况。这一般出现在以下两种情型:
  • 在生产环境中,对于没有评级历史的新用户或项目,和未经过训练的模型(这是“冷启动问题”)
  • 在交叉验证期间,数据被拆分成训练集和评估集。当使用Spark的CrossValidator或TrainValidationSplit中的简单随机拆分时,评估集里面的用户和/或项目不在训练集里面是非常常见的
默认地,当模型中不存在的用户和/或项目因素时,Spark在调用ALSModel.transform方法时,预测的值会是NaN。这在生产系统中可以是有用的,因为NaN表示一个新的用户或项目,因此系统可以预测作出一些回退的决定。

然而,在交叉验证期间这是不可取的,因为任何NaN预测值将导致评估指标的NaN结果(例如当使用RegressionEvaluator的时候)。这使得模型的选择变得不可能。

Spark允许用户将coldStartStrategy参数设置为”drop”,以便删除DataFrame中包含预测NaN值的任何行,然后会根据非NaN的数据计算评估指标。

注意:目前支持的冷启动策略是“nan”(默认)和“drop”,未来可能会支持其它的策略。

3. Java代码例子

本文使用Spark 2.2.0、Java 1.8版本,测试数据可以在以下链接下载:

http://files.grouplens.org/datasets/movielens/ml-100k.zip

import java.io.Serializable;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaALSExample {

	public static class Rating implements Serializable {

		private static final long serialVersionUID = 1L;
		private int userId;
		private int movieId;
		private float rating;
		private long timestamp;

		public Rating() {
		}

		public Rating(int userId, int movieId, float rating, long timestamp) {
			this.userId = userId;
			this.movieId = movieId;
			this.rating = rating;
			this.timestamp = timestamp;
		}

		public int getUserId() {
			return userId;
		}

		public int getMovieId() {
			return movieId;
		}

		public float getRating() {
			return rating;
		}

		public long getTimestamp() {
			return timestamp;
		}

		public static Rating parseRating(String str) {
			String[] fields = str.split("\\t");
			if (fields.length != 4) {
				throw new IllegalArgumentException("Each line must contain 4 fields");
			}
			int userId = Integer.parseInt(fields[0]);
			int movieId = Integer.parseInt(fields[1]);
			float rating = Float.parseFloat(fields[2]);
			long timestamp = Long.parseLong(fields[3]);
			return new Rating(userId, movieId, rating, timestamp);
		}
	}

	public static void main(String[] args) {
	    // 测试数据文件路径
		String path = "ml-100k/u.data";
		// 使用本地所有可用线程local[*]
		SparkSession spark = SparkSession.builder().master("local[*]").appName("JavaALSExample").getOrCreate();
		JavaRDD ratingsRDD = spark.read().textFile(path).javaRDD().map(Rating::parseRating);
		Dataset ratings = spark.createDataFrame(ratingsRDD, Rating.class);
		// 按比例随机拆分数据
		Dataset[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 });
		Dataset training = splits[0];
		Dataset test = splits[1];

		// 对训练数据集使用ALS算法构建建议模型
		ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId")
				.setRatingCol("rating");
		ALSModel model = als.fit(training);

		// Evaluate the model by computing the RMSE on the test data
		// 通过计算均方根误差RMSE(Root Mean Squared Error)对测试数据集评估模型
		// 注意下面使用冷启动策略drop,确保不会有NaN评估指标
		model.setColdStartStrategy("drop");
		Dataset predictions = model.transform(test);
		
		// 打印predictions的schema
        predictions.printSchema();
		
		// predictions的schema输出
        // root
		// |-- movieId: integer (nullable = false)
		// |-- rating: float (nullable = false)
		// |-- timestamp: long (nullable = false)
		// |-- userId: integer (nullable = false)
		// |-- prediction: float (nullable = true)

		RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating")
				.setPredictionCol("prediction");
		double rmse = evaluator.evaluate(predictions);
		// 打印均方根误差
		System.out.println("Root-mean-square error = " + rmse);
	}

}

打印均方根误差结果为:Root-mean-square error = 1.0645093959897054,这个值是越小越好,如果得出的值不符合预期,可以调整ALS的参数重新计算直到符合预期为止。然后可以分别对所有用户和项目进行建议:

// Generate top 10 movie recommendations for each user
Dataset userRecs = model.recommendForAllUsers(10);
		
// Generate top 10 user recommendations for each movie
Dataset movieRecs = model.recommendForAllItems(10);

* 参考Spark Collaborative Filtering官方链接:http://spark.apache.org/docs/latest/ml-collaborative-filtering.html

END O(∩_∩)O

你可能感兴趣的:(Spark)