用Spark简单实现阿里Swing,效果比itemCF优化后的模型好很多

Swing公式

Swing公式

Ui,Uj为item1,和item2的点击个数,下图主要讲解红框的构建思路,难点在于红框的构建思路,本文主要讲解红框的计算思路,个人试验了很久发现了一种较好的解决方式

思路

思路举例

注:图中两次过滤可过滤大量数据,解法比较有意思的地方在于用求根公式求解user1和user2点击了的item的共同数目,经过我粗略实验,发现直接利用itemPair出现的数目效果反而更好,或许值得调整原模型alpha后再查看效果

Swing模型构建流程

swing模型构建流程

思路举例

代码直接掉用fitOnline就好,按照PvEntity给出的数据格式构造数据,param为文件中SwingParams的广播变量

package com.sohu.mp.rec.itemBased.Swing.main

import com.sohu.mp.rec.itemBased.ItemCF.main.ItemCFManager.{computeSimilarities, loadData}
import com.sohu.mp.rec.itemBased.ItemCF.main.PathManager
import com.sohu.mp.rec.itemBased.Swing.entity.{Item, User}
//import com.sohu.mp.rec.itemBased.Swing.util.SwingUtil.SwingParams
import com.sohu.mp.rec.recall.common.entity.application.base.PvEntity
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

import scala.collection.mutable.ArrayBuffer


object SwingManagerTmp {

  case class SwingParams(minItemClick:Int = 100,
                         maxItemClick:Int = 10000000,
                         minUserClick:Int = 2,
                         maxUserClick:Int = 1000,
                         minPairClickNum:Int = 2,
                         maxPairClickNum:Int = 100000,
                         sessionLimitFlag: Boolean= false,
                         sessionLimitStamp: Long = 60L * 60 * 1000,
                         timeSeqFlag:Boolean= false,
                         alpha:Double= -15000D,
                         directionFlag: Boolean= false,
                         zeta: Double= 0.7D,
                         beta: Double= 0.8D,
                         userClickLengthFlag: Boolean = false,
                         itemTopN: Int = 150,
                         swingAlpha: Int = 5,
                         topNItemNum: Int = 5000,
                         defaultParallelism: Int = 500,
                         update: Boolean = false
                        )

  def filterItemClickCnt(rawData: RDD[PvEntity],
                         params: Broadcast[SwingParams]):
  RDD[(String, (String, Long, Int))]={
    val userItemClick = rawData.map(rowDataEntity => (rowDataEntity.item_id, (rowDataEntity.user_id, rowDataEntity.timeStamp)))
      .groupByKey()
      .flatMap{case(itemId, itemClickArray) =>{
        val itemClickSet = itemClickArray.toSet.toArray.take(params.value.maxItemClick)
        val size = itemClickSet.size
        val userItemResult = new ArrayBuffer[(String, (String, Long, Int))]()
        println(s"model Params, $params")
        if(size <= params.value.maxItemClick && size > params.value.minItemClick){
          for(userIdTimeStamp <- itemClickSet){
            // userId, (itemId, time, clickSize)
            val userId = userIdTimeStamp._1
            val time = userIdTimeStamp._2
            userItemResult.append((userId, (itemId, time, size)))
          }
        }
        userItemResult.toIterator
      }}
    userItemClick
  }

  // 获取文章点击总数
  def getItemClickCnt(userItemClick: RDD[(String, (String, Long, Int))]): RDD[(String, Int)]={
    val itemClick = userItemClick.map{
      case(userId, (itemId, timeStamp, totalNlickNum)) => (itemId, totalNlickNum)
    }
    itemClick
  }

  //利用用户点击长度过滤数据
  def genUserItemSet(rawData: RDD[(String, (String, Long, Int))], params: Broadcast[SwingParams]):
  RDD[(String, Array[Item])]={
    val userClickSet = rawData.map{
      case(userId, (itemId, time, size)) =>
      // UserId, itemId, itemTimeStamp
      (userId -> Item(itemId, time))}
      .groupByKey()
      .map{
        case((userId, clickedItems)) => {
        val itemSet = clickedItems.toSet.toArray
          if(itemSet.size > params.value.minUserClick && itemSet.size < params.value.maxUserClick){
            (userId, itemSet)
          }else{
            (userId, null)
          }
      }}.filter{case(userId, itemSet) =>{itemSet != null}}
    userClickSet
  }

  //genItemUserSet
  def genUserPairScore(userPairItemPair: RDD[((String, String), (String, String))], params: Broadcast[SwingParams]):
  RDD[((String, String), Double)]={
    val userPairClickNum = userPairItemPair.map{
      case ((userA, userB), (xItem, yItem)) => ((userA, userB), 1)
    }

    val userPairScore = userPairClickNum.reduceByKey{
      case(xCoClick, yCoClick) => xCoClick + yCoClick
    }.map{case ((xUser, yUser), coPairScore) =>{
      // 通过x= [-b + sqrt(b^2-4ac)]/2a得出公式[1+sqrt(1+8*coClickNum)]/2
      // n * (n-1)/2=coClickNum
      val coClickNum = (1 + math.sqrt(1 +  8 * coPairScore)) / 2
      val coScore = 1.0d / (params.value.swingAlpha + coClickNum)
      // 自己的方式修改Swing score
      //val coScore = 1.0d / coPairScore
      ((xUser, yUser), coScore)
    }}
    userPairScore
  }

  def genUserPairItemPair(itemPairUserSet: RDD[((String, String), Array[String])],
                          params: Broadcast[SwingParams]):
  RDD[((String, String), (String, String))] ={
    val userPairItemPair = itemPairUserSet.flatMap{
      case((xItemId, yItemId), userIdArray) =>{
        // userPair itemPair
        // 共同点击 应该不会超出Int 21亿
        val uPairIPairArray = ArrayBuffer[((String, String), (String, String))]()
        val userIdArraySize = userIdArray.size
        for(i <- 0 until userIdArraySize; j <- i+1 until userIdArraySize){
          // userPair itemPair
          uPairIPairArray.append(((userIdArray(i), userIdArray(j)), (xItemId, yItemId)))
        }
        uPairIPairArray
      }
    }

    println(s"userPairItemPair length and data: ${userPairItemPair.count()}")
    userPairItemPair.take(10).foreach(println)
    userPairItemPair
  }

  def getUserItemIds(userId: String,
                     userItemSetMap: Broadcast[scala.collection.Map[String, Array[Item]]]):
  Array[String]={
    userItemSetMap.value.getOrElse(userId, Array[Item]()).map{ case(item) => item.itemId}
  }

  //def updateItemScore(userPairScore: RDD[((String, String), Double)],
  //                    userItemSetMap: Broadcast[scala.collection.Map[String, Array[Item]]],
  //                    params: Broadcast[SwingParams]):
  //RDD[((String, String), Double)]={
  //  val userPairScoreUpdate = userPairScore.map{case((xUserId, yUserId), score) =>{
  //    val xUserClickedItem = getUserItemIds(xUserId, userItemSetMap)
  //    val yUserClickedItem =  getUserItemIds(yUserId, userItemSetMap)
  //    val coClickedItem = xUserClickedItem.intersect(yUserClickedItem)
  //    val scoreUpdate = computePairScore(coClickedItem, xUserClickedItem.size,
  //      yUserClickedItem.size, score, params)
  //    ((xUserId, yUserId), scoreUpdate)
  //  }}
  // userPairScoreUpdate
  //}


  def computePairScore(coClickedItem: Array[String],
                       xUserClickedSize: Int,
                       yUserClickedSize: Int,
                       score: Double,
                       params: Broadcast[SwingParams]): Double ={
   var scoreUpdate = score
   // // 采用时间序列计算
   // if(params.value.timeSeqFlag){
   //   val timeDistance = math.abs(yItem.timeStamp - xItem.timeStamp).toDouble
   //   score = score * math.exp((params.value.alpha * timeDistance))
   // }

   // // 计算方向
   // if(params.value.directionFlag){
   //   val locationDistance = yItem.localtion - xItem.localtion
   //   // 如果序列为反方向
   //   var currentZeta = params.value.zeta
   //   if(locationDistance > 0){
   //     currentZeta = 1.0f
   //   }
   //   score = score * currentZeta * math.pow(params.value.beta, math.abs(locationDistance) - 1)
   // }

    // 考虑用户点击长度
    if(params.value.userClickLengthFlag){
      scoreUpdate = score / ( math.log(1 + xUserClickedSize) * math.log(1 + yUserClickedSize))
    }
    scoreUpdate
  }

  def genItemPairUserSetRdd(userItemSet: RDD[(String, Array[Item])], params: Broadcast[SwingParams]):
  RDD[((String, String), Array[String])]={
    val userItemPairs = userItemSet.flatMap{case(userId, items) => {
      val itemPairsUser =new ArrayBuffer[((String, String), String)]()
      for(i <- 0 until items.length; j <- i+1 until  items.length){
        val xItem = items(i)
        val yItem = items(j)
        itemPairsUser.append(((items(i).itemId, items(j).itemId), userId))
      }
      itemPairsUser.toIterator
    }}.groupByKey()
        .map{case ((xItemId, yItemId), userIdArray) =>{
          val userSet = userIdArray.toSet.toArray.take(params.value.maxPairClickNum)
          if(userSet.size > params.value.minPairClickNum &&
            userSet.size <= params.value.maxPairClickNum){
            ((xItemId, yItemId), userSet)
          }else{
            ((xItemId, yItemId), null)
          }
        }}
        .filter{case((xItemId, yItemId), userIdArray) => userIdArray != null}
    userItemPairs
  }

  // 构建用户pair,将(User1,User2)视为同一个用户
  def constructUserPair(filteredRowData: RDD[(String, (Item, User))], params: Broadcast[SwingParams]):
  RDD[(Array[Item], User, User)]={
    // (itemId, Item, User)
    val itemRowData = filteredRowData
        .map{case(userId, (item, user)) =>(item.itemId, (item, user))}
      //.map(userItem =>(userItem._2._1.itemId, userItem._2))
    val userPairItems = itemRowData.groupByKey()
      .flatMap{case(itemId, itemUserArray)=>{
        val itemUserSet = itemUserArray.toSet.toArray
        // (userId1, userId2) -> (Item, User1, User2)
        val pairResult = new ArrayBuffer[((String, String),(Item, User, User))]()
        for(i <- 0 until itemUserSet.length; j <- i+1 until itemUserSet.length){
          val xPair = itemUserSet(i)
          val yPair = itemUserSet(j)
          val item = xPair._1
          val xUser = xPair._2
          val yUser = yPair._2
          pairResult.append(((xUser.userId, yUser.userId),(item, xUser, yUser)))
        }
        pairResult.toIterator
      }}
      // 表示(user1, user2)共同看过的所有item的Array
      .groupByKey()
        .map{case((userAId, userBId), itemUserPairsArray)=> {
          var items = itemUserPairsArray.map{case(item, userA, userB) =>item}.toSet.toArray
          val users = itemUserPairsArray.take(1).toArray
          val userA = users(0)._2
          val userB = users(0)._3
          if(params.value.timeSeqFlag){
            items = items.sortBy(_.timeStamp)
          }
          (items, userA, userB)
        }}
    userPairItems
  }


  def constructItemPair(userPairItemPair: RDD[((String, String), (String, String))],
                        userPairScoreRdd: RDD[((String, String), Double)],
                        params: Broadcast[SwingParams]):
  RDD[((String, String), Double)]={
    val itemPairScore = userPairItemPair.join(userPairScoreRdd).map{
      case ((userA, userB), ((xItem, yItem), score)) => ((xItem, yItem), score)
    }.reduceByKey(_+_)
    itemPairScore
  }

 // def computePairScore(xItem: ItemWithLocation,
 //                      yItem: ItemWithLocation,
 //                      userA: User,
 //                      userB: User,
 //                      coClickNum: Int,
 //                      params: Broadcast[SwingParams]): Double ={
 //   val userAClickLength = userA.clickLength
 //   val userBClickLength = userB.clickLength
 //   var score = 1.0 / (params.value.swingAlpha + coClickNum)
 //   // 采用时间序列计算
 //   if(params.value.timeSeqFlag){
 //     val timeDistance = math.abs(yItem.timeStamp - xItem.timeStamp).toDouble
 //     score = score * math.exp((params.value.alpha * timeDistance))
 //   }

 //   // 计算方向
 //   if(params.value.directionFlag){
 //     val locationDistance = yItem.localtion - xItem.localtion
 //     // 如果序列为反方向
 //     var currentZeta = params.value.zeta
 //     if(locationDistance > 0){
 //       currentZeta = 1.0f
 //     }
 //     score = score * currentZeta * math.pow(params.value.beta, math.abs(locationDistance) - 1)
 //   }

 //   // 考虑用户点击长度, 如果用户A,B点击越短,但是都点击了
 //   // 相同的(xItem, yItem),则说明该pair相关性很强
 //   if(params.value.userClickLengthFlag){
 //     score = score /( math.log(1 + userAClickLength) * math.log(1 + userBClickLength))
 //   }
 //   score
 // }


  def selectItemTopN(pairScore: RDD[(String, (String, Double))], params: Broadcast[SwingParams]):
  RDD[(Long, Seq[Long])]={
    val itemCandidates = pairScore.groupByKey()
      .map(itemCandidates =>{
        val itemId = itemCandidates._1.toLong
        //按照分数降序排列
        val candidates = itemCandidates._2.toArray.sortBy{
          case(itemId, score) => -score
        }
          .take(params.value.itemTopN)
          .map{
            case(itemId, score) => itemId.toLong
          }.toSeq
        println(s"sort candidates ${candidates.mkString("")}")
        (itemId, candidates)
      })
    itemCandidates
  }

  // Dates形如{"20200718", "20200720"}的格式,params参数格式为SwingParamsd的格式,
  // 用于控制Swing模型过滤数据,还有模型参数的存放
  def fit(spark: SparkSession, dates:List[String], params: SwingParams):
  RDD[(Long, Seq[Long])]={
    val modelParams = spark.sparkContext.broadcast(params)
    //println(s"model Params, $modelParams")
    //println("SwingManager params minItemClick, userClickLengthFlag, directionFlag")
    //println(minItemClick)
    //println(userClickLengthFlag)
    //println(directionFlag)
    //spark.sparkContext.

    // HDFS上保存数据的位置
    val dataPath = PathManager.getSwingDataPath()
    println(s"train data path: ${dataPath}")

    // 加载数据的方式
    val rowDataEntityRdd = loadData(spark, dates, dataPath)
    println(s"rowDataEntityRdd length and data: ${rowDataEntityRdd.count()}")
    rowDataEntityRdd.take(10).foreach(println)
    fitOnline(spark, rowDataEntityRdd, modelParams)
  }

  def fitOnline(spark: SparkSession, rowDataEntityRdd: RDD[PvEntity], params: Broadcast[SwingParams]):
  RDD[(Long, Seq[Long])]={
    val filteredItemRdd = filterItemClickCnt(rowDataEntityRdd, params)
    println(s"filteredItemRdd length and data: ${filteredItemRdd.count()}")
    filteredItemRdd.take(10).foreach(println)

    val itemClickNumRdd = getItemClickCnt(filteredItemRdd).distinct()
    println(s"itemClickNumRdd length and data: ${itemClickNumRdd.count()}")
    itemClickNumRdd.take(10).foreach(println)

    val userItemSetRdd = genUserItemSet(filteredItemRdd, params)
    println(s"userItemSetRdd length and data: ${userItemSetRdd.count()}")
    userItemSetRdd.take(10).foreach(println)
    //filteredItemRdd.unpersist()

    val itemPairUserSet = genItemPairUserSetRdd(userItemSetRdd, params)
    println(s"itemPairUserSet length and data: ${itemPairUserSet.count()}")
    itemPairUserSet.take(10).foreach(println)

    val userPairItemPair = genUserPairItemPair(itemPairUserSet, params)
    println(s"itemPairUserSet length and data: ${itemPairUserSet.count()}")
    itemPairUserSet.take(10).foreach(println)
    // user construct
    var userPairScoreRdd = genUserPairScore(userPairItemPair, params)

    //if(params.value.update){
    //  println(">>> entry into update userPairsScore pharse <<<")
    //  val userItemSetMap = spark.sparkContext.broadcast(userItemSetRdd.collectAsMap())
    //  userPairScoreRdd = updateItemScore(userPairScoreRdd, userItemSetMap, params)
    //  userItemSetMap.unpersist()
    //  println(">>> update userPairsScore pharse end <<<")
    //}
    //userItemSetRdd.unpersist()
    //val userPairScoreBC = spark.sparkContext.broadcast(userPairScoreRdd.collectAsMap())

    val itemPairs = constructItemPair(userPairItemPair, userPairScoreRdd, params)
    println(s"itemPairs length and data: ${itemPairs.count()}")
    itemPairs.take(10).foreach(println)
    //itemPairUserSet.unpersist()

    val itemSimilarities = computeSimilarities(itemClickNumRdd, itemPairs)
    println(s"itemSimilarities  length and data: ${itemSimilarities.count()}")
    itemSimilarities.take(10).foreach(println)
    //itemClickNumRdd.unpersist()

    val itemsTopN = selectItemTopN(itemSimilarities, params)
    println(s"itemsTopN length and data: ${itemsTopN.count()}")
    itemsTopN.take(10).foreach(println)
    itemsTopN
  }
}

  // 计算XItem, YItem之间的得分
  def computeSimilarities(validItemRdd: RDD[(String, Int)],
                          itemPairs: RDD[((String, String), Double)]):
  RDD[(String, (String, Double))]={

    val similarities = itemPairs.map{
      case ((xItemId, yItemId), pairScore) => (xItemId, (yItemId, pairScore))
    }.join(validItemRdd)
      .map { case (xItemId, ((yItemId, pairScore), xItemClickCnt)) =>
        (yItemId, (pairScore, xItemId, xItemClickCnt))}
      .join(validItemRdd)
      .map { case (yItemId, ((pairScore, xItemId, xItemClickCnt), yItemClickCnt)) =>
          val cosine = pairScore / math.sqrt(xItemClickCnt * yItemClickCnt)
          (xItemId -> (yItemId, cosine))}
    similarities
  }

参考文章:
https://mp.weixin.qq.com/s?__biz=MjM5MzY4NzE3MA==&mid=2247485008&idx=1&sn=ca0549a346bc9879c48fc99628410621&chksm=a69275bd91e5fcab7a779eccbaee6d1715eb9611c7f9e4c32e1c5c814f5f9e1d49000602476e&mpshare=1&scene=1&srcid=&sharer_sharetime=1592800738855&sharer_shareid=d1a917c43153309de51a76d5d54e85ef#rd
https://zhuanlan.zhihu.com/p/67126386?from_voters_page=true

你可能感兴趣的:(用Spark简单实现阿里Swing,效果比itemCF优化后的模型好很多)