----本文笔记整理自 《Hive编程指南》13.9 用户自定义聚合函数
1.聚合函数:指0行到多行的0个到多个列作为参数输入,返回单一值的函数,经常和group by子句一起用。
如:sum(col),avg(col),max(col),std(col)等。
2.实现通用的自定义聚合函数 GenericUDAFAverage(column),如下Java代码:
功能:实现对列数据求平均值
源代码链接:http://svn.apache.org/repos/asf/hive/branches/branch-0.8/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
注:对于ObjectInspector的各子接口/子类的用法,在上篇博文中做过简单的总结-- Hive之ObjectInspector接口解析笔记。特别是其中的 7.利用ObjectInspector解析Object数据 对下面代码的理解有帮助。
package com.hive.udaf;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.util.StringUtils;
/**
* GenericUDAFAverage.
*
*/
@Description(name = "myavg", value = "_FUNC_(x) - Returns the mean of a set of numbers")
public class GenericUDAFAverage extends AbstractGenericUDAFResolver {
static final Log LOG = LogFactory.getLog(GenericUDAFAverage.class.getName());
//读入参数类型校验,满足条件时返回聚合函数数据处理对象
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case STRING:
case TIMESTAMP:
return new GenericUDAFAverageEvaluator();
case BOOLEAN:
default:
throw new UDFArgumentTypeException(0,
"Only numeric or string type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
}
/**
* GenericUDAFAverageEvaluator.
* 自定义静态内部类:数据处理类,继承GenericUDAFEvaluator抽象类
*/
public static class GenericUDAFAverageEvaluator extends GenericUDAFEvaluator {
//1.1.定义全局输入输出数据的类型OI实例,用于解析输入输出数据
// input For PARTIAL1 and COMPLETE
PrimitiveObjectInspector inputOI;
// input For PARTIAL2 and FINAL
// output For PARTIAL1 and PARTIAL2
StructObjectInspector soi;
StructField countField;
StructField sumField;
LongObjectInspector countFieldOI;
DoubleObjectInspector sumFieldOI;
//1.2.定义全局输出数据的类型,用于存储实际数据
// output For PARTIAL1 and PARTIAL2
Object[] partialResult;
// output For FINAL and COMPLETE
DoubleWritable result;
/*
* 初始化:对各个模式处理过程,提取输入数据类型OI,返回输出数据类型OI
* .每个模式(Mode)都会执行初始化
* 1.输入参数parameters:
* .1.1.对于PARTIAL1 和COMPLETE模式来说,是原始数据(单值)
* .设定了iterate()方法的输入参数的类型OI为:
* . PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
* . 通过输入OI实例解析输入参数值
* .1.2.对于PARTIAL2 和FINAL模式来说,是模式聚合数据(双值)
* .设定了merge()方法的输入参数的类型OI为:
* . StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
* . 通过输入OI实例解析输入参数值
* 2.返回值OI:
* .2.1.对于PARTIAL1 和PARTIAL2模式来说,是设定了方法terminatePartial()返回值的OI实例
* .输出OI为 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
* .2.2.对于FINAL 和COMPLETE模式来说,是设定了方法terminate()返回值的OI实例
* .输出OI为 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
*/
@Override
public ObjectInspector init(Mode mode, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(mode, parameters);
// init input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
//部分数据作为输入参数时,用到的struct的OI实例,指定输入数据类型,用于解析数据
soi = (StructObjectInspector) parameters[0];
countField = soi.getStructFieldRef("count");
sumField = soi.getStructFieldRef("sum");
//数组中的每个数据,需要其各自的基本类型OI实例解析
countFieldOI = (LongObjectInspector) countField.getFieldObjectInspector();
sumFieldOI = (DoubleObjectInspector) sumField.getFieldObjectInspector();
}
// init output
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
// The output of a partial aggregation is a struct containing
// a "long" count and a "double" sum.
//部分聚合结果是一个数组
partialResult = new Object[2];
partialResult[0] = new LongWritable(0);
partialResult[1] = new DoubleWritable(0);
/*
* .构造Struct的OI实例,用于设定聚合结果数组的类型
* .需要字段名List和字段类型List作为参数来构造
*/
ArrayList fname = new ArrayList();
fname.add("count");
fname.add("sum");
ArrayList foi = new ArrayList();
//注:此处的两个OI类型 描述的是 partialResult[] 的两个类型,故需一致
foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
} else {
//FINAL 最终聚合结果为一个数值,并用基本类型OI设定其类型
result = new DoubleWritable(0);
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
}
/*
* .聚合数据缓存存储结构
*/
static class AverageAgg implements AggregationBuffer {
long count;
double sum;
};
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
AverageAgg result = new AverageAgg();
reset(result);
return result;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
AverageAgg myagg = (AverageAgg) agg;
myagg.count = 0;
myagg.sum = 0;
}
boolean warned = false;
/*
* .遍历原始数据
*/
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
Object p = parameters[0];
if (p != null) {
AverageAgg myagg = (AverageAgg) agg;
try {
//通过基本数据类型OI解析Object p的值
double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
myagg.count++;
myagg.sum += v;
} catch (NumberFormatException e) {
if (!warned) {
warned = true;
LOG.warn(getClass().getSimpleName() + " "
+ StringUtils.stringifyException(e));
LOG.warn(getClass().getSimpleName()
+ " ignoring similar exceptions.");
}
}
}
}
/*
* .得出部分聚合结果
*/
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
AverageAgg myagg = (AverageAgg) agg;
((LongWritable) partialResult[0]).set(myagg.count);
((DoubleWritable) partialResult[1]).set(myagg.sum);
return partialResult;
}
/*
* .合并部分聚合结果
* .注:Object[] 是 Object 的子类,此处 partial 为 Object[]数组
*/
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if (partial != null) {
AverageAgg myagg = (AverageAgg) agg;
//通过StandardStructObjectInspector实例,分解出 partial 数组元素值
Object partialCount = soi.getStructFieldData(partial, countField);
Object partialSum = soi.getStructFieldData(partial, sumField);
//通过基本数据类型的OI实例解析Object的值
myagg.count += countFieldOI.get(partialCount);
myagg.sum += sumFieldOI.get(partialSum);
}
}
/*
* .得出最终聚合结果
*/
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
AverageAgg myagg = (AverageAgg) agg;
if (myagg.count == 0) {
return null;
} else {
result.set(myagg.sum / myagg.count);
return result;
}
}
}
}
1.聚合函数中的几个过程模式 Mode:(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode)
public abstract class GenericUDAFEvaluator implements Closeable {
public static enum Mode {
/**
* PARTIAL1: from original data to partial aggregation data: iterate() and
* terminatePartial() will be called.
* PARTIAL1: 从原始数据到部分聚合数据的过程,会调用iterate()和terminatePartial()
* 可以理解为MapReduce过程中的map阶段
*/
PARTIAL1,
/**
* PARTIAL2: from partial aggregation data to partial aggregation data:
* merge() and terminatePartial() will be called.
* PARTIAL2: 从部分聚合数据到部分聚合数据的过程(多次聚合),会调用merge()和terminatePartial()
* 可以理解为MapReduce过程中的combine阶段
*/
PARTIAL2,
/**
* FINAL: from partial aggregation to full aggregation: merge() and
* terminate() will be called.
* FINAL: 从部分聚合数据到全部聚合数据的过程,会调用merge()和terminate()
* 可以理解为MapReduce过程中的reduce阶段
*/
FINAL,
/**
* COMPLETE: from original data directly to full aggregation: iterate() and
* terminate() will be called.
* COMPLETE: 从原始数据直接到全部聚合数据的过程,会调用iterate()和terminate()
* 可以理解为MapReduce过程中的直接map输出阶段,没有reduce阶段
*/
COMPLETE
};
}
2.代码结构:
1)需继承AbstractGenericUDAFResolver抽象类,重写方法getEvaluator(TypeInfo[] parameters);
2)内部静态类需继承GenericUDAFEvaluator抽象类,重写方法init(),实现方法getNewAggregationBuffer(),reset(),iterate(),terminatePartial(),merge(),terminate()。
3.程序执行过程:
1)PARTIAL1(阶段1:map):init() --> iterate() --> terminatePartial()
2)PARTIAL2(阶段2:combine):init() --> merge() --> terminatePartial()
3)FINAL (最终阶段:reduce):init() --> merge() --> terminate()
4)COMPLETE(直接输出阶段:只有map):init() --> iterate() --> terminate()
注:每个阶段都会执行init()初始化操作。
1.将com.hive.udaf包右键导出为 JAR file,命名为:"myUDAF.jar";
2.利用Windows的cmd或者PowerShell(推荐)将JAR文件上传到Linux服务器
命令如下:(在JAR文件目录下执行)
> scp myUDAF.jar root@remoteIP:~/myJars/hive/
(其中remoteIP为远程服务器IP)
3.启动hadoop,启动hive('hive>'下输入,仅支持全路径名)
> add jar /root/myJars/hive/myUDAF.jar
<会提示成功加入 class path>
4.注册临时/永久函数
> create temporary function myavg as 'com.hive.udaf.GenericUDAFAverage';(临时,作用本次会话)
> create function myavg as 'com.hive.udaf.GenericUDAFAverage';(永久)
5.运行测试
> select myavg(id) from data;
4
<此即聚合求平均值的结果>
功能:行转列,一列数据连接成一行。
如下代码:
package com.hive.udaf;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.StringUtils;
/*
* .行转列
*/
@Description(name = "mycolconcat", value = "_FUNC_(x) - Returns the concat of a set of cols")
public class ConcatUDAF extends AbstractGenericUDAFResolver{
static final Log LOG = LogFactory.getLog(ConcatUDAF.class.getName());
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case STRING:
case TIMESTAMP:
return new ConcatUDAFEvaluator();
case BOOLEAN:
default:
throw new UDFArgumentTypeException(0,
"Only numeric or string type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
}
public static class ConcatUDAFEvaluator extends GenericUDAFEvaluator {
//Mode的各部分的输入都是String类型,输出也是,所以对应的OI实例也都一样
PrimitiveObjectInspector inputOI;
Text partialResult;
Text result;
@Override
public ObjectInspector init(Mode mode, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(mode, parameters);
// init input
inputOI = (PrimitiveObjectInspector) parameters[0];
// init output
result = new Text("");
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
static class ConcatAgg implements AggregationBuffer {
StringBuilder line = new StringBuilder("");
};
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
ConcatAgg result = new ConcatAgg();
reset(result);
return result;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
ConcatAgg myagg = (ConcatAgg) agg;
myagg.line.delete(0, myagg.line.length());
}
boolean warned = false;
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
Object p = parameters[0];
if (p != null) {
ConcatAgg myagg = (ConcatAgg) agg;
try {
String v = PrimitiveObjectInspectorUtils.getString(p, inputOI);
if (myagg.line.length() == 0)
myagg.line.append(v);
else
myagg.line.append("," + v);
} catch (RuntimeException e) {
if (!warned) {
warned = true;
LOG.warn(getClass().getSimpleName() + " "
+ StringUtils.stringifyException(e));
LOG.warn(getClass().getSimpleName()
+ " ignoring similar exceptions.");
}
}
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
ConcatAgg myagg = (ConcatAgg) agg;
result.set(myagg.line.toString());
return result;
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial != null) {
try {
ConcatAgg myagg = (ConcatAgg) agg;
String v = PrimitiveObjectInspectorUtils.getString(partial, inputOI);
if (myagg.line.length() == 0)
myagg.line.append(v);
else
myagg.line.append("," + v);
} catch (RuntimeException e) {
if (!warned) {
warned = true;
LOG.warn(getClass().getSimpleName() + " "
+ StringUtils.stringifyException(e));
LOG.warn(getClass().getSimpleName()
+ " ignoring similar exceptions.");
}
}
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
ConcatAgg myagg = (ConcatAgg) agg;
result.set(myagg.line.toString());
return result;
}
}
}