From 12c08fc52a588b35cbb715cdd8a47be119e021a9 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:36:04 -0800 Subject: [PATCH] msccl: build same number of kernels as in ROCm 5.7 (#1005) Removed fullOps kernels from build --- cmake/Dependencies.cmake | 3 +-- src/collectives/device/msccl_kernel_impl.h | 14 +++++--------- src/include/msccl/msccl_kernel.h | 10 +++------- src/misc/msccl/msccl_lifecycle.cc | 2 +- src/misc/msccl/msccl_setup.cc | 18 ++++++++---------- 5 files changed, 18 insertions(+), 29 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 25862b3acd..3d50c63615 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -105,8 +105,7 @@ function(expand_collectives FILE FUNC) #include \"primitives.h\" #include \"collectives.h\" #include \"devcomm.h\" - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false); - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);") + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);") else() file(WRITE ${FILE_NAME} "#include \"${FILE}.h\" diff --git a/src/collectives/device/msccl_kernel_impl.h b/src/collectives/device/msccl_kernel_impl.h index cd21e70965..0c537349ba 100644 --- a/src/collectives/device/msccl_kernel_impl.h +++ b/src/collectives/device/msccl_kernel_impl.h @@ -492,7 +492,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct mscclRunInterpreter, ProtoSimple, fullOps>(comm, algo, work); \ } -#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \ +#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \ @@ -504,7 +504,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) -#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \ +#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \ @@ -515,13 +515,9 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct #define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, true) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, true) + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false) #endif diff --git a/src/include/msccl/msccl_kernel.h b/src/include/msccl/msccl_kernel.h index d8519b7db4..49b0e23def 100644 --- a/src/include/msccl/msccl_kernel.h +++ b/src/include/msccl/msccl_kernel.h @@ -39,14 +39,10 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct n #define MSCCL_DECL_KERNEL_ENTRY_FUNC() \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, true) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, true) + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv, false) MSCCL_DECL_KERNEL_ENTRY_FUNC() diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index 932fa2efed..3085403249 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -262,7 +262,7 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam* mscclStatus& status = mscclGetStatus(); param->scheduled = false; - + /*// Current MSCCL doesn't support pre/post op if (param->op >= ncclAvg) { return ncclSuccess; diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index d999cb1804..3e49aeac2e 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -317,17 +317,13 @@ static ncclResult_t hostToDevRedOp( #define MSCCL_KERNEL_ENTRY() \ MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \ MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \ MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, true), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, true), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Max, true), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Min, true), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, true), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, true) + MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv, false) // Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg -void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS * 2] = { +void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS] = { #ifdef COMPILE_MSCCL_KERNEL MSCCL_KERNEL_ENTRY() #endif @@ -415,8 +411,10 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count (1<typeMask & fullOpMask) || rcclParamMscclForceFullOps()) - fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2; + if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps()) { + WARN("MSCCL: this version of MSCCL build doesn't support fill Ops"); + return ncclInternalError; + } mscclWork work; work.syncFlags = status.syncFlags;