Spark2.2 广播变量broadcast原理及源码剖析

实例

首先先来看一看broadcast的使用代码:

    val factor = List[Int](1,2,3);
    val factorBroadcast = sc.broadcast(factor)

    val nums = Array(1,2,3,4,5,6,7,8,9)
    val numsRdd = sc.parallelize(nums,3)

    val list = new ListBuffer[List[Int]]()
    val resRdd = numsRdd.mapPartitions(ite =>{
      while (ite.hasNext){
        list+=ite.next()::(factorBroadcast.value)
      }
      list.iterator
    })
    resRdd.foreach(res => println(res))

    /**结果:
    List(1, 1, 2, 3)
    List(2, 1, 2, 3)
    List(3, 1, 2, 3)
    List(4, 1, 2, 3)
    List(5, 1, 2, 3)
    List(6, 1, 2, 3)
    List(7, 1, 2, 3)
    List(8, 1, 2, 3)
    List(9, 1, 2, 3)
    */

首先生成了一个集合变量,把这个变量通过sparkContext的broadcast函数进行广播,最后在rdd的每一个partition迭代时,使用这个广播变量。


源码分析

接下来看看广播变量的生成与数据的读取实现部分:

广播变量的生成

(1)SparkContext.broadcast( )

  /**
   * 向集群广播一个只读变量,返回在分布式函数中读取它的对象。
   * @return 在每个Executor上缓存一个只读变量
   */
  def broadcast[T: ClassTag](value: T): Broadcast[T] = {
    assertNotStopped()
    //不能直接广播RDDs;代替,调用collect()并传播结果。
    require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
      "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
    //通过broadcastManager中的newBroadcast函数来进行广播.
    val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
    val callSite = getCallSite
    logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
    cleaner.foreach(_.registerBroadcastForCleanup(bc))
    bc
  }

(2)BroadcastManager.newBroadcast( )
在BroadcastManager中生成广播变量的函数,这个函数直接使用的broadcastFactory的对应函数。broadcastFactory的实例通过配置spark.broadcast.factory,默认是TorrentBroadcastFactory。

  def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
  //
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }

(3)TorrentBroadcastFactory.newBroadcast( )
在TorrentBroadcastFactory中生成广播变量的函数:在这里面,直接生成了一个TorrentBroadcast的实例。

  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
    new TorrentBroadcast[T](value_, id)
  }

数据的读取实现

private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

  /**
   * 对executors的广播对象的值。
   * 这是由[[readBroadcastBlock]]重建的,通过从driver或者其它executor读取块来构建此值;
   * 在driver中,如果需要值,则从块管理器中延迟读取。
   */
  @transient private lazy val _value: T = readBroadcastBlock()
  // 是否选择压缩,默认是snappy压缩方式 
  @transient private var compressionCodec: Option[CompressionCodec] = _
  // 每个块的大小。默认值是4 mb。该值仅由broadcaster读取
  @transient private var blockSize: Int = _

  //设置配置参数
  private def setConf(conf: SparkConf) {
    compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
      Some(CompressionCodec.createCodec(conf))
    } else {
      None
    }
    blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
    checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
  }
  setConf(SparkEnv.get.conf)
  //唯一ID broadcastId
  private val broadcastId = BroadcastBlockId(id)
  // 这个广播变量包含的块总数。
  private val numBlocks: Int = writeBlocks(obj)


//--------------------


  /**
   * 将对象划分为多个块,并将这些块放在块管理器中。
   * @param value the object to divide
   * @return 这个广播变量被分成多个块
   */
  private def writeBlocks(value: T): Int = {
    import StorageLevel._
    //这里先把这个广播变量保存一份到当前的task的storage中,
    //这样做是保证在读取时,如果要使用这个广播变量的task就是本地的task时,直接从blockManager中本地读取.
    val blockManager = SparkEnv.get.blockManager
    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
      throw new SparkException(s"Failed to store $broadcastId in BlockManager")
    }
    //这里根据block的设置大小,对value进行序列化/压缩分块,每一个块的大小为blocksize的大小
    val blocks =
      TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
    if (checksumEnabled) {
      checksums = new Array[Int](blocks.length)
    }
    // 这里把序列化并压缩分块后的blocks进行迭代,存储到blockManager中
    blocks.zipWithIndex.foreach { case (block, i) =>
      if (checksumEnabled) {
        checksums(i) = calcChecksum(block)
      }
      val pieceId = BroadcastBlockId(id, "piece" + i)
      val bytes = new ChunkedByteBuffer(block.duplicate())
      if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
        throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
      }
    }
    //这个函数的返回值是一个int类型的值,这个值就是序列化压缩存储后block的个数.
    blocks.length
  }


//--------------------

//在我们的示例中,使用getValue时,会执行实例初始化时定义的lazy的函数readBroadcastBlock:
  private def readBroadcastBlock(): T = Utils.tryOrIOException {
    TorrentBroadcast.synchronized {
      setConf(SparkEnv.get.conf)
      val blockManager = SparkEnv.get.blockManager
      // 这里先从local端的blockmanager中直接读取storage中对应此广播变量的内容,
      // 如果能读取到,表示这个广播变量已经读取过来或者说这个task就是广播的本地executor.
      blockManager.getLocalValues(broadcastId) match {
        case Some(blockResult) =>
          if (blockResult.data.hasNext) {
            val x = blockResult.data.next().asInstanceOf[T]
            releaseLock(broadcastId)
            x
          } else {}

         //下面这部分执行时,表示这个广播变量在当前的executor中是第一次读取,
         //通过readBlocks函数去读取这个广播变量的所有的blocks,反序列化后,
         //直接把这个广播变量存储到本地的blockManager中,下次读取时,就可以直接从本地进行读取.
        case None =>
          val blocks = readBlocks()

          try {
           // 将合并后的副本存储在BlockManager中,这样executor上的其他tasks就不需要重新取回了。
            val obj = TorrentBroadcast.unBlockifyObject[T](
            val storageLevel = StorageLevel.MEMORY_AND_DISK
            if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
              throw new SparkException(s"Failed to store $broadcastId in BlockManager")
            }
            obj
          } finally {
            blocks.foreach(_.dispose())
          }
      }
    }
}


//--------------------


  /** 
  最后再看看readBlocks函数的处理流程:
  从driver或其他executor中获取torrent块。*/
  private def readBlocks(): Array[BlockData] = {
    val blocks = new Array[BlockData](numBlocks)
    val bm = SparkEnv.get.blockManager

// 这里开始迭代读取每一个block的内容,
// 这里的读取是先从local中进行读取,
// 如果local中没有读取到数据时,通过blockManager读取远端的数据,
// 通过读取这个block对应的location从这个location去读取这个block的内容,并存储到本地的blockManager中.
// 最后,这个函数返回读取到的blocks的集合.
    for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
      val pieceId = BroadcastBlockId(id, "piece" + pid)
      logDebug(s"Reading piece $pieceId of $broadcastId")
      // 有可能一部分broadcast blocks已经被拉到本地,所以先尝试从本地获取getLocalBytes
      bm.getLocalBytes(pieceId) match {
        case Some(block) =>
          blocks(pid) = block
          releaseLock(pieceId)
        case None =>
          bm.getRemoteBytes(pieceId) match {
            case Some(b) =>
              if (checksumEnabled) {
                val sum = calcChecksum(b.chunks(0))
                if (sum != checksums(pid)) {
                  throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
                    s" $sum != ${checksums(pid)}")
                }
              }
              // 我们从远程executors/driver的BlockManager中找到了块,把block放在这个executor的BlockManager中。
              if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
                throw new SparkException(
                  s"Failed to store $pieceId of $broadcastId in local BlockManager")
              }
              blocks(pid) = new ByteBufferBlockData(b, true)
            case None =>
              throw new SparkException(s"Failed to get $pieceId of $broadcastId")
          }
      }
    }
    blocks
  }

你可能感兴趣的:(spark)