背景
在实际应用场景,出于安全等方面考虑,有时候大数据平台不暴露对Phoenix的zookeeper url连接方式,本文提供一种基于jdbc的方式实现spark将Spark RDD/DataFrame分布式批量写入Phoenix,目前spark没有提供读写phoenix的jdbc实现。
`Spark version:2.1.0
phoenix version:4.10.0`
实现方式
Spark的jdbc接口未提供phoenix版本的方言(phoenix sql中插入数据使用upsert),通过阅读源码,可以通过修改org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
实现。
/**
* Saves the RDD to the database in a single transaction.
*/
def saveTable(
df: DataFrame,
url: String,
table: String,
options: JDBCOptions) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
df.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
}
从2.0开始,spark的这个对象变为private,因此无法直接使用该对象中方法,下面是参考该对象实现的PhoenixJdbcUtils:
object PhoenixJdbcUtils {
private val logger = LoggerFactory.getLogger("")
def saveTable(df: DataFrame,table: String,batchSize: Int) {
try {
val columns = df.columns.mkString(",")
val rddSchema = df.schema
val recordFormat: scala.collection.mutable.StringBuilder = new scala.collection.mutable.StringBuilder()
df.dtypes.foreach(x => {
x._2 match {
case "StringType" => recordFormat.append("?,")
case _ => recordFormat.append("?,")
}
})
val placeholders = recordFormat.stripSuffix(",")
df.foreachPartition(iterator => {
savePartition(PhoenixConnection.conn, table, iterator, columns, placeholders, rddSchema, batchSize)
})
}
catch {
case e => logger.error(e.toString)
}
}
def insertStatement(conn: Connection, table: String,columns: String,placeholders: String)
: PreparedStatement = {
//这里定义写入phoenix的SQL方言
val sql = s"UPSERT INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
}
private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
private def makeSetter(dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))
case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))
case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos))
case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos))
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))
case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos))
case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setString(pos + 1, row.getString(pos))
case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
case _ =>
(_: PreparedStatement, _: Row, pos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}
private def getJdbcType(dt: DataType): JdbcType = {
getCommonJDBCType(dt).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}
def savePartition(
conn: Connection,
table: String,
iterator: Iterator[Row],
columns: String,
placeholders: String,
rddSchema: StructType,
batchSize: Int): Iterator[Byte] = {
try {
var committed = false
val stmt = insertStatement(conn, table,columns,placeholders)
val setters = rddSchema.fields.map(_.dataType).map(makeSetter(_)).toArray
val numFields = rddSchema.fields.length
try {
var rowCount = 0
var writeBatch = 0
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1,i)
} else {
setters(i).apply(stmt, row, i)
}
i = i + 1
}
stmt.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
writeBatch = writeBatch + 1
rowCount = 0
}
}
if (rowCount > 0) {
stmt.executeBatch()
}
} finally {
stmt.close()
}
conn.commit()
committed = true
if(iterator.isEmpty)
{
conn.close()
}
Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
if (cause != null && e.getCause != cause) {
if (e.getCause == null) {
e.initCause(cause)
} else {
e.addSuppressed(cause)
}
}
throw e
} finally {
conn.close()
}
}
}
调用方式:
PhoenixJdbcUtils.saveTable(df.repartition(repartitions),phoenixTable,1000)
包冲突问题
spark 2.1.0和phoenix-queryserver 4.10.0存在org.apache.calcite.avatica包冲突问题,通过maven的relocation解决:
org.apache.calcite.avatica
${myorg.prefix}.org.apache.calcite.avatica
org.apache.calcite.avatica.proto.*