MSCCL: Unland PR1788 + Fix for MSCCL Data Corruption (#1960)
- Earlier fix PR1788 is no longer necessary after ROCr fix and pre-ROCr fix workaround - Inserts an s_waitcnt vmcnt(0), which fixes a data corruption issue in MSCCL
Este cometimento está contido em:
cometido por
GitHub
ascendente
fedddb452c
cometimento
154350baaf
@@ -194,7 +194,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
|
||||
RedOp redFn(mscclShmem.work.redOpArg);
|
||||
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0, 0, RCCL_METADATA_MSCCL> prims
|
||||
(tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -278,7 +278,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
|
||||
#endif
|
||||
prims.mscclSend(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
|
||||
prims.send(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT)
|
||||
if (tid == 0) {
|
||||
|
||||
@@ -413,20 +413,6 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void mscclStoreData(T *dst, uint64_t val, int eltN) {
|
||||
union {
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
};
|
||||
u8 = val;
|
||||
#pragma unroll
|
||||
for(int i=0; i < EltPerLine; i++) {
|
||||
if (i==0 || i < eltN)
|
||||
store(dst+i, elt[i]);
|
||||
// dst[i] = elt[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
@@ -514,6 +500,12 @@ private:
|
||||
nelem -= eltPerTrip;
|
||||
offset += nthreads;
|
||||
}
|
||||
#ifdef __gfx950__
|
||||
if constexpr (isMsccl(Metadata) && DST){
|
||||
// Wait for pending vector loads and stores
|
||||
__builtin_amdgcn_s_waitcnt((15 << 8) | (7 << 4)); // s_waitcnt vmcnt(0)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
|
||||
if (tid == 0) {
|
||||
@@ -590,18 +582,18 @@ private:
|
||||
dataD = applyReduce(redOp, dataD, data);
|
||||
}
|
||||
}
|
||||
mscclStoreData(dstElts, dataD, eltInLine);
|
||||
storeData(dstElts, dataD, eltInLine);
|
||||
dstElts += eltPerTrip;
|
||||
}
|
||||
if (COPY){
|
||||
mscclStoreData(dstElts, data, eltInLine);
|
||||
storeData(dstElts, data, eltInLine);
|
||||
dstElts += eltPerTrip;
|
||||
if (MULTIDSTS){
|
||||
for (int i = 1; i < ndsts; i++){
|
||||
dl.loadBegin(srcs[i], eltInLine);
|
||||
srcs[i] += eltPerTrip;
|
||||
data = dl.loadFinish();
|
||||
mscclStoreData(dsts[i], data, eltInLine);
|
||||
storeData(dsts[i], data, eltInLine);
|
||||
dsts[i] += eltPerTrip;
|
||||
}
|
||||
}
|
||||
@@ -835,51 +827,4 @@ public:
|
||||
__device__ void localCopy(T* srcs, T* dsts, int eltN) {
|
||||
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
|
||||
}
|
||||
|
||||
__device__ void mscclStoreLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) {
|
||||
union ncclLLFifoLine i4;
|
||||
i4.data1 = val & 0xffffffff;
|
||||
i4.flag1 = flag;
|
||||
i4.data2 = (val >> 32);
|
||||
i4.flag2 = flag;
|
||||
__builtin_nontemporal_store(i4.v[0], dst->v);
|
||||
__builtin_nontemporal_store(i4.v[1], dst->v+1);
|
||||
}
|
||||
|
||||
__device__ void mscclSend(intptr_t srcIx, int nelem) {
|
||||
#if defined(__gfx950__)
|
||||
T *srcElts = userBufs[0] + srcIx;
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
waitSend(divUp(nelem, EltPerLine)*sizeof(ncclLLFifoLine));
|
||||
|
||||
nelem -= tid*EltPerLine;
|
||||
srcElts += tid*EltPerLine;
|
||||
int offset = tid;
|
||||
int eltPerTrip = nthreads*EltPerLine;
|
||||
while (nelem > 0) {
|
||||
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
|
||||
|
||||
DataLoader dl;
|
||||
// ncclLLFifoLine line[MaxRecv];//unused variable - compiler warning
|
||||
uint64_t data /*peerData*/; //unused variable - compiler warning
|
||||
dl.loadBegin(srcElts, eltInLine);
|
||||
srcElts += eltPerTrip;
|
||||
data = dl.loadFinish();
|
||||
|
||||
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
|
||||
mscclStoreLL(sendPtr(i)+offset, data, sendFlag(i));
|
||||
mscclStoreLL(sendPtr(0)+offset, data, sendFlag(0));
|
||||
nelem -= eltPerTrip;
|
||||
offset += nthreads;
|
||||
}
|
||||
|
||||
for (int i=1; i < MaxSend && i < fan.nsend(); i++)
|
||||
incSend(i, offset);
|
||||
incSend(0, offset);
|
||||
#else
|
||||
LLGenericOp<0, 1, Input, -1>(srcIx, -1, nelem, false);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -752,7 +752,4 @@ public:
|
||||
__device__ void localCopy(T* srcs, T* dsts, int eltN) {
|
||||
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
|
||||
}
|
||||
__device__ void mscclSend(intptr_t inpIx, int eltN) {
|
||||
return GenericOp<0, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1337,7 +1337,4 @@ public:
|
||||
__device__ __forceinline__ void localCopy(T* srcs, T* dsts, int eltN) {
|
||||
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
|
||||
}
|
||||
__device__ __forceinline__ void mscclSend(intptr_t inpIx, int eltN) {
|
||||
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, eltN, false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -24,3 +24,12 @@ THE SOFTWARE.
|
||||
/* This file implements methods to extract metadata from an integer Metadata field passed in as a template parameter. Feel free to add additional fields below.*/
|
||||
|
||||
#define RCCL_METADATA_EMPTY 0
|
||||
#define RCCL_METADATA_MSCCL 1
|
||||
|
||||
constexpr bool isMsccl(int metadata){
|
||||
return (metadata & RCCL_METADATA_MSCCL) > 0;
|
||||
}
|
||||
|
||||
static_assert(isMsccl(RCCL_METADATA_MSCCL), "RCCL metadata value error");
|
||||
static_assert(!isMsccl(RCCL_METADATA_EMPTY), "RCCL metadata value error");
|
||||
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador