2
0

MSCCL: Unland PR1788 + Fix for MSCCL Data Corruption (#1960)

- Earlier fix PR1788 is no longer necessary after ROCr fix and pre-ROCr fix workaround
- Inserts an s_waitcnt vmcnt(0), which fixes a data corruption issue in MSCCL
Este cometimento está contido em:
alex-breslow-amd
2025-10-15 10:32:25 -07:00
cometido por GitHub
ascendente fedddb452c
cometimento 154350baaf
5 ficheiros modificados com 20 adições e 72 eliminações
+2 -2
Ver ficheiro
@@ -194,7 +194,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
}
RedOp redFn(mscclShmem.work.redOpArg);
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0> prims
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0, 0, RCCL_METADATA_MSCCL> prims
(tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg);
#if defined(ENABLE_NPKIT)
@@ -278,7 +278,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
}
#endif
prims.mscclSend(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
prims.send(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT)
if (tid == 0) {
+9 -64
Ver ficheiro
@@ -413,20 +413,6 @@ private:
}
}
__device__ void mscclStoreData(T *dst, uint64_t val, int eltN) {
union {
uint64_t u8;
T elt[EltPerLine];
};
u8 = val;
#pragma unroll
for(int i=0; i < EltPerLine; i++) {
if (i==0 || i < eltN)
store(dst+i, elt[i]);
// dst[i] = elt[i];
}
}
template <int RECV, int SEND, int SrcBuf, int DstBuf>
__device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
@@ -514,6 +500,12 @@ private:
nelem -= eltPerTrip;
offset += nthreads;
}
#ifdef __gfx950__
if constexpr (isMsccl(Metadata) && DST){
// Wait for pending vector loads and stores
__builtin_amdgcn_s_waitcnt((15 << 8) | (7 << 4)); // s_waitcnt vmcnt(0)
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
@@ -590,18 +582,18 @@ private:
dataD = applyReduce(redOp, dataD, data);
}
}
mscclStoreData(dstElts, dataD, eltInLine);
storeData(dstElts, dataD, eltInLine);
dstElts += eltPerTrip;
}
if (COPY){
mscclStoreData(dstElts, data, eltInLine);
storeData(dstElts, data, eltInLine);
dstElts += eltPerTrip;
if (MULTIDSTS){
for (int i = 1; i < ndsts; i++){
dl.loadBegin(srcs[i], eltInLine);
srcs[i] += eltPerTrip;
data = dl.loadFinish();
mscclStoreData(dsts[i], data, eltInLine);
storeData(dsts[i], data, eltInLine);
dsts[i] += eltPerTrip;
}
}
@@ -835,51 +827,4 @@ public:
__device__ void localCopy(T* srcs, T* dsts, int eltN) {
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
}
__device__ void mscclStoreLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
union ncclLLFifoLine i4;
i4.data1 = val & 0xffffffff;
i4.flag1 = flag;
i4.data2 = (val >> 32);
i4.flag2 = flag;
__builtin_nontemporal_store(i4.v[0], dst->v);
__builtin_nontemporal_store(i4.v[1], dst->v+1);
}
__device__ void mscclSend(intptr_t srcIx, int nelem) {
#if defined(__gfx950__)
T *srcElts = userBufs[0] + srcIx;
// Always waitSend in case of cleanup
nelem = nelem < 0 ? 0 : nelem;
waitSend(divUp(nelem, EltPerLine)*sizeof(ncclLLFifoLine));
nelem -= tid*EltPerLine;
srcElts += tid*EltPerLine;
int offset = tid;
int eltPerTrip = nthreads*EltPerLine;
while (nelem > 0) {
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
DataLoader dl;
// ncclLLFifoLine line[MaxRecv];//unused variable - compiler warning
uint64_t data /*peerData*/; //unused variable - compiler warning
dl.loadBegin(srcElts, eltInLine);
srcElts += eltPerTrip;
data = dl.loadFinish();
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
mscclStoreLL(sendPtr(i)+offset, data, sendFlag(i));
mscclStoreLL(sendPtr(0)+offset, data, sendFlag(0));
nelem -= eltPerTrip;
offset += nthreads;
}
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
incSend(i, offset);
incSend(0, offset);
#else
LLGenericOp<0, 1, Input, -1>(srcIx, -1, nelem, false);
#endif
}
};
-3
Ver ficheiro
@@ -752,7 +752,4 @@ public:
__device__ void localCopy(T* srcs, T* dsts, int eltN) {
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
}
__device__ void mscclSend(intptr_t inpIx, int eltN) {
return GenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false);
}
};
-3
Ver ficheiro
@@ -1337,7 +1337,4 @@ public:
__device__ __forceinline__ void localCopy(T* srcs, T* dsts, int eltN) {
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
}
__device__ __forceinline__ void mscclSend(intptr_t inpIx, int eltN) {
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, eltN, false);
}
};
+9
Ver ficheiro
@@ -24,3 +24,12 @@ THE SOFTWARE.
/* This file implements methods to extract metadata from an integer Metadata field passed in as a template parameter. Feel free to add additional fields below.*/
#define RCCL_METADATA_EMPTY 0
#define RCCL_METADATA_MSCCL 1
constexpr bool isMsccl(int metadata){
return (metadata & RCCL_METADATA_MSCCL) > 0;
}
static_assert(isMsccl(RCCL_METADATA_MSCCL), "RCCL metadata value error");
static_assert(!isMsccl(RCCL_METADATA_EMPTY), "RCCL metadata value error");