Swing公式
Ui,Uj为item1,和item2的点击个数,下图主要讲解红框的构建思路,难点在于红框的构建思路,本文主要讲解红框的计算思路,个人试验了很久发现了一种较好的解决方式
思路
注:图中两次过滤可过滤大量数据,解法比较有意思的地方在于用求根公式求解user1和user2点击了的item的共同数目,经过我粗略实验,发现直接利用itemPair出现的数目效果反而更好,或许值得调整原模型alpha后再查看效果
Swing模型构建流程
思路举例
代码直接掉用fitOnline就好,按照PvEntity给出的数据格式构造数据,param为文件中SwingParams的广播变量
package com.sohu.mp.rec.itemBased.Swing.main
import com.sohu.mp.rec.itemBased.ItemCF.main.ItemCFManager.{computeSimilarities, loadData}
import com.sohu.mp.rec.itemBased.ItemCF.main.PathManager
import com.sohu.mp.rec.itemBased.Swing.entity.{Item, User}
//import com.sohu.mp.rec.itemBased.Swing.util.SwingUtil.SwingParams
import com.sohu.mp.rec.recall.common.entity.application.base.PvEntity
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
object SwingManagerTmp {
case class SwingParams(minItemClick:Int = 100,
maxItemClick:Int = 10000000,
minUserClick:Int = 2,
maxUserClick:Int = 1000,
minPairClickNum:Int = 2,
maxPairClickNum:Int = 100000,
sessionLimitFlag: Boolean= false,
sessionLimitStamp: Long = 60L * 60 * 1000,
timeSeqFlag:Boolean= false,
alpha:Double= -15000D,
directionFlag: Boolean= false,
zeta: Double= 0.7D,
beta: Double= 0.8D,
userClickLengthFlag: Boolean = false,
itemTopN: Int = 150,
swingAlpha: Int = 5,
topNItemNum: Int = 5000,
defaultParallelism: Int = 500,
update: Boolean = false
)
def filterItemClickCnt(rawData: RDD[PvEntity],
params: Broadcast[SwingParams]):
RDD[(String, (String, Long, Int))]={
val userItemClick = rawData.map(rowDataEntity => (rowDataEntity.item_id, (rowDataEntity.user_id, rowDataEntity.timeStamp)))
.groupByKey()
.flatMap{case(itemId, itemClickArray) =>{
val itemClickSet = itemClickArray.toSet.toArray.take(params.value.maxItemClick)
val size = itemClickSet.size
val userItemResult = new ArrayBuffer[(String, (String, Long, Int))]()
println(s"model Params, $params")
if(size <= params.value.maxItemClick && size > params.value.minItemClick){
for(userIdTimeStamp <- itemClickSet){
// userId, (itemId, time, clickSize)
val userId = userIdTimeStamp._1
val time = userIdTimeStamp._2
userItemResult.append((userId, (itemId, time, size)))
}
}
userItemResult.toIterator
}}
userItemClick
}
// 获取文章点击总数
def getItemClickCnt(userItemClick: RDD[(String, (String, Long, Int))]): RDD[(String, Int)]={
val itemClick = userItemClick.map{
case(userId, (itemId, timeStamp, totalNlickNum)) => (itemId, totalNlickNum)
}
itemClick
}
//利用用户点击长度过滤数据
def genUserItemSet(rawData: RDD[(String, (String, Long, Int))], params: Broadcast[SwingParams]):
RDD[(String, Array[Item])]={
val userClickSet = rawData.map{
case(userId, (itemId, time, size)) =>
// UserId, itemId, itemTimeStamp
(userId -> Item(itemId, time))}
.groupByKey()
.map{
case((userId, clickedItems)) => {
val itemSet = clickedItems.toSet.toArray
if(itemSet.size > params.value.minUserClick && itemSet.size < params.value.maxUserClick){
(userId, itemSet)
}else{
(userId, null)
}
}}.filter{case(userId, itemSet) =>{itemSet != null}}
userClickSet
}
//genItemUserSet
def genUserPairScore(userPairItemPair: RDD[((String, String), (String, String))], params: Broadcast[SwingParams]):
RDD[((String, String), Double)]={
val userPairClickNum = userPairItemPair.map{
case ((userA, userB), (xItem, yItem)) => ((userA, userB), 1)
}
val userPairScore = userPairClickNum.reduceByKey{
case(xCoClick, yCoClick) => xCoClick + yCoClick
}.map{case ((xUser, yUser), coPairScore) =>{
// 通过x= [-b + sqrt(b^2-4ac)]/2a得出公式[1+sqrt(1+8*coClickNum)]/2
// n * (n-1)/2=coClickNum
val coClickNum = (1 + math.sqrt(1 + 8 * coPairScore)) / 2
val coScore = 1.0d / (params.value.swingAlpha + coClickNum)
// 自己的方式修改Swing score
//val coScore = 1.0d / coPairScore
((xUser, yUser), coScore)
}}
userPairScore
}
def genUserPairItemPair(itemPairUserSet: RDD[((String, String), Array[String])],
params: Broadcast[SwingParams]):
RDD[((String, String), (String, String))] ={
val userPairItemPair = itemPairUserSet.flatMap{
case((xItemId, yItemId), userIdArray) =>{
// userPair itemPair
// 共同点击 应该不会超出Int 21亿
val uPairIPairArray = ArrayBuffer[((String, String), (String, String))]()
val userIdArraySize = userIdArray.size
for(i <- 0 until userIdArraySize; j <- i+1 until userIdArraySize){
// userPair itemPair
uPairIPairArray.append(((userIdArray(i), userIdArray(j)), (xItemId, yItemId)))
}
uPairIPairArray
}
}
println(s"userPairItemPair length and data: ${userPairItemPair.count()}")
userPairItemPair.take(10).foreach(println)
userPairItemPair
}
def getUserItemIds(userId: String,
userItemSetMap: Broadcast[scala.collection.Map[String, Array[Item]]]):
Array[String]={
userItemSetMap.value.getOrElse(userId, Array[Item]()).map{ case(item) => item.itemId}
}
//def updateItemScore(userPairScore: RDD[((String, String), Double)],
// userItemSetMap: Broadcast[scala.collection.Map[String, Array[Item]]],
// params: Broadcast[SwingParams]):
//RDD[((String, String), Double)]={
// val userPairScoreUpdate = userPairScore.map{case((xUserId, yUserId), score) =>{
// val xUserClickedItem = getUserItemIds(xUserId, userItemSetMap)
// val yUserClickedItem = getUserItemIds(yUserId, userItemSetMap)
// val coClickedItem = xUserClickedItem.intersect(yUserClickedItem)
// val scoreUpdate = computePairScore(coClickedItem, xUserClickedItem.size,
// yUserClickedItem.size, score, params)
// ((xUserId, yUserId), scoreUpdate)
// }}
// userPairScoreUpdate
//}
def computePairScore(coClickedItem: Array[String],
xUserClickedSize: Int,
yUserClickedSize: Int,
score: Double,
params: Broadcast[SwingParams]): Double ={
var scoreUpdate = score
// // 采用时间序列计算
// if(params.value.timeSeqFlag){
// val timeDistance = math.abs(yItem.timeStamp - xItem.timeStamp).toDouble
// score = score * math.exp((params.value.alpha * timeDistance))
// }
// // 计算方向
// if(params.value.directionFlag){
// val locationDistance = yItem.localtion - xItem.localtion
// // 如果序列为反方向
// var currentZeta = params.value.zeta
// if(locationDistance > 0){
// currentZeta = 1.0f
// }
// score = score * currentZeta * math.pow(params.value.beta, math.abs(locationDistance) - 1)
// }
// 考虑用户点击长度
if(params.value.userClickLengthFlag){
scoreUpdate = score / ( math.log(1 + xUserClickedSize) * math.log(1 + yUserClickedSize))
}
scoreUpdate
}
def genItemPairUserSetRdd(userItemSet: RDD[(String, Array[Item])], params: Broadcast[SwingParams]):
RDD[((String, String), Array[String])]={
val userItemPairs = userItemSet.flatMap{case(userId, items) => {
val itemPairsUser =new ArrayBuffer[((String, String), String)]()
for(i <- 0 until items.length; j <- i+1 until items.length){
val xItem = items(i)
val yItem = items(j)
itemPairsUser.append(((items(i).itemId, items(j).itemId), userId))
}
itemPairsUser.toIterator
}}.groupByKey()
.map{case ((xItemId, yItemId), userIdArray) =>{
val userSet = userIdArray.toSet.toArray.take(params.value.maxPairClickNum)
if(userSet.size > params.value.minPairClickNum &&
userSet.size <= params.value.maxPairClickNum){
((xItemId, yItemId), userSet)
}else{
((xItemId, yItemId), null)
}
}}
.filter{case((xItemId, yItemId), userIdArray) => userIdArray != null}
userItemPairs
}
// 构建用户pair,将(User1,User2)视为同一个用户
def constructUserPair(filteredRowData: RDD[(String, (Item, User))], params: Broadcast[SwingParams]):
RDD[(Array[Item], User, User)]={
// (itemId, Item, User)
val itemRowData = filteredRowData
.map{case(userId, (item, user)) =>(item.itemId, (item, user))}
//.map(userItem =>(userItem._2._1.itemId, userItem._2))
val userPairItems = itemRowData.groupByKey()
.flatMap{case(itemId, itemUserArray)=>{
val itemUserSet = itemUserArray.toSet.toArray
// (userId1, userId2) -> (Item, User1, User2)
val pairResult = new ArrayBuffer[((String, String),(Item, User, User))]()
for(i <- 0 until itemUserSet.length; j <- i+1 until itemUserSet.length){
val xPair = itemUserSet(i)
val yPair = itemUserSet(j)
val item = xPair._1
val xUser = xPair._2
val yUser = yPair._2
pairResult.append(((xUser.userId, yUser.userId),(item, xUser, yUser)))
}
pairResult.toIterator
}}
// 表示(user1, user2)共同看过的所有item的Array
.groupByKey()
.map{case((userAId, userBId), itemUserPairsArray)=> {
var items = itemUserPairsArray.map{case(item, userA, userB) =>item}.toSet.toArray
val users = itemUserPairsArray.take(1).toArray
val userA = users(0)._2
val userB = users(0)._3
if(params.value.timeSeqFlag){
items = items.sortBy(_.timeStamp)
}
(items, userA, userB)
}}
userPairItems
}
def constructItemPair(userPairItemPair: RDD[((String, String), (String, String))],
userPairScoreRdd: RDD[((String, String), Double)],
params: Broadcast[SwingParams]):
RDD[((String, String), Double)]={
val itemPairScore = userPairItemPair.join(userPairScoreRdd).map{
case ((userA, userB), ((xItem, yItem), score)) => ((xItem, yItem), score)
}.reduceByKey(_+_)
itemPairScore
}
// def computePairScore(xItem: ItemWithLocation,
// yItem: ItemWithLocation,
// userA: User,
// userB: User,
// coClickNum: Int,
// params: Broadcast[SwingParams]): Double ={
// val userAClickLength = userA.clickLength
// val userBClickLength = userB.clickLength
// var score = 1.0 / (params.value.swingAlpha + coClickNum)
// // 采用时间序列计算
// if(params.value.timeSeqFlag){
// val timeDistance = math.abs(yItem.timeStamp - xItem.timeStamp).toDouble
// score = score * math.exp((params.value.alpha * timeDistance))
// }
// // 计算方向
// if(params.value.directionFlag){
// val locationDistance = yItem.localtion - xItem.localtion
// // 如果序列为反方向
// var currentZeta = params.value.zeta
// if(locationDistance > 0){
// currentZeta = 1.0f
// }
// score = score * currentZeta * math.pow(params.value.beta, math.abs(locationDistance) - 1)
// }
// // 考虑用户点击长度, 如果用户A,B点击越短,但是都点击了
// // 相同的(xItem, yItem),则说明该pair相关性很强
// if(params.value.userClickLengthFlag){
// score = score /( math.log(1 + userAClickLength) * math.log(1 + userBClickLength))
// }
// score
// }
def selectItemTopN(pairScore: RDD[(String, (String, Double))], params: Broadcast[SwingParams]):
RDD[(Long, Seq[Long])]={
val itemCandidates = pairScore.groupByKey()
.map(itemCandidates =>{
val itemId = itemCandidates._1.toLong
//按照分数降序排列
val candidates = itemCandidates._2.toArray.sortBy{
case(itemId, score) => -score
}
.take(params.value.itemTopN)
.map{
case(itemId, score) => itemId.toLong
}.toSeq
println(s"sort candidates ${candidates.mkString("")}")
(itemId, candidates)
})
itemCandidates
}
// Dates形如{"20200718", "20200720"}的格式,params参数格式为SwingParamsd的格式,
// 用于控制Swing模型过滤数据,还有模型参数的存放
def fit(spark: SparkSession, dates:List[String], params: SwingParams):
RDD[(Long, Seq[Long])]={
val modelParams = spark.sparkContext.broadcast(params)
//println(s"model Params, $modelParams")
//println("SwingManager params minItemClick, userClickLengthFlag, directionFlag")
//println(minItemClick)
//println(userClickLengthFlag)
//println(directionFlag)
//spark.sparkContext.
// HDFS上保存数据的位置
val dataPath = PathManager.getSwingDataPath()
println(s"train data path: ${dataPath}")
// 加载数据的方式
val rowDataEntityRdd = loadData(spark, dates, dataPath)
println(s"rowDataEntityRdd length and data: ${rowDataEntityRdd.count()}")
rowDataEntityRdd.take(10).foreach(println)
fitOnline(spark, rowDataEntityRdd, modelParams)
}
def fitOnline(spark: SparkSession, rowDataEntityRdd: RDD[PvEntity], params: Broadcast[SwingParams]):
RDD[(Long, Seq[Long])]={
val filteredItemRdd = filterItemClickCnt(rowDataEntityRdd, params)
println(s"filteredItemRdd length and data: ${filteredItemRdd.count()}")
filteredItemRdd.take(10).foreach(println)
val itemClickNumRdd = getItemClickCnt(filteredItemRdd).distinct()
println(s"itemClickNumRdd length and data: ${itemClickNumRdd.count()}")
itemClickNumRdd.take(10).foreach(println)
val userItemSetRdd = genUserItemSet(filteredItemRdd, params)
println(s"userItemSetRdd length and data: ${userItemSetRdd.count()}")
userItemSetRdd.take(10).foreach(println)
//filteredItemRdd.unpersist()
val itemPairUserSet = genItemPairUserSetRdd(userItemSetRdd, params)
println(s"itemPairUserSet length and data: ${itemPairUserSet.count()}")
itemPairUserSet.take(10).foreach(println)
val userPairItemPair = genUserPairItemPair(itemPairUserSet, params)
println(s"itemPairUserSet length and data: ${itemPairUserSet.count()}")
itemPairUserSet.take(10).foreach(println)
// user construct
var userPairScoreRdd = genUserPairScore(userPairItemPair, params)
//if(params.value.update){
// println(">>> entry into update userPairsScore pharse <<<")
// val userItemSetMap = spark.sparkContext.broadcast(userItemSetRdd.collectAsMap())
// userPairScoreRdd = updateItemScore(userPairScoreRdd, userItemSetMap, params)
// userItemSetMap.unpersist()
// println(">>> update userPairsScore pharse end <<<")
//}
//userItemSetRdd.unpersist()
//val userPairScoreBC = spark.sparkContext.broadcast(userPairScoreRdd.collectAsMap())
val itemPairs = constructItemPair(userPairItemPair, userPairScoreRdd, params)
println(s"itemPairs length and data: ${itemPairs.count()}")
itemPairs.take(10).foreach(println)
//itemPairUserSet.unpersist()
val itemSimilarities = computeSimilarities(itemClickNumRdd, itemPairs)
println(s"itemSimilarities length and data: ${itemSimilarities.count()}")
itemSimilarities.take(10).foreach(println)
//itemClickNumRdd.unpersist()
val itemsTopN = selectItemTopN(itemSimilarities, params)
println(s"itemsTopN length and data: ${itemsTopN.count()}")
itemsTopN.take(10).foreach(println)
itemsTopN
}
}
// 计算XItem, YItem之间的得分
def computeSimilarities(validItemRdd: RDD[(String, Int)],
itemPairs: RDD[((String, String), Double)]):
RDD[(String, (String, Double))]={
val similarities = itemPairs.map{
case ((xItemId, yItemId), pairScore) => (xItemId, (yItemId, pairScore))
}.join(validItemRdd)
.map { case (xItemId, ((yItemId, pairScore), xItemClickCnt)) =>
(yItemId, (pairScore, xItemId, xItemClickCnt))}
.join(validItemRdd)
.map { case (yItemId, ((pairScore, xItemId, xItemClickCnt), yItemClickCnt)) =>
val cosine = pairScore / math.sqrt(xItemClickCnt * yItemClickCnt)
(xItemId -> (yItemId, cosine))}
similarities
}
参考文章:
https://mp.weixin.qq.com/s?__biz=MjM5MzY4NzE3MA==&mid=2247485008&idx=1&sn=ca0549a346bc9879c48fc99628410621&chksm=a69275bd91e5fcab7a779eccbaee6d1715eb9611c7f9e4c32e1c5c814f5f9e1d49000602476e&mpshare=1&scene=1&srcid=&sharer_sharetime=1592800738855&sharer_shareid=d1a917c43153309de51a76d5d54e85ef#rd
https://zhuanlan.zhihu.com/p/67126386?from_voters_page=true