BroadcastManager--SparkEnv

Broadcast是分布式的数据共享,由BroadcastManager负责管理其创建或销毁。Broadcast一般用于处理共享的配置文件、通用Dataset、常用数据结构

广播的Broadcast变量是只读变量,保证数据的一致性。其数据保存方式是StorageLevel.MEMORY_AND_DISK,所以不会内存溢出,但广播大对象会导致网络IO或单点压力

Broadcast架构是常用的工厂模式(只生产一个产品Broadcast):BroadcastManager;BroadcastFactory、TorrentBroadcastFactory;Broadcast、TorrentBroadcast

Broadcast使用姿势

scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
scala> broadcastVar.value
res0: Array[Int] = Array(1, 2, 3)

通过SparkContext创建广播对象,然后value获取广播对象值。下面分析SparkContext的创建广播代码

def broadcast[T: ClassTag](value: T): Broadcast[T] = {
  assertNotStopped()
  // 不能直接对RDD类型进行broadcast操作,必须针对具体的值。可以先rdd.collect(),然后对结果broadcast
  require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
    "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
  // 调用SparkEnv的BroadcastManager创建广播对象
  val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
  // 获取调用堆栈信息
  val callSite = getCallSite
  logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
  // 将broadcast对象关联CleanupTaskWeakReference弱引用,并while循环检测referenceQueue。当Broadcast对象没有强引用时,会被GC回收,同时其关联的CleanupTaskWeakReference弱引用对象会被put到referenceQueue
  // CleanupTaskWeakReference继承WeakReference,并加入CleanupTask类型字段,这样从referenceQueue remove出弱引用时,能获取到CleanupTask字段值,进一步判断是否为CleanBroadcast,也就是获取到broadcastId上下文,再调用broadcastManager.unbroadcast()执行删除集群广播对象工作
  cleaner.foreach(_.registerBroadcastForCleanup(bc))
  bc
}

也可以使用Guava的FinalizableWeakReference类,在强引用对象被回收时,回调finalizeReferent()方法:无须自己新建一个Thread,维护ReferenceQueue,并while循环执行referenceQueue.remove(100),获取弱引用对象,执行额外清理工作

Guava的FinalizableReference接口只有一个方法void finalizeReferent();方法没有入参,没有类似的broadcastId上下文,通过闭包实现上下文的传递

WeakReference

Guava中的FinalizableReferenceQueue解析

BroadcastManager的创建

BroadcastManager是用来管理Broadcast,该对象在SparkEnv中创建

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

BroadcastManager类的initialize()初始化TorrentBroadcastFactory工厂,stop()、newBroadcast()、unbroadcast()分别调用BroadcastFactory接口方法,nextBroadcastId是由AtomicLong(0)生成,且自增

private[spark] class BroadcastManager(
    val isDriver: Boolean,
    conf: SparkConf,
    securityManager: SecurityManager)
  extends Logging {

  private var initialized = false
  private var broadcastFactory: BroadcastFactory = null

  initialize()

  // Called by SparkContext or Executor before using Broadcast
  private def initialize() {
    synchronized {
      if (!initialized) {
        broadcastFactory = new TorrentBroadcastFactory
        broadcastFactory.initialize(isDriver, conf, securityManager)
        initialized = true
      }
    }
  }

  def stop() {
    broadcastFactory.stop()
  }

  private val nextBroadcastId = new AtomicLong(0)

  private[broadcast] val cachedValues = {
    // HARD key, WEAK value。key强引用不会被回收,value弱引用在gc时被回收
    new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
  }

  def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }

  def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
  }
}

StrongReference(强) > SoftReference(软) > WeakReference(弱) > PhantomReference(虚)

使用SoftReference或WeakReference包装对象时,比如new WeakReference(Object, new ReferenceQueue()),当GC时,Object只被软引用或弱引用,没有被别的对象强引用

软引用:内存足够,GC时不回收此Object;内存OOM前,GC回收此Object。常用于Cache
弱引用:不管内存是否足够,GC时Object都会被回收

Java中的四种引用类型

cachedValues是apache ReferenceMap实现的,key是HARD强引用,value是WEAK弱引用。WeakHashMap的key是weak弱引用,value是强引用,当key被回收后,通过expungeStaleEntries()方法将e.value = null;断开WeakHashMap的强引用

ReferenceMap是线程不安全的,可以使用Collections.synchronizedMap包装,或使用spring的ConcurrentReferenceHashMap,类似ConcurrentHashMap实现,并使用WeakReference或SoftReference包装Entry,而不是K、V使用不同的引用类型;或者使用Guava的MapMaker类构造线程安全Map:ConcurrentMap concurrentMap = new MapMaker().weakValues().makeMap();

TorrentBroadcastFactory工厂

定义BroadcastFactory接口工厂,创建Broadcast的newBroadcast()方法、初始化initialize()、删除Broadcast的unbroadcast()、停止stop()

/**
 * An interface for all the broadcast implementations in Spark (to allow
 * multiple broadcast implementations). SparkContext uses a BroadcastFactory
 * implementation to instantiate a particular broadcast for the entire Spark job.
 */
private[spark] trait BroadcastFactory {

  def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit

  /**
   * Creates a new broadcast variable.
   *
   * @param value value to broadcast
   * @param isLocal whether we are in local mode (single JVM process)
   * @param id unique id representing this broadcast variable
   */
  def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]

  def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit

  def stop(): Unit
}

TorrentBroadcastFactory具体实现,调用TorrentBroadcast方法实现。也就是BroadcastManager调用BroadcastFactory,BroadcastFactory再调用Broadcast

/**
 * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
 * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
 * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
 */
private[spark] class TorrentBroadcastFactory extends BroadcastFactory {

  override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }

  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
    new TorrentBroadcast[T](value_, id)
  }

  override def stop() { }

  /**
   * Remove all persisted state associated with the torrent broadcast with the given ID.
   * @param removeFromDriver Whether to remove state from the driver.
   * @param blocking Whether to block until unbroadcasted
   */
  override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
  }
}

TorrentBroadcast对象

Broadcast是抽象类,因为有实例字段_isValid、_destroySite,普通方法value()、unpersist()、destroy(),以及抽象方法getValue()、doUnpersist(blocking)、doDestroy(blocking);而BroadcastFactory则定义为接口

/**
 * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
 * cached on each machine rather than shipping a copy of it with tasks. They can be used, for
 * example, to give every node a copy of a large input dataset in an efficient manner. Spark also
 * attempts to distribute broadcast variables using efficient broadcast algorithms to reduce
 * communication cost.
 *
 * Broadcast variables are created from a variable `v` by calling
 * [[org.apache.spark.SparkContext#broadcast]].
 * The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the
 * `value` method. The interpreter session below shows this:
 *
 * {{{
 * scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
 * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
 *
 * scala> broadcastVar.value
 * res0: Array[Int] = Array(1, 2, 3)
 * }}}
 *
 * After the broadcast variable is created, it should be used instead of the value `v` in any
 * functions run on the cluster so that `v` is not shipped to the nodes more than once.
 * In addition, the object `v` should not be modified after it is broadcast in order to ensure
 * that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped
 * to a new node later).
 *
 * @param id A unique identifier for the broadcast variable.
 * @tparam T Type of the data contained in the broadcast variable.
 */
abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {

  /**
   * Flag signifying whether the broadcast variable is valid
   * (that is, not already destroyed) or not.
   */
  @volatile private var _isValid = true

  private var _destroySite = ""

  /** Get the broadcasted value. */
  // 下面是3个抽象方法,下沉到具体子类实现逻辑
  def value: T = {
    assertValid()
    getValue()
  }

  /**
   * Asynchronously delete cached copies of this broadcast on the executors.
   * If the broadcast is used after this is called, it will need to be re-sent to each executor.
   */
  def unpersist() {
    unpersist(blocking = false)
  }

  /**
   * Delete cached copies of this broadcast on the executors. If the broadcast is used after
   * this is called, it will need to be re-sent to each executor.
   * @param blocking Whether to block until unpersisting has completed
   */
  def unpersist(blocking: Boolean) {
    assertValid()
    doUnpersist(blocking)
  }

  /**
   * Destroy all data and metadata related to this broadcast variable. Use this with caution;
   * once a broadcast variable has been destroyed, it cannot be used again.
   * This method blocks until destroy has completed
   */
  def destroy() {
    destroy(blocking = true)
  }

  /**
   * Destroy all data and metadata related to this broadcast variable. Use this with caution;
   * once a broadcast variable has been destroyed, it cannot be used again.
   * @param blocking Whether to block until destroy has completed
   */
  private[spark] def destroy(blocking: Boolean) {
    assertValid()
    _isValid = false
    
    // 获取线程方法栈信息。last、first都是针对先进后出的方法栈命名
    // shortForm: s"$lastSparkMethod at $firstUserFile:$firstUserLine"
    // longForm = callStack.take(callStackDepth).mkString("\n")
    _destroySite = Utils.getCallSite().shortForm
    logInfo("Destroying %s (from %s)".format(toString, _destroySite))
    doDestroy(blocking)
  }

  /**
   * Whether this Broadcast is actually usable. This should be false once persisted state is
   * removed from the driver.
   */
  private[spark] def isValid: Boolean = {
    _isValid
  }

  /**
   * Actually get the broadcasted value. Concrete implementations of Broadcast class must
   * define their own way to get the value.
   */
  protected def getValue(): T

  /**
   * Actually unpersist the broadcasted value on the executors. Concrete implementations of
   * Broadcast class must define their own logic to unpersist their own data.
   */
  protected def doUnpersist(blocking: Boolean)

  /**
   * Actually destroy all data and metadata related to this broadcast variable.
   * Implementation of Broadcast class must define their own logic to destroy their own
   * state.
   */
  protected def doDestroy(blocking: Boolean)

  /** Check if this broadcast is valid. If not valid, exception is thrown. */
  protected def assertValid() {
    if (!_isValid) {
      throw new SparkException(
        "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
    }
  }

  override def toString: String = "Broadcast(" + id + ")"
}

重点分析TorrentBroadcast类,构造函数参数是obj广播对象+nextBroadcastId。涉及BlockManager类的方法后续分析

/**
 * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
 *
 * The mechanism is as follows:
 *
 * The driver divides the serialized object into small chunks and
 * stores those chunks in the BlockManager of the driver.
 *
 * On each executor, the executor first attempts to fetch the object from its BlockManager. If
 * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
 * other executors if available. Once it gets the chunks, it puts the chunks in its own
 * BlockManager, ready for other executors to fetch from.
 *
 * This prevents the driver from being the bottleneck in sending out multiple copies of the
 * broadcast data (one per executor).
 *
 * When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
 *
 * @param obj object to broadcast
 * @param id A unique identifier for the broadcast variable.
 */
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

  /**
   * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
   * which builds this value by reading blocks from the driver and/or other executors.
   *
   * On the driver, if the value is required, it is read lazily from the block manager.
   */
   // 这是一个lazy方式,当executor需要广播对象时,从BlockManager中延迟读取
  @transient private lazy val _value: T = readBroadcastBlock()

  /** The compression codec to use, or None if compression is disabled */
  @transient private var compressionCodec: Option[CompressionCodec] = _
  /** Size of each block. Default value is 4MB.  This value is only read by the broadcaster. */
  @transient private var blockSize: Int = _

  // 初始化三个属性:compressionCodec、blockSize、checksumEnabled
  private def setConf(conf: SparkConf) {
    compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
      Some(CompressionCodec.createCodec(conf))
    } else {
      None
    }
    // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
    // 每个数据块的大小是4M,也就是4*1024*1024
    blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
    checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
  }
  setConf(SparkEnv.get.conf)

  // broadcast_0 (id是BroadcastManager类的nextBroadcastId字段,AtomicLong自增)
  private val broadcastId = BroadcastBlockId(id)

  /** Total number of blocks this broadcast variable contains. */
  // 构造TorrentBroadcast对象时,进行广播对象的写操作,并返回数据块总数
  private val numBlocks: Int = writeBlocks(obj)

  /** Whether to generate checksum for blocks or not. */
  private var checksumEnabled: Boolean = false
  /** The checksum for all the blocks. */
  // 保存每个block的校验和,对应writeBlocks()方法
  private var checksums: Array[Int] = _

  override protected def getValue() = {
    _value
  }

  // 对block计算校验和:Adler32相比CRC32计算更快
  // Checksum checksumEngine = new Adler32(); checksumEngine.update(bytes); long checksum = checksumEngine.getValue();
  private def calcChecksum(block: ByteBuffer): Int = {
    val adler = new Adler32()
    if (block.hasArray) {
      adler.update(block.array, block.arrayOffset + block.position(), block.limit()
        - block.position())
    } else {
      val bytes = new Array[Byte](block.remaining())
      block.duplicate.get(bytes)
      adler.update(bytes)
    }
    adler.getValue.toInt
  }

  /**
   * Divide the object into multiple blocks and put those blocks in the block manager.
   *
   * @param value the object to divide
   * @return number of blocks this broadcast variable is divided into
   */
   // 1. value对象putSingle到driver本地  2. 对象按block拆分成ByteBuffer数组:blockifyObject  3. 计算每个block的校验和  4. 将每个block进行putBytes
  private def writeBlocks(value: T): Int = {
    import StorageLevel._
    // Store a copy of the broadcast variable in the driver so that tasks run on the driver
    // do not create a duplicate copy of the broadcast variable's value.
    // 获取当前Executor的BlockManager组件
    val blockManager = SparkEnv.get.blockManager
    
    // 调用BlockManager的putSingle方法将广播对象写入本地的存储体系。当Spark以local模式运行时,则会将广播对象写入Driver本地的存储体系,以便于任务也可以在Driver上执行。由于MEMORY_AND_DISK对应的StorageLevel的_replication属性固定为1,因此此处只会将广播对象写入Driver或Executor本地的存储体系
    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
      throw new SparkException(s"Failed to store $broadcastId in BlockManager")
    }
    
    // 将对象经过序列化、压缩转换成一系列的字节块Array[ByteBuffer],每个块大小: 4*1024*1024
    val blocks =
      TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
      
    // 如果需要给分片广播块生成校验和,则创建与blocks块长度一致的checksums数组
    if (checksumEnabled) {
      checksums = new Array[Int](blocks.length)
    }
    blocks.zipWithIndex.foreach { case (block, i) =>
      if (checksumEnabled) {
        // 为每个ByteBuffer的block计算校验和,并保存到数组
        checksums(i) = calcChecksum(block)
      }
      // broadcast_0_piece0、broadcast_0_piece1、broadcast_0_piece2...
      val pieceId = BroadcastBlockId(id, "piece" + i)
      val bytes = new ChunkedByteBuffer(block.duplicate())
      // 将每个分片广播数据块以序列化方式写入Driver本地的存储体系,存储方式为MEMORY_AND_DISK_SER,同时tellMaster注册成为下载源
      if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
        throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
      }
    }
    blocks.length
  }

  /** Fetch torrent blocks from the driver and/or other executors. */
  // 从driver或executors循环获取广播对象的每个块数据
  // 将BlockData定义为接口,而不是某个具体实现类,抽象程度更高
  private def readBlocks(): Array[BlockData] = {
    // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
    // to the driver, so other executors can pull these chunks from this executor as well.
    val blocks = new Array[BlockData](numBlocks)
    val bm = SparkEnv.get.blockManager

    // 对各个广播分片随机shuffle,避免某个块的获取出现“热点”。都从0开始的话,前面的块都是"热点"
    for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
      val pieceId = BroadcastBlockId(id, "piece" + pid)
      logDebug(s"Reading piece $pieceId of $broadcastId")
      // First try getLocalBytes because there is a chance that previous attempts to fetch the
      // broadcast blocks have already fetched some of the blocks. In that case, some blocks
      // would be available locally (on this executor).
      // 从存储体系获取pieceId对应的Block数据,并封装为BlockData
      // 因为可能其他executor进程正在或已经下载了此Block数据,所以先getLocalBytes
      bm.getLocalBytes(pieceId) match {
        case Some(block) =>
          blocks(pid) = block
          releaseLock(pieceId)
        case None =>
          // 从远端的BlockManager以序列化的字节形式获取Block数据,返回ChunkedByteBuffer
          bm.getRemoteBytes(pieceId) match {
            case Some(b) =>
              // 检查校验和
              if (checksumEnabled) {
                // ChunkedByteBuffer数组的第一个ByteBuffer是Block数据
                val sum = calcChecksum(b.chunks(0))
                // 与本地保存的校验和数组对应的值进行对比
                if (sum != checksums(pid)) {
                  throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
                    s" $sum != ${checksums(pid)}")
                }
              }
              // We found the block from remote executors/driver's BlockManager, so put the block
              // in this executor's BlockManager.
              // 因为Block是从远程的Driver或其他Executors的BlockManager获取的,所以再把此Block保存到当前Executor进程的BlockManager里,并tellMaster成为下载源
              if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
                throw new SparkException(
                  s"Failed to store $pieceId of $broadcastId in local BlockManager")
              }
              // 封装成ByteBufferBlockData对象
              blocks(pid) = new ByteBufferBlockData(b, true)
            case None =>
              throw new SparkException(s"Failed to get $pieceId of $broadcastId")
          }
      }
    }
    blocks
  }

  /**
   * Remove all persisted state associated with this Torrent broadcast on the executors.
   */
   // 删除所有Executors的广播对象。与下面一个方法的区别就是:是否同时删除Driver上的广播对象
  override protected def doUnpersist(blocking: Boolean) {
    TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
  }

  /**
   * Remove all persisted state associated with this Torrent broadcast on the executors
   * and driver.
   */
  override protected def doDestroy(blocking: Boolean) {
    TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
  }

  /** Used by the JVM when serializing this object. */
  // 序列化类的一种方式。当类继承Serializable接口,序列化对象时,ObjectOutputStream.writeSerialData()会反射调用类的writeObject()方法实现序列化操作
  // 序列化的几种实现: 1. 直接继承Serializable接口 2. Serializable+transient 3. 实现Externalizable接口的writeExternal、readExternal方法 4. Serializable+类中实现writeObject、readObject方法,通过ObjectOutputStream反射调用
  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {   // 这句就是重写writeObject()方法的意义所在,判断广播对象是否已经销毁
    assertValid()
    // 将当前类的非静态和非transient字段写入OutputStream,Serializable接口的默认实现
    // 序列化的是对象状态,静态字段不属于对象,而是类级
    out.defaultWriteObject()
  }

  // 广播对象的读操作,同步操作
  private def readBroadcastBlock(): T = Utils.tryOrIOException {
    TorrentBroadcast.synchronized {
      // 获取broadcastCache,key是BroadcastBlockId对象,HARD引用类型;value是广播对象T,WEAK引用类型
      val broadcastCache = SparkEnv.get.broadcastManager.cachedValues

      // 先走cache,因为value是Weak引用,所以当GC回收广播对象后,先blockManager.getLocalValues从存储获取广播对象,并cache
      Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
        setConf(SparkEnv.get.conf)
        val blockManager = SparkEnv.get.blockManager
        // 从本地的存储系统中获取广播对象,即通过BlockManager的putSingle方法写入存储体系的广播对象。这里返回BlockResult类型,使用Iterator包装广播对象obj
        blockManager.getLocalValues(broadcastId) match {
          case Some(blockResult) =>
            if (blockResult.data.hasNext) {
             // 转换成广播对象,并释放块锁
              val x = blockResult.data.next().asInstanceOf[T]
              releaseLock(broadcastId)

              if (x != null) {
                // 缓存
                broadcastCache.put(broadcastId, x)
              }

              x
            } else {
              throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
            }
          case None =>
            logInfo("Started reading broadcast variable " + id)
            val startTimeMs = System.currentTimeMillis()
            
            // 读取广播对象的所有Blocks,返回Array[BlockData]数据
            val blocks = readBlocks()
            logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

            try {
             // 将Array[BlockData]合并InputStream,并解压、反序列化成对象T
              val obj = TorrentBroadcast.unBlockifyObject[T](
                blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
              // Store the merged copy in BlockManager so other tasks on this executor don't
              // need to re-fetch it.
              val storageLevel = StorageLevel.MEMORY_AND_DISK
              // 保存obj到当前Executor的BlockManager,供本机的其他Executor进程获取
              if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
                throw new SparkException(s"Failed to store $broadcastId in BlockManager")
              }

              if (obj != null) {
                broadcastCache.put(broadcastId, obj)
              }

              obj
            } finally {
              // blocks生命周期结束,clean堆外内存
              blocks.foreach(_.dispose())
            }
        }
      }
    }
  }

  /**
   * If running in a task, register the given block's locks for release upon task completion.
   * Otherwise, if not running in a task then immediately release the lock.
   */
   // 唤醒BlockId对应的读或写阻塞,执行BlockInfoManager.unlock()方法
   // 这个锁保证当数据块被一个运行中的任务使用时,其他任务不能再次使用,直到此任务完成并释放锁
   // BlockId是抽象类设计
  private def releaseLock(blockId: BlockId): Unit = {
    val blockManager = SparkEnv.get.blockManager
    // TaskContext是ThreadLocal模式,抽象类设计
    Option(TaskContext.get()) match {
      case Some(taskContext) =>
      // 当get到TaskContext对象,也就是在Task执行中进行Block获取时,在TaskContext执行结束(无论成功或失败),释放lock
        taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId))
      case None =>
        // This should only happen on the driver, where broadcast variables may be accessed
        // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow
        // broadcast variables to be garbage collected we need to free the reference here
        // which is slightly unsafe but is technically okay because broadcast variables aren't
        // stored off-heap.
        blockManager.releaseLock(blockId)
    }
  }
}

// 写数据时,从外而内,先序列化再压缩: ((compress)serialize).write(obj)
// 读数据时,从内而外,先解压再反序列化: ((compress(obj))serialize).read
private object TorrentBroadcast extends Logging {

  // 将对象obj转换成一系列的字节块Array[ByteBuffer],每个块大小: 4*1024*1024
  // 先序列化,再压缩,最后转换成字节块。OutputStream是从外而内操作数据
  def blockifyObject[T: ClassTag](
      obj: T,
      blockSize: Int,
      serializer: Serializer,
      compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
    // ChunkedByteBufferOutputStream包装
    val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
    // CompressionCodec压缩
    val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
    val ser = serializer.newInstance()
    // 序列化Stream
    val serOut = ser.serializeStream(out)
    Utils.tryWithSafeFinally {
      // OutputStream的write方法输入是byte[],obj先序列化成byte[],再压缩成另一个byte[],最后将压缩的byte[]转换成字节块,也就是调用ChunkedByteBufferOutputStream的write方法
      serOut.writeObject[T](obj)
    } {
      serOut.close()
    }
    cbbos.toChunkedByteBuffer.getChunks()
  }

  // 将Array[InputStream]转换成对象
  // 先解压,再反序列化。InputStream是从内而外的操作数据
  def unBlockifyObject[T: ClassTag](
      blocks: Array[InputStream],
      serializer: Serializer,
      compressionCodec: Option[CompressionCodec]): T = {
    require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
    // 顺序合并流SequenceInputStream
    val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration)
    // 解压
    val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
    val ser = serializer.newInstance()
    val serIn = ser.deserializeStream(in)
    // Utils.tryWithSafeFinally封装
    val obj = Utils.tryWithSafeFinally {
      // 反序列化
      serIn.readObject[T]()
    } {
      serIn.close()
    }
    obj
  }

  /**
   * Remove all persisted blocks associated with this torrent broadcast on the executors.
   * If removeFromDriver is true, also remove these persisted blocks on the driver.
   */
  // 根据id删除所有Executors进程上的广播数据blocks
  // removeFromDriver: 判断是否删除driver上的广播对象
  def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
    logDebug(s"Unpersisting TorrentBroadcast $id")
    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
  }
}

Broadcast实现总结

  1. Driver端创建广播对象,writeBlocks(obj)将对象分块,每个块做为一个block存进Driver端的BlockManager
  2. Broadcast在Driver端进行序列化,在Executor端进行反序列化,并调用broadcastVar.value获取广播对象,"@transient private lazy val _value: T = readBroadcastBlock()",这是不会被序列化且lazy方式;"private var checksums: Array[Int] = _",校验和数组可以被Driver序列化
  3. 每个Executor会试图获取所有的块,来组装成一个被Broadcast的变量。"获取块"的方法是首先从Executor自身的BlockManager中获取,如果自己的BlockManager中没有这个块,就从别的BlockManager中获取

最初的时候,Driver是获取这些块的唯一的源,随着各个Executor的BlockManager从Driver端获取了不同的块(TorrentBroadcast会有意避免各个Executor以同样的顺序获取这些块: Random.shuffle(Seq.range(0, numBlocks))),"块"的源就多了起来,
每个Executor就可能从多个源中的一个,包括Driver和其它Executor的BlockManager中获取块,这样就使得流量在整个集群中更均匀,而不是由Driver作为唯一的源

引申

SparkBroadcast实现原理

你可能感兴趣的:(BroadcastManager--SparkEnv)