Broadcast是分布式的数据共享,由BroadcastManager负责管理其创建或销毁。Broadcast一般用于处理共享的配置文件、通用Dataset、常用数据结构
广播的Broadcast变量是只读变量,保证数据的一致性。其数据保存方式是StorageLevel.MEMORY_AND_DISK,所以不会内存溢出,但广播大对象会导致网络IO或单点压力
Broadcast架构是常用的工厂模式(只生产一个产品Broadcast):BroadcastManager;BroadcastFactory、TorrentBroadcastFactory;Broadcast、TorrentBroadcast
Broadcast使用姿势
scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
scala> broadcastVar.value
res0: Array[Int] = Array(1, 2, 3)
通过SparkContext创建广播对象,然后value获取广播对象值。下面分析SparkContext的创建广播代码
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
// 不能直接对RDD类型进行broadcast操作,必须针对具体的值。可以先rdd.collect(),然后对结果broadcast
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
// 调用SparkEnv的BroadcastManager创建广播对象
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
// 获取调用堆栈信息
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
// 将broadcast对象关联CleanupTaskWeakReference弱引用,并while循环检测referenceQueue。当Broadcast对象没有强引用时,会被GC回收,同时其关联的CleanupTaskWeakReference弱引用对象会被put到referenceQueue
// CleanupTaskWeakReference继承WeakReference,并加入CleanupTask类型字段,这样从referenceQueue remove出弱引用时,能获取到CleanupTask字段值,进一步判断是否为CleanBroadcast,也就是获取到broadcastId上下文,再调用broadcastManager.unbroadcast()执行删除集群广播对象工作
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
也可以使用Guava的FinalizableWeakReference类,在强引用对象被回收时,回调finalizeReferent()方法:无须自己新建一个Thread,维护ReferenceQueue,并while循环执行referenceQueue.remove(100),获取弱引用对象,执行额外清理工作
Guava的FinalizableReference接口只有一个方法void finalizeReferent();方法没有入参,没有类似的broadcastId上下文,通过闭包实现上下文的传递
WeakReference
Guava中的FinalizableReferenceQueue解析
BroadcastManager的创建
BroadcastManager是用来管理Broadcast,该对象在SparkEnv中创建
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
BroadcastManager类的initialize()初始化TorrentBroadcastFactory工厂,stop()、newBroadcast()、unbroadcast()分别调用BroadcastFactory接口方法,nextBroadcastId是由AtomicLong(0)生成,且自增
private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
securityManager: SecurityManager)
extends Logging {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
initialize()
// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
broadcastFactory = new TorrentBroadcastFactory
broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
}
}
def stop() {
broadcastFactory.stop()
}
private val nextBroadcastId = new AtomicLong(0)
private[broadcast] val cachedValues = {
// HARD key, WEAK value。key强引用不会被回收,value弱引用在gc时被回收
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
}
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
}
}
StrongReference(强) > SoftReference(软) > WeakReference(弱) > PhantomReference(虚)
使用SoftReference或WeakReference包装对象时,比如new WeakReference(Object, new ReferenceQueue()),当GC时,Object只被软引用或弱引用,没有被别的对象强引用:
软引用:内存足够,GC时不回收此Object;内存OOM前,GC回收此Object。常用于Cache
弱引用:不管内存是否足够,GC时Object都会被回收
Java中的四种引用类型
cachedValues是apache ReferenceMap实现的,key是HARD强引用,value是WEAK弱引用。WeakHashMap的key是weak弱引用,value是强引用,当key被回收后,通过expungeStaleEntries()方法将e.value = null;断开WeakHashMap的强引用
ReferenceMap是线程不安全的,可以使用Collections.synchronizedMap包装,或使用spring的ConcurrentReferenceHashMap,类似ConcurrentHashMap实现,并使用WeakReference或SoftReference包装EntryConcurrentMap
TorrentBroadcastFactory工厂
定义BroadcastFactory接口工厂,创建Broadcast的newBroadcast()方法、初始化initialize()、删除Broadcast的unbroadcast()、停止stop()
/**
* An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a BroadcastFactory
* implementation to instantiate a particular broadcast for the entire Spark job.
*/
private[spark] trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
/**
* Creates a new broadcast variable.
*
* @param value value to broadcast
* @param isLocal whether we are in local mode (single JVM process)
* @param id unique id representing this broadcast variable
*/
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}
TorrentBroadcastFactory具体实现,调用TorrentBroadcast方法实现。也就是BroadcastManager调用BroadcastFactory,BroadcastFactory再调用Broadcast
/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
* protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
* [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
*/
private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
}
override def stop() { }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver.
* @param blocking Whether to block until unbroadcasted
*/
override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
}
}
TorrentBroadcast对象
Broadcast是抽象类,因为有实例字段_isValid、_destroySite,普通方法value()、unpersist()、destroy(),以及抽象方法getValue()、doUnpersist(blocking)、doDestroy(blocking);而BroadcastFactory则定义为接口
/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
* example, to give every node a copy of a large input dataset in an efficient manner. Spark also
* attempts to distribute broadcast variables using efficient broadcast algorithms to reduce
* communication cost.
*
* Broadcast variables are created from a variable `v` by calling
* [[org.apache.spark.SparkContext#broadcast]].
* The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the
* `value` method. The interpreter session below shows this:
*
* {{{
* scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
* broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
*
* scala> broadcastVar.value
* res0: Array[Int] = Array(1, 2, 3)
* }}}
*
* After the broadcast variable is created, it should be used instead of the value `v` in any
* functions run on the cluster so that `v` is not shipped to the nodes more than once.
* In addition, the object `v` should not be modified after it is broadcast in order to ensure
* that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped
* to a new node later).
*
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {
/**
* Flag signifying whether the broadcast variable is valid
* (that is, not already destroyed) or not.
*/
@volatile private var _isValid = true
private var _destroySite = ""
/** Get the broadcasted value. */
// 下面是3个抽象方法,下沉到具体子类实现逻辑
def value: T = {
assertValid()
getValue()
}
/**
* Asynchronously delete cached copies of this broadcast on the executors.
* If the broadcast is used after this is called, it will need to be re-sent to each executor.
*/
def unpersist() {
unpersist(blocking = false)
}
/**
* Delete cached copies of this broadcast on the executors. If the broadcast is used after
* this is called, it will need to be re-sent to each executor.
* @param blocking Whether to block until unpersisting has completed
*/
def unpersist(blocking: Boolean) {
assertValid()
doUnpersist(blocking)
}
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
* This method blocks until destroy has completed
*/
def destroy() {
destroy(blocking = true)
}
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
* @param blocking Whether to block until destroy has completed
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
// 获取线程方法栈信息。last、first都是针对先进后出的方法栈命名
// shortForm: s"$lastSparkMethod at $firstUserFile:$firstUserLine"
// longForm = callStack.take(callStackDepth).mkString("\n")
_destroySite = Utils.getCallSite().shortForm
logInfo("Destroying %s (from %s)".format(toString, _destroySite))
doDestroy(blocking)
}
/**
* Whether this Broadcast is actually usable. This should be false once persisted state is
* removed from the driver.
*/
private[spark] def isValid: Boolean = {
_isValid
}
/**
* Actually get the broadcasted value. Concrete implementations of Broadcast class must
* define their own way to get the value.
*/
protected def getValue(): T
/**
* Actually unpersist the broadcasted value on the executors. Concrete implementations of
* Broadcast class must define their own logic to unpersist their own data.
*/
protected def doUnpersist(blocking: Boolean)
/**
* Actually destroy all data and metadata related to this broadcast variable.
* Implementation of Broadcast class must define their own logic to destroy their own
* state.
*/
protected def doDestroy(blocking: Boolean)
/** Check if this broadcast is valid. If not valid, exception is thrown. */
protected def assertValid() {
if (!_isValid) {
throw new SparkException(
"Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
}
}
override def toString: String = "Broadcast(" + id + ")"
}
重点分析TorrentBroadcast类,构造函数参数是obj广播对象+nextBroadcastId。涉及BlockManager类的方法后续分析
/**
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
*
* The mechanism is as follows:
*
* The driver divides the serialized object into small chunks and
* stores those chunks in the BlockManager of the driver.
*
* On each executor, the executor first attempts to fetch the object from its BlockManager. If
* it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
* other executors if available. Once it gets the chunks, it puts the chunks in its own
* BlockManager, ready for other executors to fetch from.
*
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor).
*
* When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
*
* @param obj object to broadcast
* @param id A unique identifier for the broadcast variable.
*/
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
* Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
* which builds this value by reading blocks from the driver and/or other executors.
*
* On the driver, if the value is required, it is read lazily from the block manager.
*/
// 这是一个lazy方式,当executor需要广播对象时,从BlockManager中延迟读取
@transient private lazy val _value: T = readBroadcastBlock()
/** The compression codec to use, or None if compression is disabled */
@transient private var compressionCodec: Option[CompressionCodec] = _
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@transient private var blockSize: Int = _
// 初始化三个属性:compressionCodec、blockSize、checksumEnabled
private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
} else {
None
}
// Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
// 每个数据块的大小是4M,也就是4*1024*1024
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
}
setConf(SparkEnv.get.conf)
// broadcast_0 (id是BroadcastManager类的nextBroadcastId字段,AtomicLong自增)
private val broadcastId = BroadcastBlockId(id)
/** Total number of blocks this broadcast variable contains. */
// 构造TorrentBroadcast对象时,进行广播对象的写操作,并返回数据块总数
private val numBlocks: Int = writeBlocks(obj)
/** Whether to generate checksum for blocks or not. */
private var checksumEnabled: Boolean = false
/** The checksum for all the blocks. */
// 保存每个block的校验和,对应writeBlocks()方法
private var checksums: Array[Int] = _
override protected def getValue() = {
_value
}
// 对block计算校验和:Adler32相比CRC32计算更快
// Checksum checksumEngine = new Adler32(); checksumEngine.update(bytes); long checksum = checksumEngine.getValue();
private def calcChecksum(block: ByteBuffer): Int = {
val adler = new Adler32()
if (block.hasArray) {
adler.update(block.array, block.arrayOffset + block.position(), block.limit()
- block.position())
} else {
val bytes = new Array[Byte](block.remaining())
block.duplicate.get(bytes)
adler.update(bytes)
}
adler.getValue.toInt
}
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
*
* @param value the object to divide
* @return number of blocks this broadcast variable is divided into
*/
// 1. value对象putSingle到driver本地 2. 对象按block拆分成ByteBuffer数组:blockifyObject 3. 计算每个block的校验和 4. 将每个block进行putBytes
private def writeBlocks(value: T): Int = {
import StorageLevel._
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
// 获取当前Executor的BlockManager组件
val blockManager = SparkEnv.get.blockManager
// 调用BlockManager的putSingle方法将广播对象写入本地的存储体系。当Spark以local模式运行时,则会将广播对象写入Driver本地的存储体系,以便于任务也可以在Driver上执行。由于MEMORY_AND_DISK对应的StorageLevel的_replication属性固定为1,因此此处只会将广播对象写入Driver或Executor本地的存储体系
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
// 将对象经过序列化、压缩转换成一系列的字节块Array[ByteBuffer],每个块大小: 4*1024*1024
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
// 如果需要给分片广播块生成校验和,则创建与blocks块长度一致的checksums数组
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
// 为每个ByteBuffer的block计算校验和,并保存到数组
checksums(i) = calcChecksum(block)
}
// broadcast_0_piece0、broadcast_0_piece1、broadcast_0_piece2...
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
// 将每个分片广播数据块以序列化方式写入Driver本地的存储体系,存储方式为MEMORY_AND_DISK_SER,同时tellMaster注册成为下载源
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}
/** Fetch torrent blocks from the driver and/or other executors. */
// 从driver或executors循环获取广播对象的每个块数据
// 将BlockData定义为接口,而不是某个具体实现类,抽象程度更高
private def readBlocks(): Array[BlockData] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
// 对各个广播分片随机shuffle,避免某个块的获取出现“热点”。都从0开始的话,前面的块都是"热点"
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
// 从存储体系获取pieceId对应的Block数据,并封装为BlockData
// 因为可能其他executor进程正在或已经下载了此Block数据,所以先getLocalBytes
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
// 从远端的BlockManager以序列化的字节形式获取Block数据,返回ChunkedByteBuffer
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
// 检查校验和
if (checksumEnabled) {
// ChunkedByteBuffer数组的第一个ByteBuffer是Block数据
val sum = calcChecksum(b.chunks(0))
// 与本地保存的校验和数组对应的值进行对比
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
// 因为Block是从远程的Driver或其他Executors的BlockManager获取的,所以再把此Block保存到当前Executor进程的BlockManager里,并tellMaster成为下载源
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
// 封装成ByteBufferBlockData对象
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
// 删除所有Executors的广播对象。与下面一个方法的区别就是:是否同时删除Driver上的广播对象
override protected def doUnpersist(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
/**
* Remove all persisted state associated with this Torrent broadcast on the executors
* and driver.
*/
override protected def doDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
/** Used by the JVM when serializing this object. */
// 序列化类的一种方式。当类继承Serializable接口,序列化对象时,ObjectOutputStream.writeSerialData()会反射调用类的writeObject()方法实现序列化操作
// 序列化的几种实现: 1. 直接继承Serializable接口 2. Serializable+transient 3. 实现Externalizable接口的writeExternal、readExternal方法 4. Serializable+类中实现writeObject、readObject方法,通过ObjectOutputStream反射调用
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { // 这句就是重写writeObject()方法的意义所在,判断广播对象是否已经销毁
assertValid()
// 将当前类的非静态和非transient字段写入OutputStream,Serializable接口的默认实现
// 序列化的是对象状态,静态字段不属于对象,而是类级
out.defaultWriteObject()
}
// 广播对象的读操作,同步操作
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
// 获取broadcastCache,key是BroadcastBlockId对象,HARD引用类型;value是广播对象T,WEAK引用类型
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
// 先走cache,因为value是Weak引用,所以当GC回收广播对象后,先blockManager.getLocalValues从存储获取广播对象,并cache
Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
// 从本地的存储系统中获取广播对象,即通过BlockManager的putSingle方法写入存储体系的广播对象。这里返回BlockResult类型,使用Iterator包装广播对象obj
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) =>
if (blockResult.data.hasNext) {
// 转换成广播对象,并释放块锁
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
if (x != null) {
// 缓存
broadcastCache.put(broadcastId, x)
}
x
} else {
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
}
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
// 读取广播对象的所有Blocks,返回Array[BlockData]数据
val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
try {
// 将Array[BlockData]合并InputStream,并解压、反序列化成对象T
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
// 保存obj到当前Executor的BlockManager,供本机的其他Executor进程获取
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
if (obj != null) {
broadcastCache.put(broadcastId, obj)
}
obj
} finally {
// blocks生命周期结束,clean堆外内存
blocks.foreach(_.dispose())
}
}
}
}
}
/**
* If running in a task, register the given block's locks for release upon task completion.
* Otherwise, if not running in a task then immediately release the lock.
*/
// 唤醒BlockId对应的读或写阻塞,执行BlockInfoManager.unlock()方法
// 这个锁保证当数据块被一个运行中的任务使用时,其他任务不能再次使用,直到此任务完成并释放锁
// BlockId是抽象类设计
private def releaseLock(blockId: BlockId): Unit = {
val blockManager = SparkEnv.get.blockManager
// TaskContext是ThreadLocal模式,抽象类设计
Option(TaskContext.get()) match {
case Some(taskContext) =>
// 当get到TaskContext对象,也就是在Task执行中进行Block获取时,在TaskContext执行结束(无论成功或失败),释放lock
taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId))
case None =>
// This should only happen on the driver, where broadcast variables may be accessed
// outside of running tasks (e.g. when computing rdd.partitions()). In order to allow
// broadcast variables to be garbage collected we need to free the reference here
// which is slightly unsafe but is technically okay because broadcast variables aren't
// stored off-heap.
blockManager.releaseLock(blockId)
}
}
}
// 写数据时,从外而内,先序列化再压缩: ((compress)serialize).write(obj)
// 读数据时,从内而外,先解压再反序列化: ((compress(obj))serialize).read
private object TorrentBroadcast extends Logging {
// 将对象obj转换成一系列的字节块Array[ByteBuffer],每个块大小: 4*1024*1024
// 先序列化,再压缩,最后转换成字节块。OutputStream是从外而内操作数据
def blockifyObject[T: ClassTag](
obj: T,
blockSize: Int,
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
// ChunkedByteBufferOutputStream包装
val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
// CompressionCodec压缩
val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
val ser = serializer.newInstance()
// 序列化Stream
val serOut = ser.serializeStream(out)
Utils.tryWithSafeFinally {
// OutputStream的write方法输入是byte[],obj先序列化成byte[],再压缩成另一个byte[],最后将压缩的byte[]转换成字节块,也就是调用ChunkedByteBufferOutputStream的write方法
serOut.writeObject[T](obj)
} {
serOut.close()
}
cbbos.toChunkedByteBuffer.getChunks()
}
// 将Array[InputStream]转换成对象
// 先解压,再反序列化。InputStream是从内而外的操作数据
def unBlockifyObject[T: ClassTag](
blocks: Array[InputStream],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
// 顺序合并流SequenceInputStream
val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration)
// 解压
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
// Utils.tryWithSafeFinally封装
val obj = Utils.tryWithSafeFinally {
// 反序列化
serIn.readObject[T]()
} {
serIn.close()
}
obj
}
/**
* Remove all persisted blocks associated with this torrent broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
// 根据id删除所有Executors进程上的广播数据blocks
// removeFromDriver: 判断是否删除driver上的广播对象
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
Broadcast实现总结
- Driver端创建广播对象,writeBlocks(obj)将对象分块,每个块做为一个block存进Driver端的BlockManager
- Broadcast在Driver端进行序列化,在Executor端进行反序列化,并调用broadcastVar.value获取广播对象,"@transient private lazy val _value: T = readBroadcastBlock()",这是不会被序列化且lazy方式;"private var checksums: Array[Int] = _",校验和数组可以被Driver序列化
- 每个Executor会试图获取所有的块,来组装成一个被Broadcast的变量。"获取块"的方法是首先从Executor自身的BlockManager中获取,如果自己的BlockManager中没有这个块,就从别的BlockManager中获取
最初的时候,Driver是获取这些块的唯一的源,随着各个Executor的BlockManager从Driver端获取了不同的块(TorrentBroadcast会有意避免各个Executor以同样的顺序获取这些块: Random.shuffle(Seq.range(0, numBlocks))),"块"的源就多了起来,
每个Executor就可能从多个源中的一个,包括Driver和其它Executor的BlockManager中获取块,这样就使得流量在整个集群中更均匀,而不是由Driver作为唯一的源
引申
SparkBroadcast实现原理