Spark PageRank

说明

如果不考虑出度为0的节点情况,方法很easy,参考官方的code。考虑出度为0 有两个版本,V2是在V1基础上的修改完善版本,V1版本记录了各种出错记录,V2版自我感觉没有问题了。

考虑出度为0的节点的具体算法可以参考data-intensive text processing with mapreduce-Graph Algorithms

数据

[plain]  view plain copy print ?
  1. 1 2  
  2. 1 3  
  3. 1 4  
  4. 2 1  
  5. 3 1  
  6. 1 5  
  7. 2 5  

V2-PageRank

[plain]  view plain copy print ?
  1. package myclass  
  2.   
  3. import org.apache.spark.SparkContext  
  4. import SparkContext._  
  5.   
  6. /**  
  7.  * Created by jack on 2/25/14.  
  8.  */  
  9. object MyAccumulator {  
  10.     def main(args: Array[String]) {  
  11.         val iters = 20  
  12.         val sc = new SparkContext("local", "My PageRank", System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass))  
  13.   
  14.         val lines = sc.textFile("src/main/resources/data/pagerank_data.txt", 1)  
  15.         //根据边关系数据生成 邻接表 如:(1,(2,3,4,5)) (2,(1,5))...  
  16.         var links = lines.map(line => {  
  17.             val parts = line.split("\\s+")  
  18.             (parts(0), parts(1))  
  19.         }).distinct().groupByKey()  
  20.   
  21.         //添加出度为0的节点的邻接表项 如:(4,()) (5,())...  
  22.         val nodes = scala.collection.mutable.ArrayBuffer.empty ++ links.keys.collect()  
  23.         val newNodes = scala.collection.mutable.ArrayBuffer[String]()  
  24.         for {s <- links.values.collect()  
  25.                  k <- s if (!nodes.contains(k))  
  26.         } {  
  27.             nodes += k  
  28.             newNodes += k  
  29.         }  
  30.         val linkList = links ++ sc.parallelize(for (i <- newNodes) yield (i, List.empty))  
  31.         val nodeSize = linkList.count()  
  32.         var ranks = linkList.mapValues(v => 1.0 / nodeSize)  
  33.   
  34.         //迭代计算PR值  
  35.         for (i <- 1 to iters) {  
  36.             val dangling = sc.accumulator(0.0)  
  37.             val contribs = linkList.join(ranks).values.flatMap {  
  38.                 case (urls, rank) => {  
  39.                     val size = urls.size  
  40.                     if (size == 0) {  
  41.                         dangling += rank  
  42.                         List()  
  43.                     } else {  
  44.                         urls.map(url => (url, rank / size))  
  45.                     }  
  46.                 }  
  47.             }  
  48.             //若无下面这行,统计的dangling将为0,若用contribs.first,则dangling等于一个分片中的聚集值  
  49.             contribs.count()  
  50.             val danglingValue = dangling.value  
  51.             ranks = contribs.reduceByKey(_ + _).mapValues[Double](p =>  
  52.                 0.1 * (1.0 / nodeSize) + 0.9 * (danglingValue / nodeSize + p)  
  53.             )  
  54.             println("------------------------------" + i + "---------------------------------")  
  55.             ranks.foreach(s => println(s._1 + " - " + s._2))  
  56.         }  
  57.     }  
  58. }  
主要是使用了accumulator来记录dangling mass,需要注意的地方见代码注释。另在使用dangling值不能直接在Spark的Action操作中通过dangling.value使用。当accumulator出现在Action中,将会复制到分片(slice)上执行,执行完毕后再进行聚集。因此, 用了变量danglingValue来获得dangling的value ,进行PR值的计算。

迭代结果

迭代的次数为20次,也可以计算前后的差异阈值进行结束

[plain]  view plain copy print ?
  1. 4 - 0.15702615478678728  
  2. 2 - 0.15702615478678728  
  3. 5 - 0.22768787421485948  
  4. 1 - 0.30123366142477936  
  5. 3 - 0.15702615478678728  

V1:PageRank

各种问题

先贴上代码,再说明

[plain]  view plain copy print ?
  1. package myclass  
  2.   
  3. import org.apache.spark.SparkContext  
  4. import SparkContext._  
  5. import scala.collection.mutable.ArrayBuffer  
  6. import scala.collection.mutable  
  7.   
  8. /**  
  9.  * Created by jack on 2/22/14.  
  10.  */  
  11. object MyPageRank {  
  12.     def main(args: Array[String]) {  
  13.         if (args.length < 3) {  
  14.             System.err.println("Usage: PageRank <master> <file> <number_of_iterations>")  
  15.             System.exit(1)  
  16.         }  
  17.   
  18.         val iters = args(2).toInt  
  19.         val sc = new SparkContext(args(0), "My PageRank", System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass))  
  20.   
  21.         //未考虑出度为0的节点时的pagerank  
  22.         /*      val lines = sc.textFile(args(1), 1)  
  23.                 val links = lines.map(line => {  
  24.                     val parts = line.split("\\s+")  
  25.                     (parts(0), parts(1))  
  26.                 }).distinct().groupByKey().cache()  
  27.                 var ranks = links.mapValues(v => 1.0)  
  28.   
  29.                 for (i <- 1 to iters) {  
  30.                     val contribs = links.join(ranks).values.flatMap {  
  31.                         case (urls, rank) => {  
  32.                             val size = urls.size  
  33.                             urls.map(url => (url, rank / size))  
  34.                         }  
  35.                     }  
  36.                     ranks = contribs.reduceByKey(_ + _).mapValues(0.15 + 0.85 * _)  
  37.                 }  
  38.                 val output = ranks.collect  
  39.                 val urlSize = output.length  
  40.                 output.foreach( tup => println(tup._1 + "has rank: " + tup._2/output.length+"."))*/  
  41.   
  42. //考虑出度为0的节点  
  43.         val lines = sc.textFile(args(1), 1)  
  44.         val linkF = lines.map(line => {  
  45.             val parts = line.split("\\s+")  
  46.             (parts(0), parts(1))  
  47.         }).distinct().groupByKey()  
  48.   
  49.         var linkS = linkF  
  50.         var nodes = linkF.keys.collect()  
  51.         var newNodes = scala.collection.mutable.ArrayBuffer[String]()  
  52.         for {s <- linkF.values.collect()  
  53.                  k <- s if (!nodes.contains(k))  
  54.         } {  
  55.             nodes = nodes :+ k  
  56.             newNodes += k  
  57.         }  
  58.         linkS = linkS ++ sc.makeRDD(for (i <- newNodes) yield (i, ArrayBuffer[String]()))  
  59.         val linkT = linkS  
  60.         val nodeSize = linkS.count()  
  61.         var ranks = linkT.mapValues(v => 1.0 / nodeSize)  
  62.   
  63.         for (i <- 1 to iters) {  
  64.             var dangling = 0.0  
  65.             val linksAndPR = linkT.join(ranks).values  
  66.             for (i <- linksAndPR.filter(_._1.size == 0).collect()) {  
  67.                 dangling += i._2  
  68.             }  
  69.   
  70.             val contribs = linksAndPR.filter(_._1.size != 0).flatMap {  
  71.                 case (urls, rank) => {  
  72.                     val size = urls.size  
  73.                     urls.map(url => (url, rank / size))  
  74.                 }  
  75.             }  
  76.             ranks = contribs.reduceByKey(_ + _).mapValues[Double](p =>  
  77.                 0.1 * (1.0 / nodeSize) + 0.9 * (dangling / nodeSize + p)  
  78.             )  
  79.             println("------------------------------"+i+"---------------------------------")  
  80.             ranks.foreach(s => println(s._1 + " - " + s._2))  
  81.         }  
  82.     }  
  83. }  
以下问题针对出度为0节点的考虑:

问题1:用于dangling变量统计全局出度为0的节点的PR值和,关键就是更新的问题,用RDD的各种Transformation操作(如 foreach)无法更新dangling值,只能用for语句才有效。

问题2:针对问题1,尝试用accumulator,但是accumulator是要计算任务完成才能取值(猜测,类似hadoop的counter,无法全局统一,spark可能是分散着更新,最后再统一),单纯用accumulator不能解决问题。

问题3:对linkAndPR不得不进行了两次filter,合并处理会出现问题。

问题4:考虑出度为0的程序,只能单机跑,分布式上可能会有问题,依然是dangling问题,应该需要类似与hadoop的做法data-intensive text processing with mapreduce-Graph Algorithms ,2个job,但Spark的Job机制不熟悉,等以后解决。


你可能感兴趣的:(spark)