Force enable proto and/or algo after model selection (#1799)
* Force enable proto or algo
* Remove inc nccl_common.h
* Move logic and add error checks
* Fix topo_expl compatibility
* Allow algo/proto overrides
* Remove extra function decl
* Clarify warning message
* Move algo/proto overrides into separate functions
* Update CHANGELOG.md
[ROCm/rccl commit: 7ccc6f268f]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
1999f2eba8
Коммит
1a7ab8dfc8
@@ -35,7 +35,7 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
|
||||
* Two new APIs are exposed as part of an initiative to separate RCCL code. These APIs are `rcclGetAlgoInfo` and `rcclFuncMaxSendRecvCount`. However, user-level invocation requires that RCCL be built with `RCCL_EXPOSE_STATIC` enabled.
|
||||
* Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap `bf16` arithmetic and bridge the gap between `fp32` performance and `bf16` for both `gfx942` and `gfx950`. Pipelining has been made tunable via `rcclSetPipelining`, similar to algorithms/protocols so that regression is avoided in certain message sizes.
|
||||
* Added a direct allgather algorithm. This is enabled by default for multi-node if there are 16 nodes or fewer. The message size threshold is 4MB.
|
||||
|
||||
* Added `RCCL_OVERRIDE_PROTO` and `RCCL_OVERRIDE_ALGO` to allow direct replacement of protocol and algorithm choices. Unlike `NCCL_PROTO` and `NCCL_ALGO`, which re-run the model across enabled combinations and may not guarantee the intended override, these new options enforce the specified selections explicitly.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -2103,6 +2103,8 @@ static ncclResult_t topoGetAlgoInfo(
|
||||
if (info->algorithm == NCCL_ALGO_TREE) nt = NCCL_MAX_NTHREADS; // Tree now uses all threads always.
|
||||
if (info->algorithm == NCCL_ALGO_PAT) nt = NCCL_MAX_NTHREADS;
|
||||
info->nWarps = nt/comm->WarpSize;
|
||||
rcclOverrideAlgorithm(ncclAlgoStr, table, info);
|
||||
rcclOverrideProtocol(ncclProtoStr, table, info);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,9 @@ inline size_t rcclGetSizePerRank(ncclFunc_t const& func, size_t const& nBytes, i
|
||||
// For AR, this is the send/recv size per rank
|
||||
return (func == ncclFuncReduceScatter || func == ncclFuncAllGather || func == ncclFuncBroadcast) ? nBytes / nRanks : nBytes;
|
||||
}
|
||||
ncclResult_t rcclGetAlgoProtoIndex(const char *envStr, const char* algoProtoString[], int nEntries, int& result);
|
||||
ncclResult_t rcclOverrideProtocol(const char* ncclProtoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info);
|
||||
ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info);
|
||||
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
|
||||
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold);
|
||||
void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
|
||||
|
||||
@@ -85,6 +85,74 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s
|
||||
}
|
||||
}
|
||||
|
||||
ncclResult_t rcclGetAlgoProtoIndex(const char *envStr, const char* algoProtoString[], int nEntries, int& result) {
|
||||
if(envStr) {
|
||||
for (int i = 0; i < nEntries; ++i) {
|
||||
if (strcasecmp(envStr, algoProtoString[i]) == 0) {
|
||||
result = i;
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
static bool failedProtoWarn = false;
|
||||
if (!failedProtoWarn) {
|
||||
WARN("Invalid algo or protocol string passed %s", envStr);
|
||||
failedProtoWarn = true;
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
}
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
|
||||
ncclResult_t rcclOverrideProtocol(const char* ncclProtoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info) {
|
||||
static const char* protoOverrideEnv = ncclGetEnv("RCCL_OVERRIDE_PROTO");
|
||||
static bool validInput = true;
|
||||
if (!validInput) return ncclInvalidUsage;
|
||||
|
||||
if (protoOverrideEnv) {
|
||||
static int protoVal = NCCL_PROTO_UNDEF;
|
||||
if (protoVal == NCCL_PROTO_UNDEF) {
|
||||
if (rcclGetAlgoProtoIndex(protoOverrideEnv, ncclProtoStr, NCCL_NUM_PROTOCOLS, protoVal) != ncclSuccess) {
|
||||
validInput = false;
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
}
|
||||
if (protoVal > NCCL_PROTO_UNDEF) {
|
||||
if (table[info->algorithm][protoVal] == NCCL_ALGO_PROTO_IGNORE) {
|
||||
WARN("Failed to force unsupported protocol %s for function %s with datatype %s", protoOverrideEnv, ncclFuncToString(info->func), ncclDatatypeToString(info->datatype));
|
||||
return ncclInternalError;
|
||||
} else {
|
||||
info->protocol = protoVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL_NUM_PROTOCOLS], struct ncclTaskColl* info) {
|
||||
static const char* algoOverrideEnv = ncclGetEnv("RCCL_OVERRIDE_ALGO");
|
||||
static bool validInput = true;
|
||||
if (!validInput) return ncclInvalidUsage;
|
||||
|
||||
if (algoOverrideEnv) {
|
||||
static int algoVal = NCCL_ALGO_UNDEF;
|
||||
if (algoVal == NCCL_ALGO_UNDEF) {
|
||||
if (rcclGetAlgoProtoIndex(algoOverrideEnv, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoVal) != ncclSuccess) {
|
||||
validInput = false;
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
}
|
||||
if (algoVal > NCCL_ALGO_UNDEF) {
|
||||
if (table[algoVal][info->protocol] == NCCL_ALGO_PROTO_IGNORE) {
|
||||
WARN("Failed to force unsupported algorithm %s for function %s with datatype %s", algoOverrideEnv, ncclFuncToString(info->func), ncclDatatypeToString(info->datatype));
|
||||
return ncclInternalError;
|
||||
} else {
|
||||
info->algorithm = algoVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold) {
|
||||
// Honor user input for thread thresholds
|
||||
static int userChannelControlInput = -2;
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
#include "rocm_smi/rocm_smi.h"
|
||||
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" };
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain" };
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree", "PAT" };
|
||||
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
|
||||
|
||||
extern NodeModel *node_model;
|
||||
|
||||
Ссылка в новой задаче
Block a user