> 在spark的数据源中,只支持Append, Overwrite, ErrorIfExists, Ignore,这几种模式,但是我们在线上的业务几乎全是需要upsert功能的,就是已存在的数据肯定不能覆盖,在mysql中实现就是采用:`ON DUPLICATE KEY UPDATE`,有没有这样一种实现?官方:不好意思,不提供,dounine:我这有呀,你来用吧。哈哈,为了方便大家的使用我已经把项目打包到maven中央仓库了,为的就是使用快,容易使用。
## 吃土的方案
MysqlClient.scala
```
import java.sql._
import java.time.{LocalDate, LocalDateTime}
import scala.collection.mutable.ListBuffer
class MysqlClient(jdbcUrl: String) {
private var connection: Connection = null
val driver = "com.mysql.jdbc.Driver"
init()
def init(): Unit = {
if (connection == null || connection.isClosed) {
val split = jdbcUrl.split("\\|")
Class.forName(driver)
connection = DriverManager.getConnection(split(0), split(1), split(2))
}
}
def close(): Unit = {
connection.close()
}
def execute(sql: String, params: Any*): Unit = {
try {
val statement = connection.prepareStatement(sql)
this.fillStatement(statement, params: _*)
statement.executeUpdate
} catch {
case e: SQLException =>
e.printStackTrace()
}
}
@throws[SQLException]
def fillStatement(statement: PreparedStatement, params: Any*): Unit = {
for (i <- 1 until params.length + 1) {
val value: Any = params(i - 1)
value match {
case s: String => statement.setString(i, value.toString)
case i: Integer => statement.setInt(i, value.toString.asInstanceOf[Int])
case b: Boolean => statement.setBoolean(i, value.toString.asInstanceOf[Boolean])
case ld: LocalDate => statement.setString(i, value.toString)
case ldt: LocalDateTime => statement.setString(i, value.toString)
case l: Long => statement.setLong(i, value.toString.asInstanceOf[Long])
case d: Double => statement.setDouble(i, value.toString.asInstanceOf[Double])
case f: Float => statement.setFloat(i, value.toString.asInstanceOf[Float])
case _ => statement.setString(i, value.toString)
}
}
}
def upsert(query: Query, update: Update, tableName: String): Unit = {
val names = ListBuffer[String]()
val values = ListBuffer[String]()
val params = ListBuffer[AnyRef]()
val updates = ListBuffer[AnyRef]()
val keysArr = scala.Array(query.values.keys, update.sets.keys, update.incs.keys)
val valuesArr = scala.Array(update.sets.values, update.incs.values)
for (i: Int <- 0 until keysArr.length) {
val item = keysArr(i)
item.foreach {
key => {
names += s"`${key}`"
values += "?"
}
}
i match {
case 0 => {
params.++=(query.values.values)
}
case 1 | 2 => {
params.++=(valuesArr(i - 1).toList)
}
}
}
update.sets.foreach {
item => {
updates += s" `${item._1}` = ? "
params += item._2
}
}
update.incs.foreach {
item => {
updates += s" `${item._1}` = `${item._1}` + ? "
params += item._2
}
}
val sql = s"INSERT INTO `$tableName` (${names.mkString(",")}) VALUES(${values.mkString(",")}) ON DUPLICATE KEY UPDATE ${updates.mkString(",")}"
this.execute(sql, params.toArray[AnyRef]: _*)
}
}
case class Update(sets: Map[String, AnyRef] = Map(), incs: Map[String, AnyRef] = Map())
case class Query(values: Map[String, AnyRef] = Map())
```
吃土的程序
```
val fieldMaps = (row: Row, fields: Array[String]) => fields.map {
field => (field, Option(row.getAs[String](field)).getOrElse(""))
}.toMap
sc.sql(
s"""select time,count(userid) as pv,count(distinct(userid)) as uv from log group by time""")
.foreachPartition(item => {
val props: Properties = PropertiesUtils.properties("mysql")
val mysqlClient: MysqlClient = new MysqlClient(props.getProperty("jdbcUrl"))
while (item.hasNext) {
val row: Row = item.next()
val pv: Long = row.getAs("pv")
val uv: Long = row.getAs("uv")
val indicatorMap = Map(
"pv" -> pv.toString,
"uv" -> uv.toString
)
val update = if (overrideIndicator) {//覆盖
Update(sets = indicatorMap)
} else {//upsert
Update(incs = indicatorMap)
}
var queryMap = fieldMaps(row,"time".split(","))
mysqlClient.upsert(
Query(queryMap),
update,
"test"
)
}
mysqlClient.close()
})
```
真的是丑极了,不想看了
## 如今升级为jdbc2之后
依赖 [spark-sql-datasource](https://mvnrepository.com/artifact/com.dounine/spark-sql-datasource)
```
```
创建一张测试表
```
CREATE TABLE `test` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`time` date NOT NULL,
`pv` int(255) DEFAULT '0',
`uv` int(255) DEFAULT '0',
PRIMARY KEY (`id`),
UNIQUE KEY `uniq` (`time`) USING BTREE
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
```
程序
```
val spark = SparkSession
.builder()
.appName("jdbc2")
.master("local[*]")
.getOrCreate()
val readSchmeas = StructType(
Array(
StructField("userid", StringType, nullable = false),
StructField("time", StringType, nullable = false),
StructField("indicator", LongType, nullable = false)
)
)
val rdd = spark.sparkContext.parallelize(
Array(
Row.fromSeq(Seq("lake", "2019-02-01", 10L)),
Row.fromSeq(Seq("admin", "2019-02-01", 10L)),
Row.fromSeq(Seq("admin", "2019-02-01", 11L))
)
)
spark.createDataFrame(rdd, readSchmeas).createTempView("log")
spark.sql("select time,count(userid) as pv,count(distinct(userid)) as uv from log group by time")
.write
.format("org.apache.spark.sql.execution.datasources.jdbc2")
.options(
Map(
"savemode" -> JDBCSaveMode.Update.toString,
"driver" -> "com.mysql.jdbc.Driver",
"url" -> "jdbc:mysql://localhost:3306/ttable",
"user" -> "root",
"password" -> "root",
"dbtable" -> "test",
"useSSL" -> "false",
"duplicateIncs" -> "pv,uv",
"showSql" -> "true"
)
).save()
```
实际程序上运行会生成下面的 SQL 语句
```
INSERT INTO test (`time`,`pv`,`uv`)
VALUES (?,?,?)
ON DUPLICATE KEY UPDATE `time`=?,`pv`=`pv`+?,`uv`=`uv`+?
```
生成结果
![spark-sql datasource jdbc2 upsert](https://upload-images.jianshu.io/upload_images/9028759-afacd2fc7610a602.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)
## jdbc2新增配置
|format| duplicateIncs | showSql |
|--|--|--|
|org.apache.spark.sql.execution.datasources.jdbc2|upsert 字段|是否打印 SQL|
其他配置与内置`jdbc`数据源一样~~~