spark 自定义Partitioner

在对RDD数据进行分区时,默认使用的是 HashPartitioner,该partitioner对key进行哈希,然后mod 上分区数目,mod的结果相同的就会被分到同一个partition中
如果嫌HashPartitioner 功能单一,可以自定义partitioner

自定义partitioner

1.继承org.apache.spark.Partitioner
2.重写numPartitions方法
3.重写getPartition方法

举个例子

有一些RDD数据,要根据其key的长短来分partition,比如key="a"的都存在一个partition中
可以这样写

class Mypartitioner1(val num:Int) extends org.apache.spark.Partitioner{
  override def numPartitions: Int = num
  override def getPartition(key: Any): Int = {
      val len  =key.toString.length
      len % num
  }
}

def main(args: Array[String]): Unit = {
      val spark =SparkSession.builder().config(new SparkConf()).getOrCreate()
      val sc =spark.sparkContext
      val data =sc.parallelize(Array(
        ("aaa",2),("aaa",3),("aaa",1),("aaa",0),("aaa",4),
        ("aa",2),("aa",3),("aa",1),("aa",0),("aa",4),
        ("a",2),("a",3),("a",1),("a",0),("a",4)
      ))
       data.partitionBy(new Mypartitioner2(3))
      .saveAsTextFile("develop/wangdaopeng/lab1")

mod方法太泛了,如果想精确点可以这么写

class Mypartitioner2( num:Int) extends org.apache.spark.Partitioner{
  override def numPartitions: Int = num

  override def getPartition(key: Any): Int = {
      if(key.toString.size == 3){
           2
      }
      else if(key.toString.size == 2){
         1
      }
      else {
        0
      }
  }
}

要注意这里的0 1 2为分区的ID,范围一定是0到numPartitions-1,不然会报异常

你可能感兴趣的:(Spark)