Template generic kernel for unroll factor (#1419)
* Template generic kernel for unroll factor
This commit is contained in:
zatwierdzone przez
GitHub
rodzic
2d07f18696
commit
cb175fb0b3
@@ -18,11 +18,17 @@ struct RunWorkNop {
|
||||
};
|
||||
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
|
||||
ncclKernelMain<-1, RunWorkNop, false>(comm, channelMask, workHead);
|
||||
ncclKernelMain<-1, RunWorkNop, false, 2>(comm, channelMask, workHead);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
|
||||
ncclKernelMain<-1, RunWorkNop, false, 4>(comm, channelMask, workHead);
|
||||
}
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
|
||||
ncclKernelMain<-1, RunWorkNop, true>(comm, channelMask, workHead);
|
||||
ncclKernelMain<-1, RunWorkNop, true, 2>(comm, channelMask, workHead);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
|
||||
ncclKernelMain<-1, RunWorkNop, true, 4>(comm, channelMask, workHead);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
+11
-3
@@ -227,7 +227,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we
|
||||
}
|
||||
}
|
||||
|
||||
template<int SpecializedFnId, typename SpecializedRunWork, bool COLLTRACE>
|
||||
template<int SpecializedFnId, typename SpecializedRunWork, bool COLLTRACE, int COLL_UNROLL>
|
||||
__forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead) {
|
||||
const int tid = threadIdx.x;
|
||||
int x = tid;
|
||||
@@ -346,9 +346,15 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct
|
||||
SpecializedRunWork().run(&ncclShmem.work);
|
||||
} else {
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
ncclDevFuncTable[ncclShmem.work.header.funcIndex]();
|
||||
if (COLL_UNROLL == 4)
|
||||
ncclDevFuncTable_4[ncclShmem.work.header.funcIndex]();
|
||||
else
|
||||
ncclDevFuncTable[ncclShmem.work.header.funcIndex]();
|
||||
#else
|
||||
NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex);
|
||||
if (COLL_UNROLL == 4)
|
||||
NCCL_CALL_FUNCTIONS_4(ncclShmem.work.header.funcIndex);
|
||||
else
|
||||
NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -379,8 +385,10 @@ __forceinline__ __device__ void ncclKernelMain(struct ncclDevComm* comm, struct
|
||||
}
|
||||
|
||||
__global__ void ncclDevKernel_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
|
||||
__global__ void ncclDevKernel_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
__global__ void ncclDevKernelDebug_Generic(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
|
||||
__global__ void ncclDevKernelDebug_Generic_4(struct ncclDevComm* comm, struct channelMasks channelMask, struct ncclWork* workHead);
|
||||
#endif
|
||||
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
|
||||
+50
-12
@@ -162,7 +162,7 @@ def calc_unroll_for_local_arch():
|
||||
# Homogeneous system is required to build for only 1 varient of unroll factor
|
||||
if len(gfx_targets) == 1:
|
||||
gfx_name, cu_count = gfx_targets[0]
|
||||
if ("gfx908" == gfx_name or "gfx94" in gfx_name) and cu_count > 80:
|
||||
if "gfx908" == gfx_name or ("gfx94" in gfx_name and cu_count > 80):
|
||||
return 2
|
||||
else:
|
||||
return 4
|
||||
@@ -258,12 +258,12 @@ def equivalent_primary(coll, algo, proto, redop, ty, unroll):
|
||||
|
||||
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
||||
def enumerate_func_rows():
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
for unroll in all_unroll:
|
||||
for unroll in all_unroll:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
yield (coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
@@ -272,12 +272,12 @@ def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
all_colls.index(coll),
|
||||
all_algos.index(algo),
|
||||
all_protos.index(proto),
|
||||
all_redops.index(redop),
|
||||
all_tys.index(ty),
|
||||
all_unroll.index(unroll)
|
||||
all_tys.index(ty)
|
||||
)
|
||||
|
||||
################################################################################
|
||||
@@ -319,6 +319,8 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n")
|
||||
index = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "2": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
@@ -331,6 +333,23 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
index += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index4 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "4": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index4, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index4, sym))
|
||||
index4 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
if not is_ifc:
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
@@ -351,6 +370,24 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller<0, {index}>::call(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller4 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller4<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller4<0, {index4}>::call4(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
|
||||
# Generate <gensrc>/device_table.cpp
|
||||
if is_colltrace:
|
||||
@@ -363,7 +400,8 @@ if is_colltrace:
|
||||
|
||||
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
for fn in primary_funcs:
|
||||
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn))
|
||||
if fn[5] == "4": continue
|
||||
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn[:-1]))
|
||||
for ty in all_tys:
|
||||
out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n')
|
||||
out("};\n")
|
||||
@@ -379,11 +417,11 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# The mapping from function rows to valid primary function ids.
|
||||
out("extern int const ncclDevFuncRowToId[] = {\n")
|
||||
index = 0
|
||||
for fn in func_rows:
|
||||
for fn in func_rows[:len(func_rows)//2]:
|
||||
fn_id, comment = -1, ""
|
||||
if fn is not None:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
comment = " // " + paste(" ", *fn)
|
||||
comment = " // " + paste(" ", *fn[:-1])
|
||||
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
|
||||
index += 1
|
||||
out(f"{index}")
|
||||
|
||||
+13
-21
@@ -31,15 +31,18 @@ struct ncclKernelMatch {
|
||||
};
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->collTraceThread ? 1 : 0)
|
||||
static ncclKernelMatch const ncclKerns[2] = {
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + (p_comm)->collTraceThread ? 2 : 0)
|
||||
static ncclKernelMatch const ncclKerns[4] = {
|
||||
{(void *)ncclDevKernel_Generic, true},
|
||||
{(void *)ncclDevKernel_Generic_4, true},
|
||||
{(void *)ncclDevKernelDebug_Generic, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_4, true}
|
||||
};
|
||||
#else
|
||||
#define ncclGetKernelIndex(p_comm) (0)
|
||||
static ncclKernelMatch const ncclKerns[1] = {
|
||||
{(void*)ncclDevKernel_Generic, true}
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll)
|
||||
static ncclKernelMatch const ncclKerns[2] = {
|
||||
{(void*)ncclDevKernel_Generic, true},
|
||||
{(void*)ncclDevKernel_Generic_4, true}
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -53,21 +56,11 @@ static ncclResult_t initCollProxyOp(struct ncclInfo* collInfo, int channelId, ui
|
||||
static ncclResult_t getTunerInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps);
|
||||
static ncclResult_t topoGetAlgoInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps);
|
||||
static ncclResult_t getChannnelThreadInfo(struct ncclInfo* collInfo);
|
||||
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo, int unroll);
|
||||
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo);
|
||||
static ncclResult_t getPatternInfo(struct ncclInfo* collInfo);
|
||||
static ncclResult_t getLoopInfo(struct ncclInfo* collInfo);
|
||||
static ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetSupport);
|
||||
|
||||
int getUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
if(IsArchMatch(devProp.gcnArchName, "gfx908") || (IsArchMatch(devProp.gcnArchName, "gfx94")
|
||||
&& devProp.multiProcessorCount > 80))
|
||||
return NCCL_UNROLL_2;
|
||||
else
|
||||
return NCCL_UNROLL_4;
|
||||
}
|
||||
|
||||
// Returns maximum kernel stack size of all CUDA kernels
|
||||
ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) {
|
||||
constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]);
|
||||
@@ -186,7 +179,7 @@ static ncclResult_t appendWorkElemP2p(
|
||||
struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId,
|
||||
struct ncclWorkElemP2p const *elem, bool fuseOk
|
||||
) {
|
||||
int funcIndex = ncclDevFuncId_P2p(plan->unroll);
|
||||
int funcIndex = ncclDevFuncId_P2p();
|
||||
if (funcIndex < 0) {
|
||||
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
|
||||
return ncclInvalidUsage;
|
||||
@@ -843,7 +836,7 @@ static ncclResult_t scheduleCollTasksToPlan(
|
||||
NCCLCHECK(getTunerInfo(aggInfo, collNetSupport, nvlsSupport, 1));
|
||||
NCCLCHECK(topoGetAlgoInfo(aggInfo, collNetSupport, nvlsSupport, 1));
|
||||
NCCLCHECK(getChannnelThreadInfo(aggInfo));
|
||||
NCCLCHECK(computeCollWorkFunc(aggInfo, plan->unroll));
|
||||
NCCLCHECK(computeCollWorkFunc(aggInfo));
|
||||
NCCLCHECK(getPatternInfo(aggInfo));
|
||||
|
||||
// Try to assign algo and proto to all possible collectives
|
||||
@@ -1322,7 +1315,6 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
|
||||
plan->comm = comm;
|
||||
plan->reclaimer.fn = reclaimPlan;
|
||||
plan->persistent = persistent;
|
||||
plan->unroll = getUnrollFactor(comm);
|
||||
|
||||
// Non-persistent kernels fill up at most half of our fifo per kernel.
|
||||
int nWorkBudget = plan->persistent ? INT_MAX : comm->workFifoDepth/2;
|
||||
@@ -1755,8 +1747,8 @@ static ncclResult_t getPatternInfo(struct ncclInfo* collInfo) {
|
||||
|
||||
RCCL_PARAM(IntraNetThreshold, "INTRANET_THRESHOLD", 8388608);
|
||||
|
||||
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo, int unroll) {
|
||||
collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol, unroll);
|
||||
static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo) {
|
||||
collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol);
|
||||
if (collInfo->workFuncIndex < 0) {
|
||||
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
|
||||
return ncclInvalidUsage;
|
||||
|
||||
@@ -224,9 +224,6 @@ struct ncclKernelPlan {
|
||||
struct ncclIntruQueue<struct ncclProxyOp, &ncclProxyOp::enqNext> proxyOpQueue;
|
||||
} channels[MAXCHANNELS];
|
||||
size_t maxBytesPerChannel;
|
||||
|
||||
// Unroll factor for plan [RCCL]
|
||||
int unroll;
|
||||
};
|
||||
|
||||
#define NCCL_MAGIC 0x0280028002800280 // Nickel atomic number is 28.
|
||||
@@ -434,6 +431,9 @@ struct ncclComm {
|
||||
// buffer registration cache
|
||||
struct ncclRegCache regCache;
|
||||
uint64_t endMagic;
|
||||
|
||||
// Unroll factor for plan [RCCL]
|
||||
int unroll;
|
||||
};
|
||||
|
||||
enum ncclLaunchMode {
|
||||
|
||||
+17
-23
@@ -552,64 +552,58 @@ inline bool ncclNvlsSupported(int devRedOp, int type) {
|
||||
// Map the rowIdx to funcIdx
|
||||
extern int const ncclDevFuncRowToId[];
|
||||
|
||||
// `ncclDevFuncId()` needs to be in sync with 'ALL_COLLS' in generate.py
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int unroll) {
|
||||
// `ncclDevFuncId()` needs to be in sync with 'all_colls' in generate.py
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) {
|
||||
int row = 0;
|
||||
do {
|
||||
// RING / <all_protos> / Sum / int8_t
|
||||
if (coll == ncclFuncAllGather) {
|
||||
row += proto * NCCL_NUM_UNROLLS + unroll;
|
||||
row += proto;
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS;
|
||||
row += NCCL_NUM_PROTOCOLS;
|
||||
|
||||
// <all_algos> / <all_protos> / <all_redops> / <all_types>
|
||||
if (coll == ncclFuncAllReduce) {
|
||||
row += ((((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto) * NCCL_NUM_UNROLLS;
|
||||
row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto);
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS * (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
row += (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
|
||||
// RING / SIMPLE / Sum / int8_t
|
||||
if (coll == ncclFuncAllToAllPivot) {
|
||||
row += unroll;
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS;
|
||||
if (coll == ncclFuncAllToAllPivot) break;
|
||||
row += 1;
|
||||
|
||||
// RING / <all_protos> / Sum / int8_t
|
||||
if (coll == ncclFuncBroadcast) {
|
||||
row += proto * NCCL_NUM_UNROLLS + unroll;
|
||||
row += proto;
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS;
|
||||
row += NCCL_NUM_PROTOCOLS;
|
||||
|
||||
// RING / <all_protos> / <all_redops> / <all_types>
|
||||
if (coll == ncclFuncReduce) {
|
||||
row += (((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * proto * NCCL_NUM_UNROLLS;
|
||||
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto;
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
|
||||
// RING / <all_protos> / <all_redops> / <all_types>
|
||||
if (coll == ncclFuncReduceScatter) {
|
||||
row += (((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) * NCCL_NUM_UNROLLS + unroll) - NCCL_NUM_FLOATS * proto * NCCL_NUM_UNROLLS;
|
||||
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto;
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
|
||||
// RING / SIMPLE / Sum / int8_t
|
||||
if (coll == ncclFuncSendRecv) {
|
||||
row += unroll;
|
||||
break;
|
||||
}
|
||||
row += NCCL_NUM_UNROLLS;
|
||||
if (coll == ncclFuncSendRecv) break;
|
||||
row += 1;
|
||||
|
||||
} while (false);
|
||||
|
||||
return ncclDevFuncRowToId[row];
|
||||
}
|
||||
|
||||
inline int ncclDevFuncId_P2p(int unroll) { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - (unroll > 0 ? 0 : 1) - 1]; }
|
||||
inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - 1]; }
|
||||
|
||||
#endif
|
||||
|
||||
@@ -13,7 +13,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC
|
||||
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
|
||||
|
||||
#define NCCL_NUM_ONERANK 12
|
||||
#define FUNC_INDEX_TOTAL 1312 + NCCL_NUM_ONERANK
|
||||
#define FUNC_INDEX_TOTAL 656 + NCCL_NUM_ONERANK
|
||||
|
||||
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
|
||||
typedef enum {
|
||||
|
||||
@@ -93,6 +93,17 @@ static uint64_t hashUniqueId(ncclUniqueId const &id) {
|
||||
return h;
|
||||
}
|
||||
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
if(IsArchMatch(devProp.gcnArchName, "gfx908") || (IsArchMatch(devProp.gcnArchName, "gfx94")
|
||||
&& devProp.multiProcessorCount > 80))
|
||||
comm->unroll = NCCL_UNROLL_2;
|
||||
else
|
||||
comm->unroll = NCCL_UNROLL_4;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {
|
||||
return (size_t)hashUniqueId(uniqueId);
|
||||
@@ -559,6 +570,8 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
|
||||
|
||||
// RCCL: create persistent stream for calloc
|
||||
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
|
||||
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user