msccl: build same number of kernels as in ROCm 5.7 (#1005)
Removed fullOps kernels from build
Tento commit je obsažen v:
@@ -105,8 +105,7 @@ function(expand_collectives FILE FUNC)
|
||||
#include \"primitives.h\"
|
||||
#include \"collectives.h\"
|
||||
#include \"devcomm.h\"
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);")
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);")
|
||||
else()
|
||||
file(WRITE ${FILE_NAME}
|
||||
"#include \"${FILE}.h\"
|
||||
|
||||
@@ -492,7 +492,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>, fullOps>(comm, algo, work); \
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
|
||||
@@ -504,7 +504,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
|
||||
@@ -515,13 +515,9 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, true)
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -39,14 +39,10 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct n
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, true)
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false)
|
||||
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC()
|
||||
|
||||
|
||||
@@ -262,7 +262,7 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam*
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
param->scheduled = false;
|
||||
|
||||
|
||||
|
||||
/*// Current MSCCL doesn't support pre/post op
|
||||
if (param->op >= ncclAvg) {
|
||||
return ncclSuccess;
|
||||
|
||||
@@ -317,17 +317,13 @@ static ncclResult_t hostToDevRedOp(
|
||||
#define MSCCL_KERNEL_ENTRY() \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, true)
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, false)
|
||||
|
||||
// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg
|
||||
void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS * 2] = {
|
||||
void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
|
||||
#ifdef COMPILE_MSCCL_KERNEL
|
||||
MSCCL_KERNEL_ENTRY()
|
||||
#endif
|
||||
@@ -415,8 +411,10 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
(1<<MSCCL_RECV_REDUCE_COPY) |
|
||||
(1<<MSCCL_LOCAL_COPY);
|
||||
//check if need full ops msccl kernel
|
||||
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps())
|
||||
fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
|
||||
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps()) {
|
||||
WARN("MSCCL: this version of MSCCL build doesn't support fill Ops");
|
||||
return ncclInternalError;
|
||||
}
|
||||
|
||||
mscclWork work;
|
||||
work.syncFlags = status.syncFlags;
|
||||
|
||||
Odkázat v novém úkolu
Zablokovat Uživatele