Spark 实现MySQL update操作

背景

目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update这样的形式来实现。

实践

要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:

public enum I4SaveMode {
    Append,
    Overwrite,
    ErrorIfExists,
    Ignore,
    Update
}

JDBC数据源的相关实现主要在JdbcRelationProvider里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):

   override def createRelation(
                                   sqlContext: SQLContext,
                                   mode: SaveMode,
                                   parameters: Map[String, String],
                                   df: DataFrame): BaseRelation = {
        val options = new JDBCOptions(parameters)
        val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
        // 替换成自己的saveMode
        var saveMode = mode match {
                case SaveMode.Overwrite => I4SaveMode.Overwrite
                case SaveMode.Append => I4SaveMode.Append
                case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
                case SaveMode.Ignore => I4SaveMode.Ignore
            }
        //重点在这里,检查是否有saveMode=update的参数,并设为对应的模式
        val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
        if(parameterLower.keySet.contains("savemode")){
            saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
        }
        val conn = JdbcUtils.createConnectionFactory(options)()
        try {
            val tableExists = JdbcUtils.tableExists(conn, options)
            if (tableExists) {
                saveMode match {
                    case I4SaveMode.Overwrite =>
                        if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
                            // In this case, we should truncate table and then load.
                            truncateTable(conn, options.table)
                            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                            saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
                        } else {
                        ......
    }

接下来就是saveTable方法:

def saveTable(
      df: DataFrame,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      options: JDBCOptions,
      mode: I4SaveMode): Unit = { 
    ......
    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    .....
    repartitionedDF.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
    )
  }

这里通过getInsertStatement方法构造sql语句,接着遍历每个分区进行对应的save操作,我们先看是构造语句是怎么改的(改了的地方都有注释):

def getInsertStatement(
      table: String,
      rddSchema: StructType,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      dialect: JdbcDialect,
      mode: I4SaveMode): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      } 
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    } 
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   //若为update模式需要单独构造
    mode match {
            case I4SaveMode.Update ⇒
                val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ⇒ s"$name=?").mkString(",")
                s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
            case _ ⇒ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        }
  }

只需判断是否是update模式来构造对应的 sql语句,接着主要是看 savePartition 方法,看看具体是怎么保存的:

 def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      insertStmt: String,
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = conn.prepareStatement(insertStmt)
      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          // If there is no cause already, set 'next exception' as cause. If cause is null,
          // it *may* be because no cause was set yet
          if (e.getCause == null) {
            try {
              e.initCause(cause)
            } catch {
              // Or it may be null because the cause *was* explicitly initialized, to *null*,
              // in which case this fails. There is no other way to detect it.
              // addSuppressed in this case as well.
              case _: IllegalStateException => e.addSuppressed(cause)
            }
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

大体思想就是在迭代该分区数据进行插入之前就先根据数据的schema设置好了插入模板setters,迭代的时候只需将此模板应用到每一行数据上就行了,避免了每一行都需要去判断数据类型。
在非update的情况下:insert into tb (id,name,age) values (?,?,?)
在update情况下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下进行写入的时候需要向PreparedStatement多喂一遍数据。原本的makeSetter方法如下:

private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))
    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))
    ...
  }

我们只需要再加一个相对位置参数offset来控制,即改造成:

private def makeSetter(
       conn: Connection,
       dialect: JdbcDialect,
       dataType: DataType): JDBCValueSetter = dataType match {
     case IntegerType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setInt(pos + 1, row.getInt(pos - offset))
     case LongType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setLong(pos + 1, row.getLong(pos - offset))
    ...

在非update模式下offset就为0,在update模式下在没有超过numFields时offset为0,超过numFileds时offset为numFields。改造后的savePartition方法为:

def savePartition(
                 getConnection: () => Connection,
                 table: String,
                 iterator: Iterator[Row],
                 rddSchema: StructType,
                 insertStmt: String,
                 batchSize: Int,
                 dialect: JdbcDialect,
                 isolationLevel: Int,
                 mode: I4SaveMode): Iterator[Byte] = {
    ...
    //判断是否为update
    val isUpdateMode = mode == I4SaveMode.Update
    val stmt = conn.prepareStatement(insertStmt)
    val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
    val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val length = rddSchema.fields.length
    // update模式下占位符是2倍
    val numFields = if (isUpdateMode) length * 2 else length
    val midField = numFields / 2
    try {
        var rowCount = 0
        while (iterator.hasNext) {
            val row = iterator.next()
            var i = 0
            while (i < numFields) {
                if (isUpdateMode) {
                    // update模式下未超过字段长度,offset为0
                    i < midField match {
                        case true ?
                            if (row.isNullAt(i)) {
                                stmt.setNull(i + 1, nullTypes(i))
                            } else {
                                setters(i).apply(stmt, row, i, 0)
                            }
                        // update模式下超过字段长度,offset为midField,即字段长度
                        case false ?
                            if (row.isNullAt(i - midField)) {
                                stmt.setNull(i + 1, nullTypes(i - midField))
                            } else {
                                setters(i - midField).apply(stmt, row, i, midField)
                            }
                    }

                } else {
                    if (row.isNullAt(i)) {
                        stmt.setNull(i + 1, nullTypes(i))
                    } else {
                        setters(i).apply(stmt, row, i, 0)
                    }
                }
                i = i + 1
            }
          ...

改造好源码后,需要重新编译打包,替换掉线上对应的jar即可。其实这里有个捷径,自己创建相同的包名,改好源码后打成jar包,把该jar里面的class文件替换掉线上jar里面对应的那些class文件就可以了。

如何使用

若需要使用到update模式:

df.write.option("saveMode","update").jdbc(...)

参考

https://blog.csdn.net/cjuexuan/article/details/52333970

我的GitHub

你可能感兴趣的:(spark)