diff --git a/projects/rccl/src/clique/CliqueManager.cc b/projects/rccl/src/clique/CliqueManager.cc index dde0cf0085..91990d2f3b 100644 --- a/projects/rccl/src/clique/CliqueManager.cc +++ b/projects/rccl/src/clique/CliqueManager.cc @@ -46,8 +46,8 @@ int* CliqueManager::m_staticGpuBarrierMem = NULL; // Define some environment variables that affect clique-based kernels RCCL_PARAM(EnableClique, "ENABLE_CLIQUE", 0); // Opt-in environment variable for clique-based kernels -RCCL_PARAM(AllReduceCliqueByteLimit, "CLIQUE_ALLREDUCE_BYTE_LIMIT", 2097152); // Max number of bytes to use clique-based kernels for all reduce -RCCL_PARAM(AllReduceNumChannels, "CLIQUE_ALLREDUCE_NCHANNELS", 4); // Number of channels to use for all-reduce +RCCL_PARAM(AllReduceCliqueByteLimit, "CLIQUE_ALLREDUCE_BYTE_LIMIT", 16777216); // Max number of bytes to use clique-based kernels for all reduce +RCCL_PARAM(AllReduceNumChannels, "CLIQUE_ALLREDUCE_NCHANNELS", 0); // Number of channels to use for all-reduce. (0 for auto-select) RCCL_PARAM(CliqueDebug, "CLIQUE_DEBUG", 0); // Emit debug messages CliqueManager::CliqueManager(int const rank, @@ -321,7 +321,22 @@ ncclResult_t CliqueManager::GetNumChannelsToUse(ncclFunc_t const coll, *numChannelstoUse = 1; if (coll == ncclCollAllReduce) { - *numChannelstoUse = std::min((int)rcclParamAllReduceNumChannels(), totalNumChannels); + if (rcclParamAllReduceNumChannels() == 0) + { + // NOTE: These are currently based on collected data and not necessarily ideal for all hardware + int numChannels; + if (totalBytes <= 65536) numChannels = 1; + else if (totalBytes <= 262144) numChannels = 2; + else if (totalBytes <= 524288) numChannels = 4; + else if (totalBytes <= 2097152) numChannels = 8; + else numChannels = 11; + + *numChannelstoUse = std::min(numChannels, totalNumChannels); + } + else + { + *numChannelstoUse = std::min((int)rcclParamAllReduceNumChannels(), totalNumChannels); + } } return ncclSuccess; @@ -344,9 +359,6 @@ ncclResult_t CliqueManager::SetCliqueCollectiveArgs(CollectiveArgs* args) args->clique.ptrs = &m_pinnedCliquePtrs[opIndex]; args->clique.verbose = rcclParamCliqueDebug(); - // Determine number of channels to use for this collective - args->clique.nChannels = rcclParamAllReduceNumChannels(); - return ncclSuccess; }