Reapplying PR #1641 [AG and RS channel tuning] Add thread work threshold to tuning models and precompute reg index in LL128 (#1713)

* Reapply "[AG and RS channel tuning] Add thread work threshold to tuning models and precompute reg index in LL128 (#1641)"

This reverts commit 943ad6f7820739385a0b54e81f823d0df1dbf71c.

* Decreasing NCCL_LL128_SHMEM_ELEMS_PER_THREAD from 16 to 8
This commit is contained in:
Pedram Alizadeh
2025-06-04 13:22:11 -04:00
committato da GitHub
parent e94b360246
commit 3f7c08648f
7 ha cambiato i file con 50 aggiunte e 17 eliminazioni
+6 -2
Vedi File
@@ -8,16 +8,20 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
* Resolved an issue when using more than 64 channels when multiple collectives are used in the same `ncclGroup()` call.
* Fixed unit test failures in tests ending with `ManagedMem` and `ManagedMemGraph` suffixes.
* Suboptimal algorithmic switching point for AllReduce on MI300x
* Fixed the known issue "When splitting a communicator using `ncclCommSplit` in some GPU configurations, MSCCL initialization can cause a segmentation fault." with a design change to use `comm` instead of `rank` for `mscclStatus`. The Global map for `comm` to `mscclStatus` is still not thread safe but should be explicitly handled by mutexes for read writes. This is tested for correctness, but there is a plan to use a thread-safe map data structure in upcoming changes.
### Added
* Added new GPU target `gfx950`.
* Added support for `unroll=1` in device-code generation to improve performance
* Set a default of 112 channels for a single node with `8 * gfx950`
* Added MSCCL support for multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the
* Added MSCCL support for AllGather multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the
environment variable `RCCL_MSCCL_FORCE_ENABLE=1`. Max message size for MSCCL AllGather usage is `12292 * sizeof(datatype) * nGPUs`.
* Thread thresholds for LL/LL128 are selected in Tuning Models for the MI300X. This impacts the number of channels used for AG and RS. Channel tuning model is bypassed if `NCCL_THREAD_THRESHOLDS`, `NCCL_MIN_NCHANNELS', or 'NCCL_MAX_NCHANNELS` are set.
* Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges.
* LL/LL128 usage ranges for AR, AG, and RS are part of the tuning models, which enable architecture-specific tuning in conjunction with the existing Rome Models scheme in RCCL.
* Two new APIs are exposed as part of an initiative to separate RCCL code. These APIs are `rcclGetAlgoInfo` and `rcclFuncMaxSendRecvCount`. However, user-level invocation requires that RCCL be built with `RCCL_EXPOSE_STATIC` enabled.
### Changed
+11 -6
Vedi File
@@ -126,6 +126,11 @@ private:
template<int WordPerThread>
__device__ __forceinline__ void loadRegsBegin(uint64_t(&regs)[WordPerThread], T const *src, int eltN) {
constexpr int EltPer16B = 16/sizeof(T);
int ix[WordPerThread/2];
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
ix[g] = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4);
}
if(reinterpret_cast<uintptr_t>(src)%16 == 0) {
/* We are aligned to 16 bytes, so load directly to registers no shmem.
* Flag threads load half as much data which gets shuffled to the even
@@ -135,10 +140,9 @@ private:
*/
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4);
if(!flagThread || g%2==0) {
if(ix*EltPer16B < eltN)
load128((uint64_t*)(src + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
if(ix[g]*EltPer16B < eltN)
load128((uint64_t*)(src + ix[g]*EltPer16B), regs[2*g+0], regs[2*g+1]);
}
}
}
@@ -163,10 +167,10 @@ private:
T *shm = (T*)shm8 + misalignment/sizeof(T);
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4);
// int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4);
if(!flagThread || g%2==0) {
if(ix*EltPer16B < eltN)
loadShmemMisaligned128(shm + ix*EltPer16B, regs[2*g+0], regs[2*g+1]);
if(ix[g]*EltPer16B < eltN)
loadShmemMisaligned128(shm + ix[g]*EltPer16B, regs[2*g+0], regs[2*g+1]);
}
}
}
@@ -189,6 +193,7 @@ private:
for (int g=1; g < WordPerThread/2; g+=2) {
if (flagThread) regs[2*g-1] = regs[2*g];
}
// Write to dst if 4-byte aligned, shmem otherwise.
int misalignment = reinterpret_cast<uintptr_t>(dst)%16;
uint64_t *shm8 = shmemCvtPtr((uint64_t*)ncclScratchForWarp(warpInBlock));
+1
Vedi File
@@ -1773,6 +1773,7 @@ static ncclResult_t topoGetAlgoInfo(
// NVLS should not need more than 16 channels to get peak BW.
nc = comm->nvlsChannels;
} else {
rcclUpdateThreadThreshold(comm, nBytes, info, threadThreshold);
// Ring/Tree channel tuning
while (nBytes < nc * nt * threadThreshold) {
if (nc >= 2) nc--;
+4 -4
Vedi File
@@ -329,11 +329,11 @@ static struct tuningModel tuning_model_5 {
// Follow order in RcclTunableColls
.llProtoRanges = {
/*ReduceScatter*/
{/*LL (min/max/factor)*/ {0, 655360, 1}, /*LL64/128 (min/max/factor)*/ {131072, 3211264, 1}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 655360, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {131072, 4793500, 1, 64}},
/*AllGather*/
{/*LL (min/max/factor)*/ {0, 98304, 1}, /*LL64/128 (min/max/factor)*/ {98304, 5046272, 1}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 98304, 1, 16}, /*LL64/128 (min/max/factor/thread_threshold)*/ {98304, 5592500, 1, 64}},
/*AllReduce*/
{/*LL (min/max/factor)*/ {0, 1048576, 1},/*LL64/128 (min/max/factor)*/ {1048576, 9437184, 3145728}},
{/*LL (min/max/factor/thread_threshold)*/ {0, 1048576, 1, 0},/*LL64/128 (min/max/factor/thread_threshold)*/ {1048576, 144217728, 3145728, 0}},
},
};
@@ -722,7 +722,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
}
}
}
// Set per-thread amount of work before we increase nThreads and nChannels
for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) {
comm->threadThresholds[a][NCCL_PROTO_LL] = NCCL_LL_THREAD_THRESHOLD;
+2 -2
Vedi File
@@ -95,7 +95,7 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK
#define NCCL_LL128_MAX_NTHREADS 256
#define NCCL_LL128_ELEMS_PER_THREAD 28
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 4
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 8
#define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS)
#define NCCL_P2P_WRITE 0x01
@@ -698,7 +698,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto)
// RING / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncReduce) {
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto;
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * proto;
break;
}
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
+3 -3
Vedi File
@@ -33,10 +33,11 @@ typedef enum RcclTunableColls {
} rcclTunableIndex_t;
#define RCCL_LL_LIMITS_UNDEFINED 0
#define RCCL_PROTOCOL_ENTRY_SIZE 3
#define RCCL_PROTOCOL_ENTRY_SIZE 4
#define RCCL_PROTOCOL_MIN_IDX 0
#define RCCL_PROTOCOL_MAX_IDX 1
#define RCCL_PROTOCOL_FACTOR_IDX 2
#define RCCL_PROTOCOL_THREAD_THRESHOLD_IDX 3
#ifdef RCCL_EXPOSE_STATIC
#define RCCL_STATIC_EXPOSE_CHECK()
@@ -71,8 +72,7 @@ inline size_t rcclGetSizePerRank(ncclFunc_t const& func, size_t const& nBytes, i
return (func == ncclFuncReduceScatter || func == ncclFuncAllGather) ? nBytes / nRanks : nBytes;
}
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold);
ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType,
int collNetSupport, int nvlsSupport, int numPipeOps,
int* algo, int* protocol, int* maxChannels);
+23
Vedi File
@@ -72,6 +72,29 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s
}
}
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold) {
// Honor user input for thread thresholds
static int userChannelControlInput = -2;
if (userChannelControlInput == -2) {
const char *inputStr = getenv("NCCL_THREAD_THRESHOLDS");
if (!inputStr) {
inputStr = getenv("NCCL_MAX_NCHANNELS");
}
if (!inputStr) {
inputStr = getenv("NCCL_MIN_NCHANNELS");
}
userChannelControlInput = !inputStr ? 0 : 1;
}
if(!userChannelControlInput && comm->nNodes >= 2 && (info->func == ncclFuncReduceScatter || info->func == ncclFuncAllGather)) {
auto tunableIndex = rcclGetTunableIndex(info->func);
auto tunedThreshold = comm->minMaxLLRange[tunableIndex][info->protocol][RCCL_PROTOCOL_THREAD_THRESHOLD_IDX];
if(tunedThreshold != RCCL_LL_LIMITS_UNDEFINED) {
threadThreshold = tunedThreshold * comm->nRanks;
}
}
}
extern ncclResult_t getAlgoInfo(
struct ncclComm* comm, struct ncclTaskColl* task,
int collNetSupport, int nvlsSupport, int numPipeOps, ncclSimInfo_t* simInfo = NULL