我上一篇写了关于如何在spark中直接访问有kerberos的hbase,现在我们需要对hbase进行全表的分布式扫描,在无kerberos的情况下通过sparkcontext的newApiHadoopRDD就可以达到目的,但有了kerberos的限制,这个方法就不行了,也许有人会想到通过我之前提到的ugi doAs方法解决,但是分布式扫描的情况下如果将newApiHadoopRDD放到doAs中,只能对当前节点起到作用,无法为其他节点赋予权限,使用的时候仍然会出现证书无效的问题。
那么,如何解决此问题呢 ?首先,尽量保证spark运行在yarn环境中,不要用stdandalone方式,接着就是如何权限问题了,网上给出的比较多的方式是通过token,不过我没有尝试成功,cloudera提供了一种解决此问题的方法,就是重载RDD,自己实现一个newApiHadoopRDD。这样只要将其中访问hbase的部分放在ugi的doAs中就可以了。
该如何重载呢?RDD的重载并不困难,首先继承RDD类,接着要实现getPartitions和compute两个接口。
getPartitions是spark获取分片信息的方法。而compute则是用来获取具体的数据。
下面来看代码:
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{ SparkContext, TaskContext }
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.SerializableWritable
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.security.Credentials
import org.apache.spark.rdd.RDD
import org.apache.spark.Partition
import org.apache.spark.InterruptibleIterator
import org.apache.hadoop.hbase.mapreduce.TableInputFormat
import org.apache.hadoop.hbase.mapreduce.TableMapReduceUtil
import org.apache.hadoop.hbase.client.Scan
import org.apache.hadoop.mapreduce.Job
import org.apache.spark.Logging
import org.apache.spark.SparkHadoopMapReduceUtilExtended
import org.apache.hadoop.mapreduce.JobID
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.InputSplit
import java.text.SimpleDateFormat
import java.util.Date
import java.util.ArrayList
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod
import org.apache.hadoop.hbase.mapreduce.IdentityTableMapper
import org.apache.hadoop.hbase.CellUtil
import org.apache.hadoop.security.UserGroupInformation
import java.security.PrivilegedExceptionAction
import java.security.PrivilegedAction
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.hbase.client.Result
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
class HBaseScanRDD(sc: SparkContext,
usf: UserSecurityFunction,
@transient tableName: String,
@transient scan: Scan,
configBroadcast: Broadcast[SerializableWritable[Configuration]])
extends RDD[(Array[Byte], java.util.List[(Array[Byte], Array[Byte], Array[Byte])])](sc, Nil)
with SparkHadoopMapReduceUtilExtended
with Logging {
///
@transient val jobTransient = new Job(configBroadcast.value.value, "ExampleRead");
if(usf.isSecurityEnable())
{
usf.login().doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit =
TableMapReduceUtil.initTableMapperJob(
tableName, // input HBase table name
scan, // Scan instance to control CF and attribute selection
classOf[IdentityTableMapper], // mapper
null, // mapper output key
null, // mapper output value
jobTransient)
})
}else{
TableMapReduceUtil.initTableMapperJob(
tableName, // input HBase table name
scan, // Scan instance to control CF and attribute selection
classOf[IdentityTableMapper], // mapper
null, // mapper output key
null, // mapper output value
jobTransient)
}
@transient val jobConfigurationTrans = jobTransient.getConfiguration()
jobConfigurationTrans.set(TableInputFormat.INPUT_TABLE, tableName)
val jobConfigBroadcast = sc.broadcast(new SerializableWritable(jobConfigurationTrans))
////
private val jobTrackerId: String = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
formatter.format(new Date())
}
@transient protected val jobId = new JobID(jobTrackerId, id)
override def getPartitions: Array[Partition] = {
//addCreds
val tableInputFormat = new TableInputFormat
tableInputFormat.setConf(jobConfigBroadcast.value.value)
val jobContext = newJobContext(jobConfigBroadcast.value.value, jobId)
var rawSplits : Array[Object] = null
if(usf.isSecurityEnable())
{
rawSplits = usf.login().doAs(new PrivilegedAction[Array[Object]]{
def run: Array[Object] = tableInputFormat.getSplits(jobContext).toArray
})
}else{
rawSplits = tableInputFormat.getSplits(jobContext).toArray
}
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
}
result
}
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(Array[Byte], java.util.List[(Array[Byte], Array[Byte], Array[Byte])])] = {
//addCreds
val iter = new Iterator[(Array[Byte], java.util.List[(Array[Byte], Array[Byte], Array[Byte])])] {
//addCreds
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = jobConfigBroadcast.value.value
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = new TableInputFormat
format.setConf(conf)
var reader : RecordReader[ImmutableBytesWritable, Result] = null
if(usf.isSecurityEnable())
{
reader = usf.login().doAs(new PrivilegedAction[RecordReader[ImmutableBytesWritable, Result]]{
def run: RecordReader[ImmutableBytesWritable, Result] = {
val _reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
_reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
_reader
}
})}else{
reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
}
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
var havePair = false
var finished = false
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
}
!finished
}
override def next(): (Array[Byte], java.util.List[(Array[Byte], Array[Byte], Array[Byte])]) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
val it = reader.getCurrentValue.listCells().iterator()
val list = new ArrayList[(Array[Byte], Array[Byte], Array[Byte])]()
while (it.hasNext()) {
val kv = it.next()
list.add((CellUtil.cloneFamily(kv), CellUtil.cloneQualifier(kv), CellUtil.cloneValue(kv)))
}
(reader.getCurrentKey.copyBytes(), list)
}
private def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
}
}
}
new InterruptibleIterator(context, iter)
}
def addCreds {
val creds = SparkHadoopUtil.get.getCurrentUserCredentials()
val ugi = UserGroupInformation.getCurrentUser();
ugi.addCredentials(creds)
// specify that this is a proxy user
ugi.setAuthenticationMethod(AuthenticationMethod.PROXY)
}
private class NewHadoopPartition(
rddId: Int,
val index: Int,
@transient rawSplit: InputSplit with Writable)
extends Partition {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
override def hashCode(): Int = 41 * (41 + rddId) + index
}
}
package org.apache.spark
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
trait SparkHadoopMapReduceUtilExtended extends SparkHadoopMapReduceUtil{
}
这样就可以直接使用,new HBaseScanRDD(),传入SparkContext。
如果想用java访问,还需要在外面封装一层。
import org.apache.spark.api.java.JavaSparkContext
import org.apache.hadoop.conf.Configuration
import org.apache.spark.broadcast.Broadcast
import org.apache.hadoop.hbase.client.Scan
import org.apache.spark.SerializableWritable
import org.apache.spark.api.java.JavaPairRDD
class JavaHBaseContext(@transient jsc: JavaSparkContext,usf: UserSecurityFunction) {
def hbaseApiRDD(tableName: String,
scan: Scan,
configBroadcast: Configuration) = {
JavaPairRDD.fromRDD(new HBaseScanRDD(JavaSparkContext.toSparkContext(jsc),usf,tableName,
scan,
jsc.broadcast(new SerializableWritable(configBroadcast))))
}
}