文章来源:http://blog.csdn.net/javastart/article/details/45196191
Spark自带的JdbcRDD,只支持Long类型的分区参数,分区必须是一个Long区间。很多情况下,这种方式都不适用。
我对JdbcRDD进行了改写,可支持完全自定义分区条件。
主要实现思路:
把设置查询参数部分改写成可以自定义的函数。这样自己想怎么样设置分区参数都可以。
直接上代码吧:
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
package org.apache.spark.rdd
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.SparkContext
import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag
import java.sql.ResultSet
import java.sql.Connection
import org.apache.spark.Partition
import org.apache.spark.Logging
import java.sql.PreparedStatement
class CustomizedJdbcPartition(idx: Int, parameters: Map[String, Object]) extends Partition {
override def index = idx
val partitionParameters=parameters
}
// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private
/**
* An RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JdbcRDDSuite.
*
* @param getConnection a function that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
class CustomizedJdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
getCustomizedPartitions: () => Array[Partition],
prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement,
mapRow: (ResultSet) => T = CustomizedJdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = {
getCustomizedPartitions();
}
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[CustomizedJdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
// rather than pulling entire resultset into memory.
// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
try {
if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
}
} catch {
case ex: Exception => {
//ex.printStackTrace();
}
}
prepareStatement(stmt, part)
val rs = stmt.executeQuery()
override def getNext: T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
}
override def close() {
try {
if (null != rs && ! rs.isClosed()) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt && ! stmt.isClosed()) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn && ! conn.isClosed()) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
}
}
}
object CustomizedJdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
trait ConnectionFactory extends Serializable {
@throws[Exception]
def getConnection: Connection
}
/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
def create[T](
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
getCustomizedPartitions: () => Array[Partition],
prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement,
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
val jdbcRDD = new CustomizedJdbcRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
getCustomizedPartitions,
prepareStatement,
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
new JavaRDD[T](jdbcRDD)(fakeClassTag)
}
/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
*/
def create(
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
getCustomizedPartitions: () => Array[Partition],
prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement): JavaRDD[Array[Object]] = {
val mapRow = new JFunction[ResultSet, Array[Object]] {
override def call(resultSet: ResultSet): Array[Object] = {
resultSetToObjectArray(resultSet)
}
}
create(sc, connectionFactory, sql, getCustomizedPartitions, prepareStatement, mapRow)
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
下面是一段简单的测试代码:
package org.apache.spark
import java.sql.Connection
import java.sql.DriverManager
import org.apache.spark.rdd.CustomizedJdbcRDD
import org.apache.spark.rdd.CustomizedJdbcPartition
import java.sql.PreparedStatement
object HiveRDDTest {
private val driverName = "org.apache.hive.jdbc.HiveDriver";
private val tableName = "COLLECT_DATA";
private var connection: Connection = null;
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("HiveRDDTest").setMaster("local[2]");
val sc = new SparkContext(conf);
Class.forName(driverName);
val data = new CustomizedJdbcRDD(sc,
//创建获取JDBC连接函数
() => {
DriverManager.getConnection("jdbc:hive2://192.168.31.135:10000/default", "spark", "");
},
//设置查询SQL
"select * from collect_data where host=?",
//创建分区函数
() => {
val partitions=new Array[Partition](1);
var parameters=Map[String, Object]();
parameters+=("host" -> "172.18.26.11");
val partition=new CustomizedJdbcPartition(0, parameters);
partitions(0)=partition;
partitions;
},
//为每个分区设置查询条件(基于上面设置的SQL语句)
(stmt:PreparedStatement, partition:CustomizedJdbcPartition) => {
stmt.setString(1, partition.asInstanceOf[CustomizedJdbcPartition]
.partitionParameters.get("host").get.asInstanceOf[String])
stmt;
}
);
println(data.count());
}
}