Spark学习之11:Shuffle Read

本文描述ShuffleMapTask执行完成后,后续Stage执行时读取Shuffle Write结果的过程。涉及Shuffle Read的RDD有ShuffledRDD、CoGroupedRDD等。
发起Shuffle Read的方法是这些RDD的compute方法。下面以ShuffledRDD为例,描述Shuffle Read过程。

0. 流程图

Spark学习之11:Shuffle Read_第1张图片

1. 入口函数

Shuffle Read操作的入口是ShuffledRDD.compute方法。
  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }
(1)通过SparkEnv获取ShuffleManager对象,它两个实现HashShuffleManager和SortShuffleManager,这个两个实现的getReader方法都返回HashShuffleReader对象;
(2)调用HashShuffleReader的read方法。
(3)compute方法返回的是一个迭代器,只有在涉及action或固化操作时才会具体执行用户提供的操作。

1.1. HashShuffleReader.read

  override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
    }
    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
        sorter.iterator
      case None =>
        aggregatedIter
    }
  }
(1)BlockStoreShuffleFetcher是一个object,只有一个方法fetch,根据shuffleId和partition来获取对应的shuffle内容; fetch方法返回一个迭代器,遍历次迭代器就可以获取对应的数据记录;
(2)后面是依据不同的条件,构造不同的迭代器,比如是否要合并,排序等。
注:这里mapSideCombine的操作和Shuffle Write时调用的方法是不同的。
write时调用:combineValuesByKey;
read时调用:combineCombinersByKey。

2. BlockStoreShuffleFetcher

一个Shuffle Map Stage会将输出写到多个节点。由于多个ShuffleMapTask在同一节点执行,每个Task创建各自独立的Blocks,Blocks的数量取决于Reduce的数量(shuffle输出分区个数),因此一个reduce(一个分区)在一个节点上可能对应多个Block。
Map和Reduce关系示意图:
Spark学习之11:Shuffle Read_第2张图片

一个Reduce依赖所有的Map,每个Map都会输出一份数据到每一个Ruduce。可以理解为,有多少个Map,一个Reduce就对应多少个Block。
首先,需要通过调用MapOutputTracker.getServerStatuses获取reduce对应的Blocks所在的节点以及每个Block的大小。
  def fetch[T](
      shuffleId: Int,
      reduceId: Int,
      context: TaskContext,
      serializer: Serializer)
    : Iterator[T] =
  {
    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
    val blockManager = SparkEnv.get.blockManager
    val startTime = System.currentTimeMillis
    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
    ......
  }
调用MapOutputTracker的getServerStatuses方法。

2.1. MapOutputTracker. getServerStatuses

MapOutputTracker类定义了一个数据结构:
  protected val mapStatuses: Map[Int, Array[MapStatus]]
mapStatuses在Driver和Executor有不同的行为:
(1)在Driver端,用于记录所有ShuffleMapTask的map输出结果;
(2)在Executor端,它只作为一个缓存,如果对应数据不存在,则会从Driver端获取。
下面描述缓存没有命中,而从Driver获取的情形。
  def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {
        // Someone else is fetching it; wait for them to be done
        while (fetching.contains(shuffleId)) {
          try {
            fetching.wait()
          } catch {
            case e: InterruptedException =>
          }
        }
        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          fetching += shuffleId
        }
      }
      if (fetchedStatuses == null) {
        // We won the race to fetch the output locs; do so
        logInfo("Doing the fetch; tracker actor = " + trackerActor)
        // This try-finally prevents hangs due to timeouts:
        try {
          val fetchedBytes =
            askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      }
      if (fetchedStatuses != null) {
        fetchedStatuses.synchronized {
          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
        }
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {
      statuses.synchronized {
        return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
      }
    }
  }
(1)fetching记录当前正在获取的ShuffleId,如果当前ShuffleId有线程正在获取则等待,如果没有其他线程在获取则将ShuffleId加入fetching队列;
(2)fetchedStatuses为null,则开始获取;
(3)调用askTracker方法,向MapOutputTrackerMasterActor发送GetMapOutputStatuses消息,askTracker返回序列化的MapStatus信息;
(4)将获取的MapStatus信息反序列化生成MapStatus对象数组;
(5)调用mapStatuses.put,将MapStatus对象存入mapStatuses缓存;
(6)调用MapOutputTracker.convertMapStatuses方法,将获取的的MapStatus转化为(BlockManagerId, BlockSize)二元组,一个BlockManagerId可能对应过个BlockSize。

2.1.1. MapOutputTrackerMasterActor处理GetMapOutputStatuses消息

    case GetMapOutputStatuses(shuffleId: Int) =>
      val hostPort = sender.path.address.hostPort
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
      val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
      val serializedSize = mapOutputStatuses.size
      if (serializedSize > maxAkkaFrameSize) {
        val msg = s"Map output statuses were $serializedSize bytes which " +
          s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
        /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
         * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
         * will ultimately remove this entire code path. */
        val exception = new SparkException(msg)
        logError(msg, exception)
        throw exception
      }
      sender ! mapOutputStatuses
(1)调用MapOutputTrackerMaster.getSerializedMapOutputStatuses方法,获取ShuffleId对应的序列化好的MapStatus;
(2)返回序列化好的MapStatus信息。

2.1.2 MapOutputTrackerMaster.getSerializedMapOutputStatuses

  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
    var statuses: Array[MapStatus] = null
    var epochGotten: Long = -1
    epochLock.synchronized {
      if (epoch > cacheEpoch) {
        cachedSerializedStatuses.clear()
        cacheEpoch = epoch
      }
      cachedSerializedStatuses.get(shuffleId) match {
        case Some(bytes) =>
          return bytes
        case None =>
          statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
          epochGotten = epoch
      }
    }
    // If we got here, we failed to find the serialized locations in the cache, so we pulled
    // out a snapshot of the locations as "statuses"; let's serialize and return that
    val bytes = MapOutputTracker.serializeMapStatuses(statuses)
    logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
    // Add them into the table only if the epoch hasn't changed while we were working
    epochLock.synchronized {
      if (epoch == epochGotten) {
        cachedSerializedStatuses(shuffleId) = bytes
      }
    }
    bytes
  }
(1)判断缓存是否过期,如过期则清除;
(2)从缓存中读取数据,如果缓存中没有则从mapStatuses中读取,缓存中有则直接返回;
(3)将获取的MapStatus序列化并存入缓存。

2.1.3. MapOutputTracker.convertMapStatuses

  private def convertMapStatuses(
      shuffleId: Int,
      reduceId: Int,
      statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
    assert (statuses != null)
    statuses.map {
      status =>
        if (status == null) {
          logError("Missing an output location for shuffle " + shuffleId)
          throw new MetadataFetchFailedException(
            shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
        } else {
          (status.location, status.getSizeForBlock(reduceId))
        }
    }
  }
将每个MapStatus转换成一个 (BlockManagerId, BlockSize)二元组,因此一个BlockManagerId可能对应多个BlockSize,也就是说一个BlockManagerId在数组中会出现多次。
注:BlockSize并不代表Block的实际大小。MapStatus有两个实现:CompressedMapStatus和HighlyCompressedMapStatus。
其中, CompressedMapStatus存储的Block大小是经过压缩处理的,不能还原成原值;
当Shuffle的输出分区超过20000(spark1.3)时,采用HighlyCompressedMapStatus,它保存的Block大小的平均值。

2.2. 构建ShuffleBlockId映射

获取到Reudce对应的所有Block的位置及大小信息后,BlockStoreShuffleFetcher.fetch方法开始构建ShuffleBlockId映射。
    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
      shuffleId, reduceId, System.currentTimeMillis - startTime))
    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
    for (((address, size), index) <- statuses.zipWithIndex) {
      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
    }
    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
      case (address, splits) =>
        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
    }
(1)statuses的类型为Array[(BlockManagerId, Long)],其中BlockMangerId代表block所在的位置,Long表示Block的大小;
(2)for循环将statuses转换成[BlockManagerId,ArrayBuffer[(Int, Long)]]结构,它表示在BlockManagerId上,一个ruduce对应多个Block,其中Int表示statuses的下标索引,Long表示Block的大小;
(3)创建BlockManagerId与ShuffleBlockId的映射;由于statuses中的记录是按Map编号(即partition编号)从小到排列的(具体可参考DAGScheduler.handleTaskCompletion方法中调用Stage.addOutputLoc方法及MapOutputTracker.registerMapOutputs方法),其下标索引代表了partition编号,因此在这儿可以利用for循环保存的下标索引来创建ShuffleBlockId对象。
到此就完成了BlockManagerId到Seq[(BlockId, Long)]的映射;BlockId代表ShuffleBlockId,Long表示对应Block的大小。

2.3. 创建ShuffleBlockFetcherIterator对象

构建完ShuffleBlockId映射后, BlockStoreShuffleFetcher.fetch方法开始创建ShuffleBlockFetcherIterator对象。
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      SparkEnv.get.blockManager.shuffleClient,
      blockManager,
      blocksByAddress,
      serializer,
      SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
从类名可知,该对象是一个迭代器。在构造体中会调用自身的initialize方法。

2.3.1. ShuffleBlockFetcherIterator.initialize

  private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())
    // Split local and remote blocks.
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    fetchRequests ++= Utils.randomize(remoteRequests)
    // Send out initial requests for blocks, up to our maxBytesInFlight
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
      sendRequest(fetchRequests.dequeue())
    }
    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    // Get Local Blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }
(1)调用splitLocalRemoteBlocks方法,根据BlockManagerId来判断Block位于本地还是远端; splitLocalRemoteBlocks方法,会将每个位于远端的[BlockerManagerId, Seq[(BlockId, Long)]]封装成多个FetchRequest对象,对象的数量根据Long值的和以及 maxBytesInFlight参数来控制;
(2)将 splitLocalRemoteBlocks返回的 FetchRequest数组随机化,并加入fetchRequests队列;
(3)调用sendRequest方法发出远端读取Block请求,while循环会根据maxBytesInFlight来控制发出远程请求的数量,剩余的请求会在next方法中执行
(4)调用fetchLocalBlocks方法,从本地读取Block。

2.3.2. ShuffleBlockFetcherIterator.sendRequest

  private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val blockIds = req.blocks.map(_._1.toString)
    val address = req.address
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            buf.retain()
            results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
            shuffleMetrics.incRemoteBytesRead(buf.size)
            shuffleMetrics.incRemoteBlocksFetched(1)
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }
        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), e))
        }
      }
    )
  }
该方法负责读取Remote Block。通过ShuffleClient对象,具体实现是NettyBlockTransferService,通过fetchBlocks方法来读取Block;读取成功后, NettyBlockTransferService回调onBlockFetchSuccess方法,将结果封装成SuccessFetchResult对象,并压入results队列。

2.3.3. ShuffleBlockFetcherIterator.fetchLocalBlocks

  private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, 0, buf))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, e))
          return
      }
    }
  }
该方法负责读取本地block,并将结构封装成SuccessFetchResult对象压入results队列。

2.4. 返回迭代器

当ShuffleBlockFetcherIterator构造完成后,会对该对象进行处理并封装进InterruptibleIterator对象返回。
    val itr = blockFetcherItr.flatMap(unpackBlock)
    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
      context.taskMetrics.updateShuffleReadMetrics()
    })
    new InterruptibleIterator[T](context, completionIter) {
      val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
      override def next(): T = {
        readMetrics.incRecordsRead(1)
        delegate.next()
      }
    }

你可能感兴趣的:(Spark)