Hive on Spark源码分析(一)—— SparkTask
Hive on Spark源码分析(二)—— SparkSession与HiveSparkClient
Hive on Spark源码分析(三)—— SparkClilent与SparkClientImpl(上)
Hive on Spark源码分析(四)—— SparkClilent与SparkClientImpl(下)
Hive on Spark源码分析(五)—— RemoteDriver
Hive on Spark源码分析(六)—— RemoteSparkJobMonitor与JobHandle
RemoteDriver与SparkClient进行任务交互,并向Spark集群提交任务的。SparkClientImpl中通过调用RemoteDriver.main在新进程中启动了RemoteDriver
main函数
public static void main(String[] args) throws Exception {
new RemoteDriver(args).run();
}
run方法里主要做关闭线程,删除临时目录的工作
private void run() throws InterruptedException {
synchronized (shutdownLock) {
while (running) {
shutdownLock.wait();
}
}
executor.shutdownNow();
try {
FileUtils.deleteDirectory(localTmpDir);
} catch (IOException e) {
LOG.warn("Failed to delete local tmp dir: " + localTmpDir, e);
}
}
接下来我们看一下RemoteDriver的私有构造函数。处理参数,初始化环境变量,提交任务等都在私有工造方法中完成。首先是解析SparkClient传给RemoteDriver的参数,并付给相应的SparkConf:
SparkConf conf = new SparkConf();
String serverAddress = null;
int serverPort = -1;
for (int idx = 0; idx < args.length; idx += 2) {
String key = args[idx];
if (key.equals("--remote-host")) {
serverAddress = getArg(args, idx);
} else if (key.equals("--remote-port")) {
serverPort = Integer.parseInt(getArg(args, idx));
} else if (key.equals("--client-id")) {
conf.set(SparkClientFactory.CONF_CLIENT_ID, getArg(args, idx));
} else if (key.equals("--secret")) {
conf.set(SparkClientFactory.CONF_KEY_SECRET, getArg(args, idx));
} else if (key.equals("--conf")) {
String[] val = getArg(args, idx).split("[=]", 2);
conf.set(val[0], val[1]);
} else {
throw new IllegalArgumentException("Invalid command line: "
+ Joiner.on(" ").join(args));
}
}
//线程池,用于创建线程执行任务
executor = Executors.newCachedThreadPool();
LOG.info("Connecting to: {}:{}", serverAddress, serverPort);
//将RemoteDriver使用的参数保存到mapConf中
Map<String, String> mapConf = Maps.newHashMap();
for (Tuple2<String, String> e : conf.getAll()) {
mapConf.put(e._1(), e._2());
LOG.debug("Remote Driver configured with: " + e._1() + "=" + e._2());
}
// 得到clientId和secret用于与rpcServer建立连接时的认证.认证基于sasl,这里不考虑细节.
// sasl作为pipeline中的一个handler实现
String clientId = mapConf.get(SparkClientFactory.CONF_CLIENT_ID);
Preconditions.checkArgument(clientId != null, "No client ID provided.");
String secret = mapConf.get(SparkClientFactory.CONF_KEY_SECRET);
Preconditions.checkArgument(secret != null, "No secret provided.");
//获取hive.spark.client.rpc.threads的值,如果没有设置则获取到默认值8
int threadCount = new RpcConfiguration(mapConf).getRpcThreadCount();
this.egroup = new NioEventLoopGroup(
threadCount,
new ThreadFactoryBuilder()
.setNameFormat("Driver-RPC-Handler-%d")
.setDaemon(true)
.build());
//protocol实际是一个handler,与ClientProtocol类似
this.protocol = new DriverProtocol();
// The RPC library takes care of timing out this.
// 这里的createClient返回一个Promise
,从Future的get方法返回返回Rpc类型的对象 // 这里会创建Bootstrap,connect到rpcServer
this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort,
clientId, secret, protocol).get();
this.running = true;
public static Promise<Rpc> createClient(
Map<String, String> config,
final NioEventLoopGroup eloop,
String host,
int port,
final String clientId,
final String secret,
final RpcDispatcher dispatcher) throws Exception {
final RpcConfiguration rpcConf = new RpcConfiguration(config);
int connectTimeoutMs = (int) rpcConf.getConnectTimeoutMs();
final ChannelFuture cf = new Bootstrap()
.group(eloop)
.handler(new ChannelInboundHandlerAdapter() { })
.channel(NioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMs)
.connect(host, port);
// Set up a timeout to undo everything.
final Runnable timeoutTask = new Runnable() {
@Override
public void run() {
promise.setFailure(new TimeoutException("Timed out waiting for RPC server connection."));
}
};
final ScheduledFuture> timeoutFuture = eloop.schedule(timeoutTask,
rpcConf.getServerConnectTimeoutMs(), TimeUnit.MILLISECONDS);
this.clientRpc.addListener(new Rpc.Listener() {
@Override
public void rpcClosed(Rpc rpc) {
LOG.warn("Shutting down driver because RPC channel was closed.");
shutdown(null);
}
});
try {
JavaSparkContext sc = new JavaSparkContext(conf);
sc.sc().addSparkListener(new ClientListener());
synchronized (jcLock) {
//initiallize job context which holds runtime information about the job execution context.
jc = new JobContextImpl(sc, localTmpDir);
jcLock.notifyAll();
}
} catch (Exception e) {
LOG.error("Failed to start SparkContext: " + e, e);
//初始化sc,jc过程中,如果抛出异常时关闭发送错误,关闭各种服务
shutdown(e);
synchronized (jcLock) {
jcLock.notifyAll();
}
throw e;
}
synchronized (jcLock) {
//提交所有等待队列里的任务
for (Iterator<JobWrapper>> it = jobQueue.iterator(); it.hasNext();) {
it.next().submit();
}
}
}
private synchronized void shutdown(Throwable error) {
if (running) {
if (error == null) {
LOG.info("Shutting down remote driver.");
} else {
LOG.error("Shutting down remote driver due to error: " + error, error);
}
running = false;
for (JobWrapper> job : activeJobs.values()) {
cancelJob(job);
}
if (error != null) {
protocol.sendError(error);
}
if (jc != null) {
jc.stop();
}
clientRpc.close();
egroup.shutdownGracefully();
synchronized (shutdownLock) {
shutdownLock.notifyAll();
}
}
}
private void submit(JobWrapper> job) {
synchronized (jcLock) {
if (jc != null) {
job.submit();
} else {
LOG.info("SparkContext not yet up, queueing job request.");
jobQueue.add(job);
}
}
}
在RemoteDriver中除了一些简单的私有属性外,还定义了三个重要的内部类:JobWrapper、ClientListener和DriverProtocol。在上面分析的RemoteDriver的构造过程中我们会发现,这三个内部类的使用贯穿在整个RemoteDriver的实现当中,因此我们首分别看一些这几个内部类的作用和实现。
1. JobWrapper
JobWrapper实现了Callable
首先作为一个callable的任务,核心方法就是call方法,我们来看一下JobWrapper中call方法的实现。
第一步是t通过protocol发送jobStarted消息
@Override
public Void call() throws Exception {
protocol.jobStarted(req.id);
T result = req.job.call(jc);
for (JavaFutureAction> future : jobs) {
future.get();
completed++;
LOG.debug("Client job {}: {} of {} Spark jobs finished.",
req.id, completed, jobs.size());
}
if (sparkJobId != null) {
SparkJobInfo sparkJobInfo = jc.sc().statusTracker().getJobInfo(sparkJobId);
if (sparkJobInfo != null && sparkJobInfo.stageIds() != null &&
sparkJobInfo.stageIds().length > 0) {
synchronized (jobEndReceived) {
while (jobEndReceived.get() < jobs.size()) {
jobEndReceived.wait();
}
}
}
}
SparkCounters counters = null;
if (sparkCounters != null) {
counters = sparkCounters.snapshot();
}
protocol.jobFinished(req.id, result, null, counters);
} catch (Throwable t) {
// Catch throwables in a best-effort to report job status back to the client. It\'s
// re-thrown so that e executor can destroy the affected thread (or the JVM can
// die or whatever would happen if the throwable bubbled up).
LOG.info("Failed to run job " + req.id, t);
protocol.jobFinished(req.id, null, t,
sparkCounters != null ? sparkCounters.snapshot() : null);
throw new ExecutionException(t);
} finally {
jc.setMonitorCb(null);
activeJobs.remove(req.id);
releaseCache();
}
return null;
submit方法是供外部调用的提交任务的方法,它内部启动一个线程来提交执行JobWrapper
void submit() {
this.future = executor.submit(this);
}
void jobDone() {
synchronized (jobEndReceived) {
jobEndReceived.incrementAndGet();
jobEndReceived.notifyAll();
}
}
void releaseCache() {
if (cachedRDDIds != null) {
for (Integer cachedRDDId: cachedRDDIds) {
jc.sc().sc().unpersistRDD(cachedRDDId, false);
}
}
}
2. ClientListener
ClientListener继承自JavaSparkListener,用来监听来自Spark Scheduler的事件。ClientListener覆盖了父类中三个处理事件的方法。
当job开始时,触发onJobStart方法,将job的stage id和jobId保存到stageId这个hashmap中
@Override
public void onJobStart(SparkListenerJobStart jobStart) {
synchronized (stageToJobId) {
for (int i = 0; i < jobStart.stageIds().length(); i++) {
stageToJobId.put((Integer) jobStart.stageIds().apply(i), jobStart.jobId());
}
}
}
@Override
public void onJobEnd(SparkListenerJobEnd jobEnd) {
synchronized (stageToJobId) {
for (Iterator<Map.Entry<Integer, Integer>> it = stageToJobId.entrySet().iterator();
it.hasNext();) {
Map.Entry<Integer, Integer> e = it.next();
if (e.getValue() == jobEnd.jobId()) {
it.remove();
}
}
}
String clientId = getClientId(jobEnd.jobId());
if (clientId != null) {
activeJobs.get(clientId).jobDone();
}
@Override
public void onTaskEnd(SparkListenerTaskEnd taskEnd) {
if (taskEnd.reason() instanceof org.apache.spark.Success$
&& !taskEnd.taskInfo().speculative()) {
Metrics metrics = new Metrics(taskEnd.taskMetrics());
Integer jobId;
synchronized (stageToJobId) {
jobId = stageToJobId.get(taskEnd.stageId());
}
// TODO: implement implicit AsyncRDDActions conversion instead of jc.monitor()?
// TODO: how to handle stage failures?
String clientId = getClientId(jobId);
if (clientId != null) {
protocol.sendMetrics(clientId, jobId, taskEnd.stageId(),
taskEnd.taskInfo().taskId(), metrics);
}
}
}
3. DriverProtocol
RemoteDriver中定义了一个主要的Handler:DriverProtocol extends BaseProtocol。DriverProtocol的实现整体与ClientProtocol类似,而且两者正是SparkClient与RemoteDriver实现互相通信的组件。DriverProtocol中定义了一些发送消息的方法,其实现同样是通过调用Rpc的call方法发送不同类型的消息,这些消息的类型恰好与ClientProtocol中的几个handle方法能够处理的消息类型吻合
private class DriverProtocol extends BaseProtocol {
//发送Error类型消息
void sendError(Throwable error) {
LOG.debug("Send error to Client: {}", Throwables.getStackTraceAsString(error));
clientRpc.call(new Error(error));
}
//发送JobResult类型消息
<T extends Serializable> void jobFinished(String jobId, T result,
Throwable error, SparkCounters counters) {
LOG.debug("Send job({}) result to Client.", jobId);
clientRpc.call(new JobResult(jobId, result, error, counters));
}
//发送JobStarted类型消息
void jobStarted(String jobId) {
clientRpc.call(new JobStarted(jobId));
}
//发送JobSubmitted类型消息
void jobSubmitted(String jobId, int sparkJobId) {
LOG.debug("Send job({}/{}) submitted to Client.", jobId, sparkJobId);
clientRpc.call(new JobSubmitted(jobId, sparkJobId));
}
//发送JobMetrics类型消息
void sendMetrics(String jobId, int sparkJobId, int stageId, long taskId, Metrics metrics) {
LOG.debug("Send task({}/{}/{}/{}) metric to Client.", jobId, sparkJobId, stageId, taskId);
clientRpc.call(new JobMetrics(jobId, sparkJobId, stageId, taskId, metrics));
}
而DriverProtocol中的handle方法所处理的消息类型刚好是SparkClient中发送消息的几个方法所发送的消息类型。
处理CancelJob类型消息时
private void handle(ChannelHandlerContext ctx, CancelJob msg) {
JobWrapper> job = activeJobs.get(msg.id);
if (job == null || !cancelJob(job)) {
LOG.info("Requested to cancel an already finished job.");
}
}
处理EndSession类型消息时,直接调用RemoteDriver的shutdown方法关闭RemoteDriver
private void handle(ChannelHandlerContext ctx, EndSession msg) {
LOG.debug("Shutting down due to EndSession request.");
shutdown(null);
}
private void handle(ChannelHandlerContext ctx, JobRequest msg) {
LOG.info("Received job request {}", msg.id);
JobWrapper> wrapper = new JobWrapper<Serializable>(msg);
activeJobs.put(msg.id, wrapper);
submit(wrapper);
}
private Object handle(ChannelHandlerContext ctx, SyncJobRequest msg) throws Exception {
// In case the job context is not up yet, let\'s wait, since this is supposed to be a
// "synchronous" RPC.
LOG.debug("liban: DriverProtocol received SyncJobRequest msg. waiting for jc to be up");
if (jc == null) {
synchronized (jcLock) {
while (jc == null) {
jcLock.wait();
if (!running) {
throw new IllegalStateException("Remote context is shutting down.");
}
}
}
}
jc.setMonitorCb(new MonitorCallback() {
@Override
public void call(JavaFutureAction> future,
SparkCounters sparkCounters, Set<Integer> cachedRDDIds) {
throw new IllegalStateException(
"JobContext.monitor() is not available for synchronous jobs.");
}
});
try {
LOG.debug("liban: type of job in SyncJobRequest msg: " + msg.job.getClass().getSimpleName());
return msg.job.call(jc);
} finally {
jc.setMonitorCb(null);
}
@Override
public Serializable call(JobContext jc) throws Exception {
JobConf localJobConf = KryoSerializer.deserializeJobConf(jobConfBytes);
// Add jar to current thread class loader dynamically, and add jar paths to JobConf as Spark
// may need to load classes from this jar in other threads.
Map<String, Long> addedJars = jc.getAddedJars();
if (addedJars != null && !addedJars.isEmpty()) {
List<String> localAddedJars = SparkClientUtilities.addToClassPath(addedJars,
localJobConf, jc.getLocalTmpDir());
localJobConf.set(Utilities.HIVE_ADDED_JARS, StringUtils.join(localAddedJars, ";"));
}
// 反序列化出本地临时目录路径和SparkWork
Path localScratchDir = KryoSerializer.deserialize(scratchDirBytes, Path.class);
SparkWork localSparkWork = KryoSerializer.deserialize(sparkWorkBytes, SparkWork.class);
logConfigurations(localJobConf);
//获取sparkCounter
SparkCounters sparkCounters = new SparkCounters(jc.sc());
Map<String, List<String>> prefixes = localSparkWork.getRequiredCounterPrefix();
if (prefixes != null) {
for (String group : prefixes.keySet()) {
for (String counterName : prefixes.get(group)) {
sparkCounters.createCounter(group, counterName);
}
}
}
// 通过sparkCounter构造sparkReporter
SparkReporter sparkReporter = new SparkReporter(sparkCounters);
// Generate Spark plan
SparkPlanGenerator gen =
new SparkPlanGenerator(jc.sc(), null, localJobConf, localScratchDir, sparkReporter);
SparkPlan plan = gen.generate(localSparkWork);
// Execute generated plan.
JavaPairRDD<HiveKey, BytesWritable> finalRDD = plan.generateGraph();
// We use Spark RDD async action to submit job as it\'s the only way to get jobId now.
JavaFutureAction<Void> future = finalRDD.foreachAsync(HiveVoidFunction.getInstance());
jc.monitor(future, sparkCounters, plan.getCachedRDDIds());
return null;
}