diff --git a/projects/rccl/src/collectives.cc b/projects/rccl/src/collectives.cc index 390045eccc..74410533a8 100644 --- a/projects/rccl/src/collectives.cc +++ b/projects/rccl/src/collectives.cc @@ -375,6 +375,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count, return ncclEnqueueCheck(&info); } + NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream); ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t recvcount, @@ -386,6 +387,10 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS, nullptr }; + int nRanks; + NCCLCHECK(ncclCommCount(comm, &nRanks)); + size_t msgSize = recvcount * ncclTypeSize(datatype) * nRanks; + if (!mscclIsCaller()) // when msccl falls back to { NCCLCHECK(Recorder::instance().record(rrReduceScatter, info)); @@ -396,7 +401,42 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream); } + + // Reset value forcing direct reduce scatter algorithm + comm->enableDirectReduceScatter = 0; + if (rcclUseReduceScatterDirect(comm, msgSize)) { + INFO(NCCL_INIT, "RCCL DIRECT REDUCE-SCATTER recvcount=%zu msgSize=%zu rank=%d nRanks=%d nNodes=%d comm=%p stream=%p sendbuff=%p recvbuff=%p", + recvcount, msgSize, comm->rank, nRanks, comm->nNodes, comm, stream, sendbuff, recvbuff); + + // Temporary Buffer to store data from each rank + void* tempbuff = comm->tempBuff; + + // Use Direct Reduce Scatter Algorithm + comm->enableDirectReduceScatter = 1; + + if (recvcount == 0) return ncclSuccess; + + // Calculate offset into buffers + size_t offset = recvcount * ncclTypeSize(datatype); + + // Copy Current ranks data to tempbuff + // Enqueue the copy on the user stream so it is correctly ordered w.r.t. the subsequent + // ncclSend/ncclRecv and the rest of the ReduceScatter work on the same stream. + NCCLCHECK(ncclCudaMemcpyAsync((char*)tempbuff + comm->rank * offset, (char*)sendbuff + comm->rank * offset, offset, stream)); + + NCCLCHECK(ncclGroupStart()); + for (int i = 0; i < nRanks; i++) { + int peer = (comm->rank + i) % nRanks; + if (peer == comm->rank) { + continue; + } + NCCLCHECK(ncclSend((void*)((char*)sendbuff + peer * offset), recvcount, datatype, peer, comm, stream)); + NCCLCHECK(ncclRecv((void*)((char*)tempbuff + peer * offset), recvcount, datatype, peer, comm, stream)); + } + NCCLCHECK(ncclGroupEnd()); + } + return ncclEnqueueCheck(&info); } diff --git a/projects/rccl/src/device/reduce_scatter.h b/projects/rccl/src/device/reduce_scatter.h index 8ce151b66f..d8a9455203 100644 --- a/projects/rccl/src/device/reduce_scatter.h +++ b/projects/rccl/src/device/reduce_scatter.h @@ -16,127 +16,164 @@ namespace { #else __device__ __attribute__((noinline)) void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { #endif -#ifdef ENABLE_WARP_SPEED - int warp = threadIdx.x / WARP_SIZE; - ncclRing *ring = &ncclShmem.warpChannel[warp].ring; -#else - ncclRing *ring = &ncclShmem.channel.ring; -#endif - int const *ringRanks = ring->userRanks; - const int nranks = ncclShmem.comm.nRanks; - size_t count; - size_t gridOffset; - size_t channelCount; - size_t chunkCount; -#ifdef ENABLE_WARP_SPEED - ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &count, &gridOffset, &channelCount, &chunkCount); -#else - ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &count, &gridOffset, &channelCount, &chunkCount); -#endif - size_t offset; - size_t dataOffset; - uint32_t nelem; - int rankDest; + //TODO: move Direct Reduce Scatter path to a separate kernel + size_t msgSize = work->count * sizeof(T) * ncclShmem.comm.nRanks; + if (work->enableDirectReduceScatter && msgSize <= (size_t)work->directReduceScatterLimitBytes) { + const int nRanks = ncclShmem.comm.nRanks; + const ssize_t numElements = work->count; -#if defined(ENABLE_NPKIT) - int npKitCtxIdx = ncclShmem.channelId; -#endif + // Calculate Offset to utilize multiple channels + ssize_t elementsPerBlock = numElements / gridDim.x; + ssize_t remainderElements = numElements % gridDim.x; + // Calculate the number of elements per block for each block + // The first n blocks get 1 extra element to account for the remainder (n = remainderElements) + ssize_t numElementsPerBlock = elementsPerBlock + (blockIdx.x < remainderElements ? 1 : 0); + ssize_t channelOffset = blockIdx.x * elementsPerBlock + min((ssize_t)blockIdx.x, remainderElements); -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, NPKIT_GET_CPU_TIMESTAMP_FROM_BLOCK, - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } -#endif - -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } -#endif - -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_ENTRY) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_ENTRY, count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } -#endif - // Coverity reports that the callee treats &ring->next as an array. However, due to the use of - // FanSymmetric<1>, only the first element is ever accessed, so it's fine. - // coverity[callee_ptr_arith:FALSE] - Primitives, 0, Proto, 0, false, 0, Pipeline> - prims(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex); - -#if defined(ENABLE_NPKIT) - if (tid == 0) { - prims.npKitCtxIdx = npKitCtxIdx; - } -#endif - - for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { - nelem = min(chunkCount, channelCount - elemOffset); - - dataOffset = gridOffset + elemOffset; - /////////////// begin ReduceScatter steps /////////////// - // step 0: push data to next GPU -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_ENTRY) + // Array of src pointers pointing to rank offsets in tempBuff + void** srcPtrs = (void**)ncclScratchForWarp(0); if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + for (int i = 0; i < nRanks; i++) { + // Define offset into tempbuff for each rank's data + const ssize_t srcOffset = i * numElements + channelOffset; + srcPtrs[i] = (void*)((T*)work->tempBuff + srcOffset); + } + } + // Sync threads to ensure all srcPtrs are set before reduction + __syncthreads(); + + T* recvbuff = (T*)work->recvbuff; + // Array for destination pointer to recvbuff + void* dstPtrs[1]; + dstPtrs[0] = (void*)(recvbuff + channelOffset); + if (tid < nthreads) { + // Call reduction across all rank offsets in tempbuff and store in recvbuff + reduceCopy + (tid, nthreads, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nRanks, srcPtrs, 1, dstPtrs, numElementsPerBlock); + } + } else { + #ifdef ENABLE_WARP_SPEED + int warp = threadIdx.x / WARP_SIZE; + ncclRing *ring = &ncclShmem.warpChannel[warp].ring; + #else + ncclRing *ring = &ncclShmem.channel.ring; + #endif + int const *ringRanks = ring->userRanks; + const int nranks = ncclShmem.comm.nRanks; + size_t count; + size_t gridOffset; + size_t channelCount; + size_t chunkCount; + #ifdef ENABLE_WARP_SPEED + ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &count, &gridOffset, &channelCount, &chunkCount); + #else + ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &count, &gridOffset, &channelCount, &chunkCount); + #endif + size_t offset; + size_t dataOffset; + uint32_t nelem; + int rankDest; + + #if defined(ENABLE_NPKIT) + int npKitCtxIdx = ncclShmem.channelId; + #endif + + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, NPKIT_GET_CPU_TIMESTAMP_FROM_BLOCK, ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } -#endif - rankDest = ringRanks[nranks-1]; - offset = dataOffset + rankDest * count; - prims.send(offset, nelem); -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_EXIT) + #endif + + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } -#endif - // k-2 steps: reduce and copy to next GPU -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_RECV_REDUCE_SEND_ENTRY) + #endif + + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_ENTRY) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_RECV_REDUCE_SEND_ENTRY, nelem*(nranks-2)*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_ENTRY, count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } -#endif - for (int j=2; jnext as an array. However, due to the use of + // FanSymmetric<1>, only the first element is ever accessed, so it's fine. + // coverity[callee_ptr_arith:FALSE] + Primitives, 0, Proto, 0, false, 0, Pipeline> + prims(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex); + + #if defined(ENABLE_NPKIT) + if (tid == 0) { + prims.npKitCtxIdx = npKitCtxIdx; + } + #endif + + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + nelem = min(chunkCount, channelCount - elemOffset); + + dataOffset = gridOffset + elemOffset; + /////////////// begin ReduceScatter steps /////////////// + // step 0: push data to next GPU + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_ENTRY) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } + #endif + rankDest = ringRanks[nranks-1]; offset = dataOffset + rankDest * count; - prims.recvReduceSend(offset, nelem); - } -#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_RECV_REDUCE_SEND_EXIT) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_RECV_REDUCE_SEND_EXIT, nelem*(nranks-2)*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } -#endif + prims.send(offset, nelem); + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_EXIT) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_SEND_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } + #endif + // k-2 steps: reduce and copy to next GPU + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_REDUCE_SCATTER_RING_RECV_REDUCE_SEND_ENTRY) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_REDUCE_SCATTER_RING_RECV_REDUCE_SEND_ENTRY, nelem*(nranks-2)*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } + #endif + for (int j=2; j // std::memcpy #include // PRIx64 @@ -161,6 +162,7 @@ static inline int ncclFuncTrafficPerByte(ncclFunc_t func, int nRanks) { } RCCL_PARAM_DECLARE(EnableProxyTrace); +RCCL_PARAM_DECLARE(DirectReduceScatterThreshold); /*****************************************************************************/ /* Launch system : synchronization and CUDA kernel launch */ /*****************************************************************************/ @@ -412,7 +414,22 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { devWork.size = task->count; } #endif - + // Direct Reduce Scatter + if (task->func == ncclFuncReduceScatter && comm->enableDirectReduceScatter) { + devWork.enableDirectReduceScatter = comm->enableDirectReduceScatter; + int64_t directReduceScatterLimit = rcclParamDirectReduceScatterThreshold(); + if (directReduceScatterLimit >= 0) { + // set threshold to 2MiB hard limit + directReduceScatterLimit = std::min(directReduceScatterLimit, (int64_t)2097152); + devWork.directReduceScatterLimitBytes = (uint32_t) directReduceScatterLimit; + } else { + devWork.directReduceScatterLimitBytes = (uint32_t)0; + } + devWork.tempBuff = (void*)comm->tempBuff; + devWork.currentRank = comm->rank; + devWork.count = task->count; + } + devWork.isOneRPN = comm->isOneRPN; devWork.netRegUsed = devWork.regUsed = 0; devWork.gfx9CheapFenceOff = gfx9CheapFenceOff(devWork, comm->gfx9CheapFenceOff); @@ -725,10 +742,12 @@ static ncclResult_t scheduleCollTasksToPlan( proxyOp.incWorkCounter = true; addWorkBatchToPlan(comm, plan, c, workNode->workType, task->devFuncId, plan->workBytes); // Set pattern to profiler to add a proxy profiler for kernel events - if (task->func != ncclFuncAllToAllGda) { + // for Direct Reduce Scatter (DRS), we don't need to add proxy op + bool isDRS = task->func == ncclFuncReduceScatter && comm->enableDirectReduceScatter; + if (!isDRS && task->func != ncclFuncAllToAllGda) { NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp)); NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp)); - } + } } } else { // not task->isCollnet int trafficPerByte = ncclFuncTrafficPerByte(task->func, comm->nRanks); @@ -875,7 +894,9 @@ static ncclResult_t scheduleCollTasksToPlan( // Coverity reports "proxyOp->connection" as being possibly uninitialized. It's hard to // determine if that's actually true but it's also not clear if that would be an issue. // coverity[uninit_use_in_call:FALSE] - if (task->func != ncclFuncAllToAllGda) { + // for Direct Reduce Scatter (DRS), we don't need to add proxy op + bool isDRS = task->func == ncclFuncReduceScatter && comm->enableDirectReduceScatter; + if (!isDRS && task->func != ncclFuncAllToAllGda) { NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp)); NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp)); } diff --git a/projects/rccl/src/include/comm.h b/projects/rccl/src/include/comm.h index 694fee1083..aa11eed498 100644 --- a/projects/rccl/src/include/comm.h +++ b/projects/rccl/src/include/comm.h @@ -763,6 +763,11 @@ struct ncclComm { int symId; #endif + // Direct Reduce Scatter [RCCL] + bool enableDirectReduceScatter; + // Temporary Buffer [RCCL] + void* tempBuff; + uint64_t endMagic; }; diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index d8119c6b08..512f144147 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -385,6 +385,15 @@ struct alignas(16) ncclDevWorkColl { uintptr_t recvbuffOffset; uintptr_t* sendbuffRmtAddrs; uintptr_t* recvbuffRmtAddrs; + + bool enableDirectReduceScatter; + // Per-work (per kernel launch) limit for Direct ReduceScatter in bytes. + // This is set by the host and used as a device-side safety gate. + uint32_t directReduceScatterLimitBytes; + void* tempBuff; + int currentRank; + size_t count; + union { // Continuous-byte-distribution scheduling. The lo and hi channels are of // different size than the channels in the middle. diff --git a/projects/rccl/src/include/rccl_common.h b/projects/rccl/src/include/rccl_common.h index dd3a4b396e..5628d2365a 100644 --- a/projects/rccl/src/include/rccl_common.h +++ b/projects/rccl/src/include/rccl_common.h @@ -25,6 +25,7 @@ THE SOFTWARE. #include "nccl.h" #include "param.h" #include "core.h" + typedef enum RcclTunableColls { RCCL_UNSUPPORTED_TUNABLE = -1, RCCL_RS_TUNABLE = 0, // reduce_scatter index @@ -114,12 +115,16 @@ NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll, NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName); NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName); bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize); +bool rcclUseReduceScatterDirect(struct ncclComm* comm, size_t& msgSize); bool rcclUseAllToAllGda(struct ncclComm* comm); void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable); void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize); ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount); ncclResult_t commSetUnrollFactor(struct ncclComm* comm); bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, const char* archName); + +// Direct ReduceScatter Limit +RCCL_PARAM_DECLARE(DirectReduceScatterThreshold); int getFirmwareVersion(); bool rcclIsArchSupportedForFunc(struct ncclTaskColl* info, char const* archName); #ifdef ENABLE_WARP_SPEED diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index a5e1ea83ee..693b78c7a3 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -90,6 +90,8 @@ #define NCCL_GROUP_CUDA_STREAM 1 // CGMD: CUDA 9.0,9.1 Need to use an internal CUDA stream #endif +#define TEMP_BUFF_SIZE (4 * 1024 * 1024) // Define Size for Temporary Buffer for Direct RS + using namespace rccl; const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3] = { "AllGather", "AllReduce", "AlltoAllPivot", "AllToAllGda", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"}; @@ -484,6 +486,13 @@ static ncclResult_t commFree(ncclComm_t comm) { NCCLCHECK(ncclCeFinalize(comm)); + // tempBuff is allocated per-communicator for direct ReduceScatter on gfx950. + // It is owned by the communicator; free it during communicator teardown. + if (comm->tempBuff) { + NCCLCHECK(ncclCudaFree(comm->tempBuff)); + comm->tempBuff = nullptr; + } + if (comm->symmetricSupport) { NCCLCHECK(ncclSymkFinalize(comm)); NCCLCHECK(ncclDevrFinalize(comm)); @@ -2285,6 +2294,10 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { } #endif + // Allocate Temp Buffer for Direct Reduce Scatter + if (IsArchMatch(archName,"gfx950")) { + NCCLCHECK(ncclCudaMalloc(&(comm->tempBuff), TEMP_BUFF_SIZE)); + } #ifdef ENABLE_MSCCLPP if (job->parent) { diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index 1293fcfe66..d15853d08c 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -41,6 +41,7 @@ RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0); // Otherwise, it is automatically set for certain archs, datatypes and reduction collectives RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0); RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 75497472); +RCCL_PARAM(DirectReduceScatterThreshold, "DIRECT_REDUCE_SCATTER_THRESHOLD", 2097152); RCCL_PARAM(ThreadsPerBlock, "THREADS_PER_BLOCK", -1); RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", -1); #ifdef ENABLE_WARP_SPEED @@ -465,6 +466,41 @@ bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) { ; } +bool rcclUseReduceScatterDirect(struct ncclComm* comm, size_t& msgSize) { + // Direct ReduceScatter is supported for MI350 (gfx950): + // - 2 nodes: enable for 128KiB .. 2MiB + // - 4 and 8 nodes: enable up to 2MiB + static int userDirectReduceScatterInput = -2; + if (userDirectReduceScatterInput == -2) { + const char *inputStr = getenv("RCCL_DIRECT_REDUCE_SCATTER_DISABLE"); + userDirectReduceScatterInput = !inputStr ? 0 : 1; + } + if (userDirectReduceScatterInput == 1) { + INFO(NCCL_INIT, "RCCL DIRECT REDUCE-SCATTER has been disabled."); + return false; + } + const bool archGfx950 = IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950"); + if (!archGfx950) return false; + + size_t threshold = rcclParamDirectReduceScatterThreshold(); + if (threshold > -1) { + // Set threshold to 2MiB hard limit + // NOTE: If the DirectReduceScatterThreshold / hard-limit is increased, ensure TEMP_BUFF_SIZE (init.cc) + // is increased accordingly -> TEMP_BUFF_SIZE >= 2 * (max enabled msgSize) for headroom. + threshold = std::min(threshold, (size_t)2097152); + } else { + threshold = 2097152; + } + INFO(NCCL_INIT, "RCCL DIRECT REDUCE-SCATTER threshold set to: %zu", threshold); + + if (msgSize > threshold) return false; + // for 2 nodes, enable if msgSize is in 128KiB .. 2MiB range + if (comm->nNodes == 2) return msgSize >= (size_t)131072; + if (comm->nNodes == 8 || comm->nNodes == 4) return true; + return false; +} + + void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) { static int pxnDisable = RCCL_VALUE_UNSET; comm->enableCustColl = false;