Hive之自定义聚合函数UDAF

----本文笔记整理自 《Hive编程指南》13.9 用户自定义聚合函数
 

一、自定义聚合函数(GenericUDAFAverage实现)

1.聚合函数:指0行到多行的0个到多个列作为参数输入,返回单一值的函数,经常和group by子句一起用。

    如:sum(col),avg(col),max(col),std(col)等。

2.实现通用的自定义聚合函数 GenericUDAFAverage(column),如下Java代码:

    功能:实现对列数据求平均值

    源代码链接:http://svn.apache.org/repos/asf/hive/branches/branch-0.8/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java 

    注:对于ObjectInspector的各子接口/子类的用法,在上篇博文中做过简单的总结-- Hive之ObjectInspector接口解析笔记。特别是其中的 7.利用ObjectInspector解析Object数据 对下面代码的理解有帮助。

package com.hive.udaf;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.util.StringUtils;

/**
 * GenericUDAFAverage.
 *
 */
@Description(name = "myavg", value = "_FUNC_(x) - Returns the mean of a set of numbers")
public class GenericUDAFAverage extends AbstractGenericUDAFResolver {

  static final Log LOG = LogFactory.getLog(GenericUDAFAverage.class.getName());

  //读入参数类型校验,满足条件时返回聚合函数数据处理对象
  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
      throws SemanticException {
    if (parameters.length != 1) {
      throw new UDFArgumentTypeException(parameters.length - 1,
          "Exactly one argument is expected.");
    }

    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
          "Only primitive type arguments are accepted but "
          + parameters[0].getTypeName() + " is passed.");
    }
    switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case FLOAT:
    case DOUBLE:
    case STRING:
    case TIMESTAMP:
      return new GenericUDAFAverageEvaluator();
    case BOOLEAN:
    default:
      throw new UDFArgumentTypeException(0,
          "Only numeric or string type arguments are accepted but "
          + parameters[0].getTypeName() + " is passed.");
    }
  }

  /**
   * GenericUDAFAverageEvaluator.
     * 自定义静态内部类:数据处理类,继承GenericUDAFEvaluator抽象类
   */
  public static class GenericUDAFAverageEvaluator extends GenericUDAFEvaluator {

	//1.1.定义全局输入输出数据的类型OI实例,用于解析输入输出数据
    // input For PARTIAL1 and COMPLETE
    PrimitiveObjectInspector inputOI;

    // input For PARTIAL2 and FINAL
    // output For PARTIAL1 and PARTIAL2
    StructObjectInspector soi;
    StructField countField;
    StructField sumField;
    LongObjectInspector countFieldOI;
    DoubleObjectInspector sumFieldOI;

    //1.2.定义全局输出数据的类型,用于存储实际数据
    // output For PARTIAL1 and PARTIAL2
    Object[] partialResult;

    // output For FINAL and COMPLETE
    DoubleWritable result;

    /*
         * 初始化:对各个模式处理过程,提取输入数据类型OI,返回输出数据类型OI  
     * .每个模式(Mode)都会执行初始化
     * 1.输入参数parameters:
     * .1.1.对于PARTIAL1 和COMPLETE模式来说,是原始数据(单值)
     *    .设定了iterate()方法的输入参数的类型OI为:
     *    .		 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
     *    .		 通过输入OI实例解析输入参数值
     * .1.2.对于PARTIAL2 和FINAL模式来说,是模式聚合数据(双值)
     *    .设定了merge()方法的输入参数的类型OI为:
     *    .		 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
     *    .		 通过输入OI实例解析输入参数值
     * 2.返回值OI:
     * .2.1.对于PARTIAL1 和PARTIAL2模式来说,是设定了方法terminatePartial()返回值的OI实例
     *    .输出OI为 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
     * .2.2.对于FINAL 和COMPLETE模式来说,是设定了方法terminate()返回值的OI实例
     *    .输出OI为 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
     */
    @Override
    public ObjectInspector init(Mode mode, ObjectInspector[] parameters)
        throws HiveException {
      assert (parameters.length == 1);
      super.init(mode, parameters);

      // init input
      if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
        inputOI = (PrimitiveObjectInspector) parameters[0];
      } else {
    	//部分数据作为输入参数时,用到的struct的OI实例,指定输入数据类型,用于解析数据
        soi = (StructObjectInspector) parameters[0];
        countField = soi.getStructFieldRef("count");
        sumField = soi.getStructFieldRef("sum");
        //数组中的每个数据,需要其各自的基本类型OI实例解析
        countFieldOI = (LongObjectInspector) countField.getFieldObjectInspector();
        sumFieldOI = (DoubleObjectInspector) sumField.getFieldObjectInspector();
      }

      // init output
      if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
        // The output of a partial aggregation is a struct containing
        // a "long" count and a "double" sum.
    	//部分聚合结果是一个数组
    	partialResult = new Object[2];
        partialResult[0] = new LongWritable(0);
        partialResult[1] = new DoubleWritable(0);
        /*
         * .构造Struct的OI实例,用于设定聚合结果数组的类型
         * .需要字段名List和字段类型List作为参数来构造
         */
        ArrayList fname = new ArrayList();
        fname.add("count");
        fname.add("sum");
        ArrayList foi = new ArrayList();
        //注:此处的两个OI类型 描述的是 partialResult[] 的两个类型,故需一致
        foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
      } else {
    	//FINAL 最终聚合结果为一个数值,并用基本类型OI设定其类型
        result = new DoubleWritable(0);
        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
      }
    }

    /*
     * .聚合数据缓存存储结构
     */
    static class AverageAgg implements AggregationBuffer {
      long count;
      double sum;
    };

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      AverageAgg result = new AverageAgg();
      reset(result);
      return result;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
      AverageAgg myagg = (AverageAgg) agg;
      myagg.count = 0;
      myagg.sum = 0;
    }

    boolean warned = false;

    /*
     * .遍历原始数据
     */
    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
        throws HiveException {
      assert (parameters.length == 1);
      Object p = parameters[0];
      if (p != null) {
        AverageAgg myagg = (AverageAgg) agg;
        try {
          //通过基本数据类型OI解析Object p的值
          double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
          myagg.count++;
          myagg.sum += v;
        } catch (NumberFormatException e) {
          if (!warned) {
            warned = true;
            LOG.warn(getClass().getSimpleName() + " "
                + StringUtils.stringifyException(e));
            LOG.warn(getClass().getSimpleName()
                + " ignoring similar exceptions.");
          }
        }
      }
    }

    /*
     * .得出部分聚合结果
     */
    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
      AverageAgg myagg = (AverageAgg) agg;
      ((LongWritable) partialResult[0]).set(myagg.count);
      ((DoubleWritable) partialResult[1]).set(myagg.sum);
      return partialResult;
    }

    /*
     * .合并部分聚合结果
     * .注:Object[] 是 Object 的子类,此处 partial 为 Object[]数组
     */
    @Override
    public void merge(AggregationBuffer agg, Object partial)
        throws HiveException {
      if (partial != null) {
        AverageAgg myagg = (AverageAgg) agg;
        //通过StandardStructObjectInspector实例,分解出 partial 数组元素值
        Object partialCount = soi.getStructFieldData(partial, countField);
        Object partialSum = soi.getStructFieldData(partial, sumField);
        //通过基本数据类型的OI实例解析Object的值
        myagg.count += countFieldOI.get(partialCount);
        myagg.sum += sumFieldOI.get(partialSum);
      }
    }

    /*
     * .得出最终聚合结果
     */
    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      AverageAgg myagg = (AverageAgg) agg;
      if (myagg.count == 0) {
        return null;
      } else {
        result.set(myagg.sum / myagg.count);
        return result;
      }
    }
  }

}

二、代码解析

1.聚合函数中的几个过程模式 Mode:(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode)

public abstract class GenericUDAFEvaluator implements Closeable {

  public static enum Mode {
    /**
     * PARTIAL1: from original data to partial aggregation data: iterate() and
     * terminatePartial() will be called.
     * PARTIAL1: 从原始数据到部分聚合数据的过程,会调用iterate()和terminatePartial()
     * 可以理解为MapReduce过程中的map阶段
     */
    PARTIAL1,
        /**
     * PARTIAL2: from partial aggregation data to partial aggregation data:
     * merge() and terminatePartial() will be called.
     * PARTIAL2: 从部分聚合数据到部分聚合数据的过程(多次聚合),会调用merge()和terminatePartial()
     * 可以理解为MapReduce过程中的combine阶段
     */
    PARTIAL2,
        /**
     * FINAL: from partial aggregation to full aggregation: merge() and
     * terminate() will be called.
     * FINAL: 从部分聚合数据到全部聚合数据的过程,会调用merge()和terminate()
     * 可以理解为MapReduce过程中的reduce阶段
     */
    FINAL,
        /**
     * COMPLETE: from original data directly to full aggregation: iterate() and
     * terminate() will be called.
     * COMPLETE: 从原始数据直接到全部聚合数据的过程,会调用iterate()和terminate()
     * 可以理解为MapReduce过程中的直接map输出阶段,没有reduce阶段
     */
    COMPLETE
  };

}

2.代码结构:

1)需继承AbstractGenericUDAFResolver抽象类,重写方法getEvaluator(TypeInfo[] parameters);

2)内部静态类需继承GenericUDAFEvaluator抽象类,重写方法init(),实现方法getNewAggregationBuffer(),reset(),iterate(),terminatePartial(),merge(),terminate()。

3.程序执行过程:

1)PARTIAL1(阶段1:map):init() --> iterate() --> terminatePartial()

2)PARTIAL2(阶段2:combine):init() --> merge() --> terminatePartial()

3)FINAL (最终阶段:reduce):init() --> merge() --> terminate()

4)COMPLETE(直接输出阶段:只有map):init() --> iterate() --> terminate()

注:每个阶段都会执行init()初始化操作。

 

三、打包Jar file,并运行测试

1.将com.hive.udaf包右键导出为 JAR file,命名为:"myUDAF.jar";
2.利用Windows的cmd或者PowerShell(推荐)将JAR文件上传到Linux服务器
  命令如下:(在JAR文件目录下执行)
  > scp myUDAF.jar root@remoteIP:~/myJars/hive/
 (其中remoteIP为远程服务器IP)
3.启动hadoop,启动hive('hive>'下输入,仅支持全路径名)
  > add jar /root/myJars/hive/myUDAF.jar
  <会提示成功加入 class path>
4.注册临时/永久函数
  > create temporary function myavg as 'com.hive.udaf.GenericUDAFAverage';(临时,作用本次会话)
  > create function myavg as 'com.hive.udaf.GenericUDAFAverage';(永久)
5.运行测试
  > select myavg(id) from data;
  4
  <此即聚合求平均值的结果>

四、另外实现了自定义聚合函数Concat(col)

功能:行转列,一列数据连接成一行。

如下代码:

package com.hive.udaf;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.StringUtils;

/*
 * .行转列
 */
@Description(name = "mycolconcat", value = "_FUNC_(x) - Returns the concat of a set of cols")
public class ConcatUDAF extends AbstractGenericUDAFResolver{
	
	static final Log LOG = LogFactory.getLog(ConcatUDAF.class.getName());
	
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) 
			throws SemanticException {
		if (parameters.length != 1) {
		      throw new UDFArgumentTypeException(parameters.length - 1,
		          "Exactly one argument is expected.");
	    }

	    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
	      throw new UDFArgumentTypeException(0,
	          "Only primitive type arguments are accepted but "
	          + parameters[0].getTypeName() + " is passed.");
	    }
	    switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
	    case BYTE:
	    case SHORT:
	    case INT:
	    case LONG:
	    case FLOAT:
	    case DOUBLE:
	    case STRING:
	    case TIMESTAMP:
	    	return new ConcatUDAFEvaluator();
	    case BOOLEAN:
	    default:
	    	throw new UDFArgumentTypeException(0,
	    			"Only numeric or string type arguments are accepted but "
	    					+ parameters[0].getTypeName() + " is passed.");
	    }
	}
	
	public static class ConcatUDAFEvaluator extends GenericUDAFEvaluator {
		
		//Mode的各部分的输入都是String类型,输出也是,所以对应的OI实例也都一样
		PrimitiveObjectInspector inputOI;
		
		Text partialResult;
		
		Text result;
		
		@Override
	    public ObjectInspector init(Mode mode, ObjectInspector[] parameters)
	        throws HiveException {
			assert (parameters.length == 1);
			super.init(mode, parameters);
			
			// init input
			inputOI = (PrimitiveObjectInspector) parameters[0];
			
			// init output
			result = new Text("");
			return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
	    }
		
		static class ConcatAgg implements AggregationBuffer {
			StringBuilder line = new StringBuilder("");
	    };
		
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			ConcatAgg result = new ConcatAgg();
			reset(result);
			return result;
		}

		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			ConcatAgg myagg = (ConcatAgg) agg;
			myagg.line.delete(0, myagg.line.length());
		}
		
		boolean warned = false;

		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
			Object p = parameters[0];
			if (p != null) {
				ConcatAgg myagg = (ConcatAgg) agg;
				try {
					String v = PrimitiveObjectInspectorUtils.getString(p, inputOI);
					if (myagg.line.length() == 0)
						myagg.line.append(v);
					else
						myagg.line.append("," + v);
				} catch (RuntimeException e) {
					if (!warned) {
						warned = true;
						LOG.warn(getClass().getSimpleName() + " "
								+ StringUtils.stringifyException(e));
						LOG.warn(getClass().getSimpleName()
								+ " ignoring similar exceptions.");
					}
				}
			}
		}

		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			ConcatAgg myagg = (ConcatAgg) agg;
			result.set(myagg.line.toString());
			return result;
		}

		@Override
		public void merge(AggregationBuffer agg, Object partial) throws HiveException {
			if (partial != null) {
				try {
					ConcatAgg myagg = (ConcatAgg) agg;
					String v = PrimitiveObjectInspectorUtils.getString(partial, inputOI);
					if (myagg.line.length() == 0)
						myagg.line.append(v);
					else
						myagg.line.append("," + v);
				} catch (RuntimeException e) {
					if (!warned) {
						warned = true;
						LOG.warn(getClass().getSimpleName() + " "
								+ StringUtils.stringifyException(e));
						LOG.warn(getClass().getSimpleName()
								+ " ignoring similar exceptions.");
					}
				}
			}
		}

		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			ConcatAgg myagg = (ConcatAgg) agg;
			result.set(myagg.line.toString());
			return result;
		}
	}
	
}

 

你可能感兴趣的:(Hive)