diff --git a/projects/rccl/src/graph/tuning.cc b/projects/rccl/src/graph/tuning.cc index 6f9117304b..1a6b136499 100644 --- a/projects/rccl/src/graph/tuning.cc +++ b/projects/rccl/src/graph/tuning.cc @@ -888,14 +888,16 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclComm* comm, int coll, int algorithm, int logSize = log2i(nBytes>>6); #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + if (logSize < 0) logSize = 0; + if (logSize > 26) logSize = 26; + if (algorithm == NCCL_ALGO_TREE) { - if (logSize < 27) bw *= rcclTuningModel[comm->topo->tuning].treeCorrectionFactor[protocol][logSize]; - else bw *= rcclTuningModel[comm->topo->tuning].treeCorrectionFactor[protocol][26]; + bw *= rcclTuningModel[comm->topo->tuning].treeCorrectionFactor[protocol][logSize]; } else if (algorithm == NCCL_ALGO_RING && comm->nNodes > 1) { - if(logSize < 27) bw *= rcclTuningModel[comm->topo->tuning].ringCorrectionFactor[protocol][logSize]; - else bw *= rcclTuningModel[comm->topo->tuning].ringCorrectionFactor[protocol][26]; + bw *= rcclTuningModel[comm->topo->tuning].ringCorrectionFactor[protocol][logSize]; } + #else if (algorithm == NCCL_ALGO_TREE && coll == ncclFuncAllReduce && logSize >= 0 && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize]; if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && comm->nNodes > 1