spark DataFrame数据插入mysql性能优化(源码解析)

这里说的Spark包含SparkCore/SparkSQL/SparkStreaming,实际上都一样操作。以下展示的都是实际项目中的代码。

方法一:把整个DataFrame一次写入MySQL (DataFrame的Schema要和MySQL表里定义的域名一致)

 

            Dataset resultDF = spark.sql("select hphm,clpp,clys,tgsj,kkbh from t_cltgxx where id in (" + id.split("_")[0] + "," + id.split("_")[1] + ")");
            resultDF.show();
            Dataset resultDF2 = resultDF.withColumn("jsbh", functions.lit(new Date().getTime()))
                    .withColumn("create_time", functions.lit(new Timestamp(new Date().getTime())));
            resultDF2.show();
            resultDF2.write()
                    .format("jdbc")
                    .option("url","jdbc:mysql://lin01.cniao5.com:3306/traffic?characterEncoding=UTF-8")
                    .option("dbtable","t_tpc_result")
                    .option("user","root")
                    .option("password","123456")
                    .mode(SaveMode.Append)
                    .save();

 

 

MySQL表结构:spark DataFrame数据插入mysql性能优化(源码解析)_第1张图片

 

方法二:在RDD中调用foreach/foreachPartition,再建connection->prepare SQL->execute-> free connection,这个方法的好处是数据可以按需求处理了再update到表里,不一定需要用到整个DataFrame。

SparkCore里的RDD:

resultRDD.foreach(new VoidFunction() {
            @Override
            public void call(String s) throws Exception {
                String kkbh = s.split("&")[0];
                String hphm = s.split("&")[1];
                long jsbh = System.currentTimeMillis();
                Connection conn = JdbcUtils.getConnection();
                String sql = "insert into t_txc_result (JSBH,HPHM,KKBH,CREATE_TIME) values(?,?,?,?)";
                PreparedStatement psmt = conn.prepareStatement(sql);
                psmt.setString(1,jsbh+"");
                psmt.setString(2,hphm);
                psmt.setString(3,kkbh);
                psmt.setTimestamp(4,new Timestamp(jsbh));
                psmt.executeUpdate();
                JdbcUtils.free(psmt,conn);
                System.out.println("mysql insert : kkbh = " + kkbh + ", hphm = "+ hphm);
            }
});

 

 

SparkStreaming里的DStream:

        resultDStream.foreachRDD(new VoidFunction>() {
            @Override
            public void call(JavaRDD stringJavaRDD) throws Exception {
                stringJavaRDD.foreachPartition(new VoidFunction>() {
                    @Override
                    public void call(Iterator stringIterator) throws Exception {
                        Connection conn = JdbcUtils.getConnection();
                        PreparedStatement pstmt = null;
                        while (stringIterator.hasNext()) {
                            String data = stringIterator.next();
                            String [] fields = data.split(",");
                            String hphm=fields[0];
                            String clpp=fields[1];
                            String clys=fields[2];
                            String tgsj=fields[3];
                            String kkbh=fields[4];
                            SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss");
                            Date tgsj_date = sdf.parse(tgsj);
                            String sql = "insert into t_scrc_result (JSBH,HPHM,KKBH,TGSJ,CREATE_TIME) values(?,?,?,?,?)";
                            pstmt=conn.prepareStatement(sql);
                            long jsbh = System.currentTimeMillis();
                            pstmt.setString(1,jsbh+"_streaming");
                            pstmt.setString(2,hphm);
                            pstmt.setString(3,kkbh);
                            pstmt.setTimestamp(4,new Timestamp(tgsj_date.getTime()));
                            pstmt.setTimestamp(5,new Timestamp(jsbh));
                            pstmt.executeUpdate();
                        }
                        JdbcUtils.free(pstmt,conn);
 
                    }
                });
            }
        });

 

 

这里用到的getConnection()和free()都是封装的函数,贴上源代码供参考:

package utils;
 
 
import java.sql.*;
 
public class JdbcUtils {
    private static String url = "jdbc:mysql://lin01.cniao5.com:3306/traffic?characterEncoding=UTF-8";
    private static String user = "root";
    private static String pwd = "123456";
 
    private JdbcUtils() {
    }
 
    // 1、注册驱动oracle.jdbc.driver.OracleDriver
 
    static {
        try {
            Class.forName("com.mysql.jdbc.Driver");
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
            System.out.println("数据库驱动加载失败!");
        }
    }
 
    // 2、建立一个连接
    public static Connection getConnection() throws SQLException {
        return DriverManager.getConnection(url, user, pwd);
    }
 
    // 3、关闭资源
    public static void free(Statement stmt, Connection conn) {
        try {
            if (stmt != null) {
                stmt.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (conn != null) {
                    conn.close();
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
 
    // 3、关闭资源2
    public static void free2(ResultSet rs, Statement stmt, Connection conn) {
        try {
            if (rs != null)
                rs.close();
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } finally {
            try {
                if (stmt != null)
                    stmt.close();
            } catch (SQLException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            } finally {
                try {
                    if (conn != null)
                        conn.close();
                } catch (SQLException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
    }
}

 

 

 

Spark封装的写入DataFrame数据到数据库的API:

结果插入mysql
spark给我们做了封装,插入mysql的代码使用非常简单,直接调用spark的API即可

df.write.mode(SaveMode.Append).format("jdbc")
       .option("url",getValueOfPrefix(prefix,"url"))  // 数据库连接地址
       .option("isolationLevel","NONE")  // 不开启事务
       .option(JDBCOptions.JDBC_BATCH_INSERT_SIZE,150)  // 设置批次大小
       .option("dbtable", tableName)  // 插入的表
       .option("user",getValueOfPrefix(prefix,"username"))  // 数据库用户名
       .option("password",getValueOfPrefix(prefix,"password"))  // 数据库密码
       .save()


以上代码,运行的速度有点慢,插入几千的记录大概要话费2分钟左右,后来网上找了一些资料。原因很简单,这并没有开启批次插入,虽然代码设置了,但是数据层面没有开启批次查询,需要在数据库连接后再增加一个参数

rewriteBatchedStatements=true//启动批处理操作

db.url= "jdbc:mysql://localhost:3306/User? rewriteBatchedStatements=true";

设置完这个参数后,插入几千条记录基本就是秒杀。

 

 

源代码解析:

源代码解析总结
首先DataFrame会调用write方法,该方法返回一个org.apache.spark.sql.DataFrameWriter对象,这个对象的所有属性设置方法都采用链操作技术方式(设置完成属性后,返回this)

  def write: DataFrameWriter[T] = {
    if (isStreaming) {
      logicalPlan.failAnalysis(
        "'write' can not be called on streaming Dataset/DataFrame")
    }
    new DataFrameWriter[T](this)
  }




设置完插入属性后,调用save()方法,去执行结果保存。在save方法中,创建了org.apache.spark.sql.execution.datasources.DataSource对象,通过调用DataSource对象的write(mode, df)方法完成保存数据的操作。

  def save(): Unit = {
    assertNotBucketed("save")
    val dataSource = DataSource(
      df.sparkSession,
      className = source,
      partitionColumns = partitioningColumns.getOrElse(Nil),
      bucketSpec = getBucketSpec,
      options = extraOptions.toMap)

    dataSource.write(mode, df)
  }




write方法做了2件事情,判断结果保存到数据库,还是保存到文件系统,本次跟踪的是保存结果到数据。
  def write(mode: SaveMode, data: DataFrame): Unit = {
    if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
      throw new AnalysisException("Cannot save interval data type into external storage.")
    }

    providingClass.newInstance() match {
      case dataSource: CreatableRelationProvider =>
        dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data)  // 保存到数据库
      case format: FileFormat =>
        writeInFileFormat(format, mode, data)
      case _ =>
        sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
    }
  }
}





org.apache.spark.sql.sources.CreatableRelationProvider#createRelation是一个接口,他的实现在org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider#createRelation

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val url = jdbcOptions.url
    val table = jdbcOptions.table
    val createTableOptions = jdbcOptions.createTableOptions
    val isTruncate = jdbcOptions.isTruncate

    val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, url, table)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, table)
              saveTable(df, url, table, jdbcOptions)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, table)
              createTable(df.schema, url, table, createTableOptions, conn)
              saveTable(df, url, table, jdbcOptions)
            }

          case SaveMode.Append =>
            saveTable(df, url, table, jdbcOptions)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '$table' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(df.schema, url, table, createTableOptions, conn)
        saveTable(df, url, table, jdbcOptions)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }





最后通过org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils#saveTable函数完成数据的插入。

  def saveTable(
      df: DataFrame,
      url: String,
      table: String,
      options: JDBCOptions) {
    val dialect = JdbcDialects.get(url)
    val nullTypes: Array[Int] = df.schema.fields.map { field =>
      getJdbcType(field.dataType, dialect).jdbcNullType
    }

    val rddSchema = df.schema
    val getConnection: () => Connection = createConnectionFactory(options)
    val batchSize = options.batchSize
    val isolationLevel = options.isolationLevel
    df.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
    )
  }





可以看到,DataFrame调用foreachPartition函数,进行分区插入操作,真正完成插入的是在savePartition,函数中 。

 def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      nullTypes: Array[Int],
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = insertStatement(conn, table, rddSchema, dialect)
      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
        .map(makeSetter(conn, dialect, _)).toArray
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          if (e.getCause == null) {
            e.initCause(cause)
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

 

总结:

 

在使用开源框架的时候,遇到Bug、或者想要做相关的优化、或者想了解底层的原理,查看源码是最直接有效的方式。

特此记录寻找优化的路径。

你可能感兴趣的:(Spark,MySql)