Post thread-block size increase tuning (#2042)

* for multinode gfx950, extend AR LL128 up to 256MB, extend RS LL128 up to 8MB per rank, extend AG LL up to 64KB per rank

* dont override direct allgather threshold if set to -1

* restore 2-node AR simple at earlier message sizes than higher multi-node AR

* extend range of LL for single-node RS on gfx950

* update algo/proto for multi-node allreduce on gfx942

* set single-node AR on gfx950 to Tree LL for KB message sizes

* decrease threshold for single node Tree for gfx950 AR
Bu işleme şunda yer alıyor:
isaki001
2025-11-13 14:51:04 -06:00
işlemeyi yapan: GitHub
ebeveyn 83ffc82fa7
işleme 0d09f86608
3 değiştirilmiş dosya ile 26 ekleme ve 16 silme
+9
Dosyayı Görüntüle
@@ -2218,6 +2218,15 @@ rccl_static ncclResult_t getAlgoInfo(
NCCLCHECK(topoGetAlgoInfo(comm, info, nBytes, (float **)collCostTable, simInfo));
} else {
NCCLCHECK(topoGetAlgoInfo(comm, info, nBytes, (float **)collCostTable, simInfo));
//override algo, tree doesn't work with fewer than 64 bytes
static int userAlgoInput = -2;
const char *algoStr = getenv("NCCL_ALGO");
userAlgoInput = !algoStr ? 0 : 1;
size_t sizePerRank = rcclGetSizePerRank(info->func, nBytes, comm->nRanks);
if (!userAlgoInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncAllReduce) && sizePerRank >= 64 && sizePerRank <= 262144){
info->algorithm = NCCL_ALGO_TREE;
info->protocol = NCCL_PROTO_LL;
}
// NCCL_CTA_POLICY_EFFICIENCY requires user (non-symmetric) buffer registration (currently unsupported with MNNVL)
if (comm->config.CTAPolicy == NCCL_CTA_POLICY_EFFICIENCY && ncclGetEnv("NCCL_ALGO") == NULL && ncclGetEnv("NCCL_PROTO") == NULL && !comm->MNNVL) {
// make algorithm selection based on buffer registration
+11 -11
Dosyayı Görüntüle
@@ -336,7 +336,7 @@ static struct tuningModel tuning_model_5 {
{ /* Tree (LL/LL128/Simple)*/ { 0.06, 0.06, 0.59 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 0.08, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 0} },
},
.treeCorrectionFactor = {
.treeCorrectionFactor = { /*16M 32M 64M 128M 256M 512M 1G 2G 4G */
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.7, 0.5, 0.6, 0.6, 0.6, },
@@ -345,7 +345,7 @@ static struct tuningModel tuning_model_5 {
.ringCorrectionFactor = {
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.8, 1.0, 1.0, 1.0, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 0.2, 1.0, 1.0, 1.0, },
},
// Follow order in RcclTunableColls
.llProtoRanges = {
@@ -354,7 +354,7 @@ static struct tuningModel tuning_model_5 {
/*AllGather*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 98304, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {98304, 5046272, 1, 64}},
/*AllReduce*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 1048576, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {1048576, 9437184, 3145728, 0}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 524288, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {524288, 4415057, 3145728, 0}},
/*Reduce*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 4096, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {4096, 16777216, 1, 0}},
/*Broadcast*/
@@ -382,25 +382,25 @@ static struct tuningModel tuning_model_6 {
{ /* Tree (LL/LL128/Simple)*/ { 0.06, 0.06, 0.59 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 0.08, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 } },
},
.treeCorrectionFactor = { /*16 32 64M 128M 1G*/
.treeCorrectionFactor = { /*16M 32M 64M 128M 256M 512M 1G 2G 4G */
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 0.1, 0.1, 0.1, 0.9, 0.9, 0.7, 0.8, 0.6, 0.6, 0.6, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.1, 0.9, 0.9, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.7, 0.15, 0.6, 0.1, 0.6, },
},
.ringCorrectionFactor = {
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.8, 0.8, 1.0, 1.0, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 1.0, 0.4, 0.4, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, 5.5, 0.1, 0.1, 1.0, 1.0, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 1.0, 0.1, 0.6, 1.0, 1.0, },
},
// Follow order in RcclTunableColls
.llProtoRanges = {
/*ReduceScatter*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 4194304, 1, 64}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 8388608, 1, 64}},
/*AllGather*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 32768, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {32768, 8388608, 1, 64}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 8388608, 1, 64}},
/*AllReduce*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 262144, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {262144, 17660227, 3145728, 0}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 262144, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {262144, 70640910, 3145728, 0}},
/*Reduce*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 16383, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {16383, 16777216, 1, 0}},
/*Broadcast*/
+6 -5
Dosyayı Görüntüle
@@ -43,10 +43,11 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s
const char *protoStr = getenv("NCCL_PROTO");
userProtocolInput = !protoStr ? 0 : 1;
}
if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncAllGather) && sizePerRank <= 88448) {
// Change LL protocol threshold
info->protocol = NCCL_PROTO_LL;
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 175488) {
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 1048576) {
// Change LL protocol threshold
info->protocol = NCCL_PROTO_LL;
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 352128) {
@@ -352,13 +353,13 @@ ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) {
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) {
size_t threshold = rcclParamDirectAllGatherThreshold();
if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) {
if (comm->nNodes == 1 && threshold != -1) {
if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && threshold != -1) {
if (comm->nNodes == 1) {
threshold = 8388608;
} else if (comm->nNodes < 64 && threshold != -1) {
} else if (comm->nNodes < 64) {
threshold = comm->nNodes * 2097152;
}
} else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942")) {
} else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && threshold != -1) {
threshold = 4194304;
}