diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index c92ce89d33..2ee9dc0f22 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -205,7 +205,6 @@ namespace { /* LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))/8); const ssize_t loopSize = int(nChannels*chunkSize); const ssize_t size = args->count; - int nthreadsSplit; if (Proto::Id == NCCL_PROTO_SIMPLE) { nthreadsSplit = nthreads/2; @@ -400,7 +399,8 @@ struct RunWorkElement struct RunWorkElement { __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { - runTreeUpDown(args); + if (args->pad_0 == 0) runTreeUpDown(args); + else runTreeSplit(args); } }; diff --git a/src/enqueue.cc b/src/enqueue.cc index 7e8ffb8b64..ef76aac9cc 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -451,6 +451,31 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i if (info->coll == ncclFuncAllToAllPivot) { int pivotA2ANumUniRings = comm->topo->pivotA2ANumBiRings * 2; info->nChannels = comm->nChannels / pivotA2ANumUniRings * pivotA2ANumUniRings; + } else if (info->coll == ncclFuncAllReduce && comm->topo->pivotA2ANumBiRings == 3) { + static int userTuneInput = -2; + if (userTuneInput == -2) { + const char *protoStr = getenv("NCCL_PROTO"); + const char *algoStr = getenv("NCCL_ALGO"); + if (!protoStr && !algoStr) + userTuneInput = 0; + else + userTuneInput = 1; + } + info->nChannels = nc; + if (!userTuneInput) { + // always respect user settings + if (info->nBytes <= 196608) { + info->protocol = NCCL_PROTO_LL; + info->algorithm = NCCL_ALGO_TREE; + info->nChannels = std::min(comm->nChannels, info->nBytes <= 65536? 4 : 12); + } else if (info->nBytes <= 1048576) { + info->protocol = NCCL_PROTO_LL; + info->algorithm = NCCL_ALGO_RING; + } else { + info->protocol = NCCL_PROTO_SIMPLE; + info->algorithm = NCCL_ALGO_RING; + } + } } else if (comm->topo->nodes[GPU].nodes[0].gpu.gcn == 910 && comm->topo->tuning == 4 && ((comm->nNodes == 2 && info->nBytes == 33554432) || (comm->nNodes <= 4 && info->nBytes == 67108864))) { static int userTuneInput = -2; @@ -532,7 +557,7 @@ comp_next: // Set nstepsPerLoop and nchunksPerLoop NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); - + if (info->comm->topo->pivotA2ANumBiRings == 3 ) work->pad_0 = 1; work->opCount = info->opCount; work->header.type = ncclWorkTypeColl; work->sendbuff = info->sendbuff; @@ -543,7 +568,6 @@ comp_next: work->header.nWarps = info->nThreads / info->comm->WarpSize; work->redOpArg = info->opFull.scalarArg; work->redOpArgIsPtr = info->opFull.scalarArgIsPtr; - if (info->comm->nRanks == 1) { // one-rank reduce index work->header.funcIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype); diff --git a/src/graph/connect.cc b/src/graph/connect.cc index 0c3ba56629..2c211df88c 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -66,6 +66,88 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, return ncclSuccess; } +ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, + struct ncclTopoGraph* treeGraph) { + // int rank = comm->rank; + // int localRanks = comm->topo->nodes[GPU].count; + int nChannels = comm->nChannels; + int ring[NCCL_TOPO_MAX_NODES][NCCL_TOPO_MAX_NODES]; + int xLimit = 0, yLimit = 0; + for (int j=0; j < NCCL_TOPO_MAX_NODES; j++) { + if (treeGraph->treeBase[0][j] == -1) { + xLimit = j; + break; + } + } + for (int k=0; k < NCCL_TOPO_MAX_NODES; k++) { + if (treeGraph->treeBase[k][0] == -1) { + yLimit = k; + break; + } + } + for (int j=0; j < xLimit; j++) { + for (int k=0; k < yLimit; k++) + ring[k][j] = treeGraph->treeBase[k][j]; + } + + //new tree + for (int c=0; cchannels+c; + int ringPrev, ringNext; + int treeRoot = c%comm->nRanks; + int curRank = comm->rank; + int curRankNeighborUp, curRankNeighborDown; + int rootRing, nextRing; + int rootIndex = 0; + int arrayIndex; + for (int j=0; j < xLimit; j++) { + for (int k=0; k < yLimit; k++) { + if (treeRoot == ring[k][j]) { + rootRing = k; + rootIndex = j; + } + if (curRank == ring[k][j]) { + arrayIndex = j; + if (k > 0) curRankNeighborUp=ring[k-1][j]; + else curRankNeighborUp=ring[yLimit-1][j]; + if (k < yLimit-1) { + curRankNeighborDown=ring[k+1][j]; + nextRing = k+1; + } + else { + curRankNeighborDown = ring[0][j]; + nextRing = 0; + } + } + } + } + + if ((curRank != ring[rootRing][arrayIndex])) { + channel->tree.up = curRankNeighborUp; + channel->tree.down[0] = nextRing==rootRing ? -1 : curRankNeighborDown; + channel->tree.down[1] = -1; + } + else { + if (arrayIndex > 0) ringPrev=ring[rootRing][arrayIndex-1]; + else ringPrev=ring[rootRing][xLimit-1]; + if (arrayIndex < xLimit-1) ringNext=ring[rootRing][arrayIndex+1]; + else ringNext=ring[rootRing][0]; + if ((c/2)%2 == 1) { + int temp = ringPrev; + ringPrev = ringNext; + ringNext = temp; + } + channel->tree.up = treeRoot == curRank ? -1 : ringPrev; + channel->tree.down[0] = curRankNeighborDown; + channel->tree.down[1] = ringNext == treeRoot ? -1 : ringNext; + } + channel->tree.down[2] = -1; // cleanup + } + + + return ncclSuccess; +} + static ncclResult_t connectRings(struct ncclComm* comm, int* ringRecv, int* ringSend, int* ringPrev, int* ringNext, int* firstRanks) { int nChannels = comm->nChannels; int nNodes = comm->nNodes; diff --git a/src/graph/rome_models.cc b/src/graph/rome_models.cc index e1c98f6a0e..7c02dd4800 100644 --- a/src/graph/rome_models.cc +++ b/src/graph/rome_models.cc @@ -43,6 +43,7 @@ struct rcclRomeModel { const char *pattern; const char *ringBase; const char *options; + const char *treeBase; }; static struct rcclRomeModel rome_model_22 = { @@ -56,6 +57,7 @@ static struct rcclRomeModel rome_model_22 = { .pattern = "10302120", .ringBase = "7 4 5 3 1 0 6 2|4 7 3 5 0 1 2 6", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_25 = { @@ -69,6 +71,7 @@ static struct rcclRomeModel rome_model_25 = { .pattern = "11303011", .ringBase = "2 1 0 3 6 7 5 4|7 6 4 5 1 2 3 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_27 = { @@ -82,6 +85,7 @@ static struct rcclRomeModel rome_model_27 = { .pattern = "11303011", .ringBase = "0 6 2 3 1 7 5 4|7 1 4 5 6 0 3 2", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_29 = { @@ -95,6 +99,7 @@ static struct rcclRomeModel rome_model_29 = { .pattern = "10302120", .ringBase = "6 5 7 4 0 1 3 2|6 4 7 5 2 3 1 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_31 = { @@ -108,6 +113,7 @@ static struct rcclRomeModel rome_model_31 = { .pattern = "0110201010200110", .ringBase = "1 2 3 0 6 4 5 7|4 6 7 5 2 1 0 3", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_33 = { @@ -121,6 +127,7 @@ static struct rcclRomeModel rome_model_33 = { .pattern = "0110201010200110", .ringBase = "1 4 5 7 0 3 2 6|4 1 7 5 6 2 3 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_30 = { @@ -134,6 +141,7 @@ static struct rcclRomeModel rome_model_30 = { .pattern = "0010201010200010", .ringBase = "3 0 1 2 6 7 5 4|2 1 0 3 7 6 4 5", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_32 = { @@ -147,6 +155,7 @@ static struct rcclRomeModel rome_model_32 = { .pattern = "0010201010200010", .ringBase = "0 6 2 3 4 5 7 1|3 2 6 0 1 7 5 4", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_24 = { @@ -160,6 +169,7 @@ static struct rcclRomeModel rome_model_24 = { .pattern = "10303010", .ringBase = "0 1 2 3 5 7 6 4|1 0 3 2 7 5 4 6", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_26 = { @@ -173,6 +183,7 @@ static struct rcclRomeModel rome_model_26 = { .pattern = "10303010", .ringBase = "4 5 7 1 0 3 2 6|3 0 6 2 1 7 5 4", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_23 = { @@ -186,6 +197,7 @@ static struct rcclRomeModel rome_model_23 = { .pattern = "10302020", .ringBase = "1 7 6 4 5 2 0 3|2 5 3 0 4 6 7 1", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_38 = { @@ -199,6 +211,7 @@ static struct rcclRomeModel rome_model_38 = { .pattern = "10201000201010", .ringBase = "6 7 1 4 3 5 2 0|0 2 5 3 4 1 7 6", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_28 = { @@ -212,6 +225,7 @@ static struct rcclRomeModel rome_model_28 = { .pattern = "10302020", .ringBase = "0 3 2 1 4 5 6 7|7 6 5 4 1 2 3 0|0 2 5 7 4 6 3 1|1 3 6 4 7 5 2 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_40 = { @@ -225,6 +239,7 @@ static struct rcclRomeModel rome_model_40 = { .pattern = "10302120", .ringBase = "6 7 1 4 0 5 3 2|7 6 4 1 0 2 3 5", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_42 = { @@ -238,6 +253,7 @@ static struct rcclRomeModel rome_model_42 = { .pattern = "10201001201010", .ringBase = "7 4 6 1 3 0 2 5|6 4 7 1 3 2 5 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_44 = { @@ -251,6 +267,7 @@ static struct rcclRomeModel rome_model_44 = { .pattern = "20202120", .ringBase = "5 4 7 6 2 1 3 0|5 6 7 4 1 0 2 3", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_45 = { @@ -264,6 +281,7 @@ static struct rcclRomeModel rome_model_45 = { .pattern = "10201000201010", .ringBase = "0 1 2 3 4 5 6 7|0 2 5 7 4 6 1 3|0 3 1 6 4 7 5 2|0 7 6 5 4 3 2 1", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_46 = { @@ -277,6 +295,7 @@ static struct rcclRomeModel rome_model_46 = { .pattern = "10201001201010", .ringBase = "6 5 7 4 1 2 3 0|7 4 6 5 1 0 3 2", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_48 = { @@ -290,6 +309,7 @@ static struct rcclRomeModel rome_model_48 = { .pattern = "20202020", .ringBase = "0 1 2 3 4 5 6 7|7 6 5 4 3 2 1 0|0 1 2 3 4 5 6 7|7 6 5 4 3 2 1 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_49 = { @@ -303,6 +323,7 @@ static struct rcclRomeModel rome_model_49 = { .pattern = "21212121", .ringBase = "N0 0 1 2 3 4 5 6 7 N3|N3 7 6 5 4 3 2 1 0 N0|N1 2 3 0 1 6 7 4 5 N2|N2 5 4 7 6 1 0 3 2 N1", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_52 = { @@ -316,6 +337,7 @@ static struct rcclRomeModel rome_model_52 = { .pattern = "80", .ringBase = "0 1 3 2 4 5 7 6|6 7 5 4 2 3 1 0|0 1 5 4 6 7 3 2|2 3 7 6 4 5 1 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_53 = { @@ -329,6 +351,7 @@ static struct rcclRomeModel rome_model_53 = { .pattern = "21212121", .ringBase = "N0 0 1 2 3 4 5 6 7 N3|N3 7 6 5 4 3 2 1 0 N0|N1 2 3 0 1 6 7 4 5 N2|N2 5 4 7 6 1 0 3 2 N1", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_43 = { @@ -342,6 +365,7 @@ static struct rcclRomeModel rome_model_43 = { .pattern = "20202020", .ringBase = "0 1 2 3 4 5 6 7|0 2 5 7 4 6 1 3|0 3 1 6 4 7 5 2|0 7 6 5 4 3 2 1", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_55 = { @@ -355,6 +379,7 @@ static struct rcclRomeModel rome_model_55 = { .pattern = "20202020", .ringBase = "0 1 2 3 4 5 6 7|7 6 5 4 3 2 1 0|2 3 0 1 6 7 4 5|5 4 7 6 1 0 3 2", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_56 = { @@ -368,6 +393,7 @@ static struct rcclRomeModel rome_model_56 = { .pattern = "40404040", .ringBase = "0 1 3 2 6 7 15 14 10 11 9 8 12 13 5 4|0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4|0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1|4 5 13 12 8 9 11 10 14 15 7 6 2 3 1 0|4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0|1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0", .options = "pivotA2AEnabled=1,pivotA2ANumBiRings=3,tuning=1", + .treeBase = "10 11|14 15|6 7|2 3|0 1|4 5|12 13|8 9", }; static struct rcclRomeModel rome_model_58 = { @@ -381,6 +407,7 @@ static struct rcclRomeModel rome_model_58 = { .pattern = "402020", .ringBase = "0 1 3 2 4 5 7 6|6 7 5 4 2 3 1 0|0 1 5 4 6 7 3 2|2 3 7 6 4 5 1 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_59 = { @@ -394,6 +421,7 @@ static struct rcclRomeModel rome_model_59 = { .pattern = "42424242", .ringBase = "N4 9 8 12 13 5 4 0 1 3 2 6 7 15 14 10 11 N5|N1 3 2 6 7 15 14 10 11 9 8 12 13 5 4 0 1 N0|N3 7 6 2 3 1 0 4 5 13 12 8 9 11 10 14 15 N7|N7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 6 7 N3|N5 11 10 14 15 7 6 2 3 1 0 4 5 13 12 8 9 N4|N0 1 0 4 5 13 12 8 9 11 10 14 15 7 6 2 3 N1|N3 6 7 3 2 1 0 4 5 14 15 11 10 9 8 12 13 N6|N7 14 15 11 10 9 8 12 13 6 7 3 2 1 0 4 5 N2|N2 5 4 0 1 2 3 7 6 13 12 8 9 10 11 15 14 N7|N6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 7 6 N3|N4 8 9 13 12 4 5 1 0 2 3 7 6 14 15 11 10 N5|N5 10 11 15 14 6 7 3 2 0 1 5 4 12 13 9 8 N4|N6 12 13 9 8 10 11 15 14 6 7 3 2 0 1 5 4 N2|N2 4 5 1 0 2 3 7 6 14 15 11 10 8 9 13 12 N6|N1 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1 0 N0|N0 0 1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 N1|N5 10 11 9 8 12 13 5 4 0 1 3 2 6 7 15 14 N7|N3 6 7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 N1|N1 2 3 1 0 4 5 13 12 8 9 11 10 14 15 7 6 N3|N7 14 15 7 6 2 3 1 0 4 5 13 12 8 9 11 10 N5|N0 0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4 N2|N4 8 9 10 11 15 14 5 4 0 1 2 3 7 6 13 12 N6|N3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 N1|N1 3 2 1 0 4 5 14 15 11 10 9 8 12 13 6 7 N3|N6 12 13 6 7 3 2 1 0 4 5 14 15 11 10 9 8 N4|N2 4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0 N0|N0 1 0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 N2|N6 13 12 4 5 1 0 2 3 7 6 14 15 11 10 8 9 N4|N5 11 10 8 9 13 12 4 5 1 0 2 3 7 6 14 15 N7|N2 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0 1 N0|N7 15 14 6 7 3 2 0 1 5 4 12 13 9 8 10 11 N5|N4 9 8 10 11 15 14 6 7 3 2 0 1 5 4 12 13 N6", .options = "tuning=1", + .treeBase = "", }; static struct rcclRomeModel rome_model_62 = { @@ -407,6 +435,7 @@ static struct rcclRomeModel rome_model_62 = { .pattern = "20202020", .ringBase = "0 1 3 2 4 5 7 6|6 7 5 4 2 3 1 0|0 1 5 4 6 7 3 2|2 3 7 6 4 5 1 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_63 = { @@ -420,6 +449,7 @@ static struct rcclRomeModel rome_model_63 = { .pattern = "21212121", .ringBase = "N0 0 1 5 4 6 7 3 2 N1|N1 2 3 7 6 4 5 1 0 N0|N3 7 6 0 1 3 2 4 5 N2|N2 5 4 2 3 1 0 6 7 N3|N0 0 1 5 4 6 7 3 2 N1|N1 2 3 7 6 4 5 1 0 N0|N3 7 6 0 1 3 2 4 5 N2|N2 5 4 2 3 1 0 6 7 N3", .options = "tuning=3", + .treeBase = "", }; static struct rcclRomeModel rome_model_65 = { @@ -433,6 +463,7 @@ static struct rcclRomeModel rome_model_65 = { .pattern = "42424242", .ringBase = "N4 9 8 12 13 5 4 0 1 3 2 6 7 15 14 10 11 N5|N1 3 2 6 7 15 14 10 11 9 8 12 13 5 4 0 1 N0|N3 7 6 2 3 1 0 4 5 13 12 8 9 11 10 14 15 N7|N7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 6 7 N3|N5 11 10 14 15 7 6 2 3 1 0 4 5 13 12 8 9 N4|N0 1 0 4 5 13 12 8 9 11 10 14 15 7 6 2 3 N1|N3 6 7 3 2 1 0 4 5 14 15 11 10 9 8 12 13 N6|N7 14 15 11 10 9 8 12 13 6 7 3 2 1 0 4 5 N2|N2 5 4 0 1 2 3 7 6 13 12 8 9 10 11 15 14 N7|N6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 7 6 N3|N4 8 9 13 12 4 5 1 0 2 3 7 6 14 15 11 10 N5|N5 10 11 15 14 6 7 3 2 0 1 5 4 12 13 9 8 N4|N6 12 13 9 8 10 11 15 14 6 7 3 2 0 1 5 4 N2|N2 4 5 1 0 2 3 7 6 14 15 11 10 8 9 13 12 N6|N1 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1 0 N0|N0 0 1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 N1|N5 10 11 9 8 12 13 5 4 0 1 3 2 6 7 15 14 N7|N3 6 7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 N1|N1 2 3 1 0 4 5 13 12 8 9 11 10 14 15 7 6 N3|N7 14 15 7 6 2 3 1 0 4 5 13 12 8 9 11 10 N5|N0 0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4 N2|N4 8 9 10 11 15 14 5 4 0 1 2 3 7 6 13 12 N6|N3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 N1|N1 3 2 1 0 4 5 14 15 11 10 9 8 12 13 6 7 N3|N6 12 13 6 7 3 2 1 0 4 5 14 15 11 10 9 8 N4|N2 4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0 N0|N0 1 0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 N2|N6 13 12 4 5 1 0 2 3 7 6 14 15 11 10 8 9 N4|N5 11 10 8 9 13 12 4 5 1 0 2 3 7 6 14 15 N7|N2 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0 1 N0|N7 15 14 6 7 3 2 0 1 5 4 12 13 9 8 10 11 N5|N4 9 8 10 11 15 14 6 7 3 2 0 1 5 4 12 13 N6", .options = "netGdrLevel=PHB,tuning=4", + .treeBase = "", }; static struct rcclRomeModel rome_model_66 = { @@ -446,6 +477,7 @@ static struct rcclRomeModel rome_model_66 = { .pattern = "4040", .ringBase = "0 6 7 5 4 2 3 1|1 3 2 4 5 7 6 0|0 1 7 6 2 3 5 4|4 5 3 2 6 7 1 0", .options = "disableNumaMatching=1,tuning=2", + .treeBase = "", }; static struct rcclRomeModel rome_model_67 = { @@ -459,6 +491,7 @@ static struct rcclRomeModel rome_model_67 = { .pattern = "4242", .ringBase = "N3 7 6 0 1 3 2 4 5 N2|N2 5 4 2 3 1 0 6 7 N3|N1 2 3 5 4 0 1 7 6 N3|N2 4 5 3 2 6 7 1 0 N0|N1 3 2 4 5 7 6 0 1 N0|N0 1 0 6 7 5 4 2 3 N1|N0 0 1 7 6 2 3 5 4 N2|N3 6 7 1 0 4 5 3 2 N1", .options = "disableNumaMatching=1,tuning=2", + .treeBase = "", }; static struct rcclRomeModel rome_model_68 = { @@ -472,6 +505,7 @@ static struct rcclRomeModel rome_model_68 = { .pattern = "@@", .ringBase = "N0 0 1 2 3 N3 N4 4 5 6 7 N7 N8 8 9 10 11 N11 N12 12 13 14 15 N15|N15 15 14 13 12 N12 N11 11 10 9 8 N8 N7 7 6 5 4 N4 N3 3 2 1 0 N0|N1 1 3 0 2 N2 N5 5 7 4 6 N6 N9 9 11 8 10 N10 N13 13 15 12 14 N14|N14 14 12 15 13 N13 N10 10 8 11 9 N9 N6 6 4 7 5 N5 N2 2 0 3 1 N1|N0 0 1 2 3 N3 N4 4 5 6 7 N7 N8 8 9 10 11 N11 N12 12 13 14 15 N15|N15 15 14 13 12 N12 N11 11 10 9 8 N8 N7 7 6 5 4 N4 N3 3 2 1 0 N0|N1 1 3 0 2 N2 N5 5 7 4 6 N6 N9 9 11 8 10 N10 N13 13 15 12 14 N14|N14 14 12 15 13 N13 N10 10 8 11 9 N9 N6 6 4 7 5 N5 N2 2 0 3 1 N1", .options = "netGdrLevel=PIX", + .treeBase = "", }; static struct rcclRomeModel rome_model_71 = { @@ -485,6 +519,7 @@ static struct rcclRomeModel rome_model_71 = { .pattern = "4040", .ringBase = "0 1 3 2 4 5 7 6|6 7 5 4 2 3 1 0|0 1 5 4 2 3 7 6|6 7 3 2 4 5 1 0", .options = "disableNumaMatching=1,tuning=2", + .treeBase = "", }; static struct rcclRomeModel rome_model_72 = { @@ -498,6 +533,7 @@ static struct rcclRomeModel rome_model_72 = { .pattern = "4242", .ringBase = "N0 0 1 3 2 4 5 7 6 N3|N1 2 3 1 0 6 7 5 4 N2|N3 7 6 0 1 5 4 2 3 N1|N0 1 0 6 7 3 2 4 5 N2|N2 4 5 7 6 0 1 3 2 N1|N3 6 7 5 4 2 3 1 0 N0|N2 5 4 2 3 7 6 0 1 N0|N1 3 2 4 5 1 0 6 7 N3", .options = "disableNumaMatching=1,tuning=2", + .treeBase = "", }; static struct rcclRomeModel rome_model_73 = { @@ -511,6 +547,7 @@ static struct rcclRomeModel rome_model_73 = { .pattern = "20202020", .ringBase = "0 1 3 2 4 5 7 6|6 7 5 4 2 3 1 0|0 1 5 4 6 7 3 2|2 3 7 6 4 5 1 0", .options = "", + .treeBase = "", }; static struct rcclRomeModel rome_model_74 = { @@ -524,6 +561,7 @@ static struct rcclRomeModel rome_model_74 = { .pattern = "21212121", .ringBase = "N0 0 1 5 4 6 7 3 2 N1|N1 2 3 7 6 4 5 1 0 N0|N3 7 6 0 1 3 2 4 5 N2|N2 5 4 2 3 1 0 6 7 N3|N0 0 1 5 4 6 7 3 2 N1|N1 2 3 7 6 4 5 1 0 N0|N3 7 6 0 1 3 2 4 5 N2|N2 5 4 2 3 1 0 6 7 N3", .options = "tuning=3", + .treeBase = "", }; static struct rcclRomeModel romeTopoModels[] = { @@ -681,6 +719,75 @@ end: return ncclSuccess; } + +/* Parse user defined treeBase for complicated trees. Format is like : + * "10 11|14 15|6 7|2 3|0 1|4 5|12 13|8 9" + * + * Rings with a non-matching number of gpus are ignored so we can provide + * rings for multiple cases. + */ +ncclResult_t parseGraphLight(const char* str, struct ncclTopoSystem* system, struct ncclTopoGraph* graph, int* gpu_map) { + int gpus[NCCL_TOPO_MAX_NODES]; + int nChannels = 0; + int gpu = 0; + int offset = 0; + if (str[0] == 0) return ncclSuccess; + int status = 0; // 0 : between numbers, 1 : inside number, 2: start NET, 3: inside NET + int nets[NCCL_TOPO_MAX_NODES*2]; + int net_offset = 0, net_count = 0; + int ngpus = system->nodes[GPU].count; + // int nnets = system->nodes[NET].count; + int x=0, y=0; + // | + // from | to | x to y + do { + int digit = str[offset] - '0'; + if (digit >= 0 && digit <= 9) { + // offset++; // all the offsets + switch (status) { + case 0: + gpus[gpu] = digit; + status = 1; + break; + case 1: + gpus[gpu] = gpus[gpu]*10+digit; + break; + } + } else { + if (status == 1) { + gpu++; + } + status = 0; + if (str[offset] == '|' || str[offset] == 0) { // bump y, and x and make x 0 after + for (int r=0; rnodes[GPU].nodes[j].gpu.dev) + break; + if (j < ngpus) + { + graph->treeBase[r][x] = system->nodes[GPU].nodes[j].gpu.rank; + y=r; + } + else + return ncclInternalError; + } + y++; + graph->treeBase[y][x] = -1; + // y=0; + x++; + gpu=0; + } + } + } while (str[offset++] != 0); + graph->treeBase[0][x] = -1; + return ncclSuccess; +} + #define MAX_OPT_TOKENS 10 extern const char* topoPathTypeStr[]; @@ -1274,6 +1381,8 @@ ncclResult_t parse1H16P(struct ncclTopoSystem* system, struct ncclTopoGraph* gra // create 16P1H based on reference and remapped ids NCCLCHECK(parseGraph(romeTopoModels[i].ringBase, system, graph, g16, nnets > 1 ? n : NULL)); + + NCCLCHECK(parseGraphLight(romeTopoModels[i].treeBase, system, graph, g16)); // clean up free(all_gpu_permutations); return ncclSuccess; diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index 137386197a..259bb8a4bd 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -281,7 +281,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom nNodes; for (int a=0; aspeedIntra : graphs[a]->speedInter; diff --git a/src/include/graph.h b/src/include/graph.h index 4cfe9539a6..57727eef84 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -89,6 +89,7 @@ struct ncclTopoGraph { int inter[MAXCHANNELS*2]; int nIntraChannels; int intraNets[MAXCHANNELS*NCCL_TOPO_MAX_NODES*2]; + int treeBase[NCCL_TOPO_MAX_NODES][NCCL_TOPO_MAX_NODES]; }; ncclResult_t ncclTopoCompute(struct ncclTopoSystem* system, struct ncclTopoGraph* graph); @@ -112,6 +113,8 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, struct ncclTopoRanks** allTopoRanks, int* rings, struct ncclTopoGraph* collNetGraph, int nc); +ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); + ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph); #include "info.h" ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time); diff --git a/src/init.cc b/src/init.cc index 557c440afc..fbbcf2ae34 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1002,6 +1002,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS)); NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph, nc)); + if (comm->topo->pivotA2ANumBiRings == 3) NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph)); + free(allTopoRanks); free(nodesTreePatterns); free(nodesFirstRank); @@ -1017,7 +1019,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclTree* tree = &comm->channels[c].tree; snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d", c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up); - INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev, + INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev, comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId); } line[1023] = '\0'; diff --git a/tools/topo_expl/utils.cpp b/tools/topo_expl/utils.cpp index 42f93cad02..d70d7a7021 100644 --- a/tools/topo_expl/utils.cpp +++ b/tools/topo_expl/utils.cpp @@ -789,6 +789,8 @@ ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS)); NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph, nc)); + if (comm->topo->pivotA2ANumBiRings == 3) NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph)); + free(allTopoRanks); free(nodesTreePatterns); free(nodesFirstRank);