msccl: add templated kernel (#945)
* msccl: add templated kernel * Use defines to improve code readability * Fix kernel indexing and review feedback
Этот коммит содержится в:
@@ -105,7 +105,8 @@ function(expand_collectives FILE FUNC)
|
||||
#include \"primitives.h\"
|
||||
#include \"collectives.h\"
|
||||
#include \"devcomm.h\"
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE});")
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);")
|
||||
else()
|
||||
file(WRITE ${FILE_NAME}
|
||||
"#include \"${FILE}.h\"
|
||||
|
||||
@@ -128,7 +128,7 @@ for (int r = 0; r < numloops; r++) { \
|
||||
srcs[r] = srcPointer + srcOffset; \
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
template<typename T, typename RedOp, typename Proto, bool fullOps>
|
||||
__device__ __forceinline__ void mscclRunInterpreter(
|
||||
struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) {
|
||||
const int tid = threadIdx.x;
|
||||
@@ -411,13 +411,13 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
prims.reduce(srcs, numReductions, &dst, 1, thisNelem);
|
||||
}
|
||||
if (c == 0) step += (numReductions-1); // only advance step once!
|
||||
} else if (t->type == MSCCL_RECV_COPY_SEND)
|
||||
} else if (fullOps && t->type == MSCCL_RECV_COPY_SEND)
|
||||
prims.recvCopySend(dstOffset, thisNelem);
|
||||
else if (t->type == MSCCL_RECV_REDUCE_SEND)
|
||||
else if (fullOps && t->type == MSCCL_RECV_REDUCE_SEND)
|
||||
prims.recvReduceSend(srcOffset, thisNelem);
|
||||
else if (t->type == MSCCL_RECV_REDUCE_COPY_SEND)
|
||||
else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY_SEND)
|
||||
prims.recvReduceCopySend(srcOffset, dstOffset, thisNelem);
|
||||
else if (t->type == MSCCL_RECV_REDUCE_COPY) {
|
||||
else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY) {
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
@@ -430,7 +430,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if (t->type == MSCCL_LOCAL_COPY)
|
||||
else if (fullOps && t->type == MSCCL_LOCAL_COPY)
|
||||
prims.localCopy(srcPointer+srcOffset, dstPointer+dstOffset, thisNelem);
|
||||
else
|
||||
return;
|
||||
@@ -458,33 +458,37 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL>(comm, algo, work); \
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL, fullOps>(comm, algo, work); \
|
||||
} \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128>(comm, algo, work); \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128, fullOps>(comm, algo, work); \
|
||||
} \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>>(comm, algo, work); \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>, fullOps>(comm, algo, work); \
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min)
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,33 +6,37 @@
|
||||
#ifndef MSCCL_KERNEL_H_
|
||||
#define MSCCL_KERNEL_H_
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto) mscclKernel_##devredop##_##type##_##proto
|
||||
#define MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps) mscclKernel_##devredop##_##type##_##proto##_##fullOps
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto) \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work);
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, proto, fullOps) \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work);
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple)
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, LL128, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE_PROTO(devredop, type, Simple, fullOps)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
|
||||
|
||||
#define MSCCL_DECL_KERNEL_ENTRY_FUNC() \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min)
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, false) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Max, true) \
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP(Min, true)
|
||||
|
||||
MSCCL_DECL_KERNEL_ENTRY_FUNC()
|
||||
|
||||
|
||||
@@ -144,6 +144,8 @@ struct mscclAlgo {
|
||||
bool inPlace;
|
||||
// Whether this algorithm is suitable for out-of-place.
|
||||
bool outOfPlace;
|
||||
// Keep a bit mask of used types (max 8 at present)
|
||||
uint8_t typeMask;
|
||||
};
|
||||
|
||||
enum mscclGroupStatus {
|
||||
|
||||
@@ -573,7 +573,7 @@ ncclResult_t mscclGetAlgoFromXmlFile(const char* str, struct mscclAlgo* algo, in
|
||||
mscclTran->srcOffset = srcOffset;
|
||||
mscclTran->dstBuffer = dstBufferInt;
|
||||
mscclTran->dstOffset = dstOffset;
|
||||
|
||||
algo->typeMask |= (1<<transferType);
|
||||
if (count < 0 || count >= MSCCL_MAX_COUNT){
|
||||
WARN("MSCCL: count (%d) must be positive and less than %d", count, MSCCL_MAX_COUNT);
|
||||
return ncclInternalError;
|
||||
|
||||
@@ -283,31 +283,35 @@ static ncclResult_t hostToDevRedOp(
|
||||
nullptr, \
|
||||
nullptr
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type) \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL), \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128), \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type, fullOps) \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps), \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps), \
|
||||
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop) \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16)
|
||||
#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop, fullOps) \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double, fullOps), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps)
|
||||
|
||||
#define MSCCL_KERNEL_ENTRY() \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min)
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, false), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Max, true), \
|
||||
MSCCL_KERNEL_ENTRY_DEVREDOP(Min, true)
|
||||
|
||||
// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg
|
||||
void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
|
||||
void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS * 2] = {
|
||||
#ifdef COMPILE_MSCCL_KERNEL
|
||||
MSCCL_KERNEL_ENTRY()
|
||||
#endif
|
||||
@@ -398,7 +402,7 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
work.maxAllowedCount = status.maxAllowedCount;
|
||||
work.hasReduce = hostAlgo->hasReduce;
|
||||
work.redOpArgIsPtr = opFull.scalarArgIsPtr;
|
||||
INFO(NCCL_COLL, "MSCCL: Setup Kernel finished");
|
||||
INFO(NCCL_COLL, "MSCCL: typeMask %x Setup Kernel finished", hostAlgo->typeMask);
|
||||
|
||||
uint32_t workFifoIdxMask = status.workFifoDepth - 1;
|
||||
uint32_t workFifoSent = status.workFifoSent;
|
||||
@@ -423,7 +427,15 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
|
||||
struct mscclWork *workPtr = status.workFifo + (workFifoSent & workFifoIdxMask);
|
||||
void *args[3] = {&comm->devComm, &devAlgo, &workPtr};
|
||||
void *func = mscclKernelEntries[(opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol];
|
||||
uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol;
|
||||
uint8_t fullOpMask = (1<<MSCCL_RECV_COPY_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_COPY_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_COPY) |
|
||||
(1<<MSCCL_LOCAL_COPY);
|
||||
//check if need full ops msccl kernel
|
||||
if (hostAlgo->typeMask & fullOpMask) fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
|
||||
void *func = mscclKernelEntries[fnIndex];
|
||||
if (enableDoneEvent) {
|
||||
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));
|
||||
} else {
|
||||
|
||||
Ссылка в новой задаче
Block a user