From a9d09c65511c0a4bc355e26cde4e29691b13152a Mon Sep 17 00:00:00 2001 From: Bertan Dogancay <111835151+BertanDogancay@users.noreply.github.com> Date: Wed, 26 Feb 2025 09:48:03 -0500 Subject: [PATCH] Use bit reversal based mapping for multi-node (#1572) [ROCm/rccl commit: 85eb1f16bcdab1fa7abc754ad6a3606a9405fb12] --- projects/rccl/src/device/sendrecv.h | 2 +- projects/rccl/src/enqueue.cc | 6 +++--- projects/rccl/src/include/device.h | 22 ++++++++++++++++++---- projects/rccl/src/init.cc | 4 ++-- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/projects/rccl/src/device/sendrecv.h b/projects/rccl/src/device/sendrecv.h index 72ef207a4c..91ed574e02 100644 --- a/projects/rccl/src/device/sendrecv.h +++ b/projects/rccl/src/device/sendrecv.h @@ -163,7 +163,7 @@ struct RunWorkBatchsendBytes : work->recvBytes; int nParts = isSend ? work->nSendChannels : work->nRecvChannels; - int part = ncclP2pChannelToPart(work->nP2pChannels, work->channelBase, ncclShmem.channelId, ncclShmem.comm.p2pnChannelsPerPeer); + int part = ncclP2pChannelToPart(work->nP2pChannels, work->channelBase, ncclShmem.channelId, ncclShmem.comm.p2pnChannelsPerPeer, ncclShmem.comm.nNodes); hasWork = (part < nParts); if (nParts != 0) { size_t partBeg, partEnd; diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index ac58ac59c2..9dc9a7cd44 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -890,7 +890,7 @@ static ncclResult_t addP2pToPlan( if (!selfSend) { for (int part=0; part < nChannelsMax; part++) { - int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, nChannelsMax); + int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, nChannelsMax, comm->nNodes); struct ncclChannelPeer** channelPeers = comm->channels[channelId].peers; for (int dir=0; dir <= 1; dir++) { int peerRank = dir ? sendRank : recvRank; @@ -1006,7 +1006,7 @@ static ncclResult_t addP2pToPlan( nChannelsMax = std::max(nChannels[0], nChannels[1]); for (int part=0; part < nChannelsMax; part++) { - int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, comm->p2pnChannelsPerPeer); + int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, comm->p2pnChannelsPerPeer, comm->nNodes); plan->channelMask.masks[channelId/64] |= uint64_t(1)<<(channelId%64); // Add batch first. int funcIdx = ncclDevFuncId_P2p(); @@ -2133,7 +2133,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) { } uint8_t base = ncclP2pChannelBaseForRound(comm, round); for (int c=0; c < comm->p2pnChannelsPerPeer; c++) { - int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, c, comm->p2pnChannelsPerPeer); + int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, c, comm->p2pnChannelsPerPeer, comm->nNodes); if (isSendNotRecv) { if (comm->channels[channelId].peers[peer]->send[1].connected == 0) { // P2P uses only 1 connector //comm->connectSend[peer] |= (1UL< 1) { + // Only works because nP2pChannels is pow2 + int nChannelsLog2 = countOneBits(nP2pChannels-1); + int delta = reverseBits(part, nChannelsLog2); + return (base + delta) & (nP2pChannels-1); + } else { + return (base * nParts + part) & (nP2pChannels-1); + } } -inline __device__ int ncclP2pChannelToPart(int nP2pChannels, int base, int channel, int nParts) { - return (channel - base * nParts) & (nParts-1); +inline __device__ int ncclP2pChannelToPart(int nP2pChannels, int base, int channel, int nParts, int nNodes) { + if (nNodes > 1) { + // Only works because nP2pChannels is pow2 + int nChannelsLog2 = countOneBits(nP2pChannels-1); + int delta = (channel-base) & (nP2pChannels-1); + return reverseBits(delta, nChannelsLog2); + } else { + return (channel - base * nParts) & (nParts-1); + } } struct alignas(16) ncclDevWorkColl { diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 1bcd4e2620..f990efc04b 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -1677,11 +1677,11 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p uint8_t recvBase = ncclP2pChannelBaseForRound(comm, recvRound); for (int c=0; cp2pnChannelsPerPeer; c++) { int channelId; - channelId = ncclP2pChannelForPart(comm->p2pnChannels, sendBase, c, comm->p2pnChannelsPerPeer); + channelId = ncclP2pChannelForPart(comm->p2pnChannels, sendBase, c, comm->p2pnChannelsPerPeer, comm->nNodes); if (comm->channels[channelId].peers[peer]->send[1].connected == 0) { comm->connectSend[peer].masks[channelId/64] |= (1UL<<(channelId%64)); } - channelId = ncclP2pChannelForPart(comm->p2pnChannels, recvBase, c, comm->p2pnChannelsPerPeer); + channelId = ncclP2pChannelForPart(comm->p2pnChannels, recvBase, c, comm->p2pnChannelsPerPeer, comm->nNodes); if (comm->channels[channelId].peers[peer]->recv[1].connected == 0) { comm->connectRecv[peer].masks[channelId/64] |= (1UL<<(channelId%64)); }