spark读取和存储jdbc数据库

使用spark原生的save()方式存储MySQL表会出现以下问题。
1.如果用overwrite存储方式,会把mysql中要存储的表删除掉,新建一张表。于是出现了每个字符类型为Longtext类型,并且没有主键,这显然是不符合要求的。
所以写了以下spark存储jdbc的代码。读取mysql表方法为spark原生load(),读取完后,将会成为一张dataframe表格。此时再使用jdbc提供的方法进行存储。

注意点:获取dataframe中的每一条数据,需要用Object类型来存储。否则在插入数据时可能会出现类型转换问题。

val v = row.getValuesMap[Object](cols).values.toBuffer
statement.setObject(i+1,v.apply(i))

针对mysql现有模式为overwrite,append,upsert

整体架构

package com.mas.bgdt.dmp.framework.handler.impl

import java.sql.{DriverManager, SQLException, ResultSet}

import org.apache.spark.sql.DataFrame
import com.mas.bgdt.dmp.framework.handler.model.Connection

object JdbcHandler {

  //update
  def update(connection: Connection, sql: String) = {
    val conn = getConnection(connection)
    conn.prepareStatement(sql).executeUpdate()
    //conn.close()
  }
  //read source为jdbc:mysql://192.168.10.1:3306/bigdata/xxx格式
  override def read(source: String, connection: Connection): DataFrame = {
    connection.spark.read.format("jdbc").options(Map(
      "url" -> connection.url,
      "driver" -> connection.driver,
      "dbtable" -> source,
      "user" -> connection.user,
      "password" -> connection.password)).load()
  }
  //save
  def save(sourceDf: DataFrame, numPartitions: Integer,tabFormat:String=null,partionBy: String=null,
                    saveMode: String, targetTable: String, connection: Connection) = {
    saveMode match {
      case "overwrite" => {
        //清空表操作
        truncate(connection, targetTable)
        //插入数据操作
        insert(sourceDf, numPartitions, targetTable, connection)
      }
      case "append" => {
        //插入数据操作
        insert(sourceDf, numPartitions, targetTable, connection)
      }
      case "update" => {
        //插入更新数据操作
        upsert(sourceDf, numPartitions, targetTable, connection)
      }
      case _ => throw new IllegalArgumentException("数据存储方式saveMode需要是overwrite,append,update")
    }
  }
  //获取连接
  def getConnection(connection: Connection) ={
    DriverManager.getConnection(connection.url, connection.user, connection.password)
  }
  //清空表
  def truncate(connection: Connection, targetTable: String): Unit = {
    getConnection(connection).prepareStatement(s"truncate table $targetTable").executeUpdate()
  }
  //更新数据 upsert
  def upsert(sourceDf: DataFrame, numPartitions: Integer, targetTable: String, connection: Connection) ={
    //生成插入sql语句中的列名和占位符 insert into aa(name,age) values(?,?) on duplicate key update as=?,d=?,d=?
    val upsertSqls = upsertSql(sourceDf,targetTable)
    //取出dataframe里面的字段
    val cols = sourceDf.columns.toSeq
    //进行值的插入
    sourceDf.repartition(numPartitions).foreachPartition( it =>{
      //创建连接
      val conn = DriverManager.getConnection(connection.url, connection.user, connection.password)
      val statement = conn.prepareStatement(upsertSqls)
      //关闭自动提交
      conn.setAutoCommit(false)
      //在每个分区中去创建可的执行事务
      it.foreach( row =>{
        //根据cols获取dataframe中的每一条数据
        val v = row.getValuesMap[Object](cols).values.toBuffer
        //生成可执行的sql  insert into aa(name,age) values("21","32") on duplicate key update name=21,age=32
        for ( i <- 0 to v.length-1){
          statement.setObject(i+1,v.apply(i))
          statement.setObject(i+1+v.length,v.apply(i))
        }
        statement.addBatch()
        statement.executeBatch()
      })
      //提交数据
      conn.commit()
      //关闭连接
      conn.close()
    })
  }
  //upsertsql语句  如:insert into test(as,d,d) values(?,?,?) on duplicate key update as=?,d=?,d=?
  def upsertSql(sourceDf: DataFrame, targetTable: String): String ={
    var name = " "
    var value = " "
    var name_value = " "
    for(i <- 0 to sourceDf.columns.length-1) {
      name += sourceDf.columns.apply(i) + ","
      value += "?,"
      name_value += sourceDf.columns.apply(i) + "=" + "?" + ","
    }
    name = name.substring(0, name.length()-1).trim()
    value = value.substring(0, value.length()-1).trim()
    name_value = name_value.substring(0, name_value.length()-1).trim()
    val sql = s"insert into $targetTable($name) values($value) on duplicate key update $name_value"
    sql
  }
  //插入数据
  def insert(sourceDf: DataFrame, numPartitions: Integer, targetTable: String, connection: Connection)={
    //生成插入sql语句中的列名和占位符 insert into aa(name,age) values(?,?)
    val insertSqls = insertSql(sourceDf,targetTable)
    //取出dataframe里面的字段
    val cols = sourceDf.columns.toSeq
    //进行值的插入
    sourceDf.repartition(numPartitions).foreachPartition( it =>{
      //创建连接
      val conn = getConnection(connection)
      val statement = conn.prepareStatement(insertSqls)
      //关闭自动提交
      conn.setAutoCommit(false)
      //在每个分区中去创建可的执行事务
      it.foreach( row =>{
        //根据cols获取dataframe中的每一条数据
        val v = row.getValuesMap[Object](cols).values.toBuffer
        //生成可执行的sql  insert into aa(name,age) values("21","32")
        for ( i <- 0 to v.length-1){
          statement.setObject(i+1,v.apply(i))
        }
        statement.addBatch()
        statement.executeBatch()
      })
      //提交数据
      conn.commit()
      //关闭连接
      conn.close()
    })
  }
  //insert sql语句  如:insert into test(as,d,d) values(?,?,?)
  def insertSql(sourceDf: DataFrame, targetTable: String): String ={
    var name = " "
    var value = " "
    for(i <- 0 to sourceDf.columns.length-1) {
      name += sourceDf.columns.apply(i) + ","
      value += "?,"
    }
    name = name.substring(0, name.length()-1).trim()
    value = value.substring(0, value.length()-1).trim()
    val sql = s"insert into $targetTable($name) values($value)"
    sql
  }
}

你可能感兴趣的:(spark)