spark广播变量是spark中一种只读的数据,广播的数据被集群不同节点共享,且默认存储在内存中,读取速度比较快。
spark内部有很多地方都使用的广播变量功能,比如spark sql的join有一种broadcast join,就是使用广播变量实现的小表join的优化;spark提交job时,将task的依赖关系广播到集群等。
接下来,分析一下,spark广播变量的实现细节:
```
val broadcastVar = sc.broadcast(Array(1, 2, 3)) // 创建一个广播变量
```
看一下一个广播如何创建,并在集群间共享的
```
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal) //主要代码
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
```
通过BroadcastManager的newBoadcast()方法创建广播变量,该方法只有一行代码,就是调用BroadcastFactory的子类的newBroadcast()的方法,BroadcastFactory只有一个实现子类,就是TorrentBroadcastFactory。
```
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
```
TorrentBroadcastFactory类的newBroadcast()方法会创建一个TorrentBroadcast,该实现使用了类BitTorrent的p2p协议,实现机制如下:driver会将广播的对象序列化,然后将序列化对象拆分为小数据块,并将这些数据块存储在driver端的BlockManage中。在每个executor端,executor首先会试图从它自己的BlockManager拉取数据。如果没有,它会从driver或者其他executor拉取数据,然后将拉取来的数据存储在自己的BlockManager,等待其他executor的拉取。该算法很好的分摊了driver的网络传输数据的压力到每个executor端。下面看一下spark的具体实现:
TorrentBroadcastFactory类的newBroadcast()方法,只是创建了一个TorrentBroadcast对象。
```
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
}
```
接下来,直接看一下TorrentBroadcast类写数据到BlockManager的代码和从BlockManager读取数据的代码。
```
private def writeBlocks(value: T): Int = {
import StorageLevel._
val blockManager = SparkEnv.get.blockManager
// 存储一份广播变量的拷贝,方便driver的任务快速读取广播变量
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
// 序列化数据对象,并拆分成多个数据块
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
// 存储数据块到BlockManager
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
// 如果开启校验功能,计算数据块的校验和,使用Adler-32算法,
// Adler-32校验和几乎与CRC-3一样可靠,但是能够更快地计算出来。
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
// 写入广播数据块到driver端的BlockManager
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}
```
广播变量的存入很简单,首先存入一份到driver的BlockManager,便于Driver端运行的tasks直接读取整个数据。然后将广播变量数据切分未多个数据块,然后将所有数据块存入driver端的BlockManager。
下面看一下读出广播变量的过程。
```
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
// 首先从本地的BlockManagr读取完整的未被切分的广播变量数据
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) =>
// 如果数据存在,直接返回
if (blockResult.data.hasNext) {
val x = blockResult.data.next().asInstanceOf[T]
releaseLock(broadcastId)
x
} else {
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
}
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
// 如果本地不存在未被切分的广播变量数据,
// 则试图读取切分的数据块
val blocks = readBlocks()
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
try {
// 将获取的数据块拼回完整的数据对象
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
// 将拼回的完整数据对象,存储到BlockManager,
// 便于该executor运行的任务直接读取完整数据对象
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
obj
} finally {
// 清空block数据
blocks.foreach(_.dispose())
}
}
}
}
```
接下来,看一下读取广播变量数据块的方法
```
private def readBlocks(): Array[BlockData] = {
val blocks = new Array[BlockData](numBlocks)
val bm = SparkEnv.get.blockManager
// 读取所有的数据块
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
// 获取远端driver或executor的数据块,BlockManager会向driver端的BlockManager查询存储该数据块的所有executor和driver位置,
// 然后随机获取其中一个位置拉取数据
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
// 拉取到数据,然后计算校验和
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
// 将从远端拉取的数据块存储到本地BlockManager,以便其他executor拉取
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = new ByteBufferBlockData(b, true)
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}
```
广播变量的读取比较复杂,首先读取端会尝试从本地BlockManager直接读取为切分的完整数据;如果不存在会尝试从本地BlockManager读取切分的数据块;如果都不存在,则从远端的driver或executor拉取,拉取每个数据块时,都会随机选择一个持有该数据块的executor或driver进行拉取,这样可以减少各个节点的网络IO压力。远端拉取来数据块会拷贝一份存储在本地BlockManager,以便其他executor拉取数据用。如果广播变量是读取数据块,会将数据块拼回完整数据对象,然后会将完成的数据对象拷贝一份存储在本地BlockManager,以便executor上执行的tasks快速读取广播变量。
由此可以看出广播变量会在每个节点存储两份:
- 一份是未切分的完整数据对象,用于executor或driver上执行的tasks快速读取
- 一份是切分后的数据,用于其他executor拉取对应的数据块。
spark的广播变量的写入比较简单,写入本地BlockManager两份数据即可。读取比较复杂,这里也真正的体现了p2p的BitTorrent协议的实现。
至此,spark的广播变量分析完了,应该对spark的广播变量原理有了一个直观的了解了吧。