Spark SQL的自定义函数UDF

Spark SQL的自定义函数UDF

1. 背景

  1. 在SQL使用时,会有内置函数,但如果业务比较复杂,但又希望可以有更加灵活的函数使用和复用,则需要自定义
  2. UDF,就是user defined function,可以分为UDTF、UDAF
  3. UDTF,user defined table-generating function,就是将数据打散
  4. UDAF,user defined aggregating function,就是将数据聚合。

2. 创建和使用UDF

下述会使用案例来展示如何使用自定义函数

2.1 自定义函数,将GPS经纬度解析为位置信息

  1. 环境准备
  • 高德地图的appkey
  • idea 2020
  • maven 3.6.3
  • scala 2.12.12
  • spark 3.0.1
  • pom文件

    <properties>
        <maven.compiler.source>1.8maven.compiler.source>
        <maven.compiler.target>1.8maven.compiler.target>
        <scala.version>2.12.10scala.version>
        <spark.version>3.0.1spark.version>
        <hbase.version>2.2.5hbase.version>
        <hadoop.version>3.2.1hadoop.version>
        <encoding>UTF-8encoding>
    properties>

    <dependencies>
        
        <dependency>
            <groupId>org.scala-langgroupId>
            <artifactId>scala-libraryartifactId>
            <version>${scala.version}version>
            
            
        dependency>

        
        <dependency>
            <groupId>org.apache.httpcomponentsgroupId>
            <artifactId>httpclientartifactId>
            <version>4.5.12version>
        dependency>

        <dependency>
            <groupId>org.apache.sparkgroupId>
            <artifactId>spark-sql_2.12artifactId>
            <version>${spark.version}version>
        dependency>

        <dependency>
            <groupId>org.apache.sparkgroupId>
            <artifactId>spark-core_2.12artifactId>
            <version>${spark.version}version>
            
            
        dependency>

        
        <dependency>
            <groupId>com.alibabagroupId>
            <artifactId>fastjsonartifactId>
            <version>1.2.73version>
        dependency>

        
        <dependency>
            <groupId>mysqlgroupId>
            <artifactId>mysql-connector-javaartifactId>
            <version>5.1.47version>
        dependency>

    dependencies>

    <build>
        <pluginManagement>
            <plugins>
                
                <plugin>
                    <groupId>net.alchim31.mavengroupId>
                    <artifactId>scala-maven-pluginartifactId>
                    <version>3.2.2version>
                plugin>
                
                <plugin>
                    <groupId>org.apache.maven.pluginsgroupId>
                    <artifactId>maven-compiler-pluginartifactId>
                    <version>3.5.1version>
                plugin>
            plugins>
        pluginManagement>
        <plugins>
            <plugin>
                <groupId>net.alchim31.mavengroupId>
                <artifactId>scala-maven-pluginartifactId>
                <executions>
                    <execution>
                        <id>scala-compile-firstid>
                        <phase>process-resourcesphase>
                        <goals>
                            <goal>add-sourcegoal>
                            <goal>compilegoal>
                        goals>
                    execution>
                    <execution>
                        <id>scala-test-compileid>
                        <phase>process-test-resourcesphase>
                        <goals>
                            <goal>testCompilegoal>
                        goals>
                    execution>
                executions>
            plugin>

            <plugin>
                <groupId>org.apache.maven.pluginsgroupId>
                <artifactId>maven-compiler-pluginartifactId>
                <executions>
                    <execution>
                        <phase>compilephase>
                        <goals>
                            <goal>compilegoal>
                        goals>
                    execution>
                executions>
            plugin>

            
            <plugin>
                <groupId>org.apache.maven.pluginsgroupId>
                <artifactId>maven-shade-pluginartifactId>
                <version>2.4.3version>
                <executions>
                    <execution>
                        <phase>packagephase>
                        <goals>
                            <goal>shadegoal>
                        goals>
                        <configuration>
                            <filters>
                                <filter>
                                    <artifact>*:*artifact>
                                    <excludes>
                                        <exclude>META-INF/*.SFexclude>
                                        <exclude>META-INF/*.DSAexclude>
                                        <exclude>META-INF/*.RSAexclude>
                                    excludes>
                                filter>
                            filters>
                        configuration>
                    execution>
                executions>
            plugin>
        plugins>
    build>
object GeoFunc {

  // 根据经纬度返回省和市信息
  val geo = (longitude:Double, latitude: Double) =>  {
    val httpClient: CloseableHttpClient = HttpClients.createDefault()

    // 构建请求参数
    val httpGet = new HttpGet(s"https://restapi.amap.com/v3/geocode/regeo?&location=$longitude,$latitude&key=71cc7d9df22483b27ec40ecb45d9d87b")

    // 发送请求,获取返回信息
    val response: CloseableHttpResponse = httpClient.execute(httpGet)

    var province:String = null
    var city:String = null
    try {
      // 将返回对象中数据提取出来
      val entity: HttpEntity = response.getEntity

      if (response.getStatusLine.getStatusCode == 200) {
        // 将返回对象中数据转换为字符串
        val resultStr: String = EntityUtils.toString(entity)

        // 解析返回的json字符串
        val jSONObject: JSONObject = JSON.parseObject(resultStr)

        // 根据高德地图反地理编码接口返回数据中字段进行数据解析
        val regeocode: JSONObject = jSONObject.getJSONObject("regeocode")

        if (regeocode != null && regeocode.isEmpty == false) {
          val address: JSONObject = regeocode.getJSONObject("addressComponent")

          province = address.getString("province")
          city = address.getString("city")
        }
      }
    } catch {
      case e: Exception => {}
    } finally {
      // 每一次数据请求之后,关闭连接
      response.close()

      httpClient.close()
    }

    (province, city)
  }
}
object UDFTest1 {
  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDFTest1")
      .master("local")
      .getOrCreate()

    import sparkSession.implicits._

    // 118.396128,"latitude":35.916527
    val dataset: Dataset[(String, String)] = sparkSession.createDataset(List(("a", "118.396128,35.916527"), ("b", "118.596128,35.976527")))

    val dataFrame: DataFrame = dataset.toDF("uid", "location")

    dataFrame.createTempView("v_location")

    sparkSession.udf.register("geo", GeoFunc.geo)

    dataFrame.show()

    val dataFrame1: DataFrame = sparkSession.sql(
      """
        |select
        |uid,
        |geo(loc1, loc2) as province_city
        |from
        |(
        |  select
        |  uid,
        |  cast(loc_pair[0] as double) as loc1,
        |  cast(loc_pair[1] as double) as loc2
        |  from
        |  (
        |    select
        |    uid,
        |    split(location, '[,]') as loc_pair
        |    from
        |    v_location
        |  )
        |)
        |""".stripMargin)

    dataFrame1.show()

    sparkSession.stop()
  }
}

/*
*
-- 先切割数据
select
city,
split(location, '[,]') as loc_pair
from
v_location


-- 将数据转换为double
select
uid,
cast(loc_pair[0] as double) as loc1,
cast(loc_pair[1] as double) as loc2
from
(
  select
  uid,
  split(location, '[,]') as loc_pair
  from
  v_location
)

-- 调用自定义函数进行数据查询
select
uid,
geo(loc1, loc2) as province_city
from
(
  select
  uid,
  cast(loc_pair[0] as double) as loc1,
  cast(loc_pair[1] as double) as loc2
  from
  (
    select
    uid,
    split(location, '[,]') as loc_pair
    from
    v_location
  )
)
*
* */

2.2 自定义拼接字符串函数

object UDF_CustomConcat {
  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDF_CustomConcat")
      .master("local")
      .getOrCreate()

    import sparkSession.implicits._

    // 创建dataset,再转换为dataframe
    val dataset: Dataset[(String, String)] = sparkSession.createDataset(List(("湖南", "长沙"), ("江西", "南昌"), ("湖北", "武汉")))

    val dataFrame: DataFrame = dataset.toDF("province", "city")

    // 自定义函数,注意函数名尽量规范,见名知意一些
    val udf_func = (arg1:String, arg2:String) => {
      arg1 + "-" + arg2
    }

    // 注册自定义函数,注意这个是临时注册,只有这个代码中才可以生效
    sparkSession.udf.register("udf_func", udf_func)

    // 使用sql之前,先注册视图
    dataFrame.createTempView("v_test")

    val dataFrame1: DataFrame = sparkSession.sql("select udf_func(province, city) as concat_result from v_test;")

    dataFrame1.show()

    sparkSession.close()
  }
}

2.3 将Ip地址转换为省(市区)地理位置信息

  1. 环境准备
  • ip字典(比较大,只展示部分,可以去淘宝、拼多多、咸鱼等上购买此类数据资产)
1.4.8.0|1.4.127.255|17041408|17072127|亚洲|中国|广东|广州||电信|440100|China|CN|113.280637|23.125178
1.8.0.0|1.8.255.255|17301504|17367039|亚洲|中国|北京|北京|海淀|北龙中网|110108|China|CN|116.29812|39.95931
1.10.0.0|1.10.7.255|17432576|17434623|亚洲|中国|广东|广州||电信|440100|China|CN|113.280637|23.125178
1.10.8.0|1.10.9.255|17434624|17435135|亚洲|中国|福建|福州||电信|350100|China|CN|119.306239|26.075302
1.10.11.0|1.10.15.255|17435392|17436671|亚洲|中国|福建|福州||电信|350100|China|CN|119.306239|26.075302
1.10.16.0|1.10.127.255|17436672|17465343|亚洲|中国|广东|广州||电信|440100|China|CN|113.280637|23.125178
1.12.0.0|1.12.255.255|17563648|17629183|亚洲|中国|北京|北京||方正宽带|110100|China|CN|116.405285|39.904989
1.13.0.0|1.13.71.255|17629184|17647615|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841
1.13.72.0|1.13.87.255|17647616|17651711|亚洲|中国|吉林|吉林||方正宽带|220200|China|CN|126.55302|43.843577
1.13.88.0|1.13.95.255|17651712|17653759|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841
1.13.96.0|1.13.127.255|17653760|17661951|亚洲|中国|天津|天津||方正宽带|120100|China|CN|117.190182|39.125596
1.13.128.0|1.13.191.255|17661952|17678335|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841
1.13.192.0|1.14.95.255|17678336|17719295|亚洲|中国|辽宁|大连||方正宽带|210200|China|CN|121.618622|38.91459
1.14.96.0|1.14.127.255|17719296|17727487|亚洲|中国|辽宁|鞍山||方正宽带|210300|China|CN|122.995632|41.110626
1.14.128.0|1.14.191.255|17727488|17743871|亚洲|中国|上海|上海||方正宽带|310100|China|CN|121.472644|31.231706
1.14.192.0|1.14.223.255|17743872|17752063|亚洲|中国|吉林|长春||方正宽带|220100|China|CN|125.3245|43.886841
  • 日志数据
20090121000732398422000|122.73.114.24|aa.991kk.com|/html/taotuchaoshi/2009/0120/7553.html|Mozilla/5.0 (Windows; U; Windows NT 5.1; zh-CN; rv:1.9.0.1) Gecko/2008070208 Firefox/3.0.1|http://aa.991kk.com/html/taotuchaoshi/index.html|
20090121000732420671000|115.120.14.96|image.baidu.com|/i?ct=503316480&z=0&tn=baiduimagedetail&word=%B6%AF%CE%EF%D4%B0+%B3%A4%BE%B1%C2%B9&in=32346&cl=2&cm=1&sc=0&lm=-1&pn=527&rn=1&di=2298496252&ln=615|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; GTB5; TencentTraveler 4.0)|http://image.baidu.com/i?tn=baiduimage&ct=201326592&cl=2&lm=-1&pv=&word=%B6%AF%CE%EF%D4%B0+%B3%A4%BE%B1%C2%B9&z=0&rn=21&pn=525&ln=615|BAIDUID=C1B0C0D4AA4A7D1BF9A0F74C4B727970:FG=1; BDSTAT=c3a929956cf1d97d5982b2b7d0a20cf431adcbef67094b36acaf2edda2cc5bc0; BDUSS=jBXVi1tQ3ZTSDJiflVHRERTSUNiYUtGRmNrWkZTYllWOEJZSk1-V0xFNU1lcDFKQkFBQUFBJCQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEztdUlM7XVJZ; BDSP=e114da18f3deb48fff2c9a8ef01f3a292df5e0fe2b24463340405da85edf8db1cb1349540923dd54564e9258d109b3de9c82d158ccbf6c81800a19d8bc3eb13533fa828ba61ea8d3fd1f4134970a304e251f95cad1c8a786c9177f3e6709c93d72cf5979; iCast_Rotator_1_1=1232467533578; iCast_Rotator_1_2=1232467564718
20090121000732511280000|115.120.16.98|ui.ptlogin2.qq.com|/cgi-bin/login?appid=7000201&target=self&f_url=loginerroralert&s_url=http://minigame.qq.com/login/flashlogin/loginok.html|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; QQDownload 1.7; Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1) )|http://minigame.qq.com/act/rose0901/?aid=7000201&ADUIN=563586856&ADSESSION=1232467131&ADTAG=CLIENT.QQ.1833_SvrPush_Url.0|
20090121000732967450000|117.101.219.112|list.taobao.com|/browse/50010404-302903/n-1----------------------0---------yes-------g,giydmmjxhizdsnjwgy5tgnbsgyzdumjshe4dmoa--g,giydmmjxhlk6xp63hmztimrwgi5mnvonvc764kbsfu2gg3jj--g,ojsxgzlsozsv64dsnfrwkwzvgawdcmbqlu-------------------40-grid-ratesum-0-all-302903.htm|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30)|http://list.taobao.com/browse/50010404-302903/n-1----------------------0---------yes-------g,giydmmjxhizdsnjwgy5tgnbsgyzdumjshe4dmoa--g,giydmmjxhlk6xp63hmztimrwgi5mnvonvc764kbsfu2gg3jj---------------------42-grid-ratesum-0-all-302903.htm|
20090121000733245014000|117.101.227.3|se.360.cn|/web/navierr.htm?url=http://www.3320.net/blib/c/24/11839/&domain=www.3320.net&code=403&pid=sesoft&tabcount=7|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7; 360SE)||B=ID=435431224878393:V=2:S=8f59056144; __utma=148900148.1624674999336435000.1224880187.1226546993.1229991277.5; __utmz=148900148.1224880187.1.1.utmcsr=(direct)
20090121000733290585000|117.101.206.175|wwww.17kk.net|/0OO000OO00O00OOOOO0/new/cjbbs/zx1.htm|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7)|http://wwww.17kk.net/0OO000OO00O00OOOOO0/new/cjbbs/zzx.htm|rtime=11; ltime=1232384469187; cnzz_eid=54742851-1228495798-http%3A//wwww.17kk.net/0OO000OO00O00OOOOO0/tongji1_l7kk.htm; cck_lasttime=1232381515031; cck_count=0; cnzz_a508803=8; vw508803=%3A80391793%3A; sin508803=none; ASPSESSIONIDQQAQQCRT=GKKKBIFCLAJPKGHGEKDEAPPB; ASPSESSIONIDQSCRQCQS=BCIBBIFCMLLPGEPBCFMEHGOA; ASPSESSIONIDSQBSRDRT=GPLKBIFCJBIAHLLBJLDDANGN; ASPSESSIONIDSQBRRDRS=AHLDAIFCDIINIGLMEEJJDGDN; __utma=152924281.4523785370259723000.1228495189.1232381092.1232466255.16; __utmb=152924281.8.10.1232466255; __utmz=152924281.1228495189.1.1.utmcsr=(direct)
20090121000733387555000|117.101.206.175|wwww.17kk.net|/0OO000OO00O00OOOOO0/new/6cheng/nnts/180/sport.htm|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7)|http://wwww.17kk.net/0OO000OO00O00OOOOO0/new/6cheng/nnts/180/z8.htm|rtime=11; ltime=1232384469187; cnzz_eid=54742851-1228495798-http%3A//wwww.17kk.net/0OO000OO00O00OOOOO0/tongji1_l7kk.htm; cck_lasttime=1232381515031; cck_count=0; cnzz_a508803=8; vw508803=%3A80391793%3A; sin508803=none; ASPSESSIONIDQQAQQCRT=GKKKBIFCLAJPKGHGEKDEAPPB; ASPSESSIONIDQSCRQCQS=BCIBBIFCMLLPGEPBCFMEHGOA; ASPSESSIONIDSQBSRDRT=GPLKBIFCJBIAHLLBJLDDANGN; ASPSESSIONIDSQBRRDRS=AHLDAIFCDIINIGLMEEJJDGDN; __utma=152924281.4523785370259723000.1228495189.1232381092.1232466255.16; __utmb=152924281.8.10.1232466255; __utmz=152924281.1228495189.1.1.utmcsr=(direct)
20090121000733393911000|115.120.10.168|my.51.com|/port/ajax/main.accesslog.php|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7)|http://my.51.com/|
20090121000734192650000|115.120.9.235|www.baidu.com|/s?tn=mzmxzgx_pg&wd=xiao77%C2%DB%CC%B3|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7; Avant Browser; CIBA)|http://www.250cctv.cn/|BAIDUID=80DA16918ED68645445A6837338DBC5C:FG=1; BDSTAT=805379474b3ed4a4ab64034f78f0f736afc379310855b319ebc4b74541a9d141; BD_UTK_DVT=1; BDRCVFR[9o0so1JMIzY]=bTm-Pk1nd0D00; BDRCVFR[ZusMMNJpUDC]=QnHQ0TLSot3ILILQWcdnAPWIZm8mv3
20090121000734299056000|125.213.97.6|haort.com|/Article/200901/2071_3.html|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; QQDownload 1.7; GTB5)|http://haort.com/Article/200901/2071_2.html|krviewcurc=1; krvedlaid=4285; kcc_169767kanrt=90; AJSTAT_ok_times=2; rtime=0; ltime=1220372703640; cnzz_eid=3485291-http%3A//rentiart.com/js/3.htm; krviewcurc=2; krvedlaid=3720; cck_lasttime=1232468301734; cck_count=0; AJSTAT_ok_pages=14; Cookie9=PopAnyibaSite; kcc_169767kanrt=39
20090121000734469862000|117.101.213.66|book.tiexue.net|/Content_620501.html|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1)|http://book.tiexue.net/Content_619975.html|BookBackModel=5%2C1; RHistory=11225%2C%u7279%u6218%u5148%u9A71; __utma=247579266.194628209.1232339801.1232350177.1232464272.3; __utmb=247579266; __utmz=247579266.1232339801.1.1.utmccn=(direct)
20090121000734529619000|115.120.0.192|www.cqpa.gov.cn|/u/cqpa/news_12757.shtml|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1)|http://www.cqpa.gov.cn/u/cqpa/|ASPSESSIONIDQSRAAAST=LGAIOKNCHPHMKALKIHPODCOB
20090121000734819099000|117.101.225.140|jifen.qq.com|/static/mart/shelf/9.shtml|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; QQDownload 1.7; TencentTraveler 4.0; CIBA; .NET CLR 2.0.50727)|http://jifen.qq.com/mart/shelf_list.shtml?9|pvid=47052875; o_cookie=996957123; flv=9.0; pt2gguin=o0361474804; ptcz=1cc6f06a90bb8d1f53069184d85dd4d01dd8ca38eb7eb2fa615548538f133ede; r_cookie=912035936314; sc_cookie_floating_refresh241=3; icache=MMBMACEFG; uin_cookie=361474804; euin_cookie=AQAYuAH3EXdauOugz/OMzWPIssCyb0d3XzENGAAAAADefqSBU4unTT//nt3WNqaSQ2R44g==; pgv=ssid=s2273451828&SPATHTAG=CLIENT.PURSE.MyPurse.JifenInfo&KEYPATHTAG=2.1.1; verifysession=9b3f4c872a003e70cfe2ef5de1a62a3d862a448fd2f5b1b032512256fbd832fd7365b7d7619ef2ca; uin=o0361474804; skey=@BpkD0OWtL; JifenUserId=361474804; ACCUMULATE=g1qjCmEMXxtoOc1g00000681; _rsCS=1
20090121000735126951000|115.120.4.164|www.5webgame.com|/bbs/2fly_gift.php|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 1.1.4322)|http://www.5webgame.com/bbs/viewthread.php?tid=43&extra=page%3D1|
20090121000735482286000|125.213.97.254|tieba.baidu.com|/f?kz=527788861|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; TencentTraveler )|http://tieba.baidu.com/f?ct=&tn=&rn=&pn=&lm=&sc=&kw=%D5%DB%CC%DA&rs2=0&myselectvalue=1&word=%D5%DB%CC%DA&tb=on|BAIDUID=D87E9C0E1E427AD5EEB37C6CC4B9C5CE:FG=1; BD_UTK_DVT=1; AdPlayed=true
20090121000735619376000|115.120.3.253|m.163.com|/xxlwh/|Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; InfoPath.1)|http://blog.163.com/xxlwh/home/|
20090121000735819656000|115.120.13.149|2008.wg999.com|/|Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; QQDownload 1.7; TencentTraveler 4.0; Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1) )||ystat_bc_827474=23538508893341337937; ystat_bc_832488=29857733243775586653

object UDF_IP2Location {
  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDF_IP2Location")
      .master("local")
      .getOrCreate()

    import sparkSession.implicits._
    import org.apache.spark.sql.functions._

    // 读取文本文件
    val ipRules: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ip_location_dict.txt")

    // 将读取的ip规则字典加载到driver端,注意这里要做分布式缓存,省去join成本,也就是map端缓存
    // 数据转换,排序,去重,采集回driver端
    val ipRulesInDriver: Array[(Long, Long, String, String)] = ipRules.map(line => {
      val strings: Array[String] = line.split("[|]")
      val startIpNum: Long = strings(2).toLong
      val endIpNum: Long = strings(3).toLong
      val province: String = strings(6)
      val city: String = strings(7)

      (startIpNum, endIpNum, province, city)
    }).distinct()
      .sort($"_1" asc)
      .collect()

    // 注册广播变量
    val broadcastRefInDriver: Broadcast[Array[(Long, Long, String, String)]] = sparkSession.sparkContext.broadcast(ipRulesInDriver)

    // 自定义函数
    val ip2Location = (ip:String) => {
      val ipNumber: Long = IpUtils.ip2Long(ip)

      // 这里产生了闭包
      val ipRulesInExecutor: Array[(Long, Long, String, String)] = broadcastRefInDriver.value

      // 注意,文本文件中数据本身已经做过排序,但一般为了保险,一般都会再做一次排序
      val index: Int = IpUtils.binarySearch(ipRulesInExecutor, ipNumber)

      var province: String = null
      if(index > 0) {
        province = ipRulesInExecutor(index)._3
      }

      province
    }

    // 将自定义函数注册为udf函数
    sparkSession.udf.register("ip2Location", ip2Location)

    // 读取需要处理的日志数据
    val logLines: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ipaccess.log")

    val dataFrame: DataFrame = logLines.map(line => {
      val strings: Array[String] = line.split("[|]")
      val ipStr: String = strings(1)
      ipStr
    }).toDF("ip")

    // 将dataFrame注册为临时视图,方便做数据查询
    dataFrame.createTempView("v_ips")

    // 执行sql语句
    sparkSession.sql("select ip, ip2Location(ip) as location from v_ips")
      .limit(15)
      .show()

    sparkSession.close()
  }
}

// 这是一个工具类,主要是将ip地址转换为长整型以及二分查找
object IpUtils {
  /**
   * 将IP地址转成十进制
   *
   * @param ip
   * @return
   */
  def ip2Long(ip: String): Long = {
    val fragments = ip.split("[.]")
    var ipNum = 0L
    for (i <- 0 until fragments.length) {
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }

  /**
   * 二分法查找
   * 注意,scala中,如果是递归函数调用,必须要用return返回值,否则会导致函数无法跳出的问题
   * @param lines
   * @param ip
   * @return
   */
  def binarySearch(lines: ArrayBuffer[(Long, Long, String, String)], ip: Long): Int = {
    var low = 0 //起始
    var high = lines.length - 1 //结束
    while (low <= high) {
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
        return middle
      if (ip < lines(middle)._1)
        high = middle - 1
      else {
        low = middle + 1
      }
    }
    -1 //没有找到
  }

  def binarySearch(lines: Array[(Long, Long, String, String)], ip: Long): Int = {
    var low = 0 //起始
    var high = lines.length - 1 //结束
    while (low <= high) {
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
        return middle
      if (ip < lines(middle)._1)
        high = middle - 1
      else {
        low = middle + 1
      }
    }
    -1 //没有找到
  }
}

2.4 将订单数据中经纬度转换为地理位置

  1. 环境准备
  • 数据
{"cid": 1, "money": 600.0, "longitude":116.397128,"latitude":39.916527,"oid":"o123", }
"oid":"o112", "cid": 3, "money": 200.0, "longitude":118.396128,"latitude":35.916527}
{"oid":"o124", "cid": 2, "money": 200.0, "longitude":117.397128,"latitude":38.916527}
{"oid":"o125", "cid": 3, "money": 100.0, "longitude":118.397128,"latitude":35.916527}
{"oid":"o127", "cid": 1, "money": 100.0, "longitude":116.395128,"latitude":39.916527}
{"oid":"o128", "cid": 2, "money": 200.0, "longitude":117.396128,"latitude":38.916527}
{"oid":"o129", "cid": 3, "money": 300.0, "longitude":115.398128,"latitude":35.916527}
{"oid":"o130", "cid": 2, "money": 100.0, "longitude":116.397128,"latitude":39.916527}
{"oid":"o131", "cid": 1, "money": 100.0, "longitude":117.394128,"latitude":38.916527}
{"oid":"o132", "cid": 3, "money": 200.0, "longitude":118.396128,"latitude":35.916527}
object UDF_IP2Location {
  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDF_IP2Location")
      .master("local")
      .getOrCreate()

    import sparkSession.implicits._
    import org.apache.spark.sql.functions._

    // 读取文本文件
    val ipRules: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ip_location_dict.txt")

    // 将读取的ip规则字典加载到driver端,注意这里要做分布式缓存,省去join成本,也就是map端缓存
    // 数据转换,排序,去重,采集回driver端
    val ipRulesInDriver: Array[(Long, Long, String, String)] = ipRules.map(line => {
      val strings: Array[String] = line.split("[|]")
      val startIpNum: Long = strings(2).toLong
      val endIpNum: Long = strings(3).toLong
      val province: String = strings(6)
      val city: String = strings(7)

      (startIpNum, endIpNum, province, city)
    }).distinct()
      .sort($"_1" asc)
      .collect()

    // 注册广播变量
    val broadcastRefInDriver: Broadcast[Array[(Long, Long, String, String)]] = sparkSession.sparkContext.broadcast(ipRulesInDriver)

    // 自定义函数
    val ip2Location = (ip:String) => {
      val ipNumber: Long = IpUtils.ip2Long(ip)

      // 这里产生了闭包
      val ipRulesInExecutor: Array[(Long, Long, String, String)] = broadcastRefInDriver.value

      // 注意,文本文件中数据本身已经做过排序,但一般为了保险,一般都会再做一次排序
      val index: Int = IpUtils.binarySearch(ipRulesInExecutor, ipNumber)

      var province: String = null
      if(index > 0) {
        province = ipRulesInExecutor(index)._3
      }

      province
    }

    // 将自定义函数注册为udf函数
    sparkSession.udf.register("ip2Location", ip2Location)

    // 读取需要处理的日志数据
    val logLines: Dataset[String] = sparkSession.read.textFile("E:\\DOITLearning\\12.Spark\\ipaccess.log")

    val dataFrame: DataFrame = logLines.map(line => {
      val strings: Array[String] = line.split("[|]")
      val ipStr: String = strings(1)
      ipStr
    }).toDF("ip")

    // 将dataFrame注册为临时视图,方便做数据查询
    dataFrame.createTempView("v_ips")

    // 执行sql语句
    sparkSession.sql("select ip, ip2Location(ip) as location from v_ips")
      .limit(15)
      .show()

    sparkSession.close()
  }
}

// 这是一个工具类,主要是将ip地址转换为长整型以及二分查找
object IpUtils {
  /**
   * 将IP地址转成十进制
   *
   * @param ip
   * @return
   */
  def ip2Long(ip: String): Long = {
    val fragments = ip.split("[.]")
    var ipNum = 0L
    for (i <- 0 until fragments.length) {
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }

  /**
   * 二分法查找
   * 注意,scala中,如果是递归函数调用,必须要用return返回值,否则会导致函数无法跳出的问题
   * @param lines
   * @param ip
   * @return
   */
  def binarySearch(lines: ArrayBuffer[(Long, Long, String, String)], ip: Long): Int = {
    var low = 0 //起始
    var high = lines.length - 1 //结束
    while (low <= high) {
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
        return middle
      if (ip < lines(middle)._1)
        high = middle - 1
      else {
        low = middle + 1
      }
    }
    -1 //没有找到
  }

  def binarySearch(lines: Array[(Long, Long, String, String)], ip: Long): Int = {
    var low = 0 //起始
    var high = lines.length - 1 //结束
    while (low <= high) {
      val middle = (low + high) / 2
      if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2))
        return middle
      if (ip < lines(middle)._1)
        high = middle - 1
      else {
        low = middle + 1
      }
    }
    -1 //没有找到
  }
}

2.5 自定义聚合函数(适用于Spark1.0 2.0)

  1. 环境准备
  • 数据
name,salary,dept
jack,200.2,develop
tom,301.5,finance
sunny,412,operating
hanson,50000,ceo
tompson,312,operating
water,700.2,develop
money,500.2,develop
  • 求平均工资
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

// 这个类在1.0 2.0的spark版本还可以使用,spark 3.0版本已经废弃,使用更新的接口,相对更加精简
class CustomAvgFunction extends UserDefinedAggregateFunction {
  // 这是指输入的数据类型,因为这个自定义函数用来计算平均工资,所以输入就是一个数据,而且是double类型
  override def inputSchema: StructType = StructType(List(
    StructField("sal", DataTypes.DoubleType)
  ))

  // 这是中间结果数据类型,就是总工资,人员个数
  override def bufferSchema: StructType = StructType(List(
    StructField("sum_sal", DataTypes.DoubleType),
    StructField("counts", DataTypes.IntegerType)
  ))

  // 这是返回的数据类型,平均工资,还是double
  override def dataType: DataType = DataTypes.DoubleType

  // 确定性,这里指输入和输出数据类型是否一样
  override def deterministic: Boolean = true

  // 初始值,类似于RDD的combineBykey的用法
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 初始工资,就是0.0开始。这里需要显式指定0.0,会自动推导出是double类型0,所以不能是0,必须是0.0
    // 注意如果中间结果是乘法,除法,初始值就是1,注意灵活区别
    buffer(0) = 0.0

    buffer(1) = 0 // 人员个数
  }

  // 每处理一条数据,在每个分区进行的局部运算
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val in: Double = input.getDouble(0)

    // 添加一份工资数据
    buffer(0) = buffer.getDouble(0) + in

    // 次数累加1
    buffer(1) = buffer.getInt(1) + 1
  }

  // 每个分区的聚合结果操作
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 每个分区的总工资累加
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)

    // 每个分区的次数累加
    buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
  }

  // 最后的聚合操作
  override def evaluate(buffer: Row): Any = {
    // 总工资除以次数,如果要预防错误,可以判断分母为0的场景
    buffer.getDouble(0) / buffer.getInt(1)
  }
}
object UDF_Custom_AVG_Test {

  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDF_Custom_AVG_Test")
      .master("local")
      .getOrCreate()

    // 读取csv文件,注意option中还可以只当分割符号等信息
    val dataFrame: DataFrame = sparkSession.read
      .option("header", true)
      .option("inferschema", true)
      .csv("E:\\DOITLearning\\12.Spark\\employinfo.csv")

    // 创建临时视图,才能执行sql语句
    dataFrame.createTempView("v_emp")

    // 注册自定义函数,这个方法在spark 3.0中被指明废弃,但还可以使用
    sparkSession.udf.register("my_avg", new CustomAvgFunction)

    // 执行sql,内部运行自定义函数
    val dataFrame1: DataFrame = sparkSession.sql("select dept, my_avg(salary) as salary_avg from v_emp group by dept")

    dataFrame1.show()

    sparkSession.close()
  }
}

2.6 自定义聚合函数(适用于Spark 3.0)

  1. 求平均工资,和2.5一样的数据和需求
object UDF_Custom_AVG_Test2 {
  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDF_Custom_AVG_Test2")
      .master("local")
      .getOrCreate()

    // 读取csv文件,注意option中还可以只当分割符号等信息
    val dataFrame: DataFrame = sparkSession.read
      .option("header", true)
      .option("inferschema", true)
      .csv("E:\\DOITLearning\\12.Spark\\employinfo.csv")

    // 创建临时视图,才能执行sql语句
    dataFrame.createTempView("v_emp")

    import org.apache.spark.sql.functions._

    val myAVGFunct = new Aggregator[Double, (Double, Int), Double] {
      // 初始值
      override def zero: (Double, Int) = (0.0, 0)

      // 分区内聚合
      override def reduce(b: (Double, Int), a: Double): (Double, Int) = {
        (b._1 + a, b._2 + 1)
      }

      // 分区之间结果聚合
      override def merge(b1: (Double, Int), b2: (Double, Int)): (Double, Int) = {
        (b1._1 + b2._1, b1._2 + b2._2)
      }

      // 最后结果处理
      override def finish(reduction: (Double, Int)): Double = {
        reduction._1 / reduction._2
      }

      // 中间结果如何序列化编码
      override def bufferEncoder: Encoder[(Double, Int)] = {
        Encoders.tuple(Encoders.scalaDouble, Encoders.scalaInt)
      }

      // 数据结果输出如何进行序列化编码
      override def outputEncoder: Encoder[Double] = {
        Encoders.scalaDouble
      }
    }

    // 注册自定义方法
    // 新的自定义聚合方法,需要使用udaf将对象转换一下
    sparkSession.udf.register("my_avg", udaf(myAVGFunct))

    val dataFrame1: DataFrame = sparkSession.sql("select dept, my_avg(salary) as salary_avg from v_emp group by dept")

    dataFrame1.show()

    sparkSession.close()
  }
}

2.7 求几何平均数

object UDF_Custom_AVG_Test3 {

  def main(args: Array[String]): Unit = {

    val sparkSession: SparkSession = SparkSession.builder()
      .appName("UDF_Custom_AVG_Test3")
      .master("local")
      .getOrCreate()

    val nums: Dataset[lang.Long] = sparkSession.range(1, 10)

    nums.createTempView("v_nums")

    import org.apache.spark.sql.functions._

    //  自定义聚合函数
    val agg = new Aggregator[Long, (Long, Int), Double]() {
      // 这里是要求集合平均值,初始值会不一样
      override def zero: (Long, Int) = (1, 1)

      // 中间值处理
      override def reduce(b: (Long, Int), a: Long): (Long, Int) = {
        (b._1 * a, b._2 + 1)
      }

      // 分区之间结果聚合处理
      override def merge(b1: (Long, Int), b2: (Long, Int)): (Long, Int) = {
        (b1._1 * b2._1, b1._2 + b2._2)
      }

      // 最后结果处理
      override def finish(reduction: (Long, Int)): Double = {

        Math.pow(reduction._1.toDouble, 1 / reduction._2.toDouble)
      }

      // 中间结果序列化编码
      override def bufferEncoder: Encoder[(Long, Int)] = {
        Encoders.tuple(Encoders.scalaLong, Encoders.scalaInt)
      }

      // 输出结果编码
      override def outputEncoder: Encoder[Double] = {
        Encoders.scalaDouble
      }
    }

    // 注册方法
    sparkSession.udf.register("geo_mean", udaf(agg))

    val dataFrame: DataFrame = sparkSession.sql("select geo_mean(id) from v_nums")

    dataFrame.show()

    // 可以打印出逻辑计划,物理计划,以及其优化思路
    dataFrame.explain(true)

    sparkSession.close()
  }
}

2.8 总结

  1. 自定义函数,就跟编码时自定义的代码方法一样,可以根据业务需求做调整
  2. 如果需要复用,可以将其抽离到一个公共文件中,方便复用
  3. 自定义函数使用前需要注册一下
  4. dataframe本身要适用sql方式处理,需要先注册为视图,可以是临时的,也可以是全局的
  5. UDF、UDTF、UDAF概念和Hive中一样,也都可以自定义,最后在sql中使用

你可能感兴趣的:(spark,dataframe,scala,spark,apache,spark,scala,分布式计算,大数据)