diff --git a/src/enqueue.cc b/src/enqueue.cc index beda2feaf2..6f1d5c2277 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -2218,6 +2218,15 @@ rccl_static ncclResult_t getAlgoInfo( NCCLCHECK(topoGetAlgoInfo(comm, info, nBytes, (float **)collCostTable, simInfo)); } else { NCCLCHECK(topoGetAlgoInfo(comm, info, nBytes, (float **)collCostTable, simInfo)); + //override algo, tree doesn't work with fewer than 64 bytes + static int userAlgoInput = -2; + const char *algoStr = getenv("NCCL_ALGO"); + userAlgoInput = !algoStr ? 0 : 1; + size_t sizePerRank = rcclGetSizePerRank(info->func, nBytes, comm->nRanks); + if (!userAlgoInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncAllReduce) && sizePerRank >= 64 && sizePerRank <= 262144){ + info->algorithm = NCCL_ALGO_TREE; + info->protocol = NCCL_PROTO_LL; + } // NCCL_CTA_POLICY_EFFICIENCY requires user (non-symmetric) buffer registration (currently unsupported with MNNVL) if (comm->config.CTAPolicy == NCCL_CTA_POLICY_EFFICIENCY && ncclGetEnv("NCCL_ALGO") == NULL && ncclGetEnv("NCCL_PROTO") == NULL && !comm->MNNVL) { // make algorithm selection based on buffer registration diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index 973b61b458..6f9117304b 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -336,7 +336,7 @@ static struct tuningModel tuning_model_5 { { /* Tree (LL/LL128/Simple)*/ { 0.06, 0.06, 0.59 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 0.08, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 0} }, }, - .treeCorrectionFactor = { + .treeCorrectionFactor = { /*16M 32M 64M 128M 256M 512M 1G 2G 4G */ { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.7, 0.5, 0.6, 0.6, 0.6, }, @@ -345,7 +345,7 @@ static struct tuningModel tuning_model_5 { .ringCorrectionFactor = { { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.8, 1.0, 1.0, 1.0, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 0.2, 1.0, 1.0, 1.0, }, }, // Follow order in RcclTunableColls .llProtoRanges = { @@ -354,7 +354,7 @@ static struct tuningModel tuning_model_5 { /*AllGather*/ {/*LL (min/max/factor/thread_threshold)*/ {0, 98304, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {98304, 5046272, 1, 64}}, /*AllReduce*/ - {/*LL (min/max/factor/thread_threshold)*/ {0, 1048576, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {1048576, 9437184, 3145728, 0}}, + {/*LL (min/max/factor/thread_threshold)*/ {0, 524288, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {524288, 4415057, 3145728, 0}}, /*Reduce*/ {/*LL (min/max/factor/thread_threshold)*/ {0, 4096, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {4096, 16777216, 1, 0}}, /*Broadcast*/ @@ -382,25 +382,25 @@ static struct tuningModel tuning_model_6 { { /* Tree (LL/LL128/Simple)*/ { 0.06, 0.06, 0.59 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 0.08, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 } }, }, - .treeCorrectionFactor = { /*16 32 64M 128M 1G*/ + .treeCorrectionFactor = { /*16M 32M 64M 128M 256M 512M 1G 2G 4G */ { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 0.1, 0.1, 0.1, 0.9, 0.9, 0.7, 0.8, 0.6, 0.6, 0.6, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.1, 0.9, 0.9, 0.1, 0.1, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 0.15, 0.6, 0.1, 0.6, }, }, .ringCorrectionFactor = { { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.8, 0.8, 1.0, 1.0, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, 5.5, 0.1, 0.1, 1.0, 1.0, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 1.0, 0.1, 0.6, 1.0, 1.0, }, }, // Follow order in RcclTunableColls .llProtoRanges = { /*ReduceScatter*/ - {/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 4194304, 1, 64}}, + {/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 8388608, 1, 64}}, /*AllGather*/ - {/*LL (min/max/factor/thread_threshold)*/ {0, 32768, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {32768, 8388608, 1, 64}}, + {/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 8388608, 1, 64}}, /*AllReduce*/ - {/*LL (min/max/factor/thread_threshold)*/ {0, 262144, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {262144, 17660227, 3145728, 0}}, + {/*LL (min/max/factor/thread_threshold)*/ {0, 262144, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {262144, 70640910, 3145728, 0}}, /*Reduce*/ {/*LL (min/max/factor/thread_threshold)*/ {0, 16383, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {16383, 16777216, 1, 0}}, /*Broadcast*/ diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index 4031ae352d..5d700ace59 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -43,10 +43,11 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s const char *protoStr = getenv("NCCL_PROTO"); userProtocolInput = !protoStr ? 0 : 1; } + if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncAllGather) && sizePerRank <= 88448) { // Change LL protocol threshold info->protocol = NCCL_PROTO_LL; - } else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 175488) { + } else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 1048576) { // Change LL protocol threshold info->protocol = NCCL_PROTO_LL; } else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 352128) { @@ -352,13 +353,13 @@ ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) { bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) { size_t threshold = rcclParamDirectAllGatherThreshold(); - if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) { - if (comm->nNodes == 1 && threshold != -1) { + if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && threshold != -1) { + if (comm->nNodes == 1) { threshold = 8388608; - } else if (comm->nNodes < 64 && threshold != -1) { + } else if (comm->nNodes < 64) { threshold = comm->nNodes * 2097152; } - } else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942")) { + } else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && threshold != -1) { threshold = 4194304; }