From 6dcae8a459064049df8a03ecd20fb85e08bf8911 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Thu, 10 Jun 2021 17:51:04 -0700 Subject: [PATCH] Select sendrecv path based on collective data size (#391) * Select sendrecv path based on collective data size * Add comments on packing and unpacking group field * Toggling RCCL_P2P_NET_DISABLE in combined calls unit tests --- src/collectives/device/all_gather.h | 2 +- src/collectives/device/all_reduce.h | 2 +- src/collectives/device/broadcast.h | 2 +- src/collectives/device/primitives.h | 17 +++++---- src/collectives/device/reduce.h | 2 +- src/collectives/device/reduce_scatter.h | 2 +- src/collectives/device/sendrecv.h | 4 +-- src/enqueue.cc | 22 +++++++++--- src/group.cc | 47 +++++++++++++++++++++---- src/include/comm.h | 4 +-- src/include/devcomm.h | 13 ++++--- src/include/info.h | 2 ++ src/include/proxy.h | 2 ++ src/init.cc | 10 ++---- src/proxy.cc | 4 +-- src/transport.cc | 10 +++--- test/test_CombinedCalls.cpp | 2 +- 17 files changed, 102 insertions(+), 45 deletions(-) diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 22132b6d78..dd279b33ab 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -31,7 +31,7 @@ class ncclFunctionrecvbuff; ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex)); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 6f19717dd3..a77b187870 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -37,7 +37,7 @@ class ncclFunctionrecvbuff; ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex)); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) { ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels)); diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h index f9307fcea3..2822da189c 100644 --- a/src/collectives/device/broadcast.h +++ b/src/collectives/device/broadcast.h @@ -33,7 +33,7 @@ class ncclFunctionrecvbuff; ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex)); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 0f71629b75..3fc3f45f0f 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -49,6 +49,13 @@ #define ROLE_POST_SEND 0x10 #define ROLE_POST_RECV 0x20 +// Connection index is used to select P2P and NET and needs to be passed into ncclPrimitives constructor. +// To avoid adding another parameter which requires changes to every places ncclPrimitives are constructed, +// we pack group (max 7) and connection index (max 2) to original group which is 32-bit. +#define PACK_GROUP(gr, idx) (gr | (idx<<16)) +#define TO_GR(group) (group&0xffff) +#define TO_IDX(group) (group>>16) + // Implementation of primitive types template class ncclPrimitives { @@ -75,7 +82,7 @@ class ncclPrimitives { T* direct = NULL; T* buff; struct ncclDevComm* comm; - const int p2pNet; + const int connIndex; const T** srcs; T** dsts; @@ -247,8 +254,6 @@ class ncclPrimitives { __device__ __forceinline__ void loadRecvConn(struct ncclChannel* channel, T* directBuff) { if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) { // For oneshot: groups 0,1 use conn 0, groups 2,3 use conn 1 - const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/2 : - ((p2pNet && (NSEND+NRECV) == 1 ? NCCL_CONN_IDX_P2P_NET : ((NSEND+NRECV) == 1 ? 0 : group))); conn = &channel->devPeers[peer].recv[connIndex].conn; step = conn->step; step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS); @@ -273,8 +278,6 @@ class ncclPrimitives { __device__ __forceinline__ void loadSendConn(struct ncclChannel* channel) { if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) { // For oneshot: groups 0,1 use conn 0, groups 2,3 use conn 1 - const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/2 : - ((p2pNet && (NSEND+NRECV) == 1 ? NCCL_CONN_IDX_P2P_NET : ((NSEND+NRECV) == 1 ? 0 : group))); conn = &channel->devPeers[peer].send[connIndex].conn; step = conn->step; step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS); @@ -308,7 +311,9 @@ class ncclPrimitives { public: __device__ __forceinline__ ncclPrimitives(const int tid, const int nworkers, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, struct ncclShmemPtrs* ptrs, int group) - : comm(comm), tid(tid), nworkers(nworkers), stepSize(stepSize), srcs((const T**)ptrs[group].srcs), dsts((T**)ptrs[group].dsts), group(group), barriers(&ptrs[group].barrier), barrier_next(ptrs[group].barrier_next), p2pNet(*comm->p2pNet) { + : comm(comm), tid(tid), nworkers(nworkers), stepSize(stepSize), srcs((const T**)ptrs[TO_GR(group)].srcs), dsts((T**)ptrs[TO_GR(group)].dsts), + group(TO_GR(group)), barriers(&ptrs[TO_GR(group)].barrier), barrier_next(ptrs[TO_GR(group)].barrier_next), + connIndex((NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? TO_GR(group)/2 : TO_IDX(group)) { nthreads = nworkers; // For send operations, we need an extra warp to overlap the threadfence and the copy // int postThreads = NSEND && nworkers >= 64 ? WARP_SIZE : 0; diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h index 90ef8294c1..871d14263a 100644 --- a/src/collectives/device/reduce.h +++ b/src/collectives/device/reduce.h @@ -34,7 +34,7 @@ class ncclFunctionrecvbuff; ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex)); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index 7b3e7a7c1a..6f24c27d03 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -31,7 +31,7 @@ class ncclFunctionrecvbuff; ncclPrimitives - prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(0, args->coll.connIndex)); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels)); diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h index 2471a43501..ab1d2c0bfe 100644 --- a/src/collectives/device/sendrecv.h +++ b/src/collectives/device/sendrecv.h @@ -57,7 +57,7 @@ class ncclFunctionrank-delta+comm->nRanks)%comm->nRanks; int nt = nThreadsSplit; ncclPrimitives - prims(tid, nt, &peer, NULL, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupRecv); + prims(tid, nt, &peer, NULL, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(groupRecv, args->p2p.recvIdx)); if (recvCount == 0) { prims.recv(recvbuff, 0); @@ -73,7 +73,7 @@ class ncclFunctionrank+delta)%comm->nRanks; int nt = nThreads-nThreadsSplit; ncclPrimitives - prims(tid-nThreadsSplit, nt, NULL, &peer, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, groupSend); + prims(tid-nThreadsSplit, nt, NULL, &peer, recvbuff, stepSize, channel, comm, ncclShmem->ptrs, PACK_GROUP(groupSend, args->p2p.sendIdx)); if (sendCount == 0) { prims.send(sendbuff, 0); diff --git a/src/enqueue.cc b/src/enqueue.cc index 1c09c93b3c..723f043f92 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -701,9 +701,13 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) { int delta = (comm->nRanks - (comm->rank-peer)) % comm->nRanks; for (int c=0; cp2pnChannelsPerPeer; c++) { int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels; - if (comm->channels[channelId].peers[peer].send[NCCL_CONN_IDX_P2P].connected == 0) { // P2P uses only 1 connector + if (comm->channels[channelId].peers[peer].send[0].connected == 0) { comm->connectSend[peer] |= (1<connect = 1; + comm->connect[0] = 1; + } + if (comm->p2pNet && comm->channels[channelId].peers[peer].send[NCCL_CONN_IDX_P2P_NET].connected == 0) { + comm->connectSend[peer+comm->nRanks*NCCL_CONN_IDX_P2P_NET] |= (1<connect[NCCL_CONN_IDX_P2P_NET] = 1; } } } @@ -714,9 +718,13 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) { int delta = (comm->nRanks + (comm->rank-peer)) % comm->nRanks; for (int c=0; cp2pnChannelsPerPeer; c++) { int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels; - if (comm->channels[channelId].peers[peer].recv[NCCL_CONN_IDX_P2P].connected == 0) { // P2P uses only 1 connector + if (comm->channels[channelId].peers[peer].recv[0].connected == 0) { comm->connectRecv[peer] |= (1<connect = 1; + comm->connect[0] = 1; + } + if (comm->p2pNet && comm->channels[channelId].peers[peer].recv[NCCL_CONN_IDX_P2P_NET].connected == 0) { + comm->connectRecv[peer+comm->nRanks*NCCL_CONN_IDX_P2P_NET] |= (1<connect[NCCL_CONN_IDX_P2P_NET] = 1; } } } @@ -792,10 +800,16 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) { // Compute cuda kernel arg and proxy arg templates struct ncclQueueElem* eqElem; NCCLCHECK(ncclAddQueueElem(comm->enqueueInfo, &eqElem)); + // The proxy code will set and tune the send/recv chunk size, make sure to run it first. NCCLCHECK(ncclProxyComputeP2p(info, &eqElem->proxyArgs)); NCCLCHECK(computeP2pWorkElem(info, &eqElem->work)); + eqElem->proxyArgs.sendIdx = info->sendIdx; + eqElem->proxyArgs.recvIdx = info->recvIdx; + eqElem->work.p2p.sendIdx = info->sendIdx; + eqElem->work.p2p.recvIdx = info->recvIdx; + int channelId = info->channelId; hipLaunchParams* params = comm->myParams; params->gridDim.x = std::max(params->gridDim.x, channelId+1); diff --git a/src/group.cc b/src/group.cc index 61ff36fa9f..89d9d38926 100644 --- a/src/group.cc +++ b/src/group.cc @@ -36,6 +36,7 @@ struct ncclInitArgs { }; struct ncclCollArgs { ncclComm_t comm; + uint16_t connIndex; }; enum ncclAsyncFuncType { @@ -118,7 +119,8 @@ ncclResult_t ncclGroupStart() { return ncclSuccess; } -static ncclResult_t scheduleSendRecv(struct ncclComm* comm, int delta, int channelId, ssize_t recvbytes, void* recvbuff, ssize_t sendbytes, const void* sendbuff) { +static ncclResult_t scheduleSendRecv(struct ncclComm* comm, int delta, int channelId, ssize_t recvbytes, + void* recvbuff, ssize_t sendbytes, const void* sendbuff, uint16_t sendIdx, uint16_t recvIdx) { struct ncclInfo info = { ncclFuncSendRecv, "SendRecv", sendbuff, recvbuff, (size_t)std::max(sendbytes,recvbytes), ncclInt8, ncclSum, -1, comm, comm->userStream, /* Args */ 1, 1 }; @@ -126,6 +128,8 @@ static ncclResult_t scheduleSendRecv(struct ncclComm* comm, int delta, int chann info.channelId = channelId; info.sendbytes = sendbytes; info.recvbytes = recvbytes; + info.sendIdx = sendIdx; + info.recvIdx = recvIdx; if (delta == 0 && sendbytes != recvbytes) return ncclInvalidUsage; NCCLCHECK(ncclSetupP2pKernel(&info)); return ncclSuccess; @@ -135,7 +139,7 @@ void* ncclAsyncThreadPreconnect(void* args_) { struct ncclAsyncArgs* args = (struct ncclAsyncArgs*)args_; struct ncclComm* comm = args->coll.comm; CUDACHECKTHREAD(hipSetDevice(comm->cudaDev)); - NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, NCCL_CONN_IDX_P2P)); + NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, args->coll.connIndex)); return args; } @@ -150,6 +154,8 @@ static size_t getP2pChunkSize(size_t totalSize, int minChannels, int maxChannels return size; } +RCCL_PARAM(P2pNetThreshold, "RCCL_P2P_NET_THRESHOLD", 131072); + NCCL_API(ncclResult_t, ncclGroupEnd); ncclResult_t ncclGroupEnd() { NVTX3_FUNC_RANGE_IN(nccl_domain); @@ -195,14 +201,15 @@ ncclResult_t ncclGroupEnd() { for (int i=0; ifuncType == ASYNC_FUNC_COLL && args->coll.comm->connect) { + if (args->funcType == ASYNC_FUNC_COLL && args->coll.comm->connect[0]) { + args->coll.connIndex = 0; pthread_create(ncclGroupThreads+i, NULL, ncclAsyncThreadPreconnect, args); } } for (int i=0; ifuncType == ASYNC_FUNC_COLL && args->coll.comm->connect) { + if (args->funcType == ASYNC_FUNC_COLL && args->coll.comm->connect[0]) { int err = pthread_join(ncclGroupThreads[i], NULL); if (err != 0) { WARN("Error waiting for pthread_join : %s", strerror(errno)); @@ -210,7 +217,29 @@ ncclResult_t ncclGroupEnd() { } INFO(NCCL_INIT, "comm %p rank %d total %ld bytes - P2P preconnect COMPLETE", args->coll.comm, args->coll.comm->rank, allocTracker[args->coll.comm->cudaDev].totalAllocSize); NCCLCHECKGOTO(args->ret, ret, end); - args->coll.comm->connect = 0; + args->coll.comm->connect[0] = 0; + } + } + + for (int i=0; ifuncType == ASYNC_FUNC_COLL && args->coll.comm->connect[NCCL_CONN_IDX_P2P_NET]) { + args->coll.connIndex = NCCL_CONN_IDX_P2P_NET; + pthread_create(ncclGroupThreads+i, NULL, ncclAsyncThreadPreconnect, args); + } + } + + for (int i=0; ifuncType == ASYNC_FUNC_COLL && args->coll.comm->connect[NCCL_CONN_IDX_P2P_NET]) { + int err = pthread_join(ncclGroupThreads[i], NULL); + if (err != 0) { + WARN("Error waiting for pthread_join : %s", strerror(errno)); + return ncclSystemError; + } + INFO(NCCL_INIT, "comm %p rank %d total %ld bytes - P2P NET preconnect COMPLETE", args->coll.comm, args->coll.comm->rank, allocTracker[args->coll.comm->cudaDev].totalAllocSize); + NCCLCHECKGOTO(args->ret, ret, end); + args->coll.comm->connect[NCCL_CONN_IDX_P2P_NET] = 0; } } @@ -253,6 +282,12 @@ sched_delta: ssize_t recvChunkSize = getP2pChunkSize(totRecvBytes, nChannelsMin, nChannelsMax, stepSize, SENDRECV_SLICEFACTOR*stepSize); ssize_t sendChunkSize = getP2pChunkSize(totSendBytes, nChannelsMin, nChannelsMax, stepSize, SENDRECV_SLICEFACTOR*stepSize); + uint16_t sendIdx = 0, recvIdx = 0; + if(comm->p2pNet && totSendBytes > rcclParamP2pNetThreshold()) + sendIdx = NCCL_CONN_IDX_P2P_NET; + if(comm->p2pNet && totRecvBytes > rcclParamP2pNetThreshold()) + recvIdx = NCCL_CONN_IDX_P2P_NET; + ssize_t sendOffset = 0; ssize_t recvOffset = 0; int sendRemaining = 1, recvRemaining = 1; @@ -270,7 +305,7 @@ sched_delta: if (sendbytes >= 0 || recvbytes >= 0) { NCCLCHECKGOTO(scheduleSendRecv(comm, delta, channelId, recvbytes, recv ? ((char*)(recv->buff)) + recvOffset : NULL, - sendbytes, send ? ((const char*)(send->buff)) + sendOffset : NULL), ret, group_cleanup); + sendbytes, send ? ((const char*)(send->buff)) + sendOffset : NULL, sendIdx, recvIdx), ret, group_cleanup); } recvOffset += recvChunkSize; sendOffset += sendChunkSize; diff --git a/src/include/comm.h b/src/include/comm.h index 3bc5608332..a0774321e5 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -74,7 +74,7 @@ struct ncclComm { void* bootstrap; // Bitmasks for ncclTransportP2pSetup - int connect; + int connect[NCCL_MAX_CONNS]; uint32_t* connectSend; uint32_t* connectRecv; @@ -128,7 +128,7 @@ struct ncclComm { volatile uint32_t *abortFlag; // Flags for enable P2P NET - uint32_t *p2pNet; + uint32_t p2pNet; uint32_t useIntraNet; // Device side of the communicator diff --git a/src/include/devcomm.h b/src/include/devcomm.h index d9b3884de6..51b57ec790 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -163,7 +163,6 @@ struct ncclDirect { int down[NCCL_MAX_DIRECT_ARITY]; }; -#define NCCL_CONN_IDX_P2P (*(comm->p2pNet)*2) #define NCCL_CONN_IDX_P2P_NET 2 #define NCCL_MAX_CONNS 3 struct ncclPeer { @@ -208,7 +207,14 @@ struct ncclWorkElem { int sendChunkSize; int recvChunkSize; int32_t delta; - uint16_t nThreads; + union { + struct { + uint16_t nThreads:12; + uint16_t sendIdx:2; + uint16_t recvIdx:2; + }; + uint16_t padding; + }; } p2p; struct { uint16_t padding[15]; @@ -357,9 +363,6 @@ struct ncclDevComm { // Channels, device side struct ncclChannel* channels; - // Flags for enable P2P NET - uint32_t *p2pNet; - #ifdef ENABLE_PROFILING // Profiling counters struct ncclProf* devProf; diff --git a/src/include/info.h b/src/include/info.h index a4856893bf..d35eb99b8e 100644 --- a/src/include/info.h +++ b/src/include/info.h @@ -53,6 +53,8 @@ struct ncclInfo { int sendChunkSize; uint32_t delta; int channelId; + uint16_t sendIdx; + uint16_t recvIdx; }; #endif diff --git a/src/include/proxy.h b/src/include/proxy.h index 2b7b66931d..1cae10e533 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -61,6 +61,8 @@ struct ncclProxyArgs { int idle; uint64_t hdp_flushed; uint8_t connIndex; + uint8_t sendIdx; + uint8_t recvIdx; // Element linking pthread_mutex_t mutex; diff --git a/src/init.cc b/src/init.cc index f6e6dd4e22..0a4c364970 100644 --- a/src/init.cc +++ b/src/init.cc @@ -377,7 +377,6 @@ static ncclResult_t commFree(ncclComm_t comm) { free(comm->intraCC); } NCCLCHECK(ncclCudaHostFree((void *)comm->abortFlag)); - NCCLCHECK(ncclCudaHostFree((void *)comm->p2pNet)); // Poison comm to try and catch a double free commPoison(comm); @@ -430,9 +429,6 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { comm->hostDevComm.abortFlag = comm->abortFlag; STORE(comm->abortFlag, 0); - NCCLCHECK(ncclCudaHostCalloc((uint32_t**)&comm->p2pNet, 1)); - comm->hostDevComm.p2pNet = comm->p2pNet; - STORE(comm->p2pNet, 0); comm->collOpCount = 0; comm->p2pOpCount = 0x8000; @@ -466,8 +462,8 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { static_assert(MAXCHANNELS <= sizeof(*comm->connectSend)*8, "comm->connectSend must have enough bits for all channels"); static_assert(MAXCHANNELS <= sizeof(*comm->connectRecv)*8, "comm->connectRecv must have enough bits for all channels"); - NCCLCHECK(ncclCalloc(&comm->connectSend, comm->nRanks)); - NCCLCHECK(ncclCalloc(&comm->connectRecv, comm->nRanks)); + NCCLCHECK(ncclCalloc(&comm->connectSend, comm->nRanks*NCCL_MAX_CONNS)); + NCCLCHECK(ncclCalloc(&comm->connectRecv, comm->nRanks*NCCL_MAX_CONNS)); comm->p2pSendCount = comm->p2pRecvCount = 0; NCCLCHECK(ncclCalloc(&comm->p2pSends, comm->nRanks)); @@ -837,7 +833,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm if ((comm->topo->type & RCCL_TOPO_4P2H_ROME) && (comm->topo->type & RCCL_TOPO_GDR_ALL)) { if (rcclParamP2pNetDisable() == 0) { - STORE(comm->p2pNet, 1); + comm->p2pNet = 1; INFO(NCCL_INIT, "RCCL enabled same node P2P over network"); } else diff --git a/src/proxy.cc b/src/proxy.cc index dbddaee536..71b6894028 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -287,7 +287,7 @@ ncclResult_t ncclProxySaveP2p(struct ncclComm* comm, struct ncclProxyArgs* args) sub->sendbytes = 0; sub->nsteps = DIVUP(sub->recvbytes, sub->recvChunkSize); if (sub->nsteps == 0) sub->nsteps = 1; - NCCLCHECK(SaveProxy(proxyRecv, peerrecv, args, NCCL_CONN_IDX_P2P)); + NCCLCHECK(SaveProxy(proxyRecv, peerrecv, args, args->recvIdx)); } if (sub->delta > 0 && sendbytesOrig >= ssize_t(0)) { int peersend = (comm->rank+sub->delta)%comm->nRanks; @@ -295,7 +295,7 @@ ncclResult_t ncclProxySaveP2p(struct ncclComm* comm, struct ncclProxyArgs* args) sub->recvbytes = 0; sub->nsteps = DIVUP(sub->sendbytes, sub->sendChunkSize); if (sub->nsteps == 0) sub->nsteps = 1; - NCCLCHECK(SaveProxy(proxySend, peersend, args, NCCL_CONN_IDX_P2P)); + NCCLCHECK(SaveProxy(proxySend, peersend, args, args->sendIdx)); } // Reset proxy args for potentially multiple cuda graph launches // It is safe as long as SaveProxy copies contents of args to op diff --git a/src/transport.cc b/src/transport.cc index b1237fde93..ef6253fd68 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -69,12 +69,12 @@ ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* for (int i=0; i= comm->nRanks || peer == comm->rank || channel->peers[peer].recv[connIndex].connected) continue; - comm->connectRecv[peer] |= mask; + comm->connectRecv[peer+comm->nRanks*connIndex] |= mask; } for (int i=0; i= comm->nRanks || peer == comm->rank || channel->peers[peer].send[connIndex].connected) continue; - comm->connectSend[peer] |= mask; + comm->connectSend[peer+comm->nRanks*connIndex] |= mask; } return ncclSuccess; } @@ -98,8 +98,8 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; int sendPeer = (comm->rank + i) % comm->nRanks; - uint32_t recvMask = comm->connectRecv[recvPeer]; - uint32_t sendMask = comm->connectSend[sendPeer]; + uint32_t recvMask = comm->connectRecv[recvPeer+comm->nRanks*connIndex]; + uint32_t sendMask = comm->connectSend[sendPeer+comm->nRanks*connIndex]; struct ncclConnect* recvData = data; int sendChannels = 0, recvChannels = 0; @@ -145,7 +145,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* CUDACHECK(hipMemcpyAsync(comm->channels[c].devPeers[recvPeer].recv+connIndex, conn, sizeof(struct ncclConnector), hipMemcpyHostToDevice, transportSetupStream)); } } - comm->connectRecv[recvPeer] = comm->connectSend[sendPeer] = 0; + comm->connectRecv[recvPeer+comm->nRanks*connIndex] = comm->connectSend[sendPeer+comm->nRanks*connIndex] = 0; } CUDACHECK(hipStreamSynchronize(transportSetupStream)); CUDACHECK(hipStreamDestroy(transportSetupStream)); diff --git a/test/test_CombinedCalls.cpp b/test/test_CombinedCalls.cpp index 851bc5acc8..278c1bf067 100644 --- a/test/test_CombinedCalls.cpp +++ b/test/test_CombinedCalls.cpp @@ -122,6 +122,6 @@ namespace CorrectnessTests testing::Values(2,3,4,5,6,7,8), // In-place or not testing::Values(false), - testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), + testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1", "RCCL_P2P_NET_DISABLE=0", "RCCL_P2P_NET_DISABLE=1")), CorrectnessTest::PrintToStringParamName()); } // namespace