Use One Slice per Basic Primitive for AllReduce, ReduceScatter, AllGather (#1681) for Single Node on Some GFX9 Systems
Using a single slice rather than the typical two provides about 5% speedup (sometimes more or less) on some GFX9 systems for single node.
This commit is contained in:
+4
-3
@@ -89,7 +89,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
|
||||
|
||||
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
|
||||
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
|
||||
ALLGATHER_CHUNKSTEPS, ALLGATHER_SLICESTEPS };
|
||||
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -114,9 +114,10 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
|
||||
NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce,
|
||||
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype));
|
||||
|
||||
// RCCL update slice steps for AllReduce if single node
|
||||
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
||||
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
||||
ALLREDUCE_CHUNKSTEPS, ALLREDUCE_SLICESTEPS };
|
||||
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -329,7 +330,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
|
||||
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
|
||||
REDUCESCATTER_CHUNKSTEPS, REDUCESCATTER_SLICESTEPS };
|
||||
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
|
||||
+15
-2
@@ -175,6 +175,18 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
|
||||
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
if(work->rcclUseOneSlice){ \
|
||||
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS_SINGLE_NODE, ALLGATHER_SLICESTEPS_SINGLE_NODE>, false>(tid, nthreads, work); \
|
||||
} else{ \
|
||||
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work); \
|
||||
}
|
||||
#else
|
||||
#define rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work);
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
@@ -185,8 +197,9 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
#endif
|
||||
if (isNetOffload)
|
||||
runRing<T, RedOp, ProtoSimple<1, 1>, true>(tid, nthreads, work);
|
||||
else
|
||||
runRing<T, RedOp, ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>, false>(tid, nthreads, work);
|
||||
else{
|
||||
rcclAllGatherRunRingSimpleProtoImpl(tid, nthreads, work);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
+16
-2
@@ -558,11 +558,25 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
|
||||
#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
if(work->rcclUseOneSlice){ \
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS_SINGLE_NODE, ALLREDUCE_SLICESTEPS_SINGLE_NODE>; \
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work); \
|
||||
} else{ \
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>; \
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work); \
|
||||
}
|
||||
#else
|
||||
#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>; \
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work);
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work);
|
||||
rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -131,11 +131,25 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
|
||||
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
if(work->rcclUseOneSlice){ \
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS_SINGLE_NODE, REDUCESCATTER_SLICESTEPS_SINGLE_NODE>; \
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work); \
|
||||
} else{ \
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work); \
|
||||
}
|
||||
#else
|
||||
#define rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>; \
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work);
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(tid, nthreads, work);
|
||||
rcclReduceScatterRunRingSimpleProtoImpl(tid, nthreads, work);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -332,6 +332,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) {
|
||||
devWork.redOpArg = task->opDev.scalarArg;
|
||||
devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr;
|
||||
devWork.oneNode = (comm->nNodes == 1);
|
||||
devWork.rcclUseOneSlice = comm->rcclUseOneSlice;
|
||||
devWork.isOneRPN = comm->isOneRPN;
|
||||
devWork.netRegUsed = devWork.regUsed = 0;
|
||||
if (task->regBufType & NCCL_NET_REG_BUFFER)
|
||||
|
||||
@@ -15,11 +15,17 @@
|
||||
#define NCCL_MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two.
|
||||
|
||||
// CHUNKSIZE must be a multiple of SLICESIZE
|
||||
// RCCL: Benchmarking on single node for MI300X showed improved throughput for single node always using
|
||||
// a single slice, so we have separate configurations for single node and multi-node. Single node configs
|
||||
// are suffixed with _SINGLE_NODE.
|
||||
#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define ALLREDUCE_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2)
|
||||
#define ALLREDUCE_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
#define ALLGATHER_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define ALLGATHER_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2)
|
||||
#define ALLGATHER_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
#define REDUCESCATTER_SLICESTEPS (NCCL_STEPS/4)
|
||||
#define REDUCESCATTER_SLICESTEPS_SINGLE_NODE (NCCL_STEPS/2)
|
||||
#define REDUCESCATTER_CHUNKSTEPS (NCCL_STEPS/2)
|
||||
#define BROADCAST_SLICESTEPS 1
|
||||
#define BROADCAST_CHUNKSTEPS 1
|
||||
@@ -30,6 +36,10 @@
|
||||
#define ALLTOALL_PIVOT_SLICESTEPS 2
|
||||
#define ALLTOALL_PIVOT_CHUNKSTEPS 4
|
||||
|
||||
static_assert(ALLREDUCE_CHUNKSTEPS == ALLREDUCE_SLICESTEPS_SINGLE_NODE, "ALLREDUCE_CHUNKSTEPS must be equal to ALLREDUCE_SLICESTEPS_SINGLE_NODE");
|
||||
static_assert(ALLGATHER_CHUNKSTEPS == ALLGATHER_SLICESTEPS_SINGLE_NODE, "ALLGATHER_CHUNKSTEPS must be equal to ALLGATHER_SLICESTEPS_SINGLE_NODE");
|
||||
static_assert(REDUCESCATTER_CHUNKSTEPS == REDUCESCATTER_SLICESTEPS_SINGLE_NODE, "REDUCESCATTER_CHUNKSTEPS must be equal to REDUCESCATTER_SLICESTEPS_SINGLE_NODE");
|
||||
|
||||
const char* ncclFuncToString(ncclFunc_t op);
|
||||
const char* ncclDevRedOpToString(ncclDevRedOp_t op);
|
||||
const char* ncclDatatypeToString(ncclDataType_t type);
|
||||
|
||||
@@ -481,6 +481,7 @@ struct ncclComm {
|
||||
|
||||
int node;
|
||||
int nNodes;
|
||||
int rcclUseOneSlice; // RCCL: true if this comm is using one slice per primitive
|
||||
int localRank;
|
||||
int localRanks;
|
||||
int maxLocalRanks;
|
||||
|
||||
@@ -288,7 +288,7 @@ struct alignas(16) ncclDevWorkColl {
|
||||
// nChannels == (channelHi - channelLo) + 1
|
||||
uint32_t channelLo:8, channelHi:8;
|
||||
uint32_t nWarps:8;
|
||||
uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1;
|
||||
uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1, rcclUseOneSlice:1;
|
||||
uint32_t root:30, connIndex:2;
|
||||
uint16_t pivotA2ANumBiRings;
|
||||
void* recvbuff;
|
||||
|
||||
@@ -1353,6 +1353,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
|
||||
// Multi-node MI300A
|
||||
int managed = 0;
|
||||
CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0));
|
||||
// RCCL: Only use one slice per primitive on some single node gfx9xx systems
|
||||
comm->rcclUseOneSlice = !managed && nNodes == 1;
|
||||
if (managed && nNodes > 1) {
|
||||
// This forces the minimum channels to 24
|
||||
allGather3Data[rank].nc = 6;
|
||||
|
||||
Reference in New Issue
Block a user