diff --git a/projects/rccl/src/graph/search.cc b/projects/rccl/src/graph/search.cc index 03e8ecf35f..3f5bab618b 100644 --- a/projects/rccl/src/graph/search.cc +++ b/projects/rccl/src/graph/search.cc @@ -13,7 +13,6 @@ static ncclResult_t ncclTopoFollowPath(struct ncclTopoGraph* graph, struct ncclT if (path->count == 0) return ncclSuccess; *node = NULL; - width /= 2; if (width > 0) { if (path->type > graph->type) return ncclSuccess; graph->type = std::max(graph->type, path->type); @@ -206,7 +205,7 @@ ncclResult_t ncclTopoSearchRecGpu(struct ncclTopoSystem* system, struct ncclTopo NCCLCHECK(ncclTopoCompareGraphs(graph, saveGraph, ©)); if (copy) { memcpy(saveGraph, graph, sizeof(struct ncclTopoGraph)); - if (graph->nChannels*graph->speedIntra/2 == maxSpeed) *time = -1; + if (graph->nChannels*graph->speedIntra == maxSpeed) *time = -1; } if (graph->nChannels < MAXCHANNELS/2) { NCCLCHECK(ncclTopoSearchRec(system, graph, saveGraph, maxSpeed, time)); @@ -513,7 +512,7 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph } // TODO : let user specify NICs graph->inter[0] = graph->inter[1] = 0; - graph->speedIntra = graph->speedInter = PCI_WIDTH+2; + graph->speedIntra = graph->speedInter = system->maxWidth; graph->nvlink = 0; if (graph->pattern == NCCL_TOPO_PATTERN_RING) { // Reverse the loop @@ -541,6 +540,7 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph search: int time = NCCL_SEARCH_TIMEOUT; + int stepSpeed = system->maxWidth/4; tmpGraph.nvlink = 1; tmpGraph.nChannels = 0; tmpGraph.sameChannels = 1; @@ -588,8 +588,8 @@ search: tmpGraph.crossNic = graph->crossNic; // Try to reduce speed per channel - tmpGraph.speedIntra = tmpGraph.speedInter -= 3; - if (tmpGraph.speedIntra >= bestSpeed/2 && tmpGraph.speedIntra >= 3) goto search; + tmpGraph.speedIntra = tmpGraph.speedInter -= stepSpeed; + if (tmpGraph.speedIntra >= bestSpeed/2 && tmpGraph.speedIntra >= stepSpeed) goto search; } done: @@ -600,7 +600,7 @@ done: } if (time != 0 && tmpGraph.pattern != NCCL_TOPO_PATTERN_RING && tmpGraph.speedIntra == graph->speedIntra) { // Try to increase the intra speed only but keeping nChannels the same - tmpGraph.speedIntra += 3; + tmpGraph.speedIntra += stepSpeed; maxSpeed = tmpGraph.speedIntra * graph->nChannels; if (tmpGraph.speedIntra <= tmpGraph.speedInter*2) goto search; } @@ -609,7 +609,7 @@ done: WARN("Could not find a path for pattern %d, falling back to simple order\n", graph->pattern); for (int i=0; iintra[i] = system->nodes[GPU].nodes[i].rank; graph->inter[0] = graph->inter[1] = 0; - graph->speedIntra = graph->speedInter = 3; + graph->speedIntra = graph->speedInter = stepSpeed; graph->nvlink = 0; graph->nChannels = 1; }