Spark Streaming提高写数据库的效率

1. 前言

这是一篇挂羊头卖狗肉的文章,事实上,本文要描述的内容,和Spark Streaming没有什么关系。

在上一篇文章http://www.jianshu.com/p/a73c0c95d2fe 我们写了如何通过Spark Streaming向数据库中插入数据。可能你已经发现了,数据是逐条插入数据库的,效率底下。那么如何提高插入数据库的效率呢?

数据库写是个IO任务,并行不一定能够加速写入数据库的速度。我们主要说下批量提交和Bulk Copy Insert的方式。

2.批量提交

批量提交,就是JDBC Statment的executeBatch,直接看代码吧。

/**
  * 从Kafka中读取数据,并把数据成批写入数据库
  */
object KafkaToDB {

  val logger = LoggerFactory.getLogger(this.getClass)

  def main(args: Array[String]): Unit = {
    // 参数校验
    if (args.length < 2) {
      System.err.println(
        s"""
           |Usage: KafkaToDB  
           |   is a list of one or more Kafka brokers
           |   is a list of one or more kafka topics to consume from
           |""".stripMargin)
      System.exit(1)
    }

    // 处理参数
    val Array(brokers, topics) = args
    // topic以“,”分割
    val topicSet: Set[String] = topics.split(",").toSet
    val kafkaParams: Map[String, Object] = Map[String, Object](
      "bootstrap.servers" -> brokers,
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "example",
      "auto.offset.reset" -> "latest",
      "enable.auto.commit" -> (false: java.lang.Boolean)
    )

    // 创建上下文,以每1秒间隔的数据作为一批
    val sparkConf = new SparkConf().setAppName("KafkaToDB")
    val streamingContext = new StreamingContext(sparkConf, Seconds(2))

    // 1.创建输入流,获取数据。流操作基于DStream,InputDStream继承于DStream
    val stream = KafkaUtils.createDirectStream[String, String](
      streamingContext,
      PreferConsistent,
      Subscribe[String, String](topicSet, kafkaParams)
    )

    // 2. DStream上的转换操作
    // 取消息中的value数据,以英文逗号分割,并转成Tuple3
    val values = stream.map(_.value.split(","))
      .filter(x => x.length == 3)
      .map(x => new Tuple3[String, String, String](x(0), x(1), x(2)))


    // 输入前10条到控制台,方便调试
    values.print()

    // 3.同foreachRDD保存到数据库
    val sql = "insert into kafka_message(timeseq,timeseq2, thread, message) values (?,?,?,?)"
    values.foreachRDD(rdd => {
      val count = rdd.count()
      println("-----------------count:" + count)
      if (count > 0) {
        rdd.foreachPartition(partitionOfRecords => {
          val conn = ConnectionPool.getConnection.orNull
          if (conn != null) {
            val ps = conn.prepareStatement(sql)
            try{
              // 关闭自动执提交
              conn.setAutoCommit(false)
              partitionOfRecords.foreach(data => {
                ps.setString(1, data._1)
                ps.setString(2,System.currentTimeMillis().toString)
                ps.setString(3, data._2)
                ps.setString(4, data._3)
                ps.addBatch()
              })
              ps.executeBatch()
              conn.commit()
            } catch {
              case e: Exception =>
                logger.error("Error in execution of insert. " + e.getMessage)
            }finally {
              ps.close()
              ConnectionPool.closeConnection(conn)
            }
          }
        })
      }
    })

    streamingContext.start() // 启动计算
    streamingContext.awaitTermination() // 等待中断结束计算

  }
}

3. Bulk Copy Insert

我们使用的是PostgreSQL,其数据库JDBC驱动程序提供了Copy Insert的API,其主要过程是:

  • 1.获取数据库连接
  • 2.创建CopyManager
  • 3.把Spark Streaming中的流数据封装成InputStream
  • 4.执行CopyInsert
import java.sql.Connection

import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.SparkConf
import org.apache.spark.streaming.kafka010.ConsumerStrategies._
import org.apache.spark.streaming.kafka010.KafkaUtils
import org.apache.spark.streaming.kafka010.LocationStrategies._
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
import org.slf4j.LoggerFactory

object CopyInsert {

  val logger = LoggerFactory.getLogger(this.getClass)

  def main(args: Array[String]): Unit = {
    // 参数校验
    if (args.length < 4) {
      System.err.println(
        s"""
           |Usage: CopyInsert    
           |   is a list of one or more Kafka brokers
           |   is a list of one or more kafka topics to consume from
           |""".stripMargin)
      System.exit(1)
    }

    // 处理参数
    val Array(brokers, topics,duration,batchsize) = args
    // topic以“,”分割
    val topicSet: Set[String] = topics.split(",").toSet
    val kafkaParams: Map[String, Object] = Map[String, Object](
      "bootstrap.servers" -> brokers,
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "example",
      "auto.offset.reset" -> "latest",
      "enable.auto.commit" -> (false: java.lang.Boolean)
    )

    // 创建上下文,以每1秒间隔的数据作为一批
    val sparkConf = new SparkConf().setAppName("CopyInsertIntoPostgreSQL")
    val streamingContext = new StreamingContext(sparkConf, Seconds(duration.toInt))

    // 1.创建输入流,获取数据。流操作基于DStream,InputDStream继承于DStream
    val stream = KafkaUtils.createDirectStream[String, String](
      streamingContext,
      PreferConsistent,
      Subscribe[String, String](topicSet, kafkaParams)
    )

    // 2. DStream上的转换操作
    // 取消息中的value数据,以英文逗号分割,并转成Tuple3
    val values = stream.map(_.value.split(","))
      .filter(x => x.length == 3)
      .map(x => new Tuple3[String, String, String](x(0), x(1), x(2)))


    // 输入前10条到控制台,方便调试
    values.print()

    // 3.同foreachRDD保存到数据库
    // http://rostislav-matl.blogspot.jp/2011/08/fast-inserts-to-postgresql-with-jdbc.html
    values.foreachRDD(rdd => {
      val count = rdd.count()
      println("-----------------count:" + count)
      if (count > 0) {
        rdd.foreachPartition(partitionOfRecords => {
          val start = System.currentTimeMillis()
          val conn: Connection = ConnectionPool.getConnection.orNull
          if (conn != null) {
            val batch = batchsize.toInt
            var counter: Int = 0
            val sb: StringBuilder = new StringBuilder()
            // 获取数据库连接
            val baseConnection = conn.getMetaData.getConnection.asInstanceOf[BaseConnection]
            // 创建CopyManager
            val cpManager: CopyManager = new CopyManager(baseConnection)
            partitionOfRecords.foreach(record => {
              counter += 1
              sb.append(record._1).append(",")
                .append(System.currentTimeMillis()).append(",")
                .append(record._2).append(",")
                .append(record._3).append("\n")
              if (counter == batch) {
                // 构建输入流
                val in: InputStream = new ByteArrayInputStream(sb.toString().getBytes())
                // 执行copyin
                cpManager.copyIn("COPY kafka_message FROM STDIN WITH CSV", in)
                println("-----------------batch---------------: " + batch)
                counter = 0
                sb.delete(0, sb.length)
                closeInputStream(in)
              }
            })
            val lastIn: InputStream = new ByteArrayInputStream(sb.toString().getBytes())
            cpManager.copyIn("COPY kafka_message2 FROM STDIN WITH CSV", lastIn)
            sb.delete(0, sb.length)
            counter = 0
            closeInputStream(lastIn)
            val end = System.currentTimeMillis()
            println("-----------------duration---------------ms :" + (end - start))
          }
        })

      }
    })

    streamingContext.start() // 启动计算
    streamingContext.awaitTermination() // 等待中断结束计算
 }

 def closeInputStream(in: InputStream): Unit ={
   try{
       in.close()
    }catch{
     case e: IOException =>
       logger.error("Error on close InputStream. " + e.getMessage)
      }
  }
    
}

其它数据库应该也有bulk load的方式,例如MySQL,com.mysql.jdbc.Statment中有setLocalInfileInputStream方法,功能应该和上述的Copy Insert类似,但我还没有写例子验证。文档里有如下的描述,供参考。原文地址

Sets an InputStream instance that will be used to send data to the MySQL server for a "LOAD DATA LOCAL INFILE" statement rather than a FileInputStream or URLInputStream that represents the path given as an argument to the statement.

(完)

你可能感兴趣的:(Spark Streaming提高写数据库的效率)