From e40ff4f84a723b7a95d3c824207f0719d779ee70 Mon Sep 17 00:00:00 2001 From: Pedram Alizadeh Date: Thu, 10 Apr 2025 11:43:54 -0400 Subject: [PATCH] all_reduce LL/LL128 and Ring/Tree multi-node tuning for MI300 (#1627) * Enabling LL128 by default on MI300 * Add missing CUDACHECK * Adjust BW correction factors to fix the Tree->Ring switching point * Refactor and add ll128 AR logarithmic factor to tuning models * Move RCCL tuning changes to a separate file * Use enum for tunable indexing * Use explicit indexing in tuning models to avoid mismatch issues * Place rcclGetSizePerRank in a function * Remove HIP ifdef for rccl-only call --------- Co-authored-by: Mustafa Abduljabbar --- CMakeLists.txt | 2 ++ src/enqueue.cc | 47 +------------------------ src/graph/rome_models.cc | 2 +- src/graph/topo.h | 6 ---- src/graph/tuning.cc | 23 +++++------- src/include/comm.h | 5 +-- src/include/rccl_common.h | 64 ++++++++++++++++++++++++++++++++++ src/rccl_wrap.cc | 73 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 153 insertions(+), 69 deletions(-) create mode 100644 src/include/rccl_common.h create mode 100644 src/rccl_wrap.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 27dcb70a46..cff5c5b14b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -422,6 +422,7 @@ set(SRC_FILES src/net.cc src/msccl.cc src/proxy.cc + src/rccl_wrap.cc src/register.cc src/transport.cc src/device/all_gather.h @@ -498,6 +499,7 @@ set(SRC_FILES src/include/param.h src/include/profiler.h src/include/proxy.h + src/include/rccl_common.h src/include/rccl_vars.h src/include/register.h src/include/rccl_float8.h diff --git a/src/enqueue.cc b/src/enqueue.cc index d2d80fc60c..40bf736e8a 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1938,55 +1938,10 @@ static ncclResult_t topoGetAlgoInfo( info->protocol = backupProto; time = backupTime; } -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - // Honor user input for protocol choice - static int userProtocolInput = -2; - if (userProtocolInput == -2) { - const char *protoStr = getenv("NCCL_PROTO"); - userProtocolInput = !protoStr ? 0 : 1; - } - - if(!userProtocolInput && comm->nNodes >= 2 && (info->func == ncclFuncReduceScatter || info->func == ncclFuncAllGather)) { - auto llMin = comm->minMaxLLRange[info->func][NCCL_PROTO_LL][0]; - auto llMax = comm->minMaxLLRange[info->func][NCCL_PROTO_LL][1]; - - auto ll128Min = comm->minMaxLLRange[info->func][NCCL_PROTO_LL128][0]; - auto ll128Max = comm->minMaxLLRange[info->func][NCCL_PROTO_LL128][1]; - - // Only override model choices if min/max cutoff points are set in the tuning models - if((ll128Max != RCCL_LL_LIMITS_UNDEFINED) || (llMax != RCCL_LL_LIMITS_UNDEFINED)) { - // Keep it simple unless otherwise required - info->protocol = NCCL_PROTO_SIMPLE; - // Normalize the comparison to sizePerRank as this is essentially what matters in determining protocol choice - size_t sizePerRank = nBytes / comm->nRanks; - - if(sizePerRank <= llMax && sizePerRank > llMin) { - info->protocol = NCCL_PROTO_LL; - } -#if defined(ENABLE_LL128) - // When applicable, LL128 RS performance is better than LL, so the next condition overrides the previous LL choice - if(comm->topo->ll128Enabled) { - if(sizePerRank <= ll128Max && sizePerRank > ll128Min) { - info->protocol = NCCL_PROTO_LL128; - } - } -#endif - } else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942")) { - // Warn that model detection for MI300 (or future others) did not work as expected - // Add supported archs to this condition as they come (e.g. gfx950) - // Also make sure the tuning_model and model detection are updated for new archs - static bool failedWarn = false; - if (!failedWarn) { - WARN("LL cutoff points not detected for a supported arch %s", comm->topo->nodes[GPU].nodes[0].gpu.gcn); - failedWarn = true; - } - } - } -#endif + rcclUpdateCollectiveProtocol(comm, nBytes, info); if (comm->rank == 0) INFO(NCCL_TUNING, "%s: %ld Bytes -> Algo %d proto %d time %f", ncclFuncToString(info->func), nBytes, info->algorithm, info->protocol, time); if (simInfo) simInfo->estimatedTime = time; TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", nBytes, info->algorithm, info->protocol, time); - int nc = comm->nChannels; int nt = comm->maxThreads[info->algorithm][info->protocol]; int threadThreshold = comm->threadThresholds[info->algorithm][info->protocol]; diff --git a/src/graph/rome_models.cc b/src/graph/rome_models.cc index 1a686ba2f6..3b15482e39 100644 --- a/src/graph/rome_models.cc +++ b/src/graph/rome_models.cc @@ -814,7 +814,7 @@ static struct rcclRomeModel rome_model_81 = { "N7 7 3 2 6 0 4 1 5 N5|" "N1 1 0 2 4 3 5 7 6 N6|", - .options = "noCpuCheck=1,tuning=5,disableNumaMatching=1", + .options = "noCpuCheck=1,tuning=5,ll128Enabled=1,disableNumaMatching=1", .treeRail = "N0 0 1 2 4 3 6 5 7 N1|" "N1 1 0 4 7 3 5 2 6 N0|" diff --git a/src/graph/topo.h b/src/graph/topo.h index e488bd1e0a..1fb6af0641 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -123,12 +123,6 @@ struct ncclTopoLinkList { #define RCCL_TOPO_FORCE_INTRA 16 #define RCCL_TOPO_XGMI_ALL 32 -#define RCCL_LL_TUNABLE_COLLS 4 // LL/LL64/LL128 tunable Collectives -#define RCCL_RS_TUNABLE 0 // reduce_scatter index -#define RCCL_AG_TUNABLE 1 // all_gather index -#define RCCL_AR_TUNABLE 2 // all_reduce index -#define RCCL_RE_TUNABLE 3 // reduce index -#define RCCL_LL_LIMITS_UNDEFINED 0 #define GCN_ARCH_NAME_LEN 16 diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index c550e7ae05..593c1e5e25 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -71,7 +71,7 @@ struct tuningModel { float bwRatio [2][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][27]; float ringCorrectionFactor[NCCL_NUM_PROTOCOLS][27]; - uint64_t llProtoRanges[RCCL_LL_TUNABLE_COLLS][NCCL_NUM_PROTOCOLS - 1][2]; + uint64_t llProtoRanges[RCCL_TUNABLE_COLLS][NCCL_NUM_PROTOCOLS - 1][RCCL_PROTOCOL_ENTRY_SIZE]; }; static struct tuningModel tuning_model_0 { @@ -254,19 +254,18 @@ static struct tuningModel tuning_model_5 { .treeCorrectionFactor = { { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.7, 0.5, 0.6, 0.6, 0.6, }, }, .ringCorrectionFactor = { { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.8, 1.0, 1.0, 1.0, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.8, 1.0, 1.0, 1.0, }, }, - .llProtoRanges = { - /*ReduceScatter*/ {/* LL (Min/Max) */ {0, 655360} , /* LL128 (Min/Max) */ {131072, 3211264}}, - /*AllGather*/ {/* LL (Min/Max) */ {0, 98304} , /* LL128 (Min/Max) */ {98304, 5046272}}, - }, + .llProtoRanges[RCCL_RS_TUNABLE] = /*ReduceScatter*/ {/* LL (Min/Max) */ {0, 655360, 1} , /* LL128 (Min/Max) */ {131072, 3211264, 1}}, + .llProtoRanges[RCCL_AG_TUNABLE] = /*AllGather*/ {/* LL (Min/Max) */ {0, 98304, 1} , /* LL128 (Min/Max) */ {98304, 5046272, 1}}, + .llProtoRanges[RCCL_AR_TUNABLE] = /*AllReduce*/ {/* LL (Min/Max) */ {0, 1048576, 1} , /* LL128 (Min/Max) */ {1048576, 9437184, 3145728}}, }; static struct tuningModel rcclTuningModel[] = { @@ -372,13 +371,9 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom for (int a=0; atypeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI; for (int a=0; aminMaxLLRange[ncclFuncReduceScatter], - rcclTuningModel[comm->topo->tuning].llProtoRanges[RCCL_RS_TUNABLE], - sizeof(rcclTuningModel[comm->topo->tuning].llProtoRanges[RCCL_RS_TUNABLE])); - - memcpy(comm->minMaxLLRange[ncclFuncAllGather], - rcclTuningModel[comm->topo->tuning].llProtoRanges[RCCL_AG_TUNABLE], - sizeof(rcclTuningModel[comm->topo->tuning].llProtoRanges[RCCL_AG_TUNABLE])); + memcpy(comm->minMaxLLRange, + rcclTuningModel[comm->topo->tuning].llProtoRanges, + sizeof(rcclTuningModel[comm->topo->tuning].llProtoRanges)); for (int coll=0; collnNodes >= 2 && (info->func == ncclFuncReduceScatter || info->func == ncclFuncAllGather || info->func == ncclFuncAllReduce)) { + auto tunableIndex = rcclGetTunableIndex(info->func); + auto llMin = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL][RCCL_PROTOCOL_MIN_IDX]; + auto llMax = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL][RCCL_PROTOCOL_MAX_IDX]; + + auto ll128Min = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL128][RCCL_PROTOCOL_MIN_IDX]; + auto ll128Max = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL128][RCCL_PROTOCOL_MAX_IDX]; + + // Only override model choices if min/max cutoff points are set in the tuning models + if ((ll128Max != RCCL_LL_LIMITS_UNDEFINED) || (llMax != RCCL_LL_LIMITS_UNDEFINED)) { + // Keep it simple unless otherwise required + info->protocol = NCCL_PROTO_SIMPLE; + size_t sizePerRank = rcclGetSizePerRank(info->func, nBytes, comm->nRanks); + if (sizePerRank <= llMax && sizePerRank > llMin) { + info->protocol = NCCL_PROTO_LL; + } +#if defined(ENABLE_LL128) + // When LL128 is performant, the next condition overrides the previous LL choice + if (comm->topo->ll128Enabled) { + if (info->func == ncclFuncAllReduce) { + ll128Max += (log2i(comm->nNodes) - 1) * comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL128][RCCL_PROTOCOL_FACTOR_IDX]; + } + if (sizePerRank <= ll128Max && sizePerRank > ll128Min) { + info->protocol = NCCL_PROTO_LL128; + } + } +#endif + } else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942")) { + // Warn that model detection for MI300 (or future others) did not work as expected + // Add supported archs to this condition as they come (e.g. gfx950) + // Also make sure the tuning_model and model detection are updated for new archs + static bool failedWarn = false; + if (!failedWarn) { + WARN("LL cutoff points not detected for a supported arch %s", comm->topo->nodes[GPU].nodes[0].gpu.gcn); + failedWarn = true; + } + } + } +}