diff --git a/src/device/generate.py b/src/device/generate.py index 4a646dc283..b1aca93bef 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -52,19 +52,21 @@ else: # make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" # # # Only AllReduce Sum int32_t (but all algos, protos) -# make ONLY_FUNCS="AllReduce * * Sum int32_t" +# make ONLY_FUNCS="AllReduce * * Sum i32" # # # Only AllReduce RING Max float (but all protos and unrolls) -# make ONLY_FUNCS="AllReduce RING * Max float" +# make ONLY_FUNCS="AllReduce RING * Max f32" # # # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1 -# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16 1" +# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16 1" # # # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter) -# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float *" +# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32 *" # --- or --- -# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float *" -# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|AllGather RING LL/SIMPLE Sum int8_t 1/2/4|AllToAllPivot RING SIMPLE Sum int8_t 1/2/4|Broadcast RING LL/SIMPLE Sum int8_t 1/2/4|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|SendRecv RING SIMPLE Sum int8_t 1/2/4" +# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32 *" +# +# +# make ONLY_FUNCS="AllReduce RING/TREE LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|AllGather RING LL/LL128/SIMPLE Sum i8 1/2/4|AllToAllPivot RING SIMPLE Sum i8 1/2/4|Broadcast RING LL/LL128/SIMPLE Sum i8 1/2/4|Reduce RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|ReduceScatter RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|SendRecv RING SIMPLE Sum i8 1/2/4" # # # ONLY_FUNCS can be used together for debugging @@ -159,9 +161,6 @@ def func_filter(function_params, current_idx, item_list=None): # If the paramter is equal to '*', include all possible cases for it if current_element == "*": - if current_idx == 0: - raise ValueError("Error: Paramter 'COLL' can not be type all '*'.") - # all_params list must be in the same order as function_params --> # Get the current list from all_params current_list = all_params[current_idx] diff --git a/src/enqueue.cc b/src/enqueue.cc index 3946b32e37..fe3a8d6334 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -29,7 +29,7 @@ using namespace rccl; /* [RCCL] Determine which GPU kernel to execute */ -void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL) +void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task) { // At this time, unroll factor is controlled only by passed in unroll argument // After more investigation, this may be further tuned by the actual task being processed @@ -48,9 +48,9 @@ void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* tas return rcclKernelTable[firstKernel + kernelIdx].funcPtr; } } - // Fall back to default unroll - WARN("Requested RCCL_UNROLL_FACTOR: %d does not exist in `rcclKernelTable`. Falling back to default unroll: %d", unroll, rcclKernelTable[firstKernel].unroll); - return rcclKernelTable[firstKernel].funcPtr; + + // If does not match, return null + return nullptr; } static int rcclProtoGrainSize(int proto, ncclComm *comm){ diff --git a/src/include/enqueue.h b/src/include/enqueue.h index a381846ad0..b4a2707475 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -17,6 +17,8 @@ #define NCCL_SIMPLE_ALIGNMENT (WARP_SIZE * 8LL * 16LL) #define NCCL_BYTES_ALIGNMENT 16 +void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL); + ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize); ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); ncclResult_t ncclLaunchPrepare(struct ncclComm* comm); diff --git a/src/include/rccl_common.h b/src/include/rccl_common.h index 9090924ccb..f29096a618 100644 --- a/src/include/rccl_common.h +++ b/src/include/rccl_common.h @@ -23,6 +23,8 @@ THE SOFTWARE. #define RCCL_COMMON_H_ #include "nccl_common.h" #include "nccl.h" +#include "param.h" + typedef enum RcclTunableColls { RCCL_UNSUPPORTED_TUNABLE = -1, RCCL_RS_TUNABLE = 0, // reduce_scatter index @@ -78,4 +80,5 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co int* algo, int* protocol, int* maxChannels); ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount); -#endif \ No newline at end of file +ncclResult_t commSetUnrollFactor(struct ncclComm* comm); +#endif diff --git a/src/init.cc b/src/init.cc index bce482c16a..dc48ed566a 100644 --- a/src/init.cc +++ b/src/init.cc @@ -51,6 +51,7 @@ #include "mscclpp/mscclpp_nccl.h" #endif #include "rocm_smi_wrap.h" +#include "rccl_common.h" // [/RCCL] #include "msccl/msccl_lifecycle.h" @@ -87,37 +88,6 @@ NCCL_PARAM(RuntimeConnect, "RUNTIME_CONNECT", 1); struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {}; static ncclResult_t commReclaim(ncclComm_t comm); -//RCCL runtime param to set Unroll Factor -RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0); - -ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { - hipDeviceProp_t devProp; - CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); - - //If RCCL runtime param is set, it will override defaults - if (rcclParamUnrollFactor() != 0) { - comm->unroll = rcclParamUnrollFactor(); - INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll); - } - else { - if (IsArchMatch(devProp.gcnArchName, "gfx950")) { - //on gfx950, use unroll=1 for single-node and unroll=2 for multi-node - if (comm->nNodes == 1) - comm->unroll = 1; - else - comm->unroll = 2; - } - else if((IsArchMatch(devProp.gcnArchName, "gfx908")) || - (IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)) - //on MI300X and gfx908, use unroll=2 - comm->unroll = 2; - else - comm->unroll = 4; - INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll); - } - return ncclSuccess; -} - #ifdef ENABLE_MSCCLPP size_t std::hash::operator ()(const ncclUniqueId& uniqueId) const noexcept { return (size_t)getHash(uniqueId.internal, NCCL_UNIQUE_ID_BYTES); diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index 7fba83dbf8..643b21f3d0 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -24,6 +24,7 @@ THE SOFTWARE. #include "comm.h" #include "graph/topo.h" #include "enqueue.h" + void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) { // Honor user input for protocol choice static int userProtocolInput = -2; @@ -125,3 +126,44 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, maxCount = ncclFuncMaxSendRecvCount(func, nRanks, count); return ncclSuccess; } + +//RCCL runtime param to set Unroll Factor +RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0); + +ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { + hipDeviceProp_t devProp; + CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); + + //If RCCL runtime param is set, it will override defaults, if supported + if (rcclParamUnrollFactor() != 0) { +#if ENABLE_COLLTRACE + if(rcclGetKernelIndex(rcclParamUnrollFactor(), comm->collTraceEnabled)) { +#else + if(rcclGetKernelIndex(rcclParamUnrollFactor(), false)) { +#endif + comm->unroll = rcclParamUnrollFactor(); + INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll); + return ncclSuccess; + } + else { + // Fall back to default unroll + WARN("Requested RCCL_UNROLL_FACTOR: %ld is invalid and does not exist in `rcclKernelTable`. Falling back to pre-set unroll.", rcclParamUnrollFactor()); + } + } + + if (IsArchMatch(devProp.gcnArchName, "gfx950")) { + //on gfx950, use unroll=1 for single-node and unroll=2 for multi-node + if (comm->nNodes == 1) + comm->unroll = 1; + else + comm->unroll = 2; + } + else if((IsArchMatch(devProp.gcnArchName, "gfx908")) || + (IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)) + //on MI300X and gfx908, use unroll=2 + comm->unroll = 2; + else + comm->unroll = 4; + INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll); + return ncclSuccess; +}