常规的窗口函数当然没有什么好说的,非常简单,这里介绍一下分组的,重点是分组、排序之后的rows between用法。
关键是理解rows between中关键字含义:
关键字 | 含义 |
---|---|
preceding | 往前 |
following | 往后 |
current row | 当前行 |
unbounded | 开始行 |
unbounded preceding | 表示从前面的起点 |
unbounded following | 表示到后面的终点 |
直接看有些抽象,下面看例子。
select country,time,charge,
max(charge) over (partition by country order by time) as normal,
max(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur,
max(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1,
max(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol
from temp
默认是在分组类的当前行之前的行中计算。
rows between unbounded preceding and current row和默认的一样
rows between 2 preceding and 1 following表示在当前行的前2行和后1行中计算
rows between current row and unbounded following表示在当前行和到最后行中计算
rows between对于avg、min、max、sum这几个窗口函数的含义基本是一致的。
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.Before;
import org.junit.Test;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class SparkHiveFunctionTest implements Serializable{
private static final String DATA_PATH = "F:\\tmp\\charge.csv";
private static final String DATA_OBJECT_PATH = "F:\\tmp\\charge";
private static String[] counties = {"中国","俄罗斯","美国","日本","韩国"};
private SparkSession sparkSession;
private Dataset<Charge> dataset;
@Before
public void setUp(){
sparkSession = SparkSession
.builder()
.appName("test")
.master("local")
.getOrCreate();
sparkSession.sparkContext().setLogLevel("WARN");
}
@Test
public void start() throws IOException, ClassNotFoundException {
// List infos = getData(true);
List<Charge> infos = getData(false);
dataset = sparkSession.createDataset(infos, Encoders.bean(Charge.class));
// sum();
// avg();
// min();
max();
}
private void sum(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,time,charge," +
"sum(charge) over (partition by country order by time) as normal," +
"sum(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
"sum(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
"sum(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
" from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
}
private void avg(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,time,charge," +
"avg(charge) over (partition by country order by time) as normal," +
"avg(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
"avg(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
"avg(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
" from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
}
private void min(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,time,charge," +
"min(charge) over (partition by country order by time) as normal," +
"min(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
"min(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
"min(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
" from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
}
private void max(){
dataset.createOrReplaceTempView("temp");
String sql = "select country,time,charge," +
"max(charge) over (partition by country order by time) as normal," +
"max(charge) over (partition by country order by time rows between unbounded preceding and current row) as unb_pre_cur," +
"max(charge) over (partition by country order by time rows between 2 preceding and 1 following) as pre2_fol1," +
"max(charge) over (partition by country order by time rows between current row and unbounded following) as cur_unb_fol" +
" from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(100);
}
private static List<Charge> getData(Boolean newGen) throws IOException, ClassNotFoundException {
if(newGen != null && newGen == true){
return generateData();
}else {
return readList();
}
}
private static List<Charge> generateData() throws IOException {
FileWriter fileWriter = new FileWriter(DATA_PATH);
LinkedList<Charge> infos = new LinkedList<>();
Random random = new Random();
LocalDate localDate = LocalDate.of(2020, 1, 4);
DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
for(int i=0;i<50;i++){
Charge info = new Charge();
String county = counties[random.nextInt(counties.length)];
info.setCountry(county);
int day = random.nextInt(10);
LocalDate date = localDate.plusDays(day);
String time = date.format(dateTimeFormatter);
int charge = 10000 + random.nextInt(10000);
info.setCharge(charge);
info.setTime(time);
infos.add(info);
fileWriter.write(String.format("%s,%s,%d\n",county,time,charge));
}
fileWriter.flush();
writeList(infos);
return infos;
}
private static void writeList(LinkedList<Charge> infos) throws IOException {
FileOutputStream fos = new FileOutputStream(DATA_OBJECT_PATH);
ObjectOutputStream oos = new ObjectOutputStream(fos);
oos.writeObject(infos);
}
private static LinkedList<Charge> readList() throws IOException, ClassNotFoundException {
FileInputStream fis = new FileInputStream(DATA_OBJECT_PATH);
ObjectInputStream ois = new ObjectInputStream(fis);
LinkedList<Charge> list = (LinkedList) ois.readObject();
return list;
}
/**
* 必须public,必须实现Serializable
*/
public static class Charge implements Serializable {
/**
* 国家
*/
private String country;
/**
* 充值时间
*/
private String time;
/**
* 充值金额
*/
private Integer charge;
public String getCountry() {
return country;
}
public void setCountry(String country) {
this.country = country;
}
public String getTime() {
return time;
}
public void setTime(String time) {
this.time = time;
}
public Integer getCharge() {
return charge;
}
public void setCharge(Integer charge) {
this.charge = charge;
}
}
}