Force enable proto and/or algo after model selection (#1799)

* Force enable proto or algo

* Remove inc nccl_common.h

* Move logic and add error checks

* Fix topo_expl compatibility

* Allow algo/proto overrides

* Remove extra function decl

* Clarify warning message

* Move algo/proto overrides into separate functions

* Update CHANGELOG.md

[ROCm/rccl commit: 7ccc6f268f]
Этот коммит содержится в:
Mustafa Abduljabbar
2025-09-03 08:54:13 -04:00
коммит произвёл GitHub
родитель 1999f2eba8
Коммит 1a7ab8dfc8
5 изменённых файлов: 75 добавлений и 2 удалений
+1 -1
Просмотреть файл
@@ -35,7 +35,7 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
* Two new APIs are exposed as part of an initiative to separate RCCL code. These APIs are `rcclGetAlgoInfo` and `rcclFuncMaxSendRecvCount`. However, user-level invocation requires that RCCL be built with `RCCL_EXPOSE_STATIC` enabled.
* Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap `bf16` arithmetic and bridge the gap between `fp32` performance and `bf16` for both `gfx942` and `gfx950`. Pipelining has been made tunable via `rcclSetPipelining`, similar to algorithms/protocols so that regression is avoided in certain message sizes.
* Added a direct allgather algorithm. This is enabled by default for multi-node if there are 16 nodes or fewer. The message size threshold is 4MB.
* Added `RCCL_OVERRIDE_PROTO` and `RCCL_OVERRIDE_ALGO` to allow direct replacement of protocol and algorithm choices. Unlike `NCCL_PROTO` and `NCCL_ALGO`, which re-run the model across enabled combinations and may not guarantee the intended override, these new options enforce the specified selections explicitly.
### Changed
+2
Просмотреть файл
@@ -2103,6 +2103,8 @@ static ncclResult_t topoGetAlgoInfo(
if (info->algorithm == NCCL_ALGO_TREE) nt = NCCL_MAX_NTHREADS; // Tree now uses all threads always.
if (info->algorithm == NCCL_ALGO_PAT) nt = NCCL_MAX_NTHREADS;
info->nWarps = nt/comm->WarpSize;
rcclOverrideAlgorithm(ncclAlgoStr, table, info);
rcclOverrideProtocol(ncclProtoStr, table, info);
return ncclSuccess;
}
+3
Просмотреть файл
@@ -81,6 +81,9 @@ inline size_t rcclGetSizePerRank(ncclFunc_t const& func, size_t const& nBytes, i
// For AR, this is the send/recv size per rank
return (func == ncclFuncReduceScatter || func == ncclFuncAllGather || func == ncclFuncBroadcast) ? nBytes / nRanks : nBytes;
}
ncclResult_t rcclGetAlgoProtoIndex(const char *envStr, const char* algoProtoString[], int nEntries, int& result);
ncclResult_t rcclOverrideProtocol(const char* ncclProtoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info);
ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info);
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold);
void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
+68
Просмотреть файл
@@ -85,6 +85,74 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s
}
}
ncclResult_t rcclGetAlgoProtoIndex(const char *envStr, const char* algoProtoString[], int nEntries, int& result) {
if(envStr) {
for (int i = 0; i < nEntries; ++i) {
if (strcasecmp(envStr, algoProtoString[i]) == 0) {
result = i;
return ncclSuccess;
}
}
static bool failedProtoWarn = false;
if (!failedProtoWarn) {
WARN("Invalid algo or protocol string passed %s", envStr);
failedProtoWarn = true;
return ncclInvalidUsage;
}
}
return ncclInvalidUsage;
}
ncclResult_t rcclOverrideProtocol(const char* ncclProtoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info) {
static const char* protoOverrideEnv = ncclGetEnv("RCCL_OVERRIDE_PROTO");
static bool validInput = true;
if (!validInput) return ncclInvalidUsage;
if (protoOverrideEnv) {
static int protoVal = NCCL_PROTO_UNDEF;
if (protoVal == NCCL_PROTO_UNDEF) {
if (rcclGetAlgoProtoIndex(protoOverrideEnv, ncclProtoStr, NCCL_NUM_PROTOCOLS, protoVal) != ncclSuccess) {
validInput = false;
return ncclInvalidUsage;
}
}
if (protoVal > NCCL_PROTO_UNDEF) {
if (table[info->algorithm][protoVal] == NCCL_ALGO_PROTO_IGNORE) {
WARN("Failed to force unsupported protocol %s for function %s with datatype %s", protoOverrideEnv, ncclFuncToString(info->func), ncclDatatypeToString(info->datatype));
return ncclInternalError;
} else {
info->protocol = protoVal;
}
}
}
return ncclSuccess;
}
ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info) {
static const char* algoOverrideEnv = ncclGetEnv("RCCL_OVERRIDE_ALGO");
static bool validInput = true;
if (!validInput) return ncclInvalidUsage;
if (algoOverrideEnv) {
static int algoVal = NCCL_ALGO_UNDEF;
if (algoVal == NCCL_ALGO_UNDEF) {
if (rcclGetAlgoProtoIndex(algoOverrideEnv, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoVal) != ncclSuccess) {
validInput = false;
return ncclInvalidUsage;
}
}
if (algoVal > NCCL_ALGO_UNDEF) {
if (table[algoVal][info->protocol] == NCCL_ALGO_PROTO_IGNORE) {
WARN("Failed to force unsupported algorithm %s for function %s with datatype %s", algoOverrideEnv, ncclFuncToString(info->func), ncclDatatypeToString(info->datatype));
return ncclInternalError;
} else {
info->algorithm = algoVal;
}
}
}
return ncclSuccess;
}
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold) {
// Honor user input for thread thresholds
static int userChannelControlInput = -2;
+1 -1
Просмотреть файл
@@ -33,7 +33,7 @@
#include "rocm_smi/rocm_smi.h"
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" };
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain" };
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree", "PAT" };
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
extern NodeModel *node_model;