/************************************************************************* * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "nccl.h" #include "core.h" #include "utils.h" #include "bootstrap.h" #include "net.h" #include #include #include "proxy.h" /* Init functions */ static char bootstrapNetIfName[MAX_IF_NAME_SIZE+1]; static union ncclSocketAddress bootstrapNetIfAddr; static int bootstrapNetInitDone = 0; pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; ncclResult_t bootstrapNetInit() { if (bootstrapNetInitDone == 0) { pthread_mutex_lock(&bootstrapNetLock); if (bootstrapNetInitDone == 0) { char* env = getenv("NCCL_COMM_ID"); if (env) { union ncclSocketAddress remoteAddr; if (ncclGetSocketAddrFromString(&remoteAddr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: : or []: or :"); return ncclInvalidArgument; } if (ncclFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { WARN("NET/Socket : No usable listening interface found"); return ncclSystemError; } } else { int nIfs = ncclFindInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1); if (nIfs <= 0) { WARN("Bootstrap : no socket interface found"); return ncclInternalError; } } char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2]; sprintf(line, " %s:", bootstrapNetIfName); ncclSocketToString(&bootstrapNetIfAddr, line+strlen(line)); INFO(NCCL_INIT, "Bootstrap : Using%s", line); bootstrapNetInitDone = 1; } pthread_mutex_unlock(&bootstrapNetLock); } return ncclSuccess; } /* Socket Interface Selection type */ enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; // Additional sync functions static ncclResult_t bootstrapNetSend(struct ncclSocket* sock, void* data, int size) { NCCLCHECK(ncclSocketSend(sock, &size, sizeof(int))); NCCLCHECK(ncclSocketSend(sock, data, size)); return ncclSuccess; } static ncclResult_t bootstrapNetRecv(struct ncclSocket* sock, void* data, int size) { int recvSize; NCCLCHECK(ncclSocketRecv(sock, &recvSize, sizeof(int))); if (recvSize > size) { WARN("Message truncated : received %d bytes instead of %d", recvSize, size); return ncclInternalError; } NCCLCHECK(ncclSocketRecv(sock, data, std::min(recvSize, size))); return ncclSuccess; } struct extInfo { int rank; int nranks; union ncclSocketAddress extAddressListenRoot; union ncclSocketAddress extAddressListen; }; #include static ncclResult_t setFilesLimit() { struct rlimit filesLimit; SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit"); filesLimit.rlim_cur = filesLimit.rlim_max; SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit"); return ncclSuccess; } static void *bootstrapRoot(void* args) { struct ncclSocket* listenSock = (struct ncclSocket*)args; ncclResult_t res = ncclSuccess; int nranks = 0, c = 0; struct extInfo info; union ncclSocketAddress *rankAddresses = NULL; union ncclSocketAddress *rankAddressesRoot = NULL; // for initial rank <-> root information exchange union ncclSocketAddress *zero = NULL; NCCLCHECKGOTO(ncclCalloc(&zero, 1), res, out); setFilesLimit(); TRACE(NCCL_INIT, "BEGIN"); /* Receive addresses from all ranks */ do { struct ncclSocket sock; sock.abortFlag = NULL; NCCLCHECKGOTO(ncclSocketAccept(&sock, listenSock), res, out); NCCLCHECKGOTO(bootstrapNetRecv(&sock, &info, sizeof(info)), res, out); close(sock.fd); if (c == 0) { nranks = info.nranks; NCCLCHECKGOTO(ncclCalloc(&rankAddresses, nranks), res, out); NCCLCHECKGOTO(ncclCalloc(&rankAddressesRoot, nranks), res, out); } if (nranks != info.nranks) { WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nranks); goto out; } if (memcmp(zero, &rankAddressesRoot[info.rank], sizeof(union ncclSocketAddress)) != 0) { WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks); goto out; } // Save the connection handle for that rank memcpy(rankAddressesRoot+info.rank, &info.extAddressListenRoot, sizeof(union ncclSocketAddress)); memcpy(rankAddresses+info.rank, &info.extAddressListen, sizeof(union ncclSocketAddress)); ++c; TRACE(NCCL_INIT, "Received connect from rank %d total %d/%d", info.rank, c, nranks); } while (c < nranks); TRACE(NCCL_INIT, "COLLECTED ALL %d HANDLES", nranks); // Send the connect handle for the next rank in the AllGather ring for (int r=0; rfd); free(listenSock); if (rankAddresses) free(rankAddresses); if (rankAddressesRoot) free(rankAddressesRoot); if (zero) free(zero); TRACE(NCCL_INIT, "DONE"); return NULL; } ncclResult_t bootstrapCreateRoot(ncclUniqueId* id, bool idFromEnv) { struct ncclSocket* listenSock; NCCLCHECK(ncclCalloc(&listenSock, 1)); memcpy(&listenSock->addr, id, sizeof(union ncclSocketAddress)); NCCLCHECK(ncclSocketListen(listenSock)); memcpy(id, &listenSock->addr, sizeof(union ncclSocketAddress)); pthread_t thread; pthread_create(&thread, NULL, bootstrapRoot, (void*)listenSock); pthread_detach(thread); // will not be pthread_join()'d ncclSetThreadName(thread, "NCCL BootstrapR"); return ncclSuccess; } ncclResult_t bootstrapGetUniqueId(ncclUniqueId* id) { static_assert(sizeof(union ncclSocketAddress) < sizeof(ncclUniqueId), "NetId does not fit inside ncclUniqueId"); memset(id, 0, sizeof(ncclUniqueId)); union ncclSocketAddress* connectAddr = (union ncclSocketAddress*) id; char* env = getenv("NCCL_COMM_ID"); if (env) { INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", env); if (ncclGetSocketAddrFromString(connectAddr, env) != ncclSuccess) { WARN("Invalid NCCL_COMM_ID, please use format: : or []: or :"); return ncclInvalidArgument; } } else { memcpy(id, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); NCCLCHECK(bootstrapCreateRoot(id, false)); } return ncclSuccess; } struct unexConn { int peer; int tag; struct ncclSocket sock; struct unexConn* next; }; struct bootstrapState { struct ncclSocket listenSock; struct ncclSocket ringRecvSocket; struct ncclSocket ringSendSocket; union ncclSocketAddress* peerCommAddresses; union ncclSocketAddress* peerProxyAddresses; struct unexConn* unexpectedConnections; int cudaDev; int rank; int nranks; volatile uint32_t *abortFlag; }; ncclResult_t bootstrapInit(ncclUniqueId * id, struct ncclComm* comm) { int rank = comm->rank; int nranks = comm->nRanks; struct bootstrapState* state; NCCLCHECK(ncclCalloc(&state, 1)); state->rank = rank; state->nranks = nranks; state->abortFlag = comm->abortFlag; comm->bootstrap = state; TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks); struct extInfo info = { 0 }; info.rank = rank; info.nranks = nranks; struct ncclSocket sock, listenSockRoot; sock.abortFlag = listenSockRoot.abortFlag = comm->abortFlag; sock.asyncFlag = listenSockRoot.asyncFlag = 0; // Create socket for other ranks to contact me memcpy(&state->listenSock.addr, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); NCCLCHECK(ncclSocketListen(&state->listenSock)); memcpy(&info.extAddressListen, &state->listenSock.addr, sizeof(union ncclSocketAddress)); // Create socket for root to contact me memcpy(&listenSockRoot.addr, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); NCCLCHECK(ncclSocketListen(&listenSockRoot)); memcpy(&info.extAddressListenRoot, &listenSockRoot.addr, sizeof(union ncclSocketAddress)); // stagger connection times to avoid an overload of the root if (nranks > 128) { long msec = rank; struct timespec tv; tv.tv_sec = msec / 1000; tv.tv_nsec = 1000000 * (msec % 1000); TRACE(NCCL_INIT, "rank %d delaying connection to root by %ld msec", rank, msec); (void) nanosleep(&tv, NULL); } // send info on my listening socket to root memcpy(&sock.addr, id, sizeof(union ncclSocketAddress)); NCCLCHECK(ncclSocketConnect(&sock)); NCCLCHECK(bootstrapNetSend(&sock, &info, sizeof(info))); close(sock.fd); // get info on my "next" rank in the bootstrap ring from root NCCLCHECK(ncclSocketAccept(&sock, &listenSockRoot)); NCCLCHECK(bootstrapNetRecv(&sock, &state->ringSendSocket.addr, sizeof(union ncclSocketAddress))); close(sock.fd); close(listenSockRoot.fd); NCCLCHECK(ncclSocketConnect(&state->ringSendSocket)); // Accept the connect request from the previous rank in the AllGather ring NCCLCHECK(ncclSocketAccept(&state->ringRecvSocket, &state->listenSock)); // AllGather all listen handlers NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks)); memcpy(state->peerCommAddresses+rank, &state->listenSock.addr, sizeof(union ncclSocketAddress)); NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union ncclSocketAddress))); // Create the service proxy NCCLCHECK(ncclCalloc(&state->peerProxyAddresses, nranks)); struct ncclSocket* proxySocket; NCCLCHECK(ncclCalloc(&proxySocket, 1)); proxySocket->abortFlag = NULL; // proxy is aborted through a message memcpy(&proxySocket->addr, &bootstrapNetIfAddr, sizeof(union ncclSocketAddress)); NCCLCHECK(ncclSocketListen(proxySocket)); memcpy(state->peerProxyAddresses+rank, &proxySocket->addr, sizeof(union ncclSocketAddress)); NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress))); NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses)); TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks); return ncclSuccess; } ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) { struct bootstrapState* state = (struct bootstrapState*)commState; char* data = (char*)allData; int rank = state->rank; int nranks = state->nranks; TRACE(NCCL_INIT, "rank %d nranks %d size %d", rank, nranks, size); /* Simple ring based AllGather * At each step i receive data from (rank-i-1) from left * and send previous step's data from (rank-i) to right */ for (int i=0; iringSendSocket, data+sslice*size, size)); // Recv slice from the left NCCLCHECK(bootstrapNetRecv(&state->ringRecvSocket, data+rslice*size, size)); } TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size); return ncclSuccess; } ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) { struct bootstrapState* state = (struct bootstrapState*)commState; struct ncclSocket sock; sock.abortFlag = state->abortFlag; memcpy(&sock.addr, state->peerCommAddresses+peer, sizeof(union ncclSocketAddress)); NCCLCHECK(ncclSocketConnect(&sock)); NCCLCHECK(bootstrapNetSend(&sock, &state->rank, sizeof(int))); NCCLCHECK(bootstrapNetSend(&sock, &tag, sizeof(int))); NCCLCHECK(bootstrapNetSend(&sock, data, size)); close(sock.fd); return ncclSuccess; } ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag) { if (nranks == 1) return ncclSuccess; TRACE(NCCL_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag); /* Simple intra process barrier * * Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet, * "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988" */ int data[1]; for (int mask=1; maskpeer = peer; unex->tag = tag; memcpy(&unex->sock, sock, sizeof(struct ncclSocket)); // Enqueue struct unexConn* list = state->unexpectedConnections; if (list == NULL) { state->unexpectedConnections = unex; return ncclSuccess; } while (list->next) list = list->next; list->next = unex; return ncclSuccess; } ncclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, struct ncclSocket* sock) { struct unexConn* elem = state->unexpectedConnections; struct unexConn* prev = NULL; while (elem) { if (elem->peer == peer && elem->tag == tag) { if (prev == NULL) { state->unexpectedConnections = elem->next; } else { prev->next = elem->next; } memcpy(sock, &elem->sock, sizeof(struct ncclSocket)); free(elem); return ncclSuccess; } prev = elem; elem = elem->next; } sock->fd = -1; return ncclSuccess; } // We can't know who we'll receive from, so we need to receive everything at once ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size) { struct bootstrapState* state = (struct bootstrapState*)commState; struct ncclSocket sock; sock.abortFlag = state->abortFlag; // Search unexpected connections first NCCLCHECK(unexpectedDequeue(state, peer, tag, &sock)); if (sock.fd != -1) { NCCLCHECK(bootstrapNetRecv(&sock, ((char*)data), size)); close(sock.fd); return ncclSuccess; } // Then look for new connections while (1) { NCCLCHECK(ncclSocketAccept(&sock, &state->listenSock)); int newPeer, newTag; NCCLCHECK(bootstrapNetRecv(&sock, &newPeer, sizeof(int))); NCCLCHECK(bootstrapNetRecv(&sock, &newTag, sizeof(int))); if (newPeer == peer && newTag == tag) { NCCLCHECK(bootstrapNetRecv(&sock, ((char*)data), size)); close(sock.fd); return ncclSuccess; } // Unexpected connection. Save for later. NCCLCHECK(unexpectedEnqueue(state, newPeer, newTag, &sock)); } } ncclResult_t bootstrapClose(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (state->unexpectedConnections != NULL) { WARN("Unexpected connections are not empty"); return ncclInternalError; } close(state->listenSock.fd); close(state->ringSendSocket.fd); close(state->ringRecvSocket.fd); free(state->peerCommAddresses); free(state); return ncclSuccess; } ncclResult_t bootstrapAbort(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (commState == NULL) return ncclSuccess; if (state->listenSock.fd) close(state->listenSock.fd); if (state->ringSendSocket.fd) close(state->ringSendSocket.fd); if (state->ringRecvSocket.fd) close(state->ringRecvSocket.fd); free(state->peerCommAddresses); free(state->peerProxyAddresses); free(state); return ncclSuccess; }