diff --git a/src/graph/connect.cc b/src/graph/connect.cc index 4ea634483e..5197a89db3 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -60,8 +60,8 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, } // Duplicate channels rings/trees struct ncclChannel* channel0 = comm->channels; - struct ncclChannel* channel1 = channel0+nChannels; - memcpy(channel1, channel0, nChannels*sizeof(struct ncclChannel)); + struct ncclChannel* channel1 = (nChannels > MAXCHANNELS/2) ? 0 : channel0+nChannels; + if (channel1) memcpy(channel1, channel0, nChannels*sizeof(struct ncclChannel)); return ncclSuccess; } @@ -74,25 +74,25 @@ static ncclResult_t connectRings(struct ncclComm* comm, int* ringRecv, int* ring int* prev = ringPrev+c*comm->nRanks; int* next = ringNext+c*comm->nRanks; struct ncclChannel* channel0 = comm->channels+c; - struct ncclChannel* channel1 = channel0+nChannels; + struct ncclChannel* channel1 = (nChannels > MAXCHANNELS/2) ? 0 : channel0+nChannels; for (int n=0; nrank == recvRank) { channel0->ring.prev = prevSendRank; - channel1->ring.prev = prevSendRank; + if (channel1) channel1->ring.prev = prevSendRank; } int sendRank = send[firstRanks[n]]; int nextRecvRank = recv[firstRanks[(n+1)%nNodes]]; next[sendRank] = nextRecvRank; if (comm->rank == sendRank) { channel0->ring.next = nextRecvRank; - channel1->ring.next = nextRecvRank; + if (channel1) channel1->ring.next = nextRecvRank; } } TRACE(NCCL_GRAPH, "Ring %d : %d -> %d -> %d", c, channel0->ring.prev, comm->rank, channel0->ring.next); - TRACE(NCCL_GRAPH, "Ring %d : %d -> %d -> %d", c+nChannels, channel1->ring.prev, comm->rank, channel1->ring.next); + if (channel1) TRACE(NCCL_GRAPH, "Ring %d : %d -> %d -> %d", c+nChannels, channel1->ring.prev, comm->rank, channel1->ring.next); } return ncclSuccess; } @@ -135,29 +135,30 @@ static ncclResult_t connectTrees(struct ncclComm* comm, int* treeToParent, int* NCCLCHECK(ncclGetDtree(nNodes, node, &t0u, &t0d0, &t0d1, &t0ChildType, &t1u, &t1d0, &t1d1, &t1ChildType)); for (int c=0; cchannels+c; - struct ncclChannel* channel1 = channel0+nChannels; + struct ncclChannel* channel1 = (nChannels > MAXCHANNELS/2) ? 0 : channel0+nChannels; NCCLCHECK(getIndexes(treeToParent+c*comm->nRanks, ranksToParent, nNodes, firstRanks)); NCCLCHECK(getIndexes(treeToChild0+c*comm->nRanks, ranksToChild0, nNodes, firstRanks)); NCCLCHECK(getIndexes(treeToChild1+c*comm->nRanks, ranksToChild1, nNodes, firstRanks)); if (comm->rank == ranksToParent[node]) { NCCLCHECK(setTreeUp(&channel0->tree, t0ChildType == 0 ? ranksToChild0 : ranksToChild1, t0u)); - NCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ranksToChild0 : ranksToChild1, t1u)); + if (channel1) NCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ranksToChild0 : ranksToChild1, t1u)); } if (comm->rank == ranksToChild0[node]) { NCCLCHECK(setTreeDown(&channel0->tree, ranksToParent, t0d0)); - NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d0)); + if (channel1) NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d0)); } if (comm->rank == ranksToChild1[node]) { NCCLCHECK(setTreeDown(&channel0->tree, ranksToParent, t0d1)); - NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d1)); + if (channel1) NCCLCHECK(setTreeDown(&channel1->tree, ranksToParent, t1d1)); } if (comm->rank == ranksToParent[node] || comm->rank == ranksToChild0[node] || comm->rank == ranksToChild1[node]) { INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c, channel0->tree.up, comm->rank, channel0->tree.down[0], channel0->tree.down[1], channel0->tree.down[2]); - INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c+nChannels, channel1->tree.up, comm->rank, channel1->tree.down[0], channel1->tree.down[1], channel1->tree.down[2]); + if (channel1) INFO(NCCL_GRAPH, "Tree %d : %d -> %d -> %d/%d/%d", c+nChannels, channel1->tree.up, comm->rank, channel1->tree.down[0], channel1->tree.down[1], channel1->tree.down[2]); } - channel0->tree.depth = channel1->tree.depth = depth; + channel0->tree.depth = depth; + if (channel1) channel1->tree.depth = depth; } free(ranksToParent); free(ranksToChild0); @@ -287,14 +288,14 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa NCCLCHECK(connectTrees(comm, treeToParent, treeToChild0, treeToChild1, firstRanks, treePatterns)); // Duplicate ringPrev/ringNext for ncclBuildRing - memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int)); - memcpy(ringNext+nChannels*nranks, ringNext, nChannels*nranks*sizeof(int)); + if (nChannels <= MAXCHANNELS/2) memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int)); + if (nChannels <= MAXCHANNELS/2) memcpy(ringNext+nChannels*nranks, ringNext, nChannels*nranks*sizeof(int)); // Get number of channels after duplication nc *= comm->nChannels; nc = std::min((int)ncclMaxNchannels(), nc); // Duplication should be complete now - nChannels = comm->nChannels = std::min(MAXCHANNELS,nChannels*2); + nChannels = comm->nChannels = std::min(MAXCHANNELS, (nChannels <= MAXCHANNELS/2) ? nChannels*2 : nChannels); // Setup CollNet if (comm->collNetSupport == 1) { diff --git a/src/graph/rome_models.cc b/src/graph/rome_models.cc index 8c2fb81e65..0fed76e544 100644 --- a/src/graph/rome_models.cc +++ b/src/graph/rome_models.cc @@ -391,7 +391,7 @@ static struct rcclRomeModel rome_model_59 = { .connMatrix = { 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, }, .gdrLevel = { 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, }, .pattern = "42424242", - .ringBase = "N0 1 0 4 5 14 15 11 10 9 8 12 13 6 7 3 2 N1|N1 3 2 0 1 5 4 12 13 9 8 10 11 15 14 6 7 N3|N2 5 4 0 1 3 2 6 7 15 14 10 11 9 8 12 13 N6|N3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 N1|N5 11 10 8 9 13 12 4 5 1 0 2 3 7 6 14 15 N7|N6 13 12 8 9 11 10 14 15 7 6 2 3 1 0 4 5 N2|N2 4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0 N0|N7 15 14 6 7 3 2 0 1 5 4 12 13 9 8 10 11 N5|N0 0 1 3 2 6 7 15 14 10 11 9 8 12 13 5 4 N2|N1 2 3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 N0|N7 14 15 11 10 8 9 13 12 4 5 1 0 2 3 7 6 N3|N4 8 9 11 10 14 15 7 6 2 3 1 0 4 5 13 12 N6|N4 9 8 12 13 6 7 3 2 1 0 4 5 14 15 11 10 N5|N3 6 7 3 2 0 1 5 4 12 13 9 8 10 11 15 14 N7|N6 12 13 5 4 0 1 3 2 6 7 15 14 10 11 9 8 N4|N5 10 11 15 14 5 4 0 1 2 3 7 6 13 12 8 9 N4|", + .ringBase = "N4 9 8 12 13 5 4 0 1 3 2 6 7 15 14 10 11 N5|N5 10 11 9 8 12 13 5 4 0 1 3 2 6 7 15 14 N7|N7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 6 7 N3|N3 6 7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 N1|N1 3 2 6 7 15 14 10 11 9 8 12 13 5 4 0 1 N0|N1 2 3 1 0 4 5 13 12 8 9 11 10 14 15 7 6 N3|N3 7 6 2 3 1 0 4 5 13 12 8 9 11 10 14 15 N7|N5 11 10 14 15 7 6 2 3 1 0 4 5 13 12 8 9 N4|N7 14 15 7 6 2 3 1 0 4 5 13 12 8 9 11 10 N5|N0 1 0 4 5 13 12 8 9 11 10 14 15 7 6 2 3 N1|N0 0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4 N2|N2 5 4 0 1 2 3 7 6 13 12 8 9 10 11 15 14 N7|N4 8 9 10 11 15 14 5 4 0 1 2 3 7 6 13 12 N6|N6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 7 6 N3|N3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 N1|N1 3 2 1 0 4 5 14 15 11 10 9 8 12 13 6 7 N3|N3 6 7 3 2 1 0 4 5 14 15 11 10 9 8 12 13 N6|N2 4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0 N0|N7 14 15 11 10 9 8 12 13 6 7 3 2 1 0 4 5 N2|N6 12 13 6 7 3 2 1 0 4 5 14 15 11 10 9 8 N4|N0 1 0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 N2|N2 4 5 1 0 2 3 7 6 14 15 11 10 8 9 13 12 N6|N4 8 9 13 12 4 5 1 0 2 3 7 6 14 15 11 10 N5|N6 13 12 4 5 1 0 2 3 7 6 14 15 11 10 8 9 N4|N5 11 10 8 9 13 12 4 5 1 0 2 3 7 6 14 15 N7|N1 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1 0 N0|N5 10 11 15 14 6 7 3 2 0 1 5 4 12 13 9 8 N4|N2 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0 1 N0|N0 0 1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 N1|N7 15 14 6 7 3 2 0 1 5 4 12 13 9 8 10 11 N5|N4 9 8 10 11 15 14 6 7 3 2 0 1 5 4 12 13 N6|N6 12 13 9 8 10 11 15 14 6 7 3 2 0 1 5 4 N2|", .netGdrLevel = -2, }; @@ -430,7 +430,7 @@ static struct rcclRomeModel rome_model_65 = { .connMatrix = { 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, }, .gdrLevel = { 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, }, .pattern = "42424242", - .ringBase = "N0 1 0 4 5 14 15 11 10 9 8 12 13 6 7 3 2 N1|N1 3 2 0 1 5 4 12 13 9 8 10 11 15 14 6 7 N3|N2 5 4 0 1 3 2 6 7 15 14 10 11 9 8 12 13 N6|N3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 N1|N5 11 10 8 9 13 12 4 5 1 0 2 3 7 6 14 15 N7|N6 13 12 8 9 11 10 14 15 7 6 2 3 1 0 4 5 N2|N2 4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0 N0|N7 15 14 6 7 3 2 0 1 5 4 12 13 9 8 10 11 N5|N0 0 1 3 2 6 7 15 14 10 11 9 8 12 13 5 4 N2|N1 2 3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 N0|N7 14 15 11 10 8 9 13 12 4 5 1 0 2 3 7 6 N3|N4 8 9 11 10 14 15 7 6 2 3 1 0 4 5 13 12 N6|N4 9 8 12 13 6 7 3 2 1 0 4 5 14 15 11 10 N5|N3 6 7 3 2 0 1 5 4 12 13 9 8 10 11 15 14 N7|N6 12 13 5 4 0 1 3 2 6 7 15 14 10 11 9 8 N4|N5 10 11 15 14 5 4 0 1 2 3 7 6 13 12 8 9 N4|", + .ringBase = "N4 9 8 12 13 5 4 0 1 3 2 6 7 15 14 10 11 N5|N5 10 11 9 8 12 13 5 4 0 1 3 2 6 7 15 14 N7|N7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 6 7 N3|N3 6 7 15 14 10 11 9 8 12 13 5 4 0 1 3 2 N1|N1 3 2 6 7 15 14 10 11 9 8 12 13 5 4 0 1 N0|N1 2 3 1 0 4 5 13 12 8 9 11 10 14 15 7 6 N3|N3 7 6 2 3 1 0 4 5 13 12 8 9 11 10 14 15 N7|N5 11 10 14 15 7 6 2 3 1 0 4 5 13 12 8 9 N4|N7 14 15 7 6 2 3 1 0 4 5 13 12 8 9 11 10 N5|N0 1 0 4 5 13 12 8 9 11 10 14 15 7 6 2 3 N1|N0 0 1 2 3 7 6 13 12 8 9 10 11 15 14 5 4 N2|N2 5 4 0 1 2 3 7 6 13 12 8 9 10 11 15 14 N7|N4 8 9 10 11 15 14 5 4 0 1 2 3 7 6 13 12 N6|N6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 7 6 N3|N3 7 6 13 12 8 9 10 11 15 14 5 4 0 1 2 3 N1|N1 3 2 1 0 4 5 14 15 11 10 9 8 12 13 6 7 N3|N3 6 7 3 2 1 0 4 5 14 15 11 10 9 8 12 13 N6|N2 4 5 14 15 11 10 9 8 12 13 6 7 3 2 1 0 N0|N7 14 15 11 10 9 8 12 13 6 7 3 2 1 0 4 5 N2|N6 12 13 6 7 3 2 1 0 4 5 14 15 11 10 9 8 N4|N0 1 0 2 3 7 6 14 15 11 10 8 9 13 12 4 5 N2|N2 4 5 1 0 2 3 7 6 14 15 11 10 8 9 13 12 N6|N4 8 9 13 12 4 5 1 0 2 3 7 6 14 15 11 10 N5|N6 13 12 4 5 1 0 2 3 7 6 14 15 11 10 8 9 N4|N5 11 10 8 9 13 12 4 5 1 0 2 3 7 6 14 15 N7|N1 2 3 7 6 14 15 11 10 8 9 13 12 4 5 1 0 N0|N5 10 11 15 14 6 7 3 2 0 1 5 4 12 13 9 8 N4|N2 5 4 12 13 9 8 10 11 15 14 6 7 3 2 0 1 N0|N0 0 1 5 4 12 13 9 8 10 11 15 14 6 7 3 2 N1|N7 15 14 6 7 3 2 0 1 5 4 12 13 9 8 10 11 N5|N4 9 8 10 11 15 14 6 7 3 2 0 1 5 4 12 13 N6|N6 12 13 9 8 10 11 15 14 6 7 3 2 0 1 5 4 N2|", .netGdrLevel = 5, }; diff --git a/src/init.cc b/src/init.cc index 83a2cc6ce2..23cf33035e 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1012,6 +1012,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].nc = 4; if (comm->topo->nodes[GPU].nodes[idx].gpu.gcn == 910) allGather3Data[rank].nc = std::max(allGather3Data[rank].nc, 4/ringGraph.nChannels); + if (ringGraph.nChannels > MAXCHANNELS/2) + allGather3Data[rank].nc = 1; allGather3Data[rank].tree.pattern = treeGraph.pattern; allGather3Data[rank].tree.nChannels = treeGraph.nChannels; allGather3Data[rank].tree.sameChannels = treeGraph.sameChannels; diff --git a/tools/topo_expl/utils.cpp b/tools/topo_expl/utils.cpp index eb7d1e331c..938a3f9354 100644 --- a/tools/topo_expl/utils.cpp +++ b/tools/topo_expl/utils.cpp @@ -682,6 +682,8 @@ ncclResult_t initTransportsRank_1(struct ncclComm* comm, struct allGather1Data_t allGather3Data[rank].nc = 4; if (comm->topo->nodes[GPU].nodes[idx].gpu.gcn == 910) allGather3Data[rank].nc = std::max(allGather3Data[rank].nc, 4/ringGraph.nChannels); + if (ringGraph.nChannels > MAXCHANNELS/2) + allGather3Data[rank].nc = 1; allGather3Data[rank].tree.pattern = treeGraph.pattern; allGather3Data[rank].tree.nChannels = treeGraph.nChannels; allGather3Data[rank].tree.sameChannels = treeGraph.sameChannels;