diff --git a/projects/rccl/CHANGELOG.md b/projects/rccl/CHANGELOG.md index 8553ae679a..8522ef7cba 100644 --- a/projects/rccl/CHANGELOG.md +++ b/projects/rccl/CHANGELOG.md @@ -14,10 +14,9 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https: ### Added * Added new GPU target `gfx950`. -* Added support for `unroll=1` in device-code generation to improve performance. -* Set a default of 112 channels for a single node with `8 * gfx950`. +* Added support for `unroll=1` in device-code generation to improve performance, +* Set a default of 112 channels for a single node with `8 * gfx950`, * Enabled LL128 protocol on `gfx950`. -* Adding ability to choose unroll factor at runtime via `RCCL_UNROLL_FACTOR`. This can be set at runtime to 1, 2, or 4. This change currently increases compilation and linking time because it triples the number of kernels generated. * Added MSCCL support for AllGather multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the environment variable `RCCL_MSCCL_FORCE_ENABLE=1`. Max message size for MSCCL AllGather usage is `12292 * sizeof(datatype) * nGPUs`. * Thread thresholds for LL/LL128 are selected in Tuning Models for the MI300X. This impacts the number of channels used for AG and RS. Channel tuning model is bypassed if `NCCL_THREAD_THRESHOLDS`, `NCCL_MIN_NCHANNELS', or 'NCCL_MAX_NCHANNELS` are set. * Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges. diff --git a/projects/rccl/CMakeLists.txt b/projects/rccl/CMakeLists.txt index 1942a1d7a9..439eb9eae5 100644 --- a/projects/rccl/CMakeLists.txt +++ b/projects/rccl/CMakeLists.txt @@ -778,13 +778,9 @@ endif() set(GEN_DIR "${HIPIFY_DIR}/gensrc") -if(ONLY_FUNCS) - message(WARNING "Using ONLY_FUNCS = ${ONLY_FUNCS}. Not meant for release builds.") -endif() - # Execute the python script to generate required files execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${BUILD_LOCAL_GPU_TARGET_ONLY} ${ONLY_FUNCS} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} RESULT_VARIABLE gen_py_result ERROR_VARIABLE gen_py_error diff --git a/projects/rccl/src/collectives.cc b/projects/rccl/src/collectives.cc index 09f5eeca37..3da5eb2f3f 100644 --- a/projects/rccl/src/collectives.cc +++ b/projects/rccl/src/collectives.cc @@ -89,7 +89,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen struct ncclInfo info = { ncclFuncAllGather, "AllGather", sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */ - ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS }; + ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { @@ -117,7 +117,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou // RCCL update slice steps for AllReduce if single node struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ - ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS }; + ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { @@ -135,6 +135,26 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou return ncclEnqueueCheck(&info); } +ncclResult_t ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc) { + NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce, + NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype)); + + if (acc == nullptr) { + WARN("ncclAllReduceWithBias : acc cannot be nullptr"); + return ncclInvalidArgument; + } + + // RCCL update slice steps for AllReduce if single node + struct ncclInfo info = { ncclFuncAllReduce, "AllReduce", + sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */ + ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, acc }; + + NCCLCHECK(Recorder::instance().record(rrAllReduceWithBias, info)); + + return ncclEnqueueCheck(&info); +} + RCCL_PARAM(AllToAllPivotEnable, "ALL_TO_ALL_PIVOT_ENABLE", 0); NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, @@ -164,7 +184,7 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAllToAllPivotEnable()) { struct ncclInfo info = { ncclFuncAllToAllPivot, "AllToAllPivot", sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */ - ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS }; + ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr }; return ncclEnqueueCheck(&info); } else { int nRanks; @@ -240,7 +260,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou struct ncclInfo info = { ncclFuncBroadcast, "Broadcast", sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */ - BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS }; + BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { @@ -311,7 +331,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count, struct ncclInfo info = { ncclFuncReduce, "Reduce", sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */ - REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS }; + REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { @@ -338,7 +358,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter", sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */ - REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS }; + REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { @@ -403,7 +423,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da struct ncclInfo info = { ncclFuncSend, "Send", NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ - 1, 1 }; + 1, 1, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { @@ -429,7 +449,7 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype struct ncclInfo info = { ncclFuncRecv, "Recv", NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */ - 1, 1 }; + 1, 1, nullptr }; if (!mscclIsCaller()) // when msccl falls back to { diff --git a/projects/rccl/src/device/all_reduce.h b/projects/rccl/src/device/all_reduce.h index 6183577708..ee4dc29185 100644 --- a/projects/rccl/src/device/all_reduce.h +++ b/projects/rccl/src/device/all_reduce.h @@ -589,7 +589,12 @@ struct RunWorkColl struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { - runTreeUpDown>(tid, nthreads, work); + using Proto = ProtoSimple<1, 1>; + if (work->acc != nullptr) { + runTreeSplit(tid, nthreads, work); + } else { + runTreeUpDown(tid, nthreads, work); + } // Check-here // #if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800 // runTreeUpDown>(tid, nthreads, work); diff --git a/projects/rccl/src/device/common.cu b/projects/rccl/src/device/common.cu index 40aeb5cd76..36d396fbb8 100644 --- a/projects/rccl/src/device/common.cu +++ b/projects/rccl/src/device/common.cu @@ -13,6 +13,31 @@ __shared__ ncclShmemData ncclShmem; __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; #endif +struct RunWorkNop { + __device__ void run() {} +}; + +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args); +} +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args); +} +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args); +} +#ifdef ENABLE_COLLTRACE +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args); +} +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args); +} +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args); +} +#endif + #ifdef USE_INDIRECT_FUNCTION_CALL __device__ void ncclDevFunc_Nop(); #else diff --git a/projects/rccl/src/device/common.h b/projects/rccl/src/device/common.h index 84825f2b7d..c6c61021ad 100644 --- a/projects/rccl/src/device/common.h +++ b/projects/rccl/src/device/common.h @@ -11,8 +11,8 @@ #include "collectives.h" #include "device.h" #include "op128.h" -#include "device_table.h" #include "reduce_kernel.h" +#include "device_table.h" #include "network/unpack/unpack_defs.h" #define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree @@ -121,8 +121,10 @@ struct ncclShmemGroup { ncclConnInfo *sendConns[NCCL_MAX_ARITY]; void* userInput; void* userOutput; + void* userAcc; void* srcs[NCCL_MAX_ARITY+1]; void* dsts[NCCL_MAX_ARITY+1]; + void* acc; uint64_t barrier; union { unpackGroupShmem unpack; @@ -606,7 +608,21 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) { SpecializedRunWorkBatch().run(); } else { - NCCL_CALL_FUNCTIONS(ncclShmem.funcId); +#ifdef USE_INDIRECT_FUNCTION_CALL + if (COLL_UNROLL == 1) + ncclDevFuncTable_1[ncclShmem.funcId](); + else if (COLL_UNROLL == 2) + ncclDevFuncTable_2[ncclShmem.funcId](); + else + ncclDevFuncTable_4[ncclShmem.funcId](); +#else + if (COLL_UNROLL == 1) + NCCL_CALL_FUNCTIONS_1(ncclShmem.funcId); + else if (COLL_UNROLL == 2) + NCCL_CALL_FUNCTIONS_2(ncclShmem.funcId); + else + NCCL_CALL_FUNCTIONS_4(ncclShmem.funcId); +#endif } if (ncclShmem.nextBatchIx == -1) break; @@ -643,6 +659,15 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a #endif } +__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +#ifdef ENABLE_COLLTRACE +__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +#endif + #define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \ __global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {} diff --git a/projects/rccl/src/device/common_kernel.h b/projects/rccl/src/device/common_kernel.h index a601618aef..9223c8304c 100644 --- a/projects/rccl/src/device/common_kernel.h +++ b/projects/rccl/src/device/common_kernel.h @@ -31,7 +31,11 @@ template +#if defined(__gfx942__) || defined(__gfx950__) __device__ __forceinline__ void reduceCopyPacks( +#else +__device__ __attribute__((noinline)) void reduceCopyPacks( +#endif int nThreads, int &thread, uint64_t redArg, uint64_t *preOpArgs, bool postOp, int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn, @@ -207,15 +211,209 @@ __device__ __forceinline__ void reduceCopyPacks( thread = warp*WARP_SIZE + lane; } +template +#if defined(__gfx942__) || defined(__gfx950__) +__device__ __forceinline__ void reduceCopyPacksWithBias( +#else +__device__ __attribute__((noinline)) void reduceCopyPacksWithBias( +#endif + int nThreads, int &thread, + uint64_t redArg, uint64_t *preOpArgs, bool postOp, + int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn, + IntBytes &nBytesBehind, IntBytes &nBytesAhead, AccPtrFn const &accPtrFn + ) { + static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); + //if (BytePerPack == 0) __trap(); + + // A hunk is the amount of contiguous data a warp consumes per loop iteration + // assuming all threads partake. + constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack; + int nWarps = nThreads/WARP_SIZE; + int warp = thread/WARP_SIZE; + int lane = thread%WARP_SIZE; + + // This thread's initial position. + IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack); + IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack); + // Number of hunks to be consumed over all warps. + IntBytes nHunksAhead = nBytesAhead/(BytePerHunk + !BytePerHunk); + // Advance collective position. + nBytesBehind += nHunksAhead*BytePerHunk; + nBytesAhead -= nHunksAhead*BytePerHunk; + if (Unroll==1 && BytePerPack <= nBytesAhead) { + // Only Unroll=1 can do partial hunks (where not all threads partake). + nHunksAhead += 1; + nBytesBehind += nBytesAhead - (nBytesAhead%(BytePerPack + !BytePerPack)); + nBytesAhead = nBytesAhead%(BytePerPack + !BytePerPack); + } + nHunksAhead -= warp; + + RedFn redFn(redArg); + uintptr_t minSrcs[MinSrcs + !MinSrcs]; + uintptr_t minDsts[MinDsts + !MinDsts]; + uintptr_t accPtr = cvta_to_global(accPtrFn()) + threadBytesBehind; + BytePack bias[Unroll]; + + #pragma unroll + for (int s=0; s < MinSrcs; s++) { + minSrcs[s] = cvta_to_global(srcPtrFn(s)) + threadBytesBehind; + } + + #pragma unroll + for (int d=0; d < MinDsts; d++) { + // Yes, for some template arguments this code will be unreachable. That's fine. + // coverity[dead_error_line] + minDsts[d] = cvta_to_global(dstPtrFn(d)) + threadBytesBehind; + } + + // We dictate loop termination condition according to whether partial hunks + // can be handled or not. + while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) { + BytePack acc[Unroll]; + + // minSrcs[0] cannot be nullptr so we always process it + { RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (0 < MultimemSrcs) { + // applyLoadMultimem uses relaxed semantics for same reason we use volatile below. + acc[u] = applyLoadMultimem(redFn, minSrcs[0]); + } else { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + acc[u] = ld_volatile_global(minSrcs[0]); + // coverity[dead_error_condition] + bias[u] = ld_volatile_global(accPtr); + accPtr += WARP_SIZE*BytePerPack; + if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]); + } + minSrcs[0] += WARP_SIZE*BytePerPack; + } + } + + #pragma unroll Unroll + for (int s=1; s < MinSrcs; s++) { + // Yes, for some template arguments this code will be unreachable. That's fine. + // coverity[dead_error_begin] + BytePack tmp[Unroll]; + // coverity[dead_error_line] + RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (s < MultimemSrcs) { + // applyLoadMultimem uses relaxed semantics for same reason we use volatile below. + // coverity[dead_error_line] + tmp[u] = applyLoadMultimem(redFn, minSrcs[s]); + } else { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + tmp[u] = ld_volatile_global(minSrcs[s]); + } + minSrcs[s] += WARP_SIZE*BytePerPack; + } + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // coverity[dead_error_line] + if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); + acc[u] = applyReduce(redFn, acc[u], tmp[u]); + } + } + + for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) { + uintptr_t src = cvta_to_global(srcPtrFn(s)) + threadBytesBehind; + BytePack tmp[Unroll]; + // Yes, for some template arguments this code will be unreachable. That's fine. + // coverity[dead_error_line] + RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + tmp[u] = ld_volatile_global(src); + src += WARP_SIZE*BytePerPack; + } + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Yes, for some template arguments this code will be unreachable. That's fine. + // coverity[dead_error_line] + if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); + acc[u] = applyReduce(redFn, acc[u], tmp[u]); + } + } + + if (postOp) { + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) + acc[u] = applyPostOp(redFn, acc[u]); + } + + #pragma unroll Unroll + for (int d=0; d < MinDsts; d++) { + #pragma unroll Unroll + // Yes, for some template arguments this code will be unreachable. That's fine. + // coverity[dead_error_begin] + for (int u=0; u < Unroll; u++) { + // coverity[dead_error_condition] + if (d < MultimemDsts) { + multimem_st_global(minDsts[d], acc[u]); + } else { + if (d == 0) + st_global(minDsts[d], applyReduce(redFn, acc[u], bias[u])); + else + st_global(minDsts[d], acc[u]); + } + minDsts[d] += WARP_SIZE*BytePerPack; + } + } + for (int d=MinDsts; (MinDsts < MaxDsts) && (d < MaxDsts) && (d < nDsts); d++) { + uintptr_t dstPtr = cvta_to_global(dstPtrFn(d)); + uintptr_t dst = dstPtr + threadBytesBehind; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + st_global(dst, acc[u]); + dst += WARP_SIZE*BytePerPack; + } + } + + nWarps = nThreads/WARP_SIZE; + #pragma unroll + for (int s=0; s < MinSrcs; s++) { + minSrcs[s] += (nWarps-1)*BytePerHunk; + } + #pragma unroll + // Yes, for some template arguments this code will be unreachable. That's fine. + // coverity[dead_error_line] + for (int d=0; d < MinDsts; d++) { + minDsts[d] += (nWarps-1)*BytePerHunk; + } + accPtr += (nWarps-1)*BytePerHunk; + threadBytesBehind += nWarps*BytePerHunk; + threadBytesAhead -= nWarps*BytePerHunk; + nHunksAhead -= nWarps; + } + + nWarps = nThreads/WARP_SIZE; + warp = thread/WARP_SIZE; + lane = thread%WARP_SIZE; + // The last loop iteration could have been partial, i.e. not taken by all + // threads. The threads that weren't included need an extra subtraction to + // make the value warp uniform. + if (Unroll==1 && nHunksAhead > 0) nHunksAhead -= nWarps; + // Rotate warps so the warp which got the least work here will be warp 0. + // This effectively assigns: warp = (warp-nHunks+nWarps)%nWarps; + warp = -nHunksAhead; + thread = warp*WARP_SIZE + lane; +} + template + typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn> __device__ __forceinline__ void reduceCopy( int thread, int nThreads, uint64_t redArg, uint64_t *preOpArgs, bool postOp, int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn, - IntBytes nElts + IntBytes nElts, AccPtrFn const &accPtrFn ) { static_assert(MultimemSrcs <= MinSrcs && MultimemDsts <= MinDsts, "Multimem pointers cannot exceed respective Min values."); //int nWarps = nThreads/WARP_SIZE; @@ -230,6 +428,7 @@ __device__ __forceinline__ void reduceCopy( IntBytes nBytesBehind = 0; IntBytes nBytesAhead = nElts*sizeof(T); + bool useAcc = accPtrFn() != nullptr; #if __cpp_if_constexpr if constexpr (BigPackSize > sizeof(T)) { @@ -243,11 +442,23 @@ __device__ __forceinline__ void reduceCopy( aligned = !(__any(!aligned)); if (aligned) { #if defined(__gfx90a__) + if (useAcc) + reduceCopyPacksWithBias 1) ? 2 : Unroll), BigPackSize, + MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs> + (nThreads, thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn); + else reduceCopyPacks 1) ? 2 : Unroll), BigPackSize, MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs> (nThreads, thread, redArg, preOpArgs, postOp, nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead); #else + if (useAcc) + reduceCopyPacksWithBias + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn); + else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, @@ -255,6 +466,12 @@ __device__ __forceinline__ void reduceCopy( #endif if (nBytesAhead == 0) return; + if (useAcc) + reduceCopyPacksWithBias + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn); + else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, @@ -285,17 +502,35 @@ __device__ __forceinline__ void reduceCopy( */ #if defined(__gfx90a__) if (MinSrcs > 1) { + if (useAcc) + reduceCopyPacksWithBias + (nThreads, thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn); + else reduceCopyPacks (nThreads, thread, redArg, preOpArgs, postOp, nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead); } else { + if (useAcc) + reduceCopyPacksWithBias + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn); + else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead); } #else + if (useAcc) + reduceCopyPacksWithBias + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn); + else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, @@ -303,6 +538,12 @@ __device__ __forceinline__ void reduceCopy( #endif if (nBytesAhead == 0) return; + if (useAcc) + reduceCopyPacksWithBias + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn); + else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, @@ -317,14 +558,14 @@ __device__ __forceinline__ void reduceCopy( int thread, int nThreads, uint64_t redArg, uint64_t *preOpArgs, bool postOp, int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs, - IntBytes nElts + IntBytes nElts, void *accPtr = nullptr ) { reduceCopy (thread, nThreads, redArg, preOpArgs, postOp, nSrcs, [=]__device__(int i) { return srcPtrs[i]; }, - nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts); + nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts, [=]__device__() { return accPtr; }); } #endif // COMMON_KERNEL_H_ diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index 5644d154e4..c9d6d9b266 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -32,7 +32,7 @@ else: # developing device code. The regex supports non-space containing globs '*', # and union 'a|b'. The string representing the function has the form: # -# +# # # The possible values for redop, type, algo, proto can be found in the all_ # lists at the top of this file. @@ -45,30 +45,23 @@ else: # # Only AllReduce and Reduce # make ONLY_FUNCS="AllReduce|Reduce" # -# # Only AllGather with unroll=4 -# make ONLY_FUNCS="AllGather * * * * 4" -# # # Only non-reductions: # make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" # # # Only AllReduce Sum int32_t (but all algos, protos) # make ONLY_FUNCS="AllReduce * * Sum i32" # -# # Only AllReduce RING Max float (but all protos and unrolls) -# make ONLY_FUNCS="AllReduce RING * Max f32" +# # Only AllReduce RING Max float (but all protos) +# make ONLY_FUNCS="AllReduce RING * Max float" # -# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1 -# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16 1" +# # AllReduce TREE LL128 Prod rccl_bfloat16 +# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16" # -# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter) -# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32 *" +# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter) +# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float" # --- or --- -# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32 *" -# -# -# make ONLY_FUNCS="AllReduce RING/TREE LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|AllGather RING LL/LL128/SIMPLE Sum i8 1/2/4|AllToAllPivot RING SIMPLE Sum i8 1/2/4|Broadcast RING LL/LL128/SIMPLE Sum i8 1/2/4|Reduce RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|ReduceScatter RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|SendRecv RING SIMPLE Sum i8 1/2/4" -# -# # ONLY_FUNCS can be used together for debugging +# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" +# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t" # Paste all non-None arguments together with `sep`. def paste(sep, *args): @@ -77,8 +70,9 @@ def paste(sep, *args): is_ifc = 1 if sys.argv[2] == "ON" else 0 is_colltrace = 1 if sys.argv[3] == "ON" else 0 is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0 +is_local_arch_only = 1 if sys.argv[5] == "ON" else 0 -func_pattern = sys.argv[5:] +func_pattern = sys.argv[6:7] if func_pattern and func_pattern[0]: func_pattern = func_pattern[0] else: @@ -139,13 +133,49 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower} ################################################################################ -seen_unroll = [] +def calc_unroll_for_local_arch(): + if not is_local_arch_only: + return + + rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo" + + res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True) + rocminfo_output = res.stdout + + # Parse rocminfo binary output + gfx_targets = {} + curr_name = None + for line in rocminfo_output.splitlines(): + line = line.strip() + + if line.startswith("Name:"): + name = line.split(':')[-1].strip() + if "gfx" in name: + curr_name = name + if line.startswith("Compute Unit:") and curr_name: + cu_count = int(line.split(':')[-1].strip()) + gfx_targets[(curr_name, cu_count)] = None + curr_name = None + + # We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts + # Use (gfx_name, cu_count) as key for dictionary and convert it to list here + gfx_targets = list(gfx_targets.keys()) + + # Homogeneous system is required to build for only 1 varient of unroll factor + if len(gfx_targets) == 1: + gfx_name, cu_count = gfx_targets[0] + if "gfx950" == gfx_name: + return 1 + elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80): + return 2 + else: + return 4 # Helper function to check if the conditions for the collective is being met -def func_validate(coll, algo, proto, redop, ty, unroll): +def func_validate(coll, algo, proto, redop, ty): if redop == "SumPostDiv" and ty[0] not in ("i","u"): return False - if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll: + if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]: return False return True @@ -161,7 +191,10 @@ def func_filter(function_params, current_idx, item_list=None): # If the paramter is equal to '*', include all possible cases for it if current_element == "*": - # all_params list must be in the same order as function_params --> + if current_idx == 0: + raise ValueError("Error: Paramter 'COLL' can not be type all '*'.") + + # all_params list must be in the same order as function_params --> # Get the current list from all_params current_list = all_params[current_idx] @@ -177,12 +210,12 @@ def func_filter(function_params, current_idx, item_list=None): # Check if the current element is recognized elements = current_element.split("/") current_param = all_params[current_idx] - + # Iterate over the elements in the elements list for item in elements: if item not in current_param: raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.") - + for item in elements: item_list.append(item) yield from func_filter(function_params, current_idx+1, item_list) @@ -192,9 +225,7 @@ def func_filter(function_params, current_idx, item_list=None): else: coll, algo, proto, redop, ty, unroll = item_list - if func_validate(coll, algo, proto, redop, ty, unroll): - if not unroll in seen_unroll: - seen_unroll.append(unroll) + if func_validate(coll, algo, proto, redop, ty): yield(coll, algo, proto, redop, ty, unroll) # Parse ONLY_FUNCS input and feed it to func_filter @@ -216,6 +247,10 @@ def parse_input(func_pattern): # Maps functions to the chosen representative for the equivalence class it # belongs to. For instance (sum, signed int) maps to (sum, unsigned int). def equivalent_primary(coll, algo, proto, redop, ty, unroll): + # if local arch only, we only need to build for 1 varient of coll_unroll. + # map the other varient of coll_unroll to this one. + if coll_unroll: + unroll = str(coll_unroll) if coll in ("AllReduce", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i": @@ -233,13 +268,13 @@ def enumerate_func_rows(): for proto in all_protos: for redop in all_redops: for ty in all_tys: - if func_validate(coll, algo, proto, redop, ty, unroll): + if func_validate(coll, algo, proto, redop, ty): yield (coll, algo, proto, redop, ty, unroll) -# Sort the hashmap based on custom key +# Sort the hashmap based on custom key def custom_sort_key(fn): coll, algo, proto, redop, ty, unroll = fn - + return ( all_unroll.index(unroll), all_colls.index(coll), @@ -251,6 +286,8 @@ def custom_sort_key(fn): ################################################################################ +coll_unroll = calc_unroll_for_local_arch() + # Corresponds to ncclDevFuncRowToId[] func_rows = [fn for fn in enumerate_func_rows()] @@ -267,8 +304,6 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: print("-- Generating %s" % os.path.join(gensrc, "device_table.h")) out = f.write - out("#include \"common.h\"\n\n") - if is_ifc: func_declaration = "__device__ void" else: func_declaration = "__device__ __attribute__((noinline)) void" @@ -285,86 +320,113 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("\n") out("typedef void(*ncclDevFuncPtr_t)();\n\n") - - # Generate function tables per unroll factor - tableIdx = 0 - for curr_unroll in seen_unroll: - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % curr_unroll) - tableIdx = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn - if curr_unroll != unroll: continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (tableIdx, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (tableIdx, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (tableIdx, sym)) - tableIdx += 1 - out("nullptr};\n") - out("\n") - - # Construct indirection function workaround + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n") + index1 = 0 + for fn in primary_funcs: + coll, algo, proto, redop, ty, unroll = fn + if unroll != "1": continue + sym = paste("_", "ncclDevFunc", *fn) + if fn[2] == "LL128": + out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s,\n#else\n" % (index1, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll)) + else: + out("/*%4d*/ %s,\n" % (index1, sym)) + index1 += 1 + out("nullptr};\n") + out("\n") + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n") + index2 = 0 + for fn in primary_funcs: + coll, algo, proto, redop, ty, unroll = fn + if unroll != "2": continue + sym = paste("_", "ncclDevFunc", *fn) + if fn[2] == "LL128": + out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s,\n#else\n" % (index2, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll)) + else: + out("/*%4d*/ %s,\n" % (index2, sym)) + index2 += 1 + out("nullptr};\n") + out("\n") + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") + index4 = 0 + for fn in primary_funcs: + coll, algo, proto, redop, ty, unroll = fn + if unroll != "4": continue + sym = paste("_", "ncclDevFunc", *fn) + if fn[2] == "LL128": + out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s,\n#else\n" % (index4, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll)) + else: + out("/*%4d*/ %s,\n" % (index4, sym)) + index4 += 1 + out("nullptr};\n") + out("\n") + if not is_ifc: - out("template\n" - "struct Caller {\n" + out("template\n" + "struct Caller1 {\n" " static __forceinline__ __device__ __host__\n" - " void call(unsigned short funcIndex) noexcept\n" + " void call1(unsigned short funcIndex) noexcept\n" " {\n" " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex);\n" + " return (funcIndex < m) ? Caller1::call1(funcIndex) : Caller1::call1(funcIndex);\n" " }\n" "};\n" - "\n") - - for curr_unroll in seen_unroll: - out("template\n") - out("struct Caller<%s, f, f + 1>{\n" % curr_unroll) - out(" static __forceinline__ __device__ __host__\n"); - out(" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable_%s[f](); }\n" % curr_unroll) - out("};\n") - - out("\n") - # Create NCCL_CALL_FUNCTION helper function that will call the appropriate device function - out("template \n" - "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n") - if is_ifc: - for curr_unroll in seen_unroll: - out(" if (unroll == %s) { ncclDevFuncTable_%s[funcIndex](); }\n" % (curr_unroll, curr_unroll)) - else: - out(f" Caller::call(funcIndex);\n") - out("}\n\n") - - # Create RCCL - out("template\n"); - out("__device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* args);\n\n"); - - out("struct RunWorkNop {\n"); - out(" __device__ void run() {}\n"); - out("};\n\n"); - - out("template \n" - "__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void rcclGenericKernel(ncclDevKernelArgs4K const args4K) {\n" - " ncclKernelMain<-1, RunWorkNop, COLLTRACE, UNROLL>(&args4K.args);\n" - "}\n\n") - - out("struct rcclKernelItem {\n"); - out(" void* funcPtr;\n"); - out(" int unroll;\n"); - out("};\n\n"); - - out("/* This table contains all the __global__ functions that were compiled */\n"); - out("static struct rcclKernelItem rcclKernelTable[] = {\n") - for unroll in seen_unroll: - out(" {(void*)&(rcclGenericKernel<%s, false>), %s},\n" % (unroll, unroll)) - out("#ifdef ENABLE_COLLTRACE\n") - for unroll in seen_unroll: - out(" {(void*)&(rcclGenericKernel<%s, true>), %s},\n" % (unroll, unroll)) - out("#endif\n"); - out("};\n\n"); + "\n" + "template\n" + "struct Caller1{\n" + " static __forceinline__ __device__ __host__\n" + " void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n" + "};\n") + out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n") + out(f" Caller1<0, {index1}>::call1(funcIndex);\n") + out("}\n\n") + out("template\n" + "struct Caller2 {\n" + " static __forceinline__ __device__ __host__\n" + " void call2(unsigned short funcIndex) noexcept\n" + " {\n" + " constexpr unsigned short m = f + (l - f) / 2;\n" + " return (funcIndex < m) ? Caller2::call2(funcIndex) : Caller2::call2(funcIndex);\n" + " }\n" + "};\n" + "\n" + "template\n" + "struct Caller2{\n" + " static __forceinline__ __device__ __host__\n" + " void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n" + "};\n") + out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n") + out(f" Caller2<0, {index2}>::call2(funcIndex);\n") + out("}\n\n") + out("template\n" + "struct Caller4 {\n" + " static __forceinline__ __device__ __host__\n" + " void call4(unsigned short funcIndex) noexcept\n" + " {\n" + " constexpr unsigned short m = f + (l - f) / 2;\n" + " return (funcIndex < m) ? Caller4::call4(funcIndex) : Caller4::call4(funcIndex);\n" + " }\n" + "};\n" + "\n" + "template\n" + "struct Caller4{\n" + " static __forceinline__ __device__ __host__\n" + " void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n" + "};\n") + out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n") + out(f" Caller4<0, {index4}>::call4(funcIndex);\n") + out("}\n\n") # Generate /device_table.cpp if is_colltrace: @@ -374,7 +436,7 @@ if is_colltrace: out = f.write out('#include "nccl_common.h"\n#include "device.h"\n') out("\n") - + seen_fns = set() out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n") for fn in primary_funcs: @@ -397,13 +459,10 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: # The mapping from function rows to valid primary function ids. out("extern int const ncclDevFuncRowToId[] = {\n") index = 0 - offset = len(func_rows)//len(all_unroll) - start = all_unroll.index(seen_unroll[0]) * offset - end = start + offset - for fn in func_rows[start:end]: + for fn in func_rows[:len(func_rows)//3]: fn_id, comment = -1, "" if fn is not None: - fn_id = primary_to_index[equivalent_primary(*fn)] % offset if primary_to_index[equivalent_primary(*fn)] != -1 else -1 + fn_id = primary_to_index[equivalent_primary(*fn)] comment = " // " + paste(" ", *fn[:-1]) out("/*%4d*/ %d,%s\n" % (index, fn_id, comment)) index += 1 diff --git a/projects/rccl/src/device/prims_ll.h b/projects/rccl/src/device/prims_ll.h index f05097429f..c81eba89bf 100644 --- a/projects/rccl/src/device/prims_ll.h +++ b/projects/rccl/src/device/prims_ll.h @@ -18,7 +18,7 @@ class Primitives: // This is because of a recv buffer which is allocated to MaxRecv length in send-only cases static constexpr int MaxRecv = Fan::MaxRecv > 1 ? Fan::MaxRecv : 1; static constexpr int MaxSend = Fan::MaxSend; - static constexpr int Input=0, Output=1; + static constexpr int Input=0, Output=1, Acc=2; RedOp redOp; const int tid; const int nthreads; @@ -26,7 +26,7 @@ class Primitives: const int group; const int stepLines; Fan fan; - T *userBufs[2]; + T *userBufs[3]; struct ncclConnInfo* recvConn = NULL; volatile uint64_t* recvConnHeadPtr = NULL; uint64_t recvConnHead; @@ -436,6 +436,7 @@ private: constexpr int DST = DstBuf != -1 ? 1 : 0; T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; + T *accElts = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx; // Always waitSend in case of cleanup nelem = nelem < 0 ? 0 : nelem; @@ -460,14 +461,15 @@ private: nelem -= tid*EltPerLine; srcElts += tid*EltPerLine; dstElts += tid*EltPerLine; + if (accElts != nullptr) accElts += tid*EltPerLine; int offset = tid; int eltPerTrip = nthreads*EltPerLine; while (nelem > 0) { int eltInLine = EltPerLine < nelem ? EltPerLine : nelem; - DataLoader dl; + DataLoader dl, accdl; ncclLLFifoLine line[MaxRecv]; - uint64_t data, peerData; + uint64_t data, peerData, accData; if (SRC) { dl.loadBegin(srcElts, eltInLine); srcElts += eltPerTrip; @@ -502,7 +504,14 @@ private: storeLL(sendPtr(0)+offset, data, sendFlag(0)); } if (DST) { - storeData(dstElts, data, eltInLine); + if (accElts != nullptr) { + accdl.loadBegin(accElts, eltInLine); + accElts += eltPerTrip; + accData = accdl.loadFinish(); + storeData(dstElts, applyReduce(redOp, accData, data), eltInLine); + } else { + storeData(dstElts, data, eltInLine); + } dstElts += eltPerTrip; } nelem -= eltPerTrip; @@ -672,7 +681,7 @@ public: loadRecvSync(); // coverity[var_deref_model:FALSE] loadSendSync(); - setDataPtrs(inputBuf, outputBuf); + setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr); } __device__ ~Primitives() { @@ -685,9 +694,10 @@ public: barrier(); } - __device__ void setDataPtrs(void const *inputBuf, void *outputBuf) { + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, void const *acc = nullptr) { userBufs[Input] = (T*)inputBuf; userBufs[Output] = (T*)outputBuf; + userBufs[Acc] = (T*)acc; } __device__ void moveDataPtrs(intptr_t delta) { diff --git a/projects/rccl/src/device/prims_ll128.h b/projects/rccl/src/device/prims_ll128.h index c750342947..42a66ab065 100644 --- a/projects/rccl/src/device/prims_ll128.h +++ b/projects/rccl/src/device/prims_ll128.h @@ -25,7 +25,7 @@ class Primitives: public PrimitivesWithoutDirect> { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; - static constexpr int Input=0, Output=1; + static constexpr int Input=0, Output=1, Acc=2;; RedOp redOp; const int tid; const int nthreads; @@ -36,7 +36,7 @@ class Primitives: const bool flagThread; const int group; Fan fan; - T *userBufs[2]; + T *userBufs[3]; struct ncclConnInfo* recvConn = NULL; volatile uint64_t* recvConnHeadPtr = NULL; uint64_t recvConnHead; @@ -347,6 +347,7 @@ private: constexpr int DST = DstBuf != -1 ? 1 : 0; T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; + T *accPtr = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx; int wireOffset = WireWordPerSlice*warp + 2*wid; const int nwarps = nthreads/WARP_SIZE; nelem = nelem < 0 ? 0 : nelem; @@ -356,12 +357,25 @@ private: nelem -= DataEltPerSlice*warp; srcPtr += DataEltPerSlice*warp; dstPtr += DataEltPerSlice*warp; + if (accPtr != nullptr) accPtr += DataEltPerSlice*warp; while (nelem > 0) { const int eltInSlice = min(nelem, DataEltPerSlice); uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD]; if (SRC) loadRegsBegin(regs, srcPtr, eltInSlice); recvReduceSendCopy(regs, wireOffset, postOp); - if (DST) storeRegs(dstPtr, regs, eltInSlice); + if (DST) { + if (accPtr != nullptr) { + uint64_t accRegs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD]; + loadRegsBegin(accRegs, accPtr, eltInSlice); + loadRegsFinish(accRegs); + accPtr += DataEltPerSlice*nwarps; + #pragma unroll + for (int u=0; uacc : nullptr); } __device__ ~Primitives() { @@ -542,9 +556,10 @@ public: barrier(); } - __device__ void setDataPtrs(void const *inputBuf, void *outputBuf) { + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, void const *acc = nullptr) { userBufs[Input] = (T*)inputBuf; userBufs[Output] = (T*)outputBuf; + userBufs[Acc] = (T*)acc; } __device__ void moveDataPtrs(intptr_t delta) { diff --git a/projects/rccl/src/device/prims_simple.h b/projects/rccl/src/device/prims_simple.h index 58526495bd..fd8b75a147 100644 --- a/projects/rccl/src/device/prims_simple.h +++ b/projects/rccl/src/device/prims_simple.h @@ -265,6 +265,8 @@ private: T* userOutput = (T*)ncclShmem.groups[group].userOutput; if (Src) ncclShmem.groups[group].srcs[0] = (SrcBuf==Input ? userInput : userOutput) + srcIx + offset; if (Dst) ncclShmem.groups[group].dsts[0] = (DstBuf==Input ? userInput : userOutput) + dstIx + offset; + T* userAcc = (T*)ncclShmem.groups[group].userAcc; + ncclShmem.groups[group].acc = (Dst && userAcc != nullptr) ? userAcc + dstIx + offset : nullptr; } waitPeer(srcIx, dstIx, offset, sliceSize); subBarrier(); @@ -385,7 +387,7 @@ private: (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs, Send * fan.nsend() + Dst, ncclShmem.groups[group].dsts, - workSize); + workSize, ncclShmem.groups[group].acc); } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) @@ -827,7 +829,7 @@ public: // coverity[negative_returns:FALSE] => coverity thinks that index could be -1 but that's not actually the case // coverity[var_deref_model] => coverity thinks work can dereferenced if NULL but this is not the case - setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer); + setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer, collWork != nullptr ? collWork->acc : nullptr); // coverity[uninit_member] => coverity thinks fan.n is not initialized } else if (mode == primsModePatRs || mode == primsModePatAg) { // Connect to all ranks +/- 2^n flags |= PatMode; @@ -902,10 +904,11 @@ public: } } - __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer) { + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer, void const *acc) { if (tid==0) { ncclShmem.groups[group].userInput = (void*)inputBuf; ncclShmem.groups[group].userOutput = (void*)outputBuf; + ncclShmem.groups[group].userAcc = (void*)acc; ncclShmem.redOpArgs[0] = redOpArg; // scaler for local input } @@ -1023,7 +1026,7 @@ public: } // Set MSCCL data pointers - __device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf) { + __device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf = nullptr) { if (tid==0) { ncclShmem.groups[group].userInput = (T*)inputBuf; ncclShmem.groups[group].userOutput = (T*)outputBuf; diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index 3241652781..e9934e05d1 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -28,31 +28,30 @@ using namespace rccl; -/* [RCCL] Determine which GPU kernel to execute */ -void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task) -{ - // At this time, unroll factor is controlled only by passed in unroll argument - // After more investigation, this may be further tuned by the actual task being processed +struct ncclKernelMatch { + void* kernelFn; + bool specialized; +}; #ifdef ENABLE_COLLTRACE - int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]) / 2; - int firstKernel = useCollTrace ? numKernels : 0; +#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0)) +static ncclKernelMatch const ncclKerns[6] = { + {(void *)ncclDevKernel_Generic_1, true}, + {(void *)ncclDevKernel_Generic_2, true}, + {(void *)ncclDevKernel_Generic_4, true}, + {(void *)ncclDevKernelDebug_Generic_1, true}, + {(void *)ncclDevKernelDebug_Generic_2, true}, + {(void *)ncclDevKernelDebug_Generic_4, true} +}; #else - int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]); - int firstKernel = 0; +#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll) +static ncclKernelMatch const ncclKerns[3] = { + {(void*)ncclDevKernel_Generic_1, true}, + {(void*)ncclDevKernel_Generic_2, true}, + {(void*)ncclDevKernel_Generic_4, true} +}; #endif - // Check if the requested unroll exists - for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) { - if (rcclKernelTable[firstKernel + kernelIdx].unroll == unroll) { - return rcclKernelTable[firstKernel + kernelIdx].funcPtr; - } - } - - // If does not match, return null - return nullptr; -} - static int rcclProtoGrainSize(int proto, ncclComm *comm){ switch (proto) { case NCCL_PROTO_LL: return 16; @@ -82,7 +81,7 @@ NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); // Returns maximum kernel stack size of all CUDA kernels ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize) { - constexpr int KernelCount = sizeof(rcclKernelTable)/sizeof(rcclKernelTable[0]); + constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]); ncclResult_t result = ncclSuccess; if (maxStackSize) *maxStackSize = 0; @@ -95,7 +94,7 @@ ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* ma int ncclMaxSharedMem = rcclShmemDynamicSize(cudaArch, WarpSize); for (int k=0; k < KernelCount; k++) { - void* fn = rcclKernelTable[k].funcPtr; + void* fn = ncclKerns[k].kernelFn; cudaFuncAttributes attr = {0}; if (fn == nullptr) continue; @@ -356,6 +355,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { devWork.sendbuff = (void*)task->sendbuff; devWork.recvbuff = (void*)task->recvbuff; + devWork.acc = (void*)task->acc; devWork.sendbuffOffset = task->sendbuffOffset; devWork.recvbuffOffset = task->recvbuffOffset; devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs; @@ -806,12 +806,8 @@ static ncclResult_t scheduleCollTasksToPlan( //plan->channelMask.masks[channelId/64] |= (2ull<channelHi) - (1ull<channelLo); plan->threadPerBlock = std::max(plan->threadPerBlock, 192 /* 3*WARP_SIZE */); if (!plan->kernelSpecialized) { -#ifdef ENABLE_COLLTRACE - plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled); -#else - plan->kernelFn = rcclGetKernelIndex(comm->unroll, false); -#endif - plan->kernelSpecialized = true; + plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn; + plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized; } if (comm->rank == 0) { @@ -1127,12 +1123,8 @@ static ncclResult_t scheduleP2pTasksToPlan( plan->threadPerBlock = std::max(plan->threadPerBlock, NCCL_MAX_NTHREADS); if (!plan->kernelSpecialized) { -#ifdef ENABLE_COLLTRACE - plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled); -#else - plan->kernelFn = rcclGetKernelIndex(comm->unroll, false); -#endif - plan->kernelSpecialized = true; + plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn; + plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized; } // Compute how much to split operations @@ -2505,6 +2497,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) { t->sliceSteps = info->sliceSteps; t->eActivationMask = __atomic_load_n(&ncclProfilerEventMask, __ATOMIC_RELAXED); t->opCount = comm->opCount; + t->acc = info->acc; planner->nTasksColl += 1; ncclTaskCollSorterInsert(&planner->collSorter, t, t->trafficBytes); @@ -2554,8 +2547,8 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) { } NCCLCHECKGOTO(ArgsCheck(info), ret, fail); - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d", - info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count, + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p acc %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d", + info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->acc, info->count, info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream, info->comm->planner.nTasksP2p + info->comm->planner.nTasksColl, info->comm->localRankToRank[info->comm->localRank]); diff --git a/projects/rccl/src/include/api_trace.h b/projects/rccl/src/include/api_trace.h index 8329718671..1b51753651 100644 --- a/projects/rccl/src/include/api_trace.h +++ b/projects/rccl/src/include/api_trace.h @@ -31,7 +31,7 @@ #define RCCL_API_TRACE_VERSION_MAJOR 0 // should be increased every time new members are added to existing dispatch tables -#define RCCL_API_TRACE_VERSION_PATCH 0 +#define RCCL_API_TRACE_VERSION_PATCH 1 #if !defined(RCCL_EXTERN_C_INIT) # ifdef __cplusplus @@ -61,6 +61,10 @@ typedef ncclResult_t (*ncclAllReduce_fn_t)(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, struct ncclComm* comm, hipStream_t stream); +typedef ncclResult_t (*ncclAllReduceWithBias_fn_t)(const void* sendbuff, void* recvbuff, + size_t count, ncclDataType_t datatype, + ncclRedOp_t op, struct ncclComm* comm, + hipStream_t stream, const void* acc); typedef ncclResult_t (*ncclAllToAll_fn_t)(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream); @@ -194,6 +198,7 @@ typedef struct rcclApiFuncTable mscclUnloadAlgo_fn_t mscclUnloadAlgo_fn; ncclCommRegister_fn_t ncclCommRegister_fn; ncclCommDeregister_fn_t ncclCommDeregister_fn; + ncclAllReduceWithBias_fn_t ncclAllReduceWithBias_fn; } rcclApiFuncTable; diff --git a/projects/rccl/src/include/comm.h b/projects/rccl/src/include/comm.h index bd6ef4721d..d918f6eee4 100644 --- a/projects/rccl/src/include/comm.h +++ b/projects/rccl/src/include/comm.h @@ -195,6 +195,7 @@ struct ncclTaskColl { ncclFunc_t func; void const* sendbuff; void* recvbuff; + void const* acc; size_t count; int root; ncclDataType_t datatype; diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index c7536f819e..efc1087c36 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -310,6 +310,7 @@ struct alignas(16) ncclDevWorkColl { uint16_t pivotA2ANumBiRings:15, profilerEnabled:1; void* recvbuff; void* sendbuff; + void *acc; uintptr_t sendbuffOffset; uintptr_t recvbuffOffset; uintptr_t* sendbuffRmtAddrs; diff --git a/projects/rccl/src/include/info.h b/projects/rccl/src/include/info.h index 91bfd09c94..415a4f6197 100644 --- a/projects/rccl/src/include/info.h +++ b/projects/rccl/src/include/info.h @@ -29,6 +29,7 @@ struct ncclInfo { // Algorithm details int chunkSteps; int sliceSteps; + const void* acc; }; #endif \ No newline at end of file diff --git a/projects/rccl/src/include/nccl_common.h b/projects/rccl/src/include/nccl_common.h index 36f7be7bd6..6bab98d954 100644 --- a/projects/rccl/src/include/nccl_common.h +++ b/projects/rccl/src/include/nccl_common.h @@ -74,5 +74,10 @@ typedef enum { #define NCCL_ALGO_PROTO_IGNORE -1.0 +#define NCCL_NUM_UNROLLS 3 // 1/2/4 +#define NCCL_UNROLL_1 0 +#define NCCL_UNROLL_2 1 +#define NCCL_UNROLL_4 2 + #define NCCL_NUM_FLOATS 6 // half/float/double/rccl_bfloat16/rccl_float8/rccl_bfloat8 #endif diff --git a/projects/rccl/src/include/recorder.h b/projects/rccl/src/include/recorder.h index 0d897d634d..a3440a5a57 100644 --- a/projects/rccl/src/include/recorder.h +++ b/projects/rccl/src/include/recorder.h @@ -16,6 +16,7 @@ typedef enum { rrAllGather, rrReduceScatter, rrAllReduce, + rrAllReduceWithBias, rrSend, rrRecv, rrAllToAll, @@ -51,6 +52,7 @@ constexpr const char* rcclCallStr[] "AllGather", "ReduceScatter", "AllReduce", + "AllReduceWithBias", "Send", "Recv", "AllToAll", @@ -94,6 +96,7 @@ struct rcclApiCall { uint64_t opCount = 0; const void* sendbuff = NULL; void* recvbuff = NULL; + const void* acc = NULL; void* sendPtrBase = NULL; void* recvPtrBase = NULL; size_t sendPtrExtent = 0; diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 56c99feaef..16a8a6179a 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -610,6 +610,8 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in // RCCL: create persistent stream for calloc CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking)); + // RCCL: determine and set unroll factor for comm + NCCLCHECK(commSetUnrollFactor(comm)); comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false; comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false; @@ -1958,9 +1960,6 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail); - // RCCL: determine and set unroll factor for comm - NCCLCHECK(commSetUnrollFactor(comm)); - #ifdef ENABLE_MSCCLPP if (job->parent) { if (job->parent->mscclppCompatible) { diff --git a/projects/rccl/src/misc/api_trace.cc b/projects/rccl/src/misc/api_trace.cc index 92e09eb149..1a77c5ca56 100644 --- a/projects/rccl/src/misc/api_trace.cc +++ b/projects/rccl/src/misc/api_trace.cc @@ -153,6 +153,11 @@ ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** han ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle); +ncclResult_t +ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, + cudaStream_t stream, const void* acc); + namespace rccl { namespace @@ -211,10 +216,11 @@ RCCL_ASSERT_OFFSET(rcclApiFuncTable, mscclRunAlgo_fn, 33); RCCL_ASSERT_OFFSET(rcclApiFuncTable, mscclUnloadAlgo_fn, 34); RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclCommRegister_fn, 35); RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclCommDeregister_fn, 36); +RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclAllReduceWithBias_fn, 37); #undef RCCL_ASSERT_OFFSET -static_assert(sizeof(rcclApiFuncTable) == compute_table_size(37), +static_assert(sizeof(rcclApiFuncTable) == compute_table_size(38), "Update table major/step version and add a new offset assertion if this " "fails to compile"); @@ -261,7 +267,8 @@ RcclGetFunctionTable_impl() &mscclRunAlgo_impl, &mscclUnloadAlgo_impl, &ncclCommRegister_impl, - &ncclCommDeregister_impl }; + &ncclCommDeregister_impl, + &ncclAllReduceWithBias_impl }; #if defined(RCCL_ROCPROFILER_REGISTER) && RCCL_ROCPROFILER_REGISTER > 0 std::array table_array{ tbl }; @@ -301,6 +308,9 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff, NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, hipStream_t stream); +NCCL_API(ncclResult_t, ncclAllReduceWithBias, const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, hipStream_t stream, const void* acc); + NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream); @@ -411,6 +421,14 @@ ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, op, comm, stream); } +ncclResult_t +ncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc) +{ + return ::rccl::RcclGetFunctionTable()->ncclAllReduceWithBias_fn(sendbuff, recvbuff, count, + datatype, op, comm, stream, acc); +} + ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) diff --git a/projects/rccl/src/misc/recorder.cc b/projects/rccl/src/misc/recorder.cc index 26535e313a..8268378ad0 100644 --- a/projects/rccl/src/misc/recorder.cc +++ b/projects/rccl/src/misc/recorder.cc @@ -35,6 +35,7 @@ rcclApiCall::rcclApiCall(rcclCall_t type, const ncclInfo& info)://name(rcclCallS opCount(info.comm->opCount), sendbuff(info.sendbuff), recvbuff(info.recvbuff), + acc(info.acc), count(info.count), datatype(info.datatype), op(info.op), @@ -69,7 +70,7 @@ std::string alloc_fmt = "%s : [returned ptr : %p, size : %zu, context : ["; std::string free_fmt = "%s : [ptr : %p, context : ["; std::string redop_fmt = "%s : [scalar : %p, datatype : %d, op : %d, residence : %d, comm : %p, context : ["; std::string redopdestroy_fmt = "%s : [op : %d, comm : %p, context : ["; -std::string coll_fmt = "%s : [opCount : %lx, sendbuff : [addr : %p, base : %p, size : %zu], recvbuff : [addr : %p, base : %p, size : %zu], count : %zu, datatype : %d, op : %d, root : %d, comm : %p, nranks : %d, stream : %p, task : %d, globalrank : %d, context : ["; +std::string coll_fmt = "%s : [opCount : %lx, sendbuff : [addr : %p, base : %p, size : %zu], recvbuff : [addr : %p, base : %p, size : %zu], acc : %p, count : %zu, datatype : %d, op : %d, root : %d, comm : %p, nranks : %d, stream : %p, task : %d, globalrank : %d, context : ["; Recorder::Recorder() { @@ -256,7 +257,7 @@ void Recorder::write(const rcclApiCall &call) default: // collectives len = snprintf(buffer, 4096, coll_fmt.c_str(), rcclCallStr[call.type], call.opCount, call.sendbuff, call.sendPtrBase, call.sendPtrExtent, - call.recvbuff, call.recvPtrBase, call.recvPtrExtent, call.count, call.datatype, + call.recvbuff, call.recvPtrBase, call.recvPtrExtent, call.acc, call.count, call.datatype, call.op, call.root, call.comm, call.nRanks, call.stream, call.nTasks, call.globalRank); } @@ -686,9 +687,9 @@ void parseJsonEntry(const char* entry, std::vector& calls) default: assert(sscanf(str.c_str() + end + 3, (coll_fmt.substr(5) + ctxt_fmt).c_str(), &call.opCount, &call.sendbuff, &call.sendPtrBase, &call.sendPtrExtent, &call.recvbuff, &call.recvPtrBase, &call.recvPtrExtent, - &call.count, &call.datatype, &call.op, &call.root, + &call.acc, &call.count, &call.datatype, &call.op, &call.root, &call.comm, &call.nRanks, &call.stream, &call.nTasks, &call.globalRank, &call.timestamp, &call.tid, - &call.hipDev, &call.graphCaptured, &call.graphID) == 21); + &call.hipDev, &call.graphCaptured, &call.graphID) == 22); } calls.push_back(call); } diff --git a/projects/rccl/src/nccl.h.in b/projects/rccl/src/nccl.h.in index 6910dffa9a..22aa4467d9 100644 --- a/projects/rccl/src/nccl.h.in +++ b/projects/rccl/src/nccl.h.in @@ -23,6 +23,7 @@ #define RCCL_FLOAT8 1 #define RCCL_GATHER_SCATTER 1 #define RCCL_ALLTOALLV 1 +#define RCCL_ALLREDUCE_WITH_BIAS 1 #ifdef __cplusplus extern "C" { @@ -563,6 +564,27 @@ ncclResult_t pncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream); /*! @endcond */ +/*! @brief All-Reduce-with-Bias + @details Reduces data arrays of length *count* in *sendbuff* using *op* operation, and + leaves identical copies of result on each *recvbuff*. + In-place operation will happen if sendbuff == recvbuff. + @return Result code. See @ref rccl_result_code for more details. + + @param[in] sendbuff Input data array to reduce + @param[out] recvbuff Data array to store reduced result array + @param[in] count Number of elements in data buffer + @param[in] datatype Data buffer element datatype + @param[in] op Reduction operator + @param[in] comm Communicator group object to execute on + @param[in] stream HIP stream to execute collective on + @param[in] acc Bias data array to reduce */ +ncclResult_t ncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream, const void* acc); +/*! @cond include_hidden */ +ncclResult_t pncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream, const void* acc); +/*! @endcond */ + /*! @brief Reduce-Scatter @details Reduces data in *sendbuff* using *op* operation and leaves reduced result scattered over the devices so that *recvbuff* on rank i will contain the i-th diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index 643b21f3d0..0ef494961a 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -127,43 +127,14 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, return ncclSuccess; } -//RCCL runtime param to set Unroll Factor -RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0); - ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { hipDeviceProp_t devProp; CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); - - //If RCCL runtime param is set, it will override defaults, if supported - if (rcclParamUnrollFactor() != 0) { -#if ENABLE_COLLTRACE - if(rcclGetKernelIndex(rcclParamUnrollFactor(), comm->collTraceEnabled)) { -#else - if(rcclGetKernelIndex(rcclParamUnrollFactor(), false)) { -#endif - comm->unroll = rcclParamUnrollFactor(); - INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll); - return ncclSuccess; - } - else { - // Fall back to default unroll - WARN("Requested RCCL_UNROLL_FACTOR: %ld is invalid and does not exist in `rcclKernelTable`. Falling back to pre-set unroll.", rcclParamUnrollFactor()); - } - } - - if (IsArchMatch(devProp.gcnArchName, "gfx950")) { - //on gfx950, use unroll=1 for single-node and unroll=2 for multi-node - if (comm->nNodes == 1) - comm->unroll = 1; - else - comm->unroll = 2; - } - else if((IsArchMatch(devProp.gcnArchName, "gfx908")) || - (IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)) - //on MI300X and gfx908, use unroll=2 - comm->unroll = 2; + if(IsArchMatch(devProp.gcnArchName, "gfx950")) + comm->unroll = NCCL_UNROLL_1; + else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))) + comm->unroll = NCCL_UNROLL_2; else - comm->unroll = 4; - INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll); + comm->unroll = NCCL_UNROLL_4; return ncclSuccess; } diff --git a/projects/rccl/tools/RcclReplayer/rcclReplayer.cpp b/projects/rccl/tools/RcclReplayer/rcclReplayer.cpp index f93cd86d8d..40d7f788e4 100644 --- a/projects/rccl/tools/RcclReplayer/rcclReplayer.cpp +++ b/projects/rccl/tools/RcclReplayer/rcclReplayer.cpp @@ -487,6 +487,13 @@ void Replayer::replay() NCCL_CALL(ncclAllReduce(sbuffer, rbuffer, call.count, call.datatype, call.op, commMap[call.comm], streams[call.stream].first)); break; } + case rrAllReduceWithBias: + { + std::vector acc(call.count * ncclTypeSize(call.datatype)); + NCCL_CALL(ncclAllReduceWithBias(sbuffer, rbuffer, call.count, call.datatype, call.op, commMap[call.comm], streams[call.stream].first, acc.data())); + HIP_CALL(hipStreamSynchronize(streams[call.stream].first)); // TODO: remove, and further verify behavior of fused AR + break; + } // a2av case rrAllToAllv: { @@ -670,4 +677,4 @@ int main(int argc, char **argv) replayer.replay(); MPI_Finalize(); return 0; -} \ No newline at end of file +}