Spark的MLlib实现了协同过滤(Collaborative Filtering)这个功能。官网文档链接
熟悉推荐算法的同学可能也有这个认识:协同过滤主要分为3大类——1、基于User的协同过滤;2、基于Item的协同过滤;3、基于Model的协同过滤。前面两个比较简单不多描述了,主要讲下基于Model的协同过滤。在网上找到一个对基于Model的协同过滤的算法总结包括:Aspect Model,pLSA,LDA,聚类,SVD,Matrix Factorization等。不管这句话说的是否严谨(比如还有二分图模型),总之我认为Spark MLlib目前(2.2.0版本)并不能算是完整的协同过滤。只是做了基于Model的协同过滤中的矩阵分解内容。当然做好了矩阵分解,接下来再做别的也就轻松了。
关于基于Model的矩阵分解,可以参考矩阵分解在协同过滤推荐算法中的应用。Spark的MLlib中使用的是ALS(Alternating Least Squares (ALS) matrix factorization)算法。这个可以看成是对FunkSVD的一种求解实现。不过考虑到有时候我们输入的User-Item的rating可能不是某种评判的数值打分,而是User对于Item的某种偏好,此时使用ALS-WR(alternating-least-squares with weighted-λ-regularization)通过置信度权重来重新定义目标函数,从而得到新的结果。关于ALS和ALS-WR可以参考协同过滤之ALS-WR算法和机器学习(十四)——协同过滤的ALS算法(2)、主成分分析以及协同过滤 CF & ALS 及在Spark上的实现
上面主要是理论基础部分,熟悉了理论基础后,我们看下通过Spark的MLlib的落地实现,我们需要做哪些工作。同时依然建议参考另2篇文章ALS-WR(协同过滤推荐算法) in ML和深入理解Spark ML:基于ALS矩阵分解的协同过滤算法与源码分析
Collaborative filtering
正如前面所讲的,我们的工作是要把评分矩阵用User和Item的latent factors表达出来。MLlib通过ALS算法来学习得到User以及Item的latent factors,在具体的实现中需要以下参数:
- numBlocks is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10). 用于并行计算,同时设置User和Item的block数目,还可以使用numUserBlocks和numItemBlocks分别设置User和Item的block数目。
- rank is the number of latent factors in the model (defaults to 10). 表示latent factors的长度。对于这个值的设置参见What is recommended number of latent factors for the implicit collaborative filtering using ALS
- maxIter is the maximum number of iterations to run (defaults to 10). 交替计算User和Item的latent factors的迭代次数。
- regParam specifies the regularization parameter in ALS (defaults to 1.0). L2正则的系数lambda
- implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false which means using explicit feedback). 表示原始User和Item的rating矩阵的值是否是评判的打分值,False表示是打分值,True表示是矩阵的值是某种偏好。
- alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference observations (defaults to 1.0). 当implicitPrefs为true时,表示对原始rating的一个置信度系数,用于和rate相乘,是一个常值。可以根据对于原始数据的观察,统计先设置一个值,然后再进行后续的tuning。
- nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false). 对应于选择求解最小二乘的方法:if (nonnegative) new NNLSSolver else new CholeskySolver。如果True就是用非负正则化最小二乘(NNLS),False就是用乔里斯基分解(Cholesky)
Note: 基于DataFrame的MLlib API目前只支持integer类型的user和Item的id。其他numeric类型的user和item id列也支持,不过ids必须在integer的取值范围内。这里的numeric类型指的是java.lang.Number,看了下源码感觉负值也应该是可以的。
除了上面文档中的参数,还有一些别的参数设置也有必要列出来(下面的Dataset
- userCol:用户列的名字,String类型。对应于后续调用fit()操作时输入的Dataset
入参时用户id所在schema中的name
- itemCol:item列的名字,String类型。对应于后续调用fit()操作时输入的Dataset
入参时item id所在schema中的name
- ratingCol:rating列的名字,String类型。对应于后续调用fit()操作时输入的Dataset
入参时rating值所在schema中的name
- predictionCol:String类型。做transform()操作时输出的预测值在Dataset
结果的schema中的name,默认是“prediction”
- coldStartStrategy:String类型。有两个取值"nan" or "drop"。这个参数指示用在prediction阶段时遇到未知或者新加入的user或item时的处理策略。尤其是在交叉验证或者生产场景中,遇到没有在训练集中出现的user/item id时。"nan"表示对于未知id的prediction结果为NaN。"drop"表示对于transform()的入参DataFrame中出现未知ids的行,将会在包含prediction的返回DataFrame中被drop。默认值是"nan"
Explicit和implicit feedback
标准的协同过滤中的矩阵分解(matrix factorization)都是对user-item的打分矩阵做因子分解,比如用户对电影的打分,也称为显式反馈(explicit feedback)。
不过在现实情况中,很多user-item都不是某种特定意义的评分,而是一些比如用户的购买记录、搜索关键字,甚至是鼠标的移动。我们将这些间接用户行为称之为隐式反馈(implicit feedback)。
在Spark中处理隐式反馈的算法是ALS-WR。可以重点看下前面给出的参考链接中的算法结果,观察损失函数,就可以知道大致过程。
正则化系数
这里指的是在ALS算法中L2正则项的系数,用来防止过拟合,也能使矩阵的因子分解后的U和V矩阵的值不会太震荡,方便接下来对U和V矩阵再做进一步的利用。
而且Spark通过ALS-WR算法使得 regParam 较少的被数据集的规模所影响。这样可以使得在样本子集中学习得到的最佳参数可以应用在数据全集上而且获得相似的性能。
冷启动策略
我们使用训练后的 ALSModel 对test数据进行预测,不过可能会遇到没有出现在训练模型中的user或者item id,这是由以下两种情况产生引起的:
- 在生成中:本来就会有新的user或者item上线,是之前训练时不曾有的(这也称之为“cold start problem”)
- 在交叉验证阶段:不管是用Spark的 CrossValidator 或者 TrainValidationSplit 都有可能出现验证集中的id是训练集中没有出现过的。
默认Spark使用NaN来表示对于未知id的rate的预测结果,这样在生产中可以提示系统有新的id加入,作为接下来是否采取措施的依据。
不过在交叉验证阶段,NaN会妨碍接下来的评分度量 (比如使用 RegressionEvaluator ),此时可以选择"drop"来使得出现NaN的行都丢掉。方便调参时做模型选择。
举个栗子
下面这个栗子也是官网文档中的栗子。首先看下数据的模样:
然后是代码:
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// $example on$
import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
// $example off$
public class JavaALSExample {
// $example on$
public static class Rating implements Serializable {
private int userId;
private int movieId;
private float rating;
private long timestamp;
public Rating() {}
public Rating(int userId, int movieId, float rating, long timestamp) {
this.userId = userId;
this.movieId = movieId;
this.rating = rating;
this.timestamp = timestamp;
}
public int getUserId() {
return userId;
}
public int getMovieId() {
return movieId;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
public static Rating parseRating(String str) {
String[] fields = str.split("::");
if (fields.length != 4) {
throw new IllegalArgumentException("Each line must contain 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int movieId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, movieId, rating, timestamp);
}
}
// $example off$
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaALSExample")
.getOrCreate();
// $example on$
JavaRDD ratingsRDD = spark
.read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD()
.map(Rating::parseRating);
Dataset ratings = spark.createDataFrame(ratingsRDD, Rating.class);
Dataset[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
Dataset training = splits[0];
Dataset test = splits[1];
// Build the recommendation model using ALS on the training data
ALS als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
model.userFactors();
model.itemFactors();
// Evaluate the model by computing the RMSE on the test data
// Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
model.setColdStartStrategy("drop");
Dataset predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
// Generate top 10 movie recommendations for each user
Dataset userRecs = model.recommendForAllUsers(10);
// Generate top 10 user recommendations for each movie
Dataset movieRecs = model.recommendForAllItems(10);
// Generate top 10 movie recommendations for a specified set of users
//todo: Those API @Since("2.3.0")
// Dataset users = ratings.select(als.getUserCol()).distinct().limit(3);
// Dataset userSubsetRecs = model.recommendForUserSubset(users, 10);
// // Generate top 10 user recommendations for a specified set of movies
// Dataset movies = ratings.select(als.getItemCol()).distinct().limit(3);
// Dataset movieSubSetRecs = model.recommendForItemSubset(movies, 10);
// $example off$
userRecs.show();
movieRecs.show();
// userSubsetRecs.show();
// movieSubSetRecs.show();
spark.stop();
}
}
代码还是不难的,建议在IDEA中阅读看下。实际使用时还需要加上tuning环节来对rank,maxIter,regParam ,alpha 甚至numBlocks进行调参。