fp-growth算法原理与代码实践

原理

https://www.cnblogs.com/datahunter/p/3903413.html

代码

 import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
    import org.apache.spark.rdd.RDD
    import spark.implicits._
    import com.kugou.ml.model.MLModelFactory

    //    // 看下一共有多少scid
    //    spark.read.table("mllab.t_user_sheet_list").flatMap(row=>{
    //      row.getSeq(1).toArray[String].map(tp=>tp)
    //    }).distinct().count() // 249708->295828

    // 读取数据  
    val transactions2: RDD[Array[String]] = spark.read.table("XXX").map(row=>row.getSeq(1).toArray[String]).toDF("scids").where("size(scids)>0 and size(scids)<10000").map(row=>row.getSeq(0).toArray[String]).rdd   // 防止倾斜

    // 把数据进行缓存
    transactions2.cache() 

    // 训练模型
    val fpg = new FPGrowth().setMinSupport(0.00001).setNumPartitions(1000)
    val model = fpg.run(transactions2)

    // 存储模型
    //model.save(spark.sparkContext,"hdfs://XXX")

    //    // 得到频繁项
    //    val freqItemsets =model.freqItemsets
    //    val freqItemsets_df=freqItemsets.map(itemset=>{
    //      val str: String =itemset.items.mkString("[", "," , "]")
    //      str
    //    }).toDF("items")
    //    MLModelFactory.saveDataFrame(freqItemsets_df,"XXX") 

    //得到关联规则
    //    val model_load=FPGrowthModel.load(spark.sparkContext,"hdfs://XXX")
    val minConfidence = 0.1  
    val ass_rules =model.generateAssociationRules(minConfidence)
    val ass_rules_df=ass_rules.map(rule=>{
      val str1=rule.antecedent .mkString("[", ",", "]")
      val str2=rule.consequent .mkString("[", ",", "]")
      //      val conf=rule.confidence
      (str1,str2)
    }).toDF("antecedent","consequent")
    MLModelFactory.saveDataFrame(ass_rules_df,"XXX")

    // 对关联规则进行合并
    spark.read.table("XXX").map(row=>{
      val ant=row.getString(0).split("\\[")(1).split("\\]")(0).split(",").sortWith(_ > _).mkString(",")
      val cons=row.getString(1).split("\\[")(1).split("\\]")(0)
      (ant,cons)
    }).toDF("ant","cons").createOrReplaceTempView("view_tmp")
    val df_group=spark.sql("select ant, collect_set(cons) as cons_set from view_tmp group by ant")
    MLModelFactory.saveDataFrame(df_group,"XXX")

 

你可能感兴趣的:(推荐系统)