HIVE - UDAF开发(字符串中出现 指定字符的次数,再求次数的平均数)

一、实现展示
hive> desc test_avg_str_in_str;
user_id             	int
name                	string
value               	int

hive> select * from test_avg_str_in_str;
1	awuz	1
1	azhaoz	1
2	zhangsan	2
2	lisi	2
2	wangwu	3

-- UDAF: avgStr (找到name中出现z的次数,再求平均数)
-- 难点在计算平均数的时候,中间结果需要保存 总值和计数值,需要用到 LazyBinaryStruct 结构
hive> select user_id, avgStr(name, "z") from test_avg_str_in_str group by user_id;
1	1.5
2	0.333333

PS. 这个UDAF实现的功能目前自己瞎想的,没有啥业务应用…

二、关键函数

HIVE - UDAF开发(字符串中出现 指定字符的次数,再求次数的平均数)_第1张图片

  • PARTIAL1: map阶段, 调用iterate()和terminatePartial()
  • PARTIAL2: map端的Combiner阶段,调用merge() 和 terminatePartial()
  • FINAL: reduce阶段,调用merge()和terminate()
// 确定各个阶段输入输出参数的数据格式ObjectInspectors
public  ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;
// 保存数据聚集结果的类
abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
// 重置聚集结果
public void reset(AggregationBuffer agg) throws HiveException;
// map阶段,迭代处理输入sql传过来的列数据
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
// map与combiner结束返回结果,得到部分数据聚集结果
public Object terminatePartial(AggregationBuffer agg) throws HiveException;
// combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果。
public void merge(AggregationBuffer agg, Object partial) throws HiveException;
// reducer阶段,输出最终结果
public Object terminate(AggregationBuffer agg) throws HiveException;
三、代码CODE
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import java.util.ArrayList;
import java.util.List;


public class AvgStrInStrUDAF extends AbstractGenericUDAFResolver {
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
        if (parameters.length != 2) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly tow argument is expected.");
        }
        return new AvgCharInStringEvaluator();
    }

    public static class AvgCharInStringEvaluator extends GenericUDAFEvaluator {

        private Object[] outKey = {new LongWritable(),new LongWritable()};
        private DoubleWritable result;

        private static class AvgAgg implements AggregationBuffer {
            private Long sum = 0L;
            private Long count = 0L;
            public Long getSum() {
                return sum;
            }
            public void setSum(Long sum) {
                this.sum = sum;
            }
            public Long getCount() {
                return count;
            }
            public void setCount(Long count) {
                this.count = count;
            }
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            return new AvgAgg();
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            AvgAgg ag = (AvgAgg) agg;
            ag.setSum(0L);
            ag.setCount(0L);
        }

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            super.init(m, parameters);
            if (m.equals(Mode.PARTIAL1) || m.equals(Mode.PARTIAL2)) {
                List<String> structFieldNames = new ArrayList();
                List<ObjectInspector> structFieldTypes = new ArrayList();
                structFieldNames.add("sum");
                structFieldNames.add("count");

                structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                structFieldTypes.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);

                return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldTypes);
            }
            result = new DoubleWritable(0.0);
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            if (parameters == null) {
                return;
            }
            if (parameters[0] != null && parameters[1] != null) {
                String s1 = parameters[0].toString();
                String s2 = parameters[1].toString();
                long count = (s1.length()-s1.replace(s2, "").length())/s2.length();
                AvgAgg ag = (AvgAgg) agg;
                ag.setSum(ag.getSum() + count);
                ag.setCount(ag.getCount() + 1);
            }
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            AvgAgg ag = (AvgAgg) agg;
            ((LongWritable) outKey[0]).set(ag.getSum());
            ((LongWritable) outKey[1]).set(ag.getCount());
            return outKey;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            if (partial != null) {
                AvgAgg ag = (AvgAgg) agg;

                LongWritable sum = null;
                LongWritable count = null;
                if (partial instanceof LazyBinaryStruct) {
                    LazyBinaryStruct ls = (LazyBinaryStruct) partial;
                    sum = (LongWritable) ls.getField(0);
                    ag.setSum(Long.parseLong(sum + ""));
                    count = (LongWritable) ls.getField(1);
                    ag.setCount(Long.parseLong(count + ""));
                }
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            AvgAgg ag = (AvgAgg) agg;
            Double d = Double.parseDouble(ag.getSum() * 1.0 / ag.getCount() + "");
            result.set(d);
            return result;
        }
    }
}

参考文章

https://blog.csdn.net/kent7306/article/details/50110067
https://blog.csdn.net/Nougats/article/details/71978752
https://www.jianshu.com/p/7ebc8f9c9b78

你可能感兴趣的:(Hive)