Hive 通过 UDAF 实现 `分组取前 Top N`

数据统计分析中,对数据进行分组,取出每组数据的前 N 条数据 是非常经典的实践案例。举个例子


如下表 emp 所示:

empno ename job sal
7369 SMITH CLERK 800.0
7876 SMITH CLERK 1100.0
7900 JAMES CLERK 950.0
7934 MILLER CLERK 1300.0
7499 ALLEN SALESMAN 1600.0
7654 MARTIN SALESMAN 1250.0
7844 TURNER SALESMAN 1500.0
7521 WARD SALESMAN 1250.0

通常可以通过以下方法实现:

  • 取出每条数据的行号,再用 where 语句进行过滤。

    SELECT * FROM (
      SELECT empno
      , ename
      , sal
      , job
      , ROW_NUMBER() OVER (PARTITION BY job ORDER BY sal) AS rn
      FROM emp
    ) tmp
    WHERE rn < 10;
    

除此方法之外,希望使用 udaf 的方法去实现。


1. UDAF 说明

User-Defined Aggregation Functions (UDAFs) are an excellent way to integrate advanced data-processing into Hive. Hive allows two varieties of UDAFs: simple and generic. Simple UDAFs, as the name implies, are rather simple to write, but incur performance penalties because of the use of Java Reflection, and do not allow features such as variable-length argument lists. Generic UDAFs allow all these features but are perhaps not quite as intuitive to write as Simple UDAFs.

官方文档的介绍中,UDAF 的实现有简单与通用两种方式,简单 UDAF 因为使用Java反射导致性能损失,并且不允许使用变长参数列表等功能,已经被弃用了。主要说下通用 UDAF 的开发实现 分组取前Top N

2. 继承 AbstractGenericUDAFResolver

继承 AbstractGenericUDAFResolver 类是官网比较推荐的方法,但是开发的难度相对增加了。总体的开发流程是:

  • 重写 getEvaluator 方法,用于校验和判断 UDAF 函数的参数信息。
  • 新建一个 public 的静态类,并且继承 GenericUDAFEvaluator 类,类中包含的多个虚拟函数共同构建了 UDAF 的处理逻辑。实现必要的7个方法:
Function Purpose Return
init 确定各个阶段输入输出参数的数据格式 ObjectInspectors ObjectInspector
getNewAggregationBuffer 保存数据聚集结果的类 AggregationBuffer
iterate map 阶段,迭代处理输入sql传过来的列数据 void
terminatePartial mapcombiner 结束返回结果,得到部分数据聚集结果 Object
merge combiner 合并 map 返回的结果,还有 reducer 合并 mappercombiner 返回的结果。 void
terminate reducer 阶段,输出最终结果 Object
reset 重置聚集结果 void

下述代码就是 UDAF 的基本框架

public class GenericUDAFHistogramNumeric extends AbstractGenericUDAFResolver {
  static final Log LOG = LogFactory.getLog(GenericUDAFHistogramNumeric.class.getName());
 
  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
    // Type-checking goes here!
    return new GenericUDAFHistogramNumericEvaluator();
  }
 
  public static class GenericUDAFHistogramNumericEvaluator extends GenericUDAFEvaluator {
    // UDAF logic goes here!
  }
}

Hive 的 sql 执行最终都会转译成 MapReduce 任务,以上的几个函数的执行都和 MapReduce 任务调用紧密相关。为了标识任务的进度,GenericUDAFEvaluator 类中包含了一个枚举类 Mode,用于标示任务执行阶段。

public static enum Mode {
    /**
     * PARTIAL1: from original data to partial aggregation data: iterate() and
     * terminatePartial() will be called.
     */
    PARTIAL1,
    /**
     * PARTIAL2: from partial aggregation data to partial aggregation data:
     * merge() and terminatePartial() will be called.
     */
    PARTIAL2,
    /**
     * FINAL: from partial aggregation to full aggregation: merge() and
     * terminate() will be called.
     */
    FINAL,
    /**
     * COMPLETE: from original data directly to full aggregation: iterate() and
     * terminate() will be called.
     */
    COMPLETE
  };

其中,

  • PARTIAL1 对应着 map 操作,在这个阶段,程序分别依次调用 iterate 方法和 terminatePartial 方法,以处理每一条原始的输入数据,然后做 map 端的初步融合
  • PARTIAL2 对应着 combiner 阶段,他依次调用 merge 方法和terminatePartial 方法,对 map 端的数据做进一步的聚合
  • FINALreduce 阶段。他调用 merge 方法和 terminate 方法,整合最终结果
  • COMPLETE 指代没有 reduce 任务的 map only 操作。他直接调用 iterateterminate 方法获取最终结果。
│    `input`          `output`
│---------------------------------------
├── `PARTIAL1`
│   ├── `iterate` ──> `terminatePartial`
├── `PARTIAL2`  
│   ├── `merge` ──> `terminatePartial`
├── `FINAL`  
│   ├── `merge` ──> `terminate`
├── `COMPLETE`  
│   └── `iterate` ──> `terminate`

所以,原始表列数据只会在 PARTIAL1COMPLETE 阶段中出现,并且terminatePartial 方法只会在 PARTIAL1PARTIAL2 阶段调用,这正好说明了 map 输出和 combiner 的输出类型一定是一致的,merge 方法只在PARTIAL2FINAL 阶段调用,说明 combinerreduce 的输入类型是一致的。最后,在以上的四个阶段里,一开始都会调用 init 方法来指明输入输出。所以在 init 方法中有个枚举类 Mode,专于判断任务的执行阶段。

3. 实现 Top N

通常一组随机数据序列,按照升序或降序选取 Top N,需要用到优先队列的数据结构,在此也不例外,通过封装 JDK 原生 java.util.PriorityQueue 队列构造可选大小,升序/降序 方式的优先队列 FixSizedPriorityQueue

public class FixSizedPriorityQueue {

    PriorityQueue queue;

    private int maxSize = 1;

    private int orderFlag = 0;

    public FixSizedPriorityQueue() {

    }

    public FixSizedPriorityQueue(int maxSize, int orderFlag) {
        if (maxSize <= 0) {
            throw new IllegalArgumentException("the size of queue must larger than zero");
        }
        this.maxSize = maxSize;
        this.orderFlag = orderFlag;
        if (orderFlag == 0) {
            this.queue = new PriorityQueue<>(maxSize, desc);
        }
        if (orderFlag == 1) {
            this.queue = new PriorityQueue<>(maxSize, esc);
        }
    }

    private Comparator desc = new Comparator() { // 大堆顶
        @Override
        public int compare(E o1, E o2) {
            return (o2.compareTo(o1)); // 父节点比子节点大,则不做替换
        }
    };

    private Comparator esc = new Comparator() {  // 小堆顶
        @Override
        public int compare(E o1, E o2) {
            return (o1.compareTo(o2)); // 子节点比父节点大,则不做替换
        }
    };

    public int getMaxSize() {
        return maxSize;
    }

    public void setMaxSize(int maxSize) {
        this.maxSize = maxSize;
    }

    public int getOrderFlag() {
        return orderFlag;
    }

    public void setOrderFlag(int orderFlag) {
        this.orderFlag = orderFlag;
    }

    public void add(E e) {
        if (queue.size() < maxSize) {
            queue.add(e);
        } else {
            E peek = queue.peek();
            if (orderFlag == 0) {
                if (e.compareTo(peek) < 0) { // 把堆顶最大值替换
                    queue.poll();
                    queue.add(e);
                }
            }
            if (orderFlag == 1) {
                if (e.compareTo(peek) > 0) { // 把堆顶最小值替换
                    queue.poll();
                    queue.add(e);
                }
            }
        }
    }

    private List toArray() {
        List list = new ArrayList<>();
        list.addAll(queue);
        return list;
    }

    public int getLength() {
        return queue.size();
    }

    public void reset() {
        queue = null;
        maxSize = 1;
        orderFlag = 0;
    }

    public List toOrderArray(int orderFlag) {
        List list = toArray();
        if (orderFlag == 0) {
            list.sort(esc);
        } else {
            list.sort(desc);
        }
        return list;
    }

    public int lengthFor(JavaDataModel model) {
        int length = model.object();
        length += model.primitive1();
        length += model.primitive1();
        length += model.object() * 2;      // two comparator
        if (maxSize > 0) {
            length += model.arrayList();   // List
            length += maxSize * model.object();
        }
        return length;
    }
}

4. 构建 UDAF

函数调用方式为 group_n(col, n, 0) ,其中 ${arg1}=col 为表中列数据; ${arg2}=n 为分组取前 top n; ${arg3}=0 表示升序,${arg3}=1表示降序。 [若不指定第三个参数,默认为0]。

GroupTopN 类中,使用 GroupNAggBuffer 来临时存储参数 n0/1,以及 FixSizedPriorityQueueList。聚合缓存对象在 MapReduce 的各个执行过程中调用 terminatePartialterminate 发挥重要作用。

@SuppressWarnings("deprecation")
@Description(name = "group_n", value = "_FUNC_(x, n, 0/1) -- 0: esc; 1: desc. Returns the array contains N elements of a set of numbers")
public class GroupTopN extends AbstractGenericUDAFResolver {

    static final Logger LOG = LoggerFactory.getLogger(GroupTopN.class);

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
        if (info.length < 2 || info.length > 3) {
            throw new UDFArgumentTypeException(info.length - 1, "At least two arguments are expected, " +
                    "but no more than three arguments");
        }
        // validate the first parameter, which is the expression to compute over
        if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0,
                                               "Only primitive type arguments are accepted but "
                                                       + info[0].getTypeName() + " was passed as parameter 1.");
        }
        switch (((PrimitiveTypeInfo) info[0]).getPrimitiveCategory()) {
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case FLOAT:
            case DOUBLE:
            case TIMESTAMP:
            case DECIMAL:
            case STRING:
                break;
            case BOOLEAN:
            case DATE:
            default:
                throw new UDFArgumentTypeException(0,
                                                   "Only numeric type arguments are accepted but "
                                                           + info[0].getTypeName() + " was passed as parameter 1.");
        }

        if (((PrimitiveTypeInfo) info[1]).getPrimitiveCategory()
                != PrimitiveObjectInspector.PrimitiveCategory.INT) {
            throw new UDFArgumentTypeException(1
                    , "Only an integer argument is accepted as parameter 1, but "
                                                       + info[1].getTypeName() + " was passed instead.");
        }
        if (info.length == 3 && ((PrimitiveTypeInfo) info[2]).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.INT) {
            throw new UDFArgumentTypeException(2
                    , "Only an integer argument is accepted as parameter 2, but "
                                                       + info[2].getTypeName() + " was passed instead.");
        }
        return new GroupNEvaluator();
    }

    public static class GroupNEvaluator extends GenericUDAFEvaluator {

        // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
        private transient PrimitiveObjectInspector inputOi;
        private transient PrimitiveObjectInspector nthOi;
        private transient PrimitiveObjectInspector orderOi;

        // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list of element)
        protected transient Object[] partialResult;
        private transient StructObjectInspector soi;
        private transient StructField nthField;
        private transient StructField orderField;
        private transient StructField listField;
        private transient ListObjectInspector loi;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);

            // init input object inspectors
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {  // PARTIAL1 和 COMPLETE 开始都是调用 iterate 接口,入参
                inputOi = (PrimitiveObjectInspector) parameters[0];
                nthOi = (PrimitiveObjectInspector) parameters[1];
                if (parameters.length == 2) {
                    orderOi = PrimitiveObjectInspectorFactory.getPrimitiveObjectInspectorFromClass(Integer.class);
                }
                if (parameters.length == 3) {
                    orderOi = (PrimitiveObjectInspector) parameters[2];
                }
            } else {    // PARTIAL2 和 FINAL 开始都是调用 merge 接口,入参
                soi = (StructObjectInspector) parameters[0];
                nthField = soi.getStructFieldRef("n");
                orderField = soi.getStructFieldRef("order");
                listField = soi.getStructFieldRef("data");
                nthOi = (PrimitiveObjectInspector) nthField.getFieldObjectInspector();
                orderOi = (PrimitiveObjectInspector) orderField.getFieldObjectInspector();
                loi = (ListObjectInspector) listField.getFieldObjectInspector();
            }

            // init output object inspectors
            if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {  // PARTIAL1 和 PARTIAL2 最后调用的接口都是 terminatePartial, 输出 list 格式
                ArrayList foi = new ArrayList<>();
                foi.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                if (m == Mode.PARTIAL1) {
                    foi.add(ObjectInspectorFactory.getStandardListObjectInspector(
                            PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputOi.getTypeInfo())));
                }
                if (m == Mode.PARTIAL2) {
                    foi.add(ObjectInspectorFactory.getStandardListObjectInspector(loi.getListElementObjectInspector()));
                }
                ArrayList fname = new ArrayList<>();
                fname.add("n");
                fname.add("order");
                fname.add("data");

                partialResult = new Object[3];
                partialResult[0] = new IntWritable(1);
                partialResult[1] = new IntWritable(0);
                partialResult[2] = new ArrayList();
                return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);

            } else {    // COMPLETE 和 FINAL 最后调用的接口都是 terminal
                return ObjectInspectorFactory.getStandardListObjectInspector(
                        loi.getListElementObjectInspector());
            }
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            GroupNAggBuffer buffer = new GroupNAggBuffer();
            reset(buffer);
            return buffer;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
            buffer.container.clear();
        }

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            if (parameters[0] == null || parameters[1] == null) {
                return;
            }
            if (parameters.length == 3 && parameters[2] == null) {
                return;
            }
            GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
            int n = PrimitiveObjectInspectorUtils.getInt(parameters[1], nthOi);
            partialResult[0] = new IntWritable(n);
            buffer.n = n;

            int order = 0;
            if (parameters.length == 3) {
                order = PrimitiveObjectInspectorUtils.getInt(parameters[2], orderOi);
            }
            if (order != 0 && order != 1) {
                throw new IllegalArgumentException("the order must be '0' or '1' ");
            }
            partialResult[1] = new IntWritable(order);
            buffer.order = order;

            Object v = ObjectInspectorUtils.copyToStandardObject(parameters[0], inputOi);
            buffer.container.add(v);
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
            ((IntWritable) partialResult[0]).set(buffer.n);
            ((IntWritable) partialResult[1]).set(buffer.order);
            partialResult[2] = buffer.container;
            return partialResult;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
            int n = ((IntWritable) soi.getStructFieldData(partial, nthField)).get();
            int order = ((IntWritable) soi.getStructFieldData(partial, orderField)).get();
            List partialGroupn = (List) loi.getList(soi.getStructFieldData(partial, listField));

            buffer.order = order;
            buffer.queue = new FixSizedPriorityQueue<>(n, order);
            for (int i = 0; i < partialGroupn.size(); i++) {
                Object ele = partialGroupn.get(i);
                if (ele instanceof Comparable) {
                    buffer.queue.add((Comparable) ele);
                }
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            GroupNAggBuffer buffer = (GroupNAggBuffer) agg;
            if (buffer.queue.getLength() < 1) { // SQL standard - return null for zero elements
                return null;
            } else {
                return buffer.queue.toOrderArray(buffer.order);
            }
        }

        @AggregationType(estimable = true)
        static class GroupNAggBuffer extends AbstractAggregationBuffer {
            int n;
            int order;
            FixSizedPriorityQueue queue;
            List container;

            GroupNAggBuffer() {
                queue = new FixSizedPriorityQueue<>();
                container = new ArrayList<>();
            }

            @Override
            public int estimate() {
                return JavaDataModel.get().arrayList();
            }
        }
    }
}
 
 

其中,在 PARTIAL1 阶段,GroupNEvaluator 对象将输入数据和参数组成类似 { "n": 3, "order": 1, "data": [1,2,3,4,5] } 的数据结构并输出。PARTIAL2 阶段或者 FINAL 阶段将 { "n": 3, "order": 1, "data": [1,2,3,4,5] } 数据结构作为输入,最终按照降序 (order = 1 表示降序)输出 [5,4,3]Top 3 集合。

5. 测试结果

  • add jar /Users/yizhou/dtstack/codes/udaf/target/udaf-1.0-SNAPSHOT.jar; 上传jar 资源
  • create temporary function group_n as 'com.func.udaf.GroupTopN'; 注册临时 group_n 函数
  • select job,group_n(sal, 3, 1) from emp group by job; 执行函数
测试实例.png

最终结果,

job top_sal_3
CLERK [1300.0,1100.0, 950.0]
SALESMAN [1600.0,1500.0,1250.0]

参考资料

  • 分组取出每组数据的前N条
  • GenericUDAFCaseStudy
  • Hive编程开发(2)
  • Hive UDAF开发详解

你可能感兴趣的:(Hive 通过 UDAF 实现 `分组取前 Top N`)