* msccl: add templated kernel

* Use defines to improve code readability

* Fix kernel indexing and review feedback
Этот коммит содержится в:
Wenkai Du
2023-11-02 17:21:53 -07:00
коммит произвёл GitHub
родитель 61aed56ca7
Коммит f484ff17b9
6 изменённых файлов: 96 добавлений и 73 удалений
+2 -1
Просмотреть файл
@@ -105,7 +105,8 @@ function(expand_collectives FILE FUNC)
#include \"primitives.h\"
#include \"collectives.h\"
#include \"devcomm.h\"
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE});")
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);")
else()
file(WRITE ${FILE_NAME}
"#include \"${FILE}.h\"
+31 -27
Просмотреть файл
@@ -128,7 +128,7 @@ for (int r = 0; r < numloops; r++) { \
srcs[r] = srcPointer + srcOffset; \
}
template<typename T, typename RedOp, typename Proto>
template<typename T, typename RedOp, typename Proto, bool fullOps>
__device__ __forceinline__ void mscclRunInterpreter(
struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) {
const int tid = threadIdx.x;
@@ -411,13 +411,13 @@ __device__ __forceinline__ void mscclRunInterpreter(
prims.reduce(srcs, numReductions, &dst, 1, thisNelem);
}
if (c == 0) step += (numReductions-1); // only advance step once!
} else if (t->type == MSCCL_RECV_COPY_SEND)
} else if (fullOps && t->type == MSCCL_RECV_COPY_SEND)
prims.recvCopySend(dstOffset, thisNelem);
else if (t->type == MSCCL_RECV_REDUCE_SEND)
else if (fullOps && t->type == MSCCL_RECV_REDUCE_SEND)
prims.recvReduceSend(srcOffset, thisNelem);
else if (t->type == MSCCL_RECV_REDUCE_COPY_SEND)
else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY_SEND)
prims.recvReduceCopySend(srcOffset, dstOffset, thisNelem);
else if (t->type == MSCCL_RECV_REDUCE_COPY) {
else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY) {
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
@@ -430,7 +430,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
}
#endif
}
else if (t->type == MSCCL_LOCAL_COPY)
else if (fullOps && t->type == MSCCL_LOCAL_COPY)
prims.localCopy(srcPointer+srcOffset, dstPointer+dstOffset, thisNelem);
else
return;
@@ -458,33 +458,37 @@ __device__ __forceinline__ void mscclRunInterpreter(
#endif
}
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL>(comm, algo, work); \
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL, fullOps>(comm, algo, work); \
} \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128>(comm, algo, work); \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128, fullOps>(comm, algo, work); \
} \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>>(comm, algo, work); \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>, fullOps>(comm, algo, work); \
}
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min)
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true)
#endif
+26 -22
Просмотреть файл
@@ -6,33 +6,37 @@
#ifndef MSCCL_KERNEL_H_
#define MSCCL_KERNEL_H_
#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto) mscclKernel_##devredop##_##type##_##proto
#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps) mscclKernel_##devredop##_##type##_##proto##_##fullOps
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto) \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work);
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto, fullOps) \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work);
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple)
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple, fullOps)
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
#define MSCCL_DECL_KERNEL_ENTRY_FUNC() \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min)
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true)
MSCCL_DECL_KERNEL_ENTRY_FUNC()
+2
Просмотреть файл
@@ -144,6 +144,8 @@ struct mscclAlgo {
bool inPlace;
// Whether this algorithm is suitable for out-of-place.
bool outOfPlace;
// Keep a bit mask of used types (max 8 at present)
uint8_t typeMask;
};
enum mscclGroupStatus {
+1 -1
Просмотреть файл
@@ -573,7 +573,7 @@ ncclResult_t mscclGetAlgoFromXmlFile(const char* str, struct mscclAlgo* algo, in
mscclTran->srcOffset = srcOffset;
mscclTran->dstBuffer = dstBufferInt;
mscclTran->dstOffset = dstOffset;
algo->typeMask |= (1<<transferType);
if (count < 0 || count >= MSCCL_MAX_COUNT){
WARN("MSCCL: count (%d) must be positive and less than %d", count, MSCCL_MAX_COUNT);
return ncclInternalError;
+34 -22
Просмотреть файл
@@ -283,31 +283,35 @@ static ncclResult_t hostToDevRedOp(
nullptr, \
nullptr
#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type) \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL), \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128), \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)
#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type, fullOps) \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps), \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps), \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)
#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop) \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16)
#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop, fullOps) \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
#define MSCCL_KERNEL_ENTRY() \
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Min)
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, true), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, true)
// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg
void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS * 2] = {
#ifdef COMPILE_MSCCL_KERNEL
MSCCL_KERNEL_ENTRY()
#endif
@@ -398,7 +402,7 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
work.maxAllowedCount = status.maxAllowedCount;
work.hasReduce = hostAlgo->hasReduce;
work.redOpArgIsPtr = opFull.scalarArgIsPtr;
INFO(NCCL_COLL, "MSCCL: Setup Kernel finished");
INFO(NCCL_COLL, "MSCCL: typeMask %x Setup Kernel finished", hostAlgo->typeMask);
uint32_t workFifoIdxMask = status.workFifoDepth - 1;
uint32_t workFifoSent = status.workFifoSent;
@@ -423,7 +427,15 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
struct mscclWork *workPtr = status.workFifo + (workFifoSent & workFifoIdxMask);
void *args[3] = {&comm->devComm, &devAlgo, &workPtr};
void *func = mscclKernelEntries[(opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol];
uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol;
uint8_t fullOpMask = (1<<MSCCL_RECV_COPY_SEND) |
(1<<MSCCL_RECV_REDUCE_SEND) |
(1<<MSCCL_RECV_REDUCE_COPY_SEND) |
(1<<MSCCL_RECV_REDUCE_COPY) |
(1<<MSCCL_LOCAL_COPY);
//check if need full ops msccl kernel
if (hostAlgo->typeMask & fullOpMask) fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
void *func = mscclKernelEntries[fnIndex];
if (enableDoneEvent) {
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));
} else {