Driver端如何正确取消Spark中的job

1.      SparkContext提供了一个取消job的api

class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationClient {
/** Cancel a given job if it's scheduled or running */
private[spark] def cancelJob(jobId: Int) {
  dagScheduler.cancelJob(jobId)
}
}

2.      那么如何获取jobId呢?

Spark提供了一个叫SparkListener的对象,它提供了对spark事件的监听功能

trait SparkListener {
  /**
   * Called when a job starts
   */
  def onJobStart(jobStart: SparkListenerJobStart) { }

  /**
   * Called when a job ends
   */
  def onJobEnd(jobEnd: SparkListenerJobEnd) { }
}

因此需要自定义一个类,继承自SparkListener,即:

public class DHSparkListener implements SparkListener {
private static Logger logger = Logger.getLogger(DHSparkListener.class);
//存储了提交job的线程局部变量和job的映射关系
    private static ConcurrentHashMap<String, Integer> jobInfoMap;
    public DHSparkListener() {
        jobInfoMap = new ConcurrentHashMap<String, Integer>();
    }
    @Override
    public void onJobEnd(SparkListenerJobEnd jobEnd) {
        logger.info("DHSparkListener Job End:" + jobEnd.jobResult().getClass() + ",Id:" + jobEnd.jobId());
        for (String key : jobInfoMap.keySet()) {
            if (jobInfoMap.get(key) == jobEnd.jobId()) {
                jobInfoMap.remove(key);
                logger.info(key+" request has been returned. because "+jobEnd.jobResult().getClass());
            }
        }
    }
    @Override
    public void onJobStart(SparkListenerJobStart jobStart) {
        logger.info("DHSparkListener Job Start: JobId->" + jobStart.jobId());
//根据线程变量属性找到该job是哪个线程提交的
        logger.info("DHSparkListener Job Start: Thread->" + jobStart.properties().getProperty("thread", "default"));
        jobInfoMap.put(jobStart.properties().getProperty("thread", "default"), jobStart.jobId());
    }
……
}

那么用户如何知道该job是哪个线程提交的呢?需要在提交job的时候设置线程局部变量属性,即

SparkConf conf = new SparkConf().setAppName("SparkListenerTest application in Java");
        String sparkMaster = Configure.instance.get("SparkMaster");
        String sparkExecutorMemory = "16g";
        String sparkCoresMax = "4";
        String sparkJarAddress = "/tmp/cuckoo-core-1.0-SNAPSHOT-allinone.jar";
        conf.setMaster(sparkMaster);
        conf.set("spark.executor.memory", sparkExecutorMemory);
        conf.set("spark.cores.max", sparkCoresMax);
        JavaSparkContext jsc = new JavaSparkContext(conf);
        jsc.addJar(sparkJarAddress);
        DHSparkListener dHSparkListener = new DHSparkListener();
        jsc.sc().addSparkListener(dHSparkListener);
        List<Integer> listData = new ArrayList<Integer>();
        listData = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9);
        JavaRDD<Integer> rdd1 = jsc.parallelize(listData, 1);
JavaRDD<Integer> rdd2 = rdd1.map(new Function<Integer, Integer>() {
            public Integer call(Integer v1) throws Exception {
              //do something then return
            }
        });
<pre name="code" class="plain">       //在触发action提交job之前设置提交线程的局部属性,供SparkListener获取
       jsc.setLocalProperty("thread", "client");
       rdd2.count();

 
 

这样在jobInfoMap中记录了job和job提交者的映射关系,当发现某个job迟迟没有结束的时候,可以调用SparkContext的cancelJob取消,但是仅仅到这里就够了吗?接着往下看,excutor取消job最终调用的是:

def kill(interruptThread: Boolean) {
  _killed = true
  if (context != null) {
    context.markInterrupted()
  }
  if (interruptThread && taskThread != null) {
    taskThread.interrupt()
  }
}

最终调用到Thread.interrupt函数,给启动task的线程设置interrupt标记位,因此在长时间允许的task中,需要针对Thread的interrupt标记位进行判断,当被置位的时候,需要退出,并且做一些清理,即存在类似的代码段:

if(Thread.interrupted()){
    //……线程被中断,清理资源
}
或者调用sleep,wait函数时会抛出InterruptedException异常,需要进行捕获,然后做对应的处理


3.      最后一步,配置job kill的动作

除了以上操作之外,还需要再配置针对每个job调用kill的动作,即spark.job.interruptOnCancel属性为true 

  //在触发action提交job之前设置提交线程的局部属性,供SparkListener获取
       jsc.setLocalProperty("thread", "client");
	   //配置该job接受到kill之后的动作,即task线程收到interrupt信号
	   jsc.setLocalProperty("spark.job.interruptOnCancel", "true");
       rdd2.count();

你可能感兴趣的:(源码,spark)