From 47dfed5f01d5ddc81d272e9aa4ef3bd13376be11 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Wed, 8 Jul 2020 11:06:50 -0700 Subject: [PATCH] Revert "Split primitive class to smaller structures" (#230) This reverts commit 622b49e80a00fe6292d930073fedf529101a2a41. [ROCm/rccl commit: 52151301682ed5ab5c6490f9da8fdb933ae3e36e] --- .../rccl/src/collectives/device/all_reduce.h | 12 +- .../rccl/src/collectives/device/primitives.h | 206 ++++++++---------- .../rccl/src/collectives/device/prims_ll.h | 143 +++++------- 3 files changed, 151 insertions(+), 210 deletions(-) diff --git a/projects/rccl/src/collectives/device/all_reduce.h b/projects/rccl/src/collectives/device/all_reduce.h index 0ea08cdaa7..fe36524dbd 100644 --- a/projects/rccl/src/collectives/device/all_reduce.h +++ b/projects/rccl/src/collectives/device/all_reduce.h @@ -129,8 +129,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitivesRecvData recvData; - ncclPrimitives prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount, recvData); + ncclPrimitives prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -148,8 +147,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitivesSendData sendData; - ncclPrimitives prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount, sendData); + ncclPrimitives prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; @@ -323,8 +321,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclLLPrimitivesRecvData recvData; - ncclLLPrimitives LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount, recvData); + ncclLLPrimitives LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -342,8 +339,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) { do { struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclLLPrimitivesSendData sendData; - ncclLLPrimitives LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount, sendData); + ncclLLPrimitives LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize; diff --git a/projects/rccl/src/collectives/device/primitives.h b/projects/rccl/src/collectives/device/primitives.h index b5c6d24bfb..d6d047470a 100644 --- a/projects/rccl/src/collectives/device/primitives.h +++ b/projects/rccl/src/collectives/device/primitives.h @@ -39,41 +39,6 @@ while (LOAD(barriers+id) < barrier_next[id*MAXWARPS+w]) /* spin */; \ } while (0) -template -class ncclPrimitivesRecvData { -public: - struct ncclConnInfo* recvConn = NULL; - volatile uint64_t* recvConnHeadPtr = NULL; - uint64_t recvConnHead; - volatile uint64_t* recvConnTailPtr = NULL; - uint64_t recvConnTail; - uint64_t recvConnTailCache; // Cache last seen value - - uint64_t recvStep[NRECV]; -#if defined(RCCL_USE_DIRECT_BUFFER) - const T* recvDirectBuff[NRECV]; -#endif - const T* recvBuff[NRECV]; -}; - -template -class ncclPrimitivesSendData { -public: - struct ncclConnInfo* sendConn = NULL; - volatile int* sendConnFifoPtr = NULL; - volatile uint64_t* sendConnTailPtr = NULL; - uint64_t sendConnTail; - volatile uint64_t* sendConnHeadPtr = NULL; - uint64_t sendConnHead; - uint64_t sendConnHeadCache; // Cache last seen value - - uint64_t sendStep[NSEND]; -#if defined(RCCL_USE_DIRECT_BUFFER) - const T* sendDirectBuff[NSEND]; -#endif - T* sendBuff[NSEND]; -}; - // Implementation of primitive types template class ncclPrimitives { @@ -84,18 +49,35 @@ class ncclPrimitives { const int stepSize; int nrecv = 0; int nsend = 0; + struct ncclConnInfo* recvConn = NULL; + volatile uint64_t* recvConnHeadPtr = NULL; + uint64_t recvConnHead; + volatile uint64_t* recvConnTailPtr = NULL; + uint64_t recvConnTail; + uint64_t recvConnTailCache; // Cache last seen value - typename std::conditional&, ncclPrimitivesRecvData>::type r; - typename std::conditional&, ncclPrimitivesSendData>::type s; + struct ncclConnInfo* sendConn = NULL; + volatile int* sendConnFifoPtr = NULL; + volatile uint64_t* sendConnTailPtr = NULL; + uint64_t sendConnTail; + volatile uint64_t* sendConnHeadPtr = NULL; + uint64_t sendConnHead; + uint64_t sendConnHeadCache; // Cache last seen value - const struct ncclDevComm* comm; + uint64_t recvStep[NRECV]; + uint64_t sendStep[NSEND]; +#if defined(RCCL_USE_DIRECT_BUFFER) + const T* recvDirectBuff[NRECV]; + T* sendDirectBuff[NSEND]; +#endif + const T* recvBuff[NRECV]; + T* sendBuff[NSEND]; + struct ncclDevComm* comm; - inline __device__ int recvOffset(int i) { return (r.recvStep[i]%NCCL_STEPS)*stepSize; } - inline __device__ int sendOffset(int i) { return (s.sendStep[i]%NCCL_STEPS)*stepSize; } - inline __device__ const T* recvPtr(int i) { return ((const T*)r.recvBuff[i])+recvOffset(i); } - inline __device__ T* sendPtr(int i) { return ((T*)s.sendBuff[i])+sendOffset(i); } + inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; } + inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; } + inline __device__ const T* recvPtr(int i) { return ((const T*)recvBuff[i])+recvOffset(i); } + inline __device__ T* sendPtr(int i) { return ((T*)sendBuff[i])+sendOffset(i); } uint64_t* barriers; uint64_t* barrier_next; @@ -109,7 +91,11 @@ class ncclPrimitives { else barrier_by_id(1); } #else - asm volatile ("bar.sync 1, %0;" :: "r"(nthreads+WARP_SIZE)); + if (NSEND>NRECV) { + asm volatile ("bar.sync 1, %0;" :: "r"(nthreads+WARP_SIZE)); + } else { + asm volatile ("bar.sync 2, %0;" :: "r"(nthreads+WARP_SIZE)); + } #endif } @@ -117,7 +103,11 @@ class ncclPrimitives { #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) __syncthreads(); #else - asm volatile ("bar.sync 2, %0;" :: "r"(nthreads)); + if (NSEND>NRECV) { + asm volatile ("bar.sync 3, %0;" :: "r"(nthreads)); + } else { + asm volatile ("bar.sync 4, %0;" :: "r"(nthreads)); + } #endif } @@ -140,7 +130,7 @@ class ncclPrimitives { spins++; if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { abort = LOAD(comm->abortFlag); - if (wid == i) checkMismatch(send ? s.sendConn : r.recvConn); + if (wid == i) checkMismatch(send ? sendConn : recvConn); spins = 0; } return abort; @@ -149,57 +139,57 @@ class ncclPrimitives { inline __device__ void waitSend(int nbytes) { spins = 0; mismatch = 0; - if (s.sendConnHeadPtr) { - while (s.sendConnHeadCache + NCCL_STEPS < s.sendConnHead + SLICESTEPS) { - s.sendConnHeadCache = LOAD(s.sendConnHeadPtr); + if (sendConnHeadPtr) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) { + sendConnHeadCache = LOAD(sendConnHeadPtr); if (checkAbort(wid, 1)) break; } - if (s.sendConnFifoPtr) { - STORE(s.sendConnFifoPtr+s.sendConnHead%NCCL_STEPS, nbytes); + if (sendConnFifoPtr) { + STORE(sendConnFifoPtr+sendConnHead%NCCL_STEPS, nbytes); } - s.sendConnHead += SLICESTEPS; + sendConnHead += SLICESTEPS; } } inline __device__ void waitRecv() { spins = 0; mismatch = 0; - if (r.recvConnTailPtr) { + if (recvConnTailPtr) { #ifdef ENABLE_PROFILING uint64_t t0 = __rtc64(); #endif - while (r.recvConnTailCache < r.recvConnTail + SLICESTEPS) { - r.recvConnTailCache = LOAD(r.recvConnTailPtr); + while (recvConnTailCache < recvConnTail + SLICESTEPS) { + recvConnTailCache = LOAD(recvConnTailPtr); if (checkAbort(wid, 0)) break; } #ifdef ENABLE_PROFILING if (opCount > 0) __atomic_fetch_add(&comm->devProf->wait_recv_cycle[blockIdx.x], __rtc64() - t0, __ATOMIC_SEQ_CST); #endif - r.recvConnTail += SLICESTEPS; + recvConnTail += SLICESTEPS; } } inline __device__ void incRecv(int i) { - r.recvStep[i] += SLICESTEPS; + recvStep[i] += SLICESTEPS; } inline __device__ void postRecv() { - if (r.recvConnHeadPtr) STORE(r.recvConnHeadPtr, r.recvConnHead += SLICESTEPS); + if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += SLICESTEPS); } inline __device__ void incSend(int i) { - s.sendStep[i] += SLICESTEPS; + sendStep[i] += SLICESTEPS; } inline __device__ void postSend() { - if (s.sendConnTailPtr) { - if (s.sendConn->next_hdp_reg) STORE(s.sendConn->next_hdp_reg, 0x1); - STORE(s.sendConnTailPtr, s.sendConnTail += SLICESTEPS); + if (sendConnTailPtr) { + if (sendConn->next_hdp_reg) STORE(sendConn->next_hdp_reg, 0x1); + STORE(sendConnTailPtr, sendConnTail += SLICESTEPS); } } template inline __device__ const T* directRecvPtr(int i, ssize_t directOffset) { #if defined(RCCL_USE_DIRECT_BUFFER) - return DIRECTRECV && r.recvDirectBuff[i] ? r.recvDirectBuff[i]+directOffset : recvPtr(i); + return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i); #else return recvPtr(i); #endif @@ -208,7 +198,7 @@ class ncclPrimitives { template inline __device__ T* directSendPtr(int i, ssize_t directOffset) { #if defined(RCCL_USE_DIRECT_BUFFER) - return DIRECTSEND && s.sendDirectBuff[i] ? s.sendDirectBuff[i]+directOffset : sendPtr(i); + return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i); #else return sendPtr(i); #endif @@ -217,7 +207,7 @@ class ncclPrimitives { template inline __device__ int directRecvInc(int i, int directInc, int sliceInc) { #if defined(RCCL_USE_DIRECT_BUFFER) - return DIRECTRECV && r.recvDirectBuff[i] ? directInc : sliceInc; + return DIRECTRECV && recvDirectBuff[i] ? directInc : sliceInc; #else return sliceInc; #endif @@ -226,7 +216,7 @@ inline __device__ int directRecvInc(int i, int directInc, int sliceInc) { template inline __device__ int directSendInc(int i, int directInc, int sliceInc) { #if defined(RCCL_USE_DIRECT_BUFFER) - return DIRECTSEND && s.sendDirectBuff[i] ? directInc : sliceInc; + return DIRECTSEND && sendDirectBuff[i] ? directInc : sliceInc; #else return sliceInc; #endif @@ -267,7 +257,7 @@ inline __device__ int directSendInc(int i, int directInc, int sliceInc) { if (tid == 0 && opCount > 0) __atomic_fetch_add(&comm->devProf->wait_cycle[blockIdx.x], __rtc64() - t0, __ATOMIC_SEQ_CST); #endif #if defined(RCCL_USE_DIRECT_BUFFER) - if (DIRECTRECV && r.recvDirectBuff[0]) { + if (DIRECTRECV && recvDirectBuff[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (SEND) { ReduceOrCopyMulti(tid, nthreads, 1, srcs, nsend, dsts+1, realSize); @@ -299,81 +289,84 @@ inline __device__ int directSendInc(int i, int directInc, int sliceInc) { } __device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) { - r.recvBuff[i] = (const T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE); - r.recvStep[i] = LOAD(&conn->step); - r.recvStep[i] = ROUNDUP(r.recvStep[i], SLICESPERCHUNK*SLICESTEPS); + recvBuff[i] = (const T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE); + recvStep[i] = LOAD(&conn->step); + recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS); #if defined(RCCL_USE_DIRECT_BUFFER) - r.recvDirectBuff[i] = NULL; + recvDirectBuff[i] = NULL; if (DIRECT && LOAD((&conn->direct) & NCCL_DIRECT_GPU)) { - r.recvDirectBuff[i] = directBuff; + recvDirectBuff[i] = directBuff; if (tid == 0) STORE(conn->ptrExchange, directBuff); } #endif - if (wid == i) r.recvConn = conn; - if (wid == i) r.recvConnTail = r.recvConnHead = r.recvStep[i]; // Make sure we set this after rounding up + if (wid == i) recvConn = conn; + if (wid == i) recvConnTail = recvConnHead = recvStep[i]; // Make sure we set this after rounding up nrecv++; } __device__ __forceinline__ void loadRecvSync() { if (tid >= WARP_SIZE && tid < 2*WARP_SIZE && widtail); - r.recvConnTailCache = LOAD(r.recvConnTailPtr); + recvConnTailPtr = LOAD(&recvConn->tail); + recvConnTailCache = LOAD(recvConnTailPtr); } if (tid >= nthreads-WARP_SIZE && wid < nrecv) { - r.recvConnHeadPtr = LOAD(&r.recvConn->head); + recvConnHeadPtr = LOAD(&recvConn->head); // Return credits in case we rounded up. - STORE(r.recvConnHeadPtr, r.recvConnHead); + STORE(recvConnHeadPtr, recvConnHead); // Update opCount in case we skipped some operations - STORE(r.recvConn->opCountLoc, opCount); + STORE(recvConn->opCountLoc, opCount); } } __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { - s.sendBuff[i] = (T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE); - s.sendStep[i] = LOAD(&conn->step); - s.sendStep[i] = ROUNDUP(s.sendStep[i], SLICESPERCHUNK*SLICESTEPS); + sendBuff[i] = (T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE); + sendStep[i] = LOAD(&conn->step); + sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS); #if defined(RCCL_USE_DIRECT_BUFFER) - s.sendDirectBuff[i] = NULL; + sendDirectBuff[i] = NULL; if (DIRECT && LOAD((&conn->direct) & NCCL_DIRECT_GPU)) { void* volatile* ptr = LOAD(&conn->ptrExchange); - while ((s.sendDirectBuff[i] = (T*)(LOAD(ptr))) == NULL); + while ((sendDirectBuff[i] = (T*)(LOAD(ptr))) == NULL); barrier(); if (tid == 0) STORE(ptr, NULL); } #endif - if (wid == i) s.sendConn = conn; - if (wid == i) s.sendConnTail = s.sendConnHead = s.sendStep[i]; // Make sure we set this after rounding up + if (wid == i) sendConn = conn; + if (wid == i) sendConnTail = sendConnHead = sendStep[i]; // Make sure we set this after rounding up nsend++; } __device__ __forceinline__ void loadSendSync() { if (tid < nsend) { - s.sendConnHeadPtr = LOAD(&s.sendConn->head); - s.sendConnHeadCache = LOAD(s.sendConnHeadPtr); - s.sendConnFifoPtr = LOAD(&s.sendConn->fifo); - STORE(s.sendConn->opCountLoc, opCount); + sendConnHeadPtr = LOAD(&sendConn->head); + sendConnHeadCache = LOAD(sendConnHeadPtr); + sendConnFifoPtr = LOAD(&sendConn->fifo); + STORE(sendConn->opCountLoc, opCount); } if (tid >= nthreads-WARP_SIZE && wid < nsend) { - s.sendConnTailPtr = LOAD(&s.sendConn->tail); + sendConnTailPtr = LOAD(&sendConn->tail); } } __device__ __forceinline__ void saveRecvSync() { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { - STORE(&r.recvConn->step, r.recvConnHead); - STORE(r.recvConn->opCountLoc, opCount+1); + STORE(&recvConn->step, recvConnHead); + STORE(recvConn->opCountLoc, opCount+1); __threadfence_system(); } } __device__ __forceinline__ void saveSendSync() { if (tid < nsend) { - STORE(&s.sendConn->step, s.sendConnHead); - STORE(s.sendConn->opCountLoc, opCount+1); + STORE(&sendConn->step, sendConnHead); + STORE(sendConn->opCountLoc, opCount+1); __threadfence_system(); } } - inline __device__ void init(int* recvPeers, int* sendPeers, struct ncclChannel* channel) { + public: + __device__ __forceinline__ + ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) { barriers = channel->barrier; barrier_next = channel->barrier_next; // Make sure step is updated before we read it. @@ -385,25 +378,6 @@ inline __device__ int directSendInc(int i, int directInc, int sliceInc) { loadSendSync(); } - public: - __device__ __forceinline__ - ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) { - init(recvPeers, sendPeers, channel); - } - - __device__ __forceinline__ - ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclPrimitivesRecvData& r) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount), r(r) { - init(recvPeers, sendPeers, channel); - } - - __device__ __forceinline__ - ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclPrimitivesSendData& s) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount), s(s) { - init(recvPeers, sendPeers, channel); - } - __device__ __forceinline__ void send(const T* src, int nelem) { GenericOp<0, 0, 0, 1, 1, 0>(src, NULL, nelem, 0); diff --git a/projects/rccl/src/collectives/device/prims_ll.h b/projects/rccl/src/collectives/device/prims_ll.h index 277dd2b191..9ed03a1d48 100644 --- a/projects/rccl/src/collectives/device/prims_ll.h +++ b/projects/rccl/src/collectives/device/prims_ll.h @@ -5,28 +5,6 @@ * See LICENSE.txt for license information ************************************************************************/ -template -class ncclLLPrimitivesRecvData { -public: - struct ncclConnInfo* recvConn = NULL; - volatile uint64_t* recvConnHeadPtr = NULL; - uint64_t recvConnHead; - uint64_t recvStep[NRECV]; - union ncclLLFifoLine* recvBuff[NRECV]; -}; - -template -class ncclLLPrimitivesSendData { -public: - struct ncclConnInfo* sendConn = NULL; - volatile int* sendConnFifoPtr = NULL; - volatile uint64_t* sendConnHeadPtr = NULL; - uint64_t sendConnHead; - uint64_t sendConnHeadCache; // Cache last seen value - uint64_t sendStep[NSEND]; - union ncclLLFifoLine* sendBuff[NSEND]; -}; - template class ncclLLPrimitives { private: @@ -36,25 +14,34 @@ class ncclLLPrimitives { const int stepLines; int nrecv = 0; int nsend = 0; + struct ncclConnInfo* recvConn = NULL; + volatile uint64_t* recvConnHeadPtr = NULL; + uint64_t recvConnHead; + + struct ncclConnInfo* sendConn = NULL; + volatile int* sendConnFifoPtr = NULL; + volatile uint64_t* sendConnHeadPtr = NULL; + uint64_t sendConnHead; + uint64_t sendConnHeadCache; // Cache last seen value + + uint64_t recvStep[NRECV]; + uint64_t sendStep[NSEND]; + union ncclLLFifoLine* recvBuff[NRECV]; + union ncclLLFifoLine* sendBuff[NSEND]; struct ncclDevComm* comm; - typename std::conditional&, ncclLLPrimitivesRecvData>::type r; - typename std::conditional&, ncclLLPrimitivesSendData>::type s; - - inline __device__ int recvOffset(int i) { return (r.recvStep[i]%NCCL_STEPS)*stepLines; } - inline __device__ int sendOffset(int i) { return (s.sendStep[i]%NCCL_STEPS)*stepLines; } - inline __device__ union ncclLLFifoLine* recvPtr(int i) { return r.recvBuff[i]+recvOffset(i); } - inline __device__ union ncclLLFifoLine* sendPtr(int i) { return s.sendBuff[i]+sendOffset(i); } - inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(r.recvStep[i]+1); } - inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(s.sendStep[i]+1); } + inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; } + inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; } + inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); } + inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); } + inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); } + inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); } inline __device__ void barrier() { #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) __syncthreads(); #else - asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); + asm volatile ("basync 1, %0;" :: "r"(nthreads)); #endif } @@ -78,7 +65,7 @@ class ncclLLPrimitives { spins++; if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { abort = LOAD(comm->abortFlag); - if (wid == i) checkMismatch(send ? s.sendConn : r.recvConn); + if (wid == i) checkMismatch(send ? sendConn : recvConn); spins = 0; } return abort; @@ -87,35 +74,35 @@ class ncclLLPrimitives { inline __device__ void waitSend(int nbytes) { spins = 0; mismatch = 0; - if (s.sendConnHeadPtr) { - while (s.sendConnHeadCache + NCCL_STEPS < s.sendConnHead + 1) { - s.sendConnHeadCache = LOAD(s.sendConnHeadPtr); + if (sendConnHeadPtr) { + while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { + sendConnHeadCache = LOAD(sendConnHeadPtr); if (checkAbort(wid, 1)) break; } - if (s.sendConnFifoPtr) { - int size = ((s.sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; - STORE(s.sendConnFifoPtr+s.sendConnHead%NCCL_STEPS, size); + if (sendConnFifoPtr) { + int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; + STORE(sendConnFifoPtr+sendConnHead%NCCL_STEPS, size); } - s.sendConnHead += 1; + sendConnHead += 1; } barrier(); } inline __device__ void incRecv(int i) { - r.recvStep[i] += 1; + recvStep[i] += 1; } inline __device__ void postRecv() { barrier(); - if (r.recvConnHeadPtr) STORE(r.recvConnHeadPtr, r.recvConnHead += 1); + if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += 1); } inline __device__ void incSend(int i, int offset) { // LL Cleanup : write all flags in the slice to make sure we don't have - // data corruption when flag loops over. - if ((s.sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { + // data corruption when flag loops ove + if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) { for (int o = offset; obuffs+NCCL_PROTO_LL); - r.recvStep[i] = LOAD(&conn->step); - if (wid == i) r.recvConn = conn; + recvBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL); + recvStep[i] = LOAD(&conn->step); + if (wid == i) recvConn = conn; nrecv++; } __device__ __forceinline__ void loadRecvSync() { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { - r.recvConnHeadPtr = LOAD(&r.recvConn->head); - r.recvConnHead = LOAD(&r.recvConn->step); + recvConnHeadPtr = LOAD(&recvConn->head); + recvConnHead = LOAD(&recvConn->step); // Update opCount in case we skipped some operations - STORE(r.recvConn->opCountLoc, opCount); + STORE(recvConn->opCountLoc, opCount); } } __device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) { - s.sendBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL); - s.sendStep[i] = LOAD(&conn->step); - if (wid == i) s.sendConn = conn; + sendBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL); + sendStep[i] = LOAD(&conn->step); + if (wid == i) sendConn = conn; nsend++; } __device__ __forceinline__ void loadSendSync() { if (tid < nsend) { - s.sendConnHeadPtr = LOAD(&s.sendConn->head); - s.sendConnHeadCache = LOAD(s.sendConnHeadPtr); - s.sendConnHead = LOAD(&s.sendConn->step); - s.sendConnFifoPtr = LOAD(&s.sendConn->fifo); - STORE(s.sendConn->opCountLoc, opCount); + sendConnHeadPtr = LOAD(&sendConn->head); + sendConnHeadCache = LOAD(sendConnHeadPtr); + sendConnHead = LOAD(&sendConn->step); + sendConnFifoPtr = LOAD(&sendConn->fifo); + STORE(sendConn->opCountLoc, opCount); } } __device__ __forceinline__ void saveRecvSync() { if (tid >= nthreads-WARP_SIZE && wid < nrecv) { - STORE(&r.recvConn->step, r.recvConnHead); - STORE(r.recvConn->opCountLoc, opCount+1); + STORE(&recvConn->step, recvConnHead); + STORE(recvConn->opCountLoc, opCount+1); __threadfence_block(); } } __device__ __forceinline__ void saveSendSync() { if (tid < nsend) { - STORE(&s.sendConn->step, s.sendConnHead); - STORE(s.sendConn->opCountLoc, opCount+1); + STORE(&sendConn->step, sendConnHead); + STORE(sendConn->opCountLoc, opCount+1); __threadfence_block(); } } - inline __device__ void init(int* recvPeers, int* sendPeers, struct ncclChannel* channel) { + public: + __device__ __forceinline__ + ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) { // Make sure step is updated before we read it. barrier(); @@ -269,25 +259,6 @@ class ncclLLPrimitives { loadSendSync(); } - public: - __device__ __forceinline__ - ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) { - init(recvPeers, sendPeers, channel); - } - - __device__ __forceinline__ - ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclLLPrimitivesRecvData& r) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount), r(r) { - init(recvPeers, sendPeers, channel); - } - - __device__ __forceinline__ - ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclLLPrimitivesSendData& s) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount), s(s) { - init(recvPeers, sendPeers, channel); - } - __device__ void send(const T* src, int nelem) { return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem); }