From 8f3219dbd4d707183775617a74112dbd1af957a9 Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Wed, 14 Sep 2022 15:29:30 +0000 Subject: [PATCH] make binary tree work on 2.13.4 --- src/collectives/device/all_reduce.h | 3 +-- src/enqueue.cc | 2 +- src/init.cc | 12 ++++++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index cdf1eaba4c..ea90cbd37f 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -376,8 +376,7 @@ namespace { const int nthreads = args->nWarps*WARP_SIZE; const int bid = args->bid; const int nChannels = args->nChannels; - ncclTree *tree = &ncclShmem->channel.tree; - //ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree; + ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem->channel.binTree : &ncclShmem->channel.tree; ssize_t chunkSize = int( Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize : Proto::calcBytePerStep()/sizeof(T)); diff --git a/src/enqueue.cc b/src/enqueue.cc index 8603c15e88..404a339e6d 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1156,7 +1156,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i if (info->nBytes <= 2200008) { info->protocol = NCCL_PROTO_LL; info->algorithm = NCCL_ALGO_TREE; - info->nChannels = 24; + info->nChannels = std::min(24, comm->nChannels); } else { info->protocol = NCCL_PROTO_SIMPLE; info->algorithm = NCCL_ALGO_RING; diff --git a/src/init.cc b/src/init.cc index 833597ab3c..fee43665ce 100644 --- a/src/init.cc +++ b/src/init.cc @@ -572,6 +572,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.channels[c].ring = comm->channels[c].ring; tmpCommAndChans.channels[c].ring.userRanks = comm->channels[c].devRingUserRanks; tmpCommAndChans.channels[c].tree = comm->channels[c].tree; + tmpCommAndChans.channels[c].binTree = comm->channels[c].binTree; tmpCommAndChans.channels[c].collTree = comm->channels[c].collTree; tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; @@ -1061,15 +1062,18 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclTree* binTree = &comm->channels[c].binTree; snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d", c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up); - snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d", - c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up); + if (comm->topo->pivotA2ANumBiRings == 3) + snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d", + c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up); INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev, comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId); } line[1023] = '\0'; - binline[1023] = '\0'; INFO(NCCL_INIT, "Trees%s comm %p nRanks %02d busId %lx", line, comm, comm->nRanks, comm->busId); - INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId); + if (comm->topo->pivotA2ANumBiRings == 3) { + binline[1023] = '\0'; + INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId); + } NCCLCHECK(computeBuffSizes(comm));