diff --git a/src/graph/search.cc b/src/graph/search.cc index fd30ee88f1..d6a1628ca2 100644 --- a/src/graph/search.cc +++ b/src/graph/search.cc @@ -286,7 +286,36 @@ ncclResult_t ncclTopoSearchTryGpu(struct ncclTopoSystem* system, struct ncclTopo return ncclSuccess; } -ncclResult_t ncclTopoCompareGraphs(struct ncclTopoGraph* graph, struct ncclTopoGraph* refGraph, int* copy) { +static int ncclTopoCountXGMI(struct ncclTopoSystem* system, struct ncclTopoGraph* graph) { + int ngpus = system->nodes[GPU].count; + int count = 0; + for (int c=0; cnChannels; c++) { + for (int i=0; iintra[ngpus*c+i]; + int n = graph->intra[ngpus*c+((i+1)%ngpus)]; + struct ncclTopoNode *node; + int j; + for (j=0; jnodes[GPU].nodes[j].gpu.rank == g) break; + if (jnodes[GPU].nodes+j; + for (int k = 0; knodes[GPU].count; k++) { + if (node->paths[GPU][k].count == 1) { + struct ncclTopoLink* link = node->paths[GPU][k].list[0]; + struct ncclTopoNode* remNode = link->remNode; + if (remNode->gpu.rank == n) { + if (link->type == LINK_NVL) + count ++; + } + } + } + } + } + } + return count; +} + +ncclResult_t ncclTopoCompareGraphs(struct ncclTopoSystem* system, struct ncclTopoGraph* graph, struct ncclTopoGraph* refGraph, int* copy) { // 1. Constraint to get the same nChannels between Rings and Trees if (graph->nChannels < graph->minChannels) return ncclSuccess; @@ -298,6 +327,10 @@ ncclResult_t ncclTopoCompareGraphs(struct ncclTopoGraph* graph, struct ncclTopoG } // 3. Less hops (but not at the price of going cross NICs) if (graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1; + + // 4. Prefer graph with more XGMI connections + if (graph->nChannels == refGraph->nChannels + && ncclTopoCountXGMI(system, refGraph) < ncclTopoCountXGMI(system, graph)) *copy = 1; return ncclSuccess; } @@ -310,7 +343,7 @@ ncclResult_t ncclTopoSearchRecGpu(struct ncclTopoSystem* system, struct ncclTopo // Determine whether we found a better solution or not int copy = 0; graph->nChannels++; - NCCLCHECK(ncclTopoCompareGraphs(graph, saveGraph, ©)); + NCCLCHECK(ncclTopoCompareGraphs(system, graph, saveGraph, ©)); if (copy) { memcpy(saveGraph, graph, sizeof(struct ncclTopoGraph)); if (graph->nChannels == graph->maxChannels) *time = -1; diff --git a/src/graph/topo.h b/src/graph/topo.h index f07e6a1310..8d9c82ccca 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -21,10 +21,7 @@ #define P9_WIDTH 32.0 #define ARM_WIDTH 6.0 #define NET_WIDTH 12.0 // 100Gbit -#define VEGA_XGMI_WIDTH 20.0 -#define ROME_QPI_WIDTH 18.0 -#define ROME_PCI_WIDTH 18.0 -#define ROME_CPUPCI_WIDTH 18.0 +#define VEGA_XGMI_WIDTH 24.0 // Intel CPU convert GPU P2P traffic into 64B PCI TLPs, so GPU // to GPU traffic consumes more PCI bandwidth. diff --git a/tools/scripts/topo_val.sh b/tools/scripts/topo_val.sh index 111c8f7bf7..02273a3d1d 100755 --- a/tools/scripts/topo_val.sh +++ b/tools/scripts/topo_val.sh @@ -21,7 +21,7 @@ DIR="$(cd -P "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -for i in {0..12} +for i in {0..13} do $DIR/../topo_expl/topo_expl -m $i > "topo_m$i.log" $DIR/../TopoVisual/topo_visual.sh -i "topo_m$i.log" diff --git a/tools/topo_expl/models/topo_8p_rome.xml b/tools/topo_expl/models/topo_8p_rome.xml new file mode 100644 index 0000000000..ea996dc260 --- /dev/null +++ b/tools/topo_expl/models/topo_8p_rome.xml @@ -0,0 +1,75 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tools/topo_expl/topo_expl.cpp b/tools/topo_expl/topo_expl.cpp index 4f2df2ef4e..ac384a042d 100644 --- a/tools/topo_expl/topo_expl.cpp +++ b/tools/topo_expl/topo_expl.cpp @@ -76,6 +76,7 @@ const char *model_descriptions[] = { "4 nodes with 8 VEGA20 GPUs XGMI 4P2H 1 NIC", "4 nodes with 8 VEGA20 GPUs XGMI 4P2H 1 NIC 2nd Hive", "4 nodes with 8 VEGA20 GPUs XGMI 4P2H 2 NIC", + "single node 8 VEGA20 Rome", NULL, }; @@ -166,6 +167,10 @@ int main(int argc,char* argv[]) network.AddNode(node); } break; + case 13: + node = new NodeModel("topo_8p_rome.xml"); + network.AddNode(node); + break; default: printf("Invalid model_id %d\n", model_id); exit(0);