Spark BroadCast 解析

前言
在实际使用中对于一些许多rdd需要用到的大的只读数据集变量可以使用共享变量的方式来提高性能,例如查内存表,默认情况下会每个task都保存一份,这样太浪费资源,所以一般会采用共享变量的方式来查表,代码中经常使用,但还没细致研究过,这次刚好借着阅读Spark RDD API源码的机会来深入解析一下broadcast。

Broadcast代码还涉及到spark底层存储代码BlockManager、BlockId等。

简介
Broadcast变量使得编程人员在每一台机器上保存一份只读类型的变量而不需要为每一个task保存一份。在为每一个节点保存一份较大的输入数据集时这是一种很高效的手段,另外spark还尝试用高效的高效broadcast算法去减少通信开销。

基础类

abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {

该虚类有两种实现方式:

对应着两种网络协议类型,http协议和比特流bittorrent协议。

BroadcastFactory接口用来初始化和新建不同类型的broadcast变量,sparkContext会为不同用户产生特定的broadcast变量。

trait BroadcastFactory {

一共有下列方法:

该接口也有两种继承方式:

BroadcastManager负责具体的broadcast的初始化、删除和管理工作

private[spark] class BroadcastManager(
    val isDriver: Boolean,
    conf: SparkConf,
    securityManager: SecurityManager)
  extends Logging {

对应的方法和变量有:


bitTorrent-like broadcast

这里先简单介绍下比特流协议:

比特流Bit-torrent是一种内容分发协议,有布拉姆科恩自主开发。它采用高效的软件分发系统和P2P技术共享大体积文件(如一部电影或电视节目),并使每个用户像网络重新分配结点那样提供上传服务。一般的下载服务器为每一个发出下载请求的用户提供下载服务,而bitTorrent的工作方式与之不同。分配器或文件的持有者将文件发送给其中一名用户,再由这名用户转发给其他用户,用户之间相互转发自己所拥有的文件部分,直到每个用户的下载全部完成。这种方法可以使下载服务器同时处理多个大体积文件的下载请求,而无需占用大量带宽。

首先是TorrentBroadcastFactory:

class TorrentBroadcastFactory extends BroadcastFactory {
  override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }

  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
    new TorrentBroadcast[T](value_, id)
  }

  override def stop() { }

  /**
   * Remove all persisted state associated with the torrent broadcast with the given ID.
   * @param removeFromDriver Whether to remove state from the driver.
   * @param blocking Whether to block until unbroadcasted
   */
  override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
  }
}

 

5个功能函数:

注意Initialize和stop都是空函数,没有实际的操作。

TorrentBroadcast是重点:

private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

继承自Broadcast类,spark命名空间下的私有类

代码结构:

注意Object TorrentBroadcast中的方法。

下面开始详细分析这个类

该类是对org.apache.spark.broadcast.Broadcast类的一种类似比特流形式的实现,具体机制如下:

Driver将序列化后的对象切分成许多小块,将这些小块保存在driver的BlockManager中。在每个executor上,每个executor首先尝试从自己的本地BlockManager上去获取这些小块,如果不存在,就会从driver或者其他的executor上去获取,一旦获取到了目标小块,该executor就会将小块保存在自己的BlockManager中,等待被其他的executor获取。

这种机制使得在driver发送多份broadcast数据时(对每一个executor而言)避免成为系统的瓶颈,如果采用前面提到的org.apache.spark.broadcast.HttpBroadcast方式的话就使得driver成为整个系统的瓶颈了。

在初始化的时候,TorrentBroadcast 对象会去读取SparkEnv.get.conf。

Executor上的broadcast的对应值,值由readBroadcastBlock方法获取,通过读取存储在driver或者其他executor上的block获得,在driver上,只有当真正需要该值时,才会通过blockManager去惰性读取。

@transient private lazy val _value: T = readBroadcastBlock()

setConf:

通过配置文件获取是否需要对broadcast进行压缩,并设置环境配置。

private def setConf(conf: SparkConf) {
  compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
    Some(CompressionCodec.createCodec(conf))
  } else {
    None
  }
  // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided
  blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
}
setConf(SparkEnv.get.conf)

writeBlocks:

/**
 * Divide the object into multiple blocks and put those blocks in the block manager.
 * @param value the object to divide
 * @return number of blocks this broadcast variable is divided into
 */
private def writeBlocks(value: T): Int = {
  // Store a copy of the broadcast variable in the driver so that tasks run on the driver
  // do not create a duplicate copy of the broadcast variable's value.
  SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
    tellMaster = false)
  val blocks =
    TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
  blocks.zipWithIndex.foreach { case (block, i) =>
    SparkEnv.get.blockManager.putBytes(
      BroadcastBlockId(id, "piece" + i),
      block,
      StorageLevel.MEMORY_AND_DISK_SER,
      tellMaster = true)
  }
  blocks.length
}

第一行代码,putSingle函数参数broadcast的Id,具体值value即将要存储的obj,存储级别,是否告知Master。

在driver上保存一份broadcast的值,这样在driver上运行的task就无需再创建一份对应的拷贝了。

由之前可知,在该类中有一个private的TorrentBroadcast的object,第二行就用到了该object的方法blockifyObject。

def blockifyObject[T: ClassTag](
    obj: T,
    blockSize: Int,
    serializer: Serializer,
    compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
  val bos = new ByteArrayChunkOutputStream(blockSize)
  val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
  val ser = serializer.newInstance()
  val serOut = ser.serializeStream(out)
  serOut.writeObject[T](obj).close()
  bos.toArrays.map(ByteBuffer.wrap)
}

入参有具体要切分存储的obj,blockSize默认为4Mb,序列化方法,压缩方法。最终是将压缩和序列化后的obj以Byte Array的形式写入spark的存储block。

接上面,切分写完之后,将blocks做zipWithIndex的遍历,调用puteBytes方法,将切分好写入block的每一份broadcast的每一个block都以bytes的形式保存进blockManager之中。

最后返回的是blocks的个数即一共写了几个block。


readBlocks:

/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[ByteBuffer] = {
 // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
 // to the driver, so other executors can pull these chunks from this executor as well.
 val blocks = new Array[ByteBuffer](numBlocks)
 val bm = SparkEnv.get.blockManager

 for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
   val pieceId = BroadcastBlockId(id, "piece" + pid)
   logDebug(s"Reading piece $pieceId of $broadcastId")
   // First try getLocalBytes because there is a chance that previous attempts to fetch the
   // broadcast blocks have already fetched some of the blocks. In that case, some blocks
   // would be available locally (on this executor).
   def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
   def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
     // If we found the block from remote executors/driver's BlockManager, put the block
     // in this executor's BlockManager.
     SparkEnv.get.blockManager.putBytes(
       pieceId,
       block,
       StorageLevel.MEMORY_AND_DISK_SER,
       tellMaster = true)
     block
   }
   val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
     throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
   blocks(pid) = block
 }
 blocks
}

从driver或者executor上获取所有的blocks,通过blockManager来实现,首先在本地local尝试,没有的话就从driver或者其他executor上获取,获取之后并保存在当前executor的blockManager里面。

归根结底是通过指定的broadcastId和并遍历pieceId利用blockManager的getLocalBytes和getRemoteBytes函数来获得对应的block然后通过解压和反序列化获取最终我们所需的value。

readBroadcastBlock:

真正的去读取broadcastBlock的具体value:

private def readBroadcastBlock(): T = Utils.tryOrIOException {
  TorrentBroadcast.synchronized {
    setConf(SparkEnv.get.conf)
    SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
      case Some(x) =>
        x.asInstanceOf[T]

      case None =>
        logInfo("Started reading broadcast variable " + id)
        val startTimeMs = System.currentTimeMillis()
        val blocks = readBlocks()
        logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))

        val obj = TorrentBroadcast.unBlockifyObject[T](
          blocks, SparkEnv.get.serializer, compressionCodec)
        // Store the merged copy in BlockManager so other tasks on this executor don't
        // need to re-fetch it.
        SparkEnv.get.blockManager.putSingle(
          broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
        obj
    }
  }
}

def unBlockifyObject[T: ClassTag](
    blocks: Array[ByteBuffer],
    serializer: Serializer,
    compressionCodec: Option[CompressionCodec]): T = {
  require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
  val is = new SequenceInputStream(
    blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
  val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
  val ser = serializer.newInstance()
  val serIn = ser.deserializeStream(in)
  val obj = serIn.readObject[T]()
  serIn.close()
  obj
}

与写block的过程和方法相似,就不详细介绍了,有一点差别就是这个read操作会真正的将对应的broadcast的值value解压反序列化读出来,对应的业务代码api就是broadcast变量的value函数,我们上面提到过的惰性求值的那个_value也会触发该函数的执行。

另外对于 broadcast的清除包括彻底和非彻底区别是是否清除driver上内容。

刚才一开始讲过TorrentBroadcastFactory类,它只要是完成TorrentBroadcast的具体的初始化、停止、实例化等等、该类的实现和实例化是在统一的BroadcastManager中实现的,该类管理者httpBroadcast实例和torrentBroadcast实例。

BroadcastManager:

private[spark] class BroadcastManager(
    val isDriver: Boolean,
    conf: SparkConf,
    securityManager: SecurityManager)
  extends Logging {

  private var initialized = false
  private var broadcastFactory: BroadcastFactory = null
  initialize()
  // Called by SparkContext or Executor before using Broadcast
  private def initialize() {
    synchronized {
      if (!initialized) {
        val broadcastFactoryClass =
          conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")

        broadcastFactory =
          Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

        // Initialize appropriate BroadcastFactory and BroadcastObject
        broadcastFactory.initialize(isDriver, conf, securityManager)

        initialized = true
      }
    }
  }

  def stop() {
    broadcastFactory.stop()
  }

  private val nextBroadcastId = new AtomicLong(0)

  def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
    broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
  }

  def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
    broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
  }
}

在该类中会根据配置文件中指出的类型来实例化具体的broadcastFactory类,考虑到性能问题,默认为torrentBroadcast。

该类的函数包括broadcast环境的初始化、新建broadcast实例、停止和清除broadcast等等。

BroadcastManager在SparkEnv.scala中实例化:

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

SparkEnv类负责了一个spark 运行实例(master或者worker)所需要的所有的运行时环境对象,包括serializer、akka actor system、blockManager、map output tracker等等,目前spark代码通过一个全局变量来访问SparkEnv,所以所有的线程都可以访问同一个SparkEnv。在创建完SparkContext之后可通过SparkEnv.get来访问。

SparkContext:

具体的某一个broadcast的实例化是在SparkContext.scala中实现的:

/**
 * Broadcast a read-only variable to the cluster, returning a
 * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
 * The variable will be sent to each cluster only once.
 */
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
  assertNotStopped()
  if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
    // This is a warning instead of an exception in order to avoid breaking user programs that
    // might have created RDD broadcast variables but not used them:
    logWarning("Can not directly broadcast RDDs; instead, call collect() and "
      + "broadcast the result (see SPARK-5063)")
  }
  val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
  val callSite = getCallSite
  logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
  cleaner.foreach(_.registerBroadcastForCleanup(bc))
  bc
}

这里也就是我们在业务代码中的入口比如:

val bcMiddleTime = sc.broadcast(mapMiddleTime)

mapMiddleTime就是我们需要广播的value。

httpBroadcast
下面简单分析下httpBroadcast。

HTTPBroadcastFactory类与之前的torrentBroadcastFactory类似,不过httpBroadcast实现了initialize和stop方法。

HttpBroadcast类:

/**
 * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
 * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
 * task) is deserialized in the executor, the broadcasted data is fetched from the driver
 * (through a HTTP server running at the driver) and stored in the BlockManager of the
 * executor to speed up future accesses.
 */
private[spark] class HttpBroadcast[T: ClassTag](
    @transient var value_ : T, isLocal: Boolean, id: Long)
  extends Broadcast[T](id) with Logging with Serializable {

httpBroadcast使用的是http协议来实现broadcast,在一开始广播变量以task的一部分的形式在executor中被序列化,通过运行在driver上的HTTP server,executor获取broadcast的data,并将获取到的data保存在executor的BlockManager中缓存。

代码架构:

一开始会将value同步保存在driver的blockManager之中。

若是集群状态,则将调用HttpBroadcast单例的write函数。

HttpBroadcast.synchronized {
  SparkEnv.get.blockManager.putSingle(
    blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}

if (!isLocal) {
  HttpBroadcast.write(id, value_)
}

HttpBroadcast单例的代码如下:

private[broadcast] object HttpBroadcast extends Logging {
  private var initialized = false
  private var broadcastDir: File = null
  private var compress: Boolean = false
  private var bufferSize: Int = 65536
  private var serverUri: String = null
  private var server: HttpServer = null
  private var securityManager: SecurityManager = null

  // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
  private val files = new TimeStampedHashSet[File]
  private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
  private var compressionCodec: CompressionCodec = null
  private var cleaner: MetadataCleaner = null

  def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
    synchronized {
      if (!initialized) {
        bufferSize = conf.getInt("spark.buffer.size", 65536)
        compress = conf.getBoolean("spark.broadcast.compress", true)
        securityManager = securityMgr
        if (isDriver) {
          createServer(conf)
          conf.set("spark.httpBroadcast.uri", serverUri)
        }
        serverUri = conf.get("spark.httpBroadcast.uri")
        cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
        compressionCodec = CompressionCodec.createCodec(conf)
        initialized = true
      }
    }
  }

  def stop() {
    synchronized {
      if (server != null) {
        server.stop()
        server = null
      }
      if (cleaner != null) {
        cleaner.cancel()
        cleaner = null
      }
      compressionCodec = null
      initialized = false
    }
  }

  private def createServer(conf: SparkConf) {
    broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
    val broadcastPort = conf.getInt("spark.broadcast.port", 0)
    server =
      new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
    server.start()
    serverUri = server.uri
    logInfo("Broadcast server started at " + serverUri)
  }

  def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name)

  private def write(id: Long, value: Any) {
    val file = getFile(id)
    val fileOutputStream = new FileOutputStream(file)
    Utils.tryWithSafeFinally {
      val out: OutputStream = {
        if (compress) {
          compressionCodec.compressedOutputStream(fileOutputStream)
        } else {
          new BufferedOutputStream(fileOutputStream, bufferSize)
        }
      }
      val ser = SparkEnv.get.serializer.newInstance()
      val serOut = ser.serializeStream(out)
      Utils.tryWithSafeFinally {
        serOut.writeObject(value)
      } {
        serOut.close()
      }
      files += file
    } {
      fileOutputStream.close()
    }
  }

  
  /**
   * Remove all persisted blocks associated with this HTTP broadcast on the executors.
   * If removeFromDriver is true, also remove these persisted blocks on the driver
   * and delete the associated broadcast file.
   */
  def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized {
    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
    if (removeFromDriver) {
      val file = getFile(id)
      files.remove(file)
      deleteBroadcastFile(file)
    }
  }

  /**
   * Periodically clean up old broadcasts by removing the associated map entries and
   * deleting the associated files.
   */
  private def cleanup(cleanupTime: Long) {
    val iterator = files.internalMap.entrySet().iterator()
    while(iterator.hasNext) {
      val entry = iterator.next()
      val (file, time) = (entry.getKey, entry.getValue)
      if (time < cleanupTime) {
        iterator.remove()
        deleteBroadcastFile(file)
      }
    }
  }

  private def deleteBroadcastFile(file: File) {
    try {
      if (file.exists) {
        if (file.delete()) {
          logInfo("Deleted broadcast file: %s".format(file))
        } else {
          logWarning("Could not delete broadcast file: %s".format(file))
        }
      }
    } catch {
      case e: Exception =>
        logError("Exception while deleting broadcast file: %s".format(file), e)
    }
  }

调用的write函数,首先在broadcastDir目录下创建一个以broadcastId的name为名称的文件,然后new出来一个fileOutPutStream实例和一个outPutStream实例,获取序列化方法将value写入对应文件,并将文件添加到系统的timeStampedHashSet[File]集合中。

doUnPersist和doDestory与torrentBroadcast类似,不同的是前者在删除driver上broadcast时会删除具体文件。

上面提到torrentBroadcast没有真正实现initialize和stop函数,而httpBroadcast实现了这两个函数。

Initialize函数首先从配置文件中获取bufferSize 为65536和是否压缩标志为true,接着会判断是否是driver,是的话在driver上创建http服务,创建一个临时文件目录broadcast来保存广播变量,服务名称为HTTP broadcast server。

Httpbroadcast在实现时对于value没有做实际意义上的读取操作即对于文件的读取操作没有被执行,value的值就是构建broadcast时传入的value,因为executor都是从driver上通过http服务来获取的,所以driver在构建broadcast时的value就直接拿来作为后来读取的value了,个人是这么理解的。

接着是创建metaDataCleaner实例和压缩实例。

cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)

metaDataCleaner实例的入参表明需要清除的数据类型和清理函数,这个实例会在后台起一个time task来定期清理那些老的过时的数据,传入的清理函数cleanUp主要是用来清理之前的broadcast Files。

Stop函数包括了http server的stop、cleaner和压缩实例的清除。

你可能感兴趣的:(Spark BroadCast 解析)