NCCL源码解析: proxy 线程

文章目录

  • 前言
  • 概括
  • 详解
      • 1. 用到的变量
      • 2. proxy 线程创建
        • 2.1 ncclProxyService()
        • 2.2 proxyServiceInitOp()
        • 2.2 proxyProgressAsync()
      • 4. ncclProxyConnect()
        • 4.1 ncclProxyCallBlocking()
        • 4.2 ncclPollProxyResponse()

前言

NCCL 源码解析总目录

我尽量在每个函数之前介绍每个函数的作用,建议先不要投入到函数内部实现,先把函数作用搞清楚,有了整体框架,再回归到细节。

习惯: 我的笔记习惯:为了便于快速理解,函数调用关系通过缩进表示,也可能是函数展开,根据情况而定。

如下

// 调用 proxyConnInit
NCCLCHECK(proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection));
// 对函数 proxyConnInit 进行展开,可方便看参数
static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, ncclProxyInitReq* req, ncclProxyInitResp* resp, struct 

如有问题,请留言指正。

图后面再补;
有些遗漏之处,还没涉及,后面补;
闲话后面再补。

概括

每个GPU对应一个管理线程或者进程,在卡与卡之间建立通信的时候,会额外创建一个代理线程去完成这件事,代理线程是被动的,该做什么事还是由GPU对应的管理线程去通过TCP下发。
代理线程的主要工作有:

  1. 监听TCP端口
  2. 调用 ncclTransportComm 的 proxySharedInit, proxySetup,proxyConnect
  3. 关闭TCP链接

详解

1. 用到的变量

主要关注 comm->proxyState 的初始化,后面会作为理线程参数代使用,用到的时候再来看也行。

// 初始化
commAlloc()
	NCCLCHECK(ncclCalloc(&sharedRes, 1));
bootstrapInit()
	// proxy is aborted through a message; don't set abortFlag
	// 申请内存
	NCCLCHECK(ncclCalloc(&proxySocket, 1));
	// 建立 socket -> proxySocket
	NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag));
	// Listen 状态
	NCCLCHECK(ncclSocketListen(proxySocket));
	// 获取地址保存在 state->peerProxyAddresses + rank , IP + Port
	NCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank));
		struct bootstrapState* state;
		comm->bootstrap = state;
	// 所有节点聚合, state->peerProxyAddresses 保存全部地址
	NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress)));
	// 申请内存初始化 comm->proxyState
	NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses));
		NCCLCHECK(ncclCalloc(&comm->sharedRes->proxyState, 1));
		comm->proxyState = comm->sharedRes->proxyState;
		comm->proxyState->refCount = 1;
		comm->proxyState->listenSock = proxySocket;
		comm->proxyState->peerAddresses = state->peerProxyAddresses;

2. proxy 线程创建

主要通过 ncclProxyCreate() 进行 proxyState 对象属性初始化,NCCL 初始化时会创建线程ncclProxyService

ncclProxyCreate(comm) 
{
	// proxyState 来自 comm->proxyState
	struct ncclProxyState* proxyState = comm->proxyState;
	// 属性初始化,每个属性什么用,用到的时候介绍
	proxyState->tpRank = comm->rank;
	proxyState->tpnRanks = comm->nRanks;
	proxyState->tpLocalnRanks = comm->localRanks;
	proxyState->cudaDev = comm->cudaDev;
	proxyState->abortFlag = comm->abortFlag;
	proxyState->p2pnChannels = comm->p2pnChannels;
	proxyState->p2pChunkSize = comm->p2pChunkSize;
	proxyState->nChannels = comm->nChannels;
	proxyState->allocP2pNetLLBuffers = comm->allocP2pNetLLBuffers;
	proxyState->dmaBufSupport = comm->dmaBufSupport;
	proxyState->ncclNet = comm->ncclNet;
	proxyState->ncclCollNet = comm->ncclCollNet;
	memcpy(proxyState->buffSizes, comm->buffSizes, sizeof(comm->buffSizes));
	// 创建线程
	pthread_create(&comm->proxyState->thread, NULL, ncclProxyService, comm->proxyState);
}
2.1 ncclProxyService()

proxy 服务线程代码, 一个设备起一个 proxy 线程,线程名为 NCCL Service %rank
线程主要做三件事:

  1. 建立TCP连接
  2. 根据每个卡的线程客户端命令 type 做事
  3. 关闭TCP连接

type 定义如下:

enum ncclProxyMsgType {
  ncclProxyMsgInit = 1,      // 建立 tcp 连接
  ncclProxyMsgSharedInit = 2, // 代理线程调用 ncclTransportComm 的 proxySharedInit
  ncclProxyMsgSetup = 3,   // 代理线程调用 ncclTransportComm 的 proxySetup
  ncclProxyMsgConnect = 4, // 代理线程调用 ncclTransportComm 的 proxyConnect
  ncclProxyMsgStart = 5,   // 还没用
  ncclProxyMsgClose = 6,   // 关闭 TCP 链接
  ncclProxyMsgAbort = 7,   // 还没用
  ncclProxyMsgStop = 8,   // 停用链接,如果所有链接都停用了,代理线程才退出
  ncclProxyMsgConvertFd = 9, // cuMem API support (UDS)
};

线程中主要的处理函数是 proxyServiceInitOp()

// 参数
args = comm->proxyState
void* ncclProxyService(void* _args) {
	struct ncclProxyState* proxyState =  (struct ncclProxyState*) _args;

	// Prepare poll descriptor
	struct ncclProxyConnectionPool connectionPool;
	connectionPool.pools = NULL;
	connectionPool.banks = 0;
	connectionPool.offset = NCCL_PROXY_CONN_POOL_SIZE;

	struct pollfd pollfds[NCCL_MAX_LOCAL_RANKS+1];
	struct ncclProxyLocalPeer peers[NCCL_MAX_LOCAL_RANKS];
	memset(&peers, 0, sizeof(struct ncclProxyLocalPeer)*NCCL_MAX_LOCAL_RANKS);
	for (int s=0; s<NCCL_MAX_LOCAL_RANKS; s++) {
		pollfds[s].fd = -1;
		pollfds[s].events = POLLHUP|POLLIN;
	}
	if (ncclSocketGetFd(proxyState->listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) {
		WARN("[Proxy Service] Get listenSock fd fails");
		return NULL;
	};
	// 监听输入
	pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN;

	int maxnpeers = 0;
	int npeers = 0;
	int stop = 0;
	int asyncOpCount = 0;
	while (stop == 0 || (stop == 1 && npeers > 0)) {
		/* Even if local comm aborts, we cannot let proxy thread exit if we still have peer
			* connections. Need to wait until all other related comms call abort and safely exit
			* together, or we could face segmentation fault. */
		// 本地退出,不能推出线程,需要等其他 comms 也停止才能一起退出
		if (*proxyState->abortFlag != 0) stop = 1;
		/* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */
		int ret;
		do {
			ret = poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500);
		} while (ret < 0 && errno == EINTR);
		if (ret < 0) {
			WARN("[Proxy Service] Poll failed: %s", strerror(errno));
			return NULL;
		}
		if (pollfds[NCCL_MAX_LOCAL_RANKS].revents) {
			int s = 0;
			while (s < NCCL_MAX_LOCAL_RANKS && pollfds[s].fd >= 0) s++;
			if (s == NCCL_MAX_LOCAL_RANKS) {
				WARN("[Proxy service] Too many connections (%d max)", NCCL_MAX_LOCAL_RANKS);
				return NULL;
			}
			if (maxnpeers < s+1) maxnpeers = s+1;
			// 初始化 socket
			if (ncclSocketInit(&peers[s].sock) != ncclSuccess) {
				WARN("[Service thread] Initialize peers[%d].sock fails", s);
				return NULL;
			}
			// accept
			if (ncclSocketAccept(&peers[s].sock, proxyState->listenSock) != ncclSuccess) {
				WARN("[Service thread] Accept failed %s", strerror(errno));
			} else {
				// 监听 fd 到 pollfds
				if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) {
					WARN("[Service thread] Get peers[%d].sock fd fails", s);
					return NULL;
				}
				npeers++;
				peers[s].tpLocalRank = -1;
			}
		}
		for (int s=0; s<maxnpeers; s++) {
			struct ncclProxyLocalPeer* peer = peers+s;
			struct ncclSocket* sock = &peer->sock;
			int closeConn = 0;
			int type = 0;
			ncclResult_t res = ncclSuccess;
			if (pollfds[s].fd == -1) 
				continue;

			// Progress all ops for this ncclProxyLocalPeer
			ncclProxyAsyncOp* op = peer->asyncOps;
			while (op != nullptr) {
				ncclProxyAsyncOp* opnext = op->next; /* in case op is freed in proxyProgressAsync */
				type = op->type;
				res = proxyProgressAsync(op, proxyState, &asyncOpCount, peer, &connectionPool);
				if (res == ncclSuccess || res == ncclInProgress) {
					op = opnext;
				} else {
					// Res is a bad result
					closeConn = 1;
					WARN("[Service thread] Error encountered progressing operation=%s, res=%d, closing connection", ncclProxyMsgTypeStr[type], res);
					break;
				}
			}

			// Check for additional ops coming in
			// 检查是否有输入
			if (pollfds[s].revents & POLLIN) {
				int closed;
				// 先接收 Type
				res = ncclSocketTryRecv(sock, &type, sizeof(int), &closed, false /*blocking*/);
				if (res != ncclSuccess && res != ncclInProgress) {
					WARN("[Service thread] Could not receive type from localRank %d, res=%u, closed=%d", peer->tpLocalRank, res, closed);
					closeConn = 1;
				} else if (closed) {
					INFO(NCCL_INIT|NCCL_NET|NCCL_PROXY, "[Service thread] Connection closed by localRank %d", peer->tpLocalRank);
					closeConn = 1;
				} else if (res == ncclSuccess) { // We received something from the sock
					// 接收到数据,根据 type 做不同的动作
					if (type == ncclProxyMsgStop) {
					// 关闭连接
						stop = 1;
					closeConn = 1;
					} else if (type == ncclProxyMsgClose) {
					// 关闭连接
						closeConn = 1;
					} else if (proxyMatchOpType(type)) {
						// 处理客户端即设备的请求,根据 type 进行不同的处理
						res = proxyServiceInitOp(type, peers+s, &connectionPool, proxyState, &asyncOpCount);
					} else {
						// 关闭连接
						WARN("[Service thread] Unknown command %d from localRank %d", type, peer->tpLocalRank);
						closeConn = 1;
					}

					INFO(NCCL_PROXY, "Received and initiated operation=%s res=%d", ncclProxyMsgTypeStr[type], res);
				}
			} else if (pollfds[s].revents & POLLHUP) {
				// 关闭连接
				closeConn = 1;
			}

			if (res != ncclSuccess && res != ncclInProgress) {
				// 关闭连接
				WARN("[Proxy Service %d] Failed to execute operation %s from rank %d, retcode %d", proxyState->tpRank, ncclProxyMsgTypeStr[type], peer->tpRank, res);
				closeConn = 1;
			}

			if (closeConn) {
				// 关闭连接
				ncclSocketClose(sock);

				if (op != nullptr) {
					asyncProxyOpDequeue(peer, op);
					asyncOpCount--;
				}
				pollfds[s].fd = -1;
				npeers--;
			}
		}
	}

	// 退出操作
	// Wait for all operations to complete and stop progress thread before freeing any resource
	if (ncclProxyProgressDestroy(proxyState) != ncclSuccess) {
		WARN("[Proxy Service] proxyDestroy failed");
	}
	for (int s=0; s<maxnpeers; s++) {
		ncclSocketClose(&peers[s].sock);
	}
	ncclProxyFreeConnections(&connectionPool, proxyState);
	ncclSocketClose(proxyState->listenSock);
	free(proxyState->listenSock);
	proxyOpsFree(proxyState);
	return NULL;
}
2.2 proxyServiceInitOp()

线程中的主要处理函数,因为客户端发送数据的时候是先发什么后发什么的顺序,所以接收也先按一定的顺序接收数据,然后调用 proxyProgressAsync 进行处理;


// 本地 rank 的 proxyState
// peers 是保存在服务端的数据,数据保存的是客户端的信息
// peer 抽象的是客户端对象
res = proxyServiceInitOp(type, peers+s, &connectionPool, proxyState, &asyncOpCount);
static ncclResult_t proxyServiceInitOp(int type, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, int* asyncOpCount) {
	// 服务端 sock
	struct ncclSocket* sock = &peer->sock;
	// 申请内存
	struct ncclProxyAsyncOp* asyncOp;
	NCCLCHECK(ncclCalloc(&asyncOp, 1));

	asyncOp->type = type;
	// 按照客户端发送的顺序,接收各个字段
	// 接收 connection, 指向发送端 connection 对象的首地址
	NCCLCHECK(ncclSocketRecv(sock, &asyncOp->connection, sizeof(void*)));
	// 获取发送长度
	NCCLCHECK(ncclSocketRecv(sock, &asyncOp->reqSize, sizeof(int)));
	// 获取接收缓冲区大小
	NCCLCHECK(ncclSocketRecv(sock, &asyncOp->respSize, sizeof(int)));
	if (asyncOp->reqSize) {
		// 如果发送长度大于0,发送端会发送数据,接收端要接收数据
		// 先申请内存再接收数据
		NCCLCHECK(ncclCalloc(&asyncOp->reqBuff, asyncOp->reqSize));
		NCCLCHECK(ncclSocketRecv(sock, asyncOp->reqBuff, asyncOp->reqSize));
	}

	// Store opId for completion response
	// 接收发送端 opId 的首地址
	NCCLCHECK(ncclSocketRecv(sock, &asyncOp->opId, sizeof(asyncOp->opId)));

	// 如果发送端要接收数据,则接收数据大小大于0,服务端要申请发送缓冲区内存
	if (asyncOp->respSize) 
		NCCLCHECK(ncclCalloc(&asyncOp->respBuff, asyncOp->respSize));

	// 请求 asyncOp 加入peer 对象链表中 peer->asyncOps
	asyncProxyOpEnqueue(peer, asyncOp);

	(*asyncOpCount)++;
	// 处理请求
	NCCLCHECK(proxyProgressAsync(asyncOp, proxyState, asyncOpCount, peer, connectionPool));
	return ncclSuccess;
}
2.2 proxyProgressAsync()

处理请求函数,根据参数 type 进行不同的逻辑处理,然后按照一定的顺序返回数据

NCCLCHECK(proxyProgressAsync(asyncOp, proxyState, asyncOpCount, peer, connectionPool));
static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclProxyState* proxyState, int* asyncOpCount, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool) {
	int done = 1;
	if (op->type == ncclProxyMsgSetup) {
		// 调用 proxy proxySetup API
		TRACE(NCCL_PROXY, "proxyProgressAsync::proxySetup() opId=%p", op->opId);
		NCCLCHECK(op->connection->tcomm->proxySetup(op->connection, proxyState, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done));
	} else if (op->type == ncclProxyMsgConnect) {
		// 调用 proxy proxyConnect API
		TRACE(NCCL_PROXY, "proxyProgressAsync::proxyConnect() opId=%p op.reqBuff=%p", op->opId, op->reqBuff);
		NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, proxyState, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done));
	} else if (op->type == ncclProxyMsgSharedInit) {
		int nChannels = (int) *op->reqBuff;
		// 调用 proxy proxySharedInit API
		TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgSharedInit opId=%p op.reqBuff=%p nChannels=%d", op->opId, op->reqBuff, nChannels);
		if (op->connection->tcomm->proxySharedInit) NCCLCHECK(op->connection->tcomm->proxySharedInit(op->connection, proxyState, nChannels));
		__atomic_store_n(&op->connection->state, connSharedInitialized, __ATOMIC_RELEASE);
	} else if (op->type == ncclProxyMsgConvertFd) {
		int fd = *(int *)op->reqBuff;
		TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgConvertFd opId=%p op.reqBuff=%p fd=%d", op->opId, op->reqBuff, fd);
		NCCLCHECK(proxyConvertFd(peer, op->opId, proxyState, fd)); // cuMem API support
	} else if (op->type == ncclProxyMsgInit) {
		// 
		TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgInit opId=%p op.reqBuff=%p", op->opId, op->reqBuff);
		NCCLCHECK(proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection));
		static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, ncclProxyInitReq* req, ncclProxyInitResp* resp, struct ncclProxyConnection** connection) 
		{
			int id;
			// 为 connectionPool-> pools 分配空间,
			// connectionPool->offset++
			// id = ((pool->banks-1) << NCCL_PROXY_CONN_POOL_SIZE_POW2) + pool->offset;
			// offset 为 (1 << 7) 个,为一个 bank
			NCCLCHECK(ncclProxyNewConnection(connectionPool, &id));
			// 根据 id 获取 bank 与 offset
			// 根据 bank与 offset 获取 ncclProxyConnection 首地址 connection
			NCCLCHECK(ncclProxyGetConnection(connectionPool, id, connection));
			// 填充 connection
			(*connection)->sock = &peer->sock;
			(*connection)->transport = req->transport;
			(*connection)->send = req->send;
			(*connection)->tpLocalRank = req->tpLocalRank;
			(*connection)->sameProcess = req->sameProcess;
			peer->tpLocalRank = req->tpLocalRank;
			peer->tpRank = req->tpRank;
			// connection 首地址给 resp->connection, 要告诉客户端
			resp->connection = *connection;

			(*connection)->tcomm = (*connection)->send ? &ncclTransports[(*connection)->transport]->send : &ncclTransports[(*connection)->transport]->recv;
			// If we need proxy progress, let's allocate ops and start the thread
			if ((*connection)->tcomm->proxyProgress) {
				NCCLCHECK(proxyProgressInit(proxyState));
				struct ncclProxyProgressState* state = &proxyState->progressState;
				strncpy(resp->devShmPath, state->opsPoolShmSuffix, sizeof(resp->devShmPath));
			}
			INFO(NCCL_NET|NCCL_PROXY, "New proxy %s connection %d from local rank %d, transport %d", (*connection)->send ? "send":"recv", id, (*connection)->tpLocalRank, (*connection)->transport);
			__atomic_store_n(&(*connection)->state, connInitialized, __ATOMIC_RELEASE);
			return ncclSuccess;
		}
	} else 
		return ncclInternalError;

	if (done) {
		INFO(NCCL_PROXY, "proxyProgressAsync opId=%p op.type=%d op.reqBuff=%p op.respSize=%d done", op->opId, op->type, op->reqBuff, op->respSize);
		if (op->type == ncclProxyMsgSetup)
			__atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE);
		else if (op->type == ncclProxyMsgConnect)
			__atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE);
		/* if setup or connect is done, we should not return any error at this point since
			* ncclSocketSend might already send the respBuff to the requester. If we still choose
			* to abort and close the connection, it can cause segfault if the requester is using
			* the respBuff. */

		// Send the opId for referencing async operation
		// 发送 opId
		NCCLCHECK(ncclSocketSend(op->connection->sock, &op->opId, sizeof(op->opId)));

		// Send the response size
		// 发送接收大小
		NCCLCHECK(ncclSocketSend(op->connection->sock, &op->respSize, sizeof(op->respSize)));

		if (op->respSize) {
			// Send the response
			// 发送响应
			NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize));
		}
		// op 移出链表
		asyncProxyOpDequeue(peer, op);
		(*asyncOpCount)--;
		return ncclSuccess;

	} else if (*proxyState->abortFlag != 0) {
		return ncclInternalError;
	}

	return ncclInProgress;
}

4. ncclProxyConnect()

以其中链接为例 :如果要使用代理,那么首先要先连接,通过 type 为 ncclProxyMsgInit 告诉代理,我要链接,代理线程会 accept 建立 socket, 返回连接的 ncclProxyConnection connection 对象的首地址

链接流程如下,主要关注数据传输,有的传数据,有的传首地址:


// p2p send connector
// rank GPU 设备连接 proxy TCP 服务端,服务端建立保存连接,申请通信所需的内存资源
struct ncclConnector* send
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpProxyRank, &send->proxyConn));
ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int tpProxyRank, struct ncclProxyConnector* proxyConn) {
  struct ncclSocket* sock;
  int ready, proxyRank = -1;
  struct ncclProxyState* sharedProxyState = comm->proxyState;

  // Keep one connection per mlocal rank
  for (int i = 0; i < comm->localRanks; ++i) {
    /* find the proxy rank in comm. */
    if (comm->topParentRanks[comm->localRankToRank[i]] == tpProxyRank) {
      proxyRank = comm->localRankToRank[i];
      break;
    }
  }
  proxyConn->sameProcess = comm->peerInfo[proxyRank].pidHash == comm->peerInfo[comm->rank].pidHash ? 1 : 0;
  // Keep one connection per local rank
  proxyConn->connection = NULL;
  proxyConn->tpRank = tpProxyRank;
  // peerSocks 初始化
  if (sharedProxyState->peerSocks == NULL) {
    NCCLCHECK(ncclCalloc(&sharedProxyState->peerSocks, comm->sharedRes->tpNLocalRanks));
    NCCLCHECK(ncclCalloc(&sharedProxyState->proxyOps, comm->sharedRes->tpNLocalRanks));
    NCCLCHECK(ncclCalloc(&sharedProxyState->sharedDevMems, comm->sharedRes->tpNLocalRanks));
    for (int i = 0; i < comm->sharedRes->tpNLocalRanks; ++i) {
      NCCLCHECK(ncclSocketSetFd(-1, &sharedProxyState->peerSocks[i]));
    }
  }

  proxyConn->tpLocalRank = comm->sharedRes->tpRankToLocalRank[proxyConn->tpRank];
  sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;
  NCCLCHECK(ncclSocketReady(sock, &ready));
  if (!ready) {
	// scoket 初始化 socket
    NCCLCHECK(ncclSocketInit(sock, sharedProxyState->peerAddresses+proxyConn->tpRank, comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag));
	// 连接 proxy 服务线程中监听的端口
    NCCLCHECK(ncclSocketConnect(sock));
  }

  struct ncclProxyInitReq req = {0};
  req.transport = transport;
  req.send = send;
  req.tpLocalRank = comm->topParentLocalRanks[comm->localRank];
  req.tpRank = comm->topParentRanks[comm->rank];
  req.sameProcess = proxyConn->sameProcess;

  struct ncclProxyInitResp resp = {0};
  // This usually sends proxyConn->connection to identify which connection this is.
  // However, this is part of the response and therefore is ignored
  // 收发消息初始化,proxy 服务端申请内存,建立连接
  NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgInit, &req, sizeof(req), &resp, sizeof(resp)));
  // resp.connection 为服务端的 connection 对象的首地址
  proxyConn->connection = resp.connection;

  // If we need proxy progress, map progress ops
  struct ncclTransportComm* tcomm = send ? &ncclTransports[transport]->send : &ncclTransports[transport]->recv;
  if (tcomm->proxyProgress) {
    char poolPath[] = "/dev/shm/nccl-XXXXXX";
    strncpy(poolPath+sizeof("/dev/shm/nccl-")-1, resp.devShmPath, sizeof("XXXXXX")-1);
    struct ncclProxyOps* proxyOps = sharedProxyState->proxyOps + proxyConn->tpLocalRank;
    if (proxyOps->pool == NULL) {
      NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, 0, &proxyOps->handle));
      proxyOps->nextOps = proxyOps->nextOpsEnd = proxyOps->freeOp = -1;
    }
  }
  INFO(NCCL_NET|NCCL_PROXY, "Connection to proxy localRank %d -> connection %p", proxyConn->tpLocalRank, proxyConn->connection);
  return ncclSuccess;
}
4.1 ncclProxyCallBlocking()

调用代理线程接口,即开始发送命令,接收返回。

// 客户端通知 proxy 服务端调用响应接口,服务端根据 type 做不同的处理
// ncclProxyMsgInit 表示服务端进行通信初始化
NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgInit, &req, sizeof(req), &resp, sizeof(resp)));
ncclResult_t ncclProxyCallBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) {
	// Alloc some memory to act as a handle
	ncclResult_t res = ncclSuccess;
	void* opId = malloc(1);
	// ncclProxyCallAsync()
	// 首先发送 type
	// 再发送 proxyConn->connection 的首地址
	// 发送 reqSize
	// 发送 respSize
	// 如果 reqSize 大于0,说明有发送数据,即发送数据
	// 发送 opId 的首地址
	NCCLCHECKGOTO(ncclProxyCallAsync(comm, proxyConn, type, reqBuff, reqSize, respSize, opId), res, fail);
		struct ncclProxyState* sharedProxyState = comm->proxyState;
		sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;
		// 将当前 请求放入 state 的链表中  state->expectedResponses;
		NCCLCHECK(expectedProxyResponseEnqueue(sharedProxyState, opId, respSize));
		{
			struct ncclExpectedProxyResponse* ex;
			NCCLCHECK(ncclCalloc(&ex, 1));
			ex->opId = opId;

			// Pre-alloc response buffer
			ex->respBuff = malloc(respSize);
			ex->respSize = respSize;
			ex->done     = false;
			struct ncclExpectedProxyResponse* list = state->expectedResponses;
			if (list == NULL) {
				state->expectedResponses = ex;
				return ncclSuccess;
			}
			while (list->next) list = list->next;
			list->next = ex;
		}
		    

	do {
		res = ncclPollProxyResponse(comm, proxyConn, respBuff, opId);
		{
			int found = 0;
			// 如果 opId 在链表中找到,且 done 字段已被置为 True, 则拷贝数据到 respBuff, found 置 1
  			NCCLCHECK(expectedProxyResponseDequeue(sharedProxyState, opId, respBuff, &found));
		}
	} while (res == ncclInProgress);

exit:
	free(opId);
	return res;
fail:
	goto exit;
}
4.2 ncclPollProxyResponse()

发送的时候有 opId 作为此次通信的标识,代理线程返回数据时也会把这个opId带回来
所以接受的时候要比较 opId, 如果与本次发送的 opId 一样,那么就接收成功;
如果不一样,那么把接受的数据放入缓冲区,继续接收

// 轮询等待 opId 的返回数据
res = ncclPollProxyResponse(comm, proxyConn, respBuff, opId);
ncclResult_t ncclPollProxyResponse(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* respBuff, void* opId) {
	struct ncclProxyState* sharedProxyState = comm->proxyState;
	// Receive the connection pointer from the Proxy
	// 检查停止字段
	if (*comm->abortFlag) {
		WARN("Comm %p is in abort state", comm);
		return ncclInternalError;
	}

	if (sharedProxyState->peerSocks == NULL) 
		return ncclInternalError;

	// Check response queue
	int found = 0;
	// 如果 opId 在链表中找到,且 done 字段已被置为 True, 则拷贝数据到 respBuff, found 置 1
	NCCLCHECK(expectedProxyResponseDequeue(sharedProxyState, opId, respBuff, &found));
	if (found == 0) {
		// 发送完之后,还没收到回复,虽然有 opId, 但是 done 字段仍为 False, 所以 found == 0
		// Attempt to read in a new response header from the proxy thread
		// 对于没有父节点的 comm来说,tpLocalRank 就是 comm->localrank
		// 获取发送端的 socket
		struct ncclSocket* sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;

		void* recvOpId;
		int offset = 0;
		// 接收数据,先接受 opId
		if (ncclSuccess != ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)) {
			WARN("Socket recv failed while polling for opId=%p", opId);
			return ncclInternalError;
		}

		// 确保接收全部数据, offset == 0 返回 ncclInProgress 继续接收数据
		if (offset == 0) {
			return ncclInProgress;
		// If we've returned a partial response, block to receive the rest of it
		} else if (offset < sizeof(recvOpId)) {
			while (offset < sizeof(recvOpId))
			NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset));
		}

		INFO(NCCL_PROXY, "ncclPollProxyResponse Received new opId=%p", recvOpId);

		// Now do a blocking recv of the response size
		int respSize = 0;
		// 接收返回数据的大小
		NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(respSize)));

		// If there's a respSize to recv
		if (respSize > 0) {
			// 有返回数据
			if (recvOpId != opId) {
				// Unexpected response, need to buffer the socket data
				// 对于意想不到的 opId, 申请内存保存数据
				respBuff = malloc(respSize);
			}
			assert(respBuff != NULL);
			// 接收返回的数据
			NCCLCHECK(ncclSocketRecv(sock, respBuff, respSize));
		}

		if (recvOpId == opId) {
			// 如果已经接收了 opId 的数据,则在 state->expectedResponses 链表中移除 opId 相对应的项
			INFO(NCCL_PROXY, "recvOpId=%p matches expected opId=%p", recvOpId, opId);
			NCCLCHECK(expectedProxyResponseRemove(sharedProxyState, recvOpId));
			// 返回成功
			return ncclSuccess;
		} else {
			INFO(NCCL_PROXY, "Queuing opId=%p respBuff=%p respSize=%d", recvOpId, respBuff, respSize);
			// Store the result and mark response as completed
			// 如果接收的是其他 opId 的数据,则拷贝数据到缓冲区,并置 elem->done 为 True
			NCCLCHECK(expectedProxyResponseStore(sharedProxyState, recvOpId, respBuff, respSize));
			// 返回,继续处理接收数据
			return ncclInProgress;
		}
	} else {
		INFO(NCCL_PROXY, "ncclPollProxyResponse Dequeued cached opId=%p", opId);
	}

	return ncclSuccess;
}

你可能感兴趣的:(NCCL,NCCL,Linux,nvidia)