diff --git a/src/device/common.cu b/src/device/common.cu index 0022b6c233..ae15e06598 100644 --- a/src/device/common.cu +++ b/src/device/common.cu @@ -18,11 +18,17 @@ struct RunWorkNop { }; __launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) { - ncclKernelMain<-1, RunWorkNop, false>(comm, channelMask, workHead); + ncclKernelMain<-1, RunWorkNop, false, 2>(comm, channelMask, workHead); +} +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) { + ncclKernelMain<-1, RunWorkNop, false, 4>(comm, channelMask, workHead); } #ifdef ENABLE_COLLTRACE __launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) { - ncclKernelMain<-1, RunWorkNop, true>(comm, channelMask, workHead); + ncclKernelMain<-1, RunWorkNop, true, 2>(comm, channelMask, workHead); +} +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) { + ncclKernelMain<-1, RunWorkNop, true, 4>(comm, channelMask, workHead); } #endif diff --git a/src/device/common.h b/src/device/common.h index e9f0bed8b3..44e628b57c 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -227,7 +227,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we } } -template +template __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) { const int tid = threadIdx.x; int x = tid; @@ -346,9 +346,15 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct SpecializedRunWork().run(&ncclShmem.work); } else { #ifdef USE_INDIRECT_FUNCTION_CALL - ncclDevFuncTable[ncclShmem.work.header.funcIndex](); + if (COLL_UNROLL == 4) + ncclDevFuncTable_4[ncclShmem.work.header.funcIndex](); + else + ncclDevFuncTable[ncclShmem.work.header.funcIndex](); #else - NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex); + if (COLL_UNROLL == 4) + NCCL_CALL_FUNCTIONS_4(ncclShmem.work.header.funcIndex); + else + NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex); #endif } @@ -379,8 +385,10 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct } __global__ void ncclDevKernel_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead); +__global__ void ncclDevKernel_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead); #ifdef ENABLE_COLLTRACE __global__ void ncclDevKernelDebug_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead); +__global__ void ncclDevKernelDebug_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead); #endif #ifdef USE_INDIRECT_FUNCTION_CALL diff --git a/src/device/generate.py b/src/device/generate.py index 892e4b2a7f..6fa02c44dd 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -162,7 +162,7 @@ def calc_unroll_for_local_arch(): # Homogeneous system is required to build for only 1 varient of unroll factor if len(gfx_targets) == 1: gfx_name, cu_count = gfx_targets[0] - if ("gfx908" == gfx_name or "gfx94" in gfx_name) and cu_count > 80: + if "gfx908" == gfx_name or ("gfx94" in gfx_name and cu_count > 80): return 2 else: return 4 @@ -258,12 +258,12 @@ def equivalent_primary(coll, algo, proto, redop, ty, unroll): # Order rows are enumerated must match formula of `ncclDevFuncId()`: def enumerate_func_rows(): - for coll in all_colls: - for algo in all_algos: - for proto in all_protos: - for redop in all_redops: - for ty in all_tys: - for unroll in all_unroll: + for unroll in all_unroll: + for coll in all_colls: + for algo in all_algos: + for proto in all_protos: + for redop in all_redops: + for ty in all_tys: if func_validate(coll, algo, proto, redop, ty): yield (coll, algo, proto, redop, ty, unroll) @@ -272,12 +272,12 @@ def custom_sort_key(fn): coll, algo, proto, redop, ty, unroll = fn return ( + all_unroll.index(unroll), all_colls.index(coll), all_algos.index(algo), all_protos.index(proto), all_redops.index(redop), - all_tys.index(ty), - all_unroll.index(unroll) + all_tys.index(ty) ) ################################################################################ @@ -319,6 +319,8 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n") index = 0 for fn in primary_funcs: + coll, algo, proto, redop, ty, unroll = fn + if unroll != "2": continue sym = paste("_", "ncclDevFunc", *fn) if fn[2] == "LL128": out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") @@ -331,6 +333,23 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: index += 1 out("nullptr};\n") out("\n") + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") + index4 = 0 + for fn in primary_funcs: + coll, algo, proto, redop, ty, unroll = fn + if unroll != "4": continue + sym = paste("_", "ncclDevFunc", *fn) + if fn[2] == "LL128": + out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s,\n#else\n" % (index4, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll)) + else: + out("/*%4d*/ %s,\n" % (index4, sym)) + index4 += 1 + out("nullptr};\n") + out("\n") if not is_ifc: out("template\n" @@ -351,6 +370,24 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n") out(f" Caller<0, {index}>::call(funcIndex);\n") out("}\n\n") + out("template\n" + "struct Caller4 {\n" + " static __forceinline__ __device__ __host__\n" + " void call4(unsigned short funcIndex) noexcept\n" + " {\n" + " constexpr unsigned short m = f + (l - f) / 2;\n" + " return (funcIndex < m) ? Caller4::call4(funcIndex) : Caller4::call4(funcIndex);\n" + " }\n" + "};\n" + "\n" + "template\n" + "struct Caller4{\n" + " static __forceinline__ __device__ __host__\n" + " void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n" + "};\n") + out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n") + out(f" Caller4<0, {index4}>::call4(funcIndex);\n") + out("}\n\n") # Generate /device_table.cpp if is_colltrace: @@ -363,7 +400,8 @@ if is_colltrace: out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n") for fn in primary_funcs: - out(' "%s",\n' % paste("_", "ncclDevFunc", *fn)) + if fn[5] == "4": continue + out(' "%s",\n' % paste("_", "ncclDevFunc", *fn[:-1])) for ty in all_tys: out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n') out("};\n") @@ -379,11 +417,11 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: # The mapping from function rows to valid primary function ids. out("extern int const ncclDevFuncRowToId[] = {\n") index = 0 - for fn in func_rows: + for fn in func_rows[:len(func_rows)//2]: fn_id, comment = -1, "" if fn is not None: fn_id = primary_to_index[equivalent_primary(*fn)] - comment = " // " + paste(" ", *fn) + comment = " // " + paste(" ", *fn[:-1]) out("/*%4d*/ %d,%s\n" % (index, fn_id, comment)) index += 1 out(f"{index}") diff --git a/src/enqueue.cc b/src/enqueue.cc index b917acb892..769389b813 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -31,15 +31,18 @@ struct ncclKernelMatch { }; #ifdef ENABLE_COLLTRACE -#define ncclGetKernelIndex(p_comm) ((p_comm)->collTraceThread ? 1 : 0) -static ncclKernelMatch const ncclKerns[2] = { +#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + (p_comm)->collTraceThread ? 2 : 0) +static ncclKernelMatch const ncclKerns[4] = { {(void *)ncclDevKernel_Generic, true}, + {(void *)ncclDevKernel_Generic_4, true}, {(void *)ncclDevKernelDebug_Generic, true}, + {(void *)ncclDevKernelDebug_Generic_4, true} }; #else -#define ncclGetKernelIndex(p_comm) (0) -static ncclKernelMatch const ncclKerns[1] = { - {(void*)ncclDevKernel_Generic, true} +#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll) +static ncclKernelMatch const ncclKerns[2] = { + {(void*)ncclDevKernel_Generic, true}, + {(void*)ncclDevKernel_Generic_4, true} }; #endif @@ -53,21 +56,11 @@ static ncclResult_t initCollProxyOp(struct ncclInfo* collInfo, int channelId, ui static ncclResult_t getTunerInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps); static ncclResult_t topoGetAlgoInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps); static ncclResult_t getChannnelThreadInfo(struct ncclInfo* collInfo); -static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo, int unroll); +static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo); static ncclResult_t getPatternInfo(struct ncclInfo* collInfo); static ncclResult_t getLoopInfo(struct ncclInfo* collInfo); static ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetSupport); -int getUnrollFactor(struct ncclComm* comm) { - hipDeviceProp_t devProp; - CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); - if(IsArchMatch(devProp.gcnArchName, "gfx908") || (IsArchMatch(devProp.gcnArchName, "gfx94") - && devProp.multiProcessorCount > 80)) - return NCCL_UNROLL_2; - else - return NCCL_UNROLL_4; -} - // Returns maximum kernel stack size of all CUDA kernels ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) { constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]); @@ -186,7 +179,7 @@ static ncclResult_t appendWorkElemP2p( struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId, struct ncclWorkElemP2p const *elem, bool fuseOk ) { - int funcIndex = ncclDevFuncId_P2p(plan->unroll); + int funcIndex = ncclDevFuncId_P2p(); if (funcIndex < 0) { WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); return ncclInvalidUsage; @@ -843,7 +836,7 @@ static ncclResult_t scheduleCollTasksToPlan( NCCLCHECK(getTunerInfo(aggInfo, collNetSupport, nvlsSupport, 1)); NCCLCHECK(topoGetAlgoInfo(aggInfo, collNetSupport, nvlsSupport, 1)); NCCLCHECK(getChannnelThreadInfo(aggInfo)); - NCCLCHECK(computeCollWorkFunc(aggInfo, plan->unroll)); + NCCLCHECK(computeCollWorkFunc(aggInfo)); NCCLCHECK(getPatternInfo(aggInfo)); // Try to assign algo and proto to all possible collectives @@ -1322,7 +1315,6 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { plan->comm = comm; plan->reclaimer.fn = reclaimPlan; plan->persistent = persistent; - plan->unroll = getUnrollFactor(comm); // Non-persistent kernels fill up at most half of our fifo per kernel. int nWorkBudget = plan->persistent ? INT_MAX : comm->workFifoDepth/2; @@ -1755,8 +1747,8 @@ static ncclResult_t getPatternInfo(struct ncclInfo* collInfo) { RCCL_PARAM(IntraNetThreshold, "INTRANET_THRESHOLD", 8388608); -static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo, int unroll) { - collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol, unroll); +static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo) { + collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol); if (collInfo->workFuncIndex < 0) { WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); return ncclInvalidUsage; diff --git a/src/include/comm.h b/src/include/comm.h index ccc5572a6f..be3422ec47 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -224,9 +224,6 @@ struct ncclKernelPlan { struct ncclIntruQueue proxyOpQueue; } channels[MAXCHANNELS]; size_t maxBytesPerChannel; - - // Unroll factor for plan [RCCL] - int unroll; }; #define NCCL_MAGIC 0x0280028002800280 // Nickel atomic number is 28. @@ -434,6 +431,9 @@ struct ncclComm { // buffer registration cache struct ncclRegCache regCache; uint64_t endMagic; + + // Unroll factor for plan [RCCL] + int unroll; }; enum ncclLaunchMode { diff --git a/src/include/device.h b/src/include/device.h index f808146b5a..b30a4c23c4 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -552,64 +552,58 @@ inline bool ncclNvlsSupported(int devRedOp, int type) { // Map the rowIdx to funcIdx extern int const ncclDevFuncRowToId[]; -// `ncclDevFuncId()` needs to be in sync with 'ALL_COLLS' in generate.py -inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int unroll) { +// `ncclDevFuncId()` needs to be in sync with 'all_colls' in generate.py +inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) { int row = 0; do { // RING / / Sum / int8_t if (coll == ncclFuncAllGather) { - row += proto * NCCL_NUM_UNROLLS + unroll; + row += proto; break; } - row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS; + row += NCCL_NUM_PROTOCOLS; // / / / if (coll == ncclFuncAllReduce) { - row += ((((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto) * NCCL_NUM_UNROLLS; + row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto); break; } - row += NCCL_NUM_UNROLLS * (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); + row += (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); // RING / SIMPLE / Sum / int8_t - if (coll == ncclFuncAllToAllPivot) { - row += unroll; - break; - } - row += NCCL_NUM_UNROLLS; + if (coll == ncclFuncAllToAllPivot) break; + row += 1; // RING / / Sum / int8_t if (coll == ncclFuncBroadcast) { - row += proto * NCCL_NUM_UNROLLS + unroll; + row += proto; break; } - row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS; + row += NCCL_NUM_PROTOCOLS; // RING / / / if (coll == ncclFuncReduce) { - row += (((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * proto * NCCL_NUM_UNROLLS; + row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto; break; } - row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); + row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); // RING / / / if (coll == ncclFuncReduceScatter) { - row += (((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * proto * NCCL_NUM_UNROLLS; + row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto; break; } - row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); + row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); // RING / SIMPLE / Sum / int8_t - if (coll == ncclFuncSendRecv) { - row += unroll; - break; - } - row += NCCL_NUM_UNROLLS; + if (coll == ncclFuncSendRecv) break; + row += 1; } while (false); return ncclDevFuncRowToId[row]; } -inline int ncclDevFuncId_P2p(int unroll) { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - (unroll > 0 ? 0 : 1) - 1]; } +inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - 1]; } #endif diff --git a/src/include/nccl_common.h b/src/include/nccl_common.h index 67ff0959ee..ee2563fac3 100644 --- a/src/include/nccl_common.h +++ b/src/include/nccl_common.h @@ -13,7 +13,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); #define NCCL_NUM_ONERANK 12 -#define FUNC_INDEX_TOTAL 1312 + NCCL_NUM_ONERANK +#define FUNC_INDEX_TOTAL 656 + NCCL_NUM_ONERANK #define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now typedef enum { diff --git a/src/init.cc b/src/init.cc index 9c93d8ac51..738f756beb 100644 --- a/src/init.cc +++ b/src/init.cc @@ -93,6 +93,17 @@ static uint64_t hashUniqueId(ncclUniqueId const &id) { return h; } +ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { + hipDeviceProp_t devProp; + CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); + if(IsArchMatch(devProp.gcnArchName, "gfx908") || (IsArchMatch(devProp.gcnArchName, "gfx94") + && devProp.multiProcessorCount > 80)) + comm->unroll = NCCL_UNROLL_2; + else + comm->unroll = NCCL_UNROLL_4; + return ncclSuccess; +} + #ifdef ENABLE_MSCCLPP size_t std::hash::operator ()(const ncclUniqueId& uniqueId) const noexcept { return (size_t)hashUniqueId(uniqueId); @@ -559,6 +570,8 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in // RCCL: create persistent stream for calloc CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking)); + // RCCL: determine and set unroll factor for comm + NCCLCHECK(commSetUnrollFactor(comm)); comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false; comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;