Shuffle源码分析 Shuffle Write 和 Shuffle Read

step1:HashShuffleWriter.scala

 /**
    * 将ShuffleMapTask partition中的数据 写入磁盘
    * @param records
    */
  override def write(records: Iterator[Product2[K, V]]): Unit = {
    // 判断是否在map端进行聚合
    // 这里的话,如果是reduceBykey这种操作,它的aggregator.isDefined就是true 那么就会进行map端的本地聚合
    val iter = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        //这里就会执行本地聚合
        // 比如这里有 (hello,1)(hello,1) 那么此时就会聚合成(hello,2)
        dep.aggregator.get.combineValuesByKey(records, context)
      } else {
        records
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      records
    }
    // 能本地聚合就先本地聚合,然后遍历数据 对每个数据调用partitioner,默认是Hashpartitioner分区器
    // 生成bucketId,也就是每一份数据,要写入哪个bucket中。如果key相同的一定会写入到一个bucket中。同时一定会被一个ResultTask获取
    for (elem <- iter) {
      val bucketId = dep.partitioner.getPartition(elem._1)
      // 获取到了bucketId之后 会调用ShuffleBlockManage.forMapTask()方法,来生成bucketId对应的Writer
      // 然后用writer将数据写到bucket  进入 ShuffleBlockManage.forMapTask()方法
      shuffle.writers(bucketId).write(elem._1, elem._2)
    }
  }

step2:FileShuffleBlockResolver.scala

 /**
    * 给每个 map task 获取一个shuffleWriterGroup
    * @param shuffleId
    * @param mapId
    * @param numReducers
    * @param serializer
    * @param writeMetrics
    * @return
    */
  def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer,
      writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = {
    new ShuffleWriterGroup {
      shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers))
      private val shuffleState = shuffleStates(shuffleId)

      val openStartTime = System.nanoTime
      val serializerInstance = serializer.newInstance()
      //这里就很关键了
      //对应之前所说的,shuffle有两种模式,一种是普通的shuffle,另一种是优化后的(合并机制)
      //以前版本的默认普通shuffle机制不存在了,这里只需要分析优化后的shuffle模式
      val writers: Array[DiskBlockObjectWriter] = {
        Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
          //用shuffleId, mapId, bucketId生成唯一的shuffleblockId
          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
          val blockFile = blockManager.diskBlockManager.getFile(blockId)
          val tmp = Utils.tempFileWith(blockFile)
          // 获取一个writer 
          blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
        }
      }

Shuffle Read

step3:ShuffledRDD.scala

 //Shuffle Read的开始
  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    //调用read方法 拉取ResultTask需要聚合的数据
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

step4:BlockStoreShuffleReader.scala

  /** Read the combined key-values for this reduce task */
  // 拉取磁盘文件中的数据
  override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

    // Wrap the streams for compression based on configuration
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)
    }

    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }
}

 

你可能感兴趣的:(Scala,机器学习,分布式,spark)