NCCL 源码解析总目录
我尽量在每个函数之前介绍每个函数的作用,建议先不要投入到函数内部实现,先把函数作用搞清楚,有了整体框架,再回归到细节。
习惯: 我的笔记习惯:为了便于快速理解,函数调用关系通过缩进表示,也可能是函数展开,根据情况而定。
如下
// 调用 proxyConnInit
NCCLCHECK(proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection));
// 对函数 proxyConnInit 进行展开,可方便看参数
static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, ncclProxyInitReq* req, ncclProxyInitResp* resp, struct
如有问题,请留言指正。
图后面再补;
有些遗漏之处,还没涉及,后面补;
闲话后面再补。
每个GPU对应一个管理线程或者进程,在卡与卡之间建立通信的时候,会额外创建一个代理线程去完成这件事,代理线程是被动的,该做什么事还是由GPU对应的管理线程去通过TCP下发。
代理线程的主要工作有:
主要关注 comm->proxyState
的初始化,后面会作为理线程参数代使用,用到的时候再来看也行。
// 初始化
commAlloc()
NCCLCHECK(ncclCalloc(&sharedRes, 1));
bootstrapInit()
// proxy is aborted through a message; don't set abortFlag
// 申请内存
NCCLCHECK(ncclCalloc(&proxySocket, 1));
// 建立 socket -> proxySocket
NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag));
// Listen 状态
NCCLCHECK(ncclSocketListen(proxySocket));
// 获取地址保存在 state->peerProxyAddresses + rank , IP + Port
NCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank));
struct bootstrapState* state;
comm->bootstrap = state;
// 所有节点聚合, state->peerProxyAddresses 保存全部地址
NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress)));
// 申请内存初始化 comm->proxyState
NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses));
NCCLCHECK(ncclCalloc(&comm->sharedRes->proxyState, 1));
comm->proxyState = comm->sharedRes->proxyState;
comm->proxyState->refCount = 1;
comm->proxyState->listenSock = proxySocket;
comm->proxyState->peerAddresses = state->peerProxyAddresses;
主要通过 ncclProxyCreate()
进行 proxyState 对象属性初始化,NCCL 初始化时会创建线程ncclProxyService
。
ncclProxyCreate(comm)
{
// proxyState 来自 comm->proxyState
struct ncclProxyState* proxyState = comm->proxyState;
// 属性初始化,每个属性什么用,用到的时候介绍
proxyState->tpRank = comm->rank;
proxyState->tpnRanks = comm->nRanks;
proxyState->tpLocalnRanks = comm->localRanks;
proxyState->cudaDev = comm->cudaDev;
proxyState->abortFlag = comm->abortFlag;
proxyState->p2pnChannels = comm->p2pnChannels;
proxyState->p2pChunkSize = comm->p2pChunkSize;
proxyState->nChannels = comm->nChannels;
proxyState->allocP2pNetLLBuffers = comm->allocP2pNetLLBuffers;
proxyState->dmaBufSupport = comm->dmaBufSupport;
proxyState->ncclNet = comm->ncclNet;
proxyState->ncclCollNet = comm->ncclCollNet;
memcpy(proxyState->buffSizes, comm->buffSizes, sizeof(comm->buffSizes));
// 创建线程
pthread_create(&comm->proxyState->thread, NULL, ncclProxyService, comm->proxyState);
}
proxy 服务线程代码, 一个设备起一个 proxy
线程,线程名为 NCCL Service %rank
。
线程主要做三件事:
type
做事type 定义如下:
enum ncclProxyMsgType {
ncclProxyMsgInit = 1, // 建立 tcp 连接
ncclProxyMsgSharedInit = 2, // 代理线程调用 ncclTransportComm 的 proxySharedInit
ncclProxyMsgSetup = 3, // 代理线程调用 ncclTransportComm 的 proxySetup
ncclProxyMsgConnect = 4, // 代理线程调用 ncclTransportComm 的 proxyConnect
ncclProxyMsgStart = 5, // 还没用
ncclProxyMsgClose = 6, // 关闭 TCP 链接
ncclProxyMsgAbort = 7, // 还没用
ncclProxyMsgStop = 8, // 停用链接,如果所有链接都停用了,代理线程才退出
ncclProxyMsgConvertFd = 9, // cuMem API support (UDS)
};
线程中主要的处理函数是 proxyServiceInitOp()
// 参数
args = comm->proxyState
void* ncclProxyService(void* _args) {
struct ncclProxyState* proxyState = (struct ncclProxyState*) _args;
// Prepare poll descriptor
struct ncclProxyConnectionPool connectionPool;
connectionPool.pools = NULL;
connectionPool.banks = 0;
connectionPool.offset = NCCL_PROXY_CONN_POOL_SIZE;
struct pollfd pollfds[NCCL_MAX_LOCAL_RANKS+1];
struct ncclProxyLocalPeer peers[NCCL_MAX_LOCAL_RANKS];
memset(&peers, 0, sizeof(struct ncclProxyLocalPeer)*NCCL_MAX_LOCAL_RANKS);
for (int s=0; s<NCCL_MAX_LOCAL_RANKS; s++) {
pollfds[s].fd = -1;
pollfds[s].events = POLLHUP|POLLIN;
}
if (ncclSocketGetFd(proxyState->listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) {
WARN("[Proxy Service] Get listenSock fd fails");
return NULL;
};
// 监听输入
pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN;
int maxnpeers = 0;
int npeers = 0;
int stop = 0;
int asyncOpCount = 0;
while (stop == 0 || (stop == 1 && npeers > 0)) {
/* Even if local comm aborts, we cannot let proxy thread exit if we still have peer
* connections. Need to wait until all other related comms call abort and safely exit
* together, or we could face segmentation fault. */
// 本地退出,不能推出线程,需要等其他 comms 也停止才能一起退出
if (*proxyState->abortFlag != 0) stop = 1;
/* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */
int ret;
do {
ret = poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500);
} while (ret < 0 && errno == EINTR);
if (ret < 0) {
WARN("[Proxy Service] Poll failed: %s", strerror(errno));
return NULL;
}
if (pollfds[NCCL_MAX_LOCAL_RANKS].revents) {
int s = 0;
while (s < NCCL_MAX_LOCAL_RANKS && pollfds[s].fd >= 0) s++;
if (s == NCCL_MAX_LOCAL_RANKS) {
WARN("[Proxy service] Too many connections (%d max)", NCCL_MAX_LOCAL_RANKS);
return NULL;
}
if (maxnpeers < s+1) maxnpeers = s+1;
// 初始化 socket
if (ncclSocketInit(&peers[s].sock) != ncclSuccess) {
WARN("[Service thread] Initialize peers[%d].sock fails", s);
return NULL;
}
// accept
if (ncclSocketAccept(&peers[s].sock, proxyState->listenSock) != ncclSuccess) {
WARN("[Service thread] Accept failed %s", strerror(errno));
} else {
// 监听 fd 到 pollfds
if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) {
WARN("[Service thread] Get peers[%d].sock fd fails", s);
return NULL;
}
npeers++;
peers[s].tpLocalRank = -1;
}
}
for (int s=0; s<maxnpeers; s++) {
struct ncclProxyLocalPeer* peer = peers+s;
struct ncclSocket* sock = &peer->sock;
int closeConn = 0;
int type = 0;
ncclResult_t res = ncclSuccess;
if (pollfds[s].fd == -1)
continue;
// Progress all ops for this ncclProxyLocalPeer
ncclProxyAsyncOp* op = peer->asyncOps;
while (op != nullptr) {
ncclProxyAsyncOp* opnext = op->next; /* in case op is freed in proxyProgressAsync */
type = op->type;
res = proxyProgressAsync(op, proxyState, &asyncOpCount, peer, &connectionPool);
if (res == ncclSuccess || res == ncclInProgress) {
op = opnext;
} else {
// Res is a bad result
closeConn = 1;
WARN("[Service thread] Error encountered progressing operation=%s, res=%d, closing connection", ncclProxyMsgTypeStr[type], res);
break;
}
}
// Check for additional ops coming in
// 检查是否有输入
if (pollfds[s].revents & POLLIN) {
int closed;
// 先接收 Type
res = ncclSocketTryRecv(sock, &type, sizeof(int), &closed, false /*blocking*/);
if (res != ncclSuccess && res != ncclInProgress) {
WARN("[Service thread] Could not receive type from localRank %d, res=%u, closed=%d", peer->tpLocalRank, res, closed);
closeConn = 1;
} else if (closed) {
INFO(NCCL_INIT|NCCL_NET|NCCL_PROXY, "[Service thread] Connection closed by localRank %d", peer->tpLocalRank);
closeConn = 1;
} else if (res == ncclSuccess) { // We received something from the sock
// 接收到数据,根据 type 做不同的动作
if (type == ncclProxyMsgStop) {
// 关闭连接
stop = 1;
closeConn = 1;
} else if (type == ncclProxyMsgClose) {
// 关闭连接
closeConn = 1;
} else if (proxyMatchOpType(type)) {
// 处理客户端即设备的请求,根据 type 进行不同的处理
res = proxyServiceInitOp(type, peers+s, &connectionPool, proxyState, &asyncOpCount);
} else {
// 关闭连接
WARN("[Service thread] Unknown command %d from localRank %d", type, peer->tpLocalRank);
closeConn = 1;
}
INFO(NCCL_PROXY, "Received and initiated operation=%s res=%d", ncclProxyMsgTypeStr[type], res);
}
} else if (pollfds[s].revents & POLLHUP) {
// 关闭连接
closeConn = 1;
}
if (res != ncclSuccess && res != ncclInProgress) {
// 关闭连接
WARN("[Proxy Service %d] Failed to execute operation %s from rank %d, retcode %d", proxyState->tpRank, ncclProxyMsgTypeStr[type], peer->tpRank, res);
closeConn = 1;
}
if (closeConn) {
// 关闭连接
ncclSocketClose(sock);
if (op != nullptr) {
asyncProxyOpDequeue(peer, op);
asyncOpCount--;
}
pollfds[s].fd = -1;
npeers--;
}
}
}
// 退出操作
// Wait for all operations to complete and stop progress thread before freeing any resource
if (ncclProxyProgressDestroy(proxyState) != ncclSuccess) {
WARN("[Proxy Service] proxyDestroy failed");
}
for (int s=0; s<maxnpeers; s++) {
ncclSocketClose(&peers[s].sock);
}
ncclProxyFreeConnections(&connectionPool, proxyState);
ncclSocketClose(proxyState->listenSock);
free(proxyState->listenSock);
proxyOpsFree(proxyState);
return NULL;
}
线程中的主要处理函数,因为客户端发送数据的时候是先发什么后发什么的顺序,所以接收也先按一定的顺序接收数据,然后调用 proxyProgressAsync
进行处理;
// 本地 rank 的 proxyState
// peers 是保存在服务端的数据,数据保存的是客户端的信息
// peer 抽象的是客户端对象
res = proxyServiceInitOp(type, peers+s, &connectionPool, proxyState, &asyncOpCount);
static ncclResult_t proxyServiceInitOp(int type, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, int* asyncOpCount) {
// 服务端 sock
struct ncclSocket* sock = &peer->sock;
// 申请内存
struct ncclProxyAsyncOp* asyncOp;
NCCLCHECK(ncclCalloc(&asyncOp, 1));
asyncOp->type = type;
// 按照客户端发送的顺序,接收各个字段
// 接收 connection, 指向发送端 connection 对象的首地址
NCCLCHECK(ncclSocketRecv(sock, &asyncOp->connection, sizeof(void*)));
// 获取发送长度
NCCLCHECK(ncclSocketRecv(sock, &asyncOp->reqSize, sizeof(int)));
// 获取接收缓冲区大小
NCCLCHECK(ncclSocketRecv(sock, &asyncOp->respSize, sizeof(int)));
if (asyncOp->reqSize) {
// 如果发送长度大于0,发送端会发送数据,接收端要接收数据
// 先申请内存再接收数据
NCCLCHECK(ncclCalloc(&asyncOp->reqBuff, asyncOp->reqSize));
NCCLCHECK(ncclSocketRecv(sock, asyncOp->reqBuff, asyncOp->reqSize));
}
// Store opId for completion response
// 接收发送端 opId 的首地址
NCCLCHECK(ncclSocketRecv(sock, &asyncOp->opId, sizeof(asyncOp->opId)));
// 如果发送端要接收数据,则接收数据大小大于0,服务端要申请发送缓冲区内存
if (asyncOp->respSize)
NCCLCHECK(ncclCalloc(&asyncOp->respBuff, asyncOp->respSize));
// 请求 asyncOp 加入peer 对象链表中 peer->asyncOps
asyncProxyOpEnqueue(peer, asyncOp);
(*asyncOpCount)++;
// 处理请求
NCCLCHECK(proxyProgressAsync(asyncOp, proxyState, asyncOpCount, peer, connectionPool));
return ncclSuccess;
}
处理请求函数,根据参数 type
进行不同的逻辑处理,然后按照一定的顺序返回数据
NCCLCHECK(proxyProgressAsync(asyncOp, proxyState, asyncOpCount, peer, connectionPool));
static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclProxyState* proxyState, int* asyncOpCount, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool) {
int done = 1;
if (op->type == ncclProxyMsgSetup) {
// 调用 proxy proxySetup API
TRACE(NCCL_PROXY, "proxyProgressAsync::proxySetup() opId=%p", op->opId);
NCCLCHECK(op->connection->tcomm->proxySetup(op->connection, proxyState, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done));
} else if (op->type == ncclProxyMsgConnect) {
// 调用 proxy proxyConnect API
TRACE(NCCL_PROXY, "proxyProgressAsync::proxyConnect() opId=%p op.reqBuff=%p", op->opId, op->reqBuff);
NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, proxyState, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done));
} else if (op->type == ncclProxyMsgSharedInit) {
int nChannels = (int) *op->reqBuff;
// 调用 proxy proxySharedInit API
TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgSharedInit opId=%p op.reqBuff=%p nChannels=%d", op->opId, op->reqBuff, nChannels);
if (op->connection->tcomm->proxySharedInit) NCCLCHECK(op->connection->tcomm->proxySharedInit(op->connection, proxyState, nChannels));
__atomic_store_n(&op->connection->state, connSharedInitialized, __ATOMIC_RELEASE);
} else if (op->type == ncclProxyMsgConvertFd) {
int fd = *(int *)op->reqBuff;
TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgConvertFd opId=%p op.reqBuff=%p fd=%d", op->opId, op->reqBuff, fd);
NCCLCHECK(proxyConvertFd(peer, op->opId, proxyState, fd)); // cuMem API support
} else if (op->type == ncclProxyMsgInit) {
//
TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgInit opId=%p op.reqBuff=%p", op->opId, op->reqBuff);
NCCLCHECK(proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection));
static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, ncclProxyInitReq* req, ncclProxyInitResp* resp, struct ncclProxyConnection** connection)
{
int id;
// 为 connectionPool-> pools 分配空间,
// connectionPool->offset++
// id = ((pool->banks-1) << NCCL_PROXY_CONN_POOL_SIZE_POW2) + pool->offset;
// offset 为 (1 << 7) 个,为一个 bank
NCCLCHECK(ncclProxyNewConnection(connectionPool, &id));
// 根据 id 获取 bank 与 offset
// 根据 bank与 offset 获取 ncclProxyConnection 首地址 connection
NCCLCHECK(ncclProxyGetConnection(connectionPool, id, connection));
// 填充 connection
(*connection)->sock = &peer->sock;
(*connection)->transport = req->transport;
(*connection)->send = req->send;
(*connection)->tpLocalRank = req->tpLocalRank;
(*connection)->sameProcess = req->sameProcess;
peer->tpLocalRank = req->tpLocalRank;
peer->tpRank = req->tpRank;
// connection 首地址给 resp->connection, 要告诉客户端
resp->connection = *connection;
(*connection)->tcomm = (*connection)->send ? &ncclTransports[(*connection)->transport]->send : &ncclTransports[(*connection)->transport]->recv;
// If we need proxy progress, let's allocate ops and start the thread
if ((*connection)->tcomm->proxyProgress) {
NCCLCHECK(proxyProgressInit(proxyState));
struct ncclProxyProgressState* state = &proxyState->progressState;
strncpy(resp->devShmPath, state->opsPoolShmSuffix, sizeof(resp->devShmPath));
}
INFO(NCCL_NET|NCCL_PROXY, "New proxy %s connection %d from local rank %d, transport %d", (*connection)->send ? "send":"recv", id, (*connection)->tpLocalRank, (*connection)->transport);
__atomic_store_n(&(*connection)->state, connInitialized, __ATOMIC_RELEASE);
return ncclSuccess;
}
} else
return ncclInternalError;
if (done) {
INFO(NCCL_PROXY, "proxyProgressAsync opId=%p op.type=%d op.reqBuff=%p op.respSize=%d done", op->opId, op->type, op->reqBuff, op->respSize);
if (op->type == ncclProxyMsgSetup)
__atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE);
else if (op->type == ncclProxyMsgConnect)
__atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE);
/* if setup or connect is done, we should not return any error at this point since
* ncclSocketSend might already send the respBuff to the requester. If we still choose
* to abort and close the connection, it can cause segfault if the requester is using
* the respBuff. */
// Send the opId for referencing async operation
// 发送 opId
NCCLCHECK(ncclSocketSend(op->connection->sock, &op->opId, sizeof(op->opId)));
// Send the response size
// 发送接收大小
NCCLCHECK(ncclSocketSend(op->connection->sock, &op->respSize, sizeof(op->respSize)));
if (op->respSize) {
// Send the response
// 发送响应
NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize));
}
// op 移出链表
asyncProxyOpDequeue(peer, op);
(*asyncOpCount)--;
return ncclSuccess;
} else if (*proxyState->abortFlag != 0) {
return ncclInternalError;
}
return ncclInProgress;
}
以其中链接为例 :如果要使用代理,那么首先要先连接,通过 type 为 ncclProxyMsgInit 告诉代理,我要链接,代理线程会 accept 建立 socket, 返回连接的 ncclProxyConnection connection 对象的首地址
链接流程如下,主要关注数据传输,有的传数据,有的传首地址:
// p2p send connector
// rank GPU 设备连接 proxy TCP 服务端,服务端建立保存连接,申请通信所需的内存资源
struct ncclConnector* send
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpProxyRank, &send->proxyConn));
ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int tpProxyRank, struct ncclProxyConnector* proxyConn) {
struct ncclSocket* sock;
int ready, proxyRank = -1;
struct ncclProxyState* sharedProxyState = comm->proxyState;
// Keep one connection per mlocal rank
for (int i = 0; i < comm->localRanks; ++i) {
/* find the proxy rank in comm. */
if (comm->topParentRanks[comm->localRankToRank[i]] == tpProxyRank) {
proxyRank = comm->localRankToRank[i];
break;
}
}
proxyConn->sameProcess = comm->peerInfo[proxyRank].pidHash == comm->peerInfo[comm->rank].pidHash ? 1 : 0;
// Keep one connection per local rank
proxyConn->connection = NULL;
proxyConn->tpRank = tpProxyRank;
// peerSocks 初始化
if (sharedProxyState->peerSocks == NULL) {
NCCLCHECK(ncclCalloc(&sharedProxyState->peerSocks, comm->sharedRes->tpNLocalRanks));
NCCLCHECK(ncclCalloc(&sharedProxyState->proxyOps, comm->sharedRes->tpNLocalRanks));
NCCLCHECK(ncclCalloc(&sharedProxyState->sharedDevMems, comm->sharedRes->tpNLocalRanks));
for (int i = 0; i < comm->sharedRes->tpNLocalRanks; ++i) {
NCCLCHECK(ncclSocketSetFd(-1, &sharedProxyState->peerSocks[i]));
}
}
proxyConn->tpLocalRank = comm->sharedRes->tpRankToLocalRank[proxyConn->tpRank];
sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;
NCCLCHECK(ncclSocketReady(sock, &ready));
if (!ready) {
// scoket 初始化 socket
NCCLCHECK(ncclSocketInit(sock, sharedProxyState->peerAddresses+proxyConn->tpRank, comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag));
// 连接 proxy 服务线程中监听的端口
NCCLCHECK(ncclSocketConnect(sock));
}
struct ncclProxyInitReq req = {0};
req.transport = transport;
req.send = send;
req.tpLocalRank = comm->topParentLocalRanks[comm->localRank];
req.tpRank = comm->topParentRanks[comm->rank];
req.sameProcess = proxyConn->sameProcess;
struct ncclProxyInitResp resp = {0};
// This usually sends proxyConn->connection to identify which connection this is.
// However, this is part of the response and therefore is ignored
// 收发消息初始化,proxy 服务端申请内存,建立连接
NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgInit, &req, sizeof(req), &resp, sizeof(resp)));
// resp.connection 为服务端的 connection 对象的首地址
proxyConn->connection = resp.connection;
// If we need proxy progress, map progress ops
struct ncclTransportComm* tcomm = send ? &ncclTransports[transport]->send : &ncclTransports[transport]->recv;
if (tcomm->proxyProgress) {
char poolPath[] = "/dev/shm/nccl-XXXXXX";
strncpy(poolPath+sizeof("/dev/shm/nccl-")-1, resp.devShmPath, sizeof("XXXXXX")-1);
struct ncclProxyOps* proxyOps = sharedProxyState->proxyOps + proxyConn->tpLocalRank;
if (proxyOps->pool == NULL) {
NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, 0, &proxyOps->handle));
proxyOps->nextOps = proxyOps->nextOpsEnd = proxyOps->freeOp = -1;
}
}
INFO(NCCL_NET|NCCL_PROXY, "Connection to proxy localRank %d -> connection %p", proxyConn->tpLocalRank, proxyConn->connection);
return ncclSuccess;
}
调用代理线程接口,即开始发送命令,接收返回。
// 客户端通知 proxy 服务端调用响应接口,服务端根据 type 做不同的处理
// ncclProxyMsgInit 表示服务端进行通信初始化
NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgInit, &req, sizeof(req), &resp, sizeof(resp)));
ncclResult_t ncclProxyCallBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) {
// Alloc some memory to act as a handle
ncclResult_t res = ncclSuccess;
void* opId = malloc(1);
// ncclProxyCallAsync()
// 首先发送 type
// 再发送 proxyConn->connection 的首地址
// 发送 reqSize
// 发送 respSize
// 如果 reqSize 大于0,说明有发送数据,即发送数据
// 发送 opId 的首地址
NCCLCHECKGOTO(ncclProxyCallAsync(comm, proxyConn, type, reqBuff, reqSize, respSize, opId), res, fail);
struct ncclProxyState* sharedProxyState = comm->proxyState;
sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;
// 将当前 请求放入 state 的链表中 state->expectedResponses;
NCCLCHECK(expectedProxyResponseEnqueue(sharedProxyState, opId, respSize));
{
struct ncclExpectedProxyResponse* ex;
NCCLCHECK(ncclCalloc(&ex, 1));
ex->opId = opId;
// Pre-alloc response buffer
ex->respBuff = malloc(respSize);
ex->respSize = respSize;
ex->done = false;
struct ncclExpectedProxyResponse* list = state->expectedResponses;
if (list == NULL) {
state->expectedResponses = ex;
return ncclSuccess;
}
while (list->next) list = list->next;
list->next = ex;
}
do {
res = ncclPollProxyResponse(comm, proxyConn, respBuff, opId);
{
int found = 0;
// 如果 opId 在链表中找到,且 done 字段已被置为 True, 则拷贝数据到 respBuff, found 置 1
NCCLCHECK(expectedProxyResponseDequeue(sharedProxyState, opId, respBuff, &found));
}
} while (res == ncclInProgress);
exit:
free(opId);
return res;
fail:
goto exit;
}
发送的时候有 opId
作为此次通信的标识,代理线程返回数据时也会把这个opId
带回来
所以接受的时候要比较 opId
, 如果与本次发送的 opId
一样,那么就接收成功;
如果不一样,那么把接受的数据放入缓冲区,继续接收
// 轮询等待 opId 的返回数据
res = ncclPollProxyResponse(comm, proxyConn, respBuff, opId);
ncclResult_t ncclPollProxyResponse(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* respBuff, void* opId) {
struct ncclProxyState* sharedProxyState = comm->proxyState;
// Receive the connection pointer from the Proxy
// 检查停止字段
if (*comm->abortFlag) {
WARN("Comm %p is in abort state", comm);
return ncclInternalError;
}
if (sharedProxyState->peerSocks == NULL)
return ncclInternalError;
// Check response queue
int found = 0;
// 如果 opId 在链表中找到,且 done 字段已被置为 True, 则拷贝数据到 respBuff, found 置 1
NCCLCHECK(expectedProxyResponseDequeue(sharedProxyState, opId, respBuff, &found));
if (found == 0) {
// 发送完之后,还没收到回复,虽然有 opId, 但是 done 字段仍为 False, 所以 found == 0
// Attempt to read in a new response header from the proxy thread
// 对于没有父节点的 comm来说,tpLocalRank 就是 comm->localrank
// 获取发送端的 socket
struct ncclSocket* sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;
void* recvOpId;
int offset = 0;
// 接收数据,先接受 opId
if (ncclSuccess != ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)) {
WARN("Socket recv failed while polling for opId=%p", opId);
return ncclInternalError;
}
// 确保接收全部数据, offset == 0 返回 ncclInProgress 继续接收数据
if (offset == 0) {
return ncclInProgress;
// If we've returned a partial response, block to receive the rest of it
} else if (offset < sizeof(recvOpId)) {
while (offset < sizeof(recvOpId))
NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset));
}
INFO(NCCL_PROXY, "ncclPollProxyResponse Received new opId=%p", recvOpId);
// Now do a blocking recv of the response size
int respSize = 0;
// 接收返回数据的大小
NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(respSize)));
// If there's a respSize to recv
if (respSize > 0) {
// 有返回数据
if (recvOpId != opId) {
// Unexpected response, need to buffer the socket data
// 对于意想不到的 opId, 申请内存保存数据
respBuff = malloc(respSize);
}
assert(respBuff != NULL);
// 接收返回的数据
NCCLCHECK(ncclSocketRecv(sock, respBuff, respSize));
}
if (recvOpId == opId) {
// 如果已经接收了 opId 的数据,则在 state->expectedResponses 链表中移除 opId 相对应的项
INFO(NCCL_PROXY, "recvOpId=%p matches expected opId=%p", recvOpId, opId);
NCCLCHECK(expectedProxyResponseRemove(sharedProxyState, recvOpId));
// 返回成功
return ncclSuccess;
} else {
INFO(NCCL_PROXY, "Queuing opId=%p respBuff=%p respSize=%d", recvOpId, respBuff, respSize);
// Store the result and mark response as completed
// 如果接收的是其他 opId 的数据,则拷贝数据到缓冲区,并置 elem->done 为 True
NCCLCHECK(expectedProxyResponseStore(sharedProxyState, recvOpId, respBuff, respSize));
// 返回,继续处理接收数据
return ncclInProgress;
}
} else {
INFO(NCCL_PROXY, "ncclPollProxyResponse Dequeued cached opId=%p", opId);
}
return ncclSuccess;
}