diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index cefe38cad6..69008fc52a 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -823,9 +823,39 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Print final topology NCCLCHECK(ncclTopoPrint(comm->topo)); + // Get rings and trees + struct ncclTopoGraph ringGraph; + ringGraph.id = 0; + ringGraph.pattern = NCCL_TOPO_PATTERN_RING; + ringGraph.crossNic = ncclParamCrossNic(); + ringGraph.collNet = 0; + ringGraph.minChannels = 1; + ringGraph.maxChannels = MAXCHANNELS/2; + NCCLCHECK(ncclTopoCompute(comm->topo, &ringGraph)); + NCCLCHECK(ncclTopoPrintGraph(comm->topo, &ringGraph)); + + struct ncclTopoGraph treeGraph; + treeGraph.id = 1; + treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE; + treeGraph.crossNic = ncclParamCrossNic(); + treeGraph.collNet = 0; + treeGraph.minChannels = comm->topo->nodes[NET].count != 0 ? 1 : ringGraph.nChannels; + treeGraph.maxChannels = ringGraph.nChannels; + NCCLCHECK(ncclTopoCompute(comm->topo, &treeGraph)); + NCCLCHECK(ncclTopoPrintGraph(comm->topo, &treeGraph)); + + struct ncclTopoGraph collNetGraph; + collNetGraph.id = 2; + collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE; + collNetGraph.collNet = 1; + collNetGraph.crossNic = ncclParamCrossNic(); + collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels; + NCCLCHECK(ncclTopoCompute(comm->topo, &collNetGraph)); + NCCLCHECK(ncclTopoPrintGraph(comm->topo, &collNetGraph)); + { // [RCCL] Check if clique-based kernels can be enabled and initialize CliqueManager CliqueManager::cliqueMode_t cliqueMode = CliqueManager::CLIQUE_DISABLED; - if (intraRanks == nranks) + if (comm->localRanks == comm->nRanks) { // Check that all the GPUs have peer access to one another bool hasPeerAccess = true; @@ -867,36 +897,6 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm NCCLCHECK(comm->cliqueManager->Init(commId, rootPid)); } // [/RCCL] - // Get rings and trees - struct ncclTopoGraph ringGraph; - ringGraph.id = 0; - ringGraph.pattern = NCCL_TOPO_PATTERN_RING; - ringGraph.crossNic = ncclParamCrossNic(); - ringGraph.collNet = 0; - ringGraph.minChannels = 1; - ringGraph.maxChannels = MAXCHANNELS/2; - NCCLCHECK(ncclTopoCompute(comm->topo, &ringGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &ringGraph)); - - struct ncclTopoGraph treeGraph; - treeGraph.id = 1; - treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE; - treeGraph.crossNic = ncclParamCrossNic(); - treeGraph.collNet = 0; - treeGraph.minChannels = comm->topo->nodes[NET].count != 0 ? 1 : ringGraph.nChannels; - treeGraph.maxChannels = ringGraph.nChannels; - NCCLCHECK(ncclTopoCompute(comm->topo, &treeGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &treeGraph)); - - struct ncclTopoGraph collNetGraph; - collNetGraph.id = 2; - collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE; - collNetGraph.collNet = 1; - collNetGraph.crossNic = ncclParamCrossNic(); - collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels; - NCCLCHECK(ncclTopoCompute(comm->topo, &collNetGraph)); - NCCLCHECK(ncclTopoPrintGraph(comm->topo, &collNetGraph)); - if (comm->rank == ncclParamGraphDumpFileRank()) { struct ncclTopoGraph* graphs[3] = { &ringGraph, &treeGraph, &collNetGraph }; NCCLCHECK(ncclTopoDumpGraphs(comm->topo, 3, graphs));