diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index e501300fcc..9ac1832e3d 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -235,6 +235,12 @@ static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][22] = { { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .41, .27, .25, .39, .46, .72, .76, .87, .92, .97, 1.0, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 } }; +static float ringCorrectionFactor[NCCL_NUM_PROTOCOLS][22] = { + { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .25, .41, .55, .56, .78, .94, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 }, + { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .25, .41, .55, .56, .78, .94, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 }, + { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .04, .08, .09, .09, .11, .13, .25, .40, .59, .76, .86, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 } +}; + static ncclResult_t getAlgoInfo(struct ncclInfo* info) { struct ncclComm* comm = info->comm; float minTime = 3600000.0; // Hopefully no operation will take an hour to complete. @@ -247,6 +253,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) { if (bw == 0) continue; int logSize = log2i(info->nBytes>>6); if (a == NCCL_ALGO_TREE && logSize < 22) bw *= treeCorrectionFactor[p][logSize]; + else if (a == NCCL_ALGO_RING && logSize < 22) bw *= ringCorrectionFactor[p][logSize]; float time = comm->latencies[info->coll][a][p] + (info->nBytes) / (1000 * bw); if (time < minTime) { info->algorithm = a;