import org.apache.spark.{SparkConf, SparkContext}
object GroupNormal {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupNormal")
val sc = new SparkContext(conf)
// 数据可能有几亿条,此处只做模拟示例
val dataRDD = sc.parallelize(List(
("hello", 2),
("java", 7),
("where", 1),
("rust", 2),
// 中间还有很多数据,不做展示
("scala", 1),
("java", 1),
("black", 9)
))
// 做一个词频统计
val result = dataRDD.groupByKey()
.mapValues(_.sum)
.sortBy(_._2, false)
result.take(10).foreach(println)
sc.stop()
}
}
// 修改此部分groupByKey代码为reduceByKey
val result = dataRDD
.reduceByKey(_ + _)
.sortBy(_._2, false)
result.take(10).foreach(println)
id | chinese | math | english | year |
---|---|---|---|---|
3412312 | 121 | 115 | 134 | 2018 |
5231211 | 103 | 131 | 114 | 2010 |
…… | …… | …… | …… | …… |
2342354 | 134 | 105 | 124 | 2014 |
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* 数据分组错误示例
*
* @author ALion
* @version 2019/5/15 22:33
*/
object GroupDemo {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
// 获取原始数据
val studentDF = spark.sql(
"""
|SELECT *
|FROM tb_student_score
|WHERE id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL
""".stripMargin)
// 开始进行分析
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
})
.groupByKey() // 按年分组
.mapValues(_.toSeq.sortWith(_._2 > _._2).take(100)) // 根据math对每个人进行降序排序,最后获取前100的人
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
spark.stop()
}
}
spark.sql("SELECT * FROM tb_student_score")
这种形势读取表中数据较慢,有更快的方式 val tbSchema = StructType(Array(
StructField("id", LongType, true),
StructField("chinese", IntegerType, true),
StructField("math", IntegerType, true),
StructField("english", IntegerType, true),
StructField("year", IntegerType, true)
))
// 获取原始数据
val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
.where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
// 开始进行分析
val resultRDD = studentDF.rdd
.mapPartitions {
// 自己实现时,如果为了性能更好,不建议这样的函数式写法
// 这里只是为了方便看
_.map { row =>
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
}.toArray
.groupBy(_._1) // 先在每个分块前,获取历年的数学前100名,减少后续groupBy的shuffle数据量
.mapValues(_.map(_._2).sortWith(_._2 > _._2).take(100))
.toIterator
}.groupByKey() // 最后获取所有分块的前100名,再次排序,计算总的前100名
.mapValues(_.flatten.toSeq.sortWith(_._2 > _._2).take(100))
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
上述代码,已经完成功能实现。那么,这样的代码是否是最好的呢?答案是否定的。因为当前的排序是针对每个分块(Partition)的,一个Executor上有多个分块,每个分块有前100条数据需要shuffle,显然如果一个Executor一共只有100条数据需要shuffle才是最理想的!如果我们能有办法同时操纵每个Executor上的所有数据,获取前100条数据,那该多好啊!
然而,Spark并没提供一个类似mapPartition的可以对Executor上所有分块统一操作的算子(不然的话,我们就可以像mapPartion那样统计每Executor的前100名了)。不过我们有一个算子reduceByKey,它会在每个节点合并数据后再shuffle到一个节点进行最后的合并,这种行为似乎与我们需要的逻辑类似,不过好像又有那么一点不一样。
你可能会说reduceByKey是合并,而我们的需求是排序啊!!!是的,这看上去似乎有点矛盾。
事实上,这样是行得通的:
这样一个集合类型A,似乎在Scala、Java中不存在,不过有一个TreeSet能保证内部有序,我们可以在数据合并后手动提取前100,这样就可以了(另外,你也可以自己实现这样一个集合:3)
第一步,先将id和math转为一个对象,并为这个对象实现equals、hashCode、compareTo方法,保证后续在TreeSet中的排序不会出问题。另外,再实现一个toString方法,方便我们查看打印效果!:)
public class Person implements Comparable<Person>, Serializable {
private long id;
private int math;
public Person(long id, int math) {
this.id = id;
this.math = math;
}
@Override
public int compareTo(Person person) {
int result = person.math - this.math; // 降序
if (result == 0) {
result = person.id - this.id > 0 ? 1 : -1;
}
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Person person = (Person) o;
return id == person.id;
}
@Override
public int hashCode() {
return (int) (id ^ (id >>> 32));
}
@Override
public String toString() {
return "Person{" +
"id='" + id + '\'' +
", math=" + math +
'}';
}
}
import scala.collection.immutable.TreeSet
object Demo {
def main(args: Array[String]): Unit = {
val set = TreeSet[Person](
new Person(1231232L, 108),
new Person(3214124L, 116),
new Person(1321313L, 121),
new Person(6435235L, 125)
)
// 获取前3名
for (elem <- set.take(3)) {
println(s"--> elem = $elem")
}
}
}
第二步,将原先的id、math封装为TreeSet
studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, TreeSet(new Person(id, math)))
})
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, TreeSet(new Person(id, math)))
})
.reduceByKey((set1, set2) => set1 ++ set2 take 100) // 依次合并2个Set,并只保留前100
resultRDD.take(10).foreach(println)
import scala.collection.mutable
class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {
val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)
def +=(elem: A): MyTreeSet[A] = {
this add elem
this
}
def add(elem: A): Unit = {
set.add(elem)
// 删除排在最后的多余元素
check10Size()
}
def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
that.set.foreach(e => this add e)
this
}
def check10Size(): Unit = {
// 如果超过了firstNum个,就删除
if (set.size > firstNum) {
set -= set.last
}
}
override def toString: String = set.toString
}
object MyTreeSet {
def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem) // 默认保留前10
def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
}
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, new Person(id, math))
}).aggregateByKey(MyTreeSet[Person](100)) (
(set, v) => set += v,
(set1, set2) => set1 ++= set2
)
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import scala.collection.immutable.TreeSet
object GroupDemo3 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
val tbSchema = StructType(Array(
StructField("id", LongType, true),
StructField("chinese", IntegerType, true),
StructField("math", IntegerType, true),
StructField("english", IntegerType, true),
StructField("year", IntegerType, true)
))
// 获取原始数据
val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
.where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
// 开始进行分析
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, new Person(id, math))
}).aggregateByKey(MyTreeSet[Person](100)) (
(set, v) => set += v,
(set1, set2) => set1 ++= set2
) // 依次合并2个Set,并只保留前100
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
spark.stop()
}
}
class PersonScala(val id: Long, val math: Int) extends Ordered[PersonScala] with Serializable {
override def compare(that: PersonScala): Int = {
var result = that.math - this.math // 降序
if (result == 0)
result = if (that.id - this.id > 0) 1 else -1
result
}
override def equals(obj: Any): Boolean = {
obj match {
case person: PersonScala => this.id == person.id
case _ => false
}
}
override def hashCode(): Int = (id ^ (id >>> 32)).toInt
override def toString: String = "Person{" + "id=" + id + ", math=" + math + '}'
}
object PersonScala {
def apply(id: Long, math: Int): PersonScala = new PersonScala(id, math)
}
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
// 使用sql分析
val resultDF = spark.sql(
"""
|SELECT year,id,math
|FROM (
| SELECT year,id,math,ROW_NUMBER() OVER (PARTITION BY year ORDER BY math DESC) rank
| FROM tb_student_score
|) g
|WHERE g.rank <= 100
""".stripMargin)
// 触发Action,展示部分统计结果
resultDF.show()
spark.stop()
}