diff --git a/src/init.cc b/src/init.cc index ed20796f51..e6f2c05d54 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1485,12 +1485,21 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p if (ringGraph.nChannels > MAXCHANNELS/2) allGather3Data[rank].nc = 1; if (IsArchMatch(comm->topo->nodes[GPU].nodes[idx].gpu.gcn, "gfx94")) { - if (nranks == 2) - // NCCL_MIN_NCHANNELS=32 - allGather3Data[rank].nc = 16; - else if (nranks == 4) - // NCCL_MIN_NCHANNELS=24 - allGather3Data[rank].nc = 4; + // Multi-node MI300A + int managed = 0; + CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0)); + if (managed && nNodes > 1) { + // This forces the minimum channels to 24 + allGather3Data[rank].nc = 6; + } else { + // MI300X + if (nranks == 2) + // NCCL_MIN_NCHANNELS=32 + allGather3Data[rank].nc = 16; + else if (nranks == 4) + // NCCL_MIN_NCHANNELS=24 + allGather3Data[rank].nc = 4; + } } allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled && rcclParamPivotAlltoallEnable();