spark2.1源码分析3:spark-rpc如何实现将netty的Channel隐藏在inbox中

class  TransportServer

 bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
      @Override
      protected void initChannel(SocketChannel ch) throws Exception {
        RpcHandler rpcHandler = appRpcHandler;
        for (TransportServerBootstrap bootstrap : bootstraps) {
          rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
        }
        context.initializePipeline(ch, rpcHandler);
      }
    });

每次有新client连接完成,server都会调用匿名类的initChannel方法,在调用context.initializePipeline(ch, rpcHandler)时,通过createChannelHandler方法,创建TransportResponseHandler用于处理响应消息(Channel直接被隐藏在TransportResponseHandler中),创建TransportRequestHandler 用于处理请求消息(Channel被封装在TransportClient中,TransportClient又被隐藏在TransportRequestHandler),最后返回的TransportChannelHandler包含三者。

public TransportChannelHandler initializePipeline(
      SocketChannel channel,
      RpcHandler channelRpcHandler) {
    try {
      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
      channel.pipeline()
        .addLast("encoder", encoder)
        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
        .addLast("decoder", decoder)
        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
        // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
        // would require more logic to guarantee if this were not part of the same event loop.
        .addLast("handler", channelHandler);
      return channelHandler;
    } catch (RuntimeException e) {
      logger.error("Error while initializing Netty pipeline", e);
      throw e;
    }
  }
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
    TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
    TransportClient client = new TransportClient(channel, responseHandler);
    TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
      rpcHandler);
    return new TransportChannelHandler(client, responseHandler, requestHandler,
      conf.connectionTimeoutMs(), closeIdleConnections);
  }
class TransportChannelHandler
//原始消息到达,根据消息类型调用不同处理方式。此处以RequestMessage消息为例
  public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
    if (request instanceof RequestMessage) {
      requestHandler.handle((RequestMessage) request);  //上文实例的TransportRequestHandler
    } else {
      responseHandler.handle((ResponseMessage) request);
    }
  }
class  TransportRequestHandler 

  @Override
  public void handle(RequestMessage request) {
    if (request instanceof ChunkFetchRequest) {
      processFetchRequest((ChunkFetchRequest) request);
    } else if (request instanceof RpcRequest) {
      processRpcRequest((RpcRequest) request);   //以rpc请求为例
    } else if (request instanceof OneWayMessage) {
      processOneWayMessage((OneWayMessage) request);
    } else if (request instanceof StreamRequest) {
      processStreamRequest((StreamRequest) request);
    } else {
      throw new IllegalArgumentException("Unknown request type: " + request);
    }
  }

  private void processRpcRequest(final RpcRequest req) {
    try {
      rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {   //reverseClient上文传递的值
        @Override
        public void onSuccess(ByteBuffer response) {
          respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
        }

        @Override
        public void onFailure(Throwable e) {
          respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
        }
      });
    } catch (Exception e) {
      logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
      respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
    } finally {
      req.body().release();
    }
  }

class NettyRpcHandler

  override def receive(
      client: TransportClient,
      message: ByteBuffer,
      callback: RpcResponseCallback): Unit = {
    val messageToDispatch = internalReceive(client, message)  //将client封装在messageToDispatch 中
    dispatcher.postRemoteMessage(messageToDispatch, callback)   //此方法是将消息封装成inbox形式,放入inbox消息队列
  }

internalReceive调用
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)

class NettyRpcEnv
  private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
    NettyRpcEnv.currentClient.withValue(client) {
      deserialize { () =>
        javaSerializerInstance.deserialize[T](bytes)  //反序列化时调用NettyRpcEndpointRef实例的readObject方法,而下面两个DynamicVariable在此处被替换为新值
      }
    }
  }

object NettyRpcEnv 
  //DynamicVariable可以用来在指定作用域替换数据
  private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
  private[netty] val currentClient = new DynamicVariable[TransportClient](null)

class NettyRpcEndpointRef

 private def readObject(in: ObjectInputStream): Unit = {  //反序列化,为client、nettyEnv填充新值,而非通过网络传送过来的空值
    in.defaultReadObject()
    nettyEnv = NettyRpcEnv.currentEnv.value
    client = NettyRpcEnv.currentClient.value
  }

你可能感兴趣的:(Spark,spark-rpc,源码)