diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 47df7f06a3..9e4ff00688 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -1338,7 +1338,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p allGather3Data[rank].nc = std::max(allGather3Data[rank].nc, 4/ringGraph->nChannels); if (ringGraph->nChannels > MAXCHANNELS/2) allGather3Data[rank].nc = 1; - if (IsArchMatch(comm->topo->nodes[GPU].nodes[idx].gpu.gcn, "gfx94") || IsArchMatch(comm->topo->nodes[GPU].nodes[idx].gpu.gcn, "gfx950")) { + if (IsArchMatch(comm->topo->nodes[GPU].nodes[idx].gpu.gcn, "gfx94")) { // Multi-node MI300A int managed = 0; CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0)); @@ -1355,6 +1355,9 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p allGather3Data[rank].nc = 4; } } + if (IsArchMatch(comm->topo->nodes[GPU].nodes[idx].gpu.gcn, "gfx950")) { + allGather3Data[rank].nc = 4; + } allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled && rcclParamPivotAlltoallEnable(); comm->topo->ll128Enabled = comm->topo->ll128Enabled || rcclParamLL128ForceEnable();