Spark:将Dataset数据写入关系型数据库常用方法

详细代码可参考:org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Spark程序中我们有时会遇到要将Dataset数据插入/更新到关系型数据库的情况,有时候我们需要自己建立jdbc连接,拼接SQL语句,以下是简便的设置SQL参数的方法:

  def makeSetter(stmt: PreparedStatement, row: Row, dataType: DataType, pos: Int) = dataType match {
    case IntegerType => stmt.setInt(pos + 1, row.getInt(pos))
    case LongType => stmt.setLong(pos + 1, row.getLong(pos))
    case DoubleType => stmt.setDouble(pos + 1, row.getDouble(pos))
    case FloatType => stmt.setFloat(pos + 1, row.getFloat(pos))
    case ShortType => stmt.setInt(pos + 1, row.getShort(pos))
    case ByteType => stmt.setInt(pos + 1, row.getByte(pos))
    case BooleanType => stmt.setBoolean(pos + 1, row.getBoolean(pos))
    case StringType => stmt.setString(pos + 1, row.getString(pos))
    case BinaryType => stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
    case TimestampType => stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
    case DateType => stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
    case t: DecimalType => stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
    case _ => throw new IllegalArgumentException(s"Can't translate non-null value for field $pos")
  }

  def pSetter(dataType: DataType, pos: Int) = dataType match {
    case IntegerType => println(s"""stmt.setInt(${pos + 1}, row.getInt(${pos}))""")
    case LongType => println(s"""stmt.setLong(${pos + 1}, row.getLong(${pos}))""")
    case DoubleType => println(s"""stmt.setDouble(${pos + 1}, row.getDouble(${pos}))""")
    case FloatType => println(s"""stmt.setFloat(${pos + 1}, row.getFloat(${pos}))""")
    case ShortType => println(s"""stmt.setInt(${pos + 1}, row.getShort(${pos}))""")
    case ByteType => println(s"""stmt.setInt(${pos + 1}, row.getByte(${pos}))""")
    case BooleanType => println(s"""stmt.setBoolean(${pos + 1}, row.getBoolean(${pos}))""")
    case StringType => println(s"""stmt.setString(${pos + 1}, row.getString(${pos}))""")
    case BinaryType => println(s"""stmt.setBytes(${pos + 1}, row.getAs[Array[Byte]](${pos}))""")
    case TimestampType => println(s"""stmt.setTimestamp(${pos + 1}, row.getAs[java.sql.Timestamp](${pos}))""")
    case DateType => println(s"""stmt.setDate(${pos + 1}, row.getAs[java.sql.Date](${pos}))""")
    case t: DecimalType => println(s"""stmt.setBigDecimal(${pos + 1}, row.getDecimal(${pos}))""")
    case _ => throw new IllegalArgumentException(s"Can't translate non-null value for field $pos")
  }

下面是一个完整的更新/插入数据示例:

import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import java.util.Properties
import java.sql.DriverManager
import java.sql.PreparedStatement
import java.math.BigDecimal

object T {
  val spark = SparkSession
    .builder()
    .appName("T")
    .config("spark.sql.parquet.writeLegacyFormat", true)
    .enableHiveSupport()
    .getOrCreate()

  val prop = {
    val p = new Properties()
    p.put("driver", "oracle.jdbc.driver.OracleDriver")
    p.put("url", "jdbc:oracle:thin:@10.18.2.3:1521:testdb")
    p.put("user", "***")
    p.put("password", "***")
    p
  }
  val table = "OracleDB.TB"
  Class.forName(prop.getProperty("driver"))

  val df = spark.table("HiveDB.tb")
  val names = df.schema.map(f => f.name)
  val dataTypes = df.schema.map(f => f.dataType)

  val batchSize = 1000 // addBatch方法添加操作数
  val totalSize = df.count
  // update数据到数据库
  val pks = Seq("id1", "id2").map(_ + "=?").mkString(" and ") // 主键
  val setCol = names.map(_ + "=?").mkString(",")
  val usql = s"update $table set ${setCol} where ${pks}"

  val t1 = System.currentTimeMillis()
  df.coalesce(3).foreachPartition { part =>
    val conn = DriverManager.getConnection(prop.getProperty("url"), prop)
    conn.setAutoCommit(false)
    val stmt = conn.prepareStatement(usql)

    var i = 0
    var notInsertCount = totalSize
    part.foreach { row =>
      val k1 = row.getAs[String]("id1")
      val k2 = row.getAs[BigDecimal]("id2")

      dataTypes.zipWithIndex.foreach {
        case (dt, idx) =>
          makeSetter(stmt, row, dt, idx)
      }
      stmt.setString(dataTypes.size + 1, k1)
      stmt.setBigDecimal(dataTypes.size + 2, k2)
      stmt.addBatch()
      i += 1
      if (i >= batchSize) {
        val exeRes = stmt.executeBatch()
        notInsertCount -= exeRes.size
        conn.commit()
        stmt.clearBatch()
        i = 0
      }
    }

    if (notInsertCount > 0) {
      val exeRes = stmt.executeBatch()
      notInsertCount -= exeRes.size
      conn.commit()
      stmt.clearBatch()
    }
    println("未插入记录数: " + notInsertCount)
    conn.commit()
    stmt.close()
    conn.close()
  }
  val t2 = System.currentTimeMillis()
  println("update时间: " + (t2 - t1) / 1000)

  // insert数据到数据库
  // 类似于:
  // df.write.jdbc(url, table, connectionProperties)
  val columns = names.mkString(",")
  val placeholders = names.map(_ => "?").mkString(",")
  val isql = s"insert into $table ($columns) values ($placeholders)"

  val t3 = System.currentTimeMillis()
  df.coalesce(3).foreachPartition { part =>
    val conn = DriverManager.getConnection(prop.getProperty("url"), prop)
    conn.setAutoCommit(false)
    val stmt = conn.prepareStatement(isql)

    var i = 0
    var notInsertCount = totalSize
    part.foreach { row =>
      dataTypes.zipWithIndex.foreach {
        case (dt, idx) =>
          makeSetter(stmt, row, dt, idx)
      }
      stmt.addBatch()
      i += 1
      if (i >= batchSize) {
        val exeRes = stmt.executeBatch()
        notInsertCount -= exeRes.size
        conn.commit()
        stmt.clearBatch()
        i = 0
      }
    }

    if (notInsertCount > 0) {
      val exeRes = stmt.executeBatch()
      notInsertCount -= exeRes.size
      conn.commit()
      stmt.clearBatch()
    }
    println("未插入记录数: " + notInsertCount)
    stmt.close()
    conn.close()
  }
  val t4 = System.currentTimeMillis()
  println("insert时间: " + (t4 - t3) / 1000)

  def makeSetter(stmt: PreparedStatement, row: Row, dataType: DataType, pos: Int) = dataType match {
    case IntegerType => stmt.setInt(pos + 1, row.getInt(pos))
    case LongType => stmt.setLong(pos + 1, row.getLong(pos))
    case DoubleType => stmt.setDouble(pos + 1, row.getDouble(pos))
    case FloatType => stmt.setFloat(pos + 1, row.getFloat(pos))
    case ShortType => stmt.setInt(pos + 1, row.getShort(pos))
    case ByteType => stmt.setInt(pos + 1, row.getByte(pos))
    case BooleanType => stmt.setBoolean(pos + 1, row.getBoolean(pos))
    case StringType => stmt.setString(pos + 1, row.getString(pos))
    case BinaryType => stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
    case TimestampType => stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
    case DateType => stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
    case t: DecimalType => stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
    case _ => throw new IllegalArgumentException(s"Can't translate non-null value for field $pos")
  }
}

你可能感兴趣的:(大数据,大数据/spark)