spark源码分析— spark广播变量

spark广播变量是spark中一种只读的数据,广播的数据被集群不同节点共享,且默认存储在内存中,读取速度比较快。
spark内部有很多地方都使用的广播变量功能,比如spark sql的join有一种broadcast join,就是使用广播变量实现的小表join的优化;spark提交job时,将task的依赖关系广播到集群等。
接下来,分析一下,spark广播变量的实现细节:
```
val broadcastVar = sc.broadcast(Array(1, 2, 3)) // 创建一个广播变量
```

看一下一个广播如何创建,并在集群间共享的
```
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
    assertNotStopped()
    require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
      "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
    val bc = env.broadcastManager.newBroadcast[T](value, isLocal)   //主要代码
    val callSite = getCallSite
    logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
    cleaner.foreach(_.registerBroadcastForCleanup(bc))
    bc
  }
```

通过BroadcastManager的newBoadcast()方法创建广播变量,该方法只有一行代码,就是调用BroadcastFactory的子类的newBroadcast()的方法,BroadcastFactory只有一个实现子类,就是TorrentBroadcastFactory。
```
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }
```
TorrentBroadcastFactory类的newBroadcast()方法会创建一个TorrentBroadcast,该实现使用了类BitTorrent的p2p协议,实现机制如下:driver会将广播的对象序列化,然后将序列化对象拆分为小数据块,并将这些数据块存储在driver端的BlockManage中。在每个executor端,executor首先会试图从它自己的BlockManager拉取数据。如果没有,它会从driver或者其他executor拉取数据,然后将拉取来的数据存储在自己的BlockManager,等待其他executor的拉取。该算法很好的分摊了driver的网络传输数据的压力到每个executor端。下面看一下spark的具体实现:
TorrentBroadcastFactory类的newBroadcast()方法,只是创建了一个TorrentBroadcast对象。
```
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
    new TorrentBroadcast[T](value_, id)
  }
```
接下来,直接看一下TorrentBroadcast类写数据到BlockManager的代码和从BlockManager读取数据的代码。

```
private def writeBlocks(value: T): Int = {
    import StorageLevel._
    val blockManager = SparkEnv.get.blockManager
    // 存储一份广播变量的拷贝,方便driver的任务快速读取广播变量
    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
      throw new SparkException(s"Failed to store $broadcastId in BlockManager")
    }
    // 序列化数据对象,并拆分成多个数据块
    val blocks =
      TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
    if (checksumEnabled) {
      checksums = new Array[Int](blocks.length)
    }
    // 存储数据块到BlockManager
    blocks.zipWithIndex.foreach { case (block, i) =>
      if (checksumEnabled) {
        // 如果开启校验功能,计算数据块的校验和,使用Adler-32算法,
        // Adler-32校验和几乎与CRC-3一样可靠,但是能够更快地计算出来。
        checksums(i) = calcChecksum(block)
      }
      val pieceId = BroadcastBlockId(id, "piece" + i)
      val bytes = new ChunkedByteBuffer(block.duplicate())
      // 写入广播数据块到driver端的BlockManager
      if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
        throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
      }
    }
    blocks.length
  }
```
广播变量的存入很简单,首先存入一份到driver的BlockManager,便于Driver端运行的tasks直接读取整个数据。然后将广播变量数据切分未多个数据块,然后将所有数据块存入driver端的BlockManager。

下面看一下读出广播变量的过程。
```
private def readBroadcastBlock(): T = Utils.tryOrIOException {
    TorrentBroadcast.synchronized {
      setConf(SparkEnv.get.conf)
      val blockManager = SparkEnv.get.blockManager
      // 首先从本地的BlockManagr读取完整的未被切分的广播变量数据
      blockManager.getLocalValues(broadcastId) match {
        case Some(blockResult) =>
            // 如果数据存在,直接返回
          if (blockResult.data.hasNext) {
            val x = blockResult.data.next().asInstanceOf[T]
            releaseLock(broadcastId)
            x
          } else {
            throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
          }
        case None =>
          logInfo("Started reading broadcast variable " + id)
          val startTimeMs = System.currentTimeMillis()
          // 如果本地不存在未被切分的广播变量数据,
          // 则试图读取切分的数据块
          val blocks = readBlocks()
          logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

          try {
            // 将获取的数据块拼回完整的数据对象
            val obj = TorrentBroadcast.unBlockifyObject[T](
              blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
            // 将拼回的完整数据对象,存储到BlockManager,
            // 便于该executor运行的任务直接读取完整数据对象
            val storageLevel = StorageLevel.MEMORY_AND_DISK
            if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
              throw new SparkException(s"Failed to store $broadcastId in BlockManager")
            }
            obj
          } finally {
            // 清空block数据
            blocks.foreach(_.dispose())
          }
      }
    }
  }
```
接下来,看一下读取广播变量数据块的方法
```
private def readBlocks(): Array[BlockData] = {
    val blocks = new Array[BlockData](numBlocks)
    val bm = SparkEnv.get.blockManager
    // 读取所有的数据块
    for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
      val pieceId = BroadcastBlockId(id, "piece" + pid)
      logDebug(s"Reading piece $pieceId of $broadcastId")
      // First try getLocalBytes because there is a chance that previous attempts to fetch the
      // broadcast blocks have already fetched some of the blocks. In that case, some blocks
      // would be available locally (on this executor).
      bm.getLocalBytes(pieceId) match {
        case Some(block) =>
          blocks(pid) = block
          releaseLock(pieceId)
        case None =>
          // 获取远端driver或executor的数据块,BlockManager会向driver端的BlockManager查询存储该数据块的所有executor和driver位置,
          // 然后随机获取其中一个位置拉取数据
          bm.getRemoteBytes(pieceId) match {
            case Some(b) =>
              // 拉取到数据,然后计算校验和
              if (checksumEnabled) {
                val sum = calcChecksum(b.chunks(0))
                if (sum != checksums(pid)) {
                  throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
                    s" $sum != ${checksums(pid)}")
                }
              }
              // 将从远端拉取的数据块存储到本地BlockManager,以便其他executor拉取
              if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
                throw new SparkException(
                  s"Failed to store $pieceId of $broadcastId in local BlockManager")
              }
              blocks(pid) = new ByteBufferBlockData(b, true)
            case None =>
              throw new SparkException(s"Failed to get $pieceId of $broadcastId")
          }
      }
    }
    blocks
  }
```
广播变量的读取比较复杂,首先读取端会尝试从本地BlockManager直接读取为切分的完整数据;如果不存在会尝试从本地BlockManager读取切分的数据块;如果都不存在,则从远端的driver或executor拉取,拉取每个数据块时,都会随机选择一个持有该数据块的executor或driver进行拉取,这样可以减少各个节点的网络IO压力。远端拉取来数据块会拷贝一份存储在本地BlockManager,以便其他executor拉取数据用。如果广播变量是读取数据块,会将数据块拼回完整数据对象,然后会将完成的数据对象拷贝一份存储在本地BlockManager,以便executor上执行的tasks快速读取广播变量。
由此可以看出广播变量会在每个节点存储两份:
- 一份是未切分的完整数据对象,用于executor或driver上执行的tasks快速读取
- 一份是切分后的数据,用于其他executor拉取对应的数据块。
spark的广播变量的写入比较简单,写入本地BlockManager两份数据即可。读取比较复杂,这里也真正的体现了p2p的BitTorrent协议的实现。
至此,spark的广播变量分析完了,应该对spark的广播变量原理有了一个直观的了解了吧。

你可能感兴趣的:(spark)