Merge pull request #615 from edgargabriel/topic/two-trees

add binary tree

[ROCm/rccl commit: ea8120a346]
This commit is contained in:
Edgar Gabriel
2022-09-13 16:50:45 -05:00
committed by GitHub
6 changed files with 122 additions and 4 deletions
@@ -378,7 +378,7 @@ namespace {
const int nthreads = args->header.nWarps*WARP_SIZE;
const int bid = args->bid;
const int nChannels = args->nChannels;
ncclTree *tree = &ncclShmem->channel.tree;
ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree;
ssize_t chunkSize = int(
Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize
: Proto::calcBytePerStep()/sizeof(T));
+7 -1
View File
@@ -567,7 +567,13 @@ comp_next:
// Set nstepsPerLoop and nchunksPerLoop
NCCLCHECK(getPatternInfo(info));
NCCLCHECK(getLoopInfo(info));
if (info->comm->topo->pivotA2ANumBiRings == 3 ) work->pad_0 = 1;
if (info->comm->topo->pivotA2ANumBiRings == 3 ) {
if (ncclTypeSize(info->datatype)*info->count > 65536) {
work->pad_0 = 1;
} else {
work->pad_0 = 2;
}
}
work->opCount = info->opCount;
work->header.type = ncclWorkTypeColl;
work->sendbuff = info->sendbuff;
+95
View File
@@ -5,6 +5,25 @@
* See LICENSE.txt for license information
************************************************************************/
/*
* Code for binary tree based on the same function available in Open MPI
* File: ompi/mca/coll/base/coll_base_topo.c
*
* Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
* University Research and Technology
* Corporation. All rights reserved.
* Copyright (c) 2004-2015 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
* University of Stuttgart. All rights reserved.
* Copyright (c) 2004-2005 The Regents of the University of California.
* All rights reserved.
* Copyright (c) 2015 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
*/
#include "comm.h"
#include "graph.h"
#include "trees.h"
@@ -69,6 +88,82 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
return ncclSuccess;
}
static int calculate_level (int rank)
{
int level, num;
if( rank < 0 ) return -1;
for( level = 0, num = 0; num <= rank; level++ ) {
num += 1<<level;
}
return level-1;
}
static int calculate_num_nodes_up_to_level (int level)
{
return ((1<<level) - 1);
}
ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph) {
int nChannels = comm->nChannels;
int localRanks = 0;
for (int i=0; i<comm->topo->nodes[GPU].count; i++) {
localRanks += comm->topo->nodes[GPU].nodes[i].gpu.nRanksPerGpu;
}
for (int c=0; c<nChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
// Only the first rank on a GPU can be a treeRoot
int treeRoot = comm->topo->nodes[GPU].nodes[c%comm->topo->nodes[GPU].count].gpu.rank[0];
channel->binTree.up = -1;
channel->binTree.down[0] = -1;
channel->binTree.down[1] = -1;
channel->binTree.down[2] = -1;
/*
* Shift all ranks by root, so that the algorithm can be
* designed as if root would be always 0
* shiftedrank should be used in calculating distances
* and position in tree
*/
int shiftedrank = comm->rank - treeRoot;
if (shiftedrank < 0 ) {
shiftedrank += localRanks;
}
/* calculate my level */
int level = calculate_level (shiftedrank);
int delta = 1<<level;
/* find my children */
for (int i = 0; i < 2; i++) {
int schild = shiftedrank + delta * (i+1);
if (schild < localRanks) {
channel->binTree.down[i] = (schild+treeRoot)%localRanks;
}
}
/* find my parent */
int slimit = calculate_num_nodes_up_to_level (level);
int sparent = shiftedrank;
if (sparent < 2) {
sparent = 0;
}
else {
while (sparent >= slimit) {
sparent -= delta/2;
}
}
if (comm->rank != treeRoot) {
channel->binTree.up = (sparent+treeRoot)%localRanks;
}
}
return ncclSuccess;
}
ncclResult_t ncclTreeBasePostset(struct ncclComm* comm,
struct ncclTopoGraph* treeGraph) {
int nChannels = comm->nChannels;
+1
View File
@@ -277,6 +277,7 @@ struct ncclChannel {
struct {
struct ncclRing ring;
struct ncclTree tree;
struct ncclTree binTree;
struct ncclDirect collTree;
int id;
+2
View File
@@ -115,6 +115,8 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph);
ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph);
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph);
#include "info.h"
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time);
+16 -2
View File
@@ -1021,7 +1021,10 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS));
NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph, nc));
if (comm->topo->pivotA2ANumBiRings == 3) NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph));
if (comm->topo->pivotA2ANumBiRings == 3) {
NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph));
NCCLCHECK(ncclBinaryTreePostset(comm, &treeGraph));
}
free(allTopoRanks);
free(nodesTreePatterns);
@@ -1032,17 +1035,23 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
TRACE(NCCL_INIT, "rank %d nranks %d - BUILT %d TREES/RINGS", rank, nranks, comm->nChannels);
char line[1024];
char line[1024], binline[1024];
line[0]='\0';
binline[0]='\0';
for (int c=0; c<comm->nChannels; c++) {
struct ncclTree* tree = &comm->channels[c].tree;
struct ncclTree* binTree = &comm->channels[c].binTree;
snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d",
c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up);
snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d",
c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up);
INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev,
comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId);
}
line[1023] = '\0';
binline[1023] = '\0';
INFO(NCCL_INIT, "Trees%s comm %p nRanks %02d busId %lx", line, comm, comm->nRanks, comm->busId);
INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId);
NCCLCHECK(computeBuffSizes(comm));
@@ -1073,6 +1082,11 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
if (comm->nRanks == 1) continue;
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, affinity_restore);
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, affinity_restore);
// RCCL: need to connect binTree as well
if (comm->topo->pivotA2ANumBiRings == 3) {
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, NCCL_MAX_TREE_ARITY, channel->binTree.down, 1, &channel->binTree.up, 0), ret, affinity_restore);
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->binTree.up, NCCL_MAX_TREE_ARITY, channel->binTree.down, 0), ret, affinity_restore);
}
}
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, affinity_restore);
INFO(NCCL_INIT, "Connected all trees comm %p nRanks %02d busId %lx", comm, comm->nRanks, comm->busId);