改写Spark JdbcRDD,支持自己定义分区查询条件

文章来源:http://blog.csdn.net/javastart/article/details/45196191

改写Spark JdbcRDD,支持自己定义分区查询条件(转)

  597人阅读  评论(0)  收藏  举报
  分类:
 
大数据(102) 


改写Spark JdbcRDD,支持自己定义分区查询条件

分类: 大数据 Spark   139人阅读  评论(0) 收藏  举报
Spark JdbcRDD

Spark自带的JdbcRDD,只支持Long类型的分区参数,分区必须是一个Long区间。很多情况下,这种方式都不适用。

我对JdbcRDD进行了改写,可支持完全自定义分区条件。

主要实现思路:

把设置查询参数部分改写成可以自定义的函数。这样自己想怎么样设置分区参数都可以。

直接上代码吧:

//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

package org.apache.spark.rdd


import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.SparkContext
import org.apache.spark.util.NextIterator
import scala.reflect.ClassTag
import java.sql.ResultSet
import java.sql.Connection
import org.apache.spark.Partition
import org.apache.spark.Logging
import java.sql.PreparedStatement


class CustomizedJdbcPartition(idx: Int, parameters: Map[String, Object]) extends Partition {
  override def index = idx
  val partitionParameters=parameters
}
// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private
/**
 * An RDD that executes an SQL query on a JDBC connection and reads results.
 * For usage example, see test case JdbcRDDSuite.
 *
 * @param getConnection a function that returns an open Connection.
 *   The RDD takes care of closing the connection.
 * @param sql the text of the query.
 *   The query must contain two ? placeholders for parameters used to partition the results.
 *   E.g. "select title, author from books where ? <= id and id <= ?"
 * @param lowerBound the minimum value of the first placeholder
 * @param upperBound the maximum value of the second placeholder
 *   The lower and upper bounds are inclusive.
 * @param numPartitions the number of partitions.
 *   Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
 *   the query would be executed twice, once with (1, 10) and once with (11, 20)
 * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
 *   This should only call getInt, getString, etc; the RDD takes care of calling next.
 *   The default maps a ResultSet to an array of Object.
 */
class CustomizedJdbcRDD[T: ClassTag](
    sc: SparkContext,
    getConnection: () => Connection,
    sql: String,
    getCustomizedPartitions: () => Array[Partition],
    prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement, 
    mapRow: (ResultSet) => T = CustomizedJdbcRDD.resultSetToObjectArray _)
  extends RDD[T](sc, Nil) with Logging {


  override def getPartitions: Array[Partition] = {
    getCustomizedPartitions();
  }


  override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
    context.addTaskCompletionListener{ context => closeIfNeeded() }
    val part = thePart.asInstanceOf[CustomizedJdbcPartition]
    val conn = getConnection()
    val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)


    // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
    // rather than pulling entire resultset into memory.
    // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
    try {
   if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
     stmt.setFetchSize(Integer.MIN_VALUE)
     logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
   }
    } catch {
    case ex: Exception => {
        //ex.printStackTrace();
      }
    }


    prepareStatement(stmt, part)
    
    val rs = stmt.executeQuery()


    override def getNext: T = {
      if (rs.next()) {
        mapRow(rs)
      } else {
        finished = true
        null.asInstanceOf[T]
      }
    }


    override def close() {
      try {
        if (null != rs && ! rs.isClosed()) {
          rs.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing resultset", e)
      }
      try {
        if (null != stmt && ! stmt.isClosed()) {
          stmt.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing statement", e)
      }
      try {
        if (null != conn && ! conn.isClosed()) {
          conn.close()
        }
        logInfo("closed connection")
      } catch {
        case e: Exception => logWarning("Exception closing connection", e)
      }
    }
  }
}


object CustomizedJdbcRDD {
  def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
    Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
  }


  trait ConnectionFactory extends Serializable {
    @throws[Exception]
    def getConnection: Connection
  }


  /**
   * Create an RDD that executes an SQL query on a JDBC connection and reads results.
   * For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
   *
   * @param connectionFactory a factory that returns an open Connection.
   *   The RDD takes care of closing the connection.
   * @param sql the text of the query.
   *   The query must contain two ? placeholders for parameters used to partition the results.
   *   E.g. "select title, author from books where ? <= id and id <= ?"
   * @param lowerBound the minimum value of the first placeholder
   * @param upperBound the maximum value of the second placeholder
   *   The lower and upper bounds are inclusive.
   * @param numPartitions the number of partitions.
   *   Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
   *   the query would be executed twice, once with (1, 10) and once with (11, 20)
   * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
   *   This should only call getInt, getString, etc; the RDD takes care of calling next.
   *   The default maps a ResultSet to an array of Object.
   */
  def create[T](
      sc: JavaSparkContext,
      connectionFactory: ConnectionFactory,
      sql: String,
      getCustomizedPartitions: () => Array[Partition],
      prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement, 
      mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {


    val jdbcRDD = new CustomizedJdbcRDD[T](
      sc.sc,
      () => connectionFactory.getConnection,
      sql,
      getCustomizedPartitions,
      prepareStatement,
      (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)


    new JavaRDD[T](jdbcRDD)(fakeClassTag)
  }


  /**
   * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
   * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
   *
   * @param connectionFactory a factory that returns an open Connection.
   *   The RDD takes care of closing the connection.
   * @param sql the text of the query.
   *   The query must contain two ? placeholders for parameters used to partition the results.
   *   E.g. "select title, author from books where ? <= id and id <= ?"
   * @param lowerBound the minimum value of the first placeholder
   * @param upperBound the maximum value of the second placeholder
   *   The lower and upper bounds are inclusive.
   * @param numPartitions the number of partitions.
   *   Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
   *   the query would be executed twice, once with (1, 10) and once with (11, 20)
   */
  def create(
      sc: JavaSparkContext,
      connectionFactory: ConnectionFactory,
      sql: String,
      getCustomizedPartitions: () => Array[Partition],
      prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement): JavaRDD[Array[Object]] = {


    val mapRow = new JFunction[ResultSet, Array[Object]] {
      override def call(resultSet: ResultSet): Array[Object] = {
        resultSetToObjectArray(resultSet)
      }
    }


    create(sc, connectionFactory, sql, getCustomizedPartitions, prepareStatement, mapRow)
  }
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

下面是一段简单的测试代码:

package org.apache.spark


import java.sql.Connection
import java.sql.DriverManager
import org.apache.spark.rdd.CustomizedJdbcRDD
import org.apache.spark.rdd.CustomizedJdbcPartition
import java.sql.PreparedStatement


object HiveRDDTest {
  private val driverName = "org.apache.hive.jdbc.HiveDriver";
  private val tableName = "COLLECT_DATA";
  private var connection: Connection = null;
  
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("HiveRDDTest").setMaster("local[2]");
    val sc = new SparkContext(conf);
    Class.forName(driverName);
    
    val data = new CustomizedJdbcRDD(sc,
                           //创建获取JDBC连接函数
                           () => {
     DriverManager.getConnection("jdbc:hive2://192.168.31.135:10000/default", "spark", "");
  },
  //设置查询SQL
  "select * from collect_data where host=?",
  //创建分区函数
  () => {
    val partitions=new Array[Partition](1);
    var parameters=Map[String, Object]();
    parameters+=("host" -> "172.18.26.11");
    val partition=new CustomizedJdbcPartition(0, parameters);
    partitions(0)=partition;
    partitions;
  },
  //为每个分区设置查询条件(基于上面设置的SQL语句)
  (stmt:PreparedStatement, partition:CustomizedJdbcPartition) => {
    stmt.setString(1, partition.asInstanceOf[CustomizedJdbcPartition]
                               .partitionParameters.get("host").get.asInstanceOf[String])
    stmt;
  }
    );
    println(data.count());
    
  }

}


你可能感兴趣的:(spark,大数据)