Merge pull request #617 from edgargabriel/binary-tree-2.13.4

make binary tree work on 2.13.4
Cette révision appartient à :
Edgar Gabriel
2022-09-14 20:30:11 -05:00
révisé par GitHub
révision 05cc7bd850
3 fichiers modifiés avec 10 ajouts et 7 suppressions
+1 -2
Voir le fichier
@@ -376,8 +376,7 @@ namespace {
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
const int nChannels = args->nChannels;
ncclTree *tree = &ncclShmem->channel.tree;
//ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree;
ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree;
ssize_t chunkSize = int(
Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize
: Proto::calcBytePerStep()/sizeof(T));
+1 -1
Voir le fichier
@@ -1156,7 +1156,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
if (info->nBytes <= 2200008) {
info->protocol = NCCL_PROTO_LL;
info->algorithm = NCCL_ALGO_TREE;
info->nChannels = 24;
info->nChannels = std::min(24, comm->nChannels);
} else {
info->protocol = NCCL_PROTO_SIMPLE;
info->algorithm = NCCL_ALGO_RING;
+8 -4
Voir le fichier
@@ -572,6 +572,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
tmpCommAndChans.channels[c].ring = comm->channels[c].ring;
tmpCommAndChans.channels[c].ring.userRanks = comm->channels[c].devRingUserRanks;
tmpCommAndChans.channels[c].tree = comm->channels[c].tree;
tmpCommAndChans.channels[c].binTree = comm->channels[c].binTree;
tmpCommAndChans.channels[c].collTree = comm->channels[c].collTree;
tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c];
@@ -1061,15 +1062,18 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
struct ncclTree* binTree = &comm->channels[c].binTree;
snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d",
c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up);
snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d",
c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up);
if (comm->topo->pivotA2ANumBiRings == 3)
snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d",
c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up);
INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev,
comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId);
}
line[1023] = '\0';
binline[1023] = '\0';
INFO(NCCL_INIT, "Trees%s comm %p nRanks %02d busId %lx", line, comm, comm->nRanks, comm->busId);
INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId);
if (comm->topo->pivotA2ANumBiRings == 3) {
binline[1023] = '\0';
INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId);
}
NCCLCHECK(computeBuffSizes(comm));