前提:小编自己的环境是(CDH)Spark2.2.0 Scala2.11.8
起因:当使用Append追加写入mysql类型的数据库,spark默认是把之前存在的数据清空,然后再写入数据;这让我们很无语,明明是Append,你却给我overwrite
解决:修改源码,重写两个类(只要把这两个类放到自己项目中,无需修改spark底层源码)
1.JdbcUtils
原本是:if (mode == SaveMode.Append && tableExists) {
truncateTable(conn, table)
tableExists = true
}
把truncateTable(conn, table)删除即可
完整的:
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.Properties
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.jdbc.DriverWrapper
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import scala.util.Try
/**
* Util functions for JDBC tables.
*/
object JdbcUtils extends Logging {
val mode = SaveMode.Append
def jdbc(mode:SaveMode,url: String, df: DataFrame, table: String, connectionProperties: Properties): Unit = {
val props = new Properties()
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnection(url, props)
try {
var tableExists = JdbcUtils.tableExists(conn, table)
if (mode == SaveMode.Ignore && tableExists) {
return
}
if (mode == SaveMode.ErrorIfExists && tableExists) {
sys.error(s"Table $table already exists.")
}
if (mode == SaveMode.Overwrite && tableExists) {
truncateTable(conn, table)
tableExists = true
}
if (mode == SaveMode.Append && tableExists) {
//*********************把下面这一行注释或删除*****************************
// truncateTable(conn, table)
tableExists = true
}
// Create the table if the table didn't exist.
if (!tableExists) {
val schema = JdbcUtils.schemaString(df, url)
val sql = s"CREATE TABLE $table ($schema)"
conn.prepareStatement(sql).executeUpdate()
}
JdbcUtils.saveTable(df, url, table, props)
} finally {
conn.close()
}
}
/**
* Establishes a JDBC connection.
*/
def createConnection(url: String, connectionProperties: Properties): Connection = {
JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
}
/**
* Returns true if the table already exists in the JDBC database.
*/
def tableExists(conn: Connection, table: String): Boolean = {
// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
// SQL database systems, considering "table" could also include the database name.
Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
}
/**
* Drops a table from the JDBC database.
*/
def dropTable(conn: Connection, table: String): Unit = {
conn.prepareStatement(s"DROP TABLE $table").executeUpdate()
}
def truncateTable(conn: Connection, table: String): Unit = {
conn.prepareStatement(s"TRUNCATE TABLE $table").executeUpdate()
}
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
val fields = rddSchema.fields
val fieldsSql = new StringBuilder(s"(")
var i = 0;
for (f <- fields) {
fieldsSql.append(f.name)
if (i == fields.length - 1) {
fieldsSql.append(")")
} else {
fieldsSql.append(",")
}
i += 1
}
val sql = new StringBuilder(s"INSERT INTO $table ")
sql.append(fieldsSql.toString())
sql.append(" VALUES (")
var fieldsLeft = rddSchema.fields.length
while (fieldsLeft > 0) {
sql.append("?")
if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
fieldsLeft = fieldsLeft - 1
}
//println(sql.toString())
conn.prepareStatement(sql.toString())
}
/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction in order to avoid repeatedly inserting
* data as much as possible.
*
* It is still theoretically possible for rows in a DataFrame to be
* inserted into the database more than once if a stage somehow fails after
* the commit occurs but before the stage can return successfully.
*
* This is not a closure inside saveTable() because apparently cosmetic
* implementation changes elsewhere might easily render such a closure
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
*/
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int]): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
conn.setAutoCommit(false) // Everything in the same db transaction.
val stmt = insertStatement(conn, table, rddSchema)
try {
while (iterator.hasNext) {
val row = iterator.next()
val numFields = rddSchema.fields.length
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
rddSchema.fields(i).dataType match {
case IntegerType => stmt.setInt(i + 1, row.getInt(i))
case LongType => stmt.setLong(i + 1, row.getLong(i))
case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
case ShortType => stmt.setInt(i + 1, row.getShort(i))
case ByteType => stmt.setInt(i + 1, row.getByte(i))
case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
case StringType => stmt.setString(i + 1, row.getString(i))
case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
}
i = i + 1
}
stmt.executeUpdate()
}
} finally {
stmt.close()
}
conn.commit()
committed = true
} finally {
if (!committed) {
// The stage must fail. We got here through an exception path, so
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
conn.rollback()
conn.close()
} else {
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
} catch {
case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
}
}
}
Array[Byte]().iterator
}
/**
* Compute the schema string for this RDD.
*/
def schemaString(df: DataFrame, url: String): String = {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}
}
if (sb.length < 2) "" else sb.substring(2)
}
/**
* Saves the RDD to the database in a single transaction.
*/
def saveTable(
df: DataFrame,
url: String,
table: String,
properties: Properties = new Properties()) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case DoubleType => java.sql.Types.DOUBLE
case FloatType => java.sql.Types.REAL
case ShortType => java.sql.Types.INTEGER
case ByteType => java.sql.Types.INTEGER
case BooleanType => java.sql.Types.BIT
case StringType => java.sql.Types.CLOB
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
}
val rddSchema = df.schema
/* def createConnectionFactory(options: JDBCOptions): () => Connection = {
val driverClass: String = options.driverClass
() => {
DriverRegistry.register(driverClass)
val driver: Driver = DriverManager.getDrivers {
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
case d if d.getClass.getCanonicalName == driverClass => d
}.getOrElse {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
driver.connect(options.url, options.asConnectionProperties)
}
}*/
val driver: String = getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes)
}
}
def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
case driver => driver.getClass.getCanonicalName
}
}
2.JDBCRDD,这个类之所以也重写 是因为上面的JdbcUtils 要用到,不加入会报错:
删除部分,只保留有用的,结果是:
/**
* Created by Administrator on 2018/5/10 0010.
*/
import java.sql.{Connection, DriverManager}
import java.util.Properties
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry
private object JDBCRDD extends Logging {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException =>
logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
}
}
}
最后就是调用拉:
df是DataFrame;
url="jdbc:mysql://*****:3306/databaseName;
JdbcUtils.jdbc(SaveMode.Append,url,df,tableName,prop)