diff --git a/src/include/device.h b/src/include/device.h index a9f26141b0..ccab7c7f63 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -280,7 +280,7 @@ inline __host__ uint8_t ncclP2pChannelBaseForRound(struct ncclComm* comm, int p2 // ncclP2pChannelToPart and ncclP2pChannelForPart are inverses. The device code // uses ncclP2pChannelToPart to determine which part "this" channel is responsible for. inline __host__ int ncclP2pChannelForPart(int nP2pChannels, int base, int part, int nParts, int nNodes) { - if (nNodes > 1) { + if (nNodes > 2) { // Only works because nP2pChannels is pow2 int nChannelsLog2 = countOneBits(nP2pChannels-1); int delta = reverseBits(part, nChannelsLog2); @@ -290,7 +290,7 @@ inline __host__ int ncclP2pChannelForPart(int nP2pChannels, int base, int part, } } inline __device__ int ncclP2pChannelToPart(int nP2pChannels, int base, int channel, int nParts, int nNodes) { - if (nNodes > 1) { + if (nNodes > 2) { // Only works because nP2pChannels is pow2 int nChannelsLog2 = countOneBits(nP2pChannels-1); int delta = (channel-base) & (nP2pChannels-1);