首先创建一个RpcEnv变量,用来存储各种信息
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
securityManager, numUsableCores, !isDriver)
该变量主要调用create方法,最后调用到NettyRpcEnvFactory的create方法,以工厂模式创建RpcEnv
private[spark] object RpcEnv {
def create(
name: String,
host: String,
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
clientMode: Boolean = false): RpcEnv = {
create(name, host, host, port, conf, securityManager, 0, clientMode)
}
def create(
name: String,
bindAddress: String,
advertiseAddress: String,
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean): RpcEnv = {
val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
numUsableCores, clientMode)
new NettyRpcEnvFactory().create(config)
}
}
查看NettpRpcEnvFactory
工厂的create方法,发现它创造了一个java序列化器,并且 new了一个NettyRpcEnv
类,使用工具类的startServiceOnPort
函数开启服务,该高级函数函数使用函数作为参数,然后内部调用该函数,该函数位nettyEnv.startServer(config.bindAddress, actualPort)
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
// Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
// KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
val javaSerializerInstance =
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
config.securityManager, config.numUsableCores)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(config.bindAddress, actualPort)
(nettyEnv, nettyEnv.address.port)
}
try {
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
startServer函数调用transportContext
的createServer
函数 private val transportContext = new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this, streamManager))
,它创建了一个TransportServer的类。
def startServer(bindAddress: String, port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager))
} else {
java.util.Collections.emptyList()
}
server = transportContext.createServer(bindAddress, port, bootstraps)
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
然后注册这个RpcEndpoint
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
创建createServer实际上是创建了一个TransportServer实例
/** Create a server which will attempt to bind to a specific host and port. */
public TransportServer createServer(
String host, int port, List bootstraps) {
return new TransportServer(this, host, port, rpcHandler, bootstraps);
}
TransportServer主要就是做了一些netty的操作,在init方法中就可以体现出来,创建bossgroup和workergroup,创建ServerBootstrap然后添加handler,初始化管道绑定端口。它添加的handler
public TransportServer(
TransportContext context,
String hostToBind,
int portToBind,
RpcHandler appRpcHandler,
List bootstraps) {
this.context = context;
this.conf = context.getConf();
this.appRpcHandler = appRpcHandler;
if (conf.sharedByteBufAllocators()) {
this.pooledAllocator = NettyUtils.getSharedPooledByteBufAllocator(
conf.preferDirectBufsForSharedByteBufAllocators(), true /* allowCache */);
} else {
this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
}
this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
boolean shouldClose = true;
try {
init(hostToBind, portToBind);
shouldClose = false;
} finally {
if (shouldClose) {
JavaUtils.closeQuietly(this);
}
}
}
private void init(String hostToBind, int portToBind) {
IOMode ioMode = IOMode.valueOf(conf.ioMode());
EventLoopGroup bossGroup = NettyUtils.createEventLoop(ioMode, 1,
conf.getModuleName() + "-boss");
EventLoopGroup workerGroup = NettyUtils.createEventLoop(ioMode, conf.serverThreads(),
conf.getModuleName() + "-server");
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
.option(ChannelOption.ALLOCATOR, pooledAllocator)
.option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)
.childOption(ChannelOption.ALLOCATOR, pooledAllocator);
this.metrics = new NettyMemoryMetrics(
pooledAllocator, conf.getModuleName() + "-server", conf);
if (conf.backLog() > 0) {
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
}
if (conf.receiveBuf() > 0) {
bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
}
if (conf.sendBuf() > 0) {
bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
}
if (conf.enableTcpKeepAlive()) {
bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true);
}
bootstrap.childHandler(new ChannelInitializer() {
@Override
protected void initChannel(SocketChannel ch) {
logger.debug("New connection accepted for remote address {}.", ch.remoteAddress());
RpcHandler rpcHandler = appRpcHandler;
for (TransportServerBootstrap bootstrap : bootstraps) {
rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
}
context.initializePipeline(ch, rpcHandler);
}
});
InetSocketAddress address = hostToBind == null ?
new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind);
channelFuture = bootstrap.bind(address);
channelFuture.syncUninterruptibly();
port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
logger.debug("Shuffle server started on port: {}", port);
}
至此NettyEnv已经创建完毕,我们都知道如果要使用RPC在Spark中一定要有一个Endpoint
def registerOrLookupEndpoint(
name: String, endpointCreator: => RpcEndpoint):
RpcEndpointRef = {
if (isDriver) {
logInfo("Registering " + name)
rpcEnv.setupEndpoint(name, endpointCreator)
} else {
RpcUtils.makeDriverRef(name, conf, rpcEnv)
}
}
比如创建块管理服务的时候就会创建Endpoint
val blockManagerMaster = new BlockManagerMaster(
registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_ENDPOINT_NAME,
new BlockManagerMasterEndpoint(
rpcEnv,
isLocal,
conf,
listenerBus,
if (conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)) {
externalShuffleClient
} else {
None
}, blockManagerInfo)),
registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),
conf,
isDriver)
查看块管理的创建
class BlockManagerMaster(
var driverEndpoint: RpcEndpointRef,
var driverHeartbeatEndPoint: RpcEndpointRef,
conf: SparkConf,
isDriver: Boolean)
随便找一个发送消息的内容查看
def getStorageStatus: Array[StorageStatus] = {
if (driverEndpoint == null) return Array.empty
driverEndpoint.askSync[Array[StorageStatus]](GetStorageStatus)
}
可以看到他调用的为askSync方法, 该方法为抽象类RpcEndpointRef的方法, 然后又调用了抽象类ask方法。调用后在超时时间内等待结果。ask方法为抽象方法,其实现为NettyRpcEndpointRef类
def askSync[T: ClassTag](message: Any): T = askSync(message, defaultAskTimeout)
def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
val future = ask[T](message, timeout)
timeout.awaitResult(future)
}
查看NettyRpcEndpoint类的ask方法。它调用的为askAbortable方法。当然了该方法也是继承自RpcEndpointRef抽象类
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { askAbortable(message, timeout).future}
这个的调用实际上就是调用nettyEnv的askAbortable方法,然后根据传来的消息后见了一个RequestMessage消息
override def askAbortable[T: ClassTag](
message: Any, timeout: RpcTimeout): AbortableRpcFuture[T] = {
nettyEnv.askAbortable(new RequestMessage(nettyEnv.address, this, message), timeout)
}
这里定义了一个Promise的写占位符,等待将结果写回。并且规定了不通状态的回调函数onFailure,onSuccess,onAbort。 并且判断消息发送给谁,如果发送给本地则调用dispatcher.postLocalMessage(message, p)
如果发送的是远程终端则先封装消息为RpcOutboxMessage
然后postToOutbox(message.receiver, rpcMessage)
方法,通过outbox进行消息传递。
private[netty] def askAbortable[T: ClassTag](
message: RequestMessage, timeout: RpcTimeout): AbortableRpcFuture[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
var rpcMsg: Option[RpcOutboxMessage] = None
def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
e match {
case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e")
case _ => logWarning(s"Ignored failure: $e")
}
}
}
def onSuccess(reply: Any): Unit = reply match {
case RpcFailure(e) => onFailure(e)
case rpcReply =>
if (!promise.trySuccess(rpcReply)) {
logWarning(s"Ignored message: $reply")
}
}
def onAbort(t: Throwable): Unit = {
onFailure(t)
rpcMsg.foreach(_.onAbort())
}
try {
if (remoteAddr == address) {
val p = Promise[Any]()
p.future.onComplete {
case Success(response) => onSuccess(response)
case Failure(e) => onFailure(e)
}(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p)
} else {
val rpcMessage = RpcOutboxMessage(message.serialize(this),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
rpcMsg = Option(rpcMessage)
postToOutbox(message.receiver, rpcMessage)
promise.future.failed.foreach {
case _: TimeoutException => rpcMessage.onTimeout()
case _ =>
}(ThreadUtils.sameThread)
}
val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
override def run(): Unit = {
val remoteReceAddr = if (remoteAddr == null) {
Try {
message.receiver.client.getChannel.remoteAddress()
}.toOption.orNull
} else {
remoteAddr
}
onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteReceAddr} " +
s"in ${timeout.duration}"))
}
}, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
promise.future.onComplete { v =>
timeoutCancelable.cancel(true)
}(ThreadUtils.sameThread)
} catch {
case NonFatal(e) =>
onFailure(e)
}
new AbortableRpcFuture[T](
promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread),
onAbort)
}
发送给远程终端的Outbox
时先判断是否持有远程端的TransportClient
如果持有则直接发送,如果为持有则根据接收端的RpcAddress
去检查对应的Outbox
,如果Outbox
也查不到则创建一个新的Outbox并且将这个Outbox缓存起来方便下次使用。找到之后嫌犯段该outbox是否已经停止。如果未停止才能发送消息向这个box中放置消息。此时消息并未发出去仅仅是放入了outbox。
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
message.sendWith(receiver.client)
} else {
require(receiver.address != null,
"Cannot send message to client endpoint with no listen address.")
val targetOutbox = {
val outbox = outboxes.get(receiver.address)
if (outbox == null) {
val newOutbox = new Outbox(this, receiver.address)
val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
if (oldOutbox == null) {
newOutbox
} else {
oldOutbox
}
} else {
outbox
}
}
if (stopped.get) {
// It's possible that we put `targetOutbox` after stopping. So we need to clean it.
outboxes.remove(receiver.address)
targetOutbox.stop()
} else {
targetOutbox.send(message)
}
}
}
这里也就是先向outbox的队列中放入了一个消息。然后调用drainOutbox()
将消息发送出去
def send(message: OutboxMessage): Unit = {
val dropped = synchronized {
if (stopped) {
true
} else {
messages.add(message)
false
}
}
if (dropped) {
message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
}
检查连接,如果没有就创建,然后从这个box中取出消息 ,最后还是调用 message.sendWith(_client)
private def drainOutbox(): Unit = {
var message: OutboxMessage = null
synchronized {
if (stopped) {
return
}
if (connectFuture != null) {
// We are connecting to the remote address, so just exit
return
}
if (client == null) {
// There is no connect task but client is null, so we need to launch the connect task.
launchConnectTask()
return
}
if (draining) {
// There is some thread draining, so just exit
return
}
message = messages.poll()
if (message == null) {
return
}
draining = true
}
while (true) {
try {
val _client = synchronized { client }
if (_client != null) {
message.sendWith(_client)
} else {
assert(stopped)
}
} catch {
case NonFatal(e) =>
handleNetworkFailure(e)
return
}
synchronized {
if (stopped) {
return
}
message = messages.poll()
if (message == null) {
draining = false
return
}
}
}
}
消息发送目前已经从发送方发送完毕。接下来消息的接收方会接收到这条消息。
初始化时消息循环会不断地接收消息。
可以看到MessageLoop循环的消息,他是从inbox中不断地take数据,直到take出来的数据是PoisonPill毒药才会把毒药重新放入队列并且返回。
private def receiveLoop(): Unit = {
try {
while (true) {
try {
val inbox = active.take()
if (inbox == MessageLoop.PoisonPill) {
// Put PoisonPill back so that other threads can see it.
setActive(MessageLoop.PoisonPill)
return
}
inbox.process(dispatcher)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case _: InterruptedException => // exit
case t: Throwable =>
try {
// Re-submit a receive task so that message delivery will still work if
// UncaughtExceptionHandler decides to not kill JVM.
threadpool.execute(receiveLoopRunnable)
} finally {
throw t
}
}
}
然后取到inboxx里面的消息,再消耗掉一个线程进行处理,然后根据消息的类型进行模式匹配,匹配类型分为
这7种类型,
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
可以看到遇到了RpcMessage之后就会进行处理,这是一个偏函数,如果可以处理就处理,如果不可以处理就抛出异常。
//此为EpcEndpoint特质的方法,其具体的处理流程写在子类覆写的方法中,多种处理方式正体现了多态的强大
def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case _ => context.sendFailure(new SparkException(self + " won't reply anything"))
}
//看一下心跳的覆写,它调用了context的reply方法进行回复消息。
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
// Messages sent and received locally
case ExecutorRegistered(executorId) =>
executorLastSeen(executorId) = clock.getTimeMillis()
context.reply(true)
case ExecutorRemoved(executorId) =>
executorLastSeen.remove(executorId)
context.reply(true)
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
context.reply(true)
case ExpireDeadHosts =>
expireDeadHosts()
context.reply(true)
// Messages received from executors
case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId, executorUpdates) =>
if (scheduler != null) {
if (executorLastSeen.contains(executorId)) {
executorLastSeen(executorId) = clock.getTimeMillis()
eventLoopThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
val unknownExecutor = !scheduler.executorHeartbeatReceived(
executorId, accumUpdates, blockManagerId, executorUpdates)
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
context.reply(response)
}
})
} else {
// This may happen if we get an executor's in-flight heartbeat immediately
// after we just removed it. It's not really an error condition so we should
// not log warning here. Otherwise there may be a lot of noise especially if
// we explicitly remove executors (SPARK-4134).
logDebug(s"Received heartbeat from unknown executor $executorId")
context.reply(HeartbeatResponse(reregisterBlockManager = true))
}
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
// register itself again.
logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet")
context.reply(HeartbeatResponse(reregisterBlockManager = true))
}
}
回复消息的调用
override def reply(response: Any): Unit = {
send(response)
}
//这个调用又是调用了RpcResponseCallback的onsucess方法
override protected def send(message: Any): Unit = {
val reply = nettyEnv.serialize(message)
callback.onSuccess(reply)
}
再看onsucess方法,它是在子类RpcOutboxMessage中实现的
override def onSuccess(response: ByteBuffer): Unit = {
_onSuccess(client, response)
}
而这里的onsucess呢正是RpcOutboxMessage消息的回调函数
private[netty] case class RpcOutboxMessage(
content: ByteBuffer,
_onFailure: (Throwable) => Unit,
_onSuccess: (TransportClient, ByteBuffer) => Unit)
可以看到涉及到发送消息都使用了OutBox,涉及到了接受消息都使用了Inbox。而dispatchher则是中间调度器,而NettyRpcEnv则是环境。