Merge pull request #617 from edgargabriel/binary-tree-2.13.4
make binary tree work on 2.13.4
Cette révision appartient à :
@@ -376,8 +376,7 @@ namespace {
|
||||
const int nthreads = args->nWarps*WARP_SIZE;
|
||||
const int bid = args->bid;
|
||||
const int nChannels = args->nChannels;
|
||||
ncclTree *tree = &ncclShmem->channel.tree;
|
||||
//ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree;
|
||||
ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree;
|
||||
ssize_t chunkSize = int(
|
||||
Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize
|
||||
: Proto::calcBytePerStep()/sizeof(T));
|
||||
|
||||
@@ -1156,7 +1156,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i
|
||||
if (info->nBytes <= 2200008) {
|
||||
info->protocol = NCCL_PROTO_LL;
|
||||
info->algorithm = NCCL_ALGO_TREE;
|
||||
info->nChannels = 24;
|
||||
info->nChannels = std::min(24, comm->nChannels);
|
||||
} else {
|
||||
info->protocol = NCCL_PROTO_SIMPLE;
|
||||
info->algorithm = NCCL_ALGO_RING;
|
||||
|
||||
+8
-4
@@ -572,6 +572,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
|
||||
tmpCommAndChans.channels[c].ring = comm->channels[c].ring;
|
||||
tmpCommAndChans.channels[c].ring.userRanks = comm->channels[c].devRingUserRanks;
|
||||
tmpCommAndChans.channels[c].tree = comm->channels[c].tree;
|
||||
tmpCommAndChans.channels[c].binTree = comm->channels[c].binTree;
|
||||
tmpCommAndChans.channels[c].collTree = comm->channels[c].collTree;
|
||||
tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c];
|
||||
|
||||
@@ -1061,15 +1062,18 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
struct ncclTree* binTree = &comm->channels[c].binTree;
|
||||
snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d",
|
||||
c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up);
|
||||
snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d",
|
||||
c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up);
|
||||
if (comm->topo->pivotA2ANumBiRings == 3)
|
||||
snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d",
|
||||
c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up);
|
||||
INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev,
|
||||
comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId);
|
||||
}
|
||||
line[1023] = '\0';
|
||||
binline[1023] = '\0';
|
||||
INFO(NCCL_INIT, "Trees%s comm %p nRanks %02d busId %lx", line, comm, comm->nRanks, comm->busId);
|
||||
INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId);
|
||||
if (comm->topo->pivotA2ANumBiRings == 3) {
|
||||
binline[1023] = '\0';
|
||||
INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId);
|
||||
}
|
||||
|
||||
NCCLCHECK(computeBuffSizes(comm));
|
||||
|
||||
|
||||
Référencer dans un nouveau ticket
Bloquer un utilisateur