diff --git a/projects/rccl/src/collectives/device/msccl_kernel.cu b/projects/rccl/src/collectives/device/msccl_kernel.cu index 7a78865983..4feae3e7e9 100644 --- a/projects/rccl/src/collectives/device/msccl_kernel.cu +++ b/projects/rccl/src/collectives/device/msccl_kernel.cu @@ -474,20 +474,10 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)(struct ncclDevCo MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16) -#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) - #define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv) + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) MSCCL_IMPL_KERNEL_ENTRY_FUNC() diff --git a/projects/rccl/src/include/msccl/msccl_kernel.h b/projects/rccl/src/include/msccl/msccl_kernel.h index 25e8b75989..0f634e0f89 100644 --- a/projects/rccl/src/include/msccl/msccl_kernel.h +++ b/projects/rccl/src/include/msccl/msccl_kernel.h @@ -28,21 +28,11 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto)(struct ncclDevCom MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16) -#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) - #define MSCCL_DECL_KERNEL_ENTRY_FUNC() \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv) + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) MSCCL_DECL_KERNEL_ENTRY_FUNC() diff --git a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc index d71c1310d2..a7ebd7328c 100644 --- a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc +++ b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc @@ -189,6 +189,11 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam* mscclStatus& status = mscclGetStatus(); param->scheduled = false; + // Current MSCCL doesn't support pre/post op + if (param->op >= ncclAvg) { + return ncclSuccess; + } + // Whether the algorithm is in-place bool isInPlace = false; if (param->func == mscclFuncReduce || diff --git a/projects/rccl/src/misc/msccl/msccl_setup.cc b/projects/rccl/src/misc/msccl/msccl_setup.cc index 37c2aca7ba..b815d96fde 100644 --- a/projects/rccl/src/misc/msccl/msccl_setup.cc +++ b/projects/rccl/src/misc/msccl/msccl_setup.cc @@ -235,27 +235,14 @@ static ncclResult_t hostToDevRedOp( MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \ MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16) -#define MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(devredop) \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_NULL() - #define MSCCL_KERNEL_ENTRY() \ MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \ MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Min), \ MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv) + MSCCL_KERNEL_ENTRY_DEVREDOP(Min) -void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS] = { +// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg +void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS] = { MSCCL_KERNEL_ENTRY() };