Template generic kernel for unroll factor (#1419)

* Template generic kernel for unroll factor
This commit is contained in:
Bertan Dogancay
2024-11-12 18:27:29 -05:00
zatwierdzone przez GitHub
rodzic 2d07f18696
commit cb175fb0b3
8 zmienionych plików z 116 dodań i 65 usunięć
+8 -2
Wyświetl plik
@@ -18,11 +18,17 @@ struct RunWorkNop {
};
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
ncclKernelMain<-1, RunWorkNop, false>(comm, channelMask, workHead);
ncclKernelMain<-1, RunWorkNop, false, 2>(comm, channelMask, workHead);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
ncclKernelMain<-1, RunWorkNop, false, 4>(comm, channelMask, workHead);
}
#ifdef ENABLE_COLLTRACE
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
ncclKernelMain<-1, RunWorkNop, true>(comm, channelMask, workHead);
ncclKernelMain<-1, RunWorkNop, true, 2>(comm, channelMask, workHead);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
ncclKernelMain<-1, RunWorkNop, true, 4>(comm, channelMask, workHead);
}
#endif
+11 -3
Wyświetl plik
@@ -227,7 +227,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we
}
}
template<int SpecializedFnId, typename SpecializedRunWork, bool COLLTRACE>
template<int SpecializedFnId, typename SpecializedRunWork, bool COLLTRACE, int COLL_UNROLL>
__forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
const int tid = threadIdx.x;
int x = tid;
@@ -346,9 +346,15 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct
SpecializedRunWork().run(&ncclShmem.work);
} else {
#ifdef USE_INDIRECT_FUNCTION_CALL
ncclDevFuncTable[ncclShmem.work.header.funcIndex]();
if (COLL_UNROLL == 4)
ncclDevFuncTable_4[ncclShmem.work.header.funcIndex]();
else
ncclDevFuncTable[ncclShmem.work.header.funcIndex]();
#else
NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex);
if (COLL_UNROLL == 4)
NCCL_CALL_FUNCTIONS_4(ncclShmem.work.header.funcIndex);
else
NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex);
#endif
}
@@ -379,8 +385,10 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct
}
__global__ void ncclDevKernel_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
__global__ void ncclDevKernel_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
#ifdef ENABLE_COLLTRACE
__global__ void ncclDevKernelDebug_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
__global__ void ncclDevKernelDebug_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
#endif
#ifdef USE_INDIRECT_FUNCTION_CALL
+50 -12
Wyświetl plik
@@ -162,7 +162,7 @@ def calc_unroll_for_local_arch():
# Homogeneous system is required to build for only 1 varient of unroll factor
if len(gfx_targets) == 1:
gfx_name, cu_count = gfx_targets[0]
if ("gfx908" == gfx_name or "gfx94" in gfx_name) and cu_count > 80:
if "gfx908" == gfx_name or ("gfx94" in gfx_name and cu_count > 80):
return 2
else:
return 4
@@ -258,12 +258,12 @@ def equivalent_primary(coll, algo, proto, redop, ty, unroll):
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
def enumerate_func_rows():
for coll in all_colls:
for algo in all_algos:
for proto in all_protos:
for redop in all_redops:
for ty in all_tys:
for unroll in all_unroll:
for unroll in all_unroll:
for coll in all_colls:
for algo in all_algos:
for proto in all_protos:
for redop in all_redops:
for ty in all_tys:
if func_validate(coll, algo, proto, redop, ty):
yield (coll, algo, proto, redop, ty, unroll)
@@ -272,12 +272,12 @@ def custom_sort_key(fn):
coll, algo, proto, redop, ty, unroll = fn
return (
all_unroll.index(unroll),
all_colls.index(coll),
all_algos.index(algo),
all_protos.index(proto),
all_redops.index(redop),
all_tys.index(ty),
all_unroll.index(unroll)
all_tys.index(ty)
)
################################################################################
@@ -319,6 +319,8 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n")
index = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
if unroll != "2": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
@@ -331,6 +333,23 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
index += 1
out("nullptr};\n")
out("\n")
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
index4 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
if unroll != "4": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s,\n#else\n" % (index4, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll))
else:
out("/*%4d*/ %s,\n" % (index4, sym))
index4 += 1
out("nullptr};\n")
out("\n")
if not is_ifc:
out("template<unsigned short f, unsigned short l>\n"
@@ -351,6 +370,24 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
out(f" Caller<0, {index}>::call(funcIndex);\n")
out("}\n\n")
out("template<unsigned short f, unsigned short l>\n"
"struct Caller4 {\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller4<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
"};\n")
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
out(f" Caller4<0, {index4}>::call4(funcIndex);\n")
out("}\n\n")
# Generate <gensrc>/device_table.cpp
if is_colltrace:
@@ -363,7 +400,8 @@ if is_colltrace:
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
for fn in primary_funcs:
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn))
if fn[5] == "4": continue
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn[:-1]))
for ty in all_tys:
out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n')
out("};\n")
@@ -379,11 +417,11 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
# The mapping from function rows to valid primary function ids.
out("extern int const ncclDevFuncRowToId[] = {\n")
index = 0
for fn in func_rows:
for fn in func_rows[:len(func_rows)//2]:
fn_id, comment = -1, ""
if fn is not None:
fn_id = primary_to_index[equivalent_primary(*fn)]
comment = " // " + paste(" ", *fn)
comment = " // " + paste(" ", *fn[:-1])
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
index += 1
out(f"{index}")
+13 -21
Wyświetl plik
@@ -31,15 +31,18 @@ struct ncclKernelMatch {
};
#ifdef ENABLE_COLLTRACE
#define ncclGetKernelIndex(p_comm) ((p_comm)->collTraceThread ? 1 : 0)
static ncclKernelMatch const ncclKerns[2] = {
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + (p_comm)->collTraceThread ? 2 : 0)
static ncclKernelMatch const ncclKerns[4] = {
{(void *)ncclDevKernel_Generic, true},
{(void *)ncclDevKernel_Generic_4, true},
{(void *)ncclDevKernelDebug_Generic, true},
{(void *)ncclDevKernelDebug_Generic_4, true}
};
#else
#define ncclGetKernelIndex(p_comm) (0)
static ncclKernelMatch const ncclKerns[1] = {
{(void*)ncclDevKernel_Generic, true}
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll)
static ncclKernelMatch const ncclKerns[2] = {
{(void*)ncclDevKernel_Generic, true},
{(void*)ncclDevKernel_Generic_4, true}
};
#endif
@@ -53,21 +56,11 @@ static ncclResult_t initCollProxyOp(struct ncclInfo* collInfo, int channelId, ui
static ncclResult_t getTunerInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps);
static ncclResult_t topoGetAlgoInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps);
static ncclResult_t getChannnelThreadInfo(struct ncclInfo* collInfo);
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo, int unroll);
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo);
static ncclResult_t getPatternInfo(struct ncclInfo* collInfo);
static ncclResult_t getLoopInfo(struct ncclInfo* collInfo);
static ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetSupport);
int getUnrollFactor(struct ncclComm* comm) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
if(IsArchMatch(devProp.gcnArchName, "gfx908") || (IsArchMatch(devProp.gcnArchName, "gfx94")
&& devProp.multiProcessorCount > 80))
return NCCL_UNROLL_2;
else
return NCCL_UNROLL_4;
}
// Returns maximum kernel stack size of all CUDA kernels
ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) {
constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]);
@@ -186,7 +179,7 @@ static ncclResult_t appendWorkElemP2p(
struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId,
struct ncclWorkElemP2p const *elem, bool fuseOk
) {
int funcIndex = ncclDevFuncId_P2p(plan->unroll);
int funcIndex = ncclDevFuncId_P2p();
if (funcIndex < 0) {
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
return ncclInvalidUsage;
@@ -843,7 +836,7 @@ static ncclResult_t scheduleCollTasksToPlan(
NCCLCHECK(getTunerInfo(aggInfo, collNetSupport, nvlsSupport, 1));
NCCLCHECK(topoGetAlgoInfo(aggInfo, collNetSupport, nvlsSupport, 1));
NCCLCHECK(getChannnelThreadInfo(aggInfo));
NCCLCHECK(computeCollWorkFunc(aggInfo, plan->unroll));
NCCLCHECK(computeCollWorkFunc(aggInfo));
NCCLCHECK(getPatternInfo(aggInfo));
// Try to assign algo and proto to all possible collectives
@@ -1322,7 +1315,6 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
plan->comm = comm;
plan->reclaimer.fn = reclaimPlan;
plan->persistent = persistent;
plan->unroll = getUnrollFactor(comm);
// Non-persistent kernels fill up at most half of our fifo per kernel.
int nWorkBudget = plan->persistent ? INT_MAX : comm->workFifoDepth/2;
@@ -1755,8 +1747,8 @@ static ncclResult_t getPatternInfo(struct ncclInfo* collInfo) {
RCCL_PARAM(IntraNetThreshold, "INTRANET_THRESHOLD", 8388608);
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo, int unroll) {
collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol, unroll);
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo) {
collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol);
if (collInfo->workFuncIndex < 0) {
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
return ncclInvalidUsage;
+3 -3
Wyświetl plik
@@ -224,9 +224,6 @@ struct ncclKernelPlan {
struct ncclIntruQueue<struct ncclProxyOp, &ncclProxyOp::enqNext> proxyOpQueue;
} channels[MAXCHANNELS];
size_t maxBytesPerChannel;
// Unroll factor for plan [RCCL]
int unroll;
};
#define NCCL_MAGIC 0x0280028002800280 // Nickel atomic number is 28.
@@ -434,6 +431,9 @@ struct ncclComm {
// buffer registration cache
struct ncclRegCache regCache;
uint64_t endMagic;
// Unroll factor for plan [RCCL]
int unroll;
};
enum ncclLaunchMode {
+17 -23
Wyświetl plik
@@ -552,64 +552,58 @@ inline bool ncclNvlsSupported(int devRedOp, int type) {
// Map the rowIdx to funcIdx
extern int const ncclDevFuncRowToId[];
// `ncclDevFuncId()` needs to be in sync with 'ALL_COLLS' in generate.py
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int unroll) {
// `ncclDevFuncId()` needs to be in sync with 'all_colls' in generate.py
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) {
int row = 0;
do {
// RING / <all_protos> / Sum / int8_t
if (coll == ncclFuncAllGather) {
row += proto * NCCL_NUM_UNROLLS + unroll;
row += proto;
break;
}
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS;
row += NCCL_NUM_PROTOCOLS;
// <all_algos> / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncAllReduce) {
row += ((((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto) * NCCL_NUM_UNROLLS;
row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto);
break;
}
row += NCCL_NUM_UNROLLS * (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
row += (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
// RING / SIMPLE / Sum / int8_t
if (coll == ncclFuncAllToAllPivot) {
row += unroll;
break;
}
row += NCCL_NUM_UNROLLS;
if (coll == ncclFuncAllToAllPivot) break;
row += 1;
// RING / <all_protos> / Sum / int8_t
if (coll == ncclFuncBroadcast) {
row += proto * NCCL_NUM_UNROLLS + unroll;
row += proto;
break;
}
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS;
row += NCCL_NUM_PROTOCOLS;
// RING / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncReduce) {
row += (((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * proto * NCCL_NUM_UNROLLS;
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto;
break;
}
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
// RING / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncReduceScatter) {
row += (((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * proto * NCCL_NUM_UNROLLS;
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto;
break;
}
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
// RING / SIMPLE / Sum / int8_t
if (coll == ncclFuncSendRecv) {
row += unroll;
break;
}
row += NCCL_NUM_UNROLLS;
if (coll == ncclFuncSendRecv) break;
row += 1;
} while (false);
return ncclDevFuncRowToId[row];
}
inline int ncclDevFuncId_P2p(int unroll) { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - (unroll > 0 ? 0 : 1) - 1]; }
inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - 1]; }
#endif
+1 -1
Wyświetl plik
@@ -13,7 +13,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
#define NCCL_NUM_ONERANK 12
#define FUNC_INDEX_TOTAL 1312 + NCCL_NUM_ONERANK
#define FUNC_INDEX_TOTAL 656 + NCCL_NUM_ONERANK
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
typedef enum {
+13
Wyświetl plik
@@ -93,6 +93,17 @@ static uint64_t hashUniqueId(ncclUniqueId const &id) {
return h;
}
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
if(IsArchMatch(devProp.gcnArchName, "gfx908") || (IsArchMatch(devProp.gcnArchName, "gfx94")
&& devProp.multiProcessorCount > 80))
comm->unroll = NCCL_UNROLL_2;
else
comm->unroll = NCCL_UNROLL_4;
return ncclSuccess;
}
#ifdef ENABLE_MSCCLPP
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {
return (size_t)hashUniqueId(uniqueId);
@@ -559,6 +570,8 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
// RCCL: create persistent stream for calloc
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
// RCCL: determine and set unroll factor for comm
NCCLCHECK(commSetUnrollFactor(comm));
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;