From 1a7ab8dfc858e408a69877e45bf0195b9af197e4 Mon Sep 17 00:00:00 2001 From: Mustafa Abduljabbar Date: Wed, 3 Sep 2025 08:54:13 -0400 Subject: [PATCH] Force enable proto and/or algo after model selection (#1799) * Force enable proto or algo * Remove inc nccl_common.h * Move logic and add error checks * Fix topo_expl compatibility * Allow algo/proto overrides * Remove extra function decl * Clarify warning message * Move algo/proto overrides into separate functions * Update CHANGELOG.md [ROCm/rccl commit: 7ccc6f268f89d2459555a207bc42a61e0c98dffc] --- projects/rccl/CHANGELOG.md | 2 +- projects/rccl/src/enqueue.cc | 2 + projects/rccl/src/include/rccl_common.h | 3 ++ projects/rccl/src/rccl_wrap.cc | 68 +++++++++++++++++++++++++ projects/rccl/tools/topo_expl/utils.cpp | 2 +- 5 files changed, 75 insertions(+), 2 deletions(-) diff --git a/projects/rccl/CHANGELOG.md b/projects/rccl/CHANGELOG.md index 40c807895e..3d31ae5be0 100644 --- a/projects/rccl/CHANGELOG.md +++ b/projects/rccl/CHANGELOG.md @@ -35,7 +35,7 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https: * Two new APIs are exposed as part of an initiative to separate RCCL code. These APIs are `rcclGetAlgoInfo` and `rcclFuncMaxSendRecvCount`. However, user-level invocation requires that RCCL be built with `RCCL_EXPOSE_STATIC` enabled. * Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap `bf16` arithmetic and bridge the gap between `fp32` performance and `bf16` for both `gfx942` and `gfx950`. Pipelining has been made tunable via `rcclSetPipelining`, similar to algorithms/protocols so that regression is avoided in certain message sizes. * Added a direct allgather algorithm. This is enabled by default for multi-node if there are 16 nodes or fewer. The message size threshold is 4MB. - +* Added `RCCL_OVERRIDE_PROTO` and `RCCL_OVERRIDE_ALGO` to allow direct replacement of protocol and algorithm choices. Unlike `NCCL_PROTO` and `NCCL_ALGO`, which re-run the model across enabled combinations and may not guarantee the intended override, these new options enforce the specified selections explicitly. ### Changed diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index d0d20d554b..8a91642596 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -2103,6 +2103,8 @@ static ncclResult_t topoGetAlgoInfo( if (info->algorithm == NCCL_ALGO_TREE) nt = NCCL_MAX_NTHREADS; // Tree now uses all threads always. if (info->algorithm == NCCL_ALGO_PAT) nt = NCCL_MAX_NTHREADS; info->nWarps = nt/comm->WarpSize; + rcclOverrideAlgorithm(ncclAlgoStr, table, info); + rcclOverrideProtocol(ncclProtoStr, table, info); return ncclSuccess; } diff --git a/projects/rccl/src/include/rccl_common.h b/projects/rccl/src/include/rccl_common.h index f7110690dc..5c124b6c31 100644 --- a/projects/rccl/src/include/rccl_common.h +++ b/projects/rccl/src/include/rccl_common.h @@ -81,6 +81,9 @@ inline size_t rcclGetSizePerRank(ncclFunc_t const& func, size_t const& nBytes, i // For AR, this is the send/recv size per rank return (func == ncclFuncReduceScatter || func == ncclFuncAllGather || func == ncclFuncBroadcast) ? nBytes / nRanks : nBytes; } +ncclResult_t rcclGetAlgoProtoIndex(const char *envStr, const char* algoProtoString[], int nEntries, int& result); +ncclResult_t rcclOverrideProtocol(const char* ncclProtoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info); +ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info); void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info); void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold); void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info); diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index ece9f62f86..83988dfe74 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -85,6 +85,74 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s } } +ncclResult_t rcclGetAlgoProtoIndex(const char *envStr, const char* algoProtoString[], int nEntries, int& result) { + if(envStr) { + for (int i = 0; i < nEntries; ++i) { + if (strcasecmp(envStr, algoProtoString[i]) == 0) { + result = i; + return ncclSuccess; + } + } + static bool failedProtoWarn = false; + if (!failedProtoWarn) { + WARN("Invalid algo or protocol string passed %s", envStr); + failedProtoWarn = true; + return ncclInvalidUsage; + } + } + return ncclInvalidUsage; +} + +ncclResult_t rcclOverrideProtocol(const char* ncclProtoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info) { + static const char* protoOverrideEnv = ncclGetEnv("RCCL_OVERRIDE_PROTO"); + static bool validInput = true; + if (!validInput) return ncclInvalidUsage; + + if (protoOverrideEnv) { + static int protoVal = NCCL_PROTO_UNDEF; + if (protoVal == NCCL_PROTO_UNDEF) { + if (rcclGetAlgoProtoIndex(protoOverrideEnv, ncclProtoStr, NCCL_NUM_PROTOCOLS, protoVal) != ncclSuccess) { + validInput = false; + return ncclInvalidUsage; + } + } + if (protoVal > NCCL_PROTO_UNDEF) { + if (table[info->algorithm][protoVal] == NCCL_ALGO_PROTO_IGNORE) { + WARN("Failed to force unsupported protocol %s for function %s with datatype %s", protoOverrideEnv, ncclFuncToString(info->func), ncclDatatypeToString(info->datatype)); + return ncclInternalError; + } else { + info->protocol = protoVal; + } + } + } + return ncclSuccess; +} + +ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info) { + static const char* algoOverrideEnv = ncclGetEnv("RCCL_OVERRIDE_ALGO"); + static bool validInput = true; + if (!validInput) return ncclInvalidUsage; + + if (algoOverrideEnv) { + static int algoVal = NCCL_ALGO_UNDEF; + if (algoVal == NCCL_ALGO_UNDEF) { + if (rcclGetAlgoProtoIndex(algoOverrideEnv, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoVal) != ncclSuccess) { + validInput = false; + return ncclInvalidUsage; + } + } + if (algoVal > NCCL_ALGO_UNDEF) { + if (table[algoVal][info->protocol] == NCCL_ALGO_PROTO_IGNORE) { + WARN("Failed to force unsupported algorithm %s for function %s with datatype %s", algoOverrideEnv, ncclFuncToString(info->func), ncclDatatypeToString(info->datatype)); + return ncclInternalError; + } else { + info->algorithm = algoVal; + } + } + } + return ncclSuccess; +} + void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold) { // Honor user input for thread thresholds static int userChannelControlInput = -2; diff --git a/projects/rccl/tools/topo_expl/utils.cpp b/projects/rccl/tools/topo_expl/utils.cpp index 1185accd82..312655578d 100644 --- a/projects/rccl/tools/topo_expl/utils.cpp +++ b/projects/rccl/tools/topo_expl/utils.cpp @@ -33,7 +33,7 @@ #include "rocm_smi/rocm_smi.h" const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" }; -const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain" }; +const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree", "PAT" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; extern NodeModel *node_model;