spark源码阅读——rpc部分

rpc可以说是一个分布式系统最基础的组件了。这里解析一下spark的内部rpc框架。

RpcEndpoint

RpcEndpoint 这个接口表示一个Rpc端点,只要继承了这个trait
就具备了收发Rpc消息的能力,主要包含以下方法

  • 接收信息类

    • def receive: PartialFunction[Any, Unit] 一个偏函数,用来接受其他RpcEndpoint发来的信息,其他类可以覆盖这个方法来重写接受信息的逻辑

    • def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] 方法和上面那个差不多,不过这个处理过逻辑之后可以返回一些信息

  • 回调类

    • def onConnected(remoteAddress: RpcAddress): Unit 当有远程主机连接到这个RpcEndpoint时的回调
    • onStart,onStop,onDisconnected等回调

RpcEndpointRef

RpcEndpointRef表示了一个远程RpcEndpoint和当前端点的一个连接,如果想发送RPC消息给其他主机,可以先通过远程地址RpcAddress(一个表示远程端点的case class)获取RpcEndpointRef对象。通过这个对象发送RPC消息给远程节点。主要包括以下方法

  • 异步发送请求 def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
    这个方法发送任意的消息给远程端点,并返回一个Future对象。当远端返回信息的时候可以从这个对象获取结果。

  • 同步发送请求 def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T 等待直到返回结果

  • 只发送信息 def send(message: Any): Unit

RpcEnv

这个接口可以说非常重要了,保存了所有的远程端点信息,而且负责RPC消息的分发。每一个RpcEndpoint都有一个RpcEnv对象。如果想要与其他RpcEndpoint连接并收发信息,需要向远端RpcEndpoint注册自己,远端RpcEndpoint收到注册信息之后,会将请求连接的信息保存在RpcEnv对象中,这样就算是两个RpcEndpoint彼此连接上了(可以双向收发信息了)

  • Endpoint的注册方法

    • def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
      用来一个Endpoint把自己注册到本地的RpcEnv中。一个进程可能有多个Endpoint 比如说一个接收心跳信息的,还有一个用来监听Job的运行状态的,用来监听Executor返回信息的等等。
      RpcEndpoint通过RpcEnv发送信息给RpcEndpointRef
      RpcEnv内部将接收到的信息分发给注册在RpcEnv中的RpcEndpoint

    • def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] 异步注册

    • def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef 同步注册

  • 生命周期方法

    • stop
    • shutdown
    • awaitTermination

RpcCallContext

下面分析时会说,先贴出方法

private[spark] trait RpcCallContext {

  /**
   * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
   * will be called.
   */
  def reply(response: Any): Unit

  /**
   * Report a failure to the sender.
   */
  def sendFailure(e: Throwable): Unit

  /**
   * The sender of this message.
   */
  def senderAddress: RpcAddress
}

spark 中使用了Netty实现了这些Rpc接口,下面看一看使用netty的实现。

NettyRpcEnvFactory

private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {

  def create(config: RpcEnvConfig): RpcEnv = {
    val sparkConf = config.conf
    // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
    // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
    val javaSerializerInstance =
      new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
    val nettyEnv =
      new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
        config.securityManager)
    if (!config.clientMode) {
      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
        nettyEnv.startServer(config.bindAddress, actualPort)
        (nettyEnv, nettyEnv.address.port)
      }
      try {
        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
      } catch {
        case NonFatal(e) =>
          nettyEnv.shutdown()
          throw e
      }
    }
    nettyEnv
  }
}

用来创建NettyRpcEnv对象一个工厂,创建了一个NettyRpcEnv对象。
并启动了一个Netty服务器(nettyEnv.startServer方法)

NettyRpcEnv

这个对象主要包含了一个Dispatcher

private[netty] class NettyRpcEnv(
    val conf: SparkConf,
    javaSerializerInstance: JavaSerializerInstance,
    host: String,
    securityManager: SecurityManager) extends RpcEnv(conf) with Logging {

  ...
  private val dispatcher: Dispatcher = new Dispatcher(this)
  ...
  private val transportContext = new TransportContext(transportConf,
    new NettyRpcHandler(dispatcher, this, streamManager))
  ...
  @volatile private var server: TransportServer = _
  private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
  ... 

  def startServer(bindAddress: String, port: Int): Unit = {
        .....
        server = transportContext.createServer(bindAddress, port, bootstraps)
        dispatcher.registerRpcEndpoint(
        RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
  }
}

上面说到调用了startServer方法
而这个方法内部则向dispatcher对象注册了一个RpcEndpointVerifier,这个对象其实也是一个RpcEndpoint

private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
  extends RpcEndpoint {

  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
  }
}

private[netty] object RpcEndpointVerifier {
  val NAME = "endpoint-verifier"

  /** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */
  case class CheckExistence(name: String)
}

这里便是我们遇到的第一个RpcEndpoint 如果收到了CheckExistence这个类型的信息则调用dispatcherverify方法。

我们先看一下这个dispatcher对象。

Dispatcher

这个对象的职责便是将收到的Rpc信息分发给不同的Endpoint,可以看到内部有一个ConcurrentHashMap用来保存所有注册的RpcEndpoint

private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {

  private class EndpointData(
      val name: String,
      val endpoint: RpcEndpoint,
      val ref: NettyRpcEndpointRef) {
    val inbox = new Inbox(ref, endpoint)
  }

  private val endpoints: ConcurrentMap[String, EndpointData] =
    new ConcurrentHashMap[String, EndpointData]

  private val receivers = new LinkedBlockingQueue[EndpointData]
  ....

}

上面说到的registerRpcEndpoint方法实际上将RpcEndpointVerifier放入了这两个容器中。
RpcEndpointVerifier则被其他Endpoint用来判断自己是否被成功注册到这个RpcEnv中。
远程Endpoint发送一个包含自己名字的信息给这个RpcEnv中的这个RpcEndpointVerifier随后会检查保存Endpoint信息的容器中是否包含注册信息,并将结果返回

NettyRpcEndpointRef

前面说过RpcEndpointRef代表远端的Endpoint,可以用来发送RPC信息


private[netty] class NettyRpcEndpointRef(
    @transient private val conf: SparkConf,
    private val endpointAddress: RpcEndpointAddress,
    @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
    ...

    override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
        nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
    }
}

让我们回到RpcEnv.ask方法

private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
    val promise = Promise[Any]()
    val remoteAddr = message.receiver.address

    def onFailure(e: Throwable): Unit = { ... }
    def onSuccess(reply: Any): Unit = reply match { ... }

    try {
      if (remoteAddr == address) {
        val p = Promise[Any]()
        p.future.onComplete {
          case Success(response) => onSuccess(response)
          case Failure(e) => onFailure(e)
        }(ThreadUtils.sameThread)
        dispatcher.postLocalMessage(message, p)
      } else {
        val rpcMessage = RpcOutboxMessage(message.serialize(this),
          onFailure,
          (client, response) => onSuccess(deserialize[Any](client, response)))
        postToOutbox(message.receiver, rpcMessage)
        promise.future.onFailure {
          case _: TimeoutException => rpcMessage.onTimeout()
          case _ =>
        }(ThreadUtils.sameThread)
      }

      val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
        override def run(): Unit = {
          onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
            s"in ${timeout.duration}"))
        }
      }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
      promise.future.onComplete { v =>
        timeoutCancelable.cancel(true)
      }(ThreadUtils.sameThread)
    } catch { ... }
    promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
  }

这个方法由3部分构成
第一部分:判断消息是否是发给本地注册的RpcEndpoint的,是则发送本地信息
第二部分:如果是发给远程Endpoint的,放到OutBox里面,等待处理
第三部分:超时处理,起了一个定时任务,如果超时则报异常。同时给声明的Promise对象增加了一个回调,当rpc调用在超时前完成则取消之前起的定时任务。

我们首先看dispatcher.postLocalMessage,这个方法封装了调用信息,

def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
    val rpcCallContext =
      new LocalNettyRpcCallContext(message.senderAddress, p)
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
    postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
  }

实际上走了dispatcher.postMessage方法,实际做了3件事:

1.获取到EndpointData对象
2.往这个对象的inbox对象发信息
3.将EndpointData对象放入 receivers队列中

       
private def postMessage(
      endpointName: String,
      message: InboxMessage,
      callbackIfStopped: (Exception) => Unit): Unit ={
       ...
      val data = endpoints.get(endpointName)
      data.inbox.post(message)
      receivers.offer(data)
       ...
}

inbox对象实际就保存了发往Endpoint对象的信息。发到这里其实Endpoint 已经收到信息了。 但是post方法只是将消息放到队列里面,那么实际是怎么发送给Endpoint的呢?

private[netty] class Inbox(
    val endpointRef: NettyRpcEndpointRef,
    val endpoint: RpcEndpoint)
  extends Logging {

  inbox =>  // Give this an alias so we can use it more clearly in closures.

  @GuardedBy("this")
  protected val messages = new java.util.LinkedList[InboxMessage]()
  ...
 
  def post(message: InboxMessage): Unit = inbox.synchronized {
    if (stopped) {
      // We already put "OnStop" into "messages", so we should drop further messages
      onDrop(message)
    } else {
      messages.add(message)
      false
    }
  ...
  }

Dispatcher对象里面有一个线程池,每个线程会不断的从receivers队列中获取EndpointData并处理其中的inbox对象保存的信息

private val threadpool: ThreadPoolExecutor = {
    val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
      math.max(2, Runtime.getRuntime.availableProcessors()))
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
    for (i <- 0 until numThreads) {
      pool.execute(new MessageLoop)
    }
    pool
  }

private class MessageLoop extends Runnable {
    override def run(): Unit = {
      try {
        while (true) {
          try {
            val data = receivers.take()
            if (data == PoisonPill) {
              // Put PoisonPill back so that other MessageLoops can see it.
              receivers.offer(PoisonPill)
              return
            }
            data.inbox.process(Dispatcher.this)
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }

我们再回到inbox.process方法

def process(dispatcher: Dispatcher): Unit = {
    var message: InboxMessage = null
    inbox.synchronized {
      ... 
      message = messages.poll()
      ...
    }
    while (true) {
      safelyCall(endpoint) {
        message match {
          case RpcMessage(_sender, content, context) =>
            try {
              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
                throw new SparkException(s"Unsupported message $message from ${_sender}")
              })
            } catch { ... }

          case OneWayMessage(_sender, content) =>
            endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
              throw new SparkException(s"Unsupported message $message from ${_sender}")
            })

          case OnStart =>
            endpoint.onStart()
            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
              inbox.synchronized {
                if (!stopped) {
                  enableConcurrent = true
                }
              }
            }

          case OnStop =>
            val activeThreads = inbox.synchronized { inbox.numActiveThreads }
            ...
            dispatcher.removeRpcEndpointRef(endpoint)
            endpoint.onStop()
            ...

          case RemoteProcessConnected(remoteAddress) =>
            endpoint.onConnected(remoteAddress)

          case RemoteProcessDisconnected(remoteAddress) =>
            endpoint.onDisconnected(remoteAddress)

          case RemoteProcessConnectionError(cause, remoteAddress) =>
            endpoint.onNetworkError(cause, remoteAddress)
        }
      }

      inbox.synchronized {
        ... 
        message = messages.poll()
        if (message == null) {
          numActiveThreads -= 1
          return
        }
      }
    }
  }

可以看到这个方法不停的从messages队列中获取对象直到队列里面没有信息
之前发送给本地的Endpoint的消息是InboxMessage这个对应的模式匹配中的哪个对象呢?

private[netty] sealed trait InboxMessage

private[netty] case class OneWayMessage(
    senderAddress: RpcAddress,
    content: Any) extends InboxMessage

private[netty] case class RpcMessage(
    senderAddress: RpcAddress,
    content: Any,
    context: NettyRpcCallContext) extends InboxMessage

private[netty] case object OnStart extends InboxMessage

private[netty] case object OnStop extends InboxMessage

之前发送的本地消息是RpcMessage类型的,InboxEndpoint是一一对应的,所以会直接调用endpoint.receiveAndReply方法进行相应的处理,也就是说这时候消息已经发送到Endpoint了。(可以参考RpcEndpointVerifier.receiveAndReply,这是其中一种RpcEndpoint,在这个流程中可以理解为,本地的RpcEndpoint向本地的RpcEnv确认是否成功注册)

那么我们看一下发送消息给远程的RpcEndpoint消息被封装成RpcOutboxMessage,并调用了postToOutbox方法

private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
    if (receiver.client != null) {
      message.sendWith(receiver.client)
    } else {
      ...
      val targetOutbox = {
        val outbox = outboxes.get(receiver.address)
        ...
      }
      if (stopped.get) { ... } else {
        targetOutbox.send(message)
      }
    }
  }

private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
outbox => // Give this an alias so we can use it more clearly in closures.

  @GuardedBy("this")
  private val messages = new java.util.LinkedList[OutboxMessage]

  @GuardedBy("this")
  private var client: TransportClient = null

  @GuardedBy("this")
  private var connectFuture: java.util.concurrent.Future[Unit] = null

  def send(message: OutboxMessage): Unit = {
    val dropped = synchronized {
      if (stopped) { ... } else {
        messages.add(message)
        false
      }
    }
    if (dropped) { ... } else {
      drainOutbox()
    }
  }
 

每个Outbox里面包含

  • 一个保存消息的队列
  • 一个TransportClient 连接远程的RpcEndpoint并用来发送信息

drainOutbox方法实际做了2件事

  1. 检查是否和远端的 RpcEndpoint建立了连接,没有则起一个线程建立连接
  2. 遍历队列,发送信息给远端的RpcEnvTransportServer这个信息会被远端的 NettyRpcHandler处理
private[netty] class NettyRpcHandler(
    dispatcher: Dispatcher,
    nettyEnv: NettyRpcEnv,
    streamManager: StreamManager) extends RpcHandler with Logging {

  // A variable to track the remote RpcEnv addresses of all clients
  private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()

  override def receive(
      client: TransportClient,
      message: ByteBuffer,
      callback: RpcResponseCallback): Unit = {
    val messageToDispatch = internalReceive(client, message)
    dispatcher.postRemoteMessage(messageToDispatch, callback)
  }
}
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
    val rpcCallContext =
      new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
    postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
  }

于是我们又看到了postMesage这个方法,而这次是调用的远端的RpcEnvDispatcherpostMessage,消息最后也会被发送给注册到远端的RpcEnv中的RpcEndpoint,这样远端的RpcEndpoint便收到了来自本地的信息。完成了RPC通信。

你可能感兴趣的:(spark源码阅读——rpc部分)