diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 109191953e..5b4c76a25c 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -105,7 +105,8 @@ function(expand_collectives FILE FUNC) #include \"primitives.h\" #include \"collectives.h\" #include \"devcomm.h\" - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE});") + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false); + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);") else() file(WRITE ${FILE_NAME} "#include \"${FILE}.h\" diff --git a/src/collectives/device/msccl_kernel_impl.h b/src/collectives/device/msccl_kernel_impl.h index 86fc92d93a..9699b1d029 100644 --- a/src/collectives/device/msccl_kernel_impl.h +++ b/src/collectives/device/msccl_kernel_impl.h @@ -128,7 +128,7 @@ for (int r = 0; r < numloops; r++) { \ srcs[r] = srcPointer + srcOffset; \ } -template +template __device__ __forceinline__ void mscclRunInterpreter( struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { const int tid = threadIdx.x; @@ -411,13 +411,13 @@ __device__ __forceinline__ void mscclRunInterpreter( prims.reduce(srcs, numReductions, &dst, 1, thisNelem); } if (c == 0) step += (numReductions-1); // only advance step once! - } else if (t->type == MSCCL_RECV_COPY_SEND) + } else if (fullOps && t->type == MSCCL_RECV_COPY_SEND) prims.recvCopySend(dstOffset, thisNelem); - else if (t->type == MSCCL_RECV_REDUCE_SEND) + else if (fullOps && t->type == MSCCL_RECV_REDUCE_SEND) prims.recvReduceSend(srcOffset, thisNelem); - else if (t->type == MSCCL_RECV_REDUCE_COPY_SEND) + else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY_SEND) prims.recvReduceCopySend(srcOffset, dstOffset, thisNelem); - else if (t->type == MSCCL_RECV_REDUCE_COPY) { + else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY) if (tid == 0) { NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); @@ -430,7 +430,7 @@ __device__ __forceinline__ void mscclRunInterpreter( } #endif } - else if (t->type == MSCCL_LOCAL_COPY) + else if (fullOps && t->type == MSCCL_LOCAL_COPY) prims.localCopy(srcPointer+srcOffset, dstPointer+dstOffset, thisNelem); else return; @@ -458,33 +458,37 @@ __device__ __forceinline__ void mscclRunInterpreter( #endif } -#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \ -__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ - mscclRunInterpreter, ProtoLL>(comm, algo, work); \ +#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \ +__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ + mscclRunInterpreter, ProtoLL, fullOps>(comm, algo, work); \ } \ -__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ - mscclRunInterpreter, ProtoLL128>(comm, algo, work); \ +__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ + mscclRunInterpreter, ProtoLL128, fullOps>(comm, algo, work); \ } \ -__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ - mscclRunInterpreter, ProtoSimple>(comm, algo, work); \ +__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ + mscclRunInterpreter, ProtoSimple, fullOps>(comm, algo, work); \ } #define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16) + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) #define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) #endif diff --git a/src/include/msccl/msccl_kernel.h b/src/include/msccl/msccl_kernel.h index 647183431b..5e65bcf103 100644 --- a/src/include/msccl/msccl_kernel.h +++ b/src/include/msccl/msccl_kernel.h @@ -6,33 +6,37 @@ #ifndef MSCCL_KERNEL_H_ #define MSCCL_KERNEL_H_ -#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto) mscclKernel_##devredop##_##type##_##proto +#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps) mscclKernel_##devredop##_##type##_##proto##_##fullOps -#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto) \ -__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work); +#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto, fullOps) \ +__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work); -#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple) +#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple, fullOps) -#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16) +#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) #define MSCCL_DECL_KERNEL_ENTRY_FUNC() \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min) + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true) MSCCL_DECL_KERNEL_ENTRY_FUNC() diff --git a/src/include/msccl/msccl_struct.h b/src/include/msccl/msccl_struct.h index 5daaaa883e..73fc5c5f85 100644 --- a/src/include/msccl/msccl_struct.h +++ b/src/include/msccl/msccl_struct.h @@ -144,6 +144,8 @@ struct mscclAlgo { bool inPlace; // Whether this algorithm is suitable for out-of-place. bool outOfPlace; + // Keep a bit mask of used types (max 8 at present) + uint8_t typeMask; }; enum mscclGroupStatus { diff --git a/src/misc/msccl/msccl_parser.cc b/src/misc/msccl/msccl_parser.cc index e810682a12..317ee74792 100644 --- a/src/misc/msccl/msccl_parser.cc +++ b/src/misc/msccl/msccl_parser.cc @@ -573,7 +573,7 @@ ncclResult_t mscclGetAlgoFromXmlFile(const char* str, struct mscclAlgo* algo, in mscclTran->srcOffset = srcOffset; mscclTran->dstBuffer = dstBufferInt; mscclTran->dstOffset = dstOffset; - + algo->typeMask |= (1<= MSCCL_MAX_COUNT){ WARN("MSCCL: count (%d) must be positive and less than %d", count, MSCCL_MAX_COUNT); return ncclInternalError; diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index b683805a98..0b6c997ef2 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -283,31 +283,35 @@ static ncclResult_t hostToDevRedOp( nullptr, \ nullptr -#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type) \ - (void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL), \ - (void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128), \ - (void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple) +#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type, fullOps) \ + (void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps), \ + (void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps), \ + (void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps) -#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop) \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16) +#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop, fullOps) \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) #define MSCCL_KERNEL_ENTRY() \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \ - MSCCL_KERNEL_ENTRY_DEVREDOP(Min) + MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, true), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, true), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Max, true), \ + MSCCL_KERNEL_ENTRY_DEVREDOP(Min, true) // Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg -void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS] = { +void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS * 2] = { #ifdef COMPILE_MSCCL_KERNEL MSCCL_KERNEL_ENTRY() #endif @@ -398,7 +402,7 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count work.maxAllowedCount = status.maxAllowedCount; work.hasReduce = hostAlgo->hasReduce; work.redOpArgIsPtr = opFull.scalarArgIsPtr; - INFO(NCCL_COLL, "MSCCL: Setup Kernel finished"); + INFO(NCCL_COLL, "MSCCL: typeMask %x Setup Kernel finished", hostAlgo->typeMask); uint32_t workFifoIdxMask = status.workFifoDepth - 1; uint32_t workFifoSent = status.workFifoSent; @@ -423,7 +427,15 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count struct mscclWork *workPtr = status.workFifo + (workFifoSent & workFifoIdxMask); void *args[3] = {&comm->devComm, &devAlgo, &workPtr}; - void *func = mscclKernelEntries[(opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol]; + uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol; + uint8_t fullOpMask = (1<typeMask & fullOpMask) fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2; + void *func = mscclKernelEntries[fnIndex]; if (enableDoneEvent) { CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0)); } else {