SparkSql批量插入或更新 保存数据到Pgsql

  • 在sparksql 中,保存数据到数据,只有 Append , Overwrite , ErrorIfExists, Ignore 四种模式,不满足项目需求 ,现依据 spark save 源码,进行进一步的改造, 批量保存数据,存在则更新 不存在 则插入
**
 *测试用例
 *    批量保存数据,存在则更新 不存在 则插入
 *    INSERT INTO test_001 VALUES( ?, ?, ? )
 *    ON conflict ( ID ) DO
 *    UPDATE SET id=?,NAME = ?,age = ?;
 * @author linzhy
 */
object InsertOrUpdateTest {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName)
      .master("local[2]")
      .config("spark.debug.maxToStringFields","100")
      .getOrCreate()

    var config = ConfigFactory.load()
    val ods_url = config.getString("pg.oucloud_ods.url")
    val ods_user = config.getString("pg.oucloud_ods.user")
    val ods_password = config.getString("pg.oucloud_ods.password")

    val test_001 = spark.read.format("jdbc")
      .option("url", ods_url)
      .option("dbtable", "test_001")
      .option("user", ods_user)
      .option("password", ods_password)
      .load()

    test_001.createOrReplaceTempView("test_001")

    val sql=
      """
        |SELECT * FROM test_001
        |""".stripMargin

    val dataFrame = spark.sql(sql)
    //批量保存数据,存在则更新 不存在 则插入
    PgSqlUtil.insertOrUpdateToPgsql(dataFrame,spark.sparkContext,"test_001_copy1","id")

    spark.stop();
  }
}
  • insertOrUpdateToPgsql 方法源码
/**
   * 批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码
   * @param dataFrame
   * @param sc
   * @param table
   * @param id
   */
  def insertOrUpdateToPgsql(dataFrame:DataFrame,sc:SparkContext,table:String,id:String): Unit ={
    
    val tableSchema = dataFrame.schema
    val columns =tableSchema.fields.map(x => x.name).mkString(",")
    val placeholders = tableSchema.fields.map(_ => "?").mkString(",")
    val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on conflict($id) do update set "
    val update = tableSchema.fields.map(x =>
      x.name.toString + "=?"
    ).mkString(",")

    val realsql =sql.concat(update)
    val conn =connectionPool()
    conn.setAutoCommit(false)
    val dialect = JdbcDialects.get(conn.getMetaData.getURL)
    val broad_ps = sc.broadcast(conn.prepareStatement(realsql))

    val numFields = tableSchema.fields.length *2

    val nullTypes = tableSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val setters = tableSchema.fields.map(f => makeSetter(conn,f.dataType))
    
    var rowCount = 0
    val batchSize = 2000
    val updateindex = numFields / 2
    try {
        dataFrame.foreachPartition(iterator =>{
          //遍历批量提交
          val ps = broad_ps.value
          try{
            while (iterator.hasNext) {
              val row = iterator.next()
              var i = 0
              while (i < numFields) {
                i < updateindex match {
                  case true =>{
                    if (row.isNullAt(i)) {
                      ps.setNull(i + 1, nullTypes(i))
                    } else {
                      setters(i).apply(ps, row, i,0)
                    }
                  }
                  case false =>{
                    if (row.isNullAt(i-updateindex)) {
                      ps.setNull(i + 1, nullTypes(i-updateindex))
                    } else {
                      setters(i-updateindex).apply(ps, row, i,updateindex)
                    }
                  }
                }
                i = i + 1
              }
              ps.addBatch()
              rowCount += 1
              if (rowCount % batchSize == 0) {
                ps.executeBatch()
                rowCount = 0
              }
            }
            if (rowCount > 0) {
              ps.executeBatch()
            }
          }finally {
            ps.close()
          }
        })
      conn.commit()
    }catch {
      case e: Exception =>
        logError("Error in execution of insert. " + e.getMessage)
      conn.rollback()
       // insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)
    }finally {
      conn.close()
    }
  }

你可能感兴趣的:(Spark)