Spark数据分区(partitionBy分区、partitioner获取分区方式、自定义分区)

数据分区

partitionBy分区

在分布式程序中,通信的代价是很大的,因此控制数据分布以获得最少的网络传输可以极大地提升整体性能。和单节点的程序需要为记录集合选择合适的数据结构一样,Spark 程序可以通过控制RDD 分区方式来减少通信开销。分区并不是对所有应用都有好处的——比如,如果给定RDD 只需要被扫描一次,我们完全没有必要对其预先进行分区处理。只有当数据集多次在诸如连接这种基于键的操作中使用时,分区才会有帮助。

Spark 中所有的键值对RDD 都可以进行分区。系统会根据一个针对键的函数对元素进行分组。Spark 可以确保同一组的键出现在同一个节点上。比如,你可能使用哈希分区将一个RDD 分成了100 个分区,此时键的哈希值对100 取模的结果相同的记录会被放在一个节点上。你也可以使用范围分区法,将键在同一个范围区间内的记录都放在同一个节点上。

举个简单的例子,我们分析这样一个应用,它在内存中保存着一张很大的用户信息表——也就是一个由(UserID, UserInfo) 对组成的RDD,其中UserInfo 包含一个该用户所订阅的主题的列表。该应用会周期性地将这张表与一个小文件进行组合,这个小文件中存着过去五分钟内发生的事件——其实就是一个由(UserID, LinkInfo) 对组成的表,存放着过去五分钟内某网站各用户的访问情况。例如,我们可能需要对用户访问其未订阅主题的页面的情况进行统计。我们可以使用Spark 的join() 操作来实现这个组合操作,其中需要把UserInfo 和LinkInfo 的有序对根据UserID 进行分组。

// 初始化代码;从HDFS商的一个Hadoop SequenceFile中读取用户信息
// userData中的元素会根据它们被读取时的来源,即HDFS块所在的节点来分布
// Spark此时无法获知某个特定的UserID对应的记录位于哪个节点上
val sc = new SparkContext(...)
val userData = sc.sequenceFile[UserID, UserInfo]("hdfs://...").persist()
// 周期性调用函数来处理过去五分钟产生的事件日志
// 假设这是一个包含(UserID, LinkInfo)对的SequenceFile
def processNewLogs(logFileName: String) {
val events = sc.sequenceFile[UserID, LinkInfo](logFileName)
val joined = userData.join(events)// RDD of (UserID, (UserInfo, LinkInfo)) pairs
val offTopicVisits = joined.filter {
case (userId, (userInfo, linkInfo)) => // Expand the tuple into its components
!userInfo.topics.contains(linkInfo.topic)
}.count()
println("Number of visits to non-subscribed topics: " + offTopicVisits)
}

这段代码可以正确运行,但是不够高效。这是因为在每次调用processNewLogs() 时都会用到join() 操作,而我们对数据集是如何分区的却一无所知。默认情况下,连接操作会将两个数据集中的所有键的哈希值都求出来,将该哈希值相同的记录通过网络传到同一台机器上,然后在那台机器上对所有键相同的记录进行连接操作。因为userData 表比每五分钟出现的访问日志表events 要大得多,所以要浪费时间做很多额外工作:在每次调用时都对userData 表进行哈希值计算和跨节点数据混洗,虽然这些数据从来都不会变化。
Spark数据分区(partitionBy分区、partitioner获取分区方式、自定义分区)_第1张图片要解决这一问题也很简单:在程序开始时,对userData 表使用partitionBy() 转化操作,将这张表转为哈希分区。可以通过向partitionBy 传递一个spark.HashPartitioner 对象来实现该操作:

// scala自定义分区方式
val sc = new SparkContext(...)
val userData = sc.sequenceFile[UserID, UserInfo]("hdfs://...")
.partitionBy(new HashPartitioner(100)) // 构造100个分区
.persist()

processNewLogs() 方法可以保持不变: 在processNewLogs() 中,eventsRDD 是本地变量,只在该方法中使用了一次,所以为events 指定分区方式没有什么用处。由于在构建userData 时调用了partitionBy(),Spark 就知道了该RDD 是根据键的哈希值来分区的,这样在调用join() 时,Spark 就会利用到这一点。具体来说,当调用userData.
join(events) 时,Spark 只会对events 进行数据混洗操作,将events 中特定UserID 的记录发送到userData 的对应分区所在的那台机器上。这样,需要通过网络传输的数据就大大减少了,程序运行速度也可以显著提升了。

注意,partitionBy() 是一个转化操作,因此它的返回值总是一个新的RDD,但它不会改变原来的RDD。RDD 一旦创建就无法修改。因此应该对partitionBy() 的结果进行持久化,并保存为userData,而不是原来的sequenceFile() 的输出。此外,传给partitionBy() 的100 表示分区数目,它会控制之后对这个RDD 进行进一步操作(比如连接操作)时有多少任务会并行执行。总的来说,这个值至少应该和集群中的总核心数一样。

Spark数据分区(partitionBy分区、partitioner获取分区方式、自定义分区)_第2张图片
如果没有将partitionBy() 转化操作的结果持久化,那么后面每次用到这个RDD 时都会重复地对数据进行分区操作。不进行持久化会导致整个RDD 谱系图重新求值。那样的话,partitionBy() 带来的好处就会被抵消,导致重复对数据进行分区以及跨节点的混洗,和没有指定分区方式时发生的情况十分相似。

获取RDD的分区方式

在Scala 和Java 中,你可以使用RDD 的partitioner 属性(Java 中使用partitioner() 方法)来获取RDD 的分区方式。它会返回一个scala.Option 对象,这是Scala 中用来存放可能存在的对象的容器类。你可以对这个Option 对象调用isDefined() 来检查其中是否有值,调用get() 来获取其中的值。如果存在值的话,这个值会是一spark.Partitioner
对象。

scala> val pairs = sc.parallelize(List((1, 1), (2, 2), (3, 3)))
pairs: spark.RDD[(Int, Int)] = ParallelCollectionRDD[0] at parallelize at <console>:12

scala> pairs.partitioner
res0: Option[spark.Partitioner] = None

scala> val partitioned = pairs.partitionBy(new spark.HashPartitioner(2))
partitioned: spark.RDD[(Int, Int)] = ShuffledRDD[1] at partitionBy at <console>:14

scala> partitioned.partitioner
res1: Option[spark.Partitioner] = Some(spark.HashPartitioner@5147788d)

在这段简短的代码中,我们创建出了一个由(Int, Int) 对组成的RDD,初始时没有分
区方式信息(一个值为None 的Option 对象)。然后通过对第一个RDD 进行哈希分区,创建出了第二个RDD。如果确实要在后续操作中使用partitioned,那就应当在定义partitioned 时,在第三行输入的最后加上persist()。这和之前的例子中需要对userData调用persist() 的原因是一样的:如果不调用persist() 的话,后续的RDD 操作会对partitioned 的整个谱系重新求值,这会导致对pairs 一遍又一遍地进行哈希分区操作。

从分区中获益的操作

Spark 的许多操作都引入了将数据根据键跨节点进行混洗的过程。所有这些操作都会
从数据分区中获益。就Spark 1.0 而言,能够从数据分区中获益的操作有cogroup()、
groupWith()、join()、leftOuterJoin()、rightOuterJoin()、groupByKey()、reduceByKey()、combineByKey() 以及lookup()。

对于像reduceByKey() 这样只作用于单个RDD 的操作,运行在未分区的RDD 上的时候会导致每个键的所有对应值都在每台机器上进行本地计算,只需要把本地最终归约出的结果值从各工作节点传回主节点,所以原本的网络开销就不算大。而对于诸如cogroup() 和join() 这样的二元操作,预先进行数据分区会导致其中至少一个RDD(使用已知分区器的那个RDD)不发生数据混洗。如果两个RDD 使用同样的分区方式,并且它们还缓存在同样的机器上(比如一个RDD 是通过mapValues() 从另一个RDD 中创建出来的,这两个RDD 就会拥有相同的键和分区方式),或者其中一个RDD 还没有被计算出来,那么跨节点的数据混洗就不会发生了。

影响分区方式的操作

Spark 内部知道各操作会如何影响分区方式,并将会对数据进行分区的操作的结果RDD 自动设置为对应的分区器。例如,如果你调用join() 来连接两个RDD;由于键相同的元素会被哈希到同一台机器上,Spark 知道输出结果也是哈希分区的,这样对连接的结果进行诸如reduceByKey() 这样的操作时就会明显变快。

不过,转化操作的结果并不一定会按已知的分区方式分区,这时输出的RDD 可能就会没有设置分区器。例如,当你对一个哈希分区的键值对RDD 调用map() 时,由于传给map()的函数理论上可以改变元素的键,因此结果就不会有固定的分区方式。Spark 不会分析你的函数来判断键是否会被保留下来。不过,Spark 提供了另外两个操作mapValues() 和flatMapValues() 作为替代方法,它们可以保证每个二元组的键保持不变。

这里列出了所有会为生成的结果RDD 设好分区方式的操作:cogroup()、groupWith()、
join()、lef tOuterJoin()、rightOuterJoin()、groupByKey()、reduceByKey()、
combineByKey()、partitionBy()、sort()、mapValues()(如果父RDD 有分区方式的话)、flatMapValues()(如果父RDD 有分区方式的话),以及filter()(如果父RDD 有分区方式的话)。其他所有的操作生成的结果都不会存在特定的分区方式。

最后,对于二元操作,输出数据的分区方式取决于父RDD 的分区方式。默认情况下,结果会采用哈希分区,分区的数量和操作的并行度一样。不过,如果其中的一个父RDD 已经设置过分区方式,那么结果就会采用那种分区方式;如果两个父RDD 都设置过分区方式,结果RDD 会采用第一个父RDD 的分区方式。

自定义分区方式

虽然Spark 提供的HashPartitioner 与RangePartitioner 已经能够满足大多数用例,但
Spark 还是允许你通过提供一个自定义的Partitioner 对象来控制RDD 的分区方式。这可以让你利用领域知识进一步减少通信开销。

举个例子,假设我们要在一个网页的集合上运行前一节中的PageRank 算法。在这里,每个页面的ID(RDD 中的键)是页面的URL。当我们使用简单的哈希函数进行分区时,拥有相似的URL 的页面(比如http://www.cnn.com/WORLD 和http://www.cnn.com/US)可能会被分到完全不同的节点上。然而,我们知道在同一个域名下的网页更有可能相互链接。由于PageRank 需要在每次迭代中从每个页面向它所有相邻的页面发送一条消息,因此把这些页面分组到同一个分区中会更好。可以使用自定义的分区器来实现仅根据域名而不是整个URL 来分区。

要实现自定义的分区器,你需要继承org.apache.spark.Partitioner 类并实现下面三个方法。

  • numPartitions: Int:返回创建出来的分区数。
  • getPartition(key: Any): Int:返回给定键的分区编号(0 到numPartitions-1)。
  • equals():Java 判断相等性的标准方法。这个方法的实现非常重要,Spark 需要用这个方法来检查你的分区器对象是否和其他分区器实例相同,这样Spark 才可以判断两个RDD 的分区方式是否相同。

有一个问题需要注意,当你的算法依赖于Java 的hashCode() 方法时,这个方法有可能会返回负数。你需要十分谨慎,确保getPartition() 永远返回一个非负数。

下面展示了如何编写一个前面构思的基于域名的分区器,这个分区器只对URL 中的域
名部分求哈希。

class DomainNamePartitioner(numParts: Int) extends Partitioner {
	override def numPartitions: Int = numParts
	override def getPartition(key: Any): Int = {
		val domain = new Java.net.URL(key.toString).getHost()
		val code = (domain.hashCode % numPartitions)
		if(code < 0) {
			code + numPartitions // 使其非负
		}else{
			code
		}
	}
// 用来让Spark区分分区函数对象的Java equals方法
	override def equals(other: Any): Boolean = other match {
		case dnp: DomainNamePartitioner =>
			dnp.numPartitions == numPartitions
		case _ =>
			false
	}
}

注意,在equals() 方法中,使用Scala 的模式匹配操作符(match)来检查other 是否是DomainNamePartitioner,并在成立时自动进行类型转换;这和Java 中的instanceof() 是一样的。

参考 《Spark快速大数据分析》

你可能感兴趣的:(spark)