先分组求最大值,再把所有组求和的hive自定义UDAF

参数可变,除最后一个外,都是分组字段。 import java.math.BigDecimal;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

import com.letv.bigdata.hive.udaf.MostOccuItem.GenericUDAFMkListEvaluator;
import com.letv.bigdata.hive.udaf.MostOccuItem.GenericUDAFMkListEvaluator.MkArrayAggregationBuffer;

@Description(name = "data_tds", value = "_FUNC_(x) - Returns a double object that is sum of max value in each group. \n"
 + "In other words ,it will select the max value for each group , then it will sum the value of all group.\n"
 +"Usage example:\n"
 + "add jar bigdata_mxhz.jar;\n "
 + "create temporary function sumByGroupMax as 'com.letv.bigdata.hive.udaf.SumByGroupMax';\n"

+" select province,sumByGroupMax(case when act in('init','play','time') then 1 else 0 end,letv_cookie,uuid) \n"
+" from data_raw.tbl_play_hour \n"
+" where dt='20141203' \n"
+" and product='1' \n"
+" and hour='04' \n"
+" group by province; \n"


+" The previous sql is equals to the follow sql: \n"

+" select province,sum(num) vv from ( \n"
+" select province,max(case when act in('init','play','time') then 1 else 0 end) num \n"
 +" from data_raw.tbl_play_hour \n"
+" where dt='20141203' \n"
+" and product='1' \n"
+" and hour='04' \n"
+" and act in('init','play','time') \n"
+" group by province,letv_cookie,uuid \n"
+" )tmp group by province ; \n"

+ "CAUTION will easily cause Out Of Memmory Exception on large data sets")
/**  *   * @author houzhizhen  * create temporary function mostOccrItem as com.letv.bigdata.hive.udaf.MostOccuItem  *   *   public static enum Mode {   * PARTIAL1: from original data to partial aggregation data: iterate() and  * terminatePartial() will be called.   PARTIAL1,   * PARTIAL2: from partial aggregation data to partial aggregation data:  * merge() and terminatePartial() will be called.   PARTIAL2,   * FINAL: from partial aggregation to full aggregation: merge() and  * terminate() will be called.   FINAL,   * COMPLETE: from original data directly to full aggregation: iterate() and  * terminate() will be called.   COMPLETE  };   */
public class SumByGroupMax extends AbstractGenericUDAFResolver {
 private static char SEPERATOR = '\u0001';
 static final Log LOG = LogFactory.getLog(SumByGroupMax.class.getName());

 public SumByGroupMax() {
 }

 @Override
 public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
 throws SemanticException {
 if (parameters.length < 2) {
 throw new UDFArgumentTypeException(parameters.length - 1,
 "Exactly one argument is expected.");
 }

 if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
 throw new UDFArgumentTypeException(0,
 "Only primitive type arguments are accepted but "
 + parameters[0].getTypeName()
 + " was passed as parameter 1.");
 }
 return new GenericUDAFMkListEvaluator();
 }

 public static class GenericUDAFMkListEvaluator extends GenericUDAFEvaluator {
 // private PrimitiveObjectInspector inputOI;

 private StandardMapObjectInspector mapOI;

 @Override
 public ObjectInspector init(Mode m, ObjectInspector[] parameters)
 throws HiveException {
 super.init(m, parameters);
 if (m == Mode.PARTIAL1) {
 // inputOI = (PrimitiveObjectInspector) parameters[0];
 return ObjectInspectorFactory
 .getStandardMapObjectInspector(
 (PrimitiveObjectInspector) ObjectInspectorUtils
 .getStandardObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
 PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
 } else if (m == Mode.PARTIAL2) {
 mapOI = (StandardMapObjectInspector) parameters[0];
 return ObjectInspectorFactory
 .getStandardMapObjectInspector(
 PrimitiveObjectInspectorFactory.javaStringObjectInspector,
 PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
 } else if (m == Mode.FINAL) {
 mapOI = (StandardMapObjectInspector) parameters[0];
 return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
 } else if (m == Mode.COMPLETE) {
 return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
 } else {
 throw new RuntimeException("no such mode Exception");
 }
 }

 static class MkArrayAggregationBuffer implements AggregationBuffer {
 Map<String, Double> container;
 }

 @Override
 public void reset(AggregationBuffer agg) throws HiveException {
 ((MkArrayAggregationBuffer) agg).container = new HashMap<String, Double>();
 }

 @Override
 public AggregationBuffer getNewAggregationBuffer() throws HiveException {
 MkArrayAggregationBuffer ret = new MkArrayAggregationBuffer();
 reset(ret);
 return ret;
 }

 // Mapside
 @Override
 public void iterate(AggregationBuffer agg, Object[] parameters)
 throws HiveException {
 Double value = 0d;
 if (parameters[0] != null) {
 value = Double.valueOf(parameters[0].toString());
 }
 StringBuffer keyBuffer = new StringBuffer();
 if (parameters[1] != null) {
 keyBuffer.append(parameters[1].toString());
 }

 for (int i = 2; i < parameters.length; i++) {
 keyBuffer.append(SEPERATOR);
 if (parameters[i] != null) {
 keyBuffer.append(parameters[i].toString());
 }

 }

 MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
 if (keyBuffer != null) {
 putIntoMap(keyBuffer.toString(), myagg, value);
 } else {
 putIntoMap(null, myagg, value);
 }
 }

 // Mapside
 @Override
 public Object terminatePartial(AggregationBuffer agg)
 throws HiveException {
 MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
 Map<String, Double> ret = new HashMap<String, Double>(
 myagg.container);

 return ret;
 }

 @Override
 public void merge(AggregationBuffer agg, Object partial)
 throws HiveException {
 MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;

 Map partialResult = mapOI.getMap(partial);

 for (Object key : partialResult.keySet()) {

 putIntoMap(key.toString(), myagg,
 Double.valueOf(partialResult.get(key).toString()));
 }
 }

 @Override
 public Double terminate(AggregationBuffer agg) throws HiveException {
 Map<Object, Integer> map = new HashMap<Object, Integer>();
 MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
 BigDecimal sum = new BigDecimal(0);
 String key = null;
 for (Map.Entry<String, Double> entry : myagg.container.entrySet()) {
 sum = sum.add(new BigDecimal(entry.getValue()));
 }
 return sum.doubleValue();
 }

 private void putIntoMap(String p, MkArrayAggregationBuffer myagg,
 Double num) {
 // Object pCopy =
 // ObjectInspectorUtils.copyToStandardObject(p,this.inputOI);
 Double i = myagg.container.get(p);
 if (i == null || i < num) {
 myagg.container.put(p, num);
 } else {
 //do nothing
 }
 
 }
 }
}

你可能感兴趣的:(hive)