spark自定义分区案例

     在hadoop的mapreduce中默认patitioner是HashPartitioner,我们可以自定义Partitioner可以有效防止数据倾斜, 在Spark里面也是一样,在Spark里也是默认的HashPartitioner, 如果自己想自己定义Partitioner继承org.apache.spark里面的Partitioner并且重写它里面的两个方法就行了。

模板如下:

//只需要继承Partitioner,重写两个方法

class MyPartitioner(val num: Int) extends Partitioner {

      //这里定义partitioner个数

      override def numPartitions: Int = ???

      //这里定义分区规则

      override def getPartition(key: Any): Int = ???

}

 

案例1:单词统计

object xy {

    def main(args: Array[String]): Unit = {

        val conf = new SparkConf().setAppName("urlLocal").setMaster("local[2]")

        val sc = new SparkContext(conf)

        val rdd1 = sc.parallelize(List("lijie hello lisi", "zhangsan wangwu mazi", "hehe haha nihaoa heihei lure hehe hello word"))

        val rdd2 = rdd1.flatMap(_.split(" ")).map(x => { (x, 1) }).reduceByKey(_ + _)

        //这里指定自定义分区,然后输出

        val rdd3 = rdd2.sortBy(_._2).partitionBy(new MyPartitioner(4)).mapPartitions(x => x)

            .saveAsTextFile("C:\\Users\\Administrator\\Desktop\\out01")

        println(rdd2.collect().toBuffer)

        sc.stop()

      }

}

 

class MyPartitioner(val num: Int) extends Partitioner {

     override def numPartitions: Int = num override

     def getPartition(key: Any): Int = { val len = key.toString.length

           //根据单词长度对分区个数取模

           len % num

      }

}

案例来源:https://blog.csdn.net/qq_20641565/article/details/76130724

 

 

案例2:统计网址

package day02

import java.net.URL

import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}

import scala.collection.mutable

object UserD_Partitioner {

def main(args: Array[String]) {

     val conf = new SparkConf().setAppName("UserD_Partitioner").setMaster("local[2]")

     val sc = new SparkContext(conf)

     //rdd1将数据切分,元组中放的是(URL, 1)

     val rdd1 = sc.textFile("c://itcast.log").map(line => {

     val f = line.split("\t")

     (f(1), 1)

     })

     val rdd2 = rdd1.reduceByKey(_ + _)

     val rdd3 = rdd2.map(t => {

          val url = t._1

          val host = new URL(url).getHost

          (host, (url, t._2))

     })

     val ints = rdd3.map(_._1).distinct().collect()

     val hostParitioner = new HostParitioner(ints)

     //val rdd4 = rdd3.partitionBy(new HashPartitioner(ints.length))

     val rdd4 = rdd3.partitionBy(hostParitioner).mapPartitions(it => {

          it.toList.sortBy(_._2._2).reverse.take(2).iterator

     })

     rdd4.saveAsTextFile("c://out4")

     //println(rdd4.collect().toBuffer)

     sc.stop()

     }

}

/** 自定义分区:

* 决定了数据到哪个分区里面

* @param ins

*/

class HostParitioner(ins: Array[String]) extends Partitioner {

     val parMap = new mutable.HashMap[String, Int]()

     var count = 0

     for(i <- ins){

          parMap += (i -> count)

          count += 1

     }

     //获取分区数量

     override def numPartitions: Int = ins.length

     //数据分区规则

     override def getPartition(key: Any): Int = {

          parMap.getOrElse(key.toString, 0)

     }

}

案例来源:https://blog.csdn.net/freefish_yzx/article/details/77542526

 

Partitioner抽象类:

package org.apache.spark

/**

* An object that defines how the elements in a key-value pair RDD are partitioned by key.

* Maps each key to a partition ID, from 0 to `numPartitions - 1`.

*/

abstract class Partitioner extends Serializable {

     def numPartitions: Int

def getPartition(key: Any): Int

}

 

自定义分区,要继承Partitioner抽象类,重写里面的numPartitions和getPartitioner方法。

def numPartitions : Int      获取分区数量

def getPartitioner(key : Any) : Int       获取定义分区规则

传入一个key,返回一个Int类型的value。这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1。

注意:在这里,key就是对(根据)什么进行操作(分区),什么就是key。具体取决于什么方法使用该分区类,该方法获取到的Key是什么。

 

案例3:根据学科分区,传入的是一个去重后的学科数组。

/**

* 自定义分区器

* @param subjects

*/

class SubjectPartitioner2(val subjects:Array[String]) extends Partitioner{

     //主构造器里面的代码,new的时候就立即执行。

 

     //分区规则 HashMap(学科,编号) 编号为:0 ---> 学科数量-1

     val rules = new mutable.HashMap[String,Int]()

     var i= 0

 

     //自定义学科的编号

     for(sub <- subjects){

          rules(sub) = i //等价于:rules += (sub -> i)

          i += 1

     }

 

     /**

     * 获取分区的数量(在这里即为学科数量)

     * @return

     */

     override def numPartitions: Int = subjects.length

 

     /**

     * 数据分区的规则(传入一个key,返回一个Int类型的value)

     * def getPartition(key: Any): Int:这个函数需要对输入的key做计算,

     * 然后返回该key的分区ID,范围一定是0到numPartitions-1。

     * @param key

     * @return

     */

     override def getPartition(key: Any): Int = {

     //强转asInstanceOf

     val tuple: (String, String) = key.asInstanceOf[Tuple2[String,String]]

     val sub = tuple._1 //取出元组里面的学科

     rules(sub)

     //rules(key.toString)

}

 

 

3个应用分区类的算子:

reduceByKey()

reduceByKey()的三种参数形式:

reduceByKey(func) 函数

reduceByKey(func,numPartitions) 函数,分区数量

reduceByKey(partitioner,func) 分区器,函数

 

partitionBy()

如:partitionBy(new partitioner)

 

repartitionAndSortWithinPartitions()

如:repartitionAndSortWithinPartitions(new partitioner)

 

案例3详细说明:

数据样式:

http://UI.test.cn/laowang

http://php.test.cn/laoli

http://U-3D.test.cn/laowang

要求:统计各个学科内点击次数topN的老师

代码:

package lwj.sparkDay2

import java.net.URL

import org.apache.log4j.{Level, Logger}

import org.apache.spark.rdd.RDD

import org.apache.spark.{Partitioner, SparkConf, SparkContext}

import scala.collection.mutable

 

object FavTeacherInSubject4 {

def main(args: Array[String]): Unit = {

     //设置日志打印级别(可选)

     Logger.getLogger("org").setLevel(Level.ERROR)

     //1、设置配置信息

     val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[2]")

     //2、获取SparkContext上下文

     val sc: SparkContext = new SparkContext(conf)

     //3、读取文件

     val lines: RDD[String] = sc.textFile("C://Users//xxx//Desktop//1.log")

 

     //4、数据清洗方法:(数据格式:http://bigdata.test.cn/laozhang)

     val subjectTeacherAndOne: RDD[((String, String), Int)] = lines.map(line => {

          val index1: Int = line.lastIndexOf("/")

          val teacherName: String = line.substring(index1 + 1)

          val host: String = new URL(line).getHost //URL解析,获取网址bigdata.test.cn/laozhang

          val index2: Int = host.indexOf(".") //int indexOf(String str) :返回第一次出现的指定子字符串在此字符串中的索引。

          val subject: String = host.substring(0, index2)

          ((subject, teacherName), 1)

     })

 

     //先触发任务,计算有多少个学科

     val subjectRDD: RDD[String] = subjectTeacherAndOne.map(_._1._1).distinct()

     //触发计算,获得有多少个具体的学科

     val subjects: Array[String] = subjectRDD.collect()

 

     //先分区再聚合

     val reduced: RDD[((String, String), Int)] = subjectTeacherAndOne.reduceByKey(new SubjectPartitioner1(subjects),_+_)

 

     //通过自定义分区器将相同学科的数据都放在一个分区当中

     //val partitionesRDD: RDD[((String, String), Int)] = reduced.partitionBy(new SubjectPartitioner(subjects))

 

     //再排序(mapPartitions():一个分区一个分区的拿,传入一个迭代器,返回一个迭代器。)

     val sorted: RDD[((String, String), Int)] = reduced.mapPartitions(_.toList.sortBy(x => -x._2).take(2).iterator)

 

     //收集结果,打印

     val rules: Array[((String, String), Int)] = sorted.collect()

     println(rules.toBuffer)

 

     //释放资源

     sc.stop()

     }

}

 

     /**

     * 自定义分区器

     * @param subjects

     */

class SubjectPartitioner1(val subjects:Array[String]) extends Partitioner{

     //分区规则 HashMap(学科,编号)

     val rules = new mutable.HashMap[String,Int]()

     var i= 0

 

     //自定义学科的编号

     for(sub <- subjects){

          rules(sub) = i

          i += 1

     }

 

     /**

     * 获取分区的数量(在这里即为学科数量)

     * @return

     */

     override def numPartitions: Int = subjects.length

 

     /**

     * 数据分区的规则(传入一个key,返回一个Int类型的value)

     * @param key

     * @return

     */

     override def getPartition(key: Any): Int = {

          //强转asInstanceOf

          val tuple: (String, String) = key.asInstanceOf[Tuple2[String,String]]

          val sub = tuple._1 //取出元组里面的学科

          rules(sub)

     }

}

 

你可能感兴趣的:(小白笔记)