From 93fdcb160c09f8a4a2693006729f3fe24000dc52 Mon Sep 17 00:00:00 2001 From: Mustafa Abduljabbar Date: Tue, 6 Jan 2026 10:21:49 -0500 Subject: [PATCH] [WarpSpeed] Improve handling for auto and manual modes (#2125) * Force ring in WarpSpeed manual mode and log event * Skip usage for non-ring in WarpSpeed auto mode * Enable WarpSpeed when its CU count is set --- src/init.cc | 3 ++- src/rccl_wrap.cc | 27 +++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/init.cc b/src/init.cc index 0fce2ddb42..28de2f70c7 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1078,6 +1078,7 @@ NCCL_PARAM(AllocP2pNetLLBuffers, "ALLOC_P2P_NET_LL_BUFFERS", 0); #ifdef ENABLE_WARP_SPEED extern int64_t rcclParamWarpSpeedEnable(); extern int64_t rcclParamWarpSpeedAutoMode(); +extern int64_t rcclParamWarpSpeedCuCount(); #endif // MNNVL: Flag to indicate whether to enable Multi-Node NVLink NCCL_PARAM(MNNVLEnable, "MNNVL_ENABLE", 2); @@ -1458,7 +1459,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p } } #ifdef ENABLE_WARP_SPEED - comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0); + comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0 || rcclParamWarpSpeedCuCount() > 0); #endif // For single node communicators that do not uses the full xgmi links per gpu, i.e., nranks < 8 diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index c3a3ce243c..2bce6327aa 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -51,11 +51,19 @@ void rcclRestrictMaxChannels(struct ncclComm* comm, int& nc ) { } } +static inline bool rcclCollSupportsRing(ncclFunc_t func) { + return (func == ncclFuncAllReduce || + func == ncclFuncAllGather || + func == ncclFuncReduceScatter || + func == ncclFuncBroadcast || + func == ncclFuncReduce); +} + int32_t rcclGetProtoForGfx12(ncclFunc_t collectiveFunc, size_t sizePerRank){ int returnVal = NCCL_PROTO_SIMPLE; - int SingleNodeLLCutoffs[] = { + int SingleNodeLLCutoffs[] = { /*ncclFuncBroadcast*/ 1536, - /*ncclFuncReduce*/ 8192, + /*ncclFuncReduce*/ 8192, /*ncclFuncAllGather*/ 98304, /*ncclFuncReduceScatter*/ 98304, /*ncclFuncAllReduce*/ 913532, @@ -513,13 +521,20 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) { info->useWarpSpeed = false; - if(rcclParamWarpSpeedAutoMode() != 0) { // auto mode + if(!rcclCollSupportsRing(info->func)) return; + if(rcclParamWarpSpeedAutoMode() != 0) { // Auto performance mode if(!IsArchMatch(comm->archName, "gfx950")) { // Auto mode only available for gfx950 currently, keep it to false return; } size_t minBytes = 0; commSetUnrollFactor(comm); // TODO: reset unroll factor per task rather than per comm + // No early return based on the algorithm at the start of the function + // to allow unroll factor to be reverted to default. + // This can be changed once per-task unroll factor setting is implemented. + if(info->algorithm != NCCL_ALGO_RING) { + return; // If Ring is not selected, assume it is suboptimal and return + } if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES; else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed if(comm->nNodes == 1) { @@ -529,7 +544,11 @@ void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size info->useWarpSpeed = true; } } - } else if (comm->topo->warpSpeedEnabled && info->algorithm == NCCL_ALGO_RING) { + } else if (comm->topo->warpSpeedEnabled) { + if(info->algorithm != NCCL_ALGO_RING) { + INFO(NCCL_TUNING, "Overriding %s algorithm with RING for nccl%s at %zu bytes as WarpSpeed is requested and only supports RING", ncclAlgoToString(info->algorithm), ncclFuncToString(info->func), nBytes); + info->algorithm = NCCL_ALGO_RING; // Force Ring when WarpSpeed is enabled in manual mode as it only supports Ring + } info->useWarpSpeed = true; } }