记录oracle回写的几个解决方案

由于用的是spark1.5.1的版本,出现诸多想不到的bug,记录下来,供大家参考。

首先说下我们的需求,是将hive的表进行回写入oracle,必须使用sparksql这种形式,所以就不考虑sqoop,集群的大数据平台没有sqoop组件。必须按照一定的数据格式精准输出,从oracle跑数时用的什么类型,最后回到oracle是什么类型,并且精度是一致的。
由于大数据平台hive中,将date也存为了string,并且hive的string是不指定长度的,难度在此。

1.第一种方案:

由于考虑到不允许访问hive的metadata元信息,所以使用sqlContext.sql读目标表的schema,将其转为rdd,利用读取oracle的系统表获取最终转换的数据类型及长度,重组schema,并将其与rdd重新构成dataframe
使用一个spark.jdbc类的write.jdbc方法
option(“createTableColumnTypes”,”name varchar(200)”)
加上这个属性,来解决最后建表问题。该方法的该属性,经过测试,无法使用于spark1.5.1版本,应为2.2.0版本使用。
代码如下:

package test1

import org.apache.spark.{ SparkContext, SparkConf }
import org.apache.spark.sql._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import oracle.jdbc.driver.OracleDriver
import sun.security.util.Length
import org.apache.spark.sql.types.StringType
import java.util.ArrayList
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
import org.apache.spark.sql.jdbc._
import java.sql.Types

object ojdbcTest {

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("firstTry").setMaster("local");
    val sc = new SparkContext(conf);
    val sqlContext = new HiveContext(sc);

    //控制schame优化
    var df = sqlContext.sql("select * from  ****.BL_E01_REJECTACCOUNT")
    val df1 = df.schema.toArray

    val theJdbcDF = sqlContext.load("jdbc", Map(
      "url" -> "jdbc:oracle:thin:***/*****@//*****/*****",
      "dbtable" -> "( select column_name ,data_type,data_length,data_precision,data_scale from user_tab_cols where table_name ='BL_E01_REJECTACCOUNT' order by COLUMN_ID ) a ",
      "driver" -> "oracle.jdbc.driver.OracleDriver",
      "numPartitions" -> "5",
      "lowerBound" -> "0",
      "upperBound" -> "80000000"))

    val str = theJdbcDF.collect().toArray
    var dateArray = new ArrayBuffer[String]
    var stringArray = new ArrayBuffer[(String, Int)]

    var list = new ArrayList[org.apache.spark.sql.types.StructField]();

    var string = new ArrayList[String]

    for (j <- 0 until str.length) {
      var st = str(j)
      var column_name = st.get(0)
      var data_type = st.get(1)
      var data_length = st.get(2)
      var data_precision = st.get(3)
      var data_scale = st.get(4)
      println(column_name + ":" + data_type + ":" + data_length + ":" + data_precision + data_scale)

      if (data_type.equals("DATE")) {
        dateArray += (column_name.toString())
        string.add(column_name.toString() + " " + data_type.toString())
      }

      if (data_type.equals("NUMBER")) {
        if (data_precision != null) {
          string.add(column_name.toString() + " " + data_type.toString() + s"(${data_precision.toString().toDouble.intValue()},${data_scale.toString().toDouble.intValue()})")
        } else {
          string.add(column_name.toString() + " " + data_type.toString())
        }

      }
      if (data_type.equals("VARCHAR2")) {
        stringArray += ((column_name.toString(), data_length.toString().toDouble.intValue()))
        string.add(column_name.toString() + " " + data_type.toString() + s"(${data_length.toString().toDouble.intValue()})")
      }

    }
    for (i <- 0 until df1.length) {
      var b = df1(i)
      var dataName = b.name
      var dataType = b.dataType
      //          println("字段名"+dataName+"字段类型"+dataType)
      if (dateArray.exists(p => p.equalsIgnoreCase(s"${dataName}"))) {
        dataType = DateType

      }
      var structType = DataTypes.createStructField(dataName, dataType, true)

      list.add(structType)
    }

    val schema = DataTypes.createStructType(list)

    if (dateArray.length > 0) {

      for (m <- 0 until dateArray.length) {
        var mm = dateArray(m).toString()
        println("mm:" + mm)
        var df5 = df.withColumn(s"$mm", df(s"$mm").cast(DateType))
        df = df5
      }
    }

    val rdd = df.toJavaRDD
    val df2 = sqlContext.createDataFrame(rdd, schema);

    df2.printSchema()

    val url = "jdbc:oracle:thin:@//*******/***"
    val table = "test2"
    val user = "***"
    val password = "***"

    val url1="jdbc:oracle:thin:***/***@//***/***"
    val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")

    val a = string.toString()
    val option = a.substring(1, a.length() - 1)
    println(option)

    df2.option("createTableColumnTypes",s"${option}").write.jdbc(url, table, connectionProperties)

    sc.stop()
  }
} 

代码写的比较随意,只是个test类。

2.第二种方案:

由于考虑到之前那些情况,以上方法不适用于1.5.1后面又采用新的办法
使用重写JdbcDialect类中的三个方法进行读写,这个是sql当中获取jdbc数据库类型的方法,重写就可以实现简单转换。

package test1

import org.apache.spark.{ SparkContext, SparkConf }
import org.apache.spark.sql._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import oracle.jdbc.driver.OracleDriver
import sun.security.util.Length
import org.apache.spark.sql.types.StringType
import java.util.ArrayList
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
import org.apache.spark.sql.jdbc._
import java.sql.Types

object ojdbcTest {



    def oracleInit(){

      val dialect:JdbcDialect= new JdbcDialect() {
        override def canHandle(url:String)={
          url.startsWith("jdbc:oracle");
        }
        //读oracle的类型转换方法
        override def getCatalystType(sqlType, typeName, size, md):Option[DataType]={


      }
      //写oracle的类型转换方法
        override def getJDBCType(dt:DataType):Option[org.apache.spark.sql.jdbc.JdbcType]=

         dt match{
            case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
            case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
            case LongType    => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
            case FloatType   => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
            case DoubleType  => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
            case ByteType    => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
            case ShortType   => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
           case StringType  => Some(JdbcType("VARCHAR2(250)", java.sql.Types.VARCHAR))
            case DateType    => Some(JdbcType("DATE", java.sql.Types.DATE))
            case DecimalType.Unlimited => Some(JdbcType("NUMBER",java.sql.Types.NUMERIC))
            case _ => None
          }

      }
      JdbcDialects.registerDialect(dialect);
    }


  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("firstTry").setMaster("local");
    val sc = new SparkContext(conf);
    val sqlContext = new HiveContext(sc);

    //控制schame优化
    var df = sqlContext.sql("select * from  ****.BL_E01_REJECTACCOUNT")
    val df1 = df.schema.toArray

    val theJdbcDF = sqlContext.load("jdbc", Map(
      "url" -> "jdbc:oracle:thin:****/****@//********/claimamdb",
      "dbtable" -> "( select column_name ,data_type,data_length,data_precision,data_scale from user_tab_cols where table_name ='BL_E01_REJECTACCOUNT' order by COLUMN_ID ) a ",
      "driver" -> "oracle.jdbc.driver.OracleDriver",
      "numPartitions" -> "5",
      "lowerBound" -> "0",
      "upperBound" -> "80000000"))

    val str = theJdbcDF.collect().toArray
    var dateArray = new ArrayBuffer[String]
    var stringArray = new ArrayBuffer[(String, Int)]

    var list = new ArrayList[org.apache.spark.sql.types.StructField]();



    for (j <- 0 until str.length) {
      var st = str(j)
      var column_name = st.get(0)
      var data_type = st.get(1)
      var data_length = st.get(2)
      var data_precision = st.get(3)
      var data_scale = st.get(4)
      println(column_name + ":" + data_type + ":" + data_length + ":" + data_precision + data_scale)

      if (data_type.equals("DATE")) {
        dateArray += (column_name.toString())

      }


      if (data_type.equals("VARCHAR2")) {
        stringArray += ((column_name.toString(), data_length.toString().toDouble.intValue()))

      }

    }
    for (i <- 0 until df1.length) {
      var b = df1(i)
      var dataName = b.name
      var dataType = b.dataType
      //          println("字段名"+dataName+"字段类型"+dataType)
      if (dateArray.exists(p => p.equalsIgnoreCase(s"${dataName}"))) {
        dataType = DateType

      }
      var structType = DataTypes.createStructField(dataName, dataType, true)

      list.add(structType)
    }

    val schema = DataTypes.createStructType(list)

    if (dateArray.length > 0) {

      for (m <- 0 until dateArray.length) {
        var mm = dateArray(m).toString()
        println("mm:" + mm)
        var df5 = df.withColumn(s"$mm", df(s"$mm").cast(DateType))
        df = df5
      }
    }

    val rdd = df.toJavaRDD
    val df2 = sqlContext.createDataFrame(rdd, schema);

    df2.printSchema()

    val url = "jdbc:oracle:thin:@//********/claimamdb"
    val table = "test2"
    val user = "****"
    val password = "****"

    val url1="jdbc:oracle:thin:****/****@//********/claimamdb"
    val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")




    oracleInit()
    df2.write.jdbc(url, table, connectionProperties)

    sc.stop()



  }
}

这种方法只能解决简单类型转换,不能够满足我将hive中原先date已经被转为string再转换回oracle的date,因为即便是重写方法一样也不能传进去参数,无法判断哪个string是date型,可以继承logging类重新jdbcUtils,需要读懂源码还是有些复杂的。

3.第三种方案

代码和第一种相同。
将方法改为由于无法使其建表数据类型为精准值,每次写入oracle中string没有长度就会默认255,这种问题,我将其改为使用createjdbctable和insertIntoJDBC(url1, table, true),结果发现该版本的insertintojdbc是有bug的,官方文档上提示

Save this DataFrame to a JDBC database at url under the table name table. Assumes the table already exists and has a compatible schema. If you pass true for overwrite, it will TRUNCATE the table before performing the INSERTs. 

The table must already exist on the database. It must have a schema that is compatible with the schema of this RDD; inserting the rows of the RDD in order via the simple statement INSERT INTO table VALUES (?, ?, ..., ?) should not fail.

结果还会报错表已经存在,经过去国外的网站查询发现,这是一个bug。
查询结果如下
记录oracle回写的几个解决方案_第1张图片

记录oracle回写的几个解决方案_第2张图片

这里写图片描述

好了看了这么多东西以后,不采用以上方法,该如何将我们的数据精准搞进去。

4.第四种方案

我看了下的oracle数据库最大varchar2长度是4000,我这么考虑一下,利用重写方言的getjdbcType方法将所有string的数据转为4000,保证数据不会被截断,然后利用oracle的jdbc类将我们目标表的建表字符串拿去建表,然后用dataframe写入一张oracle的临时表,其中varchar2都是4000,再利用select将该表数据导入目标表中。

中间date类型我利用系统表的字段判断出来以后,将其转为timestamp类型,在重写的getjdbcType中转为底层的oracle的date类,这样就不会出现日期被截断的问题。

代码如下:

package test1

import org.apache.spark.{ SparkContext, SparkConf }
import org.apache.spark.sql._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import oracle.jdbc.driver.OracleDriver
import sun.security.util.Length
import org.apache.spark.sql.types.StringType
import java.util.ArrayList
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
import org.apache.spark.sql.jdbc._
import java.sql.Types

import java.sql.Connection
import java.sql.DriverManager
object ojdbcTest {



      def oracleInit(){

        val dialect:JdbcDialect= new JdbcDialect() {
          override def canHandle(url:String)={
            url.startsWith("jdbc:oracle");
          }

//       override def getCatalystType(sqlType, typeName, size, md):Option[DataType]={
    //
    //
    //      }
          override def getJDBCType(dt:DataType):Option[org.apache.spark.sql.jdbc.JdbcType]=

            dt match{
              case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
              case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
              case LongType    => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
              case FloatType   => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
              case DoubleType  => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
              case ByteType    => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
              case ShortType   => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
              case StringType  => Some(JdbcType("VARCHAR2(4000)", java.sql.Types.VARCHAR))
              case DateType    => Some(JdbcType("DATE", java.sql.Types.DATE))
              case DecimalType.Unlimited => Some(JdbcType("NUMBER",java.sql.Types.NUMERIC))
              case TimestampType=> Some(JdbcType("DATE",java.sql.Types.DATE))
              case _ => None
            }

        }
         JdbcDialects.registerDialect(dialect);
      }

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("firstTry").setMaster("local");
    val sc = new SparkContext(conf);
    val sqlContext = new HiveContext(sc);

    //控制schame优化
    var df = sqlContext.sql("select * from  ******.BL_E01_REJECTACCOUNT")
    val df1 = df.schema.toArray

    //val customSchema = sparkTargetDF.dtypes.map(x => x._1+" "+x._2).mkString(",").toUpperCase()
    val theJdbcDF = sqlContext.load("jdbc", Map(
      "url" -> "jdbc:oracle:thin:********/********//********/********",
      "dbtable" -> "( select column_name ,data_type,data_length,data_precision,data_scale from user_tab_cols where table_name ='BL_E01_REJECTACCOUNT' order by COLUMN_ID ) a ",
      "driver" -> "oracle.jdbc.driver.OracleDriver",
      "numPartitions" -> "5",
      "lowerBound" -> "0",
      "upperBound" -> "80000000"))

    val str = theJdbcDF.collect().toArray
    var dateArray = new ArrayBuffer[String]
    var stringArray = new ArrayBuffer[(String, Int)]

    var list = new ArrayList[org.apache.spark.sql.types.StructField]();

    var string = new ArrayList[String]

    for (j <- 0 until str.length) {
      var st = str(j)
      var column_name = st.get(0)
      var data_type = st.get(1)
      var data_length = st.get(2)
      var data_precision = st.get(3)
      var data_scale = st.get(4)
      println(column_name + ":" + data_type + ":" + data_length + ":" + data_precision + data_scale)

      if (data_type.equals("DATE")) {
        dateArray += (column_name.toString())
        string.add(column_name.toString() + " " + data_type.toString())
      }

      if (data_type.equals("NUMBER")) {
        if (data_precision != null) {
          string.add(column_name.toString() + " " + data_type.toString() + s"(${data_precision.toString().toDouble.intValue()},${data_scale.toString().toDouble.intValue()})")
        } else {
          string.add(column_name.toString() + " " + data_type.toString())
        }

      }
      if (data_type.equals("VARCHAR2")) {
        stringArray += ((column_name.toString(), data_length.toString().toDouble.intValue()))
        string.add(column_name.toString() + " " + data_type.toString() + s"(${data_length.toString().toDouble.intValue()})")
      }

    }
    for (i <- 0 until df1.length) {
      var b = df1(i)
      var dataName = b.name
      var dataType = b.dataType
      //          println("字段名"+dataName+"字段类型"+dataType)
      if (dateArray.exists(p => p.equalsIgnoreCase(s"${dataName}"))) {
        dataType = TimestampType

      }
      var structType = DataTypes.createStructField(dataName, dataType, true)

      list.add(structType)
    }

    val schema = DataTypes.createStructType(list)

    if (dateArray.length > 0) {

      for (m <- 0 until dateArray.length) {
        var mm = dateArray(m).toString()
        println("mm:" + mm)
        var df5 = df.withColumn(s"$mm", df(s"$mm").cast(TimestampType))
        df = df5
      }
    }

    val rdd = df.toJavaRDD
    val df2 = sqlContext.createDataFrame(rdd, schema);

    df2.printSchema()

    val url = "jdbc:oracle:thin:@//********/********"
    val table = "test2"
    val table1="test3"
    val user = "********"
    val password = "#EDC5tgb"

    val url1 = "jdbc:oracle:thin:********/********//********/********"
    val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")

    val a = string.toString()
    val option = a.substring(1, a.length() - 1)
    println(option)

    oracleInit()

    createJdbcTable(option,table)

    println("create table is finish!")

    df2.write.jdbc(url, table1, connectionProperties)

    insertTable(table,table1)
    println("已导入目标表!")


    sc.stop()
    //option("createTableColumnTypes", "CLAIMNO VARCHAR2(300), comments VARCHAR(1024)")
    //df2.select(df2("POLICYNO")).write.option("createTableColumnTypes", "CLAIMNO VARCHAR2(200)")
    //.jdbc(url, table, connectionProperties)
  }

  def createJdbcTable(option:String,table:String) = {

    val url = "jdbc:oracle:thin:@//********/********"
    //驱动名称
    val driver = "oracle.jdbc.driver.OracleDriver"
    //用户名
    val username = "********"
    //密码
    val password = "#EDC5tgb"
    //初始化数据连接
    var connection: Connection = null
    try {
      //注册Driver
      Class.forName(driver)
      //得到连接
      connection = DriverManager.getConnection(url, username, password)
      val statement = connection.createStatement
      //执行查询语句,并返回结果
      val sql =s"""
        create table ${table}
(
 ${option}
)
        """
      val rs = statement.executeQuery(sql)
      connection.close
    } catch { case e: Exception => e.printStackTrace }
    finally { //关闭连接,释放资源   connection.close     }
    }
  }

  def insertTable(table:String,table1:String){
    val url = "jdbc:oracle:thin:@//********/********"
    //驱动名称
    val driver = "oracle.jdbc.driver.OracleDriver"
    //用户名
    val username = "********"
    //密码
    val password = "*********"
    //初始化数据连接
    var connection: Connection = null
    try {
      //注册Driver
      Class.forName(driver)
      //得到连接
      connection = DriverManager.getConnection(url, username, password)
      val statement = connection.createStatement
      //执行查询语句,并返回结果
      val sql =s"""
        insert into ${table} select * from  ${table1}
        """
      val rs = statement.executeQuery(sql)
      connection.close
    } catch { case e: Exception => e.printStackTrace }
    finally { //关闭连接,释放资源   connection.close     }
    }

  }
}

很多版本上的坑比如说用
write.mode().jdbc()
这个mode给提供的参数无论给什么都会overwirite掉,无论是append还是ignore。查了下源码,savemode被写死为overwrite。
这个问题详细参考:

https://www.2cto.com/net/201609/551130.html

祝大家少走弯路!

补:由于生产环境不允许建表,所以采用

val connectionProperties = new Properties()
    connectionProperties.put("user", user)
    connectionProperties.put("password", password)
    connectionProperties.put("driver", "oracle.jdbc.driver.OracleDriver")
 jdbcUtils.saveTable(df,url,table,connectionproperties)

这种方式插入数据,实测可行。

你可能感兴趣的:(spark)