msccl: build same number of kernels as in ROCm 5.7 (#1005)

Removed fullOps kernels from build
Tento commit je obsažen v:
Wenkai Du
2023-12-07 11:36:04 -08:00
odevzdal GitHub
rodič 9c3fea1751
revize 12c08fc52a
5 změnil soubory, kde provedl 18 přidání a 29 odebrání
+1 -2
Zobrazit soubor
@@ -105,8 +105,7 @@ function(expand_collectives FILE FUNC)
#include \"primitives.h\"
#include \"collectives.h\"
#include \"devcomm.h\"
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);")
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);")
else()
file(WRITE ${FILE_NAME}
"#include \"${FILE}.h\"
+5 -9
Zobrazit soubor
@@ -492,7 +492,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>, fullOps>(comm, algo, work); \
}
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
@@ -504,7 +504,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
@@ -515,13 +515,9 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, true)
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false)
#endif
+3 -7
Zobrazit soubor
@@ -39,14 +39,10 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct n
#define MSCCL_DECL_KERNEL_ENTRY_FUNC() \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, true)
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false)
MSCCL_DECL_KERNEL_ENTRY_FUNC()
+1 -1
Zobrazit soubor
@@ -262,7 +262,7 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam*
mscclStatus& status = mscclGetStatus();
param->scheduled = false;
/*// Current MSCCL doesn't support pre/post op
if (param->op >= ncclAvg) {
return ncclSuccess;
+8 -10
Zobrazit soubor
@@ -317,17 +317,13 @@ static ncclResult_t hostToDevRedOp(
#define MSCCL_KERNEL_ENTRY() \
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, true)
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, false)
// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg
void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS * 2] = {
void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
#ifdef COMPILE_MSCCL_KERNEL
MSCCL_KERNEL_ENTRY()
#endif
@@ -415,8 +411,10 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
(1<<MSCCL_RECV_REDUCE_COPY) |
(1<<MSCCL_LOCAL_COPY);
//check if need full ops msccl kernel
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps())
fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps()) {
WARN("MSCCL: this version of MSCCL build doesn't support fill Ops");
return ncclInternalError;
}
mscclWork work;
work.syncFlags = status.syncFlags;