From 7a329bbd9425264007257122bb7ed77f02078ffb Mon Sep 17 00:00:00 2001 From: Mustafa Abduljabbar Date: Thu, 25 Sep 2025 18:58:30 -0400 Subject: [PATCH] Expose symbols for RCCL algo/proto/channels selection functions (#1923) * Unhide symbols for algo/proto functions * Add all_gather direct usage detection --- src/collectives.cc | 13 ++++------ src/include/rccl_common.h | 16 +++++++++---- src/include/rccl_vars.h | 3 ++- src/rccl_wrap.cc | 50 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 13 deletions(-) diff --git a/src/collectives.cc b/src/collectives.cc index 9eabf0d0d8..120734280e 100644 --- a/src/collectives.cc +++ b/src/collectives.cc @@ -79,8 +79,6 @@ const char* ncclProtoToString(int proto) { } } -RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304); - NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); @@ -110,9 +108,8 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream); } - if (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() && - rcclParamDirectAllGatherThreshold() > -1)) { - // use direct allgather + if (rcclUseAllGatherDirect(comm, msgSize)) { + // use direct allgather if (sendcount == 0) return ncclSuccess; size_t rankOffset = sendcount * ncclTypeSize(datatype); if (((char*)recvbuff) != (((char*)sendbuff) + comm->rank * rankOffset)) { @@ -123,7 +120,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen } NCCLCHECK(ncclGroupStart()); for (int r = 0; r < nRanks; r++) { - int peer = (comm->rank + r) % nRanks; + int peer = (comm->rank + r) % nRanks; if (in_place && (peer == comm->rank)) { continue; } @@ -132,7 +129,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen } NCCLCHECK(ncclGroupEnd()); return ncclSuccess; - } else { + } else { // use ring allgather return ncclEnqueueCheck(&info); } @@ -248,7 +245,7 @@ ncclResult_t ncclAllToAllv_impl(const void *sendbuff, const size_t sendcounts[], void *recvbuff, const size_t recvcounts[], const size_t rdispls[], ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) { NVTX3_FUNC_WITH_PARAMS(AllToAllv, NcclNvtxParamsAllToAllv, - NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype), + NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype), recvcounts[comm->rank] * ncclTypeSize(datatype), datatype)); if (!mscclIsCaller()) // when msccl falls back to diff --git a/src/include/rccl_common.h b/src/include/rccl_common.h index 029180df6e..d464ea3566 100644 --- a/src/include/rccl_common.h +++ b/src/include/rccl_common.h @@ -24,7 +24,7 @@ THE SOFTWARE. #include "nccl_common.h" #include "nccl.h" #include "param.h" - +#include "core.h" typedef enum RcclTunableColls { RCCL_UNSUPPORTED_TUNABLE = -1, RCCL_RS_TUNABLE = 0, // reduce_scatter index @@ -47,6 +47,13 @@ typedef enum { RCCL_VALUE_INVALID = -1 } rcclValueState_t; +typedef enum { + RCCL_DIRECT_ALLGATHER = NCCL_NUM_ALGORITHMS, // Direct AllGather + RCCL_MSCCL, + RCCL_MSCCLPP, + RCCL_ALGO_COUNT +} rcclAddonAlgos_t; + #ifdef RCCL_EXPOSE_STATIC #define RCCL_STATIC_EXPOSE_CHECK() #else @@ -87,9 +94,10 @@ ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info); void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold); void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info); -ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType, - int collNetSupport, int nvlsSupport, int numPipeOps, - int* algo, int* protocol, int* maxChannels); +NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType, int collNetSupport, int nvlsSupport, int numPipeOps, int* algo, int* protocol, int* maxChannels); +NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName); +NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName); +bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize); void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable); void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize); ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount); diff --git a/src/include/rccl_vars.h b/src/include/rccl_vars.h index d24cf55880..ffdd0143e3 100644 --- a/src/include/rccl_vars.h +++ b/src/include/rccl_vars.h @@ -27,12 +27,13 @@ THE SOFTWARE. RCCL_PARAM_DECLARE(EnableHipGraph); // Opt-in environment variable for enabling hipGraph +#define RCCL_EXPOSE_STATIC // Expose needed static functions for rccl-tests (or unit-testing in future) #ifdef RCCL_EXPOSE_STATIC #define rccl_static #define rccl_static_inline inline #else #define rccl_static static #define rccl_static_inline static inline -#endif +#endif` #endif diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index 05a9cb3aed..1a622fdb71 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -33,6 +33,7 @@ RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0); // Use this to assess impact of pipelining on performance. // Otherwise, it is automatically set for certain archs, datatypes and reduction collectives RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0); +RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304); void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) { // Honor user input for protocol choice @@ -234,6 +235,15 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co int collNetSupport, int nvlsSupport, int numPipeOps, int* algo, int* protocol, int* maxChannels) { RCCL_STATIC_EXPOSE_CHECK(); + int nRanks; + NCCLCHECK(ncclCommCount(comm, &nRanks)); + size_t msgSize = count * ncclTypeSize(dataType) * nRanks; + if (coll == ncclFuncAllGather && rcclUseAllGatherDirect(comm, msgSize)) { + *algo = rcclAddonAlgos_t::RCCL_DIRECT_ALLGATHER; + *protocol = NCCL_PROTO_SIMPLE; // TODO: consider LL for small messages + *maxChannels = comm->nChannels; + return ncclSuccess; + } struct ncclTaskColl task; task.func = coll; task.count = count; @@ -245,6 +255,46 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co return ncclSuccess; } +ncclResult_t rcclGetAlgoName(int algo, const char** algoName) { + if (algo < 0 || algo >= RCCL_ALGO_COUNT) { + WARN("Invalid algorithm value: %d", algo); + return ncclInvalidArgument; + } + if(algo >= NCCL_NUM_ALGORITHMS) { + switch(algo) { + case rcclAddonAlgos_t::RCCL_DIRECT_ALLGATHER: + *algoName = "Direct"; + break; + case rcclAddonAlgos_t::RCCL_MSCCL: + *algoName = "MSCCL"; + break; + case rcclAddonAlgos_t::RCCL_MSCCLPP: + *algoName = "MSCCLPP"; + break; + default: + WARN("Invalid algorithm value: %d", algo); + return ncclInvalidArgument; + } + return ncclSuccess; + } + *algoName = ncclAlgoToString(algo); + return ncclSuccess; +} + +ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) { + if (protocol < 0 || protocol >= NCCL_NUM_PROTOCOLS) { + WARN("Invalid protocol value: %d", protocol); + return ncclInvalidArgument; + } + *protocolName = ncclProtoToString(protocol); + return ncclSuccess; +} + +bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) { + return (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() && + rcclParamDirectAllGatherThreshold() > -1)); +} + void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) { static int pxnDisable = RCCL_VALUE_UNSET; comm->enableCustColl = false;