Hive自定义聚合函数UDAF(计算中位数)

背景需求

中位数(Median)又称中值,统计学中的专有名词,是按顺序排列的一组数据中居于中间位置的数,代表一个样本、种群或概率分布中的一个数值,其可将数值集合划分为相等的上下两部分。对于有限的数集,可以通过把所有观察值高低排序后找出正中间的一个作为中位数。如果观察值有偶数个,通常取最中间的两个数值的平均数作为中位数。

准备1~7个乱序数字

Hive自定义聚合函数UDAF(计算中位数)_第1张图片

奇数个数字经过排序【1,2,3,4,5,6,7】取出中间一个数字是4,则中位数是4

Hive自定义聚合函数UDAF(计算中位数)_第2张图片

再增加一个数字8,依然是乱序

Hive自定义聚合函数UDAF(计算中位数)_第3张图片

偶数个数字经过排序【1,2,3,4,5,6,7,8】取出中间两个数字【4,5】之和除以2则中位数是4.5

Hive自定义聚合函数UDAF(计算中位数)_第4张图片

 

代码实现

package udf;

import com.google.common.collect.Lists;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
 * @ClassName: UDAFNumericMedian
 * @Description: 
 * @Author: xuezhouyi
 * @Version: V1.0
 **/
public class UDAFNumericMedian extends AbstractGenericUDAFResolver {

	/* 该方法会根据sql传入的参数数据格式指定调用哪个Evaluator进行处理 */
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
		if (parameters.length != 1) {
			throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
		}
		if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
			throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted.");
		}
		if (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.DOUBLE) {
			throw new UDFArgumentTypeException(0, "Only Double type arguments are accepted.");
		}
		return new DoubleEvaluator();
	}

	private static class DoubleEvaluator extends GenericUDAFEvaluator {
		private PrimitiveObjectInspector inputOI;
		private StandardListObjectInspector mergeOI;

		/* 确定各个阶段输入输出参数的数据格式ObjectInspectors */
		@Override
		public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
			super.init(m, parameters);
			if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
				inputOI = (PrimitiveObjectInspector) parameters[0];
				return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI));
			} else {
				if (parameters[0] instanceof StandardListObjectInspector) {
					mergeOI = (StandardListObjectInspector) parameters[0];
					inputOI = (PrimitiveObjectInspector) mergeOI.getListElementObjectInspector();
					return ObjectInspectorUtils.getStandardObjectInspector(mergeOI);
				} else {
					inputOI = (PrimitiveObjectInspector) parameters[0];
					return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI));
				}
			}
		}

		/* 保存数据聚集结果的类 */
		private static class MyAggBuf implements AggregationBuffer {
			List container = Lists.newArrayList();
		}

		/* 重置聚集结果 */
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			((MyAggBuf) agg).container.clear();
		}

		/* 获取数据集结果类 */
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			MyAggBuf myAgg = new MyAggBuf();
			return myAgg;
		}

		/* map阶段,迭代处理输入sql传过来的列数据 */
		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
			if (parameters == null || parameters.length != 1) {
				return;
			}
			if (parameters[0] != null) {
				MyAggBuf myAgg = (MyAggBuf) agg;
				myAgg.container.add(ObjectInspectorUtils.copyToStandardObject(parameters[0], this.inputOI));
			}
		}

		/* map与combiner结束返回结果,得到部分数据聚集结果 */
		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			MyAggBuf myAgg = (MyAggBuf) agg;
			List list = Lists.newArrayList(myAgg.container);
			return list;
		}

		/* combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果 */
		@Override
		public void merge(AggregationBuffer agg, Object partial) throws HiveException {
			if (partial == null) {
				return;
			}
			MyAggBuf myAgg = (MyAggBuf) agg;
			List partialResult = (List) mergeOI.getList(partial);
			for (Object ob : partialResult) {
				myAgg.container.add(ObjectInspectorUtils.copyToStandardObject(ob, this.inputOI));
			}
		}

		/* reducer阶段,输出最终结果 */
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			/* 排序 */
			List container = ((MyAggBuf) agg).container;
			LinkedList doubles = new LinkedList<>();
			for (Object o : container) {
				doubles.add(((DoubleWritable) o).get());
			}
			Collections.sort(doubles);

			/* 计算中位数 */
			DoubleWritable median;
			int size = doubles.size();
			if (size % 2 != 0) {
				int i = size / 2;
				median = new DoubleWritable(doubles.get(i));
			} else {
				int i = size / 2;
				int j = size / 2 - 1;
				median = new DoubleWritable((doubles.get(i) + doubles.get(j)) / 2);
			}

			/* 构建对象并返回 */
			ArrayList objects = new ArrayList<>();
			objects.add(median);
			return objects;
		}
	}
} 
  

 

分组测试

Hive自定义聚合函数UDAF(计算中位数)_第5张图片

 

运行结果

Hive自定义聚合函数UDAF(计算中位数)_第6张图片

你可能感兴趣的:(#,Hive)