在spark中,默认采用的broadcast的方式Torrent方式,其实现方式也是TorrentBroadcast类,当通过spark上下文调用broadcast广播某数据时,将会生成唯一的broadcastid用于区分该广播变量。
在TorrentBroadcast的构造过程中,将会通过writeBlocks()方法将所需要广播的数据切分并序列化。具体的切分逻辑实现在了其伴生对象的blockfyObject()方法中。
def blockifyObject[T: ClassTag](
obj: T,
blockSize: Int,
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
Utils.tryWithSafeFinally {
serOut.writeObject[T](obj)
} {
serOut.close()
}
cbbos.toChunkedByteBuffer.getChunks()
}
在这里看到,实际的广播变量序列化对象将在这里被分块并序列化,实际的序列化在ChunkedByteBufferOutputStream中。
ChunkedByteBufferOutputStream中存在一个由ByteBuffer构成的ArrayBuffer,用来保存具体序列化后的广播变量。
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
val thisBatch = math.min(chunkSize - position, len - written)
chunks(lastChunkIndex).put(bytes, written + off, thisBatch)
written += thisBatch
position += thisBatch
}
_size += len
}
@inline
private def allocateNewChunkIfNeeded(): Unit = {
if (position == chunkSize) {
chunks += allocator(chunkSize)
lastChunkIndex += 1
position = 0
}
}
以上是ChunkedByteBufferOutputStream的write()方法,每次都会按照切分大小(默认4M)申请一个切分大小的ByteBuffer,写入相应大小的数据,直到当前ByteBuffer耗尽再申请写一个分块,或者直到需要序列化的广播变量耗尽。
接下来看到TorrentBroadcast的writeBlocks()方法。
private def writeBlocks(value: T): Int = {
import StorageLevel._
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
val blockManager = SparkEnv.get.blockManager
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}
在已经得到序列化后的切分数据之后,遍历其中的ByteArray生成该广播数据的切片id并压缩至BlockManager中持久化,这份广播数据也将再次被保存在driver上,供executor使用。
在具体的executor中只有在真正需要使用这份广播变量时,才会通过readBroadcastBlock()方法惰性加载。
@transient private lazy val _value: T = readBroadcastBlock()
在readBroadcastBlock()方法中,将会判断本地是否已经存在这份广播变量的值,如果已经存在,则可以直接在本地获取,因此,如果是同driver运行于在一台上的executor将不需要从远程获取该广播变量。
而如果没有,则需要通过readBlocks()方法获取。
private def readBlocks(): Array[BlockData] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}
此处,根据需要得到的广播变量的分片数量,将会不断随机乱序依次获取其切片数据。
如果本地不存在需要的分片数据,将会通过BlockManager的getRemoteBytes()方法远程获取数据。
在BlockManager中,每个广播变量的分片数据都记录着其存在于哪个地址上。别的地址需要远程获取该分片上的数据上时,会从BlockManager的master中获取该数据所在的所有地址,依次尝试获取目标分片的数据。
在从getRemoteBytes()得到对应的数据之后,将会持久化在本地,同时更新BlockManager中该分片的地址,以便接下来别的executor在需要获取该分片时,不用再从master中下载,而是可以直接在此处下载得到该份分片数据。
在完成所有的数据下载之后,依照分片编号排序依次反序列化,可得到需要的广播数据。