spark 的saveMode在org.apache.spark.sql.SaveMode
下,是一个枚举类,支持
而实际业务开发中,我们可能更希望一些行级别的动作而非这种表级别的动作
总结业务开发过程中常见的需求,设计出以下枚举类:
package org.apache.spark.sql.ximautil
package org.apache.spark.sql.ximautil
/**
* @author todd.chen at 8/26/16 9:52 PM.
* email : todd.chen@ximalaya.com
*/
object JdbcSaveMode extends Enumeration {
type SaveMode = Value
val IgnoreTable, Append, Overwrite, Update, ErrorIfExists, IgnoreRecord = Value
}
对应的执行SQL语句应该是
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect, saveMode: SaveMode)
: PreparedStatement = {
val columnNames = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name))
val columns = columnNames.mkString(",")
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = saveMode match {
case Update ⇒
val duplicateSetting = columnNames.map(name ⇒ s"$name=?").mkString(",")
s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
case Append | Overwrite ⇒
s"INSERT INTO $table ($columns) VALUES ($placeholders)"
case IgnoreRecord ⇒
s"INSERT IGNORE INTO $table ($columns) VALUES ($placeholders)"
case _ ⇒ throw new IllegalArgumentException(s"$saveMode is illegal")
}
conn.prepareStatement(sql)
}
2.0之前的org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
其实是有问题的,对于每一行row的set都进行了比较类型,时间复杂度非常高,2.0之后重写出了一个setter逻辑,形成了一个prepareStatment的模板,这样瞬间将原来的比较类型进行了指数级优化,核心代码:
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))
case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))
case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos))
case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos))
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))
case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos))
case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setString(pos + 1, row.getString(pos))
case BinaryType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase.split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int) =>
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos).toArray)
stmt.setArray(pos + 1, array)
case _ =>
(_: PreparedStatement, _: Row, pos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}
这个虽然已经解决了大多数问题,但如果使用DUPLICATE还是有问题的:
insert into table_name (name,age,id) values (?,?,?)
insert into table_name (name,age,id) values (?,?,?) on duplicate key update name =? ,age=?,id=?
所以在prepareStatment中的占位符应该是row的两倍,而且应该是类似这样的一个逻辑:
row[1,2,3]
setter(0,1) //index of setter,index of row
setter(1,2)
setter(2,3)
setter(3,1)
setter(4,2)
setter(5,3)
我们能发现当超过setter.length 的一半时,此时的row的index应该是setterIndex - (setterIndex/2) + 1
所以新的一个实现是这样的:
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
// offset using in duplicateSetting
private type JDBCValueSetter = (PreparedStatement, Row, Int, Int) ⇒ Unit
private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setInt(pos + 1, row.getInt(pos - offset))
case LongType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setLong(pos + 1, row.getLong(pos - offset))
case DoubleType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setDouble(pos + 1, row.getDouble(pos - offset))
case FloatType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setFloat(pos + 1, row.getFloat(pos - offset))
case ShortType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setInt(pos + 1, row.getShort(pos - offset))
case ByteType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setInt(pos + 1, row.getByte(pos - offset))
case BooleanType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setBoolean(pos + 1, row.getBoolean(pos - offset))
case StringType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setString(pos + 1, row.getString(pos - offset))
case BinaryType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - offset))
case TimestampType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - offset))
case DateType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - offset))
case t: DecimalType ⇒
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
stmt.setBigDecimal(pos + 1, row.getDecimal(pos - offset))
case ArrayType(et, _) ⇒
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase.split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos - offset).toArray)
stmt.setArray(pos + 1, array)
case _ ⇒
(_: PreparedStatement, _: Row, pos: Int, offset: Int) ⇒
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}
private def getSetter(fields: Array[StructField], connection: Connection, dialect: JdbcDialect, isUpdateMode: Boolean): Array[JDBCValueSetter] = {
val setter = fields.map(_.dataType).map(makeSetter(connection, dialect, _))
if (isUpdateMode) {
Array.fill(2)(setter).flatten
} else {
setter
}
}
在使用过程中的改变主要是:
源码:
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
require(batchSize >= 1,
s"Invalid value `${batchSize.toString}` for parameter " +
s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.")
val conn = getConnection()
var committed = false
var finalIsolationLevel = Connection.TRANSACTION_NONE
if (isolationLevel != Connection.TRANSACTION_NONE) {
try {
val metadata = conn.getMetaData
if (metadata.supportsTransactions()) {
// Update to at least use the default isolation, if any transaction level
// has been chosen and transactions are supported
val defaultIsolation = metadata.getDefaultTransactionIsolation
finalIsolationLevel = defaultIsolation
if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
// Finally update to actually requested level if possible
finalIsolationLevel = isolationLevel
} else {
logWarning(s"Requested isolation level $isolationLevel is not supported; " +
s"falling back to default isolation level $defaultIsolation")
}
} else {
logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
}
} catch {
case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
}
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray
try {
var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
val numFields = rddSchema.fields.length
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i)
}
i = i + 1
}
stmt.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
}
}
if (rowCount > 0) {
stmt.executeBatch()
}
} finally {
stmt.close()
}
if (supportsTransactions) {
conn.commit()
}
committed = true
} catch {
case e: SQLException =>
val cause = e.getNextException
if (e.getCause != cause) {
if (e.getCause == null) {
e.initCause(cause)
} else {
e.addSuppressed(cause)
}
}
throw e
} finally {
if (!committed) {
// The stage must fail. We got here through an exception path, so
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
}
conn.close()
} else {
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
} catch {
case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
}
}
}
Array[Byte]().iterator
}
改动点:
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int,
saveMode: SaveMode) = {
require(batchSize >= 1,
s"Invalid value `${batchSize.toString}` for parameter " +
s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.")
val isUpdateMode = saveMode == Update //check is UpdateMode
val conn = getConnection()
var committed = false
val length = rddSchema.fields.length
val numFields = if (isUpdateMode) length * 2 else length // real num Field length
val stmt = insertStatement(conn, table, rddSchema, dialect, saveMode)
val setters: Array[JDBCValueSetter] = getSetter(rddSchema.fields, conn, dialect, isUpdateMode) //call method getSetter
var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
val midField = numFields / 2
while (i < numFields) {
//if duplicate ,'?' size = 2 * row.field.length
if (isUpdateMode) {
i < midField match { // check midField > i ,if midFiled >i ,rowIndex is setterIndex - (setterIndex/2) + 1
case true ⇒
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
case false ⇒
if (row.isNullAt(i - midField)) {
stmt.setNull(i + 1, nullTypes(i - midField))
} else {
setters(i).apply(stmt, row, i, midField)
}
}
} else {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
}
i = i + 1
}
封装的bean对象:
case class JdbcSaveExplain(
url: String,
tableName: String,
saveMode: SaveMode,
jdbcParam: Properties
)
封装的DataFrameWriter对象
package com.ximalaya.spark.xql.exec.jdbc
import java.util.Properties
import com.ximalaya.spark.common.log.CommonLoggerTrait
import language._
import com.ximalaya.spark.xql.interpreter.jdbc.JdbcSaveExplain
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.ximautil.JdbcSaveMode.SaveMode
import org.apache.spark.sql.ximautil.JdbcSaveMode._
import org.apache.spark.sql.ximautil.XQLJdbcUtil
/**
* @author todd.chen at 8/26/16 11:33 PM.
* email : todd.chen@ximalaya.com
*/
class JdbcDataFrameWriter(dataFrame: DataFrame) extends Serializable with CommonLoggerTrait {
def writeJdbc(jdbcSaveExplain: JdbcSaveExplain) = {
this.jdbcSaveExplain = jdbcSaveExplain
this
}
def save(): Unit = {
assert(jdbcSaveExplain != null)
val saveMode = jdbcSaveExplain.saveMode
val url = jdbcSaveExplain.url
val table = jdbcSaveExplain.tableName
val props = jdbcSaveExplain.jdbcParam
if (checkTable(url, table, props, saveMode))
XQLJdbcUtil.saveTable(dataFrame, url, table, props, saveMode)
}
private def checkTable(url: String, table: String, connectionProperties: Properties, saveMode: SaveMode): Boolean = {
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnectionFactory(url, props)()
try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
//table ignore ,exit
if (saveMode == IgnoreTable && tableExists) {
logger.info(" table {} exists ,mode is ignoreTable,save nothing to it", table)
return false
}
//error if table exists
if (saveMode == ErrorIfExists && tableExists) {
sys.error(s"Table $table already exists.")
}
//overwrite table ,delete table
if (saveMode == Overwrite && tableExists) {
JdbcUtils.dropTable(conn, table)
tableExists = false
}
// Create the table if the table didn't exist.
if (!tableExists) {
checkField(dataFrame)
val schema = JdbcUtils.schemaString(dataFrame, url)
val sql = s"CREATE TABLE $table (id int not null primary key auto_increment , $schema)"
conn.prepareStatement(sql).executeUpdate()
}
true
} finally {
conn.close()
}
}
//because table in mysql need id as primary key auto increment,illegal if dataFrame contains id field
private def checkField(dataFrame: DataFrame): Unit = {
if (dataFrame.schema.exists(_.name == "id")) {
throw new IllegalArgumentException("dataFrame exists id columns,but id is primary key auto increment in mysql ")
}
}
private var jdbcSaveExplain: JdbcSaveExplain = _
private val extraOptions = new scala.collection.mutable.HashMap[String, String]
}
object JdbcDataFrameWriter {
implicit def dataFrame2JdbcWriter(dataFrame: DataFrame): JdbcDataFrameWriter = JdbcDataFrameWriter(dataFrame)
def apply(dataFrame: DataFrame): JdbcDataFrameWriter = new JdbcDataFrameWriter(dataFrame)
}
测试用例:
implicit def map2Prop(map: Map[String, String]): Properties = map.foldLeft(new Properties) {
case (prop, kv) ⇒ prop.put(kv._1, kv._2); prop
}
val sparkContext = new SparkContext(sparkConf)
val sqlContext = new SQLContext(sparkContext)
// val hiveContext = new HiveContext(sparkContext)
// import hiveContext.implicits._
import sqlContext.implicits._
val df = sparkContext.parallelize(Seq(
(1, 1, "2", "ctccct", "286"),
(2, 2, "2", "ccc", "11"),
(4, 10, "2", "ccct", "12")
)).toDF("id", "iid", "uid", "name", "age")
val jdbcSaveExplain = JdbcSaveExplain(
"test",
"jdbc:mysql://localhost:3306/test",
"mytest",
JdbcSaveMode.Update,
Map("user" → "user", "password" → "password")
)
import JdbcDataFrameWriter.dataFrame2JdbcWriter
df.writeJdbc(jdbcSaveExplain).save()
mygithub