Spark GraphX(一)



abstract class Graph[VD, ED]{
	val vertices: VertexRDD[VD]
	val edges: EdgeRDD[ED
	val triplets: RDD[EdgeTriplet[VD, ED]]]


  • Vertices:由VertexId(Long类型)、attribute(属性描述或距离)构成。如,(3L, ("San Francisco", "CA")),(1L, 10)。
  • Edges:由srcId(起始结点VertexId)、dstId(终结点VertexId)、attribute(边的权值)构成。如,Edge(1L, 2L, 20)
  • Triplets:由srcId、srcAttr(起始Vertex)和dstId、dstAttr(终止Vertex),以及attr构成。如,((1, (Santa Clara, CA)), (2, (Fremont, CA)), 20)


val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity)

Graph有很多操作,具体的可以参考Spark API。


在GraphX中,更高级的Pregel操作是一个约束到图拓扑的批量同步(bulk-synchronous)并行消息抽象。Pregel操作者执行一系列的超级步骤(super steps),在这些步骤中,顶点从之前的超级步骤中接收进入(inbound)消息的总和,为顶点属性计算一个新的值,然后在以后的超级步骤中发送消息到邻居顶点。不像Pregel而更像GraphLab,消息作为一个边为三元组的函数被并行计算,消息计算既访问了源顶点特征也访问了目的顶点特征。在super steps中,没有收到消息的顶点被跳过。当没有消息遗留时,Pregel操作停止迭代并返回最终的图。



而Pregel处理的数据流(Dataflow)是:1.每一次迭代计算从计算指定点的邻结点和出边开始;2.使用triplets视图,重新计算每个triplet的消息,然后在终结点合并消息;3.在所有顶点,信息被vertex programs收到。在伯克利的论文《 GraphX: Graph Processing in a Distributed Dataflow Framework》中,关于Graph计算的数据流的描述如下:

Each iteration begins by executing the join stage to bind active vertices with their outbound edges. Using the triplets view, messages are computed along each triplet in a map stage and then aggregated at their destination vertex in a groupby stage. Finally, the messages are received by the vertex programs in a map stage over the vertices.

在Google关于Pregel的论文《Pregel: A System for Large-Scale Graph Processing》中,举了一个图形化的例子描述如下:

Spark GraphX(一)_第1张图片


def pregel[A: ClassTag](
	initialMsg: A,
	maxIterations: Int = Int.MaxValue,
	activeDirection: EdgeDirection = EdgeDirection.Either)(
	vprog: (VertexId, VD, A) => VD,
	sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
	mergeMsg: (A, A) => A)

例一,我们可以用Pregel操作表达计算单源最短路径( single source shortest path)。

import org.apache.spark.graphx._
// Import random graph generation library
import org.apache.spark.graphx.util.GraphGenerators
// A graph with edge attributes containing distances
val graph: Graph[Int, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble)
val sourceId: VertexId = 42 // The ultimate source
// Initialize the graph such that all vertices except the root have distance infinity.
val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity)
val sssp = initialGraph.pregel(Double.PositiveInfinity)(
  vprog = (id, dist, newDist) => math.min(dist, newDist), // Vertex Program
  sendMsg = triplet => {  // Send Message
				if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
				  Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
				} else {
  mergeMsg = (a,b) => math.min(a,b) // Merge Message

例二,用Pregel实现广度优先遍历(Breadth First Search)

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.graphx.GraphLoader

 * @author Administrator
object BFS {

  def main(args: Array[String]): Unit = {

    if(args.length != 4){
        System.err.println("Usage: BFS <input file> <output file> <source vertex> <iteration number>")

    val sAllTime = System.currentTimeMillis()
    val conf = new SparkConf()
    val sc = new SparkContext(conf)
    val fname = args(0)
    val outPath = args(1)
    val srcVertex = args(2).toInt
    val numIter = args(3).toInt

    val sLoadTime = System.currentTimeMillis()
    val graphFile = GraphLoader.edgeListFile(sc, fname).cache()
    val eLoadTime = System.currentTimeMillis()

    val graph = graphFile.mapVertices((id, _) => if (id == srcVertex) 0.0 else Double.PositiveInfinity)
    val sComTime = System.currentTimeMillis()
    val bfs = graph.pregel(Double.PositiveInfinity, numIter)(
        vprog = (id, attr, msg) => math.min(attr, msg),
        sendMsg = triplet => {
            if (triplet.srcAttr != Double.PositiveInfinity) {
              Iterator((triplet.dstId, triplet.srcAttr+1))
            else {
        mergeMsg = (a,b) => math.min(a,b) )
    val eComTime = System.currentTimeMillis()

    val eAllTime = System.currentTimeMillis()
    println("Load time: " + (eLoadTime - sLoadTime) / 1000)
    println("Compute time: " + (eComTime - sComTime) / 1000)
    println("Total time: " + (eAllTime - sAllTime) / 1000)
