@@ -474,20 +474,10 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)(struct ncclDevCo
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t)
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv)
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min)
|
||||
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC()
|
||||
|
||||
@@ -28,21 +28,11 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto)(struct ncclDevCom
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(PreMulSum) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(SumPostDiv)
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min)
|
||||
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC()
|
||||
|
||||
|
||||
@@ -189,6 +189,11 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam*
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
param->scheduled = false;
|
||||
|
||||
// Current MSCCL doesn't support pre/post op
|
||||
if (param->op >= ncclAvg) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Whether the algorithm is in-place
|
||||
bool isInPlace = false;
|
||||
if (param->func == mscclFuncReduce ||
|
||||
|
||||
@@ -235,27 +235,14 @@ static ncclResult_t hostToDevRedOp(
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(devredop) \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL(), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NULL()
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY() \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(PreMulSum), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_NOFLOAT(SumPostDiv)
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min)
|
||||
|
||||
void* mscclKernelEntries[ncclNumDevRedOps * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
|
||||
// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg
|
||||
void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
|
||||
MSCCL_KERNEL_ENTRY()
|
||||
};
|
||||
|
||||
|
||||
Αναφορά σε νέο ζήτημα
Block a user