name,profession,enroll,score
曾凰妹,金融学,北京电子科技学院,637
谢德炜,金融学,北京电子科技学院,542
林逸翔,金融学,北京电子科技学院,543
王丽云,金融学,北京电子科技学院,626
吴鸿毅,金融学,北京电子科技学院,591
施珊珊,经济学类,北京理工大学,581
柯祥坤,经济学类,北京理工大学,650
庄劲聪,经济学类,北京理工大学,551
吴雅思,经济学类,北京理工大学,529
周育传,经济学类,北京理工大学,682
丁俊伟,通信工程,北京电子科技学院,708
庄逸琳,通信工程,北京电子科技学院,708
吴志发,通信工程,北京电子科技学院,578
肖妮娜,通信工程,北京电子科技学院,557
蔡建明,通信工程,北京电子科技学院,583
林逸翔,通信工程,北京电子科技学院,543
数据无现实意义
+------+----------+----------------+-----+------+
| name|profession| enroll|score|row_id|
+------+----------+----------------+-----+------+
|林逸翔| 通信工程|北京电子科技学院| 543| 1|
|肖妮娜| 通信工程|北京电子科技学院| 557| 2|
|吴志发| 通信工程|北京电子科技学院| 578| 3|
|蔡建明| 通信工程|北京电子科技学院| 583| 4|
|丁俊伟| 通信工程|北京电子科技学院| 708| 5|
|庄逸琳| 通信工程|北京电子科技学院| 708| 5|
|谢德炜| 金融学|北京电子科技学院| 542| 1|
|林逸翔| 金融学|北京电子科技学院| 543| 2|
|吴鸿毅| 金融学|北京电子科技学院| 591| 3|
|王丽云| 金融学|北京电子科技学院| 626| 4|
|曾凰妹| 金融学|北京电子科技学院| 637| 5|
|吴雅思| 经济学类| 北京理工大学| 529| 1|
|庄劲聪| 经济学类| 北京理工大学| 551| 2|
|施珊珊| 经济学类| 北京理工大学| 581| 3|
|柯祥坤| 经济学类| 北京理工大学| 650| 4|
|周育传| 经济学类| 北京理工大学| 682| 5|
+------+----------+----------------+-----+------+
package com.cch.bigdata.spark.process.rownuber
import com.cch.bigdata.spark.process.AbstractTransform
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, dense_rank, rank}
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row}
//使用分组排名 rank() over, dense_rank(), row_number()
class RowNumber extends AbstractTransform {
//序列号增加方式:
//global表示全局递增
//group表示分组排序序列号
private val add_type_set = Set("global","group")
private val add_type = "group"
//如果序列号生成方式使用了分组排序的方式,则需要选择分组策略
//rank表示跳跃排序,例如1,1,3,4 (分区内)
//row_number表示连续排序,例如1,2,3,4(分区内)
private val group_strategy_set = Set("rank","row_number")
private val group_strategy = "rank"
//分组列
private val group_columns = Array[String]("enroll","profession")
//排序列
private val sort_column = "score"
//排序类型,默认升序
private val sort_type_set = Set("asc","desc")
private val sort_type = "asc"
//如果add_type选择的是global_number,则生成的序列号列名叫row_id
private val serial_name = "row_id"
override def process(): Unit = {
if (add_type.isEmpty) {
throw new RuntimeException("序列号列类型不能为空")
}
if(!add_type_set.contains(add_type)){
throw new RuntimeException("序列号增加方式错误,只支持[global/group]")
}
if (add_type.equals("group")&&group_strategy.isEmpty) {
throw new RuntimeException("序列号生成分组策略不能为空")
}
if(!group_strategy_set.contains(group_strategy)&&add_type.equals("group")){
throw new RuntimeException("分组策略错误,只支持[rank/row_number]")
}
if (add_type.equals("group")&&(group_columns.isEmpty || group_columns.length == 0)) {
throw new RuntimeException("分组列不能为空")
}
if (add_type.equals("group")&&sort_column.isEmpty) {
throw new RuntimeException("排序列不能为空")
}
if(!sort_type_set.contains(sort_type)&&add_type.equals("group")){
throw new RuntimeException("排序类型错误,只支持[asc/desc]")
}
//获取输入流
val df: DataFrame = loadCsv("src/main/resources/csv/admission.csv",spark)
add_type match {
case "global" => {
globalNumberCreate(df)
}
case "group" => {
group_strategy match {
case "rank" =>{
groupNumberCreateByRowNumber(df)
}
case "row_number" =>{
groupNumberCreateByRank(df)
}
}
}
case _ =>{
throw new RuntimeException("分组策略错误,只支持[global/group]")
}
}
}
//分组序列号生成 rank方式
def groupNumberCreateByRank(df: DataFrame): Unit = {
//连续排名
val result: DataFrame = df.withColumn(serial_name, dense_rank().over(Window.partitionBy(group_columns.map(c => {
col(c)
}): _*).orderBy(
if(sort_type.equals("asc")) col(sort_column).asc else col(sort_column).desc
)))
result.show()
}
//分组序列号生成 row_number方式
def groupNumberCreateByRowNumber(df: DataFrame): Unit = {
//连续排名
val result: DataFrame = df.withColumn(serial_name, rank().over(Window.partitionBy(group_columns.map(c => {
col(c)
}): _*).orderBy(
if(sort_type.equals("asc")) col(sort_column).asc else col(sort_column).desc
)))
result.show()
}
//生成全局的序列号
def globalNumberCreate(df: DataFrame): Unit = {
import scala.collection.mutable._
val fieldNameList: List[String] = df.schema.fieldNames.toList
//添加列
val schema: StructType = df.schema.add(StructField(serial_name, LongType))
val queryColumnList: ListBuffer[String] = ListBuffer()
queryColumnList.append(serial_name)
queryColumnList.append(fieldNameList.map(x => {
x
}): _*)
val dfRDD: RDD[(Row, Long)] = df.rdd.zipWithIndex()
val rowRDD: RDD[Row] = dfRDD.map(tp => Row.merge(tp._1, Row((tp._2+1))))
// 将添加了索引的RDD 转化为DataFrame
val result: DataFrame = spark.createDataFrame(rowRDD, schema)
//为了将自增列放到头部,要进行查询
result.select(queryColumnList.map(c => {
col(c)
}): _*).show()
}
override def getAppName(): String = "添加索引列"
}
object RowNumber{
def main(args: Array[String]): Unit = {
new RowNumber().process()
}
}