spark
算法原理
协同过滤是用来对用户的兴趣偏好做预测的一种方法。在Spark中实现的是基于潜在因子模型的协同过滤。用户对特定物品的偏好往往可以用评分的形式给出,评分矩阵r的行数对应用户数量,列数对应物品总数,比如以下4个用户对四个电影评分:
本方法的核心在于把评分矩阵分解为用户偏好矩阵(x)和物品偏好因子矩阵(y):
我们的目标是找到最佳的x和y矩阵,使得这两个矩阵相乘时得到的预测矩阵与原始评分矩阵r之间的误差最小。转化为数学描述,就是使得以下目标函数最小化:
该目标函数由两部分构成,前半部分是平方误差,后半部分使用L2正则化,引入 λ 常数,对模型的复杂度进行控制,防止过度拟合训练数据。
Spark使用的是带正则化矩阵分解,优化函数的方式选用的是交叉最小二乘法ALS(alternative least squares),它的一般执行步骤如下:
- 用随机数初始化物品偏好因子矩阵y
- 固定y,找到可以最小化目标函数的用户偏好矩阵x
- 固定x,类同步骤2,找到最小化目标函数的物品偏好因子矩阵y
- 重复步骤2和3,直到满足算法收敛条件
ALS spark mllib 代码
参数详解:
输入参数名称 | 数据格式 | 必填/可选/固定 | 默认值 | 取值范围 | 备注 |
---|---|---|---|---|---|
userCol | String | 必填 | 用户ID | ||
itemCol | String | 必填 | 物品ID | ||
ratingCol | String | 必填 | 用户给物品的评分列 | ||
rank | Int | 可选 | 10 | ≥1 | 潜在因子数量,最优值需要根据具体数据制定 |
maxIter | Int | 可选 | 10 | ≥0 | 最大循环次数 |
lambda | Double | 可选 | 0.01 | ≥0 | 正则化参数 λ |
numUserBlocks | Int | 可选 | 10 | ≥1 | 把用户偏好矩阵拆分成小块以满足并行化需求 |
numItemBlocks | Int | 可选 | 10 | ≥1 | 把物品偏好因子矩阵拆分成块以满足并行化需求 |
implicitPrefs | Boolean | 必填 | false | false或true | 是否为推测出来的用户偏好(比如,如果一个用户购买过物品A,则推测对A有偏好)。 |
alpha | Double | 可选 | 1.0 | ≥0 | implicitPrefs为true时,根据用户的评分,在confidence基准值之上,进行额外加分 |
nonnegative | Boolean | 可选 | false | false或true | 在最小平方差优化时,是否加以“非负值”限定 |
输出参数名称 | 数据格式 | 必填/可选/固定 | 备注 |
---|---|---|---|
prediction | Float | 固定 | 预测值列 |
spark代码
// 读入数据
val ratings = sparkContext.textFile("data/mllib/als/sample_movielens_ratings.txt").map(
_.split("::") match { case Array(user, product, rating, timeStamp) =>
Rating(user.toInt, product.toInt, rating.toDouble)
})
df = sqlContext.createDataFrame(ratings)
// 参数值设定
val userCol = "user"
val itemCol = "product"
val ratingCol = "rating"
val rank = 10
val maxIter = 10
val regParam = 0.1
val numUserBlocks = 10
val numItemBlocks = 10
val implicitPrefs = false
val alpha = 1.0
val nonnegative = false
// 建立模型
val als = new ALS(userCol, itemCol, ratingCol, rank, maxIter, numUserBlocks, numItemBlocks, implicitPrefs, alpha, nonnegative)
// 模型训练
val alsModel = als.fit(df)
// 进行预测
val predResult = alsModel.transform(df)
val toDouble = udf[Double, Float]( _.toDouble)
val newPredResult = predResult.withColumn("predictionNew", toDouble(predResult("prediction")))
// 计算RMSE(模型评价)
val predRDD = newPredResult.select("predictionNew", "rating").rdd.map(r => (r.getDouble(0), r.getDouble(1)))
val regMetric = new RegressionMetrics(predRDD)
val rmseSpark = regMetric.rootMeanSquaredError
println(s"RMSE for ALS model: ${rmseSpark}")
本地实例
1.测试数据
userID | itemID | ratings |
---|---|---|
101 | 1001 | 4.0 |
101 | 1002 | 2.5 |
101 | 1004 | 3.0 |
101 | 1007 | 1.5 |
101 | 1010 | 4.0 |
101 | 1016 | 3.5 |
101 | 1022 | 4.0 |
102 | 1002 | 2.5 |
102 | 1003 | 1.0 |
102 | 1004 | 3.5 |
102 | 1006 | 2.0 |
102 | 1009 | 2.5 |
102 | 1011 | 4.0 |
102 | 1013 | 3.5 |
102 | 1015 | 4.0 |
102 | 1017 | 4.5 |
102 | 1022 | 5.0 |
103 | 1003 | 1.5 |
103 | 1005 | 1.0 |
103 | 1006 | 3.5 |
103 | 1008 | 2.0 |
103 | 1010 | 4.5 |
103 | 1014 | 3.0 |
103 | 1015 | 3.5 |
103 | 1021 | 5.0 |
103 | 1022 | 1.5 |
103 | 1023 | 5.0 |
104 | 1001 | 0.5 |
104 | 1003 | 3.0 |
104 | 1004 | 1.5 |
104 | 1007 | 1.0 |
104 | 1008 | 2.5 |
104 | 1011 | 1.0 |
104 | 1015 | 3.5 |
104 | 1018 | 4.0 |
104 | 1019 | 1.5 |
104 | 1020 | 3.0 |
2.训练
package ALSdemo
import java.io.File
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.rdd.RDD
object alsTest {
//屏蔽不必要的日志显示在终端上
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
//给用户推荐
val conf = new SparkConf().setMaster("local[2]").setAppName("als_test_wy")
val sc = new SparkContext(conf)
val myModelPath = "E:\\Spark\\scala-data\\Model\\alsTest"
val data = sc.textFile("E:\\Spark\\scala-data\\CBRec\\als_rating_test.txt")
val ratings: RDD[Rating] = data.map(_.split("#") match { case Array(user, item, rate) =>
Rating(user.toInt, item.toInt, rate.toDouble)
})
ratings.filter(x => x.user == 101).foreach(println)
// Build the recommendation model using ALS
val rank = 5
val numIterations = 10
val model = ALS.train(ratings, rank, numIterations, 0.01)
val recommendProducts: Array[Rating] = model.recommendProducts(101, 10)
for (r <- recommendProducts) {
println(r.toString)
}
val path: File = new File(myModelPath)
dirDel(path) //删除原模型保存的文件,不删除新模型保存会报错
model.save(sc, myModelPath)
}
//删除模型目录和文件
def dirDel(path: File) {
if (!path.exists())
return
else if (path.isFile) {
path.delete()
return
}
val file: Array[File] = path.listFiles()
for (d <- file) {
dirDel(d)
}
path.delete()
}
}
3.调用模型预测
package ALSdemo
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
object alsLoadModelTest {
//屏蔽不必要的日志显示在终端上
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
//给用户推荐
val conf = new SparkConf().setMaster("local").setAppName("rem_test")
val sc = new SparkContext(conf)
val myModelPath = "E:\\Spark\\scala-data\\Model\\alsTest"
val model = MatrixFactorizationModel.load(sc, myModelPath)
val recommendProducts = model.recommendProducts(102, 12)
for (r <- recommendProducts) {
println(r.toString)
}
}
}
4.预测结果
user101:
Rating(101,1010,4.000591419102056)
Rating(101,1022,3.9969496458948193)
Rating(101,1001,3.9772784041229023)
Rating(101,1015,3.5501142515465673)
Rating(101,1016,3.4999375705609506)
Rating(101,1004,3.0070683414579378)
Rating(101,1006,2.64035448857031)
Rating(101,1021,2.5037825017384447)
Rating(101,1023,2.5037825017384447)
Rating(101,1002,2.4961448711069245)
user102:
Rating(102,1022,5.004920261743269)
Rating(102,1017,4.503333959561672)
Rating(102,1015,3.986380420809543)
Rating(102,1011,3.9743258175532787)
Rating(102,1013,3.5025929824248347)
Rating(102,1004,3.481621012016846)
Rating(102,1002,2.509660995203404)
Rating(102,1009,2.5018521424997786)
Rating(102,1006,1.9992398840722876)
Rating(102,1003,1.019633914552224)
Rating(102,1018,0.5646255853665232)
Rating(102,1016,0.49503960012882686)