spark join shuffle 数据读取的过程

spark join shuffle 数据读取的过程

在spark中,当数据要shuffle时,这个拉取过程RDD是怎么和ShuffleMapTask 关联起来的。
在CoGroupedRDD通过调用如下函数去读取指定分区的数据

 SparkEnv.get.shuffleManager
      .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
      .read()

通过上面的方法,就可以知道调用那个依赖的RDD,读取那个分片数据。
然后创建BlockStoreShuffleReader读取对象。在该类中执行下面的方法

// 下面就是对这个shuffler中的分片数据进行读取并进行相关的aggregate操作了
val blockFetcherItr = new ShuffleBlockFetcherIterator(
  context,
  blockManager.shuffleClient,
  blockManager,
  mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
  // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
  SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

可以看到首先要通过mapOutputTracker去拉取该分区的地址信息

 def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
  : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
// 拉取这些状态数据回来了
val statuses = getStatuses(shuffleId)
// Synchronize on the returned array because, on the driver, it gets mutated in place
statuses.synchronized {
  return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}

然后在 getStatuses函数中,发起远程调用,读取这个shuffle的结果地址数据

 try {
      // 拉取这个shuffle的状态数据
      val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
      // 这个status是那些数据分片的地址来的
      fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
      logInfo("Got the output locations")
      mapStatuses.put(shuffleId, fetchedStatuses)
    } 

在MapOutputTrackerMaster中的MapOutputTrackerMasterEndpoint 接收线程中,接收到相关的消息

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
  // 问这个shuffler的地址信息
  val hostPort = context.senderAddress.hostPort
  logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
  // 去这个tracker里面去拉取了
  val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
  val serializedSize = mapOutputStatuses.length
  if (serializedSize > maxAkkaFrameSize) {
    val msg = s"Map output statuses were $serializedSize bytes which " +
      s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."

    /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
     * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
    val exception = new SparkException(msg)
    logError(msg, exception)
    context.sendFailure(exception)
  } else {
    context.reply(mapOutputStatuses)
  }

在tracker 保存着shuffle的执行结果。这些数据是通过DAGScheduler 在调用ShuffleMapTask 的时候,运行的结果存放的

 case smt: ShuffleMapTask =>
        val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
        updateAccumulators(event)
        val status = event.result.asInstanceOf[MapStatus]
        val execId = status.location.executorId
        logDebug("ShuffleMapTask finished on " + execId)
        if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
          // 这是一个失败的任务
          logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
        } else {
          // 记录这个分区的运行结果
          shuffleStage.addOutputLoc(smt.partitionId, status)
        }

        if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
          markStageAsFinished(shuffleStage)
          logInfo("looking for newly runnable stages")
          logInfo("running: " + runningStages)
          logInfo("waiting: " + waitingStages)
          logInfo("failed: " + failedStages)

          // We supply true to increment the epoch number here in case this is a
          // recomputation of the map outputs. In that case, some nodes may have cached
          // locations with holes (from when we detected the error) and will need the
          // epoch incremented to refetch them.
          // TODO: Only increment the epoch number if this is not the first time
          //       we registered these map outputs.
          // 把当前shuffler的执行结果存放在这里了
          mapOutputTracker.registerMapOutputs(
            shuffleStage.shuffleDep.shuffleId,
            shuffleStage.outputLocInMapOutputTrackerFormat(),
            changeEpoch = true)

详情可查看 《SPARK TASK 任务状态管理》 ,在DAGScheduler 中当ShuffleMapTask 完成任务时,把对应的shuffleid
的计算结果路径写到mapOutputTracker中,然后在其它地方就可以请求到这个数据了。

private def convertMapStatuses(
  shuffleId: Int,
  startPartition: Int,
  endPartition: Int,
  statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
//  statuses是包含着所有的地址信息
for ((status, mapId) <- statuses.zipWithIndex) {
  if (status == null) {
    val errorMessage = s"Missing an output location for shuffle $shuffleId"
    logError(errorMessage)
    throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
  } else {
    for (part <- startPartition until endPartition) {
      // 就是拿这个地址中这个分片里面的数据
      splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
        ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
    }
  }
}

splitsByAddress.toSeq
}

然后就可以拉取指定分片里面的数据了,通过ShuffleBlockFetcherIterator 类的功能,

private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
  req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size

// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
// 去这些地址拉取数据了,同时注意block对象是 ShuffleBlockId 里面包含着当前请求的是那个分片数据
//  在拉取的时候,还要对block块数据进行分片
val address = req.address
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
  new BlockFetchingListener {
    override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {

    }

    override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
      logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
      results.put(new FailureFetchResult(BlockId(blockId), address, e))
    }
  }
)
}

然后就可以通过shuffleClient (NettyBlockTransferService )进行远程的拉取了。

override def fetchBlocks(
  host: String,
  port: Int,
  execId: String,
  blockIds: Array[String],
  listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
// 通过netty的方式去拉取block文件
try {
  val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
    override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
      // 这里就是创建netty客户端进行拉取数据了
      val client = clientFactory.createClient(host, port)
      new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
    }
  }
  // 如果有重试器,则就创建一个包装对象进行重试
  val maxRetries = transportConf.maxIORetries()
  if (maxRetries > 0) {
    // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
    // a bug in this code. We should remove the if statement once we're sure of the stability.
    new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
  } else {
    blockFetchStarter.createAndStart(blockIds, listener)
  }
} catch {
  case e: Exception =>
    logError("Exception while beginning fetchBlocks", e)
    blockIds.foreach(listener.onBlockFetchFailure(_, e))
}
}

当拉取回来后,就可以对这个iterator数据进行后缀的处理了。然后回到BlockStoreShuffleReader类中

override def read(): Iterator[Product2[K, C]] = {
// 读取这个分片的数据了,生成iteracor对象
// 下面就是对这个shuffler中的分片数据进行读取并进行相关的aggregate操作了
val blockFetcherItr = new ShuffleBlockFetcherIterator(
  context,
  blockManager.shuffleClient,
  blockManager,
  mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
  // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
  SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
  blockManager.wrapForCompression(blockId, inputStream)
}

val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
  // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
  // NextIterator. The NextIterator makes sure that close() is called on the
  // underlying InputStream when all records have been read.
  serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
// metrics记录数据量
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
  recordIter.map(record => {
    readMetrics.incRecordsRead(1)
    record
  }),
  context.taskMetrics().updateShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
  if (dep.mapSideCombine) {
    // We are reading values that are already combined
    // 在这执行聚合方法了,在map端已经进行合并的了,这个数据先分组再count操作
    val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
    dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
  } else {
    // We don't know the value type, but also don't care -- the dependency *should*
    // have made sure its compatible w/ this aggregator, which will convert the value
    // type to the combined type C
    // 在shullfer阶段已经执行这些方法的了,这个在最后所有数据进行count操作
    val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
    dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
  }
} else {
  require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
  interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
  case Some(keyOrd: Ordering[K]) =>
    // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
    // the ExternalSorter won't spill to disk.
    // 对这数据进行排序
    val sorter =
      new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
    sorter.insertAll(aggregatedIter)
    context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
    context.internalMetricsToAccumulators(
      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
    CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
  case None =>
    aggregatedIter
}
}

当数据拉取回来后生成iteraror ,然后判断是否有aggregator聚合函数,如果有就执行,
所以这样就可以在shuffle分片数据的过程就可以提前执行聚合函数,减少传输到后面的数据量。
所以像一些count 或者 sum的操作,其实可以直接在这里进行执行,这样在这个过程中就只要把这个
结果传到后面就可以了,数据量就大大减少了。

同时如果要对Shuffle的分片数据进行排序的需求keyOrdering,就直接在这里创建一个ExternalSorter
对象,对上面的数据进行排序返回,所以就在这个shffle的分片阶段中可以实现aggregate聚合函数和keyordering
对字段排序的功能。

总结

  1. 请求读取指定的分片数据split
  2. 去MapOutputTrackerMaster拉取该shuffleid的分片地址信息
  3. 通过netty到相关的地址拉取指定Partition的数据
  4. 去拉取回来的数据执行聚合函数操作
  5. 去执行后的iterator数据执行 keyorder排序数据,然后最后返回

你可能感兴趣的:(spark)