Support fused all reduce and elementwise operations (#1729)

* Support fused all reduce and elementwise operations

Add additional "acc" parameter to RCCL Replayer logs

Add flag which indicates availability of new API

* Fix Recorder json parsing

* Remove unreachable code

* Remove extra acc pointer check

* .

* Revert "[DEVICE] Adding ability to choose unroll factor at runtime (#1734)"

This reverts commit 4cadf3597c.

* Use noinline to reduce kernels linking time

* Don't use noinline for gfx942 and gfx950 to avoid perf regression

---------

Co-authored-by: AtlantaPepsi <timhu102@amd.com>
Co-authored-by: BertanDogancay <bertan.dogancay@gmail.com>

[ROCm/rccl commit: 9a4213356d]
Этот коммит содержится в:
Wenkai Du
2025-07-23 09:04:17 -07:00
коммит произвёл GitHub
родитель cbb648505a
Коммит caff9764d3
24 изменённых файлов: 656 добавлений и 231 удалений
+2 -3
Просмотреть файл
@@ -14,10 +14,9 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
### Added
* Added new GPU target `gfx950`.
* Added support for `unroll=1` in device-code generation to improve performance.
* Set a default of 112 channels for a single node with `8 * gfx950`.
* Added support for `unroll=1` in device-code generation to improve performance,
* Set a default of 112 channels for a single node with `8 * gfx950`,
* Enabled LL128 protocol on `gfx950`.
* Adding ability to choose unroll factor at runtime via `RCCL_UNROLL_FACTOR`. This can be set at runtime to 1, 2, or 4. This change currently increases compilation and linking time because it triples the number of kernels generated.
* Added MSCCL support for AllGather multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the environment variable `RCCL_MSCCL_FORCE_ENABLE=1`. Max message size for MSCCL AllGather usage is `12292 * sizeof(datatype) * nGPUs`.
* Thread thresholds for LL/LL128 are selected in Tuning Models for the MI300X. This impacts the number of channels used for AG and RS. Channel tuning model is bypassed if `NCCL_THREAD_THRESHOLDS`, `NCCL_MIN_NCHANNELS', or 'NCCL_MAX_NCHANNELS` are set.
* Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges.
+1 -5
Просмотреть файл
@@ -778,13 +778,9 @@ endif()
set(GEN_DIR "${HIPIFY_DIR}/gensrc")
if(ONLY_FUNCS)
message(WARNING "Using ONLY_FUNCS = ${ONLY_FUNCS}. Not meant for release builds.")
endif()
# Execute the python script to generate required files
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${BUILD_LOCAL_GPU_TARGET_ONLY} ${ONLY_FUNCS}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
RESULT_VARIABLE gen_py_result
ERROR_VARIABLE gen_py_error
+28 -8
Просмотреть файл
@@ -89,7 +89,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS };
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -117,7 +117,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
// RCCL update slice steps for AllReduce if single node
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS };
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -135,6 +135,26 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
return ncclEnqueueCheck(&info);
}
ncclResult_t ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc) {
NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype));
if (acc == nullptr) {
WARN("ncclAllReduceWithBias : acc cannot be nullptr");
return ncclInvalidArgument;
}
// RCCL update slice steps for AllReduce if single node
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, acc };
NCCLCHECK(Recorder::instance().record(rrAllReduceWithBias, info));
return ncclEnqueueCheck(&info);
}
RCCL_PARAM(AllToAllPivotEnable, "ALL_TO_ALL_PIVOT_ENABLE", 0);
NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
@@ -164,7 +184,7 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAllToAllPivotEnable()) {
struct ncclInfo info = { ncclFuncAllToAllPivot, "AllToAllPivot",
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS };
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
return ncclEnqueueCheck(&info);
} else {
int nRanks;
@@ -240,7 +260,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou
struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -311,7 +331,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count,
struct ncclInfo info = { ncclFuncReduce, "Reduce",
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -338,7 +358,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS };
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -403,7 +423,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da
struct ncclInfo info = { ncclFuncSend, "Send",
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
1, 1, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
@@ -429,7 +449,7 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype
struct ncclInfo info = { ncclFuncRecv, "Recv",
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
1, 1, nullptr };
if (!mscclIsCaller()) // when msccl falls back to
{
+6 -1
Просмотреть файл
@@ -589,7 +589,12 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(tid, nthreads, work);
using Proto = ProtoSimple<1, 1>;
if (work->acc != nullptr) {
runTreeSplit<T, RedOp, Proto>(tid, nthreads, work);
} else {
runTreeUpDown<T, RedOp, Proto>(tid, nthreads, work);
}
// Check-here
// #if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800
// runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(tid, nthreads, work);
+25
Просмотреть файл
@@ -13,6 +13,31 @@ __shared__ ncclShmemData ncclShmem;
__shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)];
#endif
struct RunWorkNop {
__device__ void run() {}
};
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args);
}
#ifdef ENABLE_COLLTRACE
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args);
}
#endif
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void ncclDevFunc_Nop();
#else
+27 -2
Просмотреть файл
@@ -11,8 +11,8 @@
#include "collectives.h"
#include "device.h"
#include "op128.h"
#include "device_table.h"
#include "reduce_kernel.h"
#include "device_table.h"
#include "network/unpack/unpack_defs.h"
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
@@ -121,8 +121,10 @@ struct ncclShmemGroup {
ncclConnInfo *sendConns[NCCL_MAX_ARITY];
void* userInput;
void* userOutput;
void* userAcc;
void* srcs[NCCL_MAX_ARITY+1];
void* dsts[NCCL_MAX_ARITY+1];
void* acc;
uint64_t barrier;
union {
unpackGroupShmem unpack;
@@ -606,7 +608,21 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) {
SpecializedRunWorkBatch().run();
} else {
NCCL_CALL_FUNCTIONS<COLL_UNROLL>(ncclShmem.funcId);
#ifdef USE_INDIRECT_FUNCTION_CALL
if (COLL_UNROLL == 1)
ncclDevFuncTable_1[ncclShmem.funcId]();
else if (COLL_UNROLL == 2)
ncclDevFuncTable_2[ncclShmem.funcId]();
else
ncclDevFuncTable_4[ncclShmem.funcId]();
#else
if (COLL_UNROLL == 1)
NCCL_CALL_FUNCTIONS_1(ncclShmem.funcId);
else if (COLL_UNROLL == 2)
NCCL_CALL_FUNCTIONS_2(ncclShmem.funcId);
else
NCCL_CALL_FUNCTIONS_4(ncclShmem.funcId);
#endif
}
if (ncclShmem.nextBatchIx == -1) break;
@@ -643,6 +659,15 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
#endif
}
__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
#ifdef ENABLE_COLLTRACE
__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
#endif
#define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
+245 -4
Просмотреть файл
@@ -31,7 +31,11 @@ template<typename RedFn, typename T, int Unroll, int BytePerPack,
int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes, typename SrcPtrFn, typename DstPtrFn>
#if defined(__gfx942__) || defined(__gfx950__)
__device__ __forceinline__ void reduceCopyPacks(
#else
__device__ __attribute__((noinline)) void reduceCopyPacks(
#endif
int nThreads, int &thread,
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
@@ -207,15 +211,209 @@ __device__ __forceinline__ void reduceCopyPacks(
thread = warp*WARP_SIZE + lane;
}
template<typename RedFn, typename T, int Unroll, int BytePerPack,
int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
#if defined(__gfx942__) || defined(__gfx950__)
__device__ __forceinline__ void reduceCopyPacksWithBias(
#else
__device__ __attribute__((noinline)) void reduceCopyPacksWithBias(
#endif
int nThreads, int &thread,
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
IntBytes &nBytesBehind, IntBytes &nBytesAhead, AccPtrFn const &accPtrFn
) {
static_assert(std::is_signed<IntBytes>::value, "IntBytes must be a signed integral type.");
//if (BytePerPack == 0) __trap();
// A hunk is the amount of contiguous data a warp consumes per loop iteration
// assuming all threads partake.
constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack;
int nWarps = nThreads/WARP_SIZE;
int warp = thread/WARP_SIZE;
int lane = thread%WARP_SIZE;
// This thread's initial position.
IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack);
IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack);
// Number of hunks to be consumed over all warps.
IntBytes nHunksAhead = nBytesAhead/(BytePerHunk + !BytePerHunk);
// Advance collective position.
nBytesBehind += nHunksAhead*BytePerHunk;
nBytesAhead -= nHunksAhead*BytePerHunk;
if (Unroll==1 && BytePerPack <= nBytesAhead) {
// Only Unroll=1 can do partial hunks (where not all threads partake).
nHunksAhead += 1;
nBytesBehind += nBytesAhead - (nBytesAhead%(BytePerPack + !BytePerPack));
nBytesAhead = nBytesAhead%(BytePerPack + !BytePerPack);
}
nHunksAhead -= warp;
RedFn redFn(redArg);
uintptr_t minSrcs[MinSrcs + !MinSrcs];
uintptr_t minDsts[MinDsts + !MinDsts];
uintptr_t accPtr = cvta_to_global(accPtrFn()) + threadBytesBehind;
BytePack<BytePerPack> bias[Unroll];
#pragma unroll
for (int s=0; s < MinSrcs; s++) {
minSrcs[s] = cvta_to_global(srcPtrFn(s)) + threadBytesBehind;
}
#pragma unroll
for (int d=0; d < MinDsts; d++) {
// Yes, for some template arguments this code will be unreachable. That's fine.
// coverity[dead_error_line]
minDsts[d] = cvta_to_global(dstPtrFn(d)) + threadBytesBehind;
}
// We dictate loop termination condition according to whether partial hunks
// can be handled or not.
while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) {
BytePack<BytePerPack> acc[Unroll];
// minSrcs[0] cannot be nullptr so we always process it
{ RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0);
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
if (0 < MultimemSrcs) {
// applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
acc[u] = applyLoadMultimem<RedFn, BytePerPack>(redFn, minSrcs[0]);
} else {
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
acc[u] = ld_volatile_global<BytePerPack>(minSrcs[0]);
// coverity[dead_error_condition]
bias[u] = ld_volatile_global<BytePerPack>(accPtr);
accPtr += WARP_SIZE*BytePerPack;
if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]);
}
minSrcs[0] += WARP_SIZE*BytePerPack;
}
}
#pragma unroll Unroll
for (int s=1; s < MinSrcs; s++) {
// Yes, for some template arguments this code will be unreachable. That's fine.
// coverity[dead_error_begin]
BytePack<BytePerPack> tmp[Unroll];
// coverity[dead_error_line]
RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0);
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
if (s < MultimemSrcs) {
// applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
// coverity[dead_error_line]
tmp[u] = applyLoadMultimem<RedFn, BytePerPack>(redFn, minSrcs[s]);
} else {
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
tmp[u] = ld_volatile_global<BytePerPack>(minSrcs[s]);
}
minSrcs[s] += WARP_SIZE*BytePerPack;
}
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
// coverity[dead_error_line]
if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]);
acc[u] = applyReduce(redFn, acc[u], tmp[u]);
}
}
for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) {
uintptr_t src = cvta_to_global(srcPtrFn(s)) + threadBytesBehind;
BytePack<BytePerPack> tmp[Unroll];
// Yes, for some template arguments this code will be unreachable. That's fine.
// coverity[dead_error_line]
RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0);
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
tmp[u] = ld_volatile_global<BytePerPack>(src);
src += WARP_SIZE*BytePerPack;
}
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
// Yes, for some template arguments this code will be unreachable. That's fine.
// coverity[dead_error_line]
if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]);
acc[u] = applyReduce(redFn, acc[u], tmp[u]);
}
}
if (postOp) {
#pragma unroll Unroll
for (int u=0; u < Unroll; u++)
acc[u] = applyPostOp(redFn, acc[u]);
}
#pragma unroll Unroll
for (int d=0; d < MinDsts; d++) {
#pragma unroll Unroll
// Yes, for some template arguments this code will be unreachable. That's fine.
// coverity[dead_error_begin]
for (int u=0; u < Unroll; u++) {
// coverity[dead_error_condition]
if (d < MultimemDsts) {
multimem_st_global(minDsts[d], acc[u]);
} else {
if (d == 0)
st_global<BytePerPack>(minDsts[d], applyReduce(redFn, acc[u], bias[u]));
else
st_global<BytePerPack>(minDsts[d], acc[u]);
}
minDsts[d] += WARP_SIZE*BytePerPack;
}
}
for (int d=MinDsts; (MinDsts < MaxDsts) && (d < MaxDsts) && (d < nDsts); d++) {
uintptr_t dstPtr = cvta_to_global(dstPtrFn(d));
uintptr_t dst = dstPtr + threadBytesBehind;
#pragma unroll Unroll
for (int u=0; u < Unroll; u++) {
st_global<BytePerPack>(dst, acc[u]);
dst += WARP_SIZE*BytePerPack;
}
}
nWarps = nThreads/WARP_SIZE;
#pragma unroll
for (int s=0; s < MinSrcs; s++) {
minSrcs[s] += (nWarps-1)*BytePerHunk;
}
#pragma unroll
// Yes, for some template arguments this code will be unreachable. That's fine.
// coverity[dead_error_line]
for (int d=0; d < MinDsts; d++) {
minDsts[d] += (nWarps-1)*BytePerHunk;
}
accPtr += (nWarps-1)*BytePerHunk;
threadBytesBehind += nWarps*BytePerHunk;
threadBytesAhead -= nWarps*BytePerHunk;
nHunksAhead -= nWarps;
}
nWarps = nThreads/WARP_SIZE;
warp = thread/WARP_SIZE;
lane = thread%WARP_SIZE;
// The last loop iteration could have been partial, i.e. not taken by all
// threads. The threads that weren't included need an extra subtraction to
// make the value warp uniform.
if (Unroll==1 && nHunksAhead > 0) nHunksAhead -= nWarps;
// Rotate warps so the warp which got the least work here will be warp 0.
// This effectively assigns: warp = (warp-nHunks+nWarps)%nWarps;
warp = -nHunksAhead;
thread = warp*WARP_SIZE + lane;
}
template<int Unroll, typename RedFn, typename T,
int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes, typename SrcPtrFn, typename DstPtrFn>
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
__device__ __forceinline__ void reduceCopy(
int thread, int nThreads,
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
IntBytes nElts
IntBytes nElts, AccPtrFn const &accPtrFn
) {
static_assert(MultimemSrcs <= MinSrcs && MultimemDsts <= MinDsts, "Multimem pointers cannot exceed respective Min values.");
//int nWarps = nThreads/WARP_SIZE;
@@ -230,6 +428,7 @@ __device__ __forceinline__ void reduceCopy(
IntBytes nBytesBehind = 0;
IntBytes nBytesAhead = nElts*sizeof(T);
bool useAcc = accPtrFn() != nullptr;
#if __cpp_if_constexpr
if constexpr (BigPackSize > sizeof(T)) {
@@ -243,11 +442,23 @@ __device__ __forceinline__ void reduceCopy(
aligned = !(__any(!aligned));
if (aligned) {
#if defined(__gfx90a__)
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
#else
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
@@ -255,6 +466,12 @@ __device__ __forceinline__ void reduceCopy(
#endif
if (nBytesAhead == 0) return;
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, /*Unroll=*/1, BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
@@ -285,17 +502,35 @@ __device__ __forceinline__ void reduceCopy(
*/
#if defined(__gfx90a__)
if (MinSrcs > 1) {
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
} else {
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
}
#else
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
@@ -303,6 +538,12 @@ __device__ __forceinline__ void reduceCopy(
#endif
if (nBytesAhead == 0) return;
if (useAcc)
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
else
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
@@ -317,14 +558,14 @@ __device__ __forceinline__ void reduceCopy(
int thread, int nThreads,
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs,
IntBytes nElts
IntBytes nElts, void *accPtr = nullptr
) {
reduceCopy<Unroll, RedFn, T,
MultimemSrcs, MinSrcs, MaxSrcs,
MultimemDsts, MinDsts, MaxDsts, PreOpSrcs, IntBytes>
(thread, nThreads, redArg, preOpArgs, postOp,
nSrcs, [=]__device__(int i) { return srcPtrs[i]; },
nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts);
nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts, [=]__device__() { return accPtr; });
}
#endif // COMMON_KERNEL_H_
+170 -111
Просмотреть файл
@@ -32,7 +32,7 @@ else:
# developing device code. The regex supports non-space containing globs '*',
# and union 'a|b'. The string representing the function has the form:
#
# <coll> <algo> <proto> <redop> <type> <unroll>
# <coll> <algo> <proto> <redop> <type>
#
# The possible values for redop, type, algo, proto can be found in the all_<foo>
# lists at the top of this file.
@@ -45,30 +45,23 @@ else:
# # Only AllReduce and Reduce
# make ONLY_FUNCS="AllReduce|Reduce"
#
# # Only AllGather with unroll=4
# make ONLY_FUNCS="AllGather * * * * 4"
#
# # Only non-reductions:
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
#
# # Only AllReduce Sum int32_t (but all algos, protos)
# make ONLY_FUNCS="AllReduce * * Sum i32"
#
# # Only AllReduce RING Max float (but all protos and unrolls)
# make ONLY_FUNCS="AllReduce RING * Max f32"
# # Only AllReduce RING Max float (but all protos)
# make ONLY_FUNCS="AllReduce RING * Max float"
#
# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16 1"
# # AllReduce TREE LL128 Prod rccl_bfloat16
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
#
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32 *"
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32 *"
#
#
# make ONLY_FUNCS="AllReduce RING/TREE LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|AllGather RING LL/LL128/SIMPLE Sum i8 1/2/4|AllToAllPivot RING SIMPLE Sum i8 1/2/4|Broadcast RING LL/LL128/SIMPLE Sum i8 1/2/4|Reduce RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|ReduceScatter RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|SendRecv RING SIMPLE Sum i8 1/2/4"
#
# # ONLY_FUNCS can be used together for debugging
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
# Paste all non-None arguments together with `sep`.
def paste(sep, *args):
@@ -77,8 +70,9 @@ def paste(sep, *args):
is_ifc = 1 if sys.argv[2] == "ON" else 0
is_colltrace = 1 if sys.argv[3] == "ON" else 0
is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0
is_local_arch_only = 1 if sys.argv[5] == "ON" else 0
func_pattern = sys.argv[5:]
func_pattern = sys.argv[6:7]
if func_pattern and func_pattern[0]:
func_pattern = func_pattern[0]
else:
@@ -139,13 +133,49 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
################################################################################
seen_unroll = []
def calc_unroll_for_local_arch():
if not is_local_arch_only:
return
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True)
rocminfo_output = res.stdout
# Parse rocminfo binary output
gfx_targets = {}
curr_name = None
for line in rocminfo_output.splitlines():
line = line.strip()
if line.startswith("Name:"):
name = line.split(':')[-1].strip()
if "gfx" in name:
curr_name = name
if line.startswith("Compute Unit:") and curr_name:
cu_count = int(line.split(':')[-1].strip())
gfx_targets[(curr_name, cu_count)] = None
curr_name = None
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
gfx_targets = list(gfx_targets.keys())
# Homogeneous system is required to build for only 1 varient of unroll factor
if len(gfx_targets) == 1:
gfx_name, cu_count = gfx_targets[0]
if "gfx950" == gfx_name:
return 1
elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80):
return 2
else:
return 4
# Helper function to check if the conditions for the collective is being met
def func_validate(coll, algo, proto, redop, ty, unroll):
def func_validate(coll, algo, proto, redop, ty):
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
return False
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
return False
return True
@@ -161,7 +191,10 @@ def func_filter(function_params, current_idx, item_list=None):
# If the paramter is equal to '*', include all possible cases for it
if current_element == "*":
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type> <unroll>
if current_idx == 0:
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
# Get the current list from all_params
current_list = all_params[current_idx]
@@ -177,12 +210,12 @@ def func_filter(function_params, current_idx, item_list=None):
# Check if the current element is recognized
elements = current_element.split("/")
current_param = all_params[current_idx]
# Iterate over the elements in the elements list
for item in elements:
if item not in current_param:
raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.")
for item in elements:
item_list.append(item)
yield from func_filter(function_params, current_idx+1, item_list)
@@ -192,9 +225,7 @@ def func_filter(function_params, current_idx, item_list=None):
else:
coll, algo, proto, redop, ty, unroll = item_list
if func_validate(coll, algo, proto, redop, ty, unroll):
if not unroll in seen_unroll:
seen_unroll.append(unroll)
if func_validate(coll, algo, proto, redop, ty):
yield(coll, algo, proto, redop, ty, unroll)
# Parse ONLY_FUNCS input and feed it to func_filter
@@ -216,6 +247,10 @@ def parse_input(func_pattern):
# Maps functions to the chosen representative for the equivalence class it
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
# if local arch only, we only need to build for 1 varient of coll_unroll.
# map the other varient of coll_unroll to this one.
if coll_unroll:
unroll = str(coll_unroll)
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
# map signed integer sum/prod to unsigned
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
@@ -233,13 +268,13 @@ def enumerate_func_rows():
for proto in all_protos:
for redop in all_redops:
for ty in all_tys:
if func_validate(coll, algo, proto, redop, ty, unroll):
if func_validate(coll, algo, proto, redop, ty):
yield (coll, algo, proto, redop, ty, unroll)
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty> <unroll>
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
def custom_sort_key(fn):
coll, algo, proto, redop, ty, unroll = fn
return (
all_unroll.index(unroll),
all_colls.index(coll),
@@ -251,6 +286,8 @@ def custom_sort_key(fn):
################################################################################
coll_unroll = calc_unroll_for_local_arch()
# Corresponds to ncclDevFuncRowToId[]
func_rows = [fn for fn in enumerate_func_rows()]
@@ -267,8 +304,6 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, "device_table.h"))
out = f.write
out("#include \"common.h\"\n\n")
if is_ifc: func_declaration = "__device__ void"
else: func_declaration = "__device__ __attribute__((noinline)) void"
@@ -285,86 +320,113 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
out("\n")
out("typedef void(*ncclDevFuncPtr_t)();\n\n")
# Generate function tables per unroll factor
tableIdx = 0
for curr_unroll in seen_unroll:
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % curr_unroll)
tableIdx = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
if curr_unroll != unroll: continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s,\n#else\n" % (tableIdx, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s,\n#endif\n" % (tableIdx, sym_ll))
else:
out("/*%4d*/ %s,\n" % (tableIdx, sym))
tableIdx += 1
out("nullptr};\n")
out("\n")
# Construct indirection function workaround
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
index1 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
if unroll != "1": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s,\n#else\n" % (index1, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll))
else:
out("/*%4d*/ %s,\n" % (index1, sym))
index1 += 1
out("nullptr};\n")
out("\n")
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
index2 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
if unroll != "2": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s,\n#else\n" % (index2, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll))
else:
out("/*%4d*/ %s,\n" % (index2, sym))
index2 += 1
out("nullptr};\n")
out("\n")
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
index4 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
if unroll != "4": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s,\n#else\n" % (index4, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll))
else:
out("/*%4d*/ %s,\n" % (index4, sym))
index4 += 1
out("nullptr};\n")
out("\n")
if not is_ifc:
out("template<int unroll, unsigned short f, unsigned short l>\n"
"struct Caller {\n"
out("template<unsigned short f, unsigned short l>\n"
"struct Caller1 {\n"
" static __forceinline__ __device__ __host__\n"
" void call(unsigned short funcIndex) noexcept\n"
" void call1(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller<unroll, f, m>::call(funcIndex) : Caller<unroll, m, l>::call(funcIndex);\n"
" return (funcIndex < m) ? Caller1<f, m>::call1(funcIndex) : Caller1<m, l>::call1(funcIndex);\n"
" }\n"
"};\n"
"\n")
for curr_unroll in seen_unroll:
out("template<unsigned short f>\n")
out("struct Caller<%s, f, f + 1>{\n" % curr_unroll)
out(" static __forceinline__ __device__ __host__\n");
out(" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable_%s[f](); }\n" % curr_unroll)
out("};\n")
out("\n")
# Create NCCL_CALL_FUNCTION helper function that will call the appropriate device function
out("template <int unroll>\n"
"__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
if is_ifc:
for curr_unroll in seen_unroll:
out(" if (unroll == %s) { ncclDevFuncTable_%s[funcIndex](); }\n" % (curr_unroll, curr_unroll))
else:
out(f" Caller<unroll, 0, {tableIdx}>::call(funcIndex);\n")
out("}\n\n")
# Create RCCL
out("template<int SpecializedFnId, typename SpecializedRunWorkBatch, bool COLLTRACE, int COLL_UNROLL>\n");
out("__device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* args);\n\n");
out("struct RunWorkNop {\n");
out(" __device__ void run() {}\n");
out("};\n\n");
out("template <int UNROLL, bool COLLTRACE>\n"
"__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void rcclGenericKernel(ncclDevKernelArgs4K const args4K) {\n"
" ncclKernelMain<-1, RunWorkNop, COLLTRACE, UNROLL>(&args4K.args);\n"
"}\n\n")
out("struct rcclKernelItem {\n");
out(" void* funcPtr;\n");
out(" int unroll;\n");
out("};\n\n");
out("/* This table contains all the __global__ functions that were compiled */\n");
out("static struct rcclKernelItem rcclKernelTable[] = {\n")
for unroll in seen_unroll:
out(" {(void*)&(rcclGenericKernel<%s, false>), %s},\n" % (unroll, unroll))
out("#ifdef ENABLE_COLLTRACE\n")
for unroll in seen_unroll:
out(" {(void*)&(rcclGenericKernel<%s, true>), %s},\n" % (unroll, unroll))
out("#endif\n");
out("};\n\n");
"\n"
"template<unsigned short f>\n"
"struct Caller1<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n"
"};\n")
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n")
out(f" Caller1<0, {index1}>::call1(funcIndex);\n")
out("}\n\n")
out("template<unsigned short f, unsigned short l>\n"
"struct Caller2 {\n"
" static __forceinline__ __device__ __host__\n"
" void call2(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller2<f, m>::call2(funcIndex) : Caller2<m, l>::call2(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller2<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n"
"};\n")
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n")
out(f" Caller2<0, {index2}>::call2(funcIndex);\n")
out("}\n\n")
out("template<unsigned short f, unsigned short l>\n"
"struct Caller4 {\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller4<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
"};\n")
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
out(f" Caller4<0, {index4}>::call4(funcIndex);\n")
out("}\n\n")
# Generate <gensrc>/device_table.cpp
if is_colltrace:
@@ -374,7 +436,7 @@ if is_colltrace:
out = f.write
out('#include "nccl_common.h"\n#include "device.h"\n')
out("\n")
seen_fns = set()
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
for fn in primary_funcs:
@@ -397,13 +459,10 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
# The mapping from function rows to valid primary function ids.
out("extern int const ncclDevFuncRowToId[] = {\n")
index = 0
offset = len(func_rows)//len(all_unroll)
start = all_unroll.index(seen_unroll[0]) * offset
end = start + offset
for fn in func_rows[start:end]:
for fn in func_rows[:len(func_rows)//3]:
fn_id, comment = -1, ""
if fn is not None:
fn_id = primary_to_index[equivalent_primary(*fn)] % offset if primary_to_index[equivalent_primary(*fn)] != -1 else -1
fn_id = primary_to_index[equivalent_primary(*fn)]
comment = " // " + paste(" ", *fn[:-1])
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
index += 1
+17 -7
Просмотреть файл
@@ -18,7 +18,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>:
// This is because of a recv buffer which is allocated to MaxRecv length in send-only cases
static constexpr int MaxRecv = Fan::MaxRecv > 1 ? Fan::MaxRecv : 1;
static constexpr int MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1;
static constexpr int Input=0, Output=1, Acc=2;
RedOp redOp;
const int tid;
const int nthreads;
@@ -26,7 +26,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>:
const int group;
const int stepLines;
Fan fan;
T *userBufs[2];
T *userBufs[3];
struct ncclConnInfo* recvConn = NULL;
volatile uint64_t* recvConnHeadPtr = NULL;
uint64_t recvConnHead;
@@ -436,6 +436,7 @@ private:
constexpr int DST = DstBuf != -1 ? 1 : 0;
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
T *accElts = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
// Always waitSend in case of cleanup
nelem = nelem < 0 ? 0 : nelem;
@@ -460,14 +461,15 @@ private:
nelem -= tid*EltPerLine;
srcElts += tid*EltPerLine;
dstElts += tid*EltPerLine;
if (accElts != nullptr) accElts += tid*EltPerLine;
int offset = tid;
int eltPerTrip = nthreads*EltPerLine;
while (nelem > 0) {
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
DataLoader dl;
DataLoader dl, accdl;
ncclLLFifoLine line[MaxRecv];
uint64_t data, peerData;
uint64_t data, peerData, accData;
if (SRC) {
dl.loadBegin(srcElts, eltInLine);
srcElts += eltPerTrip;
@@ -502,7 +504,14 @@ private:
storeLL(sendPtr(0)+offset, data, sendFlag(0));
}
if (DST) {
storeData(dstElts, data, eltInLine);
if (accElts != nullptr) {
accdl.loadBegin(accElts, eltInLine);
accElts += eltPerTrip;
accData = accdl.loadFinish();
storeData(dstElts, applyReduce(redOp, accData, data), eltInLine);
} else {
storeData(dstElts, data, eltInLine);
}
dstElts += eltPerTrip;
}
nelem -= eltPerTrip;
@@ -672,7 +681,7 @@ public:
loadRecvSync();
// coverity[var_deref_model:FALSE]
loadSendSync();
setDataPtrs(inputBuf, outputBuf);
setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr);
}
__device__ ~Primitives() {
@@ -685,9 +694,10 @@ public:
barrier();
}
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, void const *acc = nullptr) {
userBufs[Input] = (T*)inputBuf;
userBufs[Output] = (T*)outputBuf;
userBufs[Acc] = (T*)acc;
}
__device__ void moveDataPtrs(intptr_t delta) {
+20 -5
Просмотреть файл
@@ -25,7 +25,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>> {
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1;
static constexpr int Input=0, Output=1, Acc=2;;
RedOp redOp;
const int tid;
const int nthreads;
@@ -36,7 +36,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>:
const bool flagThread;
const int group;
Fan fan;
T *userBufs[2];
T *userBufs[3];
struct ncclConnInfo* recvConn = NULL;
volatile uint64_t* recvConnHeadPtr = NULL;
uint64_t recvConnHead;
@@ -347,6 +347,7 @@ private:
constexpr int DST = DstBuf != -1 ? 1 : 0;
T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
T *accPtr = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
int wireOffset = WireWordPerSlice*warp + 2*wid;
const int nwarps = nthreads/WARP_SIZE;
nelem = nelem < 0 ? 0 : nelem;
@@ -356,12 +357,25 @@ private:
nelem -= DataEltPerSlice*warp;
srcPtr += DataEltPerSlice*warp;
dstPtr += DataEltPerSlice*warp;
if (accPtr != nullptr) accPtr += DataEltPerSlice*warp;
while (nelem > 0) {
const int eltInSlice = min(nelem, DataEltPerSlice);
uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
if (SRC) loadRegsBegin(regs, srcPtr, eltInSlice);
recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SrcBuf, DstBuf>(regs, wireOffset, postOp);
if (DST) storeRegs(dstPtr, regs, eltInSlice);
if (DST) {
if (accPtr != nullptr) {
uint64_t accRegs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
loadRegsBegin(accRegs, accPtr, eltInSlice);
loadRegsFinish(accRegs);
accPtr += DataEltPerSlice*nwarps;
#pragma unroll
for (int u=0; u<NCCL_LL128_SHMEM_ELEMS_PER_THREAD; u++) {
regs[u] = applyReduce(redOp, accRegs[u], regs[u]);
}
}
storeRegs(dstPtr, regs, eltInSlice);
}
wireOffset += WireWordPerSlice*nwarps;
srcPtr += DataEltPerSlice*nwarps;
@@ -529,7 +543,7 @@ public:
loadRecvSync();
// coverity[var_deref_model:FALSE]
loadSendSync();
setDataPtrs(inputBuf, outputBuf);
setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr);
}
__device__ ~Primitives() {
@@ -542,9 +556,10 @@ public:
barrier();
}
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, void const *acc = nullptr) {
userBufs[Input] = (T*)inputBuf;
userBufs[Output] = (T*)outputBuf;
userBufs[Acc] = (T*)acc;
}
__device__ void moveDataPtrs(intptr_t delta) {
+7 -4
Просмотреть файл
@@ -265,6 +265,8 @@ private:
T* userOutput = (T*)ncclShmem.groups[group].userOutput;
if (Src) ncclShmem.groups[group].srcs[0] = (SrcBuf==Input ? userInput : userOutput) + srcIx + offset;
if (Dst) ncclShmem.groups[group].dsts[0] = (DstBuf==Input ? userInput : userOutput) + dstIx + offset;
T* userAcc = (T*)ncclShmem.groups[group].userAcc;
ncclShmem.groups[group].acc = (Dst && userAcc != nullptr) ? userAcc + dstIx + offset : nullptr;
}
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(srcIx, dstIx, offset, sliceSize);
subBarrier();
@@ -385,7 +387,7 @@ private:
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs,
Send * fan.nsend() + Dst, ncclShmem.groups[group].dsts,
workSize);
workSize, ncclShmem.groups[group].acc);
}
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
@@ -827,7 +829,7 @@ public:
// coverity[negative_returns:FALSE] => coverity thinks that index could be -1 but that's not actually the case
// coverity[var_deref_model] => coverity thinks work can dereferenced if NULL but this is not the case
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer);
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer, collWork != nullptr ? collWork->acc : nullptr);
// coverity[uninit_member] => coverity thinks fan.n is not initialized
} else if (mode == primsModePatRs || mode == primsModePatAg) { // Connect to all ranks +/- 2^n
flags |= PatMode;
@@ -902,10 +904,11 @@ public:
}
}
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer) {
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer, void const *acc) {
if (tid==0) {
ncclShmem.groups[group].userInput = (void*)inputBuf;
ncclShmem.groups[group].userOutput = (void*)outputBuf;
ncclShmem.groups[group].userAcc = (void*)acc;
ncclShmem.redOpArgs[0] = redOpArg; // scaler for local input
}
@@ -1023,7 +1026,7 @@ public:
}
// Set MSCCL data pointers
__device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
__device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf = nullptr) {
if (tid==0) {
ncclShmem.groups[group].userInput = (T*)inputBuf;
ncclShmem.groups[group].userOutput = (T*)outputBuf;
+29 -36
Просмотреть файл
@@ -28,31 +28,30 @@
using namespace rccl;
/* [RCCL] Determine which GPU kernel to execute */
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task)
{
// At this time, unroll factor is controlled only by passed in unroll argument
// After more investigation, this may be further tuned by the actual task being processed
struct ncclKernelMatch {
void* kernelFn;
bool specialized;
};
#ifdef ENABLE_COLLTRACE
int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]) / 2;
int firstKernel = useCollTrace ? numKernels : 0;
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0))
static ncclKernelMatch const ncclKerns[6] = {
{(void *)ncclDevKernel_Generic_1, true},
{(void *)ncclDevKernel_Generic_2, true},
{(void *)ncclDevKernel_Generic_4, true},
{(void *)ncclDevKernelDebug_Generic_1, true},
{(void *)ncclDevKernelDebug_Generic_2, true},
{(void *)ncclDevKernelDebug_Generic_4, true}
};
#else
int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]);
int firstKernel = 0;
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll)
static ncclKernelMatch const ncclKerns[3] = {
{(void*)ncclDevKernel_Generic_1, true},
{(void*)ncclDevKernel_Generic_2, true},
{(void*)ncclDevKernel_Generic_4, true}
};
#endif
// Check if the requested unroll exists
for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
if (rcclKernelTable[firstKernel + kernelIdx].unroll == unroll) {
return rcclKernelTable[firstKernel + kernelIdx].funcPtr;
}
}
// If does not match, return null
return nullptr;
}
static int rcclProtoGrainSize(int proto, ncclComm *comm){
switch (proto) {
case NCCL_PROTO_LL: return 16;
@@ -82,7 +81,7 @@ NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0);
// Returns maximum kernel stack size of all CUDA kernels
ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize) {
constexpr int KernelCount = sizeof(rcclKernelTable)/sizeof(rcclKernelTable[0]);
constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]);
ncclResult_t result = ncclSuccess;
if (maxStackSize) *maxStackSize = 0;
@@ -95,7 +94,7 @@ ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* ma
int ncclMaxSharedMem = rcclShmemDynamicSize(cudaArch, WarpSize);
for (int k=0; k < KernelCount; k++) {
void* fn = rcclKernelTable[k].funcPtr;
void* fn = ncclKerns[k].kernelFn;
cudaFuncAttributes attr = {0};
if (fn == nullptr) continue;
@@ -356,6 +355,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) {
devWork.sendbuff = (void*)task->sendbuff;
devWork.recvbuff = (void*)task->recvbuff;
devWork.acc = (void*)task->acc;
devWork.sendbuffOffset = task->sendbuffOffset;
devWork.recvbuffOffset = task->recvbuffOffset;
devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs;
@@ -806,12 +806,8 @@ static ncclResult_t scheduleCollTasksToPlan(
//plan->channelMask.masks[channelId/64] |= (2ull<<devWork->channelHi) - (1ull<<devWork->channelLo);
plan->threadPerBlock = std::max(plan->threadPerBlock, 192 /* 3*WARP_SIZE */);
if (!plan->kernelSpecialized) {
#ifdef ENABLE_COLLTRACE
plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled);
#else
plan->kernelFn = rcclGetKernelIndex(comm->unroll, false);
#endif
plan->kernelSpecialized = true;
plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn;
plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized;
}
if (comm->rank == 0) {
@@ -1127,12 +1123,8 @@ static ncclResult_t scheduleP2pTasksToPlan(
plan->threadPerBlock = std::max(plan->threadPerBlock, NCCL_MAX_NTHREADS);
if (!plan->kernelSpecialized) {
#ifdef ENABLE_COLLTRACE
plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled);
#else
plan->kernelFn = rcclGetKernelIndex(comm->unroll, false);
#endif
plan->kernelSpecialized = true;
plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn;
plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized;
}
// Compute how much to split operations
@@ -2505,6 +2497,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) {
t->sliceSteps = info->sliceSteps;
t->eActivationMask = __atomic_load_n(&ncclProfilerEventMask, __ATOMIC_RELAXED);
t->opCount = comm->opCount;
t->acc = info->acc;
planner->nTasksColl += 1;
ncclTaskCollSorterInsert(&planner->collSorter, t, t->trafficBytes);
@@ -2554,8 +2547,8 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
}
NCCLCHECKGOTO(ArgsCheck(info), ret, fail);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d",
info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count,
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p acc %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d",
info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->acc, info->count,
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream,
info->comm->planner.nTasksP2p + info->comm->planner.nTasksColl,
info->comm->localRankToRank[info->comm->localRank]);
+6 -1
Просмотреть файл
@@ -31,7 +31,7 @@
#define RCCL_API_TRACE_VERSION_MAJOR 0
// should be increased every time new members are added to existing dispatch tables
#define RCCL_API_TRACE_VERSION_PATCH 0
#define RCCL_API_TRACE_VERSION_PATCH 1
#if !defined(RCCL_EXTERN_C_INIT)
# ifdef __cplusplus
@@ -61,6 +61,10 @@ typedef ncclResult_t (*ncclAllReduce_fn_t)(const void* sendbuff, void* recvbuff,
size_t count, ncclDataType_t datatype,
ncclRedOp_t op, struct ncclComm* comm,
hipStream_t stream);
typedef ncclResult_t (*ncclAllReduceWithBias_fn_t)(const void* sendbuff, void* recvbuff,
size_t count, ncclDataType_t datatype,
ncclRedOp_t op, struct ncclComm* comm,
hipStream_t stream, const void* acc);
typedef ncclResult_t (*ncclAllToAll_fn_t)(const void* sendbuff, void* recvbuff,
size_t count, ncclDataType_t datatype,
ncclComm_t comm, hipStream_t stream);
@@ -194,6 +198,7 @@ typedef struct rcclApiFuncTable
mscclUnloadAlgo_fn_t mscclUnloadAlgo_fn;
ncclCommRegister_fn_t ncclCommRegister_fn;
ncclCommDeregister_fn_t ncclCommDeregister_fn;
ncclAllReduceWithBias_fn_t ncclAllReduceWithBias_fn;
} rcclApiFuncTable;
+1
Просмотреть файл
@@ -195,6 +195,7 @@ struct ncclTaskColl {
ncclFunc_t func;
void const* sendbuff;
void* recvbuff;
void const* acc;
size_t count;
int root;
ncclDataType_t datatype;
+1
Просмотреть файл
@@ -310,6 +310,7 @@ struct alignas(16) ncclDevWorkColl {
uint16_t pivotA2ANumBiRings:15, profilerEnabled:1;
void* recvbuff;
void* sendbuff;
void *acc;
uintptr_t sendbuffOffset;
uintptr_t recvbuffOffset;
uintptr_t* sendbuffRmtAddrs;
+1
Просмотреть файл
@@ -29,6 +29,7 @@ struct ncclInfo {
// Algorithm details
int chunkSteps;
int sliceSteps;
const void* acc;
};
#endif
+5
Просмотреть файл
@@ -74,5 +74,10 @@ typedef enum {
#define NCCL_ALGO_PROTO_IGNORE -1.0
#define NCCL_NUM_UNROLLS 3 // 1/2/4
#define NCCL_UNROLL_1 0
#define NCCL_UNROLL_2 1
#define NCCL_UNROLL_4 2
#define NCCL_NUM_FLOATS 6 // half/float/double/rccl_bfloat16/rccl_float8/rccl_bfloat8
#endif
+3
Просмотреть файл
@@ -16,6 +16,7 @@ typedef enum {
rrAllGather,
rrReduceScatter,
rrAllReduce,
rrAllReduceWithBias,
rrSend,
rrRecv,
rrAllToAll,
@@ -51,6 +52,7 @@ constexpr const char* rcclCallStr[]
"AllGather",
"ReduceScatter",
"AllReduce",
"AllReduceWithBias",
"Send",
"Recv",
"AllToAll",
@@ -94,6 +96,7 @@ struct rcclApiCall {
uint64_t opCount = 0;
const void* sendbuff = NULL;
void* recvbuff = NULL;
const void* acc = NULL;
void* sendPtrBase = NULL;
void* recvPtrBase = NULL;
size_t sendPtrExtent = 0;
+2 -3
Просмотреть файл
@@ -610,6 +610,8 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
// RCCL: create persistent stream for calloc
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
// RCCL: determine and set unroll factor for comm
NCCLCHECK(commSetUnrollFactor(comm));
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
@@ -1958,9 +1960,6 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail);
// RCCL: determine and set unroll factor for comm
NCCLCHECK(commSetUnrollFactor(comm));
#ifdef ENABLE_MSCCLPP
if (job->parent) {
if (job->parent->mscclppCompatible) {
+20 -2
Просмотреть файл
@@ -153,6 +153,11 @@ ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** han
ncclResult_t
ncclCommDeregister_impl(const ncclComm_t comm, void* handle);
ncclResult_t
ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm,
cudaStream_t stream, const void* acc);
namespace rccl
{
namespace
@@ -211,10 +216,11 @@ RCCL_ASSERT_OFFSET(rcclApiFuncTable, mscclRunAlgo_fn, 33);
RCCL_ASSERT_OFFSET(rcclApiFuncTable, mscclUnloadAlgo_fn, 34);
RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclCommRegister_fn, 35);
RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclCommDeregister_fn, 36);
RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclAllReduceWithBias_fn, 37);
#undef RCCL_ASSERT_OFFSET
static_assert(sizeof(rcclApiFuncTable) == compute_table_size(37),
static_assert(sizeof(rcclApiFuncTable) == compute_table_size(38),
"Update table major/step version and add a new offset assertion if this "
"fails to compile");
@@ -261,7 +267,8 @@ RcclGetFunctionTable_impl()
&mscclRunAlgo_impl,
&mscclUnloadAlgo_impl,
&ncclCommRegister_impl,
&ncclCommDeregister_impl };
&ncclCommDeregister_impl,
&ncclAllReduceWithBias_impl };
#if defined(RCCL_ROCPROFILER_REGISTER) && RCCL_ROCPROFILER_REGISTER > 0
std::array<void*, 1> table_array{ tbl };
@@ -301,6 +308,9 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff,
NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, hipStream_t stream);
NCCL_API(ncclResult_t, ncclAllReduceWithBias, const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, hipStream_t stream, const void* acc);
NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream);
@@ -411,6 +421,14 @@ ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t
datatype, op, comm, stream);
}
ncclResult_t
ncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc)
{
return ::rccl::RcclGetFunctionTable()->ncclAllReduceWithBias_fn(sendbuff, recvbuff, count,
datatype, op, comm, stream, acc);
}
ncclResult_t
ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, hipStream_t stream)
+5 -4
Просмотреть файл
@@ -35,6 +35,7 @@ rcclApiCall::rcclApiCall(rcclCall_t type, const ncclInfo& info)://name(rcclCallS
opCount(info.comm->opCount),
sendbuff(info.sendbuff),
recvbuff(info.recvbuff),
acc(info.acc),
count(info.count),
datatype(info.datatype),
op(info.op),
@@ -69,7 +70,7 @@ std::string alloc_fmt = "%s : [returned ptr : %p, size : %zu, context : [";
std::string free_fmt = "%s : [ptr : %p, context : [";
std::string redop_fmt = "%s : [scalar : %p, datatype : %d, op : %d, residence : %d, comm : %p, context : [";
std::string redopdestroy_fmt = "%s : [op : %d, comm : %p, context : [";
std::string coll_fmt = "%s : [opCount : %lx, sendbuff : [addr : %p, base : %p, size : %zu], recvbuff : [addr : %p, base : %p, size : %zu], count : %zu, datatype : %d, op : %d, root : %d, comm : %p, nranks : %d, stream : %p, task : %d, globalrank : %d, context : [";
std::string coll_fmt = "%s : [opCount : %lx, sendbuff : [addr : %p, base : %p, size : %zu], recvbuff : [addr : %p, base : %p, size : %zu], acc : %p, count : %zu, datatype : %d, op : %d, root : %d, comm : %p, nranks : %d, stream : %p, task : %d, globalrank : %d, context : [";
Recorder::Recorder()
{
@@ -256,7 +257,7 @@ void Recorder::write(const rcclApiCall &call)
default: // collectives
len = snprintf(buffer, 4096, coll_fmt.c_str(),
rcclCallStr[call.type], call.opCount, call.sendbuff, call.sendPtrBase, call.sendPtrExtent,
call.recvbuff, call.recvPtrBase, call.recvPtrExtent, call.count, call.datatype,
call.recvbuff, call.recvPtrBase, call.recvPtrExtent, call.acc, call.count, call.datatype,
call.op, call.root, call.comm, call.nRanks, call.stream, call.nTasks, call.globalRank);
}
@@ -686,9 +687,9 @@ void parseJsonEntry(const char* entry, std::vector<rcclApiCall>& calls)
default:
assert(sscanf(str.c_str() + end + 3, (coll_fmt.substr(5) + ctxt_fmt).c_str(),
&call.opCount, &call.sendbuff, &call.sendPtrBase, &call.sendPtrExtent, &call.recvbuff, &call.recvPtrBase, &call.recvPtrExtent,
&call.count, &call.datatype, &call.op, &call.root,
&call.acc, &call.count, &call.datatype, &call.op, &call.root,
&call.comm, &call.nRanks, &call.stream, &call.nTasks, &call.globalRank, &call.timestamp, &call.tid,
&call.hipDev, &call.graphCaptured, &call.graphID) == 21);
&call.hipDev, &call.graphCaptured, &call.graphID) == 22);
}
calls.push_back(call);
}
+22
Просмотреть файл
@@ -23,6 +23,7 @@
#define RCCL_FLOAT8 1
#define RCCL_GATHER_SCATTER 1
#define RCCL_ALLTOALLV 1
#define RCCL_ALLREDUCE_WITH_BIAS 1
#ifdef __cplusplus
extern "C" {
@@ -563,6 +564,27 @@ ncclResult_t pncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream);
/*! @endcond */
/*! @brief All-Reduce-with-Bias
@details Reduces data arrays of length *count* in *sendbuff* using *op* operation, and
leaves identical copies of result on each *recvbuff*.
In-place operation will happen if sendbuff == recvbuff.
@return Result code. See @ref rccl_result_code for more details.
@param[in] sendbuff Input data array to reduce
@param[out] recvbuff Data array to store reduced result array
@param[in] count Number of elements in data buffer
@param[in] datatype Data buffer element datatype
@param[in] op Reduction operator
@param[in] comm Communicator group object to execute on
@param[in] stream HIP stream to execute collective on
@param[in] acc Bias data array to reduce */
ncclResult_t ncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream, const void* acc);
/*! @cond include_hidden */
ncclResult_t pncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream, const void* acc);
/*! @endcond */
/*! @brief Reduce-Scatter
@details Reduces data in *sendbuff* using *op* operation and leaves reduced result
scattered over the devices so that *recvbuff* on rank i will contain the i-th
+5 -34
Просмотреть файл
@@ -127,43 +127,14 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count,
return ncclSuccess;
}
//RCCL runtime param to set Unroll Factor
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
//If RCCL runtime param is set, it will override defaults, if supported
if (rcclParamUnrollFactor() != 0) {
#if ENABLE_COLLTRACE
if(rcclGetKernelIndex(rcclParamUnrollFactor(), comm->collTraceEnabled)) {
#else
if(rcclGetKernelIndex(rcclParamUnrollFactor(), false)) {
#endif
comm->unroll = rcclParamUnrollFactor();
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
return ncclSuccess;
}
else {
// Fall back to default unroll
WARN("Requested RCCL_UNROLL_FACTOR: %ld is invalid and does not exist in `rcclKernelTable`. Falling back to pre-set unroll.", rcclParamUnrollFactor());
}
}
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
if (comm->nNodes == 1)
comm->unroll = 1;
else
comm->unroll = 2;
}
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
//on MI300X and gfx908, use unroll=2
comm->unroll = 2;
if(IsArchMatch(devProp.gcnArchName, "gfx950"))
comm->unroll = NCCL_UNROLL_1;
else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)))
comm->unroll = NCCL_UNROLL_2;
else
comm->unroll = 4;
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
comm->unroll = NCCL_UNROLL_4;
return ncclSuccess;
}
+8 -1
Просмотреть файл
@@ -487,6 +487,13 @@ void Replayer::replay()
NCCL_CALL(ncclAllReduce(sbuffer, rbuffer, call.count, call.datatype, call.op, commMap[call.comm], streams[call.stream].first));
break;
}
case rrAllReduceWithBias:
{
std::vector<char> acc(call.count * ncclTypeSize(call.datatype));
NCCL_CALL(ncclAllReduceWithBias(sbuffer, rbuffer, call.count, call.datatype, call.op, commMap[call.comm], streams[call.stream].first, acc.data()));
HIP_CALL(hipStreamSynchronize(streams[call.stream].first)); // TODO: remove, and further verify behavior of fused AR
break;
}
// a2av
case rrAllToAllv:
{
@@ -670,4 +677,4 @@ int main(int argc, char **argv)
replayer.replay();
MPI_Finalize();
return 0;
}
}