diff --git a/src/device/msccl_kernel_impl.h b/src/device/msccl_kernel_impl.h index 9eabe54aa4..d2419b4f15 100644 --- a/src/device/msccl_kernel_impl.h +++ b/src/device/msccl_kernel_impl.h @@ -194,7 +194,7 @@ __device__ __forceinline__ void mscclRunInterpreter( } RedOp redFn(mscclShmem.work.redOpArg); - Primitives, 1, Proto, 0> prims + Primitives, 1, Proto, 0, 0, RCCL_METADATA_MSCCL> prims (tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg); #if defined(ENABLE_NPKIT) @@ -278,7 +278,7 @@ __device__ __forceinline__ void mscclRunInterpreter( } #endif - prims.mscclSend(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end. + prims.send(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end. #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT) if (tid == 0) { diff --git a/src/device/prims_ll.h b/src/device/prims_ll.h index 620cd9e75f..703860a642 100644 --- a/src/device/prims_ll.h +++ b/src/device/prims_ll.h @@ -413,20 +413,6 @@ private: } } - __device__ void mscclStoreData(T *dst, uint64_t val, int eltN) { - union { - uint64_t u8; - T elt[EltPerLine]; - }; - u8 = val; - #pragma unroll - for(int i=0; i < EltPerLine; i++) { - if (i==0 || i < eltN) - store(dst+i, elt[i]); - // dst[i] = elt[i]; - } - } - template __device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { constexpr int SRC = SrcBuf != -1 ? 1 : 0; @@ -514,6 +500,12 @@ private: nelem -= eltPerTrip; offset += nthreads; } + #ifdef __gfx950__ + if constexpr (isMsccl(Metadata) && DST){ + // Wait for pending vector loads and stores + __builtin_amdgcn_s_waitcnt((15 << 8) | (7 << 4)); // s_waitcnt vmcnt(0) + } + #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { @@ -590,18 +582,18 @@ private: dataD = applyReduce(redOp, dataD, data); } } - mscclStoreData(dstElts, dataD, eltInLine); + storeData(dstElts, dataD, eltInLine); dstElts += eltPerTrip; } if (COPY){ - mscclStoreData(dstElts, data, eltInLine); + storeData(dstElts, data, eltInLine); dstElts += eltPerTrip; if (MULTIDSTS){ for (int i = 1; i < ndsts; i++){ dl.loadBegin(srcs[i], eltInLine); srcs[i] += eltPerTrip; data = dl.loadFinish(); - mscclStoreData(dsts[i], data, eltInLine); + storeData(dsts[i], data, eltInLine); dsts[i] += eltPerTrip; } } @@ -835,51 +827,4 @@ public: __device__ void localCopy(T* srcs, T* dsts, int eltN) { return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN); } - - __device__ void mscclStoreLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) { - union ncclLLFifoLine i4; - i4.data1 = val & 0xffffffff; - i4.flag1 = flag; - i4.data2 = (val >> 32); - i4.flag2 = flag; - __builtin_nontemporal_store(i4.v[0], dst->v); - __builtin_nontemporal_store(i4.v[1], dst->v+1); - } - - __device__ void mscclSend(intptr_t srcIx, int nelem) { -#if defined(__gfx950__) - T *srcElts = userBufs[0] + srcIx; - - // Always waitSend in case of cleanup - nelem = nelem < 0 ? 0 : nelem; - waitSend(divUp(nelem, EltPerLine)*sizeof(ncclLLFifoLine)); - - nelem -= tid*EltPerLine; - srcElts += tid*EltPerLine; - int offset = tid; - int eltPerTrip = nthreads*EltPerLine; - while (nelem > 0) { - int eltInLine = EltPerLine < nelem ? EltPerLine : nelem; - - DataLoader dl; - // ncclLLFifoLine line[MaxRecv];//unused variable - compiler warning - uint64_t data /*peerData*/; //unused variable - compiler warning - dl.loadBegin(srcElts, eltInLine); - srcElts += eltPerTrip; - data = dl.loadFinish(); - - for (int i=1; i < MaxSend && i < fan.nsend(); i++) - mscclStoreLL(sendPtr(i)+offset, data, sendFlag(i)); - mscclStoreLL(sendPtr(0)+offset, data, sendFlag(0)); - nelem -= eltPerTrip; - offset += nthreads; - } - - for (int i=1; i < MaxSend && i < fan.nsend(); i++) - incSend(i, offset); - incSend(0, offset); -#else - LLGenericOp<0, 1, Input, -1>(srcIx, -1, nelem, false); -#endif - } }; diff --git a/src/device/prims_ll128.h b/src/device/prims_ll128.h index 3d26fc58e5..e1a8943658 100644 --- a/src/device/prims_ll128.h +++ b/src/device/prims_ll128.h @@ -752,7 +752,4 @@ public: __device__ void localCopy(T* srcs, T* dsts, int eltN) { return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN); } - __device__ void mscclSend(intptr_t inpIx, int eltN) { - return GenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false); - } }; diff --git a/src/device/prims_simple.h b/src/device/prims_simple.h index 1aceac87de..fb00bd8e72 100644 --- a/src/device/prims_simple.h +++ b/src/device/prims_simple.h @@ -1337,7 +1337,4 @@ public: __device__ __forceinline__ void localCopy(T* srcs, T* dsts, int eltN) { return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN); } - __device__ __forceinline__ void mscclSend(intptr_t inpIx, int eltN) { - genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, eltN, false); - } }; diff --git a/src/device/rccl_metadata.h b/src/device/rccl_metadata.h index 4875137a61..be278f617d 100644 --- a/src/device/rccl_metadata.h +++ b/src/device/rccl_metadata.h @@ -24,3 +24,12 @@ THE SOFTWARE. /* This file implements methods to extract metadata from an integer Metadata field passed in as a template parameter. Feel free to add additional fields below.*/ #define RCCL_METADATA_EMPTY 0 +#define RCCL_METADATA_MSCCL 1 + +constexpr bool isMsccl(int metadata){ + return (metadata & RCCL_METADATA_MSCCL) > 0; +} + +static_assert(isMsccl(RCCL_METADATA_MSCCL), "RCCL metadata value error"); +static_assert(!isMsccl(RCCL_METADATA_EMPTY), "RCCL metadata value error"); +