离线数仓项目:自定义UDAF函数

参考官网:GenericUDAFCaseStudy - Apache Hive - Apache Software Foundationhttps://cwiki.apache.org/confluence/display/Hive/GenericUDAFCaseStudy

package comxxx.hive;

import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

import java.text.DecimalFormat;
import java.util.*;

/**
 * 1.Writing the resolver -- 负责解析函数的元数据,函数传入的参数的类型检查。函数返回值的说明等
 * 2.Writing the evaluator --负责计算
 *     2.1getNewAggregationBuffer
 *     2.2iterate
 *     2.3terminatePartial
 *     2.4merge
 *     2.5terminate
 * 
 * 3.UDAF的运行原理:
 *     ①在group by 分组后运行
 *     ②运行的范围是分组的一组内
 *     ③依次对组中的每一行进行计算,最终得到一行结果
 * 4.函数如何用? --分组后直接调用函数,传入spu_name
 * select
 *     coupon_id,myudaf(spu_name)
 * from test6
 * group by coupon_id
 */
public class MyUDAF extends AbstractGenericUDAFResolver {

    // 创建一个自己定义的Evaluator
    @Override
    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
        // Type-checking goes here! --进行函数输入的参数类型检查
        //获取函数输入的参数
        TypeInfo[] parameters = info.getParameters();
        //对参数个数进行验证 -- 我们要求只传入一列
        if (parameters.length != 1) {
            throw new UDFArgumentException("参数个数只能是一个!");
        }
        //校验类型 --判断是不是基本数据类型
        if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentException("参数类型必须是基础数据类型!");
        }

        // 校验类型 --我们要求必须是String
        if (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()
                != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentException("参数类型必须是String!");
        }
        return new MyEvaluator();
    }

    // 定义Evaluator
    public static class MyEvaluator extends GenericUDAFEvaluator {
        /**
         * 需要手动调出init方法
         * 目的是给函数标识当前处于计算的哪个阶段(mode)
         * ObjectInspector :是类型检查器
         */
        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            //调用此方法获取当前的mode
            super.init(m, parameters);
            // This function should be overriden in every sub class
            // And the sub class should call super.init(m, parameters) to get mode set.

            //根据当前所处的阶段,获取当前阶段要使用的类型检查器
            /**
             * @param parameters In PARTIAL1 and COMPLETE mode, the parameters are original data;
             *                   In PARTIAL2 and FINAL mode, the parameters are just partial aggregations
             */
            if (m == Mode.PARTIAL2 || m == Mode.FINAL) {
                // 给mapOI赋值
                mapOI = (StandardMapObjectInspector) parameters[0];
            }
            /**
            @return In PARTIAL1 and PARTIAL2 mode, the ObjectInspector for the return value of terminatePartial() call;
         *          In FINAL and COMPLETE mode, the ObjectInspector for the return value of terminate() call.
             */
            if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
                //返回terminatePartial()返回值(是map)对应的类型检查器--map的检查器(k,v的检查器)
                return ObjectInspectorFactory.getStandardMapObjectInspector(
                        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                        PrimitiveObjectInspectorFactory.javaIntObjectInspector);
            } else {
                //返回terminate()返回值(是string)对应的类型检查器--string
                return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
            }
        }

        //创建一个新的缓冲区  --缓冲区是自己定义的!--我们需要存的是品牌名称和下单次数,需要map结构
        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            return new MyBuf();
        }

        // 重置缓冲区,清空缓冲区
        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            // 先将缓冲区对象强转成自己定义的缓冲区对象类型,再调出自己定义的缓冲区对象
            ((MyBuf) agg).buff.clear();
        }

        /**
         * 迭代输入的每一行,将结果存入缓冲区
         *
         * @param agg        缓冲区对象
         * @param parameters 输入的一行 --一列spu_name
         */
        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            //取出输入的一行中的spu_name --参数只有一列
            String spu_name = parameters[0].toString();
            //获取map
            HashMap buff = ((MyBuf) agg).buff;
            //将spu_name累加到map中 --先取出buff里原来的value,如果没有就默认为0,取出来之后再+1作为新的value,放入map中
            buff.put(spu_name, buff.getOrDefault(spu_name, 0) + 1);
            // 统计品牌次数
        }

        /**
         * 负责缓冲区序列化
         * Here persistable means the return value can only be built up in terms of Java primitives,
         * arrays, primitive wrappers (e.g. Double), Hadoop Writables, Lists, and Maps
         * 返回值只能是基础数据类型、arrays, primitive wrappers (e.g. Double), Hadoop Writables, Lists, and Maps等
         */
        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            return ((MyBuf) agg).buff;
        }

        /**
         * 将两个缓冲区进行合并,得到一个缓冲区
         *
         * @param agg 当前task的缓冲区
         * @param partial 从网络中接收的其他task序列化后的缓冲区对象,使用时需要反序列化
         *      反序列化之前需要先用缓冲区对应的ObjectInspector进行类型检查,检查通过,才能反序列化
         */
        //声明一个缓冲区对应的ObjectInspector --map类型的对象检查器
        //声明后是在Init()中为mapOI赋值,如果在此处赋值,后续仍会将此值清空
        private StandardMapObjectInspector mapOI;

        @Override
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            //取出当前缓冲区的map
            HashMap map1 = ((MyBuf) agg).buff;
            //对从网络中接收的缓冲区对象partial进行类型检查
            Map map2 = mapOI.getMap(partial);
            //对key,value的类型继续检查
            PrimitiveObjectInspector mapKeyObjectInspector = (PrimitiveObjectInspector)mapOI.getMapKeyObjectInspector();
            PrimitiveObjectInspector mapValueObjectInspector = (PrimitiveObjectInspector)mapOI.getMapValueObjectInspector();
            // 使用key,value的类型检测器,检测key,value是不是此类型,如果是,反序列化获取Key,value
            for (Map.Entry entry : map2.entrySet()) {
                String key = PrimitiveObjectInspectorUtils.getString(entry.getKey(), mapKeyObjectInspector);
                int value = PrimitiveObjectInspectorUtils.getInt(entry.getValue(), mapValueObjectInspector);
                //将当前缓冲区的map中的元素与从网络中接收的map中的元素进行合并
                map1.put(key,map1.getOrDefault(key,0) + value);
            }
        }

        /**
         * 基于最后合并的最终的缓冲区,计算得到函数输出的结果
         *
         * @param agg 最终合并的缓冲区
         */
        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            //先取出合并后的缓冲区map
            HashMap buff = ((MyBuf) agg).buff;
            //1.计算总次数
            double totalTimes = 0d;
            for (Integer value : buff.values()) { //for循环遍历得到每个spu_name对应的次数,进行累加
                totalTimes += value;
            }
            // 2.对每一个spu_name的times进行排序,取前三名 --map是非线性的,不能排序,要转成list等线性的结构进行排序
            // 先转成set,再转成list
            ArrayList> entryArrayList = new ArrayList<>(buff.entrySet());
            // 需要传入一个比较器
            entryArrayList.sort(new Comparator>() {
                // 降序排序(默认是升序,所以前面加个-号)
                @Override
                public int compare(Map.Entry o1, Map.Entry o2) {
                    return -o1.getValue().compareTo(o2.getValue());
                }
            });
            // 取前三 --subList(0,3)截取前三个 --但集合中可能没有三个,所以要取最小值Math.min(3,entryArrayList.size())
            List> top3Spu_name = entryArrayList.subList(0, Math.min(3, entryArrayList.size()));

            // 3.计算前三的比例之和,求其他的比例
            double top3Percent = 0d;
            // 声明一个存放每一个spu_name最终字符串的集合
            ArrayList strs = new ArrayList<>();
            //声明一个百分数格式化器
            DecimalFormat decimalFormat = new DecimalFormat("##.##%");
            for (Map.Entry entry : top3Spu_name) {
                // 计算前三的每个spu_name的占比
                double spu_percent = entry.getValue() / totalTimes;
                strs.add(entry.getKey() + ":" + decimalFormat.format(spu_percent));
                top3Percent += spu_percent;
            }
            // 计算其他的比例  --只有当前coupon_id下的spu_name>3才有其他 (spu_name有三个以上才有其他)
            if (entryArrayList.size() > 3) {
                strs.add("其他:" + decimalFormat.format(1 - top3Percent));
            }
            //将集合中的字符串拼接为结果--使用一个工具类
            String result = StringUtils.join(strs, ',');

            return result;
        }

        // UDAF logic goes here! --定义缓冲区对象
        static class MyBuf implements AggregationBuffer {
            //自定定义存储想存储的数据的结构 --我们需要存的是品牌名称和下单次数,需要map结构
            private HashMap buff = new HashMap<>();
        }
    }
}

你可能感兴趣的:(hive,hadoop,big,data,数据仓库,java)