首先先来看一看broadcast的使用代码:
val factor = List[Int](1,2,3);
val factorBroadcast = sc.broadcast(factor)
val nums = Array(1,2,3,4,5,6,7,8,9)
val numsRdd = sc.parallelize(nums,3)
val list = new ListBuffer[List[Int]]()
val resRdd = numsRdd.mapPartitions(ite =>{
while (ite.hasNext){
list+=ite.next()::(factorBroadcast.value)
}
list.iterator
})
resRdd.foreach(res => println(res))
/**结果:
List(1, 1, 2, 3)
List(2, 1, 2, 3)
List(3, 1, 2, 3)
List(4, 1, 2, 3)
List(5, 1, 2, 3)
List(6, 1, 2, 3)
List(7, 1, 2, 3)
List(8, 1, 2, 3)
List(9, 1, 2, 3)
*/
首先生成了一个集合变量,把这个变量通过sparkContext的broadcast函数进行广播,最后在rdd的每一个partition迭代时,使用这个广播变量。
接下来看看广播变量的生成与数据的读取实现部分:
(1)SparkContext.broadcast( )
/**
* 向集群广播一个只读变量,返回在分布式函数中读取它的对象。
* @return 在每个Executor上缓存一个只读变量
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
//不能直接广播RDDs;代替,调用collect()并传播结果。
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
//通过broadcastManager中的newBroadcast函数来进行广播.
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
(2)BroadcastManager.newBroadcast( )
在BroadcastManager中生成广播变量的函数,这个函数直接使用的broadcastFactory的对应函数。broadcastFactory的实例通过配置spark.broadcast.factory,默认是TorrentBroadcastFactory。
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
//
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
(3)TorrentBroadcastFactory.newBroadcast( )
在TorrentBroadcastFactory中生成广播变量的函数:在这里面,直接生成了一个TorrentBroadcast的实例。
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
}
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
* 对executors的广播对象的值。
* 这是由[[readBroadcastBlock]]重建的,通过从driver或者其它executor读取块来构建此值;
* 在driver中,如果需要值,则从块管理器中延迟读取。
*/
@transient private lazy val _value: T = readBroadcastBlock()
// 是否选择压缩,默认是snappy压缩方式
@transient private var compressionCodec: Option[CompressionCodec] = _
// 每个块的大小。默认值是4 mb。该值仅由broadcaster读取
@transient private var blockSize: Int = _
//设置配置参数
private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
} else {
None
}
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
}
setConf(SparkEnv.get.conf)
//唯一ID broadcastId
private val broadcastId = BroadcastBlockId(id)
// 这个广播变量包含的块总数。
private val numBlocks: Int = writeBlocks(obj)
//--------------------
/**
* 将对象划分为多个块,并将这些块放在块管理器中。
* @param value the object to divide
* @return 这个广播变量被分成多个块
*/
private def writeBlocks(value: T): Int = {
import StorageLevel._
//这里先把这个广播变量保存一份到当前的task的storage中,
//这样做是保证在读取时,如果要使用这个广播变量的task就是本地的task时,直接从blockManager中本地读取.
val blockManager = SparkEnv.get.blockManager
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
//这里根据block的设置大小,对value进行序列化/压缩分块,每一个块的大小为blocksize的大小
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
// 这里把序列化并压缩分块后的blocks进行迭代,存储到blockManager中
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
//这个函数的返回值是一个int类型的值,这个值就是序列化压缩存储后block的个数.
blocks.length
}
//--------------------
//在我们的示例中,使用getValue时,会执行实例初始化时定义的lazy的函数readBroadcastBlock:
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
// 这里先从local端的blockmanager中直接读取storage中对应此广播变量的内容,
// 如果能读取到,表示这个广播变量已经读取过来或者说这个task就是广播的本地executor.
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) =>
if (blockResult.data.hasNext) {
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
x
} else {}
//下面这部分执行时,表示这个广播变量在当前的executor中是第一次读取,
//通过readBlocks函数去读取这个广播变量的所有的blocks,反序列化后,
//直接把这个广播变量存储到本地的blockManager中,下次读取时,就可以直接从本地进行读取.
case None =>
val blocks = readBlocks()
try {
// 将合并后的副本存储在BlockManager中,这样executor上的其他tasks就不需要重新取回了。
val obj = TorrentBroadcast.unBlockifyObject[T](
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
obj
} finally {
blocks.foreach(_.dispose())
}
}
}
}
//--------------------
/**
最后再看看readBlocks函数的处理流程:
从driver或其他executor中获取torrent块。*/
private def readBlocks(): Array[BlockData] = {
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
// 这里开始迭代读取每一个block的内容,
// 这里的读取是先从local中进行读取,
// 如果local中没有读取到数据时,通过blockManager读取远端的数据,
// 通过读取这个block对应的location从这个location去读取这个block的内容,并存储到本地的blockManager中.
// 最后,这个函数返回读取到的blocks的集合.
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
// 有可能一部分broadcast blocks已经被拉到本地,所以先尝试从本地获取getLocalBytes
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
// 我们从远程executors/driver的BlockManager中找到了块,把block放在这个executor的BlockManager中。
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}