Spark中的UDTF

1、介绍

之前的文章中讲到过如何编写Spark的UDF函数之前的文章如下:

https://blog.csdn.net/Aaron_ch/article/details/113346185

那么何为UDTF呢??又如何在Spark中使用UDTF呢??

1.1、何为UDTF

其实做过大数据的,熟悉Hive小伙伴一定知道,Hive中也有很多常用官方UDTF,

explode
json_tuple
get_splits

等等。

就是把一行数据,转换为多行多列。简单来讲如下:

输入 {"test01":"hhh","test02":{"test03":"yyyy","test04":"uuuu"}} 这样的字符串

输出

col1 col2
hhh yyyy
hhh uuuu

 

 

 

 

1.2、如何使用

查看源码中其实是没有UDTF的相关接口信息的,去官方看下:

Spark SQL supports integration of Hive UDFs, UDAFs and UDTFs. Similar to Spark UDFs and UDAFs, Hive UDFs work on a single row as input and generate a single row as output, while Hive UDAFs operate on multiple rows and return a single aggregated row as a result. In addition, Hive also supports UDTFs (User Defined Tabular Functions) that act on one row as input and return multiple rows as output. To use Hive UDFs/UDAFs/UTFs, the user should register them in Spark, and then use them in Spark SQL queries.

 能够明显看出来,Spark的UDTF函数完全用的就是Hive的,官网链接(https://spark.apache.org/docs/3.1.1/sql-ref-functions-udf-hive.html#conten

 查看官方列子,可以看出来,编写的时候,直接继承org.apache.hadoop.hive.ql.udf.generic.GenericUDTF就行

1.2.1、代码实例

以解析Json字符串为例:

主体代码为:

public class AnalysisJsonToArrayUDTF extends GenericUDTF {


    @Override
    public StructObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {

        if (args.length < 2) {
            throw new UDFArgumentLengthException("At least two parameters are needed,plz check!");
        }

        int i;
        for (i = 0; i < args.length; ++i) {
            if (args[i].getCategory() != ObjectInspector.Category.PRIMITIVE || !args[i].getTypeName().equals("string")) {
                throw new UDFArgumentException("get_json_arrary()'s arguments have to be string type");
            }
        }

        //定义返回的数据的列名
        ArrayList fieldNames = new ArrayList<>();
        fieldNames.add("col");

        //定义返回的数据的列类型
        ArrayList fieldOIs = new ArrayList<>();
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public void process(Object[] objects) throws HiveException {

        if (objects.length < 2) {
            throw new UDFArgumentLengthException("At least two parameters are needed,plz check!");
        }

        String jsonStr = objects[0].toString();
        String jsonKey = objects[1].toString();

        Map> keyValueMap2 = new IdentityHashMap<>();

        String psKey;
        String psValue;
        String key;
        JsonUtils.dbJSONFormatIntoMap(JsonUtils.str2FastJSON(jsonStr), keyValueMap2);


        for (Map.Entry> mapEntry : keyValueMap2.entrySet()) {


            for (Map.Entry mapEntry2 : mapEntry.getValue().entrySet()) {


                psKey = mapEntry2.getKey();

                psValue = mapEntry2.getValue();
                int keyLength = psKey.split("[.]").length;

                if (keyLength < 2) {
                    if (jsonKey.equals(psKey)) {
                        forward(psValue);
                    }
                }


                key = psKey.split("[.]")[keyLength - 1];


                if (keyLength >= 2 && key.equals(jsonKey) && !StringUtils.isEmpty(psValue)) {
                    String mp2 = "{\"" + key + "\":" + psValue + "}";

                    for (Map.Entry me2 : JsonUtils.str2FastJSON(mp2).entrySet()) {
                        if (me2.getValue() instanceof JSONArray) {
                            JSONArray jsonArray = (JSONArray) me2.getValue();
                            for (int i = 0; i < jsonArray.size(); i++) {
                                System.out.println(jsonArray.get(i));
                                forward(jsonArray.get(i).toString());
                            }
                        } else {
                            forward(JsonUtils.trimBothEndsChars(psValue, "\\[\\]"));
                        }
                    }

                }
            }
        }

    }

    @Override
    public void close() throws HiveException {

    }

}

main函数调用

public static void main(String[] args) {

        SparkConf conf = new SparkConf()
                .setAppName("Sync")
                .set("hive.exec.dynamici.partition", "true")
                .set("hive.exec.dynamic.partition.mode", "nonstrict")
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                .set("spark.sql.autoBroadcastJoinThreshold", "204800")
                .set("spark.debug.maxToStringFields", "1000")
                .set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
                .setMaster("local[*]");

        SparkSession sparkSession = SparkSession.builder()
                .config(conf).enableHiveSupport().getOrCreate();

        sparkSession.sql("create temporary function marketing_json_trun as 'AnalysisJsonToArrayUDTF'");
        String jsonStr = "{\"test01\":[{\"gj\":{\"sf\":\"js\"}},{\"ds\":\"nj\"},{\"ds\":\"sh\"}]}";
        sparkSession.sql("select marketing_json_trun('"+jsonStr+"','ds')").show();

    }

 

你可能感兴趣的:(大数据,#,Spark,spark)