Hive自定义UDF函数

以下基于hive 3.1.2版本

Hive中自定义UDF函数,有两种实现方式,一是通过继承org.apache.hadoop.hive.ql.exec.UDF类实现,二是通过继承org.apache.hadoop.hive.ql.udf.generic.GenericUDF类实现。

无论是哪种方式,实现步骤都是:

  1. 继承特定类,实现接口或方法
  2. 打jar包
  3. 将生成的jar包加入到hive环境中
  4. 在hive中创建jar包中实现类的对应函数

首先引入pom依赖:


    org.apache.hive
    hive-exec
    3.1.2

1. UDF实现

继承UDF类实现时只需要实现evaluate方法就可以了,写之前,找了replace函数的源码用来参考,源码贴在下面:

package org.apache.hadoop.hive.ql.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.io.Text;


/**
 * UDFReplace replaces all substrings that are matched with a replacement substring.
 *
 */
@Description(name = "replace",
    value = "_FUNC_(str, search, rep) - replace all substrings of 'str' that "
    + "match 'search' with 'rep'", extended = "Example:\n"
    + "  > SELECT _FUNC_('Hack and Hue', 'H', 'BL') FROM src LIMIT 1;\n"
    + "  'BLack and BLue'")
public class UDFReplace extends UDF {

  private Text result = new Text();

  public UDFReplace() {
  }

  public Text evaluate(Text s, Text search, Text replacement) {
    if (s == null || search == null || replacement == null) {
      return null;
    }
    String r = s.toString().replace(search.toString(), replacement.toString());
    result.set(r);
    return result;
  }
}

模仿上面,自己定义了个函数,功能和hive中的repeat函数一样:

package com.demo.hive;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;


@Description(name = "my_repeat",				// 用于描述该类在hive中对应的函数名,一般与hive中的映射函数名保持一致
        value = "_FUNC_(str, n): repeat str n times",       // "desc function xxx"时显示的内容
        extended = "Example SQL: select _FUNC_('a',3);\nResult: 'aaa'")   // "desc function extended xxx"时显示的内容
public class MyUDFRepeat extends UDF {
    // 涉及到hive中的字符或字符串类型,建议使用Text类处理
    private Text res = new Text();

    public Text evaluate(Text str, IntWritable n) {
        if (str == null || n == null) {
            return null;
        }

        if (n.get() > 0) {
            byte[] arr = str.getBytes();
            byte[] newArr = new byte[str.getLength() * n.get()];

            for (int i = 0; i < n.get(); i++) {
                System.arraycopy(arr, 0, newArr, i * str.getLength(), str.getLength());
            }
            res.set(newArr);
        }
        return res;
    }
}

在写上面这个函数时,最开始出现了一些问题,逻辑上怎么检查都没看出来,捯饬了将近一天才发现原来是Text类中的getByte()和String中的getByte()略有区别(返回的字节数组长度并不相等),后来将所有的str.getbytes().length换成str.getLength()就好了,这里以后再深入研究一下。关于Text类的API:https://hadoop.apache.org/docs/r3.1.2/api/index.html

将上面源码打成jar包之后上传到hive服务所在主机或者hadoop上,然后在本地idea中执行:

add jar /root/HiveLib/hive_udf-1.0-SNAPSHOT.jar;						// jar包加入到hive环境
create temporary function my_repeat as 'com.demo.hive.MyUDFRepeat';		// 创建临时函数,只对当前session生效

创建完函数可以查看一下函数详细信息:
desc function extended my_repeat;
Hive自定义UDF函数_第1张图片
跑下测试数据验证效果:
select *,my_repeat(name,2),repeat(name,2) from db_prac.employee;
Hive自定义UDF函数_第2张图片

2. GenericUDF实现

同样先贴一下length函数的源码,通过GenericUDF类实现需要实现父类中的三个抽象方法:initialize()、evaluate()、getDisplayString()

package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLength;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.lazy.LazyBinary;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;

/**
 * GenericUDFLength.
 *
 */
@Description(name = "length",
    value = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data",
    extended = "Example:\n"
    + "  > SELECT _FUNC_('Facebook') FROM src LIMIT 1;\n" + "  8")
@VectorizedExpressions({StringLength.class})
public class GenericUDFLength extends GenericUDF {
  private final IntWritable result = new IntWritable();
  private transient PrimitiveObjectInspector argumentOI;
  private transient PrimitiveObjectInspectorConverter.StringConverter stringConverter;
  private transient PrimitiveObjectInspectorConverter.BinaryConverter binaryConverter;
  private transient boolean isInputString;

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    if (arguments.length != 1) {
      throw new UDFArgumentLengthException(
          "LENGTH requires 1 argument, got " + arguments.length);
    }

    if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentException(
          "LENGTH only takes primitive types, got " + argumentOI.getTypeName());
    }
    argumentOI = (PrimitiveObjectInspector) arguments[0];

    PrimitiveObjectInspector.PrimitiveCategory inputType = argumentOI.getPrimitiveCategory();
    ObjectInspector outputOI = null;
    switch (inputType) {
      case CHAR:
      case VARCHAR:
      case STRING:
        isInputString = true;
        stringConverter = new PrimitiveObjectInspectorConverter.StringConverter(argumentOI);
        break;

      case BINARY:
        isInputString = false;
        binaryConverter = new PrimitiveObjectInspectorConverter.BinaryConverter(argumentOI,
            PrimitiveObjectInspectorFactory.writableBinaryObjectInspector);
        break;

      default:
        throw new UDFArgumentException(
            " LENGTH() only takes STRING/CHAR/VARCHAR/BINARY types as first argument, got "
            + inputType);
    }

    outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
    return outputOI;
  }

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    byte[] data = null;
    if (isInputString) {
      String val = null;
      if (arguments[0] != null) {
        val = (String) stringConverter.convert(arguments[0].get());
      }
      if (val == null) {
        return null;
      }

      data = val.getBytes();

      int len = 0;
      for (int i = 0; i < data.length; i++) {
        if (GenericUDFUtils.isUtfStartByte(data[i])) {
          len++;
        }
      }
      result.set(len);
      return result;
    } else {
      BytesWritable val = null;
      if (arguments[0] != null) {
        val = (BytesWritable) binaryConverter.convert(arguments[0].get());
      }
      if (val == null) {
        return null;
      }

      result.set(val.getLength());
      return result;
    }
  }

  @Override
  public String getDisplayString(String[] children) {
    return getStandardDisplayString("length", children);
  }
}

模仿上面,下面写了个判断是否是子字符串的函数:

package com.demo.hive;


import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;

@Description(name = "str_contains", value = "_FUNC_(str1, str2): return true if str1 contains str2, else return false")
public class MyGenericUDFContains extends GenericUDF {
    private StringObjectInspector pos1;
    private StringObjectInspector pos2;

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        // 检查参数个数
        if (arguments.length != 2) {
            throw new UDFArgumentLengthException("参数个数必须为2");
        }
        // 检查参数类型
        if (!(arguments[0] instanceof StringObjectInspector) || !(arguments[1] instanceof StringObjectInspector)) {
            throw new UDFArgumentException("参数必须都为String类型");
        }
        this.pos1 = (StringObjectInspector) arguments[0];
        this.pos2 = (StringObjectInspector) arguments[1];
        // 函数结果返回类型为布尔类型
        return PrimitiveObjectInspectorFactory.javaBooleanObjectInspector;
    }

    @Override
    public Object evaluate(DeferredObject[] arguments) throws HiveException {
        String str1 = this.pos1.getPrimitiveJavaObject(arguments[0].get());
        String str2 = this.pos2.getPrimitiveJavaObject(arguments[1].get());
        return str1.contains(str2) ? Boolean.TRUE : Boolean.FALSE;
    }

    @Override
    public String getDisplayString(String[] children) {
        return getStandardDisplayString("str_contains", children);
    }
}

打jar包上传之后,创建映射函数:
create temporary function str_contains as 'com.demo.hive.MyGenericUDFContains';
查看一下函数信息:
desc function extended str_contains;
在这里插入图片描述
跑一下测试数据:
select name, str_contains(name,"i") from db_prac.employee;
Hive自定义UDF函数_第3张图片
end

总结

  1. UDF类实现简单,只需要实现evaluate()方法就可以了,并且该方法支持重载;GenericUDF类相对于UDF类复杂了一些,但提供了更加灵活的参数检查和更丰富的参数类型,开发中根据实际情况选择。
  2. 上面的注册方式为临时注册,注册的函数只在当前session有效,一般只是测试使用。如果需要永久注册,可以先将jar包上传hdfs,然后通过命令create function my_repeat as 'com.demo.hive.MyUDFRepeat' using jar "hdfs:/user/hive/lib/hive_udf-1.0-SNAPSHOT.jar";永久注册。
    删除注册过的函数:drop [temporary] function xxx;

你可能感兴趣的:(Hadoop大数据,hive,hadoop,大数据)