Update tuning parameters (#518)
* Update tuning parameters * Respect user algo and topo selections
This commit is contained in:
gecommit door
GitHub
bovenliggende
2d558c9abc
commit
7cbbca4da1
@@ -426,6 +426,26 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
|
||||
if (info->coll == ncclFuncAllToAllPivot) {
|
||||
int pivotA2ANumUniRings = comm->topo->pivotA2ANumBiRings * 2;
|
||||
info->nChannels = comm->nChannels / pivotA2ANumUniRings * pivotA2ANumUniRings;
|
||||
} else if (comm->topo->nodes[GPU].nodes[0].gpu.gcn == 910 && comm->nChannels == 32 && comm->nRanks/comm->nNodes == 16 && info->nBytes >= 268435456
|
||||
&& ((comm->nNodes > 2 && info->nBytes <= 2147483648) || (comm->nNodes == 2 && info->nBytes <= 1073741824))) {
|
||||
static int userTuneInput = -2;
|
||||
if (userTuneInput == -2) {
|
||||
const char *protoStr = getenv("NCCL_PROTO");
|
||||
const char *algoStr = getenv("NCCL_ALGO");
|
||||
if (!protoStr && !algoStr)
|
||||
userTuneInput = 0;
|
||||
else
|
||||
userTuneInput = 1;
|
||||
}
|
||||
if (userTuneInput) {
|
||||
// always respect user settings
|
||||
info->nChannels = nc;
|
||||
} else {
|
||||
// use ring simple with reduced channels on gfx90a for specific data sizes
|
||||
info->protocol = NCCL_PROTO_SIMPLE;
|
||||
info->algorithm = NCCL_ALGO_RING;
|
||||
info->nChannels = nc/2;
|
||||
}
|
||||
} else {
|
||||
info->nChannels = nc;
|
||||
}
|
||||
|
||||
+18
-18
@@ -63,11 +63,11 @@ static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 12.0
|
||||
// Tree/Simple is the latency a 256kB chunk, which is ~ base lat + 256k/12GB/s (+ 256k/12GB/s for the network).
|
||||
static float hwLat [3][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] =
|
||||
{ /* NVLINK */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* Ring (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* CollNet (LL/LL128/Simple)*/ { 1.2, 1.2, 3.8 } },
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* Ring (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* CollNet (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 } },
|
||||
/* PCI */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 } },
|
||||
/* NET */
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 40.0, 40.0, 50.0 }, /* Ring (LL/LL128/Simple)*/ { 4.0, 4.0, 25.0 }, /* CollNet (LL/LL128/Simple)*/ { 9.8, 9.8, 19.5 } }
|
||||
{ /* Tree (LL/LL128/Simple)*/ { 33.0, 33.0, 15.8 }, /* Ring (LL/LL128/Simple)*/ { 5.1, 5.1, 68.8 }, /* CollNet (LL/LL128/Simple)*/ { 33.0, 33.0, 15.8 } }
|
||||
};
|
||||
|
||||
// LL128 max BW per channel
|
||||
@@ -134,10 +134,9 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
|
||||
// Various model refinements
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if (a == NCCL_ALGO_RING && (p == NCCL_PROTO_LL || p == NCCL_PROTO_LL128)) busBw *= 0.05;
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_SIMPLE) (nNodes == 2) ? busBw *= 0.33 : busBw *= 0.11;
|
||||
if (a == NCCL_ALGO_TREE && (p == NCCL_PROTO_LL || p == NCCL_PROTO_LL128)) busBw *= 0.04;
|
||||
if (gcn == 910 && a == NCCL_ALGO_TREE && p == NCCL_PROTO_SIMPLE && nNodes == 2 && nRanks == 32) busBw *= 3.61;
|
||||
if (a == NCCL_ALGO_RING && (p == NCCL_PROTO_LL || p == NCCL_PROTO_LL128)) busBw *= 0.20;
|
||||
if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_SIMPLE) (nNodes == 2) ? busBw *= 0.99 : busBw *= 0.42;
|
||||
if (a == NCCL_ALGO_TREE && (p == NCCL_PROTO_LL || p == NCCL_PROTO_LL128)) busBw *= 0.15;
|
||||
if (gcn == 910 && a == NCCL_ALGO_TREE && p == NCCL_PROTO_SIMPLE && nNodes == 2 && nRanks == 16) busBw *= 6.5;
|
||||
#else
|
||||
if (compCap80) busBw = std::min(busBw, 235.0f);
|
||||
@@ -282,18 +281,19 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
|
||||
|
||||
// Trees are not perfectly sticking to the model for medium sizes. Applying a static correction
|
||||
// factor is not ideal but works quite well. Powers of two, 64 B to 128MB.
|
||||
static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][25] = {
|
||||
{ 0.7, 0.7, 0.7, 0.6, 0.6, 0.3, 0.9, 0.5, 0.5, 0.6, 0.5, 0.5, 0.8, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8, 0.9, 0.9, 1.0, 1.0, 1.0, 1.0, },
|
||||
{ 0.7, 0.7, 0.7, 0.6, 0.6, 0.3, 0.9, 0.5, 0.5, 0.6, 0.5, 0.5, 0.8, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8, 0.9, 0.9, 1.0, 1.0, 1.0, 1.0, },
|
||||
{ 0.4, 0.4, 0.3, 0.3, 0.2, 0.5, 0.5, 0.7, 0.2, 0.2, 0.3, 0.6, 0.7, 1.0, 1.3, 1.0, 1.2, 1.2, 1.1, 1.1, 1.2, 1.2, 1.5, 1.7, 2.4, },
|
||||
static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][27] = {
|
||||
{ 0.5, 0.4, 0.7, 0.6, 1.0, 1.0, 0.5, 0.4, 0.1, 0.5, 0.4, 0.6, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.6, 0.5, 0.4, 0.4, 0.3, 0.2, 0.1, 0.1, 0.1, },
|
||||
{ 0.5, 0.4, 0.7, 0.6, 1.0, 1.0, 0.5, 0.4, 0.1, 0.5, 0.4, 0.6, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.6, 0.5, 0.4, 0.4, 0.3, 0.2, 0.1, 0.1, 0.1, },
|
||||
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.4, 0.5, 0.1, 0.6, 1.0, 1.0, 1.0, 0.6, 0.5, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.5, 0.3, 0.3, },
|
||||
};
|
||||
|
||||
static float ringCorrectionFactor[NCCL_NUM_PROTOCOLS][25] = {
|
||||
{ 0.4, 0.6, 0.6, 0.3, 0.2, 0.2, 0.2, 0.2, 0.4, 0.6, 0.7, 0.9, 1.4, 1.5, 1.0, 0.8, 0.7, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, },
|
||||
{ 0.4, 0.6, 0.6, 0.3, 0.2, 0.2, 0.2, 0.2, 0.4, 0.6, 0.7, 0.9, 1.4, 1.5, 1.0, 0.8, 0.7, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, },
|
||||
{ 0.6, 0.4, 0.4, 0.4, 0.2, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.6, 0.8, 0.9, 1.1, 2.0, 2.9, },
|
||||
static float ringCorrectionFactor[NCCL_NUM_PROTOCOLS][27] = {
|
||||
{ 1.0, 0.5, 1.0, 1.0, 0.6, 0.7, 1.0, 1.0, 0.2, 1.0, 0.9, 0.7, 1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.8, 0.7, 0.6, 0.5, 0.5, 0.3, 0.2, 0.1, 0.1, },
|
||||
{ 1.0, 0.5, 1.0, 1.0, 0.6, 0.7, 1.0, 1.0, 0.2, 1.0, 0.9, 0.7, 1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.8, 0.7, 0.6, 0.5, 0.5, 0.3, 0.2, 0.1, 0.1, },
|
||||
{ 0.3, 1.0, 0.3, 0.1, 0.1, 0.1, 0.3, 0.7, 1.0, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.5, 0.9, 1.0, 1.0, 1.0, 1.0, },
|
||||
};
|
||||
|
||||
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time) {
|
||||
float bw = info->comm->bandwidths[info->coll][algorithm][protocol];
|
||||
float lat = info->comm->latencies[info->coll][algorithm][protocol];
|
||||
@@ -304,12 +304,12 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
if (algorithm == NCCL_ALGO_TREE) {
|
||||
if (logSize < 25) bw *= treeCorrectionFactor[protocol][logSize];
|
||||
else bw *= treeCorrectionFactor[protocol][24];
|
||||
if (logSize < 27) bw *= treeCorrectionFactor[protocol][logSize];
|
||||
else bw *= treeCorrectionFactor[protocol][26];
|
||||
}
|
||||
else if (algorithm == NCCL_ALGO_RING) {
|
||||
if(logSize < 25) bw *= ringCorrectionFactor[protocol][logSize];
|
||||
else bw *= ringCorrectionFactor[protocol][24];
|
||||
if(logSize < 27) bw *= ringCorrectionFactor[protocol][logSize];
|
||||
else bw *= ringCorrectionFactor[protocol][26];
|
||||
}
|
||||
#else
|
||||
if (algorithm == NCCL_ALGO_TREE && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize];
|
||||
|
||||
Verwijs in nieuw issue
Block a user