class Person(val id: Long, val grade: Int) extends Ordered[Person] with Serializable {
override def compare(that: Person): Int = {
var result = that.grade - this.grade // 降序
if (result == 0)
result = if (that.id - this.id > 0) 1 else -1
result
}
override def equals(obj: Any): Boolean = {
obj match {
case person: Person => this.id == person.id
case _ => false
}
}
override def hashCode(): Int = (id ^ (id >>> 32)).toInt
override def toString: String = "Person{" + "id=" + id + ", grade=" + grade + "}"
}
object Person {
def apply(id: Long, grade: Int): Person = new Person(id, grade)
}
/**
* Description: 数学、语文、英语的前NUM名的聚合器
*
* Date: 2019/11/27 1:39
*
* @author ALion
*/
class PersonAggregator(val mathSet: MyTreeSet[Person],
val chineseSet : MyTreeSet[Person],
val englishSet : MyTreeSet[Person]) {
/**
* 向聚合器添加单个元素
* @param element (人的id, 数学, 语文, 英语)
* @return this PersonAggregator
*/
def +=(element: (Long, Int, Int, Int)): PersonAggregator = {
this.mathSet += Person(element._1, element._2)
this.chineseSet += Person(element._1, element._3)
this.englishSet += Person(element._1, element._4)
this
}
/**
* 聚合成绩的方法
* @param that 另一个聚合器
* @return this PersonAggregator
*/
def ++=(that: PersonAggregator): PersonAggregator = {
this.mathSet ++= that.mathSet
this.chineseSet ++= that.chineseSet
this.englishSet ++= that.englishSet
this
}
override def toString: String =
"PersonAggregator{" +
"mathSet=" + mathSet +
", chineseSet=" + chineseSet +
", englishSet=" + englishSet +
'}'
}
object PersonAggregator {
def apply(): PersonAggregator =
new PersonAggregator(MyTreeSet[Person](), MyTreeSet[Person](), MyTreeSet[Person]())
}
object Demo {
def main(args: Array[String]): Unit = {
// 此处,我让MyTreeSet取的前2名,修改后面附录的MyTreeSet即可
val aggregator1 = PersonAggregator()
aggregator1 += (1, 80, 92, 100) += (2, 85, 90, 78) += (3, 88, 95, 67)
println(s"aggregator1 = $aggregator1")
}
}
aggregator = PersonAggregator{mathSet=TreeSet(Person{id=3, grade=88}, Person{id=2, grade=85}), chineseSet=TreeSet(Person{id=3, grade=95}, Person{id=1, grade=92}), englishSet=TreeSet(Person{id=1, grade=100}, Person{id=2, grade=78})}
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val chinese = row.getInt(row.fieldIndex("chinese"))
val english = row.getInt(row.fieldIndex("english"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math, chinese, english))
})
.aggregateByKey(PersonAggregator())(
(agg, v) => agg += v,
(agg1, agg2) => agg1 ++= agg2
) // 依次合并2个聚合器PersonAggregator
/**
* Description: 数学前100名,数学平均成绩,数学为0分的人数 -> 聚合器
*
* @note {{{
* 前100名 -> mathSet
* 分数之和 -> totalGrade
* 总人数 -> totalCount
* 平均成绩 -> totalGrade / totalCount (如果Long不够大,你可以换其他专用的数据类型,例如BigInt)
* 0分的人数 -> zeroCount
* }}}
*
* Date: 2019/11/27 1:39
* @author ALion
*/
class PersonAggregator2(val mathSet: MyTreeSet[Person],
var totalGrade: Long, var totalCount: Long,
var zeroCount: Long) {
/**
* 向聚合器添加单个元素
* @param element (人的id, 数学)
* @return this PersonAggregator
*/
def +=(element: (Long, Int)): PersonAggregator2 = {
this.mathSet += Person(element._1, element._2)
this.totalGrade += element._2
this.totalCount += 1
if (element._2 == 0) this.zeroCount += 1
this
}
/**
* 聚合成绩、人数的方法
*
* @param that 另一个聚合器
* @return this PersonAggregator
*/
def ++=(that: PersonAggregator2): PersonAggregator2 = {
this.mathSet ++= that.mathSet
this.totalGrade += that.totalGrade
this.totalCount += that.totalCount
this.zeroCount += that.zeroCount
this
}
/**
* 计算平均值
*/
def calcAVG(): Double = {
totalGrade / totalCount.toDouble
}
override def toString: String =
"PersonAggregator2{" +
"mathSet=" + mathSet +
", avgGrade=" + calcAVG() +
", zeroCount=" + zeroCount +
'}'
}
object PersonAggregator2 {
def apply(): PersonAggregator2 =
new PersonAggregator2(MyTreeSet[Person](), 0, 0 ,0)
}
import scala.collection.immutable.TreeSet
object Demo {
def main(args: Array[String]): Unit = {
val aggregator2 = PersonAggregator2()
aggregator2 += (1, 80) += (2, 0) += (3, 0)
println(s"aggregator2 = $aggregator2")
}
}
aggregator2 = PersonAggregator2{mathSet=TreeSet(Person{id=1, grade=80}, Person{id=3, grade=0}, Person{id=2, grade=0}), avgGrade=26.666666666666668, zeroCount=2}
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
})
.aggregateByKey(PersonAggregator2())(
(agg, v) => agg += v,
(agg1, agg2) => agg1 ++= agg2
)
RMSE的计算公式: 1 m ∑ i = 1 m ( x i − x − ) 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i} - _x^{-})^2} m1i=1∑m(xi−x−)2
咋一看上去似乎不可能能够一次性统计完,因为似乎得先算出平均数,才能继续计算RMSE的值啊!你的思路或许是这样的:
上面的逻辑没有问题,但是真的就不能一次完成聚合吗?
让我们先尝试对聚合算法进行拆解(当然有的算法确实没法拆解),对RMSE的算法进行转换,过程如下:
现在来看,显然简单了,你只需要找到m、 ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 ∑i=1mxi2、 ∑ i = 1 m x i \sum_{i=1}^{m} x_i ∑i=1mxi即可
/**
* Description: 数学平均成绩,RMSE -> 聚合器
*
* @note {{{
* 分数之和 -> totalGrade
* 总人数 -> totalCount
* 平均成绩 -> totalGrade / totalCount
* 所有分数平方的和 -> sqrtSum
* (如果Long不够大,你可以换其他专用的数据类型,例如BigInt)
* }}}
*
* Date: 2019/11/27 1:39
* @author ALion
*/
class PersonAggregator3(var totalGrade: Long, var totalCount: Long, var powSum: Long) {
/**
* 聚合成绩、人数的方法
*
* @param that 另一个聚合器
* @return this PersonAggregator
*/
def ++(that: PersonAggregator3): PersonAggregator3 = {
this.totalGrade += that.totalGrade
this.totalCount += that.totalCount
this.powSum += that.powSum
new PersonAggregator3(totalGrade, totalCount, powSum)
}
/**
* 计算平均值
*/
def calcAVG(): Double = {
totalGrade / totalCount.toDouble
}
/**
* 根据化简后的公式计算 RMSE
*/
def calcRMSE(): Double = {
val avg = calcAVG()
Math.sqrt(powSum / totalCount.toDouble - avg * avg)
}
// 懂lazy的话,就按下面的写法写
// lazy val avg: Double = totalGrade / totalCount.toDouble
//
// lazy val rmse: Double = Math.sqrt(sqrtSum / totalCount.toDouble - avg * avg)
override def toString: String =
"PersonAggregator3{" +
"avgGrade=" + calcAVG() +
", rmse=" + calcRMSE() +
'}'
}
object PersonAggregator3 {
def apply(math: Int): PersonAggregator3 =
new PersonAggregator3(math, 1, math * math)
}
val resultRDD = studentDF.rdd
.map(row => {
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
// 此处不用为每个元素生成一个大对象(集合等),无需使用aggregateByKey,你可以试着写一下:)
(year, PersonAggregator3(math))
}).reduceByKey(_ ++ _)
// 针对前面求RMSE的业务
val resultRDD = studentDF.rdd
.map { row =>
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (math, 1, math * math))
}.reduceByKey { case (t1, t2) =>
(t1._1 + t2._1, t1._2 + t2._2, t1._3 + t2._3)
}.mapValues { case (totalGrade, totalCount, powSum) =>
val avg = totalGrade / totalCount.toDouble
val rmse = Math.sqrt(powSum / totalCount.toDouble - avg * avg)
(avg, rmse)
}
/**
* Description: 自定义求RMSE的聚合函数
*
* @example {{{
* spark.udf.register("rmseUDAF", new MyRmseUDAF())
* spark.sql("SELECT rmseUDAF(math) FROM tb_person")
* }}}
*
* @author ALion
*/
class MyRmseUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = StructType(
StructField("math", LongType) :: Nil
)
override def bufferSchema: StructType = StructType(
StructField("totalGrade", LongType) ::
StructField("totalCount", LongType) ::
StructField("powSum", LongType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
buffer.update(2, 0L)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// totalGrade
buffer.update(0, buffer.getLong(0) + input.getLong(0))
// totalCount
buffer.update(1, buffer.getLong(1) + 1)
// powSum
buffer.update(2, buffer.getLong(2) + input.getLong(0) * input.getLong(0))
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// totalGrade
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
// totalCount
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
// powSum
buffer1.update(2, buffer1.getLong(2) + buffer2.getLong(2))
}
override def evaluate(buffer: Row): Any = {
val totalGrade = buffer.getLong(0)
val totalCount = buffer.getLong(1).toDouble
val powSum = buffer.getLong(2)
val avg = totalGrade / totalCount
// RMSE
Math.sqrt(powSum / totalCount - avg * avg)
}
}
import scala.collection.mutable
class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {
val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)
def +=(elem: A): MyTreeSet[A] = {
this add elem
this
}
def add(elem: A): Unit = {
set.add(elem)
// 删除排在最后的多余元素
check10Size()
}
def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
that.set.foreach(e => this add e)
this
}
def check10Size(): Unit = {
// 如果超过了firstNum个,就删除
if (set.size > firstNum) {
set -= set.last
}
}
override def toString: String = set.toString
}
object MyTreeSet {
def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem)
def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
}