Spark源码:启动Master

源码目录


1 start-master.sh

-- spark/sbin/start-master.sh

CLASS="org.apache.spark.deploy.master.Master"

"${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \
  --host $SPARK_MASTER_HOST --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \
  $ORIGINAL_ARGS


2 调用主函数

  • 进入org.apache.spark.deploy.master.Master.scala
  def main(argStrings: Array[String]) {
    Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler(
      exitOnUncaughtException = false))
    Utils.initDaemon(log)
    val conf = new SparkConf
    val args = new MasterArguments(argStrings, conf)
    val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
    rpcEnv.awaitTermination()
  }

  1. val conf = new SparkConf
    使用ConcurrentHashMap[String, String]保存配置信息,将system properties内以spark开头的配置放入到ConcurrentHashMap。

  2. val args = new MasterArguments(argStrings, conf)
    解析命令行中的参数,加载默认参数,生成Master参数。

  3. val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
    创建RpcEnv、注册RpcEndpoint(关键部分)。

  4. rpcEnv.awaitTermination()
    运行直到RpcEnv关闭。


3 重点分析 startRpcEnvAndEndpoint

  • 进入org.apache.spark.deploy.master.Master.scala
  /**
   * Start the Master and return a three tuple of:
   *   (1) The Master RpcEnv
   *   (2) The web UI bound port
   *   (3) The REST server bound port, if any
   */
  def startRpcEnvAndEndpoint(
      host: String,
      port: Int,
      webUiPort: Int,
      conf: SparkConf): (RpcEnv, Int, Option[Int]) = {
    val securityMgr = new SecurityManager(conf)
    val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
    val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
      new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
    val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest)
    (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
  }
  1. val securityMgr = new SecurityManager(conf)
    创建SecurityManager,对账号、权限以及身份认证进行设置和管理。

  2. val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
    创建RpcEnv。

  3. val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
    创建 RpcEndpoint,并注册到 RpcEnv 上返回 RpcEndpointRef。

  4. val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest)
    RpcEndpointRef(masterEndpoint) 同步发送消息(BoundPortsRequest) 给对应的 RpcEndpoint(Master).receiveAndReply,然后超时等待返回结果。

  5. (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
    返回结果 ( MasterRpcEnv, webUIPort, RESTServerPort(如果有) )。


3.1 分析创建RpcEnv过程

  • 进入org.apache.spark.rpc.RpcEnv.scala
private[spark] object RpcEnv {

  def create(
      name: String,
      host: String,
      port: Int,
      conf: SparkConf,
      securityManager: SecurityManager,
      clientMode: Boolean = false): RpcEnv = {
    create(name, host, host, port, conf, securityManager, 0, clientMode)
  }

  def create(
      name: String,
      bindAddress: String,
      advertiseAddress: String,
      port: Int,
      conf: SparkConf,
      securityManager: SecurityManager,
      numUsableCores: Int,
      clientMode: Boolean): RpcEnv = {
    val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
      numUsableCores, clientMode)
    new NettyRpcEnvFactory().create(config)
  }
}

构造 RpcEnvConfig,然后利用 NettyRpcEnvFactory 工厂类创建 NettyRpcEnv。

  • 进入org.apache.spark.rpc.netty.NettyRpcEnvFactory.scala
private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {

  def create(config: RpcEnvConfig): RpcEnv = {
    val sparkConf = config.conf
    // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
    // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
    val javaSerializerInstance =
      new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
    val nettyEnv =
      new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
        config.securityManager, config.numUsableCores)
    if (!config.clientMode) {
      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
        nettyEnv.startServer(config.bindAddress, actualPort)
        (nettyEnv, nettyEnv.address.port)
      }
      try {
        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
      } catch {
        case NonFatal(e) =>
          nettyEnv.shutdown()
          throw e
      }
    }
    nettyEnv
  }
}
  1. 创建NettyRpcEnv对象;
  2. 从config中获取clientMode属性,如果clientMode为否,则表示该RpcEnv创建在Server端,于是调用Utils.startServiceOnPort()启动服务,其又会调用函数startNettyRpcEnv: Int => (NettyRpcEnv, Int)
  3. 在函数 startNettyRpcEnv 中又会调用 NettyRpcEnv.startServer(),该方法会创建TransportServer;
  4. 返回NettyRpcEnv。

3.1.1 创建NettyRpcEnv对象

  • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
private[netty] class NettyRpcEnv(
    val conf: SparkConf,
    javaSerializerInstance: JavaSerializerInstance,
    host: String,
    securityManager: SecurityManager,
    numUsableCores: Int) extends RpcEnv(conf) with Logging {

  private[netty] val transportConf = SparkTransportConf.fromSparkConf(
    conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
    "rpc",
    conf.getInt("spark.rpc.io.threads", numUsableCores))

  private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)

  private val streamManager = new NettyStreamManager(this)

  private val transportContext = new TransportContext(transportConf,
    new NettyRpcHandler(dispatcher, this, streamManager))

  // 省略
}
  1. 创建NettyRpcEnv时会在其内部创建Dispatcher、NettyStreamManager、TransportContext等;
  2. 在创建TransportContext时还会创建NettyRpcHandler,用于将传入的RPC请求分发到注册的endpoints上去处理。

3.1.2 Utils.startServiceOnPort()

  • 进入org.apache.spark.util.Utils.Utils.scala
  /**
   * Attempt to start a service on the given port, or fail after a number of attempts.
   * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
   *
   * @param startPort The initial port to start the service on.
   * @param startService Function to start service on a given port.
   *                     This is expected to throw java.net.BindException on port collision.
   * @param conf A SparkConf used to get the maximum number of retries when binding to a port.
   * @param serviceName Name of the service.
   * @return (service: T, port: Int)
   */
  def startServiceOnPort[T](
      startPort: Int,
      startService: Int => (T, Int),
      conf: SparkConf,
      serviceName: String = ""): (T, Int) = {

    require(startPort == 0 || (1024 <= startPort && startPort < 65536),
      "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.")

    val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
    val maxRetries = portMaxRetries(conf)
    for (offset <- 0 to maxRetries) {
      // Do not increment port if startPort is 0, which is treated as a special port
      val tryPort = if (startPort == 0) {
        startPort
      } else {
        userPort(startPort, offset)
      }
      try {
        val (service, port) = startService(tryPort)
        logInfo(s"Successfully started service$serviceString on port $port.")
        return (service, port)
      } catch {
        case e: Exception if isBindCollision(e) =>
          if (offset >= maxRetries) {
            val exceptionMessage = if (startPort == 0) {
              s"${e.getMessage}: Service$serviceString failed after " +
                s"$maxRetries retries (on a random free port)! " +
                s"Consider explicitly setting the appropriate binding address for " +
                s"the service$serviceString (for example spark.driver.bindAddress " +
                s"for SparkDriver) to the correct binding address."
            } else {
              s"${e.getMessage}: Service$serviceString failed after " +
                s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " +
                s"the appropriate port for the service$serviceString (for example spark.ui.port " +
                s"for SparkUI) to an available port or increasing spark.port.maxRetries."
            }
            val exception = new BindException(exceptionMessage)
            // restore original stack trace
            exception.setStackTrace(e.getStackTrace)
            throw exception
          }
          if (startPort == 0) {
            // As startPort 0 is for a random free port, it is most possibly binding address is
            // not correct.
            logWarning(s"Service$serviceString could not bind on a random free port. " +
              "You may check whether configuring an appropriate binding address.")
          } else {
            logWarning(s"Service$serviceString could not bind on port $tryPort. " +
              s"Attempting port ${tryPort + 1}.")
          }
      }
    }
    // Should never happen
    throw new SparkException(s"Failed to start service$serviceString on port $startPort")
  }

传入参数:
(1) startPort:基于SparkConf构建的配置RpcEnvConfig中配置的端口号作为起始端口号
(2) startService:之前定义的函数 val startNettyRpcEnv: Int => (NettyRpcEnv, Int)
(3) conf:sparkConf
(4) serviceName:服务名("sparkMaster")

逻辑:
(1) 校验startPort
(2) 尝试 0 to maxRetries 次,每次设置一个tryPort(有自定义的设置规则)传入startService,尝试启动服务
(3) 如果启动服务成功,则返回
(4) 如果超过maxRetries次,仍未启动成功,则抛出异常

startService中调用 nettyEnv.startServer(config.bindAddress, actualPort) 尝试启动服务。

3.1.3 NettyRpcEnv.startServer()

  • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
  def startServer(bindAddress: String, port: Int): Unit = {
    val bootstraps: java.util.List[TransportServerBootstrap] =
      if (securityManager.isAuthenticationEnabled()) {
        java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager))
      } else {
        java.util.Collections.emptyList()
      }
    server = transportContext.createServer(bindAddress, port, bootstraps)
    dispatcher.registerRpcEndpoint(
      RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
  }

  1. TransportContext 创建 TransportServer;
  2. 在 Dispatcher 上注册 RpcEndpointVerifier(注册RpcEndpoint的流程都一样,详见后文)。
  • 进入org.apache.spark.network.TransportContext.java
  /** Create a server which will attempt to bind to a specific host and port. */
  public TransportServer createServer(
      String host, int port, List bootstraps) {
    return new TransportServer(this, host, port, rpcHandler, bootstraps);
  }
  • 进入org.apache.spark.network.server.TransportServer.java
  /**
   * Creates a TransportServer that binds to the given host and the given port, or to any available
   * if 0. If you don't want to bind to any special host, set "hostToBind" to null.
   * */
  public TransportServer(
      TransportContext context,
      String hostToBind,
      int portToBind,
      RpcHandler appRpcHandler,
      List bootstraps) {
    this.context = context;
    this.conf = context.getConf();
    this.appRpcHandler = appRpcHandler;
    this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));

    boolean shouldClose = true;
    try {
      init(hostToBind, portToBind);
      shouldClose = false;
    } finally {
      if (shouldClose) {
        JavaUtils.closeQuietly(this);
      }
    }
  }


  private void init(String hostToBind, int portToBind) {

    IOMode ioMode = IOMode.valueOf(conf.ioMode());
    EventLoopGroup bossGroup =
      NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server");
    EventLoopGroup workerGroup = bossGroup;

    PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator(
      conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());

    bootstrap = new ServerBootstrap()
      .group(bossGroup, workerGroup)
      .channel(NettyUtils.getServerChannelClass(ioMode))
      .option(ChannelOption.ALLOCATOR, allocator)
      .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)
      .childOption(ChannelOption.ALLOCATOR, allocator);

    this.metrics = new NettyMemoryMetrics(
      allocator, conf.getModuleName() + "-server", conf);

    if (conf.backLog() > 0) {
      bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
    }

    if (conf.receiveBuf() > 0) {
      bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
    }

    if (conf.sendBuf() > 0) {
      bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
    }

    bootstrap.childHandler(new ChannelInitializer() {
      @Override
      protected void initChannel(SocketChannel ch) {
        logger.debug("New connection accepted for remote address {}.", ch.remoteAddress());

        RpcHandler rpcHandler = appRpcHandler;
        for (TransportServerBootstrap bootstrap : bootstraps) {
          rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
        }
        context.initializePipeline(ch, rpcHandler);
      }
    });

    InetSocketAddress address = hostToBind == null ?
        new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind);
    channelFuture = bootstrap.bind(address);
    channelFuture.syncUninterruptibly();

    port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
    logger.debug("Shuffle server started on port: {}", port);
  }

在初始化 TransportServer 阶段:

  1. 基于Netty API初始化ServerBootstrap,设置管道初始化器ChannelInitializer到ServerBootstrap内部的ChannelHandler;
  2. 创建InetSocketAddress,ServerBootstrap绑定InetSocketAddress。

初始化管道:

  • 进入org.apache.spark.network.TransportContext.java
  /**
   * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
   * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
   * response messages.
   *
   * @param channel The channel to initialize.
   * @param channelRpcHandler The RPC handler to use for the channel.
   *
   * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
   * be used to communicate on this channel. The TransportClient is directly associated with a
   * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
   */
  public TransportChannelHandler initializePipeline(
      SocketChannel channel,
      RpcHandler channelRpcHandler) {
    try {
      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
      channel.pipeline()
        .addLast("encoder", ENCODER)
        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
        .addLast("decoder", DECODER)
        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
        // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
        // would require more logic to guarantee if this were not part of the same event loop.
        .addLast("handler", channelHandler);
      return channelHandler;
    } catch (RuntimeException e) {
      logger.error("Error while initializing Netty pipeline", e);
      throw e;
    }
  }
  1. 创建TransportChannelHandler;
  2. SocketChannel.pipeline增加TransportChannelHandler。

3.2 创建并注册 Master(RpcEndpoint)

  • 进入org.apache.spark.deploy.master.Master.scala
private[deploy] class Master(
    override val rpcEnv: RpcEnv,
    address: RpcAddress,
    webUiPort: Int,
    val securityMgr: SecurityManager,
    val conf: SparkConf)
  extends ThreadSafeRpcEndpoint with Logging with LeaderElectable {
   
    // 省略

}

Master 继承了 ThreadSafeRpcEndpoint,是一个 RpcEndpoint。

  • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
  override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
    dispatcher.registerRpcEndpoint(name, endpoint)
  }
  • 进入org.apache.spark.rpc.netty.Dispatcher.scala
  def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
    val addr = RpcEndpointAddress(nettyEnv.address, name)
    val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
    synchronized {
      if (stopped) {
        throw new IllegalStateException("RpcEnv has been stopped")
      }
      if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
        throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
      }
      val data = endpoints.get(name)
      endpointRefs.put(data.endpoint, data.ref)
      receivers.offer(data)  // for the OnStart message
    }
    endpointRef
  }
  1. 创建相应的NettyRpcEndpointRef,在创建 NettyRpcEndpointRef 时,会传入三个参数:SparkConf、RpcEndpointAddress 和 NettyRpcEnv,在这里,RpcEndpointAddress就是NettyRpcEnv的地址;
  2. 构建EndpointData(name, endpoint, endpointRef)并保存到endpoints(ConcurrentMap[String, EndpointData])中;
  3. 同时将EndpointData放入到receivers(LinkedBlockingQueue[EndpointData])中(EndpointData加入了队列,在哪里取出来处理呢?见后文"Dispatcher里消息处理过程")
  4. 返回 NettyRpcEndpointRef。

看一看 EndpointData

  • 进入org.apache.spark.rpc.netty.Dispatcher.EndpointData.scala
  private class EndpointData(
      val name: String,
      val endpoint: RpcEndpoint,
      val ref: NettyRpcEndpointRef) {
    val inbox = new Inbox(ref, endpoint)
  }

在构建 EndpointData 时会创建Inbox,再看看 Inbox

  • 进入org.apache.spark.rpc.netty.Inbox.scala
private[netty] class Inbox(
    val endpointRef: NettyRpcEndpointRef,
    val endpoint: RpcEndpoint)
  extends Logging {

  @GuardedBy("this")
  protected val messages = new java.util.LinkedList[InboxMessage]()

  // OnStart should be the first message to process
  inbox.synchronized {
    messages.add(OnStart)
  }
}

在构建Inbox时,会声明一个LinkedList[InboxMessage](messages),同时会将Onstart加入到 messages 中,这样在创建Inbox时就将OnStart加入队列,可以保证OnStart第一个被处理。

Onstart 加入了队列,又在哪里取出来处理呢?见后文"Dispatcher里消息处理过程"


3.3 同步发送消息

  • 进入org.apache.spark.rpc.RpcEndpointRef.scala
  /**
   * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
   * default timeout, throw an exception if this fails.
   *
   * Note: this is a blocking action which may cost a lot of time,  so don't call it in a message
   * loop of [[RpcEndpoint]].

   * @param message the message to send
   * @tparam T type of the reply message
   * @return the reply message from the corresponding [[RpcEndpoint]]
   */
  def askSync[T: ClassTag](message: Any): T = askSync(message, defaultAskTimeout)


  /**
   * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
   * specified timeout, throw an exception if this fails.
   *
   * Note: this is a blocking action which may cost a lot of time, so don't call it in a message
   * loop of [[RpcEndpoint]].
   *
   * @param message the message to send
   * @param timeout the timeout duration
   * @tparam T type of the reply message
   * @return the reply message from the corresponding [[RpcEndpoint]]
   */
  def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
    val future = ask[T](message, timeout)
    timeout.awaitResult(future)
  }

  • 进入org.apache.spark.rpc.netty.NettyRpcEndpointRef.scala
private[netty] class NettyRpcEndpointRef(
    @transient private val conf: SparkConf,
    private val endpointAddress: RpcEndpointAddress,
    @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {

  override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
    nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
  }

}
  1. 新建 RequestMessage
  2. 调用 NettyRpcEnv 超时发送消息
  • 进入org.apache.spark.rpc.netty.RequestMessage.scala
private[netty] class RequestMessage(
    val senderAddress: RpcAddress,
    val receiver: NettyRpcEndpointRef,
    val content: Any) {

    // 省略

}

参数说明:
(1)val senderAddress: RpcAddress:消息的发送方地址,消息由NettyRpcEnv发送,因此发送方地址为NettyRpcEnv的地址,即RpcAddress;
(2)val receiver: NettyRpcEndpointRef:消息的接收方,消息发送给NettyRpcEndpointRef,进而找出对应的RpcEndpoint来处理此消息;
(3)val content: Any:消息内容

具体的,在new RequestMessage(nettyEnv.address, this, message)中:
(1)senderAddress是nettyEnv.address,表示NettyRpcEnv的地址,在启动Master时会在Master节点上创建一个Master NettyRpcEnv,此处的地址就是Master的地址;
(2)receiver是this,表示调用NettyRpcEnv.ask方法的NettyRpcEndpointRef,即 masterEndpoint;
(3)content是message,表示消息内容。

  • 进入org.apache.spark.rpc.netty.NettyRpcEnv.scala
  private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
    val promise = Promise[Any]()
    val remoteAddr = message.receiver.address

    def onFailure(e: Throwable): Unit = {
      if (!promise.tryFailure(e)) {
        e match {
          case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e")
          case _ => logWarning(s"Ignored failure: $e")
        }
      }
    }

    def onSuccess(reply: Any): Unit = reply match {
      case RpcFailure(e) => onFailure(e)
      case rpcReply =>
        if (!promise.trySuccess(rpcReply)) {
          logWarning(s"Ignored message: $reply")
        }
    }

    try {
      if (remoteAddr == address) {
        val p = Promise[Any]()
        p.future.onComplete {
          case Success(response) => onSuccess(response)
          case Failure(e) => onFailure(e)
        }(ThreadUtils.sameThread)
        dispatcher.postLocalMessage(message, p)
      } else {
        val rpcMessage = RpcOutboxMessage(message.serialize(this),
          onFailure,
          (client, response) => onSuccess(deserialize[Any](client, response)))
        postToOutbox(message.receiver, rpcMessage)
        promise.future.failed.foreach {
          case _: TimeoutException => rpcMessage.onTimeout()
          case _ =>
        }(ThreadUtils.sameThread)
      }

      val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
        override def run(): Unit = {
          onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
            s"in ${timeout.duration}"))
        }
      }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
      promise.future.onComplete { v =>
        timeoutCancelable.cancel(true)
      }(ThreadUtils.sameThread)
    } catch {
      case NonFatal(e) =>
        onFailure(e)
    }
    promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
  }
  1. 从RequestMessage中取出 receiver.address 作为remoteAddr,即消息接收方地址;
  2. 判断 "消息接收方地址" 与 "当前发送消息的NettyRpcEnv的地址" 是否相同;
  3. 如果相同,表示处理消息的RpcEndpoint就注册在当前NettyRpcEnv中(对应的RpcEndpoint和RpcEndpointRef总在相同的RpcEndpoint中),则新建Promise对象,为其future设置完成时的回调函数,然后利用NettyRpcEnv内部的Dispatcher的postLocalMessage方法投递消息到本地;
  4. 如果不同,表示处理消息的RpcEndpoint注册在其他NettyRpcEnv中,则新建RpcOutboxMessage,然后调用postToOutbox方法投递消息到Outbox;
  5. 创建NettyRpcEnv时会在内部维护一个timeoutScheduler,利用timeoutScheduler可以新启一个线程定时抛出那些超时任务的异常信息;
  6. 如果超时时间内消息处理成功,则取消定时抛出超时异常信息的线程任务;
  7. 返回处理结果。

在这里, "消息接收方地址" 与 "当前发送消息的NettyRpcEnv的地址" 相同,因此投递消息到本地。

3.3.1 投递消息到本地

  • 进入org.apache.spark.rpc.netty.Dispatcher.scala
  /** Posts a message sent by a local endpoint. */
  def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
    val rpcCallContext =
      new LocalNettyRpcCallContext(message.senderAddress, p)
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
    postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
  }
  1. 创建本地RPC调用上下文:LocalNettyRpcCallContext;
  2. 构建RpcMessage;
  3. 投递消息。
  • 进入org.apache.spark.rpc.netty.Dispatcher.scala
  /**
   * Posts a message to a specific endpoint.
   *
   * @param endpointName name of the endpoint.
   * @param message the message to post
   * @param callbackIfStopped callback function if the endpoint is stopped.
   */
  private def postMessage(
      endpointName: String,
      message: InboxMessage,
      callbackIfStopped: (Exception) => Unit): Unit = {
    val error = synchronized {
      val data = endpoints.get(endpointName)
      if (stopped) {
        Some(new RpcEnvStoppedException())
      } else if (data == null) {
        Some(new SparkException(s"Could not find $endpointName."))
      } else {
        data.inbox.post(message)
        receivers.offer(data)
        None
      }
    }
    // We don't need to call `onStop` in the `synchronized` block
    error.foreach(callbackIfStopped)
  }
  1. 根据 endpointName 从 Dispatcher 的 endpoints 中取出对应的 EndpointData(之前已经调用 RpcEnv.setupEndpoint 时注册到 Dispatcher 的 endpoints 中);
  2. 将消息内容加入到 EndpointData 的 Inbox 中;
  3. 将 EndpointData 放入到 receivers 中等待Dispatcher.MessageLoop 处理。

这边也将消息放入到消息队列中了,在哪里取出来处理呢?见后文 "Dispatcher里消息处理过程"


3.4 Dispatcher里消息处理过程

在上面的过程中,创建EndpointData时同时会创建Inbox,在创建Inbox时又会将Onstart加入Inbox的内部队列messages,创建完的EndpointData会被放入到Dispatcher的内部队列receivers中,那么这两个队列中的内容在什么地方取出来处理呢?

过程如下:

  1. new Dispatcher 时,会声明一个线程池
  • 进入org.apache.spark.rpc.netty.Dispatcher.scala
  /** Thread pool used for dispatching messages. */
  private val threadpool: ThreadPoolExecutor = {
    val availableCores =
      if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
    val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
      math.max(2, availableCores))
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
    for (i <- 0 until numThreads) {
      pool.execute(new MessageLoop)
    }
    pool
  }
  1. MessageLoop继承了Runnable,循环不断的从 Dispatcher 的 receivers 中取出数据,即上面加入到receivers中的EndpointData,取出来的数据调用其内部的Inbox.process方法继续处理Inbox内的数据
  • 进入org.apache.spark.rpc.netty.Dispatcher.scala
  /** Message loop used for dispatching messages. */
  private class MessageLoop extends Runnable {
    override def run(): Unit = {
      try {
        while (true) {
          try {
            val data = receivers.take()
            if (data == PoisonPill) {
              // Put PoisonPill back so that other MessageLoops can see it.
              receivers.offer(PoisonPill)
              return
            }
            data.inbox.process(Dispatcher.this)
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case _: InterruptedException => // exit
        case t: Throwable =>
          try {
            // Re-submit a MessageLoop so that Dispatcher will still work if
            // UncaughtExceptionHandler decides to not kill JVM.
            threadpool.execute(new MessageLoop)
          } finally {
            throw t
          }
      }
    }
  }
  1. 从Inbox内部队列messages中取出数据来处理(例如上面创建Inbox时加入的Onstart
  • 进入org.apache.spark.rpc.netty.Inbox.scala
  /**
   * Process stored messages.
   */
  def process(dispatcher: Dispatcher): Unit = {
    var message: InboxMessage = null
    inbox.synchronized {
      if (!enableConcurrent && numActiveThreads != 0) {
        return
      }
      message = messages.poll()
      if (message != null) {
        numActiveThreads += 1
      } else {
        return
      }
    }
    while (true) {
      safelyCall(endpoint) {
        message match {
          case RpcMessage(_sender, content, context) =>
            try {
              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
                throw new SparkException(s"Unsupported message $message from ${_sender}")
              })
            } catch {
              case e: Throwable =>
                context.sendFailure(e)
                // Throw the exception -- this exception will be caught by the safelyCall function.
                // The endpoint's onError function will be called.
                throw e
            }

          case OneWayMessage(_sender, content) =>
            endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
              throw new SparkException(s"Unsupported message $message from ${_sender}")
            })

          case OnStart =>
            endpoint.onStart()
            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
              inbox.synchronized {
                if (!stopped) {
                  enableConcurrent = true
                }
              }
            }

          case OnStop =>
            val activeThreads = inbox.synchronized { inbox.numActiveThreads }
            assert(activeThreads == 1,
              s"There should be only a single active thread but found $activeThreads threads.")
            dispatcher.removeRpcEndpointRef(endpoint)
            endpoint.onStop()
            assert(isEmpty, "OnStop should be the last message")

          case RemoteProcessConnected(remoteAddress) =>
            endpoint.onConnected(remoteAddress)

          case RemoteProcessDisconnected(remoteAddress) =>
            endpoint.onDisconnected(remoteAddress)

          case RemoteProcessConnectionError(cause, remoteAddress) =>
            endpoint.onNetworkError(cause, remoteAddress)
        }
      }

      inbox.synchronized {
        // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
        // every time.
        if (!enableConcurrent && numActiveThreads != 1) {
          // If we are not the only one worker, exit
          numActiveThreads -= 1
          return
        }
        message = messages.poll()
        if (message == null) {
          numActiveThreads -= 1
          return
        }
      }
    }
  }

看看OnStart匹配的情况,会调用endpoint.onStart()方法。

这意味着只要 RpcEndpoint 注册到 RpcEnv 上,就会向Dispatcher.Inbox 的内部队列中加入OnStart,那么后台线程就会取出OnStart处理,调用刚才注册的 RpcEndpoint 的 onStart() 方法。

因此,在本文中会调用 Master.onStart() 方法:

  • 进入org.apache.spark.deploy.master.Master.scala
  override def onStart(): Unit = {
    logInfo("Starting Spark master at " + masterUrl)
    logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
    webUi = new MasterWebUI(this, webUiPort)
    webUi.bind()
    masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
    if (reverseProxy) {
      masterWebUiUrl = conf.get("spark.ui.reverseProxyUrl", masterWebUiUrl)
      webUi.addProxy()
      logInfo(s"Spark Master is acting as a reverse proxy. Master, Workers and " +
       s"Applications UIs are available at $masterWebUiUrl")
    }
    checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
      override def run(): Unit = Utils.tryLogNonFatalError {
        self.send(CheckForWorkerTimeOut)
      }
    }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)

    if (restServerEnabled) {
      val port = conf.getInt("spark.master.rest.port", 6066)
      restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl))
    }
    restServerBoundPort = restServer.map(_.start())

    masterMetricsSystem.registerSource(masterSource)
    masterMetricsSystem.start()
    applicationMetricsSystem.start()
    // Attach the master and app metrics servlet handler to the web ui after the metrics systems are
    // started.
    masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
    applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)

    val serializer = new JavaSerializer(conf)
    val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
      case "ZOOKEEPER" =>
        logInfo("Persisting recovery state to ZooKeeper")
        val zkFactory =
          new ZooKeeperRecoveryModeFactory(conf, serializer)
        (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
      case "FILESYSTEM" =>
        val fsFactory =
          new FileSystemRecoveryModeFactory(conf, serializer)
        (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
      case "CUSTOM" =>
        val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
        val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
          .newInstance(conf, serializer)
          .asInstanceOf[StandaloneRecoveryModeFactory]
        (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
      case _ =>
        (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this))
    }
    persistenceEngine = persistenceEngine_
    leaderElectionAgent = leaderElectionAgent_
  }

4. 总结

  1. Master启动时,创建NettyRpcEnv;
    1.1 创建NettyRpcEnv时会在其内部创建Dispatcher;
    1.2 创建Dispatcher时会在其内部创建ConcurrentMap[String, EndpointData] 和 LinkedBlockingQueue[EndpointData];

  2. 创建RpcEndpoint(Master);

  3. 注册RpcEndpoint(Master)到Dispatcher上;
    3.1 注册RpcEndpoint(Master)到Dispatcher时,会先创建NettyRpcEndpointRef(masterEndpoint),
    3.2 然后构建EndpointData(name, endpoint, endpointRef),在构建EndpointData时会在其内部创建Inbox,创建Inbox时会在其内部创建 LinkedList[InboxMessage],同时加入OnStart消息到队列中,
    3.3 然后将(name, EndpointData)放入ConcurrentMap[String, EndpointData],
    3.4 将EndpointData放入LinkedBlockingQueue[EndpointData];

  4. 返回对应的NettyRpcEndpointRef(masterEndpoint);

  5. NettyRpcEndpointRef(masterEndpoint) 发消息;

  6. 构建RequestMessage(senderAddress:RpcAddress, receiver:NettyRpcEndpointRef, content:Any),调用NettyRpcEnv发消息;

  7. NettyRpcEnv发消息时,判断消息接收方地址(receiver)和当前NettyRpcEnv的地址是否相同,
    7.1 如果相同则调用Dispatcher投递消息到本地,
    7.2 如果不同则需要调用Dispatcher把消息投递到远程NettyRpcEnv;

  8. 投递消息到本地时,从receiver中获取endpointName,然后根据此endpointName从Dispatcher的ConcurrentMap中获取EndpointData,
    8.1 将消息加入EndpointData.Inbox.LinkedList[InboxMessage]中,
    8.2 将EndpointData加入Dispatcher.LinkedBlockingQueue[EndpointData]中;

  9. Dispatcher.MessageLoop不断从上面的两个队列中取数据出来处理。

你可能感兴趣的:(Spark源码:启动Master)