Expose symbols for RCCL algo/proto/channels selection functions (#1923)
* Unhide symbols for algo/proto functions * Add all_gather direct usage detection
This commit is contained in:
committed by
GitHub
parent
cb14fccdcc
commit
7a329bbd94
+5
-8
@@ -79,8 +79,6 @@ const char* ncclProtoToString(int proto) {
|
||||
}
|
||||
}
|
||||
|
||||
RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304);
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
|
||||
|
||||
@@ -110,9 +108,8 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
|
||||
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
|
||||
}
|
||||
|
||||
if (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() &&
|
||||
rcclParamDirectAllGatherThreshold() > -1)) {
|
||||
// use direct allgather
|
||||
if (rcclUseAllGatherDirect(comm, msgSize)) {
|
||||
// use direct allgather
|
||||
if (sendcount == 0) return ncclSuccess;
|
||||
size_t rankOffset = sendcount * ncclTypeSize(datatype);
|
||||
if (((char*)recvbuff) != (((char*)sendbuff) + comm->rank * rankOffset)) {
|
||||
@@ -123,7 +120,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
|
||||
}
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
for (int r = 0; r < nRanks; r++) {
|
||||
int peer = (comm->rank + r) % nRanks;
|
||||
int peer = (comm->rank + r) % nRanks;
|
||||
if (in_place && (peer == comm->rank)) {
|
||||
continue;
|
||||
}
|
||||
@@ -132,7 +129,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
return ncclSuccess;
|
||||
} else {
|
||||
} else {
|
||||
// use ring allgather
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
@@ -248,7 +245,7 @@ ncclResult_t ncclAllToAllv_impl(const void *sendbuff, const size_t sendcounts[],
|
||||
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
|
||||
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) {
|
||||
NVTX3_FUNC_WITH_PARAMS(AllToAllv, NcclNvtxParamsAllToAllv,
|
||||
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype),
|
||||
NVTX3_PAYLOAD(comm ? comm->commHash : 0, sendcounts[comm->rank] * ncclTypeSize(datatype),
|
||||
recvcounts[comm->rank] * ncclTypeSize(datatype), datatype));
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
|
||||
@@ -24,7 +24,7 @@ THE SOFTWARE.
|
||||
#include "nccl_common.h"
|
||||
#include "nccl.h"
|
||||
#include "param.h"
|
||||
|
||||
#include "core.h"
|
||||
typedef enum RcclTunableColls {
|
||||
RCCL_UNSUPPORTED_TUNABLE = -1,
|
||||
RCCL_RS_TUNABLE = 0, // reduce_scatter index
|
||||
@@ -47,6 +47,13 @@ typedef enum {
|
||||
RCCL_VALUE_INVALID = -1
|
||||
} rcclValueState_t;
|
||||
|
||||
typedef enum {
|
||||
RCCL_DIRECT_ALLGATHER = NCCL_NUM_ALGORITHMS, // Direct AllGather
|
||||
RCCL_MSCCL,
|
||||
RCCL_MSCCLPP,
|
||||
RCCL_ALGO_COUNT
|
||||
} rcclAddonAlgos_t;
|
||||
|
||||
#ifdef RCCL_EXPOSE_STATIC
|
||||
#define RCCL_STATIC_EXPOSE_CHECK()
|
||||
#else
|
||||
@@ -87,9 +94,10 @@ ncclResult_t rcclOverrideAlgorithm(const char* ncclAlgoStr[], float table[][NCCL
|
||||
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
|
||||
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold);
|
||||
void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
|
||||
ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType,
|
||||
int collNetSupport, int nvlsSupport, int numPipeOps,
|
||||
int* algo, int* protocol, int* maxChannels);
|
||||
NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType, int collNetSupport, int nvlsSupport, int numPipeOps, int* algo, int* protocol, int* maxChannels);
|
||||
NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName);
|
||||
NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName);
|
||||
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize);
|
||||
void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable);
|
||||
void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize);
|
||||
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
|
||||
|
||||
@@ -27,12 +27,13 @@ THE SOFTWARE.
|
||||
|
||||
RCCL_PARAM_DECLARE(EnableHipGraph); // Opt-in environment variable for enabling hipGraph
|
||||
|
||||
#define RCCL_EXPOSE_STATIC // Expose needed static functions for rccl-tests (or unit-testing in future)
|
||||
#ifdef RCCL_EXPOSE_STATIC
|
||||
#define rccl_static
|
||||
#define rccl_static_inline inline
|
||||
#else
|
||||
#define rccl_static static
|
||||
#define rccl_static_inline static inline
|
||||
#endif
|
||||
#endif`
|
||||
|
||||
#endif
|
||||
|
||||
@@ -33,6 +33,7 @@ RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0);
|
||||
// Use this to assess impact of pipelining on performance.
|
||||
// Otherwise, it is automatically set for certain archs, datatypes and reduction collectives
|
||||
RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0);
|
||||
RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304);
|
||||
|
||||
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
|
||||
// Honor user input for protocol choice
|
||||
@@ -234,6 +235,15 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co
|
||||
int collNetSupport, int nvlsSupport, int numPipeOps,
|
||||
int* algo, int* protocol, int* maxChannels) {
|
||||
RCCL_STATIC_EXPOSE_CHECK();
|
||||
int nRanks;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
size_t msgSize = count * ncclTypeSize(dataType) * nRanks;
|
||||
if (coll == ncclFuncAllGather && rcclUseAllGatherDirect(comm, msgSize)) {
|
||||
*algo = rcclAddonAlgos_t::RCCL_DIRECT_ALLGATHER;
|
||||
*protocol = NCCL_PROTO_SIMPLE; // TODO: consider LL for small messages
|
||||
*maxChannels = comm->nChannels;
|
||||
return ncclSuccess;
|
||||
}
|
||||
struct ncclTaskColl task;
|
||||
task.func = coll;
|
||||
task.count = count;
|
||||
@@ -245,6 +255,46 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t rcclGetAlgoName(int algo, const char** algoName) {
|
||||
if (algo < 0 || algo >= RCCL_ALGO_COUNT) {
|
||||
WARN("Invalid algorithm value: %d", algo);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
if(algo >= NCCL_NUM_ALGORITHMS) {
|
||||
switch(algo) {
|
||||
case rcclAddonAlgos_t::RCCL_DIRECT_ALLGATHER:
|
||||
*algoName = "Direct";
|
||||
break;
|
||||
case rcclAddonAlgos_t::RCCL_MSCCL:
|
||||
*algoName = "MSCCL";
|
||||
break;
|
||||
case rcclAddonAlgos_t::RCCL_MSCCLPP:
|
||||
*algoName = "MSCCLPP";
|
||||
break;
|
||||
default:
|
||||
WARN("Invalid algorithm value: %d", algo);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
*algoName = ncclAlgoToString(algo);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) {
|
||||
if (protocol < 0 || protocol >= NCCL_NUM_PROTOCOLS) {
|
||||
WARN("Invalid protocol value: %d", protocol);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
*protocolName = ncclProtoToString(protocol);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) {
|
||||
return (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() &&
|
||||
rcclParamDirectAllGatherThreshold() > -1));
|
||||
}
|
||||
|
||||
void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) {
|
||||
static int pxnDisable = RCCL_VALUE_UNSET;
|
||||
comm->enableCustColl = false;
|
||||
|
||||
Reference in New Issue
Block a user