方法一:把整个DataFrame一次写入MySQL (DataFrame的Schema要和MySQL表里定义的域名一致)
Dataset resultDF = spark.sql("select hphm,clpp,clys,tgsj,kkbh from t_cltgxx where id in (" + id.split("_")[0] + "," + id.split("_")[1] + ")");
resultDF.show();
Dataset resultDF2 = resultDF.withColumn("jsbh", functions.lit(new Date().getTime()))
.withColumn("create_time", functions.lit(new Timestamp(new Date().getTime())));
resultDF2.show();
resultDF2.write()
.format("jdbc")
.option("url","jdbc:mysql://lin01.cniao5.com:3306/traffic?characterEncoding=UTF-8")
.option("dbtable","t_tpc_result")
.option("user","root")
.option("password","123456")
.mode(SaveMode.Append)
.save();
方法二:在RDD中调用foreach/foreachPartition,再建connection->prepare SQL->execute-> free connection,这个方法的好处是数据可以按需求处理了再update到表里,不一定需要用到整个DataFrame。
SparkCore里的RDD:
resultRDD.foreach(new VoidFunction() {
@Override
public void call(String s) throws Exception {
String kkbh = s.split("&")[0];
String hphm = s.split("&")[1];
long jsbh = System.currentTimeMillis();
Connection conn = JdbcUtils.getConnection();
String sql = "insert into t_txc_result (JSBH,HPHM,KKBH,CREATE_TIME) values(?,?,?,?)";
PreparedStatement psmt = conn.prepareStatement(sql);
psmt.setString(1,jsbh+"");
psmt.setString(2,hphm);
psmt.setString(3,kkbh);
psmt.setTimestamp(4,new Timestamp(jsbh));
psmt.executeUpdate();
JdbcUtils.free(psmt,conn);
System.out.println("mysql insert : kkbh = " + kkbh + ", hphm = "+ hphm);
}
});
SparkStreaming里的DStream:
resultDStream.foreachRDD(new VoidFunction>() {
@Override
public void call(JavaRDD stringJavaRDD) throws Exception {
stringJavaRDD.foreachPartition(new VoidFunction>() {
@Override
public void call(Iterator stringIterator) throws Exception {
Connection conn = JdbcUtils.getConnection();
PreparedStatement pstmt = null;
while (stringIterator.hasNext()) {
String data = stringIterator.next();
String [] fields = data.split(",");
String hphm=fields[0];
String clpp=fields[1];
String clys=fields[2];
String tgsj=fields[3];
String kkbh=fields[4];
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss");
Date tgsj_date = sdf.parse(tgsj);
String sql = "insert into t_scrc_result (JSBH,HPHM,KKBH,TGSJ,CREATE_TIME) values(?,?,?,?,?)";
pstmt=conn.prepareStatement(sql);
long jsbh = System.currentTimeMillis();
pstmt.setString(1,jsbh+"_streaming");
pstmt.setString(2,hphm);
pstmt.setString(3,kkbh);
pstmt.setTimestamp(4,new Timestamp(tgsj_date.getTime()));
pstmt.setTimestamp(5,new Timestamp(jsbh));
pstmt.executeUpdate();
}
JdbcUtils.free(pstmt,conn);
}
});
}
});
这里用到的getConnection()和free()都是封装的函数,贴上源代码供参考:
package utils;
import java.sql.*;
public class JdbcUtils {
private static String url = "jdbc:mysql://lin01.cniao5.com:3306/traffic?characterEncoding=UTF-8";
private static String user = "root";
private static String pwd = "123456";
private JdbcUtils() {
}
// 1、注册驱动oracle.jdbc.driver.OracleDriver
static {
try {
Class.forName("com.mysql.jdbc.Driver");
} catch (ClassNotFoundException e) {
e.printStackTrace();
System.out.println("数据库驱动加载失败!");
}
}
// 2、建立一个连接
public static Connection getConnection() throws SQLException {
return DriverManager.getConnection(url, user, pwd);
}
// 3、关闭资源
public static void free(Statement stmt, Connection conn) {
try {
if (stmt != null) {
stmt.close();
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (conn != null) {
conn.close();
}
} catch (SQLException e) {
e.printStackTrace();
}
}
}
// 3、关闭资源2
public static void free2(ResultSet rs, Statement stmt, Connection conn) {
try {
if (rs != null)
rs.close();
} catch (SQLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
try {
if (stmt != null)
stmt.close();
} catch (SQLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
try {
if (conn != null)
conn.close();
} catch (SQLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
}
Spark封装的写入DataFrame数据到数据库的API:
结果插入mysql
spark给我们做了封装,插入mysql的代码使用非常简单,直接调用spark的API即可
df.write.mode(SaveMode.Append).format("jdbc")
.option("url",getValueOfPrefix(prefix,"url")) // 数据库连接地址
.option("isolationLevel","NONE") // 不开启事务
.option(JDBCOptions.JDBC_BATCH_INSERT_SIZE,150) // 设置批次大小
.option("dbtable", tableName) // 插入的表
.option("user",getValueOfPrefix(prefix,"username")) // 数据库用户名
.option("password",getValueOfPrefix(prefix,"password")) // 数据库密码
.save()
以上代码,运行的速度有点慢,插入几千的记录大概要话费2分钟左右,后来网上找了一些资料。原因很简单,这并没有开启批次插入,虽然代码设置了,但是数据层面没有开启批次查询,需要在数据库连接后再增加一个参数
rewriteBatchedStatements=true//启动批处理操作
db.url= "jdbc:mysql://localhost:3306/User? rewriteBatchedStatements=true";
设置完这个参数后,插入几千条记录基本就是秒杀。
源代码解析:
源代码解析总结
首先DataFrame会调用write方法,该方法返回一个org.apache.spark.sql.DataFrameWriter对象,这个对象的所有属性设置方法都采用链操作技术方式(设置完成属性后,返回this)
def write: DataFrameWriter[T] = {
if (isStreaming) {
logicalPlan.failAnalysis(
"'write' can not be called on streaming Dataset/DataFrame")
}
new DataFrameWriter[T](this)
}
设置完插入属性后,调用save()方法,去执行结果保存。在save方法中,创建了org.apache.spark.sql.execution.datasources.DataSource对象,通过调用DataSource对象的write(mode, df)方法完成保存数据的操作。
def save(): Unit = {
assertNotBucketed("save")
val dataSource = DataSource(
df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec,
options = extraOptions.toMap)
dataSource.write(mode, df)
}
write方法做了2件事情,判断结果保存到数据库,还是保存到文件系统,本次跟踪的是保存结果到数据。
def write(mode: SaveMode, data: DataFrame): Unit = {
if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) // 保存到数据库
case format: FileFormat =>
writeInFileFormat(format, mode, data)
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}
}
}
org.apache.spark.sql.sources.CreatableRelationProvider#createRelation是一个接口,他的实现在org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider#createRelation
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
df: DataFrame): BaseRelation = {
val jdbcOptions = new JDBCOptions(parameters)
val url = jdbcOptions.url
val table = jdbcOptions.table
val createTableOptions = jdbcOptions.createTableOptions
val isTruncate = jdbcOptions.isTruncate
val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
try {
val tableExists = JdbcUtils.tableExists(conn, url, table)
if (tableExists) {
mode match {
case SaveMode.Overwrite =>
if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, table)
saveTable(df, url, table, jdbcOptions)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, table)
createTable(df.schema, url, table, createTableOptions, conn)
saveTable(df, url, table, jdbcOptions)
}
case SaveMode.Append =>
saveTable(df, url, table, jdbcOptions)
case SaveMode.ErrorIfExists =>
throw new AnalysisException(
s"Table or view '$table' already exists. SaveMode: ErrorIfExists.")
case SaveMode.Ignore =>
// With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
// to not save the contents of the DataFrame and to not change the existing data.
// Therefore, it is okay to do nothing here and then just return the relation below.
}
} else {
createTable(df.schema, url, table, createTableOptions, conn)
saveTable(df, url, table, jdbcOptions)
}
} finally {
conn.close()
}
createRelation(sqlContext, parameters)
}
最后通过org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils#saveTable函数完成数据的插入。
def saveTable(
df: DataFrame,
url: String,
table: String,
options: JDBCOptions) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
df.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
}
可以看到,DataFrame调用foreachPartition函数,进行分区插入操作,真正完成插入的是在savePartition,函数中 。
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
val conn = getConnection()
var committed = false
var finalIsolationLevel = Connection.TRANSACTION_NONE
if (isolationLevel != Connection.TRANSACTION_NONE) {
try {
val metadata = conn.getMetaData
if (metadata.supportsTransactions()) {
// Update to at least use the default isolation, if any transaction level
// has been chosen and transactions are supported
val defaultIsolation = metadata.getDefaultTransactionIsolation
finalIsolationLevel = defaultIsolation
if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
// Finally update to actually requested level if possible
finalIsolationLevel = isolationLevel
} else {
logWarning(s"Requested isolation level $isolationLevel is not supported; " +
s"falling back to default isolation level $defaultIsolation")
}
} else {
logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
}
} catch {
case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
}
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray
val numFields = rddSchema.fields.length
try {
var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i)
}
i = i + 1
}
stmt.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
}
}
if (rowCount > 0) {
stmt.executeBatch()
}
} finally {
stmt.close()
}
if (supportsTransactions) {
conn.commit()
}
committed = true
Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
if (cause != null && e.getCause != cause) {
if (e.getCause == null) {
e.initCause(cause)
} else {
e.addSuppressed(cause)
}
}
throw e
} finally {
if (!committed) {
// The stage must fail. We got here through an exception path, so
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
}
conn.close()
} else {
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
} catch {
case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
}
}
}
}
总结:
在使用开源框架的时候,遇到Bug、或者想要做相关的优化、或者想了解底层的原理,查看源码是最直接有效的方式。
特此记录寻找优化的路径。