Spark RDD之Partition

RDD概述

RDD是一个抽象类,主要包含五个部分:

  1. partitions列表
  2. 计算每一个split的函数
  3. 依赖rdd的列表(dependencies)
  4. 键值对rdd的partitioner
  5. 计算每个split的首选位置列表

其中最后两个部分是可选的,以上五个部分对应着五个方法:

  @DeveloperApi
  def compute(split: Partition, context: TaskContext): Iterator[T]

  protected def getPartitions: Array[Partition]

  protected def getDependencies: Seq[Dependency[_]] = deps

  protected def getPreferredLocations(split: Partition): Seq[String] = Nil

  @transient val partitioner: Option[Partitioner] = None

Partition

一份待处理的原始数据会被按照相应的逻辑(例如jdbc和hdfs的split逻辑)切分成n份,每份数据对应到RDD中的一个Partition,Partition的数量决定了task的数量,影响着程序的并行度。Partition的源码如下:

trait Partition extends Serializable {

  def index: Int

  override def hashCode(): Int = index

  override def equals(other: Any): Boolean = super.equals(other)
}

Partition和RDD是伴生的,即每一种RDD都有其对应的Partition实现,所以,分析Partition主要是分析其子类,我们关注两个常用的子类,JdbcPartition和HadoopPartition。

JdbcPartition

JdbcPartition类包含于JdbcRDD.scala文件,它继承了Partition类,多加了两个长整型属性lower和upper。

private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
  override def index: Int = idx
}

JdbcPartition被定义为一个半私有类,只有父包和子包可以访问,对外暴露的接口是jdbcRDD,该RDD定义如下:

class JdbcRDD[T: ClassTag](
    sc: SparkContext,
    getConnection: () => Connection,
    sql: String,
    lowerBound: Long,
    upperBound: Long,
    numPartitions: Int,
    mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
  extends RDD[T](sc, Nil) with Logging

其中有两个重要方法getPartitionscompute

  override def getPartitions: Array[Partition] = {
    val length = BigInt(1) + upperBound - lowerBound
    (0 until numPartitions).map { i =>
      val start = lowerBound + ((i * length) / numPartitions)
      val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
      new JdbcPartition(i, start.toLong, end.toLong)
    }.toArray
  }

getPartitions主要做的操作是将数据根据传入的numPartitions参数进行分片并封装成多个JdbcPartition类返回。

override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
  {
    context.addTaskCompletionListener{ context => closeIfNeeded() }
    val part = thePart.asInstanceOf[JdbcPartition]
    val conn = getConnection()
    val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

    stmt.setLong(1, part.lower)
    stmt.setLong(2, part.upper)
    val rs = stmt.executeQuery()
}

compute将传入的Partition强行转化为JdbcPartition,连接数据库并且对sql语句进行预处理,将JdbcPartition中的lower和upper作为数据库中id的上下界传入,并执行查询语句。也就是说,每次调用compute时,不会对所有数据进行操作,而是只对数据的一部分(也就是一个Partition)进行操作。

HadoopPartition

HadoopPartition与Partition相比,增加了一个Int整型和一个InputSplit类型(Hadoop中的数据分片,可以理解为另一类Partition)的数据,重写了hashCode和equals方法,并增加了一个获取Hadoop环境参数列表的方法getPipeEnvVars()

private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: InputSplit)
  extends Partition {

  val inputSplit = new SerializableWritable[InputSplit](s)

  override def hashCode(): Int = 31 * (31 + rddId) + index

  override def equals(other: Any): Boolean = super.equals(other)

  def getPipeEnvVars(): Map[String, String] = {
    val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
      val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]

      Map("map_input_file" -> is.getPath().toString(),
        "mapreduce_map_input_file" -> is.getPath().toString())
    } else {
      Map()
    }
    envVars
  }
}

与JdbcPartition相类似,HadoopPartition也被定义为一个半私有类,只有父包和子包可以访问,对外暴露的接口是HadoopRDD,其中的获取分片和计算的方法如下:

  override def getPartitions: Array[Partition] = {
    val jobConf = getJobConf()
  
    SparkHadoopUtil.get.addCredentials(jobConf)
    val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
    val inputSplits = if (ignoreEmptySplits) {
      allInputSplits.filter(_.getLength > 0)
    } else {
      allInputSplits
    }
    val array = new Array[Partition](inputSplits.size)
    for (i <- 0 until inputSplits.size) {
      array(i) = new HadoopPartition(id, i, inputSplits(i))
    }
    array
  }

getPartitions只是将Hadoop中的数据分片通过一系列处理之后,遍历得到Spark中的Partition,并且设置了hash值的偏移量i。

  override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
    val iter = new NextIterator[(K, V)] {

      private val split = theSplit.asInstanceOf[HadoopPartition]
      private var reader: RecordReader[K, V] = null
      private val inputFormat = getInputFormat(jobConf)
      reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)

      private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
      private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()

      override def getNext(): (K, V) = {
        try {
          finished = !reader.next(key, value)
        } catch {
        }
        (key, value)
      }
    }
    new InterruptibleIterator[(K, V)](context, iter)
  }

getPartitions将传入的Partition强行转化为HadoopPartition,调用Hadoop的API将InputSplit转化为InputFormat,获取到InputFormat之后通过配置Reader读取inputFormat中的数据并返回一个迭代器。

partition的数量如果在初始化SparkContext时没有指定,则默认读取spark.default.parallelism中的配置,也可以通过传参指定例如上述的JdbcPartition,如果都都没有设置,则根据输入数据的分片数量来决定。同时Transformation也会影响partition的数目,例如union则是两个rdd的partition相加,filter、map则是继承父RDD的partition数,intersection是取两者最大。

Partition数量的影响:

Partition数量太少:资源不能充分利用,例如local模式下,有16core,但是Partition数量仅为8的话,有一半的core没利用到。
Partition数量太多:资源利用没问题,但是导致task过多,task的序列化和传输的时间开销增大。

你可能感兴趣的:(Spark)