详细代码可参考:org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
Spark程序中我们有时会遇到要将Dataset数据插入/更新到关系型数据库的情况,有时候我们需要自己建立jdbc连接,拼接SQL语句,以下是简便的设置SQL参数的方法:
def makeSetter(stmt: PreparedStatement, row: Row, dataType: DataType, pos: Int) = dataType match {
case IntegerType => stmt.setInt(pos + 1, row.getInt(pos))
case LongType => stmt.setLong(pos + 1, row.getLong(pos))
case DoubleType => stmt.setDouble(pos + 1, row.getDouble(pos))
case FloatType => stmt.setFloat(pos + 1, row.getFloat(pos))
case ShortType => stmt.setInt(pos + 1, row.getShort(pos))
case ByteType => stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType => stmt.setBoolean(pos + 1, row.getBoolean(pos))
case StringType => stmt.setString(pos + 1, row.getString(pos))
case BinaryType => stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
case TimestampType => stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
case DateType => stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
case t: DecimalType => stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
case _ => throw new IllegalArgumentException(s"Can't translate non-null value for field $pos")
}
def pSetter(dataType: DataType, pos: Int) = dataType match {
case IntegerType => println(s"""stmt.setInt(${pos + 1}, row.getInt(${pos}))""")
case LongType => println(s"""stmt.setLong(${pos + 1}, row.getLong(${pos}))""")
case DoubleType => println(s"""stmt.setDouble(${pos + 1}, row.getDouble(${pos}))""")
case FloatType => println(s"""stmt.setFloat(${pos + 1}, row.getFloat(${pos}))""")
case ShortType => println(s"""stmt.setInt(${pos + 1}, row.getShort(${pos}))""")
case ByteType => println(s"""stmt.setInt(${pos + 1}, row.getByte(${pos}))""")
case BooleanType => println(s"""stmt.setBoolean(${pos + 1}, row.getBoolean(${pos}))""")
case StringType => println(s"""stmt.setString(${pos + 1}, row.getString(${pos}))""")
case BinaryType => println(s"""stmt.setBytes(${pos + 1}, row.getAs[Array[Byte]](${pos}))""")
case TimestampType => println(s"""stmt.setTimestamp(${pos + 1}, row.getAs[java.sql.Timestamp](${pos}))""")
case DateType => println(s"""stmt.setDate(${pos + 1}, row.getAs[java.sql.Date](${pos}))""")
case t: DecimalType => println(s"""stmt.setBigDecimal(${pos + 1}, row.getDecimal(${pos}))""")
case _ => throw new IllegalArgumentException(s"Can't translate non-null value for field $pos")
}
下面是一个完整的更新/插入数据示例:
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import java.util.Properties
import java.sql.DriverManager
import java.sql.PreparedStatement
import java.math.BigDecimal
object T {
val spark = SparkSession
.builder()
.appName("T")
.config("spark.sql.parquet.writeLegacyFormat", true)
.enableHiveSupport()
.getOrCreate()
val prop = {
val p = new Properties()
p.put("driver", "oracle.jdbc.driver.OracleDriver")
p.put("url", "jdbc:oracle:thin:@10.18.2.3:1521:testdb")
p.put("user", "***")
p.put("password", "***")
p
}
val table = "OracleDB.TB"
Class.forName(prop.getProperty("driver"))
val df = spark.table("HiveDB.tb")
val names = df.schema.map(f => f.name)
val dataTypes = df.schema.map(f => f.dataType)
val batchSize = 1000 // addBatch方法添加操作数
val totalSize = df.count
// update数据到数据库
val pks = Seq("id1", "id2").map(_ + "=?").mkString(" and ") // 主键
val setCol = names.map(_ + "=?").mkString(",")
val usql = s"update $table set ${setCol} where ${pks}"
val t1 = System.currentTimeMillis()
df.coalesce(3).foreachPartition { part =>
val conn = DriverManager.getConnection(prop.getProperty("url"), prop)
conn.setAutoCommit(false)
val stmt = conn.prepareStatement(usql)
var i = 0
var notInsertCount = totalSize
part.foreach { row =>
val k1 = row.getAs[String]("id1")
val k2 = row.getAs[BigDecimal]("id2")
dataTypes.zipWithIndex.foreach {
case (dt, idx) =>
makeSetter(stmt, row, dt, idx)
}
stmt.setString(dataTypes.size + 1, k1)
stmt.setBigDecimal(dataTypes.size + 2, k2)
stmt.addBatch()
i += 1
if (i >= batchSize) {
val exeRes = stmt.executeBatch()
notInsertCount -= exeRes.size
conn.commit()
stmt.clearBatch()
i = 0
}
}
if (notInsertCount > 0) {
val exeRes = stmt.executeBatch()
notInsertCount -= exeRes.size
conn.commit()
stmt.clearBatch()
}
println("未插入记录数: " + notInsertCount)
conn.commit()
stmt.close()
conn.close()
}
val t2 = System.currentTimeMillis()
println("update时间: " + (t2 - t1) / 1000)
// insert数据到数据库
// 类似于:
// df.write.jdbc(url, table, connectionProperties)
val columns = names.mkString(",")
val placeholders = names.map(_ => "?").mkString(",")
val isql = s"insert into $table ($columns) values ($placeholders)"
val t3 = System.currentTimeMillis()
df.coalesce(3).foreachPartition { part =>
val conn = DriverManager.getConnection(prop.getProperty("url"), prop)
conn.setAutoCommit(false)
val stmt = conn.prepareStatement(isql)
var i = 0
var notInsertCount = totalSize
part.foreach { row =>
dataTypes.zipWithIndex.foreach {
case (dt, idx) =>
makeSetter(stmt, row, dt, idx)
}
stmt.addBatch()
i += 1
if (i >= batchSize) {
val exeRes = stmt.executeBatch()
notInsertCount -= exeRes.size
conn.commit()
stmt.clearBatch()
i = 0
}
}
if (notInsertCount > 0) {
val exeRes = stmt.executeBatch()
notInsertCount -= exeRes.size
conn.commit()
stmt.clearBatch()
}
println("未插入记录数: " + notInsertCount)
stmt.close()
conn.close()
}
val t4 = System.currentTimeMillis()
println("insert时间: " + (t4 - t3) / 1000)
def makeSetter(stmt: PreparedStatement, row: Row, dataType: DataType, pos: Int) = dataType match {
case IntegerType => stmt.setInt(pos + 1, row.getInt(pos))
case LongType => stmt.setLong(pos + 1, row.getLong(pos))
case DoubleType => stmt.setDouble(pos + 1, row.getDouble(pos))
case FloatType => stmt.setFloat(pos + 1, row.getFloat(pos))
case ShortType => stmt.setInt(pos + 1, row.getShort(pos))
case ByteType => stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType => stmt.setBoolean(pos + 1, row.getBoolean(pos))
case StringType => stmt.setString(pos + 1, row.getString(pos))
case BinaryType => stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
case TimestampType => stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
case DateType => stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
case t: DecimalType => stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
case _ => throw new IllegalArgumentException(s"Can't translate non-null value for field $pos")
}
}