Spark 实现两表查询(SparkCore和SparkSql)

项目需求:

ip.txt:包含ip起始地址,ip结束地址,ip所属省份

access.txt:包含ip地址和各种访问数据

需求:两表联合查询每个省份的ip数量

SparkCore

使用广播,将小表广播到executor.对大表的每条数据都到小表中进行查找。

package day07

import java.sql.DriverManager

import org.apache.log4j.{Level, Logger}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object IPLocation {
  val ipFile = "d:\\data\\spark\\ip.txt"
  val acessFile = "d:\\data\\spark\\access.log"

  def main(args: Array[String]): Unit = {

    Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
    val conf = new SparkConf().setAppName("IpLocation").setMaster("local[3]")
    val sc = new SparkContext(conf)
    //1.读取IP规则资源库
    val lines = sc.textFile(ipFile)
    //2.整理Ip规则
    val ipRules = lines.map(x => {
      val splited = x.split("[|]")
      val startNum = splited(2).toLong
      val endNum = splited(3).toLong
      val province = splited(6)
      (startNum,endNum,province)
    })
    //println(ipRules.collect().toBuffer)
    //3.将Ip收集起来
    val ipDriver: Array[(Long, Long, String)] = ipRules.collect()
    //4.将IP通过广播的方式发送到executor
    //广播之后,在Driver端获取了广播变量的引用(如果没有广播完,就不往下走)
    val broadcastRef: Broadcast[Array[(Long, Long, String)]] = sc.broadcast(ipDriver)

    //5.读取访问日志
    val access = sc.textFile(acessFile)
    //6.整理访问日志
    val provinces = access.map(x => {
      val fields = x.split("[|]")
      val ip = fields(1)
      val ipNum = MyUtils.ip2Long(ip)
      //通过广播获取所有ip规则,然后进行匹配
      val allIpRulesExecutor = broadcastRef.value
      //根据规则查找,二分查找
      var province = "未知"
      val index = MyUtils.binarySearch(allIpRulesExecutor,ipNum)
      if(index != -1){
        province = allIpRulesExecutor(index)._3
      }
      (province,1)
    })
    //7.按照省份进行计数
    val reduceRDD: RDD[(String, Int)] = provinces.reduceByKey(_+_)
    //8.打印结果
    //reduceRDD.foreach(println)
    //9.将数据存储到mysql中
    /**
      * reduceRDD.foreach(x => {
      *
      * val conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/test?characterEncoding=utf-8&useSSL=true","root","123456")
      * val pstm = conn.prepareStatement("insert into access_log values (?,?)")
      *       pstm.setString(1,x._1)
      *       pstm.setInt(2,x._2)
      *       pstm.execute()
      *       pstm.close()
      *       conn.close()
      * })
      */
    //MyUtils.data2MySQL(reduceRDD.collect().toIterator)
    reduceRDD.foreachPartition(MyUtils.data2MySQL(_))
    sc.stop()


  }

}

SparkSql

1.将两张表的数据提取出来,转换成DataFrame,创建两个view。实现join查询

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Dataset, SparkSession}

object IPDemo {
  Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
  val ipFile = ("d:\\data\\spark\\ip.txt")
  val acessFile = "d:\\data\\spark\\access.log"
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("Ip").master("local[*]").getOrCreate()
    import spark.implicits._
    //读取ip文件
    val ipFile = spark.read.textFile("d:\\data\\spark\\ip.txt")
    //整理ip文件
    val ipRules: Dataset[(Long, Long, String)] = ipFile.map(line => {
      val splited = line.split("[|]")
      val startNum = splited(2).toLong
      val endNum = splited(3).toLong
      val province = splited(6)
      (startNum,endNum,province)
    })
    //加入元数据
    val ipDF = ipRules.toDF("start_num","end_num","province")
    //将ip注册成view
    ipDF.createTempView("t_ip")
    //读取访问日志文件
    val access_file = spark.read.textFile(acessFile)
    import day07.MyUtils
    val accessDF = access_file.map(line =>{
      val fields = line.split("[|]")
      val ip = fields(1)
      MyUtils.ip2Long(ip)
    }).toDF("ip")
    //将访问日志整理成视图
    accessDF.createTempView("t_access")
    //sql语句 关联两张表
    val result = spark.sql("SELECT province,count(*) counts FROM t_ip JOIN t_access ON ip>=start_num and ip<=end_num GROUP BY province ORDER BY counts DESC")
    result.show();
    spark.stop()

  }

}

2.改进方法

两表join,如果数据量太大,就会导致运行速度变慢。所以将ip的数据以广播的方式发送到Executor。构建一个自定义方法,进行查询。

import day07.MyUtils
import org.apache.spark.sql.{Dataset, SparkSession}

object IpLocation {
  val ipFile = "d:\\data\\spark\\ip.txt"
  val acessFile = "d:\\data\\spark\\access.log"
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("SQLIPLocation").master("local[*]").getOrCreate()
    //隐式转换
    import  spark.implicits._
    //读取ip文件
    val ipFile = spark.read.textFile("d:\\data\\spark\\ip.txt")
    //整理ip文件
    val ipRules: Dataset[(Long, Long, String)] = ipFile.map(line => {
      val splited = line.split("[|]")
      val startNum = splited(2).toLong
      val endNum = splited(3).toLong
      val province = splited(6)
      (startNum,endNum,province)
    })
    //加入元数据
    //val ipDF = ipRules.toDF("start_num","end_num","province")

    //将全部的IP规则收集到Driver端
    val ipRulesDriver = ipRules.collect()
    //广播 阻塞的方法 没有广播完,就不会向下
    val broadcastRef = spark.sparkContext.broadcast(ipRulesDriver)

    //读取web日志
    val accessLogLines = spark.read.textFile(acessFile)
    val ips = accessLogLines.map(line => {
      val Fields = line.split("[|]")
      val ip = Fields(1)
      MyUtils.ip2Long(ip)
    }).toDF("ip_num")
    //将访问日志数据注册成视图
    ips.createTempView("access_ip")

    //定义并注册自定义函数
    //自定义函数在哪里定义的?  (Driver)  业务逻辑在Executor执行
    spark.udf.register("ip_num2Province",(ip_num:Long)=>{
      //获取广播到Driver
      //根据Driver端的广播变量引用,在发送task时,会将Driver端的引用伴随着发送到Executor
      val rulesExecute: Array[(Long, Long, String)] = broadcastRef.value
      val index = MyUtils.binarySearch(rulesExecute,ip_num)
      var province = "未知"
      if(index != -1){
        province = rulesExecute(index)._3
      }
      province
    })

    val result = spark.sql("select ip_num2Province(ip_num) province,count(*) counts from access_ip group by province order by counts desc")

    result.show()

    spark.stop()

  }

}

三、用到的工具包代码如下:

import java.sql.{Connection, DriverManager, PreparedStatement}

/**
  * Created by zx on 2017/12/12.
  */
object MyUtils {

//将ip转换成数字类型
  def ip2Long(ip:String):Long ={
    val fragments = ip.split("[.]")
    var ipNum =0L
    for(i<- 0 until fragments.length){
      ipNum = fragments(i).toLong | ipNum << 8L
    }
    ipNum
  }



//查找某个ip所属的省份
  def binarySearch(lines: Array[(Long,Long,String)],ip: Long):Int ={
    var low =0
    var high =lines.length-1
    while(low <=high){
      val middle =(low+high)/2
      if((ip>=lines(middle)._1) && (ip<=lines(middle)._2))
        return middle
      if(ip < lines(middle)._1)
        high=middle -1
      else{
        low =middle +1
      }
    }
    -1
  }

//连接mysql 插入数据
  def data2MySQL(iter:Iterator[(String,Int)])={
    val conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/test","root","123456")
    val ps = conn.prepareStatement("insert into access_log values (?,?)")
    iter.foreach(x =>{
      ps.setString(1,x._1)
      ps.setInt(2,x._2)
      ps.executeUpdate()
    })
    if(conn!=null){
      conn.close()
    }
    if(ps!=null){
      ps.close()
    }
  }

}

 

你可能感兴趣的:(hadoop)