@@ -235,6 +235,12 @@ static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][22] = {
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .41, .27, .25, .39, .46, .72, .76, .87, .92, .97, 1.0, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 }
|
||||
};
|
||||
|
||||
static float ringCorrectionFactor[NCCL_NUM_PROTOCOLS][22] = {
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .25, .41, .55, .56, .78, .94, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 },
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .25, .41, .55, .56, .78, .94, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 },
|
||||
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .04, .08, .09, .09, .11, .13, .25, .40, .59, .76, .86, 1.0 , 1.0 , 1.0 , 1.0 , 1.0 }
|
||||
};
|
||||
|
||||
static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
||||
struct ncclComm* comm = info->comm;
|
||||
float minTime = 3600000.0; // Hopefully no operation will take an hour to complete.
|
||||
@@ -247,6 +253,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info) {
|
||||
if (bw == 0) continue;
|
||||
int logSize = log2i(info->nBytes>>6);
|
||||
if (a == NCCL_ALGO_TREE && logSize < 22) bw *= treeCorrectionFactor[p][logSize];
|
||||
else if (a == NCCL_ALGO_RING && logSize < 22) bw *= ringCorrectionFactor[p][logSize];
|
||||
float time = comm->latencies[info->coll][a][p] + (info->nBytes) / (1000 * bw);
|
||||
if (time < minTime) {
|
||||
info->algorithm = a;
|
||||
|
||||
Reference in New Issue
Block a user