Apache Spark之Rpc(上)

SparkNetty.png

0. RpcEnv

整个通信的核心,为通信构建环境,启动server; 建立RpcEndpoint,所有RpcEndpoint(提供某类服务)都需要注册到RpcEnv; 消息路由,也就是整个RpcEndpoint的通信都交给RpcEnv, 屏蔽了rpc调用与本地调用,让上层专注endpiont的设计,通信细节全部封装到RpcEnv。目前唯一的实现就是NettyRpcEnv,以Netty作为rpc的基础。

NettyRpcEnv

[soark-core] org.apache.spark.rpc.netty.NettyRpcEnv

class NettyRpcEnv extends RpcEnv with Logging {
  val role //diver or executor
  val transportConf: TransportConf //spark.rpc.*
  val dispatcher: Dispatcher // 
  val streamManager: NettyStreamManager //
  val transportContext: TransportConext = new TransportContext(transportConf,
    new NettyRpcHandler(dispatcher, this, streamManager))//
  val clientFactory: TransportClientFactory //
  //A separate client factory for file downloads. This avoids using the same RPC handler as
  //the main RPC context, so that events caused by these clients are kept isolated from the main RPC traffic.
  var fileDownloadFactory: TransportClientFactory //文件下载专用,避免影响
  val timeoutScheduler: SchedulerExecutorService
  // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
  // to implement non-blocking send/ask.
  val clientConnectionExecutor: ThreadPoolExecutor
  var server: TransportServer
  val outboxes: ConcurrentHashMap[RpcAddress, Outbox]
  lazy val address: RpcAddress
}

RpcEndpointRef作为通信的发起端,关联某个RpcEndpoint,通过ref进行通信, ref用uri表示为 :

Remote: spark://{endpointName}@{ip}:{port}

Client: spark-client://{endpointName}

dispatcher就是根据不同的endpoint name进行消息分发,交给对应的endpoint进行处理。

2. Client端的建立与通信

NettyRpcEnv在driver和executor上都会创建,我们按照一次请求来分析源码

这里我们介绍一个executor与DriveEndpoint通信获取SparkAppConfig的过程,此时driver端建立的TransportServer是server, executor作为client发起请求获取配置信息。

DriverEndpoint是在初始化SparkContext里创建的。具体为CoarseGrainedSchedulerBackend的字段中构造的

[spark-core] org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend

class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)
  extends ExecutorAllocationClient with SchedulerBackend with Logging {
    ...
    //setup driverEndpoint
    val driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint())
    ...
  }

请求的发起

CoarseGrainedExecutorBackend作为spark的executor启动类。在启动后时需要获取SparkAppConfig

[spark-core] org.apache.spark.executor.CoarseGrainedExecutorBackend

object CoarseGrainedExecutorBackend extends Logging {
  def main(args: Array[String]): Unit = {
    //匿名函数,创建backend对象
    val createFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) =>
      CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env, resourceProfile) =>
      new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId,
        arguments.bindAddress, arguments.hostname, arguments.cores, arguments.userClassPath, env,
        arguments.resourcesFileOpt, resourceProfile)
    }
    run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn)
    System.exit(0)
  }
  //
   def run(
      arguments: Arguments,
      backendCreateFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) =>
        CoarseGrainedExecutorBackend): Unit = {
     ...
     //这是一个临时的NettyRpcEnv用于获取driver的RpcEndpointRef
     val fetcher = RpcEnv.create(
        "driverPropsFetcher",
        arguments.bindAddress,
        arguments.hostname,
        -1,
        executorConf,
        new SecurityManager(executorConf),
        numUsableCores = 0,
        clientMode = true)
         ...
        //这里构造一个driver的rpcEndpointRef
        // spark://CoarseGrainedScheduler@{ip}:{port}
          driver = fetcher.setupEndpointRefByURI(arguments.driverUrl)
      ...
     //通过ref进行rpc调用获取SparkAppConfig
      val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig(arguments.resourceProfileId))
      val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", arguments.appId))
      fetcher.shutdown()
     ...
   }
}

driver.askSync是一个同步请求。等待结果返回。ref的请求最终都委托给了NettyRpcEnv来做处理

private[netty] def askAbortable[T: ClassTag](
      message: RequestMessage, timeout: RpcTimeout): AbortableRpcFuture[T] = {
  ...
  //此时我们在executor
  //remoteAddr是diver地址不为null, address是null
  if (remoteAddr == address) {
    val p = Promise[Any]()
    p.future.onComplete {
      case Success(response) => onSuccess(response)
      case Failure(e) => onFailure(e)
    }(ThreadUtils.sameThread)
    dispatcher.postLocalMessage(message, p)
  } else {
    //注意各种消息的包装,不同的消息包装,在不同的层次中使用rpcOutboxMessage
    val rpcMessage = RpcOutboxMessage(message.serialize(this),
       onFailure,
      //处理返回值的回调,在netty Channel的Handler中调用
      (client, response) => onSuccess(deserialize[Any](client, response)))
       rpcMsg = Option(rpcMessage)
    //核心的方法,把消息加入到Outbox中
    postToOutbox(message.receiver, rpcMessage)
    ...
  }
}

请求消息的发送

Outbox,每一个rpc地址都维护了这样一个消息队列,所有发送到同一个RpcAddress的消息都放到一个队列中,等待TransportClient发送到对应的server。

[spark-core] org.apache.spark.rpc.netty.Outbox

class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
  val messages = new java.util.LinkedList[OutboxMessage]
  var client: TransportClient = null
  ...
  //通过该方法触发消息真正发送。
  //方法可能被多线程调用,仅有一个线程能执行真正的发送消息
  private def drainOutbox(): Unit = {
    var message: OutboxMessage = null
    synchronized {
      if (stopped) {
        return
      }
      //有线程在启动一个链接了,当前线程就不需要做任何事了
      if (connectFuture != null) {
        // We are connecting to the remote address, so just exit
        return
      }
      //链接还没有建立,通过提交一个后台线程创建TransportClient
      //lauchConnectTask创建好好client后会再次调用drainOutbox,也就是当前线程也可以不在管了。由创建链接的线程继续往后执行
      if (client == null) {
        // There is no connect task but client is null, so we need to launch the connect task.
        launchConnectTask()
        return
      }
      if (draining) {
        // There is some thread draining, so just exit
        return
      }
      message = messages.poll()
      if (message == null) {
        return
      }
      draining = true
    }
   //取消息,直到队列被消费完
    while (true) {
      try {
        val _client = synchronized { client }
        if (_client != null) {
          message.sendWith(_client)
        } else {
          assert(stopped)
        }
      } catch {
        case NonFatal(e) =>
          handleNetworkFailure(e)
          return
      }
      synchronized {
        if (stopped) {
          return
        }
        message = messages.poll()
        if (message == null) {
          draining = false
          return
        }
      }
    }
  }
}

TransportClient是对Netty Channel的封装,所以调用message.sendWith(_client),就进入了Netty发送消息的范围了。

TransportClient是通过TransportClientFactory进行创建的,TransportClientFactory维护了该进程的所有的TransportClient,同时为每个RpcAddress创建了一个链接池。

[common/network-common] org.apache.spark.network.client.TransportClientFactory

public class TransportClientFactory implements Closeable {
  //对外暴露的方法,先看有没有缓存的链接,没有就创建一个
   public TransportClient createClient(String remoteHost, int remotePort){
     ...
     // Create the ClientPool if we don't have it yet.
    ClientPool clientPool = connectionPool.get(unresolvedAddress);
    if (clientPool == null) {
      connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
      clientPool = connectionPool.get(unresolvedAddress);
    }
    ...// 这里省略了池里有对象可用,方法直接返回
    //random的位置没有链接,新建立一个
    synchronized (clientPool.locks[clientIndex]) {
      cachedClient = clientPool.clients[clientIndex];
      ...//double check
      //create new client
      clientPool.clients[clientIndex] = createClient(resolvedAddress);
      return clientPool.clients[clientIndex];
    }
   }
  //建立TransportClient,netty client
  private TransportClient createClient(InetSocketAddress address) {
    //熟悉的netty style
    Bootstrap bootstrap = new Bootstrap();
    bootstrap.group(workerGroup)
      .channel(socketChannelClass)
      // Disable Nagle's Algorithm since we don't want packets to wait
      .option(ChannelOption.TCP_NODELAY, true)
      .option(ChannelOption.SO_KEEPALIVE, true)
      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
      .option(ChannelOption.ALLOCATOR, pooledAllocator);
    bootstrap.handler(new ChannelInitializer() {
      @Override
      public void initChannel(SocketChannel ch) {
        //将处理封装在TransportChannelHandler中
        //context为TransportContext,统一封装了server/client的channel handler。
        //这里我只需要知道,新的socket建立后,处理消息就交给TransportChannelHandler了
        TransportChannelHandler clientHandler = context.initializePipeline(ch);
        clientRef.set(clientHandler.getClient());
        channelRef.set(ch);
      }
    });
    ChannelFuture cf = bootstrap.connect(address);
    ...//等待链接建立完成
    return client;
  }
}

从上面的代码中可以知道,这里的Message即为RpcOutboxMessage,该类定义在Ouxbox文件里面。

case class RpcOutboxMessage(
    content: ByteBuffer,
    _onFailure: (Throwable) => Unit,
    _onSuccess: (TransportClient, ByteBuffer) => Unit)
//messge是消息载体,同时也是一个callBack,会在请求返回时进行调用
  extends OutboxMessage with RpcResponseCallback with Logging {

  private var client: TransportClient = _
  private var requestId: Long = _
//通过Transportclient发送消息
  override def sendWith(client: TransportClient): Unit = {
    this.client = client
    this.requestId = client.sendRpc(content, this)
  }
}

来到TransportClient中

[common/network-common] org.apache.spark.network.client.TransportClient

public class TransportClient implements Closeable {
  private final Channel channel;
  private final TransportResponseHandler handler;
  @Nullable private String clientId;
  ...
  //唯一的构造函数,Channel就是netty的channel
  public TransportClient(Channel channel, TransportResponseHandler handler) {
    this.channel = Preconditions.checkNotNull(channel);
    this.handler = Preconditions.checkNotNull(handler);
    this.timedOut = false;
  }
  ...
  //
  public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
    if (logger.isTraceEnabled()) {
      logger.trace("Sending RPC to {}", getRemoteAddress(channel));
    }

    long requestId = requestId();
    //这里也是重点,callback是处理返回值的
    //hanlder是与client一同创建的
    handler.addRpcRequest(requestId, callback);

    //这里把callbakc只处理onFail的情况
    RpcChannelListener listener = new RpcChannelListener(requestId, callback);
    channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
      .addListener(listener);

    return requestId;
  }
}

到这里,请求的message已经通过netty发送出去了。接下看我们看看怎么处理消息返回的的

接收返回消息

Netty client处理消息返回,即在BootStrap上添加handler,这个处理就在TransportClient创建的过程中

TransportClientFactory#createClient

bootstrap.handler(new ChannelInitializer() {
      @Override
      public void initChannel(SocketChannel ch) {
        TransportChannelHandler clientHandler = context.initializePipeline(ch);
        clientRef.set(clientHandler.getClient());
        channelRef.set(ch);
      }
    });

这里的context即为NettyRpcEnv的成员变量

val transportContext: TransportConext = new TransportContext(transportConf,
    new NettyRpcHandler(dispatcher, this, streamManager))

这里就需要进一步介绍TransportContext,

TranportContext用于创建创建TransportClientFactory, TrasnportSever,以及为配置Netty的ChannelHandler

[common/network-common] org.apache.spark.network.TransportContext

public class TransportContext implements Closeable {
  ...
  private final TransportConf conf;
  //rpcHandler,处理request信息
  private final RpcHandler rpcHandler; 
  //client
  public TransportClientFactory createClientFactory(...);
  //server
  public TransportServer createServer(...)
  
  //配置ChannelHandler
  public TransportChannelHandler initializePipeline(SocketChannel channel) {
    return initializePipeline(channel, rpcHandler);
  }
  
  public TransportChannelHandler initializePipeline(
      SocketChannel channel,
      RpcHandler channelRpcHandler) {
    try {
      //这里的handler是封装了server / client两端的handler
      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
      ChannelPipeline pipeline = channel.pipeline()
        .addLast("encoder", ENCODER)
        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
        .addLast("decoder", DECODER)
        .addLast("idleStateHandler",
          new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
        // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
        // would require more logic to guarantee if this were not part of the same event loop.
        .addLast("handler", channelHandler);
      // Use a separate EventLoopGroup to handle ChunkFetchRequest messages for shuffle rpcs.
      if (chunkFetchWorkers != null) {
        ChunkFetchRequestHandler chunkFetchHandler = new ChunkFetchRequestHandler(
          channelHandler.getClient(), rpcHandler.getStreamManager(),
          conf.maxChunksBeingTransferred(), true /* syncModeEnabled */);
        pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", chunkFetchHandler);
      }
      return channelHandler;
    } catch (RuntimeException e) {...}
  }
   /**
   * Creates the server- and client-side handler which is used to handle both RequestMessages and
   * ResponseMessages. The channel is expected to have been successfully created, though certain
   * properties (such as the remoteAddress()) may not be available yet.
   */
  private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
    TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
    TransportClient client = new TransportClient(channel, responseHandler);
    boolean separateChunkFetchRequest = conf.separateChunkFetchRequest();
    ChunkFetchRequestHandler chunkFetchRequestHandler = null;
    if (!separateChunkFetchRequest) {
      chunkFetchRequestHandler = new ChunkFetchRequestHandler(
        client, rpcHandler.getStreamManager(),
        conf.maxChunksBeingTransferred(), false /* syncModeEnabled */);
    }
    TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
      rpcHandler, conf.maxChunksBeingTransferred(), chunkFetchRequestHandler);
    return new TransportChannelHandler(client, responseHandler, requestHandler,
      conf.connectionTimeoutMs(), separateChunkFetchRequest, closeIdleConnections, this);
  }
}

ChannelHandler确定了,就可以看到从Channel怎么处理数据了。

[common/network-common] org.apache.spark.network.server.TransportChannelHandler

public class TransportChannelHandler extends SimpleChannelInboundHandler {
  private final TransportClient client;
  private final TransportResponseHandler responseHandler;
  private final TransportRequestHandler requestHandler; 
  private final TransportContext transportContext;
  ...
  //server/client端处理消息
  public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
    if (request instanceof RequestMessage) {
      requestHandler.handle((RequestMessage) request);
    } else if (request instanceof ResponseMessage) {
      responseHandler.handle((ResponseMessage) request);
    } else {
      ctx.fireChannelRead(request);
    }
  }
}

这里分析的是消息的返回,也就是进入到TransportResponseHandler

[common/network-common] org.apache.spark.network.server.TransportResponseHandler#handle

public void handle(ResponseMessage message) throws Exception {
  ...//我们目前只关注executor请求sparkAppConfig
  else if (message instanceof RpcResponse) {
      RpcResponse resp = (RpcResponse) message;
        //在TransportClient#sendRpc中,保存了callBack与requestId的映射,现在就是用到callback的时候
      RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
      if (listener == null) {
        logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
          resp.requestId, getRemoteAddress(channel), resp.body().size());
      } else {
        outstandingRpcs.remove(resp.requestId);
        try {
          //通知消息返回. 此时是在netty的worker线程
          //这里的listener就是RpcOutMessage本身。
          listener.onSuccess(resp.body().nioByteBuffer());
        } finally {
          resp.body().release();
        }
      }
    }
}

到这里,消息的返回就回到了RpcOutboxMessage创建的地方,即回到NettyRpcEnv#askAbortable,进一步查看回调如何处理

private[netty] def askAbortable[T: ClassTag](
      message: RequestMessage, timeout: RpcTimeout): AbortableRpcFuture[T] = {
  ...
  def onSuccess(reply: Any): Unit = reply match {
      case RpcFailure(e) => onFailure(e)
      case rpcReply =>
            //这里Future的状态就是success.
        if (!promise.trySuccess(rpcReply)) {
          logWarning(s"Ignored message: $reply")
        }
    }
  ...
  //创建并定义了回调
  val rpcMessage = RpcOutboxMessage(message.serialize(this),
          onFailure,
          (client, response) => onSuccess(deserialize[Any](client, response)))
}

进一步的,driverEndpointRef#askSync中的awaitResult就可以从阻塞返回了。

def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
    val future = ask[T](message, timeout)
    timeout.awaitResult(future)
  }

在executor端就可以拿到SparkAppConfig了,至此client端完成一次请求的通信就完成了。稍后分析,server端,接收到消息后,如何路由到正确的RpcEndpoint,以及处理请求后如何返回。

通信作为驱动整个应用运作的核心,包括信息交换,数据传输,信号传播等都依赖通信。所以所以spark通信作为源码分析的开篇。

作为大数据从业新人,希望向各位前辈学习,如果理解有不恰当的,望不吝指教!

注:源码基于Apache Spark 3.0

作者:pokerwu
本作品采用知识共享署名-非商业性使用 4.0 国际许可协议进行许可。

你可能感兴趣的:(Apache Spark之Rpc(上))