这一章节,主要介绍FPGrowth源码,以及运行过程演示
2.3 FPGrowth源码详解
run方法是FPGrowth的入口函数,其代码注释如下:
/**
* Computes an FP-Growth model that contains frequent itemsets.
* @param data input data set, each element contains a transaction
* @return an [[FPGrowthModel]]
*
*/
@Since("1.3.0")
defrun[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
if(data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is notcached.")
}
valcount= data.count()
valminCount= math.ceil(minSupport* count).toLong
valnumParts= if(numPartitions> 0) numPartitions else data.partitions.length //获取分区数(PFP算法中的分组)
valpartitioner= newHashPartitioner(numParts)
valfreqItems= genFreqItems(data, minCount, partitioner)//生成频繁项
valfreqItemsets= genFreqCloseItemsets_V1(data, minCount, freqItems,partitioner)//生成频繁项集
// genFreqItemsets(data, minCount,freqItems, partitioner)
newFPGrowthModel(freqItemsets)
}
Sparkmllib中的算法通常都是调用run方法,然后返回一个模型。本章中,我们测试运行的代码如下:
def main(args: Array[String]) {
valsc = new SparkContext(new SparkConf().setMaster("local[*]").setAppName("test"))
sc.setLogLevel("WARN")
valdata= sc.textFile("F:/test/fpg/data.txt")
.map(_.split(" ")).cache
valmodel= newFPGrowth().setMinSupport(0).setNumPartitions(3).run(data)
model.freqItemsets.collect.filter(_.items.size >= 1).foreach(f => println(f.items.mkString(",")+"->"+f.freq))
}
其中,参数minSupport设置为0,保留所有项;partition设置为3,也就是将所有项划分到3个组中,至于具体的分组情况。可以调用mapPartitionWithIndex方法来查看。对于,本章使用的数据集,分组情况如下所示:
0号:I2, I5
1号:I3, I4
3号:I1
其他同道们如果跟我一样的数据集和分区数的话,应该也是相关分组情况。所以,接下来我会以这个分区情况来进行示例。
genFreqItems容易理解,就是统计所有项的频次,并筛选出满足条件的频繁项。这个方法里面还做的一件事就是利用HashPartitioner将各个item划分到不同的分区中。(为什么用hash,还有数据倾斜之类的问题这里就不考虑啦!)。其代码注释如下:
/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
privatedefgenFreqItems[Item:ClassTag](
data: RDD[Array[Item]],
minCount: Long,
partitioner: Partitioner): Array[Item] = {
data.flatMap { t =>
valuniq= t.toSet
if(t.size != uniq.size){ //这里注意一条事务的项不能有重复,否则会报错
throw new SparkException(s"Items in atransaction must be unique but got ${t.toSeq}.")
}
t
}.map(v => (v, 1L))
.reduceByKey(partitioner, _ + _) //将每一个item根据hash分配到各个分区,并求其频次
.filter(_._2>= minCount)
.collect()
.sortBy(-_._2)//按频次降序排列
.map(_._1)
}
该方法对样本数据进行划分和统计的结果如下表所示:
TID |
商品ID的列表 |
排序后的ID与对应组ID |
基于Q对事务进行划分 |
T100 |
I1,I2,I5 |
I2, I1, I5 -> Q0, Q2, Q0 |
Q0->{I2, I1, I5 }, Q2->{I2, I1} |
T200 |
I2,I4 |
I2, I4 -> Q0, Q1 |
Q0->{I2}, Q1->{I2, I4} |
T300 |
I2,I3 |
I2, I3 -> Q0, Q1 |
Q0->{I2}, Q1->{I2, I3} |
T400 |
I1,I2,I4 |
I2, I1, I4 -> Q0, Q2, Q1 |
Q0->{I2}, Q1->{I2, I1, I4}, Q2->{I2, I1} |
T500 |
I1,I3 |
I1, I3 -> Q2, Q1 |
Q1->{I1, I3}, Q2->{I1} |
T600 |
I2,I3 |
I2, I3 -> Q0, Q1 |
Q1->{I2, I3}, Q0->{I2} |
T700 |
I1,I3 |
I1, I3 -> Q2, Q1 |
Q1->{I1, I3}, Q2->{I1} |
T800 |
I1,I2,I3,I5 |
I2, I1, I3, I5 -> Q0, Q2, Q1,Q0 |
Q0->{I2, I1,I3, I5 }, Q1->{I2,I1,I3}, Q2->{I2, I1} |
T900 |
I1,I2,I3 |
I2, I1, I3 -> Q0, Q2, Q1 |
Q1->{I2,I1,I3}, Q2->{I2, I1}, Q0->{I2} |
F-list |
I2:7 I1:6 I3:6 I4:2 I5:2 |
||
Q-list |
Q0:{I2, I5} Q1:{I3, I4} Q2:{I1} |
/**
* Generates conditional transactions.
* 这里面有一个技巧:就是在生成条件事务的时候,itemToRank将freqItemset用其对应的下标来映射。
* 所以每条事务进过筛选并按freqItemset中的频次重排之后,以item对应的下标来输出
* @param transaction a transaction
* @param itemToRank map from item to their rank
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
privatedefgenCondTransactions[Item: ClassTag](
transaction: Array[Item],
itemToRank: Map[Item, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
valoutput= mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent itemspattern and sort their ranks.
/**
* 假设事务为I1,I2,I5, 其中itemToRank中只有{I2->0, I1->1}
* 那么itemToRank.get对这三项取得的值为 {1,0,option(null)},经过flatMap展平后
* 得到的filtered中为:Array(1, 0)
* 再经过sort变成Array(0, 1)
*/
valfiltered= transaction.flatMap(itemToRank.get)
ju.Arrays.sort(filtered) //按照itemToRank重新排序
valn = filtered.length
vari = n - 1
while(i >= 0) {
valitem= filtered(i)
valpart= partitioner.getPartition(item)
if(!output.contains(part)) {
output(part) = filtered.slice(0, i + 1)
}
i -= 1
}
output
}
/**
* Generate frequent itemsets by building FP-Trees, the extraction isdone on each partition.
* @param data transactions
* @param minCount minimum count for frequent itemsets
* @param freqItems frequent items
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
privatedefgenFreqItemsets[Item: ClassTag](
data: RDD[Array[Item]],
minCount: Long,
freqItems: Array[Item],
partitioner: Partitioner): RDD[FreqItemset[Item]] = {
valitemToRank= freqItems.zipWithIndex.toMap
data/*input:RDD[Array[Item]]*/
.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner) //(part:分区号, array[int]:事务)
}/*output: RDD[(part, Array[Int])]*/
//aggregateByKey的作用是对每个分区及其包含的所有事务,构建一颗FPTree
.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
(tree, transaction) => tree.add(transaction, 1L), //一条事务作为只有一条分支的树
(tree1, tree2) => tree1.merge(tree2)) //所有单分支树进行合并成为一颗树(各个分区分别进行这两个操作)
.flatMap { case (part,tree)=>
//分别从各个分区对应树进行频繁模式抽取
tree.extract(minCount,x => partitioner.getPartition(x) == 0)
}.map { case(ranks,count)=>
newFreqItemset(ranks.map(i=> freqItems(i)).toArray, count)//将下标转换成其对应的item
}
}
下图给出了aggregateByKey中运行情况的演示