数据统计分析中,对数据进行分组,取出每组数据的前 N 条数据
是非常经典的实践案例。举个例子
如下表 emp
所示:
empno | ename | job | sal |
---|---|---|---|
7369 | SMITH | CLERK | 800.0 |
7876 | SMITH | CLERK | 1100.0 |
7900 | JAMES | CLERK | 950.0 |
7934 | MILLER | CLERK | 1300.0 |
7499 | ALLEN | SALESMAN | 1600.0 |
7654 | MARTIN | SALESMAN | 1250.0 |
7844 | TURNER | SALESMAN | 1500.0 |
7521 | WARD | SALESMAN | 1250.0 |
通常可以通过以下方法实现:
-
取出每条数据的行号,再用 where 语句进行过滤。
SELECT * FROM ( SELECT empno , ename , sal , job , ROW_NUMBER() OVER (PARTITION BY job ORDER BY sal) AS rn FROM emp ) tmp WHERE rn < 10;
除此方法之外,希望使用 udaf 的方法去实现。
1. UDAF 说明
User-Defined Aggregation Functions (UDAFs) are an excellent way to integrate advanced data-processing into Hive. Hive allows two varieties of UDAFs: simple and generic. Simple UDAFs, as the name implies, are rather simple to write, but incur performance penalties because of the use of Java Reflection, and do not allow features such as variable-length argument lists. Generic UDAFs allow all these features but are perhaps not quite as intuitive to write as Simple UDAFs.
官方文档的介绍中,UDAF 的实现有简单与通用两种方式,简单 UDAF 因为使用Java反射导致性能损失,并且不允许使用变长参数列表等功能,已经被弃用了。主要说下通用 UDAF 的开发实现 分组取前Top N
。
2. 继承 AbstractGenericUDAFResolver
类
继承 AbstractGenericUDAFResolver
类是官网比较推荐的方法,但是开发的难度相对增加了。总体的开发流程是:
- 重写
getEvaluator
方法,用于校验和判断 UDAF 函数的参数信息。 - 新建一个
public
的静态类,并且继承GenericUDAFEvaluator
类,类中包含的多个虚拟函数共同构建了 UDAF 的处理逻辑。实现必要的7个方法:
Function | Purpose | Return |
---|---|---|
init |
确定各个阶段输入输出参数的数据格式 ObjectInspectors |
ObjectInspector |
getNewAggregationBuffer |
保存数据聚集结果的类 | AggregationBuffer |
iterate |
map 阶段,迭代处理输入sql 传过来的列数据 |
void |
terminatePartial |
map 与 combiner 结束返回结果,得到部分数据聚集结果 |
Object |
merge |
combiner 合并 map 返回的结果,还有 reducer 合并 mapper 或combiner 返回的结果。 |
void |
terminate |
reducer 阶段,输出最终结果 |
Object |
reset |
重置聚集结果 | void |
下述代码就是 UDAF 的基本框架
public class GenericUDAFHistogramNumeric extends AbstractGenericUDAFResolver {
static final Log LOG = LogFactory.getLog(GenericUDAFHistogramNumeric.class.getName());
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
// Type-checking goes here!
return new GenericUDAFHistogramNumericEvaluator();
}
public static class GenericUDAFHistogramNumericEvaluator extends GenericUDAFEvaluator {
// UDAF logic goes here!
}
}
Hive 的 sql 执行最终都会转译成 MapReduce
任务,以上的几个函数的执行都和 MapReduce
任务调用紧密相关。为了标识任务的进度,GenericUDAFEvaluator
类中包含了一个枚举类 Mode
,用于标示任务执行阶段。
public static enum Mode {
/**
* PARTIAL1: from original data to partial aggregation data: iterate() and
* terminatePartial() will be called.
*/
PARTIAL1,
/**
* PARTIAL2: from partial aggregation data to partial aggregation data:
* merge() and terminatePartial() will be called.
*/
PARTIAL2,
/**
* FINAL: from partial aggregation to full aggregation: merge() and
* terminate() will be called.
*/
FINAL,
/**
* COMPLETE: from original data directly to full aggregation: iterate() and
* terminate() will be called.
*/
COMPLETE
};
其中,
-
PARTIAL1
对应着map
操作,在这个阶段,程序分别依次调用iterate
方法和terminatePartial
方法,以处理每一条原始的输入数据,然后做map
端的初步融合 -
PARTIAL2
对应着combiner
阶段,他依次调用merge
方法和terminatePartial
方法,对map
端的数据做进一步的聚合 -
FINAL
是reduce
阶段。他调用merge
方法和terminate
方法,整合最终结果 -
COMPLETE
指代没有reduce
任务的map only
操作。他直接调用iterate
和terminate
方法获取最终结果。
│ `input` `output`
│---------------------------------------
├── `PARTIAL1`
│ ├── `iterate` ──> `terminatePartial`
├── `PARTIAL2`
│ ├── `merge` ──> `terminatePartial`
├── `FINAL`
│ ├── `merge` ──> `terminate`
├── `COMPLETE`
│ └── `iterate` ──> `terminate`
所以,原始表列数据只会在 PARTIAL1
和 COMPLETE
阶段中出现,并且terminatePartial
方法只会在 PARTIAL1
和 PARTIAL2
阶段调用,这正好说明了 map
输出和 combiner
的输出类型一定是一致的,merge
方法只在PARTIAL2
和 FINAL
阶段调用,说明 combiner
和 reduce
的输入类型是一致的。最后,在以上的四个阶段里,一开始都会调用 init
方法来指明输入输出。所以在 init
方法中有个枚举类 Mode
,专于判断任务的执行阶段。
3. 实现 Top N
通常一组随机数据序列,按照升序或降序选取 Top N,需要用到优先队列的数据结构,在此也不例外,通过封装 JDK 原生 java.util.PriorityQueue
队列构造可选大小,升序/降序 方式的优先队列 FixSizedPriorityQueue
。
public class FixSizedPriorityQueue {
PriorityQueue queue;
private int maxSize = 1;
private int orderFlag = 0;
public FixSizedPriorityQueue() {
}
public FixSizedPriorityQueue(int maxSize, int orderFlag) {
if (maxSize <= 0) {
throw new IllegalArgumentException("the size of queue must larger than zero");
}
this.maxSize = maxSize;
this.orderFlag = orderFlag;
if (orderFlag == 0) {
this.queue = new PriorityQueue<>(maxSize, desc);
}
if (orderFlag == 1) {
this.queue = new PriorityQueue<>(maxSize, esc);
}
}
private Comparator desc = new Comparator() { // 大堆顶
@Override
public int compare(E o1, E o2) {
return (o2.compareTo(o1)); // 父节点比子节点大,则不做替换
}
};
private Comparator esc = new Comparator() { // 小堆顶
@Override
public int compare(E o1, E o2) {
return (o1.compareTo(o2)); // 子节点比父节点大,则不做替换
}
};
public int getMaxSize() {
return maxSize;
}
public void setMaxSize(int maxSize) {
this.maxSize = maxSize;
}
public int getOrderFlag() {
return orderFlag;
}
public void setOrderFlag(int orderFlag) {
this.orderFlag = orderFlag;
}
public void add(E e) {
if (queue.size() < maxSize) {
queue.add(e);
} else {
E peek = queue.peek();
if (orderFlag == 0) {
if (e.compareTo(peek) < 0) { // 把堆顶最大值替换
queue.poll();
queue.add(e);
}
}
if (orderFlag == 1) {
if (e.compareTo(peek) > 0) { // 把堆顶最小值替换
queue.poll();
queue.add(e);
}
}
}
}
private List toArray() {
List list = new ArrayList<>();
list.addAll(queue);
return list;
}
public int getLength() {
return queue.size();
}
public void reset() {
queue = null;
maxSize = 1;
orderFlag = 0;
}
public List toOrderArray(int orderFlag) {
List list = toArray();
if (orderFlag == 0) {
list.sort(esc);
} else {
list.sort(desc);
}
return list;
}
public int lengthFor(JavaDataModel model) {
int length = model.object();
length += model.primitive1();
length += model.primitive1();
length += model.object() * 2; // two comparator
if (maxSize > 0) {
length += model.arrayList(); // List
length += maxSize * model.object();
}
return length;
}
}
4. 构建 UDAF
函数调用方式为 group_n(col, n, 0)
,其中 ${arg1}=col
为表中列数据; ${arg2}=n
为分组取前 top n
; ${arg3}=0
表示升序,${arg3}=1
表示降序。 [若不指定第三个参数,默认为0]。
在 GroupTopN
类中,使用 GroupNAggBuffer
来临时存储参数 n
和 0/1
,以及 FixSizedPriorityQueue
和 List
。聚合缓存对象在 MapReduce
的各个执行过程中调用 terminatePartial
和 terminate
发挥重要作用。
@SuppressWarnings("deprecation")
@Description(name = "group_n", value = "_FUNC_(x, n, 0/1) -- 0: esc; 1: desc. Returns the array contains N elements of a set of numbers")
public class GroupTopN extends AbstractGenericUDAFResolver {
static final Logger LOG = LoggerFactory.getLogger(GroupTopN.class);
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
if (info.length < 2 || info.length > 3) {
throw new UDFArgumentTypeException(info.length - 1, "At least two arguments are expected, " +
"but no more than three arguments");
}
// validate the first parameter, which is the expression to compute over
if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but "
+ info[0].getTypeName() + " was passed as parameter 1.");
}
switch (((PrimitiveTypeInfo) info[0]).getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case TIMESTAMP:
case DECIMAL:
case STRING:
break;
case BOOLEAN:
case DATE:
default:
throw new UDFArgumentTypeException(0,
"Only numeric type arguments are accepted but "
+ info[0].getTypeName() + " was passed as parameter 1.");
}
if (((PrimitiveTypeInfo) info[1]).getPrimitiveCategory()
!= PrimitiveObjectInspector.PrimitiveCategory.INT) {
throw new UDFArgumentTypeException(1
, "Only an integer argument is accepted as parameter 1, but "
+ info[1].getTypeName() + " was passed instead.");
}
if (info.length == 3 && ((PrimitiveTypeInfo) info[2]).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.INT) {
throw new UDFArgumentTypeException(2
, "Only an integer argument is accepted as parameter 2, but "
+ info[2].getTypeName() + " was passed instead.");
}
return new GroupNEvaluator();
}
public static class GroupNEvaluator extends GenericUDAFEvaluator {
// For PARTIAL1 and COMPLETE: ObjectInspectors for original data
private transient PrimitiveObjectInspector inputOi;
private transient PrimitiveObjectInspector nthOi;
private transient PrimitiveObjectInspector orderOi;
// For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list of element)
protected transient Object[] partialResult;
private transient StructObjectInspector soi;
private transient StructField nthField;
private transient StructField orderField;
private transient StructField listField;
private transient ListObjectInspector loi;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
// init input object inspectors
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) { // PARTIAL1 和 COMPLETE 开始都是调用 iterate 接口,入参
inputOi = (PrimitiveObjectInspector) parameters[0];
nthOi = (PrimitiveObjectInspector) parameters[1];
if (parameters.length == 2) {
orderOi = PrimitiveObjectInspectorFactory.getPrimitiveObjectInspectorFromClass(Integer.class);
}
if (parameters.length == 3) {
orderOi = (PrimitiveObjectInspector) parameters[2];
}
} else { // PARTIAL2 和 FINAL 开始都是调用 merge 接口,入参
soi = (StructObjectInspector) parameters[0];
nthField = soi.getStructFieldRef("n");
orderField = soi.getStructFieldRef("order");
listField = soi.getStructFieldRef("data");
nthOi = (PrimitiveObjectInspector) nthField.getFieldObjectInspector();
orderOi = (PrimitiveObjectInspector) orderField.getFieldObjectInspector();
loi = (ListObjectInspector) listField.getFieldObjectInspector();
}
// init output object inspectors
if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) { // PARTIAL1 和 PARTIAL2 最后调用的接口都是 terminatePartial, 输出 list 格式
ArrayList foi = new ArrayList<>();
foi.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
foi.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
if (m == Mode.PARTIAL1) {
foi.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputOi.getTypeInfo())));
}
if (m == Mode.PARTIAL2) {
foi.add(ObjectInspectorFactory.getStandardListObjectInspector(loi.getListElementObjectInspector()));
}
ArrayList fname = new ArrayList<>();
fname.add("n");
fname.add("order");
fname.add("data");
partialResult = new Object[3];
partialResult[0] = new IntWritable(1);
partialResult[1] = new IntWritable(0);
partialResult[2] = new ArrayList();
return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
} else { // COMPLETE 和 FINAL 最后调用的接口都是 terminal
return ObjectInspectorFactory.getStandardListObjectInspector(
loi.getListElementObjectInspector());
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
GroupNAggBuffer buffer = new GroupNAggBuffer();
reset(buffer);
return buffer;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
buffer.container.clear();
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (parameters[0] == null || parameters[1] == null) {
return;
}
if (parameters.length == 3 && parameters[2] == null) {
return;
}
GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
int n = PrimitiveObjectInspectorUtils.getInt(parameters[1], nthOi);
partialResult[0] = new IntWritable(n);
buffer.n = n;
int order = 0;
if (parameters.length == 3) {
order = PrimitiveObjectInspectorUtils.getInt(parameters[2], orderOi);
}
if (order != 0 && order != 1) {
throw new IllegalArgumentException("the order must be '0' or '1' ");
}
partialResult[1] = new IntWritable(order);
buffer.order = order;
Object v = ObjectInspectorUtils.copyToStandardObject(parameters[0], inputOi);
buffer.container.add(v);
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
((IntWritable) partialResult[0]).set(buffer.n);
((IntWritable) partialResult[1]).set(buffer.order);
partialResult[2] = buffer.container;
return partialResult;
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial == null) {
return;
}
GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
int n = ((IntWritable) soi.getStructFieldData(partial, nthField)).get();
int order = ((IntWritable) soi.getStructFieldData(partial, orderField)).get();
List
其中,在 PARTIAL1
阶段,GroupNEvaluator
对象将输入数据和参数组成类似 { "n": 3, "order": 1, "data": [1,2,3,4,5] }
的数据结构并输出。PARTIAL2
阶段或者 FINAL
阶段将 { "n": 3, "order": 1, "data": [1,2,3,4,5] }
数据结构作为输入,最终按照降序 (order = 1 表示降序)输出 [5,4,3]
的 Top 3
集合。
5. 测试结果
-
add jar /Users/yizhou/dtstack/codes/udaf/target/udaf-1.0-SNAPSHOT.jar;
上传jar
资源 -
create temporary function group_n as 'com.func.udaf.GroupTopN';
注册临时group_n
函数 -
select job,group_n(sal, 3, 1) from emp group by job;
执行函数
最终结果,
job | top_sal_3 |
---|---|
CLERK | [1300.0,1100.0, 950.0] |
SALESMAN | [1600.0,1500.0,1250.0] |
参考资料
- 分组取出每组数据的前N条
- GenericUDAFCaseStudy
- Hive编程开发(2)
- Hive UDAF开发详解