Netty 是一个基于NIO的客户、服务器端编程框架,Netty 是一个基于NIO的客户、服务器端编程框架。
Spark Netty RPC通信可分为两大部分:
- 首先,是Spark 对Netty通信基础Api的封装,实现Server和Client,这部分代码主要以java实现,在network-common 子项目中,分别实现了TransportServer、TransportClientFactory(连接池)、TransportClient(客户端)等功能
- 后续,在下一篇文章,我们会分析Rpc消息如何在应用层创建、发送、分发(dispatcher)、处理(Endpoint)及回调,这一完整流程
在本文,我们主要介绍Spark对Netty通信层的封装,以及在参考Spark Netty实现原理的基础上,在项目中的应用。
Netty通信基础框架中,NettyRpcEnv初始化创建TransportContext,其具有createServer()、createClientFactory()核心功能,是创建服务端和客户端连接池的起点。
以下TransportContex概要图在查看过程中,注意有①~④子模块,需各自对应起来分析。
在NettyRpcEnv中创建TransportContext:
private[netty] class NettyRpcEnv(
val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance,
host: String,
...) extends RpcEnv(conf) with Logging {
...
val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", numUsableCores))
val streamManager = new NettyStreamManager(this)
// 创建TransportContext上下文
private val transportContext = new TransportContext(transportConf,
new NettyRpcHandler(dispatcher, this, streamManager))
...
}
从NettyRpcEnv中创建transportConf可知,其数据源是SparkConf全局配单的一份拷贝,同时提供了netty相关配置的get()方法,包括io mode、timeout、buffer size、threads等netty配置:
public class TransportConf {
private final ConfigProvider conf;
...
public int numConnectionsPerPeer() {
return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);
}
public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); }
public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); }
public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); }
public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); }
}
由TransportContext的创建可知,RpcHandler的实现为: new NettyRpcHandler(dispatcher, this, streamManager)
NettyRpcHandler内部具有dispatcher和streamManager两个重要成员,这里分别介绍:
dispatcher作用是在server端分发各rpc消息至inbox和应用级Endpoint,具体分发流程在下一篇文章介绍,这里只做简单介绍,NettyRpcHandler中,涉及dispatcher有两类用法:
A、接收数据:
** receive(client, messages, callback)【server TransportRequestHandler#processRpcRequest()处理Rpc消息】、
** receive(client, messages)【server TransportRequestHandler#processOneWayMessage()处理单向消息】、
** internalReceive()【直接数据包装为RequestMessage】
B、代理TransportRequestHandler的channelActive()、channelInactive()、exceptionCaught()方法,在channel通道连接、失去连接、捕获异常时,使用postToAll()给dispatcher发送对应远端client状态消息
- NettystreamManager作用是文件传输,我们通过SparkContex调用addJar()增加依赖jar包,会调用env.rpcEnv.fileServer.addJar(file),将文件添加到NettyStreamManager中
- OneForOneStreamManger用于点对点chunk传输,核心方法是getChunk(),其创建流程是NettyBlockTransferService–>NettyBlockRpcServer–>OneForOneStreamManager,最后nettyBlockTransferService在SparkEnv中由BlockManager持有,提供块传输服务(文章篇幅限制,这个分支感兴趣的读者可以深入研究)。
以NettystreamManager为例,介绍下载文件的整体流程:
NettyStreamManager的核心方法是openStream(streamId),将文件打开为输出流。
由Executor发起请求文件的StreamRequest消息,并最终返回文件流及StreamResponse:
上图需要注意的点:
- TransportRequestHandler是在Server(即spark driver)中,进行文件的打开、返回消息的封装,并通过channel写出。
- 其余模块均在Executor中,属于Client端功能。
以TransportResponseHandler中handler(responseMessage)为起点,Executor中处理StreamResponse消息的流程如下图
分为查找callback、设置拦截器、在TransportFrameDecoder的channelRead()中调用拦截器的handler()方法处理数据。
TransportServer是一个逻辑上的服务端概念,init()方法内部调用contex.initializePipleline()时,创建了TransportChannelHandler、TransportRequestHandler,与指定的channel绑定,用于channelRead()时消息读取,最后将消息交给其内部nettyRpcHandler的dispatcher进行消息分发,最后,各应用Endpoint会接收和最终响应消息。
public class TransportContext {
public TransportServer createServer(int port, List bootstraps) {
return new TransportServer(this, null, port, rpcHandler, bootstraps);
}
public TransportServer createServer(
String host, int port, List bootstraps) {
return new TransportServer(this, host, port, rpcHandler, bootstraps);
}
}
public class TransportServer implements Closeable {
private void init(String hostToBind, int portToBind) {
IOMode ioMode = IOMode.valueOf(conf.ioMode());
// Netty服务端需要同时创建bossGroup和workerGroup
EventLoopGroup bossGroup =
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server");
EventLoopGroup workerGroup = bossGroup;
PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator(
conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
// 创建Netty服务的根引导程序并对其进行配置
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
.option(ChannelOption.ALLOCATOR, allocator)
.option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)
.childOption(ChannelOption.ALLOCATOR, allocator);
this.metrics = new NettyMemoryMetrics(
allocator, conf.getModuleName() + "-server", conf);
...
// 为根引导程序设置回调函数
bootstrap.childHandler(new ChannelInitializer() {
@Override
protected void initChannel(SocketChannel ch) {
RpcHandler rpcHandler = appRpcHandler;
for (TransportServerBootstrap bootstrap : bootstraps) {
rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
}
// 注意这里是调用TransportContext提供的#initializePipeline,创建TransportChannelHandler,并设置当前channel对应的encoder、decoder等处理器
context.initializePipeline(ch, rpcHandler);
}
});
// 给根引导程序绑定端口
InetSocketAddress address = hostToBind == null ?
new InetSocketAddress(portToBind) : new InetSocketAddress(hostToBind, portToBind);
channelFuture = bootstrap.bind(address);
channelFuture.syncUninterruptibly();
port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
logger.debug("Shuffle server started on port: {}", port);
}
}
可直接参考本文第一张概要图示
dispatcher消息分发机制在下一篇文章详细讲解,本文主要讲netty的通信接口封装
public class TransportContext {
public TransportClientFactory createClientFactory(List bootstraps) {
return new TransportClientFactory(this, bootstraps);
}
public TransportClientFactory createClientFactory() {
return createClientFactory(new ArrayList<>());
}
}
–> TransportClientFactory内部创建了多个连接池(connectionPool),每个ClientPool连接池对应一个远端未解析地址:private final ConcurrentHashMap
connectionPool;
–> TransportClientFactory定义了两个不同用途的createClient()方法:
- 【createClient(String remoteHost, int remotePort)】:使用远端host、port,创建未解析地址(unresolved InetSocketAddress),在connectionPool中获取对应地址的连接池,并使用随机数从ClientPool中获取某一个TransportClient
- 【createClient(InetSocketAddress address)】:如果上一步通过随机数未从ClientPool中获到TransportClient,则对unresolved InetSocketAddress进行DNS解析,并调用createClient(InetSocketAddress address)创建TransportClient,并放入ClientPool中
public class TransportClientFactory implements Closeable {
private static class ClientPool {
TransportClient[] clients;
Object[] locks;
ClientPool(int size) {
clients = new TransportClient[size];
locks = new Object[size];
for (int i = 0; i < size; i++) {
locks[i] = new Object();
}
}
}
private final ConcurrentHashMap connectionPool;
}
先说下结论:
- 从connectionPool的定义来看,一个远端(服务端)SocketAddress,对应一个ClientPool
- ClientPool内部维护了一个size长度的TransportClient数组,以及对应size长度锁对象(object)
步骤见以下代码注释,其内部分为两大部分:
- 尝试获取ClientPool及其已缓存的cacheClient,如果为Active状态,则返回这一结果
- 如果未获取到cacheClient或其状态非Active,则解析传入的remoteHost、remotePort(即解析DNS),并最后调用createClient(InetSocketAddress address),创建一个新的TransportClient
public class TransportClientFactory implements Closeable {
public TransportClient createClient(String remoteHost, int remotePort)
throws IOException, InterruptedException {
// 创建未解析的SocketAddress
final InetSocketAddress unresolvedAddress =
InetSocketAddress.createUnresolved(remoteHost, remotePort);
// 从获取指定Server的连接池
ClientPool clientPool = connectionPool.get(unresolvedAddress);
// 连接池ClientPool连接池不存在,则先创建numConnectionsPerPeer设定size的ClientPool
if (clientPool == null) {
connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
clientPool = connectionPool.get(unresolvedAddress);
}
// 从ClientPool中随机获取一个TransportClient(即尝试获取已缓存的TransportClient)
int clientIndex = rand.nextInt(numConnectionsPerPeer);
TransportClient cachedClient = clientPool.clients[clientIndex];
// 如果缓存的cachedClient 不为空且为Active状态,则获取其handler,更新时间
if (cachedClient != null && cachedClient.isActive()) {
TransportChannelHandler handler = cachedClient.getChannel().pipeline()
.get(TransportChannelHandler.class);
synchronized (handler) {
handler.getResponseHandler().updateTimeOfLastRequest();
}
// 如果cachedClient为Active状态,则返回当前的cachedClient
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}",
cachedClient.getSocketAddress(), cachedClient);
return cachedClient;
}
}
// 解析DNS,创建resolvedAddress
final long preResolveHost = System.nanoTime();
final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
if (hostResolveTimeMs > 2000) {
logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
} else {
logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
}
synchronized (clientPool.locks[clientIndex]) {
cachedClient = clientPool.clients[clientIndex];
// 再次校验cachedClient 是否已经被别的线程在竞态条件下创建
if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
}
}
// 调用 createClient(resolvedAddress)创建一个新的TransportClient,并返回
clientPool.clients[clientIndex] = createClient(resolvedAddress);
return clientPool.clients[clientIndex];
}
}
}
TransportClient内部保存了Channel和TransportResponseHandler对象:
- 可以通过channel.writeAndFlush()发送数据
- 可以通过TransportResponseHandler来处理Server返回的RPC消息,TransportResponseHandler内部private final Map
outstandingRpcs;即需要回调消息,key即消息唯一UUID
TransportClient定义:
public class TransportClient implements Closeable {
private final Channel channel;
private final TransportResponseHandler handler;
...
// 发送RPC请求,注意requestId()即为创建消息的唯一UUID
public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
...
long requestId = requestId();
handler.addRpcRequest(requestId, callback);
RpcChannelListener listener = new RpcChannelListener(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
.addListener(listener);
return requestId;
}
// 发送OneWay请求,较为简单
public void send(ByteBuffer message) {
channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
}
// 上篇文章介绍的发送文件下载Stream请求
public void stream(String streamId, StreamCallback callback) {
StdChannelListener listener = new StdChannelListener(streamId) {
@Override
void handleFailure(String errorMsg, Throwable cause) throws Exception {
callback.onFailure(streamId, new IOException(errorMsg, cause));
}
};
...
synchronized (this) {
handler.addStreamCallback(streamId, callback);
channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
}
}
}
通过createClient()创建TransportClient :
public class TransportClientFactory implements Closeable {
private TransportClient createClient(InetSocketAddress address)
throws IOException, InterruptedException {
// 创建workGroup,并设置参数
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
.option(ChannelOption.ALLOCATOR, pooledAllocator);
...
final AtomicReference clientRef = new AtomicReference<>();
final AtomicReference channelRef = new AtomicReference<>();
// 为根引导程序设置回调函数
bootstrap.handler(new ChannelInitializer() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
});
// Connect to the remote server
long preConnect = System.nanoTime();
// 真实连接远端Server
ChannelFuture cf = bootstrap.connect(address);
if (!cf.await(conf.connectionTimeoutMs())) {
throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
// 设置为需要返回的TransportClient
TransportClient client = clientRef.get();
Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null client";
long preBootstrap = System.nanoTime();
try {
// 执行认证相关的引导程序
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client, channel);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
...
}
long postBootstrap = System.nanoTime();
return client;
}
}
本文第二、三章节,TransportServer的init()初始化、TransportClientFactory的createClient(InetSocketAddress address)创建TransportClient,在设置channel的回调函数时,均使用到了context.initializePipeline(ch),可见其重要性,这里分析其回调逻辑:
initializePipeline(SocketChannel channel,RpcHandler channelRpcHandler)分为两步:
A、即使用createChannelHandler(channel, channelRpcHandler)创建TransportChannelHandler,则个类非常重要,其channelRead(ChannelHandlerContext ctx, Object request)方法即真实的读入消息,并处理
B、设置ENCODER、DECODER及channelHandler处理函数
public class TransportContext {
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
.addLast("handler", channelHandler);
return channelHandler;
} catch (RuntimeException e) {
logger.error("Error while initializing Netty pipeline", e);
throw e;
}
}
}
创建TransportChannelHandler前,分别创建了两个应用层的处理器:
A、TransportRequestHandler :在服务端,处理Client发送过来的各种消息
B、TransportResponseHandler:在客户端,处理Server响应的各种消息
public class TransportContext {
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler, conf.maxChunksBeingTransferred());
return new TransportChannelHandler(client, responseHandler, requestHandler,
conf.connectionTimeoutMs(), closeIdleConnections);
}
}
TransportChannelHandler 根据当前消息类型进行处理:
A、RequestMessage调用requestHandler
B、ResponseMessage调用responseHandler
public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request);
} else if (request instanceof ResponseMessage) {
responseHandler.handle((ResponseMessage) request);
} else {
ctx.fireChannelRead(request);
}
}
}
TransportRequestHandler及TransportReponseHandler对消息的处理如下图,主要是从TransportRequestHandler及TransportReponseHandler对应的handler(request)函数开始。
主要分为Chunk、RPC、Stream三类消息的处理,及回调响应
public class TransportRequestHandler extends MessageHandler {
@Override
public void handle(RequestMessage request) {
if (request instanceof ChunkFetchRequest) {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
} else if (request instanceof OneWayMessage) {
processOneWayMessage((OneWayMessage) request);
} else if (request instanceof StreamRequest) {
processStreamRequest((StreamRequest) request);
} else if (request instanceof UploadStream) {
processStreamUpload((UploadStream) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
}
}
public class TransportResponseHandler extends MessageHandler {
@Override
public void handle(ResponseMessage message) throws Exception {
if (message instanceof ChunkFetchSuccess) {
ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
if (listener == null) {
resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
resp.body().release();
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
if (listener == null) {
...
} else {
outstandingFetches.remove(resp.streamChunkId);
listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
"Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
}
} else if (message instanceof RpcResponse) {
RpcResponse resp = (RpcResponse) message;
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
...
} else {
outstandingRpcs.remove(resp.requestId);
try {
listener.onSuccess(resp.body().nioByteBuffer());
} finally {
resp.body().release();
}
}
} else if (message instanceof RpcFailure) {
RpcFailure resp = (RpcFailure) message;
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
...
} else {
outstandingRpcs.remove(resp.requestId);
listener.onFailure(new RuntimeException(resp.errorString));
}
} else if (message instanceof StreamResponse) {
StreamResponse resp = (StreamResponse) message;
Pair entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry.getValue();
if (resp.byteCount > 0) {
StreamInterceptor interceptor = new StreamInterceptor<>(
this, resp.streamId, resp.byteCount, callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
deactivateStream();
}
} else {
callback.onComplete(resp.streamId);
}
}
} else if (message instanceof StreamFailure) {
StreamFailure resp = (StreamFailure) message;
Pair entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry.getValue();
try {
callback.onFailure(resp.streamId, new RuntimeException(resp.error));
}
} else {
logger.warn("Stream failure with unknown callback: {}", resp.error);
}
}
}
}
CC攻击检测项目中,我们创建了一个watchdog子项目,需要使用远程通信方式,检测server端各任务的计算进度(UTC时间)是否已经实时更新,如果计算进度跟当前系统时间差值超过10分钟,则重启应用。
在实现Netty RPC的过程中,我们参考Spark Netty分别创建了ServerBootstrap(Server端)、BootStrap(Client端),并设置了对应的WorkGroup,在initChannel()中增加处理通信时间的channelHandler,实现Server和Client创建
数据的发送和接收依赖于channelHandler.ctx.writeAndFlush()与channelHandler.channelRead(),在实现RPC过程中,参考Spark的实现方法:
发送消息:为每条消息进行编码,并生成当前消息的唯一UUID,创建回调函数。
接收消息:解码消息内容,根据返回消息的UUID,查找回调函数并调用,实现RPC