Fixing clique-topology detection (#342)

* Fixing clique-topology detection
* Fix to enable multi-process clique-based kernels

[ROCm/rccl commit: caba0a63d2]
Dieser Commit ist enthalten in:
gilbertlee-amd
2021-04-07 11:29:44 -06:00
committet von GitHub
Ursprung b4a7fa7011
Commit a7e99734e8
+31 -31
Datei anzeigen
@@ -823,9 +823,39 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
// Print final topology
NCCLCHECK(ncclTopoPrint(comm->topo));
// Get rings and trees
struct ncclTopoGraph ringGraph;
ringGraph.id = 0;
ringGraph.pattern = NCCL_TOPO_PATTERN_RING;
ringGraph.crossNic = ncclParamCrossNic();
ringGraph.collNet = 0;
ringGraph.minChannels = 1;
ringGraph.maxChannels = MAXCHANNELS/2;
NCCLCHECK(ncclTopoCompute(comm->topo, &ringGraph));
NCCLCHECK(ncclTopoPrintGraph(comm->topo, &ringGraph));
struct ncclTopoGraph treeGraph;
treeGraph.id = 1;
treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE;
treeGraph.crossNic = ncclParamCrossNic();
treeGraph.collNet = 0;
treeGraph.minChannels = comm->topo->nodes[NET].count != 0 ? 1 : ringGraph.nChannels;
treeGraph.maxChannels = ringGraph.nChannels;
NCCLCHECK(ncclTopoCompute(comm->topo, &treeGraph));
NCCLCHECK(ncclTopoPrintGraph(comm->topo, &treeGraph));
struct ncclTopoGraph collNetGraph;
collNetGraph.id = 2;
collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE;
collNetGraph.collNet = 1;
collNetGraph.crossNic = ncclParamCrossNic();
collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels;
NCCLCHECK(ncclTopoCompute(comm->topo, &collNetGraph));
NCCLCHECK(ncclTopoPrintGraph(comm->topo, &collNetGraph));
{ // [RCCL] Check if clique-based kernels can be enabled and initialize CliqueManager
CliqueManager::cliqueMode_t cliqueMode = CliqueManager::CLIQUE_DISABLED;
if (intraRanks == nranks)
if (comm->localRanks == comm->nRanks)
{
// Check that all the GPUs have peer access to one another
bool hasPeerAccess = true;
@@ -867,36 +897,6 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
NCCLCHECK(comm->cliqueManager->Init(commId, rootPid));
} // [/RCCL]
// Get rings and trees
struct ncclTopoGraph ringGraph;
ringGraph.id = 0;
ringGraph.pattern = NCCL_TOPO_PATTERN_RING;
ringGraph.crossNic = ncclParamCrossNic();
ringGraph.collNet = 0;
ringGraph.minChannels = 1;
ringGraph.maxChannels = MAXCHANNELS/2;
NCCLCHECK(ncclTopoCompute(comm->topo, &ringGraph));
NCCLCHECK(ncclTopoPrintGraph(comm->topo, &ringGraph));
struct ncclTopoGraph treeGraph;
treeGraph.id = 1;
treeGraph.pattern = tmpNnodes <= 2 ? NCCL_TOPO_PATTERN_TREE : NCCL_TOPO_PATTERN_BALANCED_TREE;
treeGraph.crossNic = ncclParamCrossNic();
treeGraph.collNet = 0;
treeGraph.minChannels = comm->topo->nodes[NET].count != 0 ? 1 : ringGraph.nChannels;
treeGraph.maxChannels = ringGraph.nChannels;
NCCLCHECK(ncclTopoCompute(comm->topo, &treeGraph));
NCCLCHECK(ncclTopoPrintGraph(comm->topo, &treeGraph));
struct ncclTopoGraph collNetGraph;
collNetGraph.id = 2;
collNetGraph.pattern = NCCL_TOPO_PATTERN_TREE;
collNetGraph.collNet = 1;
collNetGraph.crossNic = ncclParamCrossNic();
collNetGraph.minChannels = collNetGraph.maxChannels = ringGraph.nChannels;
NCCLCHECK(ncclTopoCompute(comm->topo, &collNetGraph));
NCCLCHECK(ncclTopoPrintGraph(comm->topo, &collNetGraph));
if (comm->rank == ncclParamGraphDumpFileRank()) {
struct ncclTopoGraph* graphs[3] = { &ringGraph, &treeGraph, &collNetGraph };
NCCLCHECK(ncclTopoDumpGraphs(comm->topo, 3, graphs));