其他更多java基础文章:
java基础学习(目录)
SparkSQL中自定义UDF和UDAF,开窗函数的应用
概述
SparkSQL中的UDF相当于是1进1出,UDAF相当于是多进一出,类似于聚合函数。
开窗函数一般分组取topn时常用。
UDF
SparkConf conf = new SparkConf();
conf.setMaster("local");
conf.setAppName("udf");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
JavaRDD parallelize = sc.parallelize(Arrays.asList("zhansan","lisi","wangwu"));
JavaRDD rowRDD = parallelize.map(new Function() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Row call(String s) throws Exception {
return RowFactory.create(s);
}
});
List fields = new ArrayList();
fields.add(DataTypes.createStructField("name", DataTypes.StringType,true));
StructType schema = DataTypes.createStructType(fields);
DataFrame df = sqlContext.createDataFrame(rowRDD,schema);
df.registerTempTable("user");
/**
* 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
*/
sqlContext.udf().register("StrLen", new UDF1() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Integer call(String t1) throws Exception {
return t1.length();
}
}, DataTypes.IntegerType);
sqlContext.sql("select name ,StrLen(name) as length from user").show();
复制代码
new UDF1
这些参数需要对应,UDF2就是表示传两个参数,UDF3就是传三个参数。例如new UDF2
,表示传入两个参数,第一个为String类型,第二个为Integer类型,返回Integer类型的数据。
UDAF
UDAF: 用户自定义聚合函数,user defined aggreagatefunction
实现UDAF函数如果要自定义类要继承UserDefinedAggregateFunction类
package com.spark.sparksql.udf_udaf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
/**
* UDAF 用户自定义聚合函数
* @author root
*
*/
public class UDAF {
public static void main(String[] args) {
SparkConf conf = new SparkConf();
conf.setMaster("local").setAppName("udaf");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
JavaRDD parallelize = sc.parallelize(
Arrays.asList("zhangsan","lisi","wangwu","zhangsan","zhangsan","lisi"));
JavaRDD rowRDD = parallelize.map(new Function() {
/**
*
*/
private static final long serialVersionUID = 1L;
@Override
public Row call(String s) throws Exception {
return RowFactory.create(s);
}
});
List fields = new ArrayList();
fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("user");
/**
* 注册一个UDAF函数,实现统计相同值得个数
* 注意:这里可以自定义一个类继承UserDefinedAggregateFunction类也是可以的
*/
sqlContext.udf().register("StringCount",new UserDefinedAggregateFunction() {
/**
*
*/
private static final long serialVersionUID = 1L;
/**
* 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0);
}
/**
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* buffer.getInt(0)获取的是上一次聚合后的值
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 大聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
*/
@Override
public void update(MutableAggregationBuffer buffer, Row arg1) {
buffer.update(0, buffer.getInt(0)+1);
}
/**
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值
* buffer2.getInt(0) : 这次计算传入进来的update的结果
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
* 也可以是一个节点里面的多个executor合并
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
}
/**
* 在进行聚合操作的时候所要处理的数据的结果的类型
*/
@Override
public StructType bufferSchema() {
return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("bffer111", DataTypes.IntegerType, true)));
}
/**
* 最后返回一个和DataType的类型要一致的类型,返回UDAF最后的计算结果
*/
@Override
public Object evaluate(Row row) {
return row.getInt(0);
}
/**
* 指定UDAF函数计算后返回的结果类型
*/
@Override
public DataType dataType() {
return DataTypes.IntegerType;
}
/**
* 指定输入字段的字段及类型
*/
@Override
public StructType inputSchema() {
return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("nameeee", DataTypes.StringType, true)));
}
/**
* 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
*/
@Override
public boolean deterministic() {
return true;
}
});
sqlContext.sql("select name ,StringCount(name) as strCount from user group by name").show();
sc.stop();
}
}
复制代码
开窗函数
row_number() 开窗函数是按照某个字段分组,然后取另一字段的前几个的值,相当于分组取topN
开窗函数格式:row_number() over (partitin by XXX order by XXX)
package com.spark.sparksql.windowfun;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.hive.HiveContext;
/**是hive的函数,必须在集群中运行。
* row_number()开窗函数:
* 主要是按照某个字段分组,然后取另一字段的前几个的值,相当于 分组取topN
* row_number() over (partition by xxx order by xxx desc) xxx
* 注意:
* 如果SQL语句里面使用到了开窗函数,那么这个SQL语句必须使用HiveContext来执行,HiveContext默认情况下在本地无法创建
* @author root
*
*/
public class RowNumberWindowFun {
public static void main(String[] args) {
SparkConf conf = new SparkConf();
conf.setAppName("windowfun");
JavaSparkContext sc = new JavaSparkContext(conf);
HiveContext hiveContext = new HiveContext(sc);
hiveContext.sql("use spark");
hiveContext.sql("drop table if exists sales");
hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) "
+ "row format delimited fields terminated by '\t'");
hiveContext.sql("load data local inpath '/root/test/sales' into table sales");
/**
* 开窗函数格式:
* 【 row_number() over (partition by XXX order by XXX) as rank】//起个别名
* 注意:rank 从1开始
*/
/**
* 以类别分组,按每种类别金额降序排序,显示 【日期,种类,金额】 结果,如:
*
* 1 A 100
* 2 B 200
* 3 A 300
* 4 B 400
* 5 A 500
* 6 B 600
* 排序后:
* 5 A 500 --rank 1
* 3 A 300 --rank 2
* 1 A 100 --rank 3
* 6 B 600 --rank 1
* 4 B 400 --rank 2
* 2 B 200 --rank 3
*
*/
DataFrame result = hiveContext.sql("select riqi,leibie,jine "
+ "from ("
+ "select riqi,leibie,jine,"
+ "row_number() over (partition by leibie order by jine desc) rank "
+ "from sales) t "
+ "where t.rank<=3");
result.show(100);
/**
* 将结果保存到hive表sales_result
*/
result.write().mode(SaveMode.Overwrite).saveAsTable("sales_result");
sc.stop();
}
}
复制代码