Use One Slice per Basic Primitive for AllReduce, ReduceScatter, AllGather (#1681) for Single Node on Some GFX9 Systems

Using a single slice rather than the typical two provides about 5% speedup (sometimes more or less) on some GFX9 systems for single node.
This commit is contained in:
alex-breslow-amd
2025-05-29 16:17:35 -07:00
committed by GitHub
parent 12517a957e
commit 2f6b20c00a
9 changed files with 66 additions and 10 deletions
+4 -3
View File
@@ -89,7 +89,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -114,9 +114,10 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype));
// RCCL update slice steps for AllReduce if single node
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -329,7 +330,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS };
if (!mscclIsCaller()) // when msccl falls back to
{
+15 -2
View File
@@ -175,6 +175,18 @@ namespace {
}
}
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS_SINGLE_NODE, ALLGATHER_SLICESTEPS_SINGLE_NODE>, false>(tid, nthreads, work); \
} else{ \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work); \
}
#else
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work);
#endif
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
@@ -185,8 +197,9 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
#endif
if (isNetOffload)
runRing<T, RedOp, ProtoSimple<1, 1>, true>(tid, nthreads, work);
else
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work);
else{
rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work);
}
}
};
+16 -2
View File
@@ -558,11 +558,25 @@ namespace {
}
}
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS_SINGLE_NODE, ALLREDUCE_SLICESTEPS_SINGLE_NODE>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
} else{ \
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
}
#else
#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>; \
runRing<T, RedOp, Proto>(tid, nthreads, work);
#endif
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
runRing<T, RedOp, Proto>(tid, nthreads, work);
rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work);
}
};
+16 -2
View File
@@ -131,11 +131,25 @@ namespace {
}
}
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
if(work->rcclUseOneSlice){ \
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS_SINGLE_NODE, REDUCESCATTER_SLICESTEPS_SINGLE_NODE>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
} else{ \
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
runRing<T, RedOp, Proto>(tid, nthreads, work); \
}
#else
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
runRing<T, RedOp, Proto>(tid, nthreads, work);
#endif
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
runRing<T, RedOp, Proto>(tid, nthreads, work);
rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work);
}
};
+1
View File
@@ -332,6 +332,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) {
devWork.redOpArg = task->opDev.scalarArg;
devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr;
devWork.oneNode = (comm->nNodes == 1);
devWork.rcclUseOneSlice = comm->rcclUseOneSlice;
devWork.isOneRPN = comm->isOneRPN;
devWork.netRegUsed = devWork.regUsed = 0;
if (task->regBufType & NCCL_NET_REG_BUFFER)
+10
View File
@@ -15,11 +15,17 @@
#define NCCL_MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two.
// CHUNKSIZE must be a multiple of SLICESIZE
// RCCL: Benchmarking on single node for MI300X showed improved throughput for single node always using
// a single slice, so we have separate configurations for single node and multi-node. Single node configs
// are suffixed with _SINGLE_NODE.
#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
#define ALLREDUCE_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2)
#define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2)
#define ALLGATHER_SLICESTEPS (NCCL_STEPS/4)
#define ALLGATHER_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2)
#define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2)
#define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4)
#define REDUCESCATTER_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2)
#define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2)
#define BROADCAST_SLICESTEPS 1
#define BROADCAST_CHUNKSTEPS 1
@@ -30,6 +36,10 @@
#define ALLTOALL_PIVOT_SLICESTEPS 2
#define ALLTOALL_PIVOT_CHUNKSTEPS 4
static_assert(ALLREDUCE_CHUNKSTEPS == ALLREDUCE_SLICESTEPS_SINGLE_NODE, "ALLREDUCE_CHUNKSTEPS must be equal to ALLREDUCE_SLICESTEPS_SINGLE_NODE");
static_assert(ALLGATHER_CHUNKSTEPS == ALLGATHER_SLICESTEPS_SINGLE_NODE, "ALLGATHER_CHUNKSTEPS must be equal to ALLGATHER_SLICESTEPS_SINGLE_NODE");
static_assert(REDUCESCATTER_CHUNKSTEPS == REDUCESCATTER_SLICESTEPS_SINGLE_NODE, "REDUCESCATTER_CHUNKSTEPS must be equal to REDUCESCATTER_SLICESTEPS_SINGLE_NODE");
const char* ncclFuncToString(ncclFunc_t op);
const char* ncclDevRedOpToString(ncclDevRedOp_t op);
const char* ncclDatatypeToString(ncclDataType_t type);
+1
View File
@@ -481,6 +481,7 @@ struct ncclComm {
int node;
int nNodes;
int rcclUseOneSlice; // RCCL: true if this comm is using one slice per primitive
int localRank;
int localRanks;
int maxLocalRanks;
+1 -1
View File
@@ -288,7 +288,7 @@ struct alignas(16) ncclDevWorkColl {
// nChannels == (channelHi - channelLo) + 1
uint32_t channelLo:8, channelHi:8;
uint32_t nWarps:8;
uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1;
uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1, rcclUseOneSlice:1;
uint32_t root:30, connIndex:2;
uint16_t pivotA2ANumBiRings;
void* recvbuff;
+2
View File
@@ -1353,6 +1353,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
// Multi-node MI300A
int managed = 0;
CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0));
// RCCL: Only use one slice per primitive on some single node gfx9xx systems
comm->rcclUseOneSlice = !managed && nNodes == 1;
if (managed && nNodes > 1) {
// This forces the minimum channels to 24
allGather3Data[rank].nc = 6;