RDD是一个抽象类,主要包含五个部分:
其中最后两个部分是可选的,以上五个部分对应着五个方法:
@DeveloperApi
def compute(split: Partition, context: TaskContext): Iterator[T]
protected def getPartitions: Array[Partition]
protected def getDependencies: Seq[Dependency[_]] = deps
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
@transient val partitioner: Option[Partitioner] = None
一份待处理的原始数据会被按照相应的逻辑(例如jdbc和hdfs的split逻辑)切分成n份,每份数据对应到RDD中的一个Partition,Partition的数量决定了task的数量,影响着程序的并行度。Partition的源码如下:
trait Partition extends Serializable {
def index: Int
override def hashCode(): Int = index
override def equals(other: Any): Boolean = super.equals(other)
}
Partition和RDD是伴生的,即每一种RDD都有其对应的Partition实现,所以,分析Partition主要是分析其子类,我们关注两个常用的子类,JdbcPartition和HadoopPartition。
JdbcPartition类包含于JdbcRDD.scala文件,它继承了Partition类,多加了两个长整型属性lower和upper。
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index: Int = idx
}
JdbcPartition被定义为一个半私有类,只有父包和子包可以访问,对外暴露的接口是jdbcRDD,该RDD定义如下:
class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging
其中有两个重要方法getPartitions
和compute
:
override def getPartitions: Array[Partition] = {
val length = BigInt(1) + upperBound - lowerBound
(0 until numPartitions).map { i =>
val start = lowerBound + ((i * length) / numPartitions)
val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
new JdbcPartition(i, start.toLong, end.toLong)
}.toArray
}
getPartitions
主要做的操作是将数据根据传入的numPartitions参数进行分片并封装成多个JdbcPartition类返回。
override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
{
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()
}
compute
将传入的Partition强行转化为JdbcPartition,连接数据库并且对sql语句进行预处理,将JdbcPartition中的lower和upper作为数据库中id的上下界传入,并执行查询语句。也就是说,每次调用compute时,不会对所有数据进行操作,而是只对数据的一部分(也就是一个Partition)进行操作。
HadoopPartition与Partition相比,增加了一个Int整型和一个InputSplit类型(Hadoop中的数据分片,可以理解为另一类Partition)的数据,重写了hashCode和equals方法,并增加了一个获取Hadoop环境参数列表的方法getPipeEnvVars()
:
private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: InputSplit)
extends Partition {
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = 31 * (31 + rddId) + index
override def equals(other: Any): Boolean = super.equals(other)
def getPipeEnvVars(): Map[String, String] = {
val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
Map("map_input_file" -> is.getPath().toString(),
"mapreduce_map_input_file" -> is.getPath().toString())
} else {
Map()
}
envVars
}
}
与JdbcPartition相类似,HadoopPartition也被定义为一个半私有类,只有父包和子包可以访问,对外暴露的接口是HadoopRDD,其中的获取分片和计算的方法如下:
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
SparkHadoopUtil.get.addCredentials(jobConf)
val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
val inputSplits = if (ignoreEmptySplits) {
allInputSplits.filter(_.getLength > 0)
} else {
allInputSplits
}
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
array(i) = new HadoopPartition(id, i, inputSplits(i))
}
array
}
getPartitions
只是将Hadoop中的数据分片通过一系列处理之后,遍历得到Spark中的Partition,并且设置了hash值的偏移量i。
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {
private val split = theSplit.asInstanceOf[HadoopPartition]
private var reader: RecordReader[K, V] = null
private val inputFormat = getInputFormat(jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
override def getNext(): (K, V) = {
try {
finished = !reader.next(key, value)
} catch {
}
(key, value)
}
}
new InterruptibleIterator[(K, V)](context, iter)
}
getPartitions
将传入的Partition强行转化为HadoopPartition,调用Hadoop的API将InputSplit转化为InputFormat,获取到InputFormat之后通过配置Reader读取inputFormat中的数据并返回一个迭代器。
partition的数量如果在初始化SparkContext时没有指定,则默认读取spark.default.parallelism
中的配置,也可以通过传参指定例如上述的JdbcPartition,如果都都没有设置,则根据输入数据的分片数量来决定。同时Transformation也会影响partition的数目,例如union则是两个rdd的partition相加,filter、map则是继承父RDD的partition数,intersection是取两者最大。
Partition数量太少:资源不能充分利用,例如local模式下,有16core,但是Partition数量仅为8的话,有一半的core没利用到。
Partition数量太多:资源利用没问题,但是导致task过多,task的序列化和传输的时间开销增大。