虽然Spark使用关系型数据库作为数据源的场景并不多,但是有时候我们还是希望能够能够从MySql等数据库中读取数据,并封装成RDD。Spark官方确实也提供了这么一个库给我们,org.apache.spark.rdd.JdbcRDD。但是这个库使用起来让人觉得很鸡肋,因为它不支持条件查询,只支持起止边界查询,这大大限定了它的使用场景。很多时候我们需要分析的数据不可能单独建一个表,它们往往被混杂在一个大的表中,我们会希望更加精确的找出某一类的数据做分析。
查看了一下这个JdbcRDD的源码,我们就能明白为什么他只提供起止边界了。
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val url = conn.getMetaData.getURL
if (url.startsWith("jdbc:mysql:")) {
// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force
// streaming results, rather than pulling entire resultset into memory.
// See the below URL
// dev.mysql.com/doc/connector-j/5.1/en/connector-j-reference-implementation-notes.html
stmt.setFetchSize(Integer.MIN_VALUE)
} else {
stmt.setFetchSize(100)
}
logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()
它使用的是游标的方式,conn.prepareStatement(sql, type, concurrency),因此传入的参数只能是这个分区的起始编号part.lower和这个分区的终止编号part.upper。我查了半天资料,也不知道这种方式该如何将条件传给这个stmt ,有点难受。索性也不尝试了,也不考虑兼容其他类型的数据库,只考虑mysql数据库的话,把游标这种方式给去了,这样使用limit总能给它查出来吧。
以下是具体实现,
重写的JdbcRDD:
package JdbcRDD
import java.sql.{Connection, ResultSet}
import java.util.ArrayList
import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
private class JdbcPartition(idx: Int, val lower: Long, val upper: Long ,val params:ArrayList[Any]) extends Partition {
override def index: Int = idx
}
class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
lowerBound: Long,
upperBound: Long,
params: ArrayList[Any],
numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = {
// bounds are inclusive, hence the + 1 here and - 1 on end
val length = BigInt(1) + upperBound - lowerBound
(0 until numPartitions).map { i =>
val start = lowerBound + ((i * length) / numPartitions)
val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
new JdbcPartition(i, start.toLong, end.toLong,params)
}.toArray
}
override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
{
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
//直接采用我们常用的预处理方式
val stmt = conn.prepareStatement(sql)
val url = conn.getMetaData.getURL
if (url.startsWith("jdbc:mysql:")) {
stmt.setFetchSize(Integer.MIN_VALUE)
} else {
return null
}
logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
//传参
val params = part.params
val paramsSize = params.size()
if(params!=null){
for(i <- 1 to paramsSize){
val param = params.get(i-1)
param match {
case param:String => stmt.setString(i,param)
case param:Int => stmt.setInt(i,param)
case param:Boolean => stmt.setBoolean(i,param)
case param:Double => stmt.setDouble(i,param)
case param:Float => stmt.setFloat(i,param)
case _=> {
println("type is fault")
}
}
}
}
//限定该分区查询起始偏移量和条数
stmt.setLong(paramsSize+1, part.lower)
stmt.setLong(paramsSize+2, part.upper-part.lower+1)
val rs = stmt.executeQuery()
override def getNext(): T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
}
override def close() {
try {
if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
}
}
}
object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
trait ConnectionFactory extends Serializable {
@throws[Exception]
def getConnection: Connection
}
def create[T](
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
params: ArrayList[Any],
numPartitions: Int,
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
val jdbcRDD = new JdbcRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
lowerBound,
upperBound,
params,
numPartitions,
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
new JavaRDD[T](jdbcRDD)( fakeClassTag)
}
def create(
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
params: ArrayList[Any],
numPartitions: Int): JavaRDD[Array[Object]] = {
val mapRow = new JFunction[ResultSet, Array[Object]] {
override def call(resultSet: ResultSet): Array[Object] = {
resultSetToObjectArray(resultSet)
}
}
create(sc, connectionFactory, sql, lowerBound, upperBound, params,numPartitions, mapRow)
}
private def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}
接下来是测试代码:
package JdbcRDD
import java.sql.{DriverManager, ResultSet}
import java.util
import org.apache.spark.SparkContext
object JdbcRDDTest {
def main(args: Array[String]) {
//val conf = new SparkConf().setAppName("spark_mysql").setMaster("local")
val sc = new SparkContext("local[2]","spark_mysql")
def createConnection() = {
Class.forName("com.mysql.jdbc.Driver").newInstance()
DriverManager.getConnection("jdbc:mysql://localhost:3306/transportation", "root", "pass")
}
def extractValues(r: ResultSet) = {
(r.getString(1), r.getString(2))
}
val params = new util.ArrayList[Any]
params.add(100)//传参
params.add(7)
val data = new JdbcRDD(sc, createConnection, "SELECT * FROM login_log where id<=? and user_id=? limit ?,?", lowerBound = 1, upperBound =20,params=params, numPartitions = 5, mapRow = extractValues)
data.cache()
println(data.collect.length)
println(data.collect().toList)
sc.stop()
}
}
测试结果:
可以看出,重写这个JdbcRDD后我们可以条件查询某一个表,也可以同时限定查询条数,这给我们用Spark分析Mysql中的数据提供了方便,我们不需要先将需要的数据滤出来再进行分析。当然,这个demo写的比较粗糙,只是提供这么一种方法的演示,后期还可以稍加修改。