diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 317b58bcd0..557c440afc 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -356,6 +356,7 @@ static ncclResult_t commFree(ncclComm_t comm) { RCCL_PARAM(CliqueIgnoreTopo, "CLIQUE_IGNORE_TOPO", 0); RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0); +RCCL_PARAM(PivotAlltoallEnable, "PIVOT_ALLTOALL_ENABLE", 0); NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2); NCCL_PARAM(DisableGraphHelper, "GRAPH_HELPER_DISABLE", 0); NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 0); @@ -883,7 +884,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].collNet.typeIntra = collNetGraph.typeIntra; allGather3Data[rank].collNet.typeInter = collNetGraph.typeInter; allGather3Data[rank].collNetSupport = comm->collNetSupport; - allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled; + allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled && rcclParamPivotAlltoallEnable(); comm->nChannels = (comm->topo->nodes[GPU].count != comm->topo->nRanks && comm->topo->nodes[NET].count) ? std::min(treeGraph.nChannels, ringGraph.nChannels) : ringGraph.nChannels;