Navi4 LL enablement and tuning (#2095)

* LL enablement for gfx1201

* Single node LL/Simple tuning

* multinode algo/prto default choice

* First iteration of Table tuning

* gfx924 tuning table correction

* Addressing PR comments and prefix match fix


[ROCm/rccl commit: 9545ae04b2]
Этот коммит содержится в:
Avinash
2026-01-05 10:17:12 -06:00
коммит произвёл GitHub
родитель 57f81914d8
Коммит de23e1db6d
3 изменённых файлов: 83 добавлений и 6 удалений
+10 -3
Просмотреть файл
@@ -2060,6 +2060,14 @@ static ncclResult_t topoGetAlgoInfo(
}
}
}
if(algorithm == NCCL_ALGO_UNDEF){
INFO(NCCL_INIT,"Optimal algorithm is not found in collCostTable, Setting it a default value NCCL_ALGO_RING");
algorithm = NCCL_ALGO_RING;
}
if(protocol == NCCL_PROTO_UNDEF){
INFO(NCCL_INIT,"Optimal protocol is not found in collCostTable, Setting it a default value NCCL_PROTO_SIMPLE");
protocol = NCCL_PROTO_SIMPLE;
}
info->algorithm = algorithm;
info->protocol = protocol;
@@ -2119,9 +2127,8 @@ static ncclResult_t topoGetAlgoInfo(
int minNChannels = ncclParamMinNchannels();
// Ring/Tree channel tuning
INFO(NCCL_INIT, "minNChannels:%i", minNChannels);
while (nBytes < nc * nt * threadThreshold && nc > minNChannels) {
if (nc >= 2) nc--;
else break;
if(nBytes < nc * nt * threadThreshold && nc > minNChannels){
nc = std::max(1,std::max(minNChannels,(int)(nBytes/std::max(1,nt * threadThreshold))));
}
INFO(NCCL_INIT, "post-adjustment based on threadThreshold:%i nBytes:%lu nc:%i", threadThreshold, nBytes, nc);
rcclOverrideChannels(comm, info->func, nBytes, nc);
+53 -3
Просмотреть файл
@@ -423,6 +423,56 @@ static struct tuningModel tuning_model_6 {
},
};
static struct tuningModel tuning_model_7 {
.hwLat = {
/* NVLINK */
{ /* Tree (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 2.2, 3.6 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 0.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 1.4 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 3.6} },
/* PCI */
{ /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 5.7} },
/* NET */
{ /* Tree (LL/LL128/Simple)*/ { 11.8, 18.2, 20.8 }, /* Ring (LL/LL128/Simple)*/ { 9.5, 19.8, 15.1 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 11.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 18.2 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 15.1} },
},
.bwRatio = {
/* 2 nodes */
{ /* Tree (LL/LL128/Simple)*/ { 0.051, 0.22, 0.64 }, /* Ring (LL/LL128/Simple)*/ { 0.74, 0.34, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 0} },
/* more than 2 nodes */
{ /* Tree (LL/LL128/Simple)*/ { 0.051, 0.22, 0.64 }, /* Ring (LL/LL128/Simple)*/ { 0.74, 0.34, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 }, /* NVLS Tree */ { 0, 0, 0 }, /* PAT */ { 0, 0, 1.00} },
},
.treeCorrectionFactor = {
{ 0.1, 0.2, 0.1, 0.1, 0.9, 0.3, 0.4, 0.1, 0.2, 0.4, 0.2, 0.1, 0.3, 0.3, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.3, 1.0, 0.1, 0.5, 1.0, 0.9, 1.0, 1.0, 1.0, 0.3, 0.1, 0.4, 0.5, 0.5, 0.4, 0.4, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, },
{ 0.2, 1.0, 0.1, 0.1, 0.7, 0.2, 0.4, 0.1, 0.1, 0.3, 0.4, 0.3, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9, },
},
.ringCorrectionFactor = {
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.4, 0.2, 0.3, 0.5, 0.3, 0.1, 0.5, 0.5, 0.3, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.7, 0.5, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, },
{ 1.0, 0.8, 0.2, 1.0, 1.0, 0.3, 1.0, 0.1, 0.1, 0.2, 0.2, 0.1, 0.5, 1.0, 0.8, 0.8, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, },
},
// Follow order in RcclTunableColls
.llProtoRanges = {
/*ReduceScatter*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 8388608, 1, 64}},
/*AllGather*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 65536, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {65536, 8388608, 1, 64}},
/*AllReduce*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 1048576, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {1048576, 70640910, 3145728, 0}},
/*Reduce*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 16383, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {16383, 16777216, 1, 0}},
/*Broadcast*/
{/*LL (min/max/factor/thread_threshold)*/ {0, 2048, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {2048, 16777216, 1, 0}},
},
.channelThresholds = {
// For each collective, define minMax per-rank size threshold for 32,40,48,56,64 channels
/*ReduceScatter*/ {{512, 1024, 2},{1024, 2048, 4},{2048, 4096, 8},{4096, 65536, 16}, {65536, 262144, 32}, {262144, 524288, 40}, {1,1, 48}, {524288, 1048576, 56}, {1048576, 268435457, 64}},
/*AllGather*/ {{2048, 4096, 2},{4096, 8192, 4},{8192, 16384, 8},{16384, 262144, 16},{262144, 524288, 32}, {524288, 1048576, 40}, {1,1, 48}, {1048576, 4194304, 56}, {4194304, 268435457, 64}},
/*AllReduce*/ {{0,0,0},{0,0,0},{0,0,0},{0,0,0},{0,0,0}, {0,0,0}, {0,0,0}, {0,0,0}, {0,0,0}},
},
};
static struct tuningModel rcclTuningModel[] = {
tuning_model_0,
tuning_model_1,
@@ -431,6 +481,7 @@ static struct tuningModel rcclTuningModel[] = {
tuning_model_4,
tuning_model_5,
tuning_model_6,
tuning_model_7,
};
/* Array indexes used below */
@@ -542,7 +593,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS];
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) intraHw[a] = graphs[a]->typeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI;
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) hw[a] = nNodes == 1 ? intraHw[a] : NCCL_HW_NET;
INFO(NCCL_INIT,"RCCL Tuning index:%d",comm->topo->tuning);
memcpy(comm->minMaxLLRange,
rcclTuningModel[comm->topo->tuning].llProtoRanges,
sizeof(rcclTuningModel[comm->topo->tuning].llProtoRanges));
@@ -766,8 +817,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
}
for (int c=0; c<NCCL_NUM_FUNCTIONS; c++) for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
// Disable LL protocol on gfx12xx
int pEnable = (p == NCCL_PROTO_LL && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx12")) ? 0 : protoEnable[c*NCCL_NUM_PROTOCOLS+p];
int pEnable = protoEnable[c*NCCL_NUM_PROTOCOLS+p];
if (pEnable != 0 && p == NCCL_PROTO_LL128) {
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#if defined(ENABLE_LL128)
+20
Просмотреть файл
@@ -51,6 +51,24 @@ void rcclRestrictMaxChannels(struct ncclComm* comm, int& nc ) {
}
}
int32_t rcclGetProtoForGfx12(ncclFunc_t collectiveFunc, size_t sizePerRank){
int returnVal = NCCL_PROTO_SIMPLE;
int SingleNodeLLCutoffs[] = {
/*ncclFuncBroadcast*/ 1536,
/*ncclFuncReduce*/ 8192,
/*ncclFuncAllGather*/ 98304,
/*ncclFuncReduceScatter*/ 98304,
/*ncclFuncAllReduce*/ 913532,
/*ncclFuncSendRecv*/ 0,
/*ncclFuncSend*/ 0,
/*ncclFuncRecv*/ 0
};
if(collectiveFunc < sizeof(SingleNodeLLCutoffs)/sizeof(int)) {
returnVal = (sizePerRank <= SingleNodeLLCutoffs[collectiveFunc]) ? NCCL_PROTO_LL : NCCL_PROTO_SIMPLE;
}
return returnVal;
}
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
// Honor user input for protocol choice
static int userProtocolInput = -2;
@@ -69,6 +87,8 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes == 1 && (info->func == ncclFuncReduceScatter) && sizePerRank <= 352128) {
// Change LL protocol threshold
info->protocol = NCCL_PROTO_LL;
} else if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx12") && comm->nNodes == 1){
info->protocol = rcclGetProtoForGfx12( info->func,sizePerRank);
} else if(!userProtocolInput && comm->nNodes >= 2 && (info->func == ncclFuncReduceScatter || info->func == ncclFuncAllGather || info->func == ncclFuncAllReduce || info->func == ncclFuncBroadcast || info->func == ncclFuncReduce)) {
auto tunableIndex = rcclGetTunableIndex(info->func);
auto llMin = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL][RCCL_PROTOCOL_MIN_IDX];