Expose symbols for RCCL algo/proto/channels selection functions (#1923)

* Unhide symbols for algo/proto functions

* Add all_gather direct usage detection
This commit is contained in:
Mustafa Abduljabbar
2025-09-25 18:58:30 -04:00
committed by GitHub
parent cb14fccdcc
commit 7a329bbd94
4 changed files with 69 additions and 13 deletions
+5 -8
View File
@@ -79,8 +79,6 @@ const char* ncclProtoToString(int proto) {
}
}
RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304);
NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
@@ -110,9 +108,8 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
}
if (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() &&
rcclParamDirectAllGatherThreshold() > -1)) {
// use direct allgather
if (rcclUseAllGatherDirect(comm, msgSize)) {
// use direct allgather
if (sendcount == 0) return ncclSuccess;
size_t rankOffset = sendcount * ncclTypeSize(datatype);
if (((char*)recvbuff) != (((char*)sendbuff) + comm->rank * rankOffset)) {
@@ -123,7 +120,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
}
NCCLCHECK(ncclGroupStart());
for (int r = 0; r < nRanks; r++) {
int peer = (comm->rank + r) % nRanks;
int peer = (comm->rank + r) % nRanks;
if (in_place && (peer == comm->rank)) {
continue;
}
@@ -132,7 +129,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
}
NCCLCHECK(ncclGroupEnd());
return ncclSuccess;
} else {
} else {
// use ring allgather
return ncclEnqueueCheck(&info);
}
@@ -248,7 +245,7 @@ ncclResult_t ncclAllToAllv_impl(const void *sendbuff, const size_t sendcounts[],
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) {
NVTX3_FUNC_WITH_PARAMS(AllToAllv, NcclNvtxParamsAllToAllv,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype),
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype),
recvcounts[comm->rank] * ncclTypeSize(datatype), datatype));
if (!mscclIsCaller()) // when msccl falls back to
+12 -4
View File
@@ -24,7 +24,7 @@ THE SOFTWARE.
#include "nccl_common.h"
#include "nccl.h"
#include "param.h"
#include "core.h"
typedef enum RcclTunableColls {
RCCL_UNSUPPORTED_TUNABLE = -1,
RCCL_RS_TUNABLE = 0, // reduce_scatter index
@@ -47,6 +47,13 @@ typedef enum {
RCCL_VALUE_INVALID = -1
} rcclValueState_t;
typedef enum {
RCCL_DIRECT_ALLGATHER = NCCL_NUM_ALGORITHMS, // Direct AllGather
RCCL_MSCCL,
RCCL_MSCCLPP,
RCCL_ALGO_COUNT
} rcclAddonAlgos_t;
#ifdef RCCL_EXPOSE_STATIC
#define RCCL_STATIC_EXPOSE_CHECK()
#else
@@ -87,9 +94,10 @@ ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold);
void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType,
int collNetSupport, int nvlsSupport, int numPipeOps,
int* algo, int* protocol, int* maxChannels);
NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType, int collNetSupport, int nvlsSupport, int numPipeOps, int* algo, int* protocol, int* maxChannels);
NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName);
NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName);
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize);
void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable);
void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize);
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
+2 -1
View File
@@ -27,12 +27,13 @@ THE SOFTWARE.
RCCL_PARAM_DECLARE(EnableHipGraph); // Opt-in environment variable for enabling hipGraph
#define RCCL_EXPOSE_STATIC // Expose needed static functions for rccl-tests (or unit-testing in future)
#ifdef RCCL_EXPOSE_STATIC
#define rccl_static
#define rccl_static_inline inline
#else
#define rccl_static static
#define rccl_static_inline static inline
#endif
#endif`
#endif
+50
View File
@@ -33,6 +33,7 @@ RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0);
// Use this to assess impact of pipelining on performance.
// Otherwise, it is automatically set for certain archs, datatypes and reduction collectives
RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0);
RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304);
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
// Honor user input for protocol choice
@@ -234,6 +235,15 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co
int collNetSupport, int nvlsSupport, int numPipeOps,
int* algo, int* protocol, int* maxChannels) {
RCCL_STATIC_EXPOSE_CHECK();
int nRanks;
NCCLCHECK(ncclCommCount(comm, &nRanks));
size_t msgSize = count * ncclTypeSize(dataType) * nRanks;
if (coll == ncclFuncAllGather && rcclUseAllGatherDirect(comm, msgSize)) {
*algo = rcclAddonAlgos_t::RCCL_DIRECT_ALLGATHER;
*protocol = NCCL_PROTO_SIMPLE; // TODO: consider LL for small messages
*maxChannels = comm->nChannels;
return ncclSuccess;
}
struct ncclTaskColl task;
task.func = coll;
task.count = count;
@@ -245,6 +255,46 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co
return ncclSuccess;
}
ncclResult_t rcclGetAlgoName(int algo, const char** algoName) {
if (algo < 0 || algo >= RCCL_ALGO_COUNT) {
WARN("Invalid algorithm value: %d", algo);
return ncclInvalidArgument;
}
if(algo >= NCCL_NUM_ALGORITHMS) {
switch(algo) {
case rcclAddonAlgos_t::RCCL_DIRECT_ALLGATHER:
*algoName = "Direct";
break;
case rcclAddonAlgos_t::RCCL_MSCCL:
*algoName = "MSCCL";
break;
case rcclAddonAlgos_t::RCCL_MSCCLPP:
*algoName = "MSCCLPP";
break;
default:
WARN("Invalid algorithm value: %d", algo);
return ncclInvalidArgument;
}
return ncclSuccess;
}
*algoName = ncclAlgoToString(algo);
return ncclSuccess;
}
ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) {
if (protocol < 0 || protocol >= NCCL_NUM_PROTOCOLS) {
WARN("Invalid protocol value: %d", protocol);
return ncclInvalidArgument;
}
*protocolName = ncclProtoToString(protocol);
return ncclSuccess;
}
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) {
return (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() &&
rcclParamDirectAllGatherThreshold() > -1));
}
void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) {
static int pxnDisable = RCCL_VALUE_UNSET;
comm->enableCustColl = false;