使用spark mllib库实现协同过滤

使用的数据集是tpch工具生成的数据集,存放在hive中,关于相关的细节,请看
https://www.jianshu.com/p/154069c0e721

ColleborativeFilter2.scala
传入参数:model保存路径 迭代次数
作用:使用数据训练模型,最后将模型保存至本地
说明:将用户购买物品的数量作为rating值

import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.hive.HiveContext

object ColleborativeFilter2 {
  def main(args: Array[String]): Unit = {
    SetLogger
    val path= args(0)
    val num=args(1).toInt
    println("==========程序初始化===============")
    val sparkConf = new SparkConf().setAppName("CF").setMaster("local[2]")
    val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
    val sc =spark.sparkContext

    println("==========数据准备阶段===============")
    val ratings: RDD[Rating] = prepareData(sc)
    println("==========训练阶段===============")
    val rank = 10
    val numIterations = num
    val model = ALS.train(ratings, rank, numIterations, 0.01)
    println("==========训练完成===============")
    model.save(sc,path)
    println("保存到:"+path)
    sc.stop()
  }

  private def prepareData(sc: SparkContext) = {
    val hiveContext = new HiveContext(sc)
    hiveContext.sql("use tpch")
    //利用hive查询数据
    val resultDf = hiveContext
      .sql("select  o.O_CUSTKEY customer,l.L_PARTKEY part,count(*) rating" +
        " from orders o,lineitem l where o.O_ORDERKEY=l.L_ORDERKEY" +
        " group by o.O_CUSTKEY,l.L_PARTKEY")
    //增加评分,默认10
    //val resultDf=customerPartDf.withColumn("rating",customerPartDf("customer")*0+10.0)
    resultDf.show()
    val ratings = resultDf.rdd.map(u =>
      Rating(u.getInt(0), u.getInt(1), u.get(2).toString.toDouble)
    )
    val numRatings = ratings.count()
    val numUsers = ratings.map(_.user).distinct().count()
    val numMovies = ratings.map(_.product).distinct().count()
    println("共计:ratings: " + numRatings + " User " + numUsers + " Part " + numMovies)
    ratings
  }

  def SetLogger = {
    Logger.getLogger("org").setLevel(Level.OFF)
    Logger.getLogger("com").setLevel(Level.OFF)
    System.setProperty("spark.ui.showConsoleProgress", "false")
    Logger.getRootLogger().setLevel(Level.OFF);
  }
}

输出结果:

==========程序初始化===============
==========数据准备阶段===============
+--------+------+------+
|customer|  part|rating|
+--------+------+------+
|   25001|115772|     1|
|  103915|175999|     1|
|   79666| 56901|     1|
|  126154|192471|     1|
|  147884|165801|     1|
|   92054| 75664|     1|
|   40555|187715|     1|
|   22195| 14042|     1|
|   51124| 31213|     1|
|   96481|193796|     1|
|   32779| 14503|     1|
|  129082| 73486|     1|
|  134419| 97723|     1|
|   26981|116112|     1|
|  125698|109181|     1|
|   23536|148693|     1|
|   43201|129019|     1|
|  135277| 82917|     1|
|   63298| 19008|     1|
|   78565|119137|     1|
+--------+------+------+
only showing top 20 rows

共计:ratings: 6000127 User 99996 Movie 200000
==========训练阶段===============
==========训练完成===============
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
保存到:file:///Users/david/IdeaProjects/ideaTest/MySpark/target/tmp/myCollaborativeFilter

TestModel.scala
传入参数: model位置 文件存储位置
作用:读取模型,进行推荐

package com.example.spark

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

object TestModel {
  def main(args: Array[String]): Unit = {
    SetLogger
    println("==========模型加载阶段===============")
    val modelPath=args(0)
    val savePath=args(1)
    val conf = new SparkConf().setAppName("TM").setMaster("local[2]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    val sc = spark.sparkContext
    val model = MatrixFactorizationModel.load(sc, modelPath)
    println("模型加载成功:path="+modelPath)
    println("==========推荐阶段===============")

    recommend(model)

  }


  def recommend(model: MatrixFactorizationModel) = {
    var choose = ""
    while (choose != "3") { //如果选择3.离开,就结束运行程序
      print("请选择要推荐类型  1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?")
      choose = readLine().toString //读取用户输入
      if (choose == "1") { //如果输入1.针对用户推荐电影
        print("请输入用户id?")
        val inputUserID = readLine() //读取用户ID
        RecommendMovies(model,  inputUserID.toString.toInt) //针对此用户推荐电影
      } else if (choose == "2") { //如果输入2.针对电影推荐感兴趣的用户
        print("请输入产品的 id?")
        val inputMovieID = readLine() //读取MovieID
        RecommendUsers(model,  inputMovieID.toString().toInt) //针对此电影推荐用户
      }
    }
  }

  def RecommendMovies(model: MatrixFactorizationModel, inputUserID: Int) = {
    val RecommendMovie = model.recommendProducts(inputUserID, 10)
    var i = 1
    println("针对用户id" + inputUserID + "推荐下列产品:")
    RecommendMovie.foreach { r =>
      println(i.toString() + "." + r.product + "评分:" + r.rating.toString())
      i += 1
    }
  }

  def RecommendUsers(model: MatrixFactorizationModel, inputMovieID: Int) = {
    val RecommendUser = model.recommendUsers(inputMovieID, 10)
    var i = 1
    println("针对产品 id" + inputMovieID  + "推荐下列用户id:")
    RecommendUser.foreach { r =>
      println(i.toString + "用户id:" + r.user + "   评分:" + r.rating)
      i = i + 1
    }
  }
  def SetLogger = {
    Logger.getLogger("org").setLevel(Level.OFF)
    Logger.getLogger("com").setLevel(Level.OFF)
    System.setProperty("spark.ui.showConsoleProgress", "false")
    Logger.getRootLogger().setLevel(Level.OFF);
  }
}

输出结果:

==========初始化模型===============
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
请选择要推荐类型  1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?1
请输入用户id?125698
针对用户id125698推荐下列产品:
1.194564评分:3.3862003302193537
2.95318评分:3.227529363190912
3.107867评分:3.0270877690246434
4.86599评分:2.908890211972091
5.165007评分:2.8965168519326028
6.152244评分:2.8816292303536546
7.127895评分:2.832183626366389
8.37218评分:2.8070734618310933
9.43516评分:2.7800139701236577
10.162949评分:2.755918520650188
请选择要推荐类型  1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?2
请输入产品的 id?148693
针对产品 id148693推荐下列用户id:
1用户id:74882   评分:3.2278715120519137
2用户id:60653   评分:2.980748402528624
3用户id:147077   评分:2.900603769820539
4用户id:75080   评分:2.7945391669012976
5用户id:44345   评分:2.7765308146132384
6用户id:110015   评分:2.7676577792488897
7用户id:57929   评分:2.5332419522978946
8用户id:136910   评分:2.4901329135980883
9用户id:124451   评分:2.442147327035805
10用户id:109289   评分:2.360915024772536
请选择要推荐类型  1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?3

你可能感兴趣的:(使用spark mllib库实现协同过滤)