Merge pull request #615 from edgargabriel/topic/two-trees
add binary tree
[ROCm/rccl commit: ea8120a346]
This commit is contained in:
@@ -378,7 +378,7 @@ namespace {
|
||||
const int nthreads = args->header.nWarps*WARP_SIZE;
|
||||
const int bid = args->bid;
|
||||
const int nChannels = args->nChannels;
|
||||
ncclTree *tree = &ncclShmem->channel.tree;
|
||||
ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize
|
||||
: Proto::calcBytePerStep()/sizeof(T));
|
||||
|
||||
@@ -567,7 +567,13 @@ comp_next:
|
||||
// Set nstepsPerLoop and nchunksPerLoop
|
||||
NCCLCHECK(getPatternInfo(info));
|
||||
NCCLCHECK(getLoopInfo(info));
|
||||
if (info->comm->topo->pivotA2ANumBiRings == 3 ) work->pad_0 = 1;
|
||||
if (info->comm->topo->pivotA2ANumBiRings == 3 ) {
|
||||
if (ncclTypeSize(info->datatype)*info->count > 65536) {
|
||||
work->pad_0 = 1;
|
||||
} else {
|
||||
work->pad_0 = 2;
|
||||
}
|
||||
}
|
||||
work->opCount = info->opCount;
|
||||
work->header.type = ncclWorkTypeColl;
|
||||
work->sendbuff = info->sendbuff;
|
||||
|
||||
@@ -5,6 +5,25 @@
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/*
|
||||
* Code for binary tree based on the same function available in Open MPI
|
||||
* File: ompi/mca/coll/base/coll_base_topo.c
|
||||
*
|
||||
* Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
|
||||
* University Research and Technology
|
||||
* Corporation. All rights reserved.
|
||||
* Copyright (c) 2004-2015 The University of Tennessee and The University
|
||||
* of Tennessee Research Foundation. All rights
|
||||
* reserved.
|
||||
* Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
|
||||
* University of Stuttgart. All rights reserved.
|
||||
* Copyright (c) 2004-2005 The Regents of the University of California.
|
||||
* All rights reserved.
|
||||
* Copyright (c) 2015 Research Organization for Information Science
|
||||
* and Technology (RIST). All rights reserved.
|
||||
*/
|
||||
|
||||
|
||||
#include "comm.h"
|
||||
#include "graph.h"
|
||||
#include "trees.h"
|
||||
@@ -69,6 +88,82 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm,
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static int calculate_level (int rank)
|
||||
{
|
||||
int level, num;
|
||||
if( rank < 0 ) return -1;
|
||||
for( level = 0, num = 0; num <= rank; level++ ) {
|
||||
num += 1<<level;
|
||||
}
|
||||
return level-1;
|
||||
}
|
||||
|
||||
static int calculate_num_nodes_up_to_level (int level)
|
||||
{
|
||||
return ((1<<level) - 1);
|
||||
}
|
||||
|
||||
ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm,
|
||||
struct ncclTopoGraph* treeGraph) {
|
||||
int nChannels = comm->nChannels;
|
||||
int localRanks = 0;
|
||||
for (int i=0; i<comm->topo->nodes[GPU].count; i++) {
|
||||
localRanks += comm->topo->nodes[GPU].nodes[i].gpu.nRanksPerGpu;
|
||||
}
|
||||
|
||||
for (int c=0; c<nChannels; c++) {
|
||||
struct ncclChannel* channel = comm->channels+c;
|
||||
// Only the first rank on a GPU can be a treeRoot
|
||||
int treeRoot = comm->topo->nodes[GPU].nodes[c%comm->topo->nodes[GPU].count].gpu.rank[0];
|
||||
|
||||
channel->binTree.up = -1;
|
||||
channel->binTree.down[0] = -1;
|
||||
channel->binTree.down[1] = -1;
|
||||
channel->binTree.down[2] = -1;
|
||||
|
||||
/*
|
||||
* Shift all ranks by root, so that the algorithm can be
|
||||
* designed as if root would be always 0
|
||||
* shiftedrank should be used in calculating distances
|
||||
* and position in tree
|
||||
*/
|
||||
int shiftedrank = comm->rank - treeRoot;
|
||||
if (shiftedrank < 0 ) {
|
||||
shiftedrank += localRanks;
|
||||
}
|
||||
|
||||
/* calculate my level */
|
||||
int level = calculate_level (shiftedrank);
|
||||
int delta = 1<<level;
|
||||
|
||||
/* find my children */
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int schild = shiftedrank + delta * (i+1);
|
||||
if (schild < localRanks) {
|
||||
channel->binTree.down[i] = (schild+treeRoot)%localRanks;
|
||||
}
|
||||
}
|
||||
|
||||
/* find my parent */
|
||||
int slimit = calculate_num_nodes_up_to_level (level);
|
||||
int sparent = shiftedrank;
|
||||
if (sparent < 2) {
|
||||
sparent = 0;
|
||||
}
|
||||
else {
|
||||
while (sparent >= slimit) {
|
||||
sparent -= delta/2;
|
||||
}
|
||||
}
|
||||
if (comm->rank != treeRoot) {
|
||||
channel->binTree.up = (sparent+treeRoot)%localRanks;
|
||||
}
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
||||
ncclResult_t ncclTreeBasePostset(struct ncclComm* comm,
|
||||
struct ncclTopoGraph* treeGraph) {
|
||||
int nChannels = comm->nChannels;
|
||||
|
||||
@@ -277,6 +277,7 @@ struct ncclChannel {
|
||||
struct {
|
||||
struct ncclRing ring;
|
||||
struct ncclTree tree;
|
||||
struct ncclTree binTree;
|
||||
struct ncclDirect collTree;
|
||||
|
||||
int id;
|
||||
|
||||
@@ -115,6 +115,8 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
|
||||
|
||||
ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph);
|
||||
|
||||
ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph);
|
||||
|
||||
ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph);
|
||||
#include "info.h"
|
||||
ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time);
|
||||
|
||||
@@ -1021,7 +1021,10 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
NCCLCHECK(ncclCalloc(&rings, nranks*MAXCHANNELS));
|
||||
NCCLCHECK(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph, nc));
|
||||
|
||||
if (comm->topo->pivotA2ANumBiRings == 3) NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph));
|
||||
if (comm->topo->pivotA2ANumBiRings == 3) {
|
||||
NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph));
|
||||
NCCLCHECK(ncclBinaryTreePostset(comm, &treeGraph));
|
||||
}
|
||||
|
||||
free(allTopoRanks);
|
||||
free(nodesTreePatterns);
|
||||
@@ -1032,17 +1035,23 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d - BUILT %d TREES/RINGS", rank, nranks, comm->nChannels);
|
||||
|
||||
char line[1024];
|
||||
char line[1024], binline[1024];
|
||||
line[0]='\0';
|
||||
binline[0]='\0';
|
||||
for (int c=0; c<comm->nChannels; c++) {
|
||||
struct ncclTree* tree = &comm->channels[c].tree;
|
||||
struct ncclTree* binTree = &comm->channels[c].binTree;
|
||||
snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d",
|
||||
c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up);
|
||||
snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d",
|
||||
c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up);
|
||||
INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev,
|
||||
comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId);
|
||||
}
|
||||
line[1023] = '\0';
|
||||
binline[1023] = '\0';
|
||||
INFO(NCCL_INIT, "Trees%s comm %p nRanks %02d busId %lx", line, comm, comm->nRanks, comm->busId);
|
||||
INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId);
|
||||
|
||||
NCCLCHECK(computeBuffSizes(comm));
|
||||
|
||||
@@ -1073,6 +1082,11 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
if (comm->nRanks == 1) continue;
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, affinity_restore);
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, affinity_restore);
|
||||
// RCCL: need to connect binTree as well
|
||||
if (comm->topo->pivotA2ANumBiRings == 3) {
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, NCCL_MAX_TREE_ARITY, channel->binTree.down, 1, &channel->binTree.up, 0), ret, affinity_restore);
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->binTree.up, NCCL_MAX_TREE_ARITY, channel->binTree.down, 0), ret, affinity_restore);
|
||||
}
|
||||
}
|
||||
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, affinity_restore);
|
||||
INFO(NCCL_INIT, "Connected all trees comm %p nRanks %02d busId %lx", comm, comm->nRanks, comm->busId);
|
||||
|
||||
Reference in New Issue
Block a user