重写JdbcRDD实现条件查询Mysql数据库

虽然Spark使用关系型数据库作为数据源的场景并不多,但是有时候我们还是希望能够能够从MySql等数据库中读取数据,并封装成RDD。Spark官方确实也提供了这么一个库给我们,org.apache.spark.rdd.JdbcRDD。但是这个库使用起来让人觉得很鸡肋,因为它不支持条件查询,只支持起止边界查询,这大大限定了它的使用场景。很多时候我们需要分析的数据不可能单独建一个表,它们往往被混杂在一个大的表中,我们会希望更加精确的找出某一类的数据做分析。
查看了一下这个JdbcRDD的源码,我们就能明白为什么他只提供起止边界了。

 val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

    val url = conn.getMetaData.getURL
    if (url.startsWith("jdbc:mysql:")) {
      // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force
      // streaming results, rather than pulling entire resultset into memory.
      // See the below URL
      // dev.mysql.com/doc/connector-j/5.1/en/connector-j-reference-implementation-notes.html

      stmt.setFetchSize(Integer.MIN_VALUE)
    } else {
      stmt.setFetchSize(100)
    }

    logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")

    stmt.setLong(1, part.lower)
    stmt.setLong(2, part.upper)
    val rs = stmt.executeQuery()

它使用的是游标的方式,conn.prepareStatement(sql, type, concurrency),因此传入的参数只能是这个分区的起始编号part.lower和这个分区的终止编号part.upper。我查了半天资料,也不知道这种方式该如何将条件传给这个stmt ,有点难受。索性也不尝试了,也不考虑兼容其他类型的数据库,只考虑mysql数据库的话,把游标这种方式给去了,这样使用limit总能给它查出来吧。
以下是具体实现,
重写的JdbcRDD:

package JdbcRDD

import java.sql.{Connection, ResultSet}
import java.util.ArrayList
import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
private class JdbcPartition(idx: Int, val lower: Long, val upper: Long ,val params:ArrayList[Any]) extends Partition {
  override def index: Int = idx 
}

class JdbcRDD[T: ClassTag](
                            sc: SparkContext,
                            getConnection: () => Connection,
                            sql: String,
                            lowerBound: Long,
                            upperBound: Long,
                            params: ArrayList[Any],
                            numPartitions: Int,
                            mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
  extends RDD[T](sc, Nil) with Logging {
  override def getPartitions: Array[Partition] = {
    // bounds are inclusive, hence the + 1 here and - 1 on end
    val length = BigInt(1) + upperBound - lowerBound
    (0 until numPartitions).map { i =>
      val start = lowerBound + ((i * length) / numPartitions)
      val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
      new JdbcPartition(i, start.toLong, end.toLong,params)
    }.toArray
  }

  override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
  {
    context.addTaskCompletionListener{ context => closeIfNeeded() }
    val part = thePart.asInstanceOf[JdbcPartition]
    val conn = getConnection()
//直接采用我们常用的预处理方式
    val stmt = conn.prepareStatement(sql)
    val url = conn.getMetaData.getURL
    if (url.startsWith("jdbc:mysql:")) {
      stmt.setFetchSize(Integer.MIN_VALUE)
    } else {
      return null
    }
    logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
//传参
   val params = part.params
    val paramsSize = params.size()
    if(params!=null){
      for(i <- 1 to paramsSize){
        val param = params.get(i-1)
        param match {
          case param:String => stmt.setString(i,param)
          case param:Int => stmt.setInt(i,param)
          case param:Boolean => stmt.setBoolean(i,param)
          case param:Double => stmt.setDouble(i,param)
          case param:Float => stmt.setFloat(i,param)
          case _=> {
            println("type is fault")
          }
        }
      }
    }
//限定该分区查询起始偏移量和条数
    stmt.setLong(paramsSize+1, part.lower)
    stmt.setLong(paramsSize+2, part.upper-part.lower+1)
    val rs = stmt.executeQuery()
    override def getNext(): T = {
      if (rs.next()) {
        mapRow(rs)
      } else {
        finished = true
        null.asInstanceOf[T]
      }
    }
    override def close() {
      try {
        if (null != rs) {
          rs.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing resultset", e)
      }
      try {
        if (null != stmt) {
          stmt.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing statement", e)
      }
      try {
        if (null != conn) {
          conn.close()
        }
        logInfo("closed connection")
      } catch {
        case e: Exception => logWarning("Exception closing connection", e)
      }
    }
  }
}
object JdbcRDD {
  def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
    Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
  }

  trait ConnectionFactory extends Serializable {
    @throws[Exception]
    def getConnection: Connection
  }

  def create[T](
                 sc: JavaSparkContext,
                 connectionFactory: ConnectionFactory,
                 sql: String,
                 lowerBound: Long,
                 upperBound: Long,
                 params: ArrayList[Any],
                 numPartitions: Int,
                 mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {

    val jdbcRDD = new JdbcRDD[T](
      sc.sc,
      () => connectionFactory.getConnection,
      sql,
      lowerBound,
      upperBound,
      params,
      numPartitions,
      (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
    new JavaRDD[T](jdbcRDD)( fakeClassTag)
  }

 
  def create(
              sc: JavaSparkContext,
              connectionFactory: ConnectionFactory,
              sql: String,
              lowerBound: Long,
              upperBound: Long,
              params: ArrayList[Any],
              numPartitions: Int): JavaRDD[Array[Object]] = {

    val mapRow = new JFunction[ResultSet, Array[Object]] {
      override def call(resultSet: ResultSet): Array[Object] = {
        resultSetToObjectArray(resultSet)
      }
    }
    create(sc, connectionFactory, sql, lowerBound, upperBound, params,numPartitions, mapRow)
  }
  private def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}

接下来是测试代码:

package JdbcRDD

import java.sql.{DriverManager, ResultSet}
import java.util

import org.apache.spark.SparkContext

object JdbcRDDTest {
  def main(args: Array[String]) {
    //val conf = new SparkConf().setAppName("spark_mysql").setMaster("local")
    val sc = new SparkContext("local[2]","spark_mysql")

    def createConnection() = {
      Class.forName("com.mysql.jdbc.Driver").newInstance()
      DriverManager.getConnection("jdbc:mysql://localhost:3306/transportation", "root", "pass")
    }
    def extractValues(r: ResultSet) = {
      (r.getString(1), r.getString(2))
    }
    val params = new util.ArrayList[Any]
    params.add(100)//传参
    params.add(7)
    val data = new JdbcRDD(sc, createConnection, "SELECT * FROM login_log where  id<=? and user_id=? limit ?,?", lowerBound = 1, upperBound =20,params=params, numPartitions = 5, mapRow = extractValues)
    data.cache()
    println(data.collect.length)
    println(data.collect().toList)
    sc.stop()
  }
}

测试结果:


测试结果

可以看出,重写这个JdbcRDD后我们可以条件查询某一个表,也可以同时限定查询条数,这给我们用Spark分析Mysql中的数据提供了方便,我们不需要先将需要的数据滤出来再进行分析。当然,这个demo写的比较粗糙,只是提供这么一种方法的演示,后期还可以稍加修改。

你可能感兴趣的:(重写JdbcRDD实现条件查询Mysql数据库)