[INIT] Fix fallback for unsupported user-specified runtime unroll factor (#1780)

* [INIT] Fix fallback for unsupported user-specified runtime unroll factor
* Add CollTrace guard
* Move `commSetUnrollFactor()` to rccl_wrap.cc
* Modify comments in the device-code generator script
This commit is contained in:
Nilesh M Negi
2025-07-10 10:56:18 -05:00
committed by GitHub
parent 68d6f99e0f
commit 2c099fe29a
6 changed files with 61 additions and 45 deletions
+8 -9
View File
@@ -52,19 +52,21 @@ else:
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
#
# # Only AllReduce Sum int32_t (but all algos, protos)
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
# make ONLY_FUNCS="AllReduce * * Sum i32"
#
# # Only AllReduce RING Max float (but all protos and unrolls)
# make ONLY_FUNCS="AllReduce RING * Max float"
# make ONLY_FUNCS="AllReduce RING * Max f32"
#
# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16 1"
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16 1"
#
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float *"
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32 *"
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float *"
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|AllGather RING LL/SIMPLE Sum int8_t 1/2/4|AllToAllPivot RING SIMPLE Sum int8_t 1/2/4|Broadcast RING LL/SIMPLE Sum int8_t 1/2/4|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|SendRecv RING SIMPLE Sum int8_t 1/2/4"
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32 *"
#
#
# make ONLY_FUNCS="AllReduce RING/TREE LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|AllGather RING LL/LL128/SIMPLE Sum i8 1/2/4|AllToAllPivot RING SIMPLE Sum i8 1/2/4|Broadcast RING LL/LL128/SIMPLE Sum i8 1/2/4|Reduce RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|ReduceScatter RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|SendRecv RING SIMPLE Sum i8 1/2/4"
#
# # ONLY_FUNCS can be used together for debugging
@@ -159,9 +161,6 @@ def func_filter(function_params, current_idx, item_list=None):
# If the paramter is equal to '*', include all possible cases for it
if current_element == "*":
if current_idx == 0:
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type> <unroll>
# Get the current list from all_params
current_list = all_params[current_idx]
+4 -4
View File
@@ -29,7 +29,7 @@
using namespace rccl;
/* [RCCL] Determine which GPU kernel to execute */
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL)
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task)
{
// At this time, unroll factor is controlled only by passed in unroll argument
// After more investigation, this may be further tuned by the actual task being processed
@@ -48,9 +48,9 @@ void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* tas
return rcclKernelTable[firstKernel + kernelIdx].funcPtr;
}
}
// Fall back to default unroll
WARN("Requested RCCL_UNROLL_FACTOR: %d does not exist in `rcclKernelTable`. Falling back to default unroll: %d", unroll, rcclKernelTable[firstKernel].unroll);
return rcclKernelTable[firstKernel].funcPtr;
// If does not match, return null
return nullptr;
}
static int rcclProtoGrainSize(int proto, ncclComm *comm){
+2
View File
@@ -17,6 +17,8 @@
#define NCCL_SIMPLE_ALIGNMENT (WARP_SIZE * 8LL * 16LL)
#define NCCL_BYTES_ALIGNMENT 16
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL);
ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize);
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info);
ncclResult_t ncclLaunchPrepare(struct ncclComm* comm);
+4 -1
View File
@@ -23,6 +23,8 @@ THE SOFTWARE.
#define RCCL_COMMON_H_
#include "nccl_common.h"
#include "nccl.h"
#include "param.h"
typedef enum RcclTunableColls {
RCCL_UNSUPPORTED_TUNABLE = -1,
RCCL_RS_TUNABLE = 0, // reduce_scatter index
@@ -78,4 +80,5 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co
int* algo, int* protocol, int* maxChannels);
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
#endif
ncclResult_t commSetUnrollFactor(struct ncclComm* comm);
#endif
+1 -31
View File
@@ -51,6 +51,7 @@
#include "mscclpp/mscclpp_nccl.h"
#endif
#include "rocm_smi_wrap.h"
#include "rccl_common.h"
// [/RCCL]
#include "msccl/msccl_lifecycle.h"
@@ -87,37 +88,6 @@ NCCL_PARAM(RuntimeConnect, "RUNTIME_CONNECT", 1);
struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {};
static ncclResult_t commReclaim(ncclComm_t comm);
//RCCL runtime param to set Unroll Factor
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
//If RCCL runtime param is set, it will override defaults
if (rcclParamUnrollFactor() != 0) {
comm->unroll = rcclParamUnrollFactor();
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
}
else {
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
if (comm->nNodes == 1)
comm->unroll = 1;
else
comm->unroll = 2;
}
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
//on MI300X and gfx908, use unroll=2
comm->unroll = 2;
else
comm->unroll = 4;
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
}
return ncclSuccess;
}
#ifdef ENABLE_MSCCLPP
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {
return (size_t)getHash(uniqueId.internal, NCCL_UNIQUE_ID_BYTES);
+42
View File
@@ -24,6 +24,7 @@ THE SOFTWARE.
#include "comm.h"
#include "graph/topo.h"
#include "enqueue.h"
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
// Honor user input for protocol choice
static int userProtocolInput = -2;
@@ -125,3 +126,44 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count,
maxCount = ncclFuncMaxSendRecvCount(func, nRanks, count);
return ncclSuccess;
}
//RCCL runtime param to set Unroll Factor
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
//If RCCL runtime param is set, it will override defaults, if supported
if (rcclParamUnrollFactor() != 0) {
#if ENABLE_COLLTRACE
if(rcclGetKernelIndex(rcclParamUnrollFactor(), comm->collTraceEnabled)) {
#else
if(rcclGetKernelIndex(rcclParamUnrollFactor(), false)) {
#endif
comm->unroll = rcclParamUnrollFactor();
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
return ncclSuccess;
}
else {
// Fall back to default unroll
WARN("Requested RCCL_UNROLL_FACTOR: %ld is invalid and does not exist in `rcclKernelTable`. Falling back to pre-set unroll.", rcclParamUnrollFactor());
}
}
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
if (comm->nNodes == 1)
comm->unroll = 1;
else
comm->unroll = 2;
}
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
//on MI300X and gfx908, use unroll=2
comm->unroll = 2;
else
comm->unroll = 4;
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
return ncclSuccess;
}