From 2f6b20c00a5c480dcdc8d5a9dc2655f9fde42cc2 Mon Sep 17 00:00:00 2001 From: alex-breslow-amd Date: Thu, 29 May 2025 16:17:35 -0700 Subject: [PATCH] Use One Slice per Basic Primitive for AllReduce, ReduceScatter, AllGather (#1681) for Single Node on Some GFX9 Systems Using a single slice rather than the typical two provides about 5% speedup (sometimes more or less) on some GFX9 systems for single node. --- src/collectives.cc | 7 ++++--- src/device/all_gather.h | 17 +++++++++++++++-- src/device/all_reduce.h | 18 ++++++++++++++++-- src/device/reduce_scatter.h | 18 ++++++++++++++++-- src/enqueue.cc | 1 + src/include/collectives.h | 10 ++++++++++ src/include/comm.h | 1 + src/include/device.h | 2 +- src/init.cc | 2 ++ 9 files changed, 66 insertions(+), 10 deletions(-) diff --git a/src/collectives.cc b/src/collectives.cc index 82e81983e4..efe502b431 100644 --- a/src/collectives.cc +++ b/src/collectives.cc @@ -89,7 +89,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen struct ncclInfo info = { ncclFuncAllGather, "AllGather", sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ - ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS }; + ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS }; if (!mscclIsCaller()) // when msccl falls back to { @@ -114,9 +114,10 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce, NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype)); + // RCCL update slice steps for AllReduce if single node struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ - ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS }; + ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS }; if (!mscclIsCaller()) // when msccl falls back to { @@ -329,7 +330,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ - REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS }; + REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS }; if (!mscclIsCaller()) // when msccl falls back to { diff --git a/src/device/all_gather.h b/src/device/all_gather.h index c54c90d1f3..297b72d3f8 100644 --- a/src/device/all_gather.h +++ b/src/device/all_gather.h @@ -175,6 +175,18 @@ namespace { } } +#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices. +#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \ + if(work->rcclUseOneSlice){ \ + runRing, false>(tid, nthreads, work); \ + } else{ \ + runRing, false>(tid, nthreads, work); \ + } +#else +#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \ + runRing, false>(tid, nthreads, work); +#endif + template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { @@ -185,8 +197,9 @@ struct RunWorkColl, true>(tid, nthreads, work); - else - runRing, false>(tid, nthreads, work); + else{ + rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work); + } } }; diff --git a/src/device/all_reduce.h b/src/device/all_reduce.h index 6c58d72ddb..44f7dddd8f 100644 --- a/src/device/all_reduce.h +++ b/src/device/all_reduce.h @@ -558,11 +558,25 @@ namespace { } } +#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices. +#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \ + if(work->rcclUseOneSlice){ \ + using Proto = ProtoSimple; \ + runRing(tid, nthreads, work); \ + } else{ \ + using Proto = ProtoSimple; \ + runRing(tid, nthreads, work); \ + } +#else +#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \ + using Proto = ProtoSimple; \ + runRing(tid, nthreads, work); +#endif + template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { - using Proto = ProtoSimple; - runRing(tid, nthreads, work); + rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work); } }; diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h index 5c2085bd6d..81e6976f53 100644 --- a/src/device/reduce_scatter.h +++ b/src/device/reduce_scatter.h @@ -131,11 +131,25 @@ namespace { } } +#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices. +#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \ + if(work->rcclUseOneSlice){ \ + using Proto = ProtoSimple; \ + runRing(tid, nthreads, work); \ + } else{ \ + using Proto = ProtoSimple; \ + runRing(tid, nthreads, work); \ + } +#else +#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \ + using Proto = ProtoSimple; \ + runRing(tid, nthreads, work); +#endif + template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { - using Proto = ProtoSimple; - runRing(tid, nthreads, work); + rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work); } }; diff --git a/src/enqueue.cc b/src/enqueue.cc index 98c862ba7f..a8a3536253 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -332,6 +332,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { devWork.redOpArg = task->opDev.scalarArg; devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr; devWork.oneNode = (comm->nNodes == 1); + devWork.rcclUseOneSlice = comm->rcclUseOneSlice; devWork.isOneRPN = comm->isOneRPN; devWork.netRegUsed = devWork.regUsed = 0; if (task->regBufType & NCCL_NET_REG_BUFFER) diff --git a/src/include/collectives.h b/src/include/collectives.h index d9e653c760..6734d6d7d3 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -15,11 +15,17 @@ #define NCCL_MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two. // CHUNKSIZE must be a multiple of SLICESIZE +// RCCL: Benchmarking on single node for MI300X showed improved throughput for single node always using +// a single slice, so we have separate configurations for single node and multi-node. Single node configs +// are suffixed with _SINGLE_NODE. #define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4) +#define ALLREDUCE_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2) #define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2) #define ALLGATHER_SLICESTEPS (NCCL_STEPS/4) +#define ALLGATHER_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2) #define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2) #define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4) +#define REDUCESCATTER_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2) #define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2) #define BROADCAST_SLICESTEPS 1 #define BROADCAST_CHUNKSTEPS 1 @@ -30,6 +36,10 @@ #define ALLTOALL_PIVOT_SLICESTEPS 2 #define ALLTOALL_PIVOT_CHUNKSTEPS 4 +static_assert(ALLREDUCE_CHUNKSTEPS == ALLREDUCE_SLICESTEPS_SINGLE_NODE, "ALLREDUCE_CHUNKSTEPS must be equal to ALLREDUCE_SLICESTEPS_SINGLE_NODE"); +static_assert(ALLGATHER_CHUNKSTEPS == ALLGATHER_SLICESTEPS_SINGLE_NODE, "ALLGATHER_CHUNKSTEPS must be equal to ALLGATHER_SLICESTEPS_SINGLE_NODE"); +static_assert(REDUCESCATTER_CHUNKSTEPS == REDUCESCATTER_SLICESTEPS_SINGLE_NODE, "REDUCESCATTER_CHUNKSTEPS must be equal to REDUCESCATTER_SLICESTEPS_SINGLE_NODE"); + const char* ncclFuncToString(ncclFunc_t op); const char* ncclDevRedOpToString(ncclDevRedOp_t op); const char* ncclDatatypeToString(ncclDataType_t type); diff --git a/src/include/comm.h b/src/include/comm.h index 2bb73c7170..c0e1e0ce37 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -481,6 +481,7 @@ struct ncclComm { int node; int nNodes; + int rcclUseOneSlice; // RCCL: true if this comm is using one slice per primitive int localRank; int localRanks; int maxLocalRanks; diff --git a/src/include/device.h b/src/include/device.h index fe6f94c1ee..434098548c 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -288,7 +288,7 @@ struct alignas(16) ncclDevWorkColl { // nChannels == (channelHi - channelLo) + 1 uint32_t channelLo:8, channelHi:8; uint32_t nWarps:8; - uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1; + uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1, rcclUseOneSlice:1; uint32_t root:30, connIndex:2; uint16_t pivotA2ANumBiRings; void* recvbuff; diff --git a/src/init.cc b/src/init.cc index cda7f1359f..99d5d3b173 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1353,6 +1353,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p // Multi-node MI300A int managed = 0; CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0)); + // RCCL: Only use one slice per primitive on some single node gfx9xx systems + comm->rcclUseOneSlice = !managed && nNodes == 1; if (managed && nNodes > 1) { // This forces the minimum channels to 24 allGather3Data[rank].nc = 6;