修改spark thriftserver返回日志信息

      通过spark thriftserver提交sql至spark运行时,由于thriftserver默认读取的是本地的日志文件,因此getQueryLog会返回空,为了支持spark运行模式下返回任务运行信息,对getQueryLog方法进行修改。

       首先在org.apache.hive.service.cli.operation.Operation类中添加接口:       

public String getStatementId() {return StringUtils.EMPTY;}

       在org.apache.spark.sql.hive.thriftserver.SparkExecuteStatementOperation类中添加实现:

override def getStatementId: String = statementId

      该方法的作用是返回Job的groupId。

      其次新建一个SparkHiveLog类:

class SparkHiveLog extends SparkListener with Logging{

    private class SparkStageInfo(val stageId: Int, val totalTask: Int) {

        var completedTask = 0

        var status = "Running"
    }

    /**
     * 组对job列表对映射
      */
    private val jobGroupMap : Map[String, Set[Int]] = Map[String, Set[Int]]()

    /**
     * job对stage列表的映射
     */
    private val jobListMap : Map[Int, Seq[Int]] = Map[Int, Seq[Int]]()


     /**
      * stageId对应StageInfo
      */
    private val stageMap: Map[Int, SparkStageInfo] = Map[Int, SparkStageInfo]()


    private var taskCPUs = 1

    override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
      val groupId = jobStart.properties.getProperty("spark.jobGroup.id")
      taskCPUs = jobStart.properties.getProperty("spark.task.cpus", "1").toInt

      jobListMap += (jobStart.jobId -> jobStart.stageIds)
      if(jobGroupMap.contains(groupId)){
          jobGroupMap.get(groupId).get.add(jobStart.jobId)
      }else{
          jobGroupMap += (groupId -> Set(jobStart.jobId))
      }
    }

    override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
       val stageId = stageSubmitted.stageInfo.stageId
       val numTasks = stageSubmitted.stageInfo.numTasks

       val stageInfo = new SparkStageInfo(stageId, numTasks)
       stageMap += (stageId -> stageInfo)
    }


    override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
        val stageId = stageCompleted.stageInfo.stageId
        if(stageMap.contains(stageId)){
            val stageInfo = stageMap.get(stageId).get
            stageInfo.status = "Completed"
        }
    }


    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
        val stageId = taskEnd.stageId

        if(stageMap.contains(stageId)){
            val stageInfo = stageMap.get(stageId).get
            if(taskEnd.taskInfo.successful){
                val completeTask = stageInfo.completedTask
                stageInfo.completedTask = completeTask+1
            }
         }
    }

    /**
     * 获取stage下任务进度信息
     * @param stageId
     * @return
     */
    private def getStageInfo(stageId: Int) : String = {
       val sb = new StringBuffer("stage")
       if(!stageMap.contains(stageId)){
          return ""
        }
        val stageInfo = stageMap.get(stageId).get
        sb.append(stageId).append("(").append(stageInfo.completedTask).append("/")
        sb.append(stageInfo.totalTask).append(")")
        sb.toString
    }

    /**
      * 获取job下个stage进度信息
      * @param jobId
      * @return
      */
    private def getJobInfo(jobId: Int): String = {
       val sb = new StringBuffer("Job")
       val map = new util.TreeMap[Int, String]()
       if(jobListMap.contains(jobId)){
          for(stageId <- jobListMap.get(jobId).get){
              val stageInfo = getStageInfo(stageId)
              if(StringUtils.isNotEmpty(stageInfo)){
                  map.put(stageId, getStageInfo(stageId))
              }
          }
       }
       sb.append(jobId).append("/").append(formatString(map, "/"))
       sb.toString
    }

    /**
      * 查询总的job进度信息
      * @param groupId
      * @return
      */
    def getLogInfo(groupId: String) : String = {
        val sb = new StringBuffer()

    if(jobGroupMap.contains(groupId)){
        val map = new util.TreeMap[Int, String]()
        val iter = jobGroupMap.get(groupId).get.iterator
        while(iter.hasNext){
            val jobId = iter.next()
            map.put(jobId, getJobInfo(jobId))
        }
        sb.append(formatString(map, "; "))
     }
     sb.toString
    }

    private def formatString(map: util.TreeMap[Int, String], split: String): String = {
       val list = new util.ArrayList[String]()
       val iter = map.keySet().iterator()
       while(iter.hasNext){
          list.add(map.get(iter.next()))
       }
       StringUtils.join(list, split)
     }


    def clearCaches(groupId: String) : Unit = {
       if(jobGroupMap.contains(groupId)){
           val iter = jobGroupMap.get(groupId).get.iterator
           while(iter.hasNext){
              val jobId = iter.next()
              val iter1 = jobListMap.get(jobId).get.iterator
              while(iter1.hasNext){
                  stageMap.remove(iter1.next())
              }
              jobListMap.remove(jobId)
           }
           jobGroupMap.remove(groupId)
        }
    }

  }

 object SparkHiveLog {

    private var sparkHiveLog : SparkHiveLog = null

    def getSparkHiveLog() : SparkHiveLog = {
       if(sparkHiveLog==null){
           sparkHiveLog = new SparkHiveLog()
       }
       sparkHiveLog
     }
 }

     该类的主要作用是在onJobStart、onStageSubmitted、onTaskEnd阶段保存任务的信息。

     在org.apache.hive.service.cli.operation.OperationManager类中添加接口:

public RowSet getQueryLog(OperationHandle opHandle){return null;}
     在 org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager类中添加该接口的实现:
override def getQueryLog(opHandle: OperationHandle): RowSet = {

    val operation = handleToOperation.get(opHandle)
    val groupId = operation.getStatementId

    val schema = new Schema
    val fieldSchema = new FieldSchema
    fieldSchema.setName("operation_log")
    fieldSchema.setType("string")
    schema.addToFieldSchemas(fieldSchema)
    val tableSchema = new TableSchema(schema)

    val rowSet = RowSetFactory.create(tableSchema, getOperation(opHandle).getProtocolVersion)
    rowSet.addRow(Array(SparkHiveLog.getSparkHiveLog().getLogInfo(groupId)))
    rowSet
 }
    为了使getQueryLog方法生效,需要修改 org.apache.hive.service.cli.session.HiveSessionImpl类中的 fetchResults方法:
@Override
  public RowSet fetchResults(OperationHandle opHandle, FetchOrientation orientation,
  long maxRows, FetchType fetchType) throws HiveSQLException {
    acquire(true);
    try {
       if (fetchType == FetchType.QUERY_OUTPUT) {
          return operationManager.getOperationNextRowSet(opHandle, orientation, maxRows);
       }
         return operationManager.getQueryLog(opHandle);
      } finally {
        release(true);
    }
  }
    同时重写 org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager类中的 closeOperation方法,该方法的主要作用是在statement关闭时清空 SparkHiveLog中保存的任务运行信息。
override def closeOperation(opHandle: OperationHandle): Unit = {

    val operation = handleToOperation.get(opHandle)
    val groupId = operation.getStatementId

    SparkHiveLog.getSparkHiveLog().clearCaches(groupId)
    operation.close()
    handleToOperation.remove(opHandle)
  }

    最后需要在org.apache.spark.sql.hive.thriftserver.HiveThriftServer2类中添加SparkHiveLog监听:

SparkSQLEnv.sparkContext.addSparkListener(SparkHiveLog.getSparkHiveLog)

      重新编译spark源码,替换spark-hive-thriftserver_2.11-2.x包,重启thriftserver,通过HiveStatement执行sql代码时,可以调用getQueryLog返回运行日志。


你可能感兴趣的:(修改spark thriftserver返回日志信息)