Hive UDF小结

HiveUDF简介:

1)Hive中用于扩展HiveSQL功能的用户自定义函数称为HiveUDF
2)UDF又分为UDAF(用户自定义聚合函数),UDTF(用户自钉子表生成函数)

Hive内置函数

实际上Hive内置了很多函数,包括关系/算数/逻辑操作符都属于函数

hive提供的build-in函数包括以下几类:
1. 关系操作符:包括 = 、 <> 、 <= 、>=等
2. 算数操作符:包括 + 、 - 、 *、/等
3. 逻辑操作符:包括AND 、 && 、 OR 、 || 等
4. 复杂类型构造函数:包括map、struct、create_union等
5. 复杂类型操作符:包括A[n]、Map[key]、S.x
6. 数学操作符:包括ln(double a)、sqrt(double a)等
7. 集合操作符:包括size(Array)、sort_array(Array)等
8. 类型转换函数: binary(string|binary)、cast(expr as )
9. 日期函数:包括from_unixtime(bigint unixtime[, string format])、unix_timestamp()等
10.条件函数:包括if(boolean testCondition, T valueTrue, T valueFalseOrNull)等
11. 字符串函数:包括acat(string|binary A, string|binary B…)等
12. 其他:xpath、get_json_objectscii(string str)、con

Hive UDF的实现机制

1)Hive UDF实际上是一个Java类,开发UDF也是实现指定接口开发一个类,该类需要实现规定的方法,Hive引擎依据规则会将这些方法解析成MR任务,实现需求
2)UDF的信息存储在HiveMetaStore中,Hive添加,删除,使用UDF是通过操作metastore实现的
3)源码对应的类是FunctionRegistry,FunctionRegistry注册了所有的内置自定义函数

HiveUDF添加

HiveUDF添加主要分成两种方式
1,重编译Hive源码添加:添加UDF可以在Hive源码中增加新的UDF类,然后在一个FunctionRegistry类中注册,重编译Hive然后使用
2,通过命令行添加:独立开发UDF,将UDF打包成jar,通过Hive命令行添加到系统中(实际上是调用了FunctionRegistry中的方法),于是产生了第三发UDF项目,如[Brickhouse](https://github.com/klout/brickhouse)
添加又分为临时和永久,临时适合于测试,当会话结束,函数会消失,永久添加会稳定添加到Hivemetastore中,重启会话然存在。

1,永久
create function function_name AS ‘full.class.name’ using JAR ‘jar/absolute/path’;
2,临时
ADD JAR hdfs://hive/warehouse/udf/brickhouse-0.7.1-SNAPSHOT.jar;
CREATE TEMPORARY FUNCTION to_json AS ‘brickhouse.udf.json.ToJsonUDF’;
CREATE TEMPORARY FUNCTION combine_unique AS ‘brickhouse.udf.collect.CombineUniqueUDAF’;

HiveUDF开发接口

1.UDF

Hive有两个不同的接口编写UDF程序。一个是基础的UDF接口,一个是复杂的GenericUDF接口。
org.apache.hadoop.hive.ql. exec.UDF 基础UDF的函数读取和返回基本类型,即Hadoop和Hive的基本类型。如,Text、IntWritable、LongWritable、DoubleWritable等。
org.apache.hadoop.hive.ql.udf.generic.GenericUDF 复杂的GenericUDF可以处理Map、List、Set类型。
@Describtion注解是可选的,用于对函数进行说明,其中的FUNC字符串表示函数名,当使用DESCRIBE FUNCTION命令时,替换成函数名。

@Describtion包含三个属性:

name:用于指定Hive中的函数名。
value:用于描述函数的参数。
extended:额外的说明,如,给出示例。当使用DESCRIBE FUNCTION EXTENDED name的时候打印。

简单UDF的实现很简单,只需要继承UDF,然后实现evaluate()方法就行了。


1. @Description(  
2.     name = "hello",  
3.     value = "_FUNC_(str) - from the input string"  
4.         + "returns the value that is \"Hello $str\" ",  
5.     extended = "Example:\n"  
6.         + " > SELECT _FUNC_(str) FROM src;"  
7. )  
8. public class HelloUDF extends UDF{  
9.       
10.     public String evaluate(String str){  
11.         try {  
12.             return "Hello " + str;  
13.         } catch (Exception e) {  
14.             // TODO: handle exception  
15.             e.printStackTrace();  
16.             return "ERROR";  
17.         }  
18.     }  
19. }  

2.GenericUDF

GenericUDF实现比较复杂,需要先继承GenericUDF。这个API需要操作Object Inspectors,并且要对接收的参数类型和数量进行检查。GenericUDF需要实现以下三个方法:


1. //这个方法只调用一次,并且在evaluate()方法之前调用。该方法接受的参数是一个ObjectInspectors数组。该方法检查接受正确的参数类型和参数个数。  
2. abstract ObjectInspector initialize(ObjectInspector[] arguments);  
3.   
4. //这个方法类似UDF的evaluate()方法。它处理真实的参数,并返回最终结果。  
5. abstract Object evaluate(GenericUDF.DeferredObject[] arguments);  
6.   
7. //这个方法用于当实现的GenericUDF出错的时候,打印出提示信息。而提示信息就是你实现该方法最后返回的字符串。  
8. abstract String getDisplayString(String[] children);  

3.UDTF

用户自定义表生成函数(UDTF)接受零个或多个输入,然后产生多列或多行的输出,如explode()。要实现UDTF,需要继承org.apache.hadoop.hive.ql.udf.generic.GenericUDTF,同时实现三个方法:


1. // 该方法指定输入输出参数:输入的Object Inspectors和输出的Struct。  
2. abstract StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException;   
3.   
4. // 该方法处理输入记录,然后通过forward()方法返回输出结果。  
5. abstract void process(Object[] record) throws HiveException;  
6.   
7. // 该方法用于通知UDTF没有行可以处理了。可以在该方法中清理代码或者附加其他处理输出。  
8. abstract void close() throws HiveException;  

4.UDAF

UDAF是需要在hive的sql语句和group by联合使用,hive的group by对于每个分组,只能返回一条记录。
用户自定义聚合函数(UDAF)接受从零行到多行的零个到多个列,然后返回单一值,如sum()、count()。要实现UDAF,我们需要实现下面的类:

org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator

AbstractGenericUDAFResolver检查输入参数,并且指定使用哪个resolver。在AbstractGenericUDAFResolver里,只需要实现一个方法:

public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException;  

但是,主要的逻辑处理还是在Evaluator中。我们需要继承GenericUDAFEvaluator,并且实现下面几个方法:

1. // 输入输出都是Object inspectors  
2. public  ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;  
3.   
4. // AggregationBuffer保存数据处理的临时结果  
5. abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;  
6.   
7. // 重新设置AggregationBuffer  
8. public void reset(AggregationBuffer agg) throws HiveException;  
9.   
10. // 处理输入记录  
11. public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;  
12.   
13. // 处理全部输出数据中的部分数据  
14. public Object terminatePartial(AggregationBuffer agg) throws HiveException;  
15.   
16. // 把两个部分数据聚合起来  
17. public void merge(AggregationBuffer agg, Object partial) throws HiveException;  
18.   
19. // 输出最终结果  
20. public Object terminate(AggregationBuffer agg) throws HiveException;  

在处理之前,先看下UADF的Enum GenericUDAFEvaluator.Mode。Mode有4中情况:

PARTIAL1:Mapper阶段。从原始数据到部分聚合,会调用iterate()terminatePartial()。
PARTIAL2:Combiner阶段,在Mapper端合并Mapper的结果数据。从部分聚合到部分聚合,会调用merge()terminatePartial()。
FINAL:Reducer阶段。从部分聚合数据到完全聚合,会调用merge()terminate()。
COMPLETE:出现这个阶段,表示MapReduce中只用Mapper没有Reducer,所以Mapper端直接输出结果了。从原始数据到完全聚合,会调用iterate()terminate()

在实现UDAF时,主要实现下面几个方法:

init():当实例化UDAF的Evaluator时执行,并且指定输入输出数据的类型。
iterate():把输入数据处理后放入到内存聚合块中(AggregationBuffer),典型的MapperterminatePartial():其为iterate()轮转结束后,返回轮转数据,类似于Combinermerge():介绍terminatePartial()的结果,然后把这些partial结果数据merge到一起。
terminate():返回最终的结果。
iterate()terminatePartial()都在Mapper端。
merge()terminate()都在Reducer端。
AggregationBuffer存储中间或最终结果。通过我们定义自己的Aggregation Buffer,可以处理任何类型的数据。

UDAF开发流程——以GenericUDAFSum为例

开发通用UDAF有两个步骤,第一个是编写resolver类,第二个是编写evaluator类。resolver负责类型检查,操作符重载。evaluator真正实现UDAF的逻辑。通常来说,顶层UDAF类继承org.apache.hadoop.hive.ql.udf.GenericUDAFResolver2,里面编写嵌套类evaluator 实现UDAF的逻辑。

实现 resolver

resolver通常继承org.apache.hadoop.hive.ql.udf.GenericUDAFResolver2,但是我们更建议继承AbstractGenericUDAFResolver,隔离将来hive接口的变化。

GenericUDAFResolver和GenericUDAFResolver2接口的区别是,后面的允许evaluator实现可以访问更多的信息,例如DISTINCT限定符,通配符FUNCTION(*)。

public class GenericUDAFSum extends AbstractGenericUDAFResolver {

  static final Log LOG = LogFactory.getLog(GenericUDAFSum.class.getName());

  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
    throws SemanticException {
    // Type-checking goes here!
    return new GenericUDAFSumLong(); 
  } 

  public static class GenericUDAFSumLong extends GenericUDAFEvaluator {
    // UDAF logic goes here!
  } 
}

这个就是UDAF的代码骨架,第一行创建LOG对象,用来写入警告和错误到hive的log。GenericUDAFResolver只需要重写一个方法:getEvaluator,它根据SQL传入的参数类型,返回正确的evaluator。这里最主要是实现操作符的重载。
getEvaluator的完整代码如下:

public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
    throws SemanticException {
    if (parameters.length != 1) {
      throw new UDFArgumentTypeException(parameters.length - 1,
          "Exactly one argument is expected.");
    }

    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
          "Only primitive type arguments are accepted but "
          + parameters[0].getTypeName() + " is passed.");
    }
    switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case TIMESTAMP:
      return new GenericUDAFSumLong();
    case FLOAT:
    case DOUBLE:
    case STRING:
      return new GenericUDAFSumDouble();
    case BOOLEAN:
    default:
      throw new UDFArgumentTypeException(0,
          "Only numeric or string type arguments are accepted but "
          + parameters[0].getTypeName() + " is passed.");
    }

这里做了类型检查,如果不是原生类型(即符合类型,array,map此类),则抛出异常,还实现了操作符重载,对于整数类型,使用GenericUDAFSumLong实现UDAF的逻辑,对于浮点类型,使用GenericUDAFSumDouble实现UDAF的逻辑。

实现evaluator

所有evaluators必须继承抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator。子类必须实现它的一些抽象方法,实现UDAF的逻辑。
GenericUDAFEvaluator有一个嵌套类Mode,这个类很重要,它表示了udaf在mapreduce的各个阶段,理解Mode的含义,就可以理解了hive的UDAF的运行流程。

public static enum Mode {
    /**
     * PARTIAL1: 这个是mapreduce的map阶段:从原始数据到部分数据聚合
     * 将会调用iterate()和terminatePartial()
     */
    PARTIAL1,
        /**
     * PARTIAL2: 这个是mapreduce的map端的Combiner阶段,负责在map端合并map的数据::从部分数据聚合到部分数据聚合:
     * 将会调用merge() 和 terminatePartial() 
     */
    PARTIAL2,
        /**
     * FINAL: mapreduce的reduce阶段:从部分数据的聚合到完全聚合 
     * 将会调用merge()和terminate()
     */
    FINAL,
        /**
     * COMPLETE: 如果出现了这个阶段,表示mapreduce只有map,没有reduce,所以map端就直接出结果了:从原始数据直接到完全聚合
      * 将会调用 iterate()和terminate()
     */
    COMPLETE
  };

一般情况下,完整的UDAF逻辑是一个mapreduce过程,如果有mapper和reducer,就会经历PARTIAL1(mapper),FINAL(reducer),如果还有combiner,那就会经历PARTIAL1(mapper),PARTIAL2(combiner),FINAL(reducer)。
而有一些情况下的mapreduce,只有mapper,而没有reducer,所以就会只有COMPLETE阶段,这个阶段直接输入原始数据,出结果。

下面以GenericUDAFSumLong的evaluator实现讲解

public static class GenericUDAFSumLong extends GenericUDAFEvaluator {

private PrimitiveObjectInspector inputOI;
    private LongWritable result;

   //这个方法返回了UDAF的返回类型,这里确定了sum自定义函数的返回类型是Long类型
    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
      assert (parameters.length == 1);
      super.init(m, parameters);
      result = new LongWritable(0);
      inputOI = (PrimitiveObjectInspector) parameters[0];
      return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
    }

    /** 存储sum的值的类 */
    static class SumLongAgg implements AggregationBuffer {
      boolean empty;
      long sum;
    }

    //创建新的聚合计算的需要的内存,用来存储mapper,combiner,reducer运算过程中的相加总和。

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      SumLongAgg result = new SumLongAgg();
      reset(result);
      return result;
    }
    
    //mapreduce支持mapper和reducer的重用,所以为了兼容,也需要做内存的重用。

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
      SumLongAgg myagg = (SumLongAgg) agg;
      myagg.empty = true;
      myagg.sum = 0;
    }

    private boolean warned = false;
  
    //map阶段调用,只要把保存当前和的对象agg,再加上输入的参数,就可以了。
    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
      assert (parameters.length == 1);
      try {
        merge(agg, parameters[0]);
      } catch (NumberFormatException e) {
        if (!warned) {
          warned = true;
          LOG.warn(getClass().getSimpleName() + " "
              + StringUtils.stringifyException(e));
        }
      }
    }
   //mapper结束要返回的结果,还有combiner结束返回的结果
    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
      return terminate(agg);
    }

    //combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果。
    @Override
    public void merge(AggregationBuffer agg, Object partial) throws HiveException {
      if (partial != null) {
        SumLongAgg myagg = (SumLongAgg) agg;
        myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, inputOI);
        myagg.empty = false;
      }
    }

    //reducer返回结果,或者是只有mapper,没有reducer时,在mapper端返回结果。
    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      SumLongAgg myagg = (SumLongAgg) agg;
      if (myagg.empty) {
        return null;
      }
      result.set(myagg.sum);
      return result;
    }

  }

除了GenericUDAFSumLong,还有重载的GenericUDAFSumDouble,以上代码都在hive的源码:org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum。

实例

求余弦函数

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import java.lang.Math;

public class CosSimilar extends UDAF {
    public static class CosSimilarState {
        private double mSumXY;
        private double mSumXX;
        private double mSumYY;
    }

    public static class CosSimilarEvaluator implements UDAFEvaluator {
        CosSimilarState state;
        public CosSimilarEvaluator() {
            super();
            state = new CosSimilarState();
            init();
        }

        /** * init函数类似于构造函数,用于UDAF的初始化 */

        public void init() {
            state.mSumXY = 0;
            state.mSumXX = 0;
            state.mSumYY = 0;
        }

        /** * iterate接收传入的参数,并进行内部的轮转。其返回类型为boolean * * @param o * @return */

        public boolean iterate(Double x, Double y) {
            state.mSumXY += x * y;
            state.mSumXX += x * x;
            state.mSumYY += y * y;

            return true;
        }

        /** * terminatePartial无参数,其为iterate函数轮转结束后,返回轮转数据, * terminatePartial类似于hadoop的Combiner * * @return */

        public CosSimilarState terminatePartial() {
            // combiner
            return state;
        }

        /** * merge接收terminatePartial的返回结果,进行数据merge操作,其返回类型为boolean * * @param o * @return */

        public boolean merge(CosSimilarState other) {
            if (other != null) {
                state.mSumXY += other.mSumXY;
                state.mSumXX += other.mSumXX;
                state.mSumYY += other.mSumYY;
            }

            return true;
        }

        /** * terminate返回最终的聚集函数结果 * * @return */

        public Double terminate() {
            if (state.mSumXX < 0.0001 || state.mSumYY < 0.0001){
                return 0.0;
            }
            else {
                return Double.valueOf(state.mSumXY / Math.sqrt(state.mSumXX) / Math.sqrt(state.mSumYY));
            }
        }

    }
}

SumTwo

工作需要开发的SumTwo

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.json.JSONException;
import org.json.JSONObject;

import java.math.BigDecimal;
import java.util.*;

/**
 * GenericUDAFSumTwo.
 */
@Description(name = "sum_two", value = "_FUNC_(x) - Returns the sum of two sets of numbers")
public class GenericUDAFSumTwo extends AbstractGenericUDAFResolver {

  static final Log LOG = LogFactory.getLog(GenericUDAFSumTwo.class.getName());

  @Override
  public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
    TypeInfo[] parameters = info.getParameters();
    if (parameters.length != 3) {
      throw new UDFArgumentException("Please specify exactly three arguments.");
    }

    // vid
    // validate the first parameter, which is the number of histogram bins
    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
              "Only primitive type arguments are accepted but " + parameters[0].getTypeName()
                      + " was passed as parameter 1.");
    }

    // metric1
    // validate the second parameter, which is the expression to compute over
    if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(1,
              "Only primitive type arguments are accepted but " + parameters[1].getTypeName()
                      + " was passed as parameter 2.");
    }

    // metric2
    // validate the third parameter, which is the expression to compute over
    if (parameters[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(2,
              "Only primitive type arguments are accepted but " + parameters[2].getTypeName()
                      + " was passed as parameter 3.");
    }

    switch (((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) {
      case BYTE:
      case SHORT:
      case INT:
      case LONG:
      case FLOAT:
      case DOUBLE:
        break;
      case STRING:
      case BOOLEAN:
      default:
        throw new UDFArgumentTypeException(1,
                "Only numeric type arguments are accepted but " + parameters[1].getTypeName()
                        + " was passed as parameter 2.");
    }

    switch (((PrimitiveTypeInfo) parameters[2]).getPrimitiveCategory()) {
      case BYTE:
      case SHORT:
      case INT:
      case LONG:
      case FLOAT:
      case DOUBLE:
        break;
      case STRING:
      case BOOLEAN:
      default:
        throw new UDFArgumentTypeException(2,
                "Only numeric type arguments are accepted but " + parameters[2].getTypeName()
                        + " was passed as parameter 3.");
    }

    return new GenericUDAFSumTwoEvaluator();
  }

  /**
   * ABTestTwoSumUDAFEvaluator.
   */
  public static class GenericUDAFSumTwoEvaluator extends GenericUDAFEvaluator {
    private String jobname;
    private PrimitiveObjectInspector inputOI1;
    private PrimitiveObjectInspector inputOI2;
    private PrimitiveObjectInspector inputOI3;
    private ObjectInspector outputOI;
    private Mode mode;

    @Override
    public void configure(MapredContext mapredContext) {
      jobname = mapredContext.getJobConf().get("mapreduce.job.name", "abtest_confidence_rate");
      jobname = jobname.replaceAll(" ", "_");
      if (jobname.length() > 100) {
        jobname = "abtest_confidence_rate" + new Random().nextInt((int) (System.currentTimeMillis() % 10000));
      }
    }

    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
      super.init(m, parameters);
      this.mode = m;
      if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
        // vid
        inputOI1 = (PrimitiveObjectInspector) parameters[0];
        // metric1
        inputOI2 = (PrimitiveObjectInspector) parameters[1];
        // metric2
        inputOI3 = (PrimitiveObjectInspector) parameters[2];
      }
      if (m == Mode.FINAL || m == Mode.COMPLETE) {
        outputOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
      } else {
        outputOI = ObjectInspectorFactory.getStandardMapObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
      }
      return outputOI;
    }

    /**
     * class for storing vid string and two sum values.
     */
    static class SumAgg extends AbstractAggregationBuffer {
      String vid;
      Map resultMap = new HashMap<>();
    }

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      SumAgg result = new SumAgg();
      reset(result);
      return result;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
      SumAgg myagg = (SumAgg) agg;
      myagg.vid = null;
      myagg.resultMap = new HashMap<>();
    }

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
      if (parameters == null) {
        return;
      }
      assert (parameters.length == 3);

      try {
        // vid
        if (parameters[0] == null || inputOI1 == null || inputOI1.getPrimitiveJavaObject(parameters[0]) == null) {
          return;
        }

        // metric1
        double number1;
        Object object2 = inputOI2.getPrimitiveJavaObject(parameters[1]);
        switch (inputOI2.getPrimitiveCategory()) {
          case SHORT:
            number1 = (short) object2;
            break;
          case INT:
            number1 = (int) object2;
            break;
          case LONG:
            number1 = (long) object2;
            break;
          case FLOAT:
            number1 = (float) object2;
            break;
          case DOUBLE:
            number1 = (double) object2;
            break;
          default:
            throw new UDFArgumentTypeException(1, "Only numeric type arguments are accepted");
        }

        //metric2
        double number2;
        Object object3 = inputOI3.getPrimitiveJavaObject(parameters[2]);
        switch (inputOI3.getPrimitiveCategory()) {
          case SHORT:
            number2 = (short) object3;
            break;
          case INT:
            number2 = (int) object3;
            break;
          case LONG:
            number2 = (long) object3;
            break;
          case FLOAT:
            number2 = (float) object3;
            break;
          case DOUBLE:
            number2 = (double) object3;
            break;
          default:
            throw new UDFArgumentTypeException(2, "Only numeric type arguments are accepted");
        }

        String vid = String.valueOf(inputOI1.getPrimitiveJavaObject(parameters[0]));

        if(agg == null){
          agg = new SumAgg();
        }

        SumAgg myAgg = (SumAgg)agg;
        myAgg.vid = vid;

        if (myAgg.resultMap == null) {
          throw new UDFArgumentException("Result map is null");
        }

        myAgg.resultMap.put(vid + "#1", myAgg.resultMap.getOrDefault(vid + "#1", 0.0) + number1);
        myAgg.resultMap.put(vid + "#2", myAgg.resultMap.getOrDefault(vid + "#2", 0.0) + number2);

      } catch (Exception e) {
        throw new HiveException(e);
      }
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
      return ((SumAgg)agg).resultMap;
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial) throws HiveException {

      if(partial != null){
        Map partailResultMap= ((LazyBinaryMap) partial).getMap();

        SumAgg myAgg = (SumAgg)agg;
        if(myAgg == null){
          myAgg = new SumAgg();
        }

        for (Object nameObj: partailResultMap.keySet()) {
          String key = nameObj.toString();
          Object valueObj = partailResultMap.getOrDefault(nameObj, null);
          if (valueObj != null) {
            myAgg.resultMap.put(key, ((DoubleWritable) valueObj).get() + myAgg.resultMap.getOrDefault(key, 0.0));
          }
        }
      }

    }

    //
    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      SumAgg myagg = (SumAgg) agg;
      if (myagg.resultMap == null) {
        return null;
      }
      Map resultMap =  ((SumAgg)agg).resultMap;

      JSONObject jsonObject = new JSONObject();
      try {
        for (String key: resultMap.keySet()) {
          String[] splits = key.split("#");
          String vid = splits[0];
          String uid = splits[1];

          jsonObject.put("vid", vid);
          if(uid.equals("1")){
            jsonObject.put("sum1", new BigDecimal(String.valueOf(resultMap.get(key))).toString());
          }
          else if(uid.equals("2")){
            jsonObject.put("sum2", new BigDecimal(String.valueOf(resultMap.get(key))).toString());
          }
        }
      } catch (JSONException e) {
        e.printStackTrace();
      }

      return jsonObject.toString();
    }

  }


}

你可能感兴趣的:(Hive)