继5-RpcEnv(Rpc抽象层) 之后,我们再来了解下Rpc框架下的实现层。
上一节里RpcEnv里create函数调用的是NettyRpcEnvFactory的create函数。
NettyRpcEnvFactory
NettyRpcEnvFactory类位于NettyRpcEnv.scala文件,其create函数实现如下:
private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
// Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
// KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
val javaSerializerInstance =
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
config.securityManager)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(config.bindAddress, actualPort)
(nettyEnv, nettyEnv.address.port)
}
try {
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
}
NettyRpcEnvFactory创建了NettyRpcEnv之后,如果clientMode为false,即服务端(Driver端Rpc通讯),则使用创建出的NettyRpcEnv的函数startServer定义一个函数变量startNettyRpcEnv((nettyEnv, nettyEnv.address.port)为函数的返回值),将该函数作为参数传递给函数Utils.startServiceOnPort,即在Driver端启动服务。
这里可以进入Utils.startServiceOnPort这个函数看看源代码,可以看出为什么不直接调用nettyEnv.startServer,而要把它封装起来传递给工具类来调用:在这个端口启动服务不一定一次就能成功,工具类里对失败的情况做最大次数的尝试,直到启动成功并返回启动成功后的端口。
我们再来研究下NettyRpcEnv类,该类继承RpcEnv,具有伴生对象。伴生对象仅维持两个对象currentEnv和currentClient(在NettyRpcEndpointRef反序列化时使用,暂时不太明白什么意思):
private[netty] object NettyRpcEnv extends Logging {
/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
*
* {
{
{
* NettyRpcEnv.currentEnv.withValue(this) {
* your deserialization codes
* }
* }}}
*/
private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
/**
* Similar to `currentEnv`, this variable references the client instance associated with an
* RPC, in case it's needed to find out the remote address during deserialization.
*/
private[netty] val currentClient = new DynamicVariable[TransportClient](null)
}
继续看看伴生类NettyRpcEnv(明天继续…)
NettyRpcEnv的构造函数中创建了一些私有变量,如下:
package org.apache.spark.rpc.netty
......
private[netty] class NettyRpcEnv(
val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance,
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
private val streamManager = new NettyStreamManager(this)
private val transportContext = new TransportContext(transportConf,
new NettyRpcHandler(dispatcher, this, streamManager))
private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
if (securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
securityManager.isSaslEncryptionEnabled()))
} else {
java.util.Collections.emptyList[TransportClientBootstrap]
}
}
private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
/**
* A separate client factory for file downloads. This avoids using the same RPC handler as
* the main RPC context, so that events caused by these clients are kept isolated from the
* main RPC traffic.
* * It also allows for different configuration of certain properties, such as the number of
* connections per peer.
*/
@volatile private var fileDownloadFactory: TransportClientFactory = _
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
// to implement non-blocking send/ask.
// TODO: a non-blocking TransportClientFactory.createClient in future
private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 64))
@volatile private var server: TransportServer = _
private val stopped = new AtomicBoolean(false)
/**
* A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
* we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
*/
private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
......
......
}
我们先了解下这些成员变量dispatcher, streamManager, transportContext, clientFactory, fileDownloadFactory, clientConnectionExecutor, server。
private val dispatcher: Dispatcher = new Dispatcher(this)
Dispatcher类是一个消息分发器,负责将RPC消息发送到适当的端点。该类有一个内部类EndpointData,包含端点/端点引用/收件箱Inbox。类Dispatcher包含3个端点及引用相关的私有变量endpoints, endpointRefs, receivers。包含成员函数registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef, 该函数向Dispatcher注册端点(添加到上述3个变量里), 并返回创建的端点引用(注意:这个地方返回的端点引用中监听地址与传递的参数没有关系,也就是说返回的端点引用都是同一个端点引用,Dispatcher所在的端点引用?看不懂这里,待后续了解,同时欢迎赐教。从后续的NettyRpcEnv中函数send和ask可以看出,dispatcher用于处理发往本地Endpoint的消息,发往远程端的消息是放入Outbox中,这大概是为什么吧!)。
Dispatcher还包括RpcEndpointRef的获取/移除/取消注册等函数。
私有函数postMessage向特定端点发送消息,该函数实现是把InboxMessage消息实例放入特定端点的Inbox,同时把这个特定端点的EndpointData放入receivers中,由receivers追踪。
私有函数postMessage被public函数postToAll, postRemoteMessage, postLocalMessage, postOneWayMessage调用。postToAll把消息发给所有注册过端点;postRemoteMessage将参数RequestMessage和RpcResponseCallback组装成RpcMessage放入receiver对应的inbox;postLocalMessage与postRemoteMessage类似,只是RpcCallContext略有差别;postOneWayMessage通过RequestMessage组装成没有RpcCallContext的OneWayMessage放入receiver对应的inbox。
Dispatcher还包含一个继承了Runnable(具有抽象函数run()的java interface,主要用于线程执行)的内部类MessageLoop,用于消息处理:循环地从receivers取出具有消息的EndpointData,调用inbox处理消息,直到取到PoisonPill(成员为null的EndpointData,是一种标记,表示需要跳出消息循环,Dispatcher的stop函数放入的),取到之后也要再放进去,以便其他消息循环退出。
Dispatcher内部维护着一个线程池threadpool: ThreadPoolExecutor, 通过孤立对象ThreadUtils的函数newDaemonFixedThreadPool创建包含指定数量线程的线程池,给每个线程new一个MessageLoop实例让其运行。
Dispatcher还有public的stop函数,函数中依次取消注册的端点,给receivers队列里放入PoisonPill,以便MessageLoop退出,调用线程池的shutdown函数。
package org.apache.spark.rpc.netty
......
/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private class EndpointData(
val name: String,
val endpoint: RpcEndpoint,
val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
}
private val endpoints: ConcurrentMap[String, EndpointData] =
new ConcurrentHashMap[String, EndpointData]
private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]
/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
* immediately.
*/
@GuardedBy("this")
private var stopped = false
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val addr = RpcEndpointAddress(nettyEnv.address, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
val data = endpoints.get(name)
endpointRefs.put(data.endpoint, data.ref)
receivers.offer(data) // for the OnStart message
}
endpointRef
}
def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)
def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)
// Should be idempotent
private def unregisterRpcEndpoint(name: String): Unit = {
val data = endpoints.remove(name)
if (data != null) {
data.inbox.stop()
receivers.offer(data) // for the OnStop message
}
// Don't clean `endpointRefs` here because it's possible that some messages are being processed
// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
// `removeRpcEndpointRef`.
}
def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
synchronized {
if (stopped) {
// This endpoint will be stopped by Dispatcher.stop() method.
return
}
unregisterRpcEndpoint(rpcEndpointRef.name)
}
}
/**
* Send a message to all registered [[RpcEndpoint]]s in this process.
* * This can be used to make network events known to all end points (e.g. "a new node connected").
*/
def postToAll(message: InboxMessage): Unit = {
val iter = endpoints.keySet().iterator()
while (iter.hasNext) {
val name = iter.next
postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}"))
}
}
/** Posts a message sent by a remote endpoint. */
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
val rpcCallContext =
new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
/** Posts a message sent by a local endpoint. */
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
val rpcCallContext =
new LocalNettyRpcCallContext(message.senderAddress, p)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
}
/** Posts a one-way message. */
def postOneWayMessage(message: RequestMessage): Unit = {
postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
(e) => throw e)
}
/**
* Posts a message to a specific endpoint.
* * @param endpointName name of the endpoint.
* @param message the message to post
* @param callbackIfStopped callback function if the endpoint is stopped.
*/
private def postMessage(
endpointName: String,
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
val data = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
} else if (data == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
data.inbox.post(message)
receivers.offer(data)
None
}
}
// We don't need to call `onStop` in the `synchronized` block
error.foreach(callbackIfStopped)
}
def stop(): Unit = {
synchronized {
if (stopped) {
return
}
stopped = true
}
// Stop all endpoints. This will queue all endpoints for processing by the message loops.
endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
// Enqueue a message that tells the message loops to stop.
receivers.offer(PoisonPill)
threadpool.shutdown()
}
def awaitTermination(): Unit = {
threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
}
/**
* Return if the endpoint exists
*/
def verify(name: String): Boolean = {
endpoints.containsKey(name)
}
/** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = {
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, Runtime.getRuntime.availableProcessors()))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
/** A poison endpoint that indicates MessageLoop should exit its message loop. */
private val PoisonPill = new EndpointData(null, null, null)
}
我们再看看Dispatcher里面用到的EndpointData中包含类Inbox,Inbox存放消息,并且提供处理消息的函数。Inbox所在源文件Inbox.scala。
Inbox中首先包含一些InboxMessage的定义,继承sealed trait InboxMessage。样例类OneWayMessage, RpcMessage, 和一些特殊的InboxMessage:OnStart, OnStop, RemoteProcessConnected, RemoteProcessDisconnected, RemoteProcessConnectionError。
类Inbox内部存放messages的是一个LinkedList[InboxMessage], 维持的另外几个变量: stopped表示Inbox是否停止; enableConcurrent表示是否并发执行(Dispatcher里面是多个线程处理消息,那么同一个RpcEndpointData的Inbox就可以被多个线程调用process),Inbox启动时会置为true, Inbox停止时会置为false; numActiveThreads表示在处理该RpcEndpointData的Inbox里的消息的线程数。
主要函数process(dispatcher: Dispatcher): Unit处理消息。
处理OnStart消息 该消息在Inbox的构造函数中放入消息LinkedList;调用RpcEndpoint的onStart(),另外如果不是ThreadSafeRpcEndpoint,则把并行标记置为true;
处理OnStop消息 该消息在Inbox停止时放入LinkedList;Dispatcher停止时,OnStop是每个RpcEndpointData最后一个放入的消息,放入之前把Dispatcher的字段stopped置为true, postMessage就放入不了消息;故该条消息应该是该Inbox最后一条消息,程序中有assert(activeThreads == 1,…),表示处理该消息的线程是调用该Inbox的最后一个线程;将该RpcEndpointRef从Dispatcher中移除,调用该RpcEndpoint的onStop()。
处理RemoteProcessConnected, RemoteProcessDisconnected, RemoteProcessConnectionError消息时都是直接调用RpcEndpoint的相应函数。
处理RpcMessage 调用RpcEndpoint的receiveAndReply函数,取决于RpcEndpoint的具体实现。
处理OneWayMessage 调用RpcEndpoint的receive函数,取决于RpcEndpoint的具体实现。
package org.apache.spark.rpc.netty
......
private[netty] sealed trait InboxMessage
private[netty] case class OneWayMessage(
senderAddress: RpcAddress,
content: Any) extends InboxMessage
private[netty] case class RpcMessage(
senderAddress: RpcAddress,
content: Any,
context: NettyRpcCallContext) extends InboxMessage
private[netty] case object OnStart extends InboxMessage
private[netty] case object OnStop extends InboxMessage
/** A message to tell all endpoints that a remote process has connected. */
private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage
/** A message to tell all endpoints that a remote process has disconnected. */
private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage
/** A message to tell all endpoints that a network error has happened. */
private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress)
extends InboxMessage
/**
* An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
*/
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
val endpoint: RpcEndpoint)
extends Logging {
inbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
protected val messages = new java.util.LinkedList[InboxMessage]()
/** True if the inbox (and its associated endpoint) is stopped. */
@GuardedBy("this")
private var stopped = false
/** Allow multiple threads to process messages at the same time. */
@GuardedBy("this")
private var enableConcurrent = false
/** The number of threads processing messages for this inbox. */
@GuardedBy("this")
private var numActiveThreads = 0
// OnStart should be the first message to process
inbox.synchronized {
messages.add(OnStart)
}
/**
* Process stored messages.
*/
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case NonFatal(e) =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
def post(message: InboxMessage): Unit = inbox.synchronized {
if (stopped) {
// We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
} else {
messages.add(message)
false
}
}
def stop(): Unit = inbox.synchronized {
// The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
// message
if (!stopped) {
// We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
// thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
// safely.
enableConcurrent = false
stopped = true
messages.add(OnStop)
// Note: The concurrent events in messages will be processed one by one.
}
}
def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }
/**
* Called when we are dropping a message. Test cases override this to test message dropping.
* Exposed for testing.
*/
protected def onDrop(message: InboxMessage): Unit = {
logWarning(s"Drop $message because $endpointRef is stopped")
}
/**
* Calls action closure, and calls the endpoint's onError function in the case of exceptions.
*/
private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
try action catch {
case NonFatal(e) =>
try endpoint.onError(e) catch {
case NonFatal(ee) => logError(s"Ignoring error", ee)
}
}
}
}
private val streamManager = new NettyStreamManager(this)
NettyStreamManager继承trait RpcEnvFileServer(功能方面),主要用于NettyRpcEnv环境下的文件管理和服务,同时继承abstract StreamManager(实现方面),源文件为NettyStreamManager.scala
package org.apache.spark.rpc.netty
......
/**
* StreamManager implementation for serving files from a NettyRpcEnv.
*
* Three kinds of resources can be registered in this manager, all backed by actual files:
*
* - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]].
* - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]].
* - arbitrary directories; all files under the directory become available through the manager,
* respecting the directory's hierarchy.
*
* Only streaming (openStream) is supported.
*/
private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
extends StreamManager with RpcEnvFileServer {
private val files = new ConcurrentHashMap[String, File]()
private val jars = new ConcurrentHashMap[String, File]()
private val dirs = new ConcurrentHashMap[String, File]()
override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
throw new UnsupportedOperationException()
}
override def openStream(streamId: String): ManagedBuffer = {
val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
val file = ftype match {
case "files" => files.get(fname)
case "jars" => jars.get(fname)
case other =>
val dir = dirs.get(ftype)
require(dir != null, s"Invalid stream URI: $ftype not found.")
new File(dir, fname)
}
if (file != null && file.isFile()) {
new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
} else {
null
}
}
override def addFile(file: File): String = {
val existingPath = files.putIfAbsent(file.getName, file)
......
}
override def addJar(file: File): String = {
val existingPath = jars.putIfAbsent(file.getName, file)
......
}
override def addDirectory(baseUri: String, path: File): String = {
val fixedBaseUri = validateDirectoryUri(baseUri)
require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null,
......
}
}
(明天继续…)
类TransportContext所在源文件TransportContext.java,该类负责Rpc消息的传输,涉及Netty通讯方式的具体实现(主要为server和client创建消息的传输通道?!!)。主要函数有createServer, createClientFactory, initializePipeline。其中又涉及类TransportClientFactory, TransportClientBootstrap, TransportServer, TransportServerBootstrap, TransportChannelHandler,这些类已经都是Netty通讯的具体实现,是用java实现的,待后续展开研究。
package org.apache.spark.network;
......
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
......
import org.apache.spark.network.client.TransportResponseHandler;
......
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
......
/**
* Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
* setup Netty Channel pipelines with a
* {@link org.apache.spark.network.server.TransportChannelHandler}.
* * There are two communication protocols that the TransportClient provides, control-plane RPCs and
* data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
* TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
* which can be streamed through the data plane in chunks using zero-copy IO.
* * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
* channel. As each TransportChannelHandler contains a TransportClient, this enables server
* processes to send messages back to the client on an existing channel.
*/
public class TransportContext {
private static final Logger logger = LoggerFactory.getLogger(TransportContext.class);
private final TransportConf conf;
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;
private final MessageEncoder encoder;
private final MessageDecoder decoder;
public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this(conf, rpcHandler, false);
}
public TransportContext(
TransportConf conf,
RpcHandler rpcHandler,
boolean closeIdleConnections) {
this.conf = conf;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
this.closeIdleConnections = closeIdleConnections;
}
/**
* Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
* a new Client. Bootstraps will be executed synchronously, and must run successfully in order
* to create a Client.
*/
public TransportClientFactory createClientFactory(List bootstraps) {
return new TransportClientFactory(this, bootstraps);
}
public TransportClientFactory createClientFactory() {
return createClientFactory(Lists.newArrayList());
}
/** Create a server which will attempt to bind to a specific port. */
public TransportServer createServer(int port, List bootstraps) {
return new TransportServer(this, null, port, rpcHandler, bootstraps);
}
/** Create a server which will attempt to bind to a specific host and port. */
public TransportServer createServer(
String host, int port, List bootstraps) {
return new TransportServer(this, host, port, rpcHandler, bootstraps);
}
/** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer(List bootstraps) {
return createServer(0, bootstraps);
}
public TransportServer createServer() {
return createServer(0, Lists.newArrayList());
}
public TransportChannelHandler initializePipeline(SocketChannel channel) {
return initializePipeline(channel, rpcHandler);
}
/**
* Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
* has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
* response messages.
* * @param channel The channel to initialize.
* @param channelRpcHandler The RPC handler to use for the channel.
* * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
* be used to communicate on this channel. The TransportClient is directly associated with a
* ChannelHandler to ensure all users of the same channel get the same TransportClient object.
*/
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
.addLast("handler", channelHandler);
return channelHandler;
} catch (RuntimeException e) {
logger.error("Error while initializing Netty pipeline", e);
throw e;
}
}
/**
* Creates the server- and client-side handler which is used to handle both RequestMessages and
* ResponseMessages. The channel is expected to have been successfully created, though certain
* properties (such as the remoteAddress()) may not be available yet.
*/
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler,
conf.connectionTimeoutMs(), closeIdleConnections);
}
public TransportConf getConf() { return conf; }
}
类NettyRpcHandler在NettyRpcEnv中new出来作为TransportContext构造函数参数传入:
private val transportContext = new TransportContext(transportConf,
new NettyRpcHandler(dispatcher, this, streamManager))
NettyRpcHandler在文件NettyRpcEnv.scala中,NettyRpcHandler继承抽象类RpcHandler。RpcHandler在文件RpcHandler.java中,所在的package是在network的server命名空间里,故应为server处理Rpc消息的类。RpcHandler用于处理TransportClient发送的Rpc消息,在其receive函数中处理Rpc消息,也有channelActive和channelInactive函数,处理与客户端的通讯channel的连接状态。
package org.apache.spark.network.server;
......
/**
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
*/
public abstract class RpcHandler {
private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
/**
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
* the client in string form as a standard RPC failure.
*
* This method will not be called in parallel for a single TransportClient (i.e., channel).
*
* @param client A channel client which enables the handler to make requests back to the sender
* of this RPC. This will always be the exact same object for a particular channel.
* @param message The serialized bytes of the RPC.
* @param callback Callback which should be invoked exactly once upon success or failure of the
* RPC.
*/
public abstract void receive(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback);
/**
* Returns the StreamManager which contains the state about which streams are currently being
* fetched by a TransportClient.
*/
public abstract StreamManager getStreamManager();
/**
* Receives an RPC message that does not expect a reply.
......
*/
public void receive(TransportClient client, ByteBuffer message) {
receive(client, message, ONE_WAY_CALLBACK);
}
/**
* Invoked when the channel associated with the given client is active.
*/
public void channelActive(TransportClient client) { }
/**
* Invoked when the channel associated with the given client is inactive.
* No further requests will come from this client.
*/
public void channelInactive(TransportClient client) { }
public void exceptionCaught(Throwable cause, TransportClient client) { }
private static class OneWayRpcCallback implements RpcResponseCallback {
......
}
}
NettyRpcHandler实现RpcHandler的接口,因为给server发送Rpc消息的client不止一个,故NettyRpcHandler内部维护了一个remoteAddresses: ConcurrentHashMap[RpcAddress, RpcAddress]跟踪给它发过消息的client。
receive函数中,转换成RemoteMessage或OneWayMessage放入dispatcher;如果client是初次发送消息给该server,则把client的socket地址添加到remoteAddresses中,并且给dispatcher中所有的Endpoint发送RemoteProcessConnected(remoteEnvAddress)消息。
channelActive函数中把RemoteProcessConnected(clientAddr)消息发送给dispatcher中所有的Endpoint;channelInactive函数中把该client的outbox移除了,把remoteAddresses对该client的跟踪移除了,把RemoteProcessConnected(clientAddr)消息发送给dispatcher中所有的Endpoint,如果remoteAddresses不为null,则也把RemoteProcessDisconnected(remoteEnvAddress)消息发送给dispatcher中所有的Endpoint。clientAddr和remoteEnvAddress作为参数的连接状态消息有什么不同,暂时不是很明了,待后续了解。
/**
* Dispatches incoming RPCs to registered endpoints.
* * The handler keeps track of all client instances that communicate with it, so that the RpcEnv
* knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e.,
* one that is not listening for incoming connections, but rather needs to be contacted via the
* client socket).
* * Events are sent on a per-connection basis, so if a client opens multiple connections to the
* RpcEnv, multiple connection / disconnection events will be created for that client (albeit
* with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
dispatcher: Dispatcher,
nettyEnv: NettyRpcEnv,
streamManager: StreamManager) extends RpcHandler with Logging {
// A variable to track the remote RpcEnv addresses of all clients
private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()
override def receive(
client: TransportClient,
message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
val messageToDispatch = internalReceive(client, message)
dispatcher.postRemoteMessage(messageToDispatch, callback)
}
override def receive(
client: TransportClient,
message: ByteBuffer): Unit = {
val messageToDispatch = internalReceive(client, message)
dispatcher.postOneWayMessage(messageToDispatch)
}
private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
if (requestMessage.senderAddress == null) {
// Create a new message with the socket address of the client as the sender.
RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
} else {
// The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
// the listening address
val remoteEnvAddress = requestMessage.senderAddress
if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) {
dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
}
requestMessage
}
}
override def getStreamManager: StreamManager = streamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))
// If the remove RpcEnv listens to some address, we should also fire a
// RemoteProcessConnectionError for the remote RpcEnv listening address
val remoteEnvAddress = remoteAddresses.get(clientAddr)
if (remoteEnvAddress != null) {
dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress))
}
} else {
// If the channel is closed before connecting, its remoteAddress will be null.
// See java.net.Socket.getRemoteSocketAddress
// Because we cannot get a RpcAddress, just log it
logError("Exception before connecting to the client", cause)
}
}
override def channelActive(client: TransportClient): Unit = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
dispatcher.postToAll(RemoteProcessConnected(clientAddr))
}
override def channelInactive(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
val remoteEnvAddress = remoteAddresses.remove(clientAddr)
// If the remove RpcEnv listens to some address, we should also fire a
// RemoteProcessDisconnected for the remote RpcEnv listening address
if (remoteEnvAddress != null) {
dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress))
}
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
// See java.net.Socket.getRemoteSocketAddress
}
}
}
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
// to implement non-blocking send/ask.
// TODO: a non-blocking TransportClientFactory.createClient in future
private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 64))
/**
* A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
* we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
*/
private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
try {
dispatcher.postOneWayMessage(message)
} catch {
case e: RpcEnvStoppedException => logWarning(e.getMessage)
}
} else {
// Message to a remote RPC endpoint.
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
}
}
private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
logWarning(s"Ignored failure: $e")
}
}
def onSuccess(reply: Any): Unit = reply match {
case RpcFailure(e) => onFailure(e)
case rpcReply =>
if (!promise.trySuccess(rpcReply)) {
logWarning(s"Ignored message: $reply")
}
}
try {
if (remoteAddr == address) {
val p = Promise[Any]()
p.future.onComplete {
case Success(response) => onSuccess(response)
case Failure(e) => onFailure(e)
}(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p)
} else {
val rpcMessage = RpcOutboxMessage(serialize(message),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
postToOutbox(message.receiver, rpcMessage)
promise.future.onFailure {
case _: TimeoutException => rpcMessage.onTimeout()
case _ =>
}(ThreadUtils.sameThread)
}
val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
override def run(): Unit = {
onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}"))
}
}, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
promise.future.onComplete { v =>
timeoutCancelable.cancel(true)
}(ThreadUtils.sameThread)
} catch {
case NonFatal(e) =>
onFailure(e)
}
promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
Outbox与Inbox结构上大体相似,但是消息发送方式不太一样。Outbox同样维持一个消息列表:LinkedList[OutboxMessage]。Inbox中消息post进来后不负责发送,由dispatcher中的线程池循环取消息发送;Outbox中send和ask把消息放入消息列表后,需要主动调用函数drainOutbox(),循环读取所有消息并发送。所以Outbox的send和ask是同步函数,send函数是NettyRpcEnv中给远程端发送消息的postToOutbox函数调用的,并且只有在函数postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit的参数receiver.address为空的情况下,不为空则不经过Outbox,直接发送。
在Outbox 的消息处理函数drainOutbox()中,如果初次给该远端发送消息,则需要调用NettyRpcEnv的线程池clientConnectionExecutor来建立连接。
package org.apache.spark.rpc.netty
......
private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
outbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
private val messages = new java.util.LinkedList[OutboxMessage]
@GuardedBy("this")
private var client: TransportClient = null
/**
* connectFuture points to the connect task. If there is no connect task, connectFuture will be
* null.
*/
@GuardedBy("this")
private var connectFuture: java.util.concurrent.Future[Unit] = null
@GuardedBy("this")
private var stopped = false
/**
* If there is any thread draining the message queue
*/
@GuardedBy("this")
private var draining = false
/**
* Send a message. If there is no active connection, cache it and launch a new connection. If
* [[Outbox]] is stopped, the sender will be notified with a [[SparkException]].
*/
def send(message: OutboxMessage): Unit = {
val dropped = synchronized {
if (stopped) {
true
} else {
messages.add(message)
false
}
}
if (dropped) {
message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
}
/**
* Drain the message queue. If there is other draining thread, just exit. If the connection has
* not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the
* connection.
*/
private def drainOutbox(): Unit = {
var message: OutboxMessage = null
synchronized {
if (stopped) {
return
}
if (connectFuture != null) {
// We are connecting to the remote address, so just exit
return
}
if (client == null) {
// There is no connect task but client is null, so we need to launch the connect task.
launchConnectTask()
return
}
if (draining) {
// There is some thread draining, so just exit
return
}
message = messages.poll()
if (message == null) {
return
}
draining = true
}
while (true) {
try {
val _client = synchronized { client }
if (_client != null) {
message.sendWith(_client)
} else {
assert(stopped == true)
}
} catch {
case NonFatal(e) =>
handleNetworkFailure(e)
return
}
synchronized {
if (stopped) {
return
}
message = messages.poll()
if (message == null) {
draining = false
return
}
}
}
}
private def launchConnectTask(): Unit = {
connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] {
override def call(): Unit = {
try {
val _client = nettyEnv.createClient(address)
outbox.synchronized {
client = _client
if (stopped) {
closeClient()
}
}
} catch {
case ie: InterruptedException =>
// exit
return
case NonFatal(e) =>
outbox.synchronized { connectFuture = null }
handleNetworkFailure(e)
return
}
outbox.synchronized { connectFuture = null }
// It's possible that no thread is draining now. If we don't drain here, we cannot send the
// messages until the next message arrives.
drainOutbox()
}
})
}
/**
* Stop [[Inbox]] and notify the waiting messages with the cause.
*/
private def handleNetworkFailure(e: Throwable): Unit = {
synchronized {
assert(connectFuture == null)
if (stopped) {
return
}
stopped = true
closeClient()
}
// Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along
// with a new connection
nettyEnv.removeOutbox(address)
// Notify the connection failure for the remaining messages
//
// We always check `stopped` before updating messages, so here we can make sure no thread will
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
message.onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
}
private def closeClient(): Unit = synchronized {
// Just set client to null. Don't close it in order to reuse the connection.
client = null
}
/**
* Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a
* [[SparkException]].
*/
def stop(): Unit = {
synchronized {
if (stopped) {
return
}
stopped = true
if (connectFuture != null) {
connectFuture.cancel(true)
}
closeClient()
}
// We always check `stopped` before updating messages, so here we can make sure no thread will
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
}
至此,RpcEnv的实现类NettyRpcEnv的主要结构大致了解了一遍,里面各种类及其职责简单了解了下,后续如果有时间,想画一个类的关系图,能更加直观清晰的知道各类的关系。NettyRpcEnv中涉及的Netty通讯框架的内容就更加深入了,有时间可以继续学习下。