[Device] Add dynamic fetch/reduce pipelining for reduction collectives - Simple protocol (#1861)
* Support pipelining codegen and template specialization
* Support ReduceCopy pipelining for AllReduce, ReduceScatter, and Reduce (currently enabled for bfloat16)
* Remove need for FUNC_INDEX_TOTAL
* Add pipeline field to device function key construction logic
* Avoid unneeded codegen for LL/LL64 kernels
* Modify conditions and add pipeline dtypes env
* Optimize selection for both gfx942 and gfx950
* Increase pipeline bitfield width
* Use __forceinline__ for all device functions
* Realign reduceCopy with original form
* Add opt-out option to enable perf debugs
* Remove force-reduce-pipelining option from README
* Update CHANGELOG.md
---------
Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
[ROCm/rccl commit: 277747c199]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
c7fce9b0eb
Коммит
f37f290134
@@ -22,10 +22,10 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
|
||||
* Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges.
|
||||
* LL/LL128 usage ranges for AR, AG, and RS are part of the tuning models, which enable architecture-specific tuning in conjunction with the existing Rome Models scheme in RCCL.
|
||||
* Two new APIs are exposed as part of an initiative to separate RCCL code. These APIs are `rcclGetAlgoInfo` and `rcclFuncMaxSendRecvCount`. However, user-level invocation requires that RCCL be built with `RCCL_EXPOSE_STATIC` enabled.
|
||||
* Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap bf16 arithmetic.
|
||||
* Added `--force-reduce-pipeline` as an option that can be passed to the `install.sh` script. Passing this option will enable software-triggered pipelining `bfloat16` reductions (i.e. `all_reduce`, `reduce_scatter` and `reduce`).
|
||||
* Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap `bf16` arithmetic and bridge the gap between `fp32` performance and `bf16` for both `gfx942` and `gfx950`. Pipelining has been made tunable via `rcclSetPipelining`, similar to algorithms/protocols so that regression is avoided in certain message sizes.
|
||||
* Added a direct allgather algorithm. This is enabled by default for multi-node if there are 16 nodes or fewer. The message size threshold is 4MB.
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
* Compatibility with NCCL 2.23.4
|
||||
|
||||
@@ -38,7 +38,6 @@ option(PROFILE "Enable profiling"
|
||||
option(TIMETRACE "Enable time-trace during compilation" OFF)
|
||||
option(TRACE "Enable additional tracing" OFF)
|
||||
option(FAULT_INJECTION "Enable fault injection" ON)
|
||||
option(FORCE_REDUCE_PIPELINING "Force reduce pipelining" OFF)
|
||||
|
||||
# Default GPU architectures to build
|
||||
#==================================================================================================
|
||||
@@ -848,18 +847,6 @@ foreach(file ${GENERATED_FILES})
|
||||
list(APPEND HIP_SOURCES ${file})
|
||||
endforeach()
|
||||
|
||||
# Enable SW pipelining where needed
|
||||
foreach(SOURCE_FILE ${HIP_SOURCES})
|
||||
# TODO: enable bf16 pipelining by default upon having the pipelined/scalar switching feature
|
||||
# if (FORCE_REDUCE_PIPELINING AND (SOURCE_FILE MATCHES "gensrc/reduce_.*" OR SOURCE_FILE MATCHES "gensrc/reduce_scatter_.*" OR SOURCE_FILE MATCHES "gensrc/all_reduce_.*"))
|
||||
# message(STATUS "RCCL_ENABLE_SW_PIPELINE enabled for ${SOURCE_FILE}")
|
||||
# set_source_files_properties(${SOURCE_FILE} PROPERTIES COMPILE_FLAGS "-DRCCL_ENABLE_SW_PIPELINE")
|
||||
if(FORCE_REDUCE_PIPELINING AND SOURCE_FILE MATCHES "gensrc/(reduce|reduce_scatter|all_reduce).*_bf16\\.cpp$")
|
||||
message(STATUS "BF16 pipelining support enabled for ${SOURCE_FILE}")
|
||||
set_source_files_properties(${SOURCE_FILE} PROPERTIES COMPILE_FLAGS "-DRCCL_ENABLE_SW_PIPELINE")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Create an initial git_version.cpp file (that will be updated with latest git version)
|
||||
#==================================================================================================
|
||||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/git_version.cpp "")
|
||||
|
||||
@@ -64,7 +64,6 @@ RCCL build & installation helper script
|
||||
-t|--tests_build Build rccl unit tests, but do not run
|
||||
--time-trace Plot the build time of RCCL (requires `ninja-build` package installed on the system)
|
||||
--verbose Show compile commands
|
||||
--force-reduce-pipeline Force reduce_copy sw pipeline to be used for every reduce-based collectives and datatypes
|
||||
```
|
||||
|
||||
By default, RCCL builds for all GPU targets defined in `DEFAULT_GPUS` in `CMakeLists.txt`. To target specific GPU(s), and potentially reduce build time, use `--amdgpu_targets` as a `;` separated string listing GPU(s) to target.
|
||||
|
||||
Обычный файл → Исполняемый файл
+13
-4
@@ -21,13 +21,22 @@
|
||||
HIP_FILE=$1
|
||||
|
||||
if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then
|
||||
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL\2>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL, int Pipeline\2>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?(?:, int RCCLMetadata)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL, int Pipeline\2>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, USE_ACC, COLL_UNROLL>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(runRing<T.*?)((, (true|false))?>\()/\1, USE_ACC, COLL_UNROLL\2/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeUpDown<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeSplit<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkColl<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
|
||||
echo "Added COLL_UNROLL and USE_ACC template arguments to $HIP_FILE"
|
||||
|
||||
perl -pi -e 's/(runTreeSplit<T, RedOp, (ProtoLL|ProtoLL128), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeUpDown<T, RedOp, (ProtoLL|ProtoLL128), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runRing<T, RedOp, (ProtoLL|ProtoLL128), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runRing<T, RedOp, (ProtoLL|ProtoLL128), (RCCL_ONE_NODE_RING_SIMPLE|RCCL_METADATA_EMPTY), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
|
||||
|
||||
perl -pi -e 's/(runRing<T, RedOp, Proto, (RCCL_ONE_NODE_RING_SIMPLE|RCCL_METADATA_EMPTY), USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runRing<T, RedOp, Proto, USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeSplit<T, RedOp, Proto, USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeUpDown<T, RedOp, Proto, USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL, Pipeline>/" "$HIP_FILE"
|
||||
sed -i "s/\\(RunWorkColl<[^,]*,[^,]*,[^,]*,[^,]*,[^>]*\\)>/\\1, USE_ACC, COLL_UNROLL, Pipeline>/" "$HIP_FILE"
|
||||
fi
|
||||
@@ -14,7 +14,7 @@
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto, int RCCLMetadata, int USE_ACC, int COLL_UNROLL>
|
||||
template<typename T, typename RedOp, typename Proto, int RCCLMetadata>
|
||||
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__)
|
||||
__device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
#else
|
||||
@@ -61,7 +61,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &ring->next as an array. However, due to the use of
|
||||
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, RCCLMetadata> prims
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, RCCLMetadata, Pipeline> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -252,7 +252,7 @@ namespace {
|
||||
#endif
|
||||
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims
|
||||
(tid, nthreads, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -301,7 +301,7 @@ namespace {
|
||||
}
|
||||
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims
|
||||
(tid, nthreads, &tree->up, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -420,7 +420,7 @@ namespace {
|
||||
|
||||
if (tree->up == -1) {
|
||||
// Reduce and broadcast. Max number of recv is 2, max number of send is 2
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto,USE_ACC >
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC>
|
||||
prims(tid, nthreads, tree->down, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -463,7 +463,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &tree->up as an array. However, due to the use of
|
||||
// FanAsymmetric<n, 1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0>
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -508,7 +508,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &tree->up as an array. However, due to the use of
|
||||
// FanAsymmetric<1, n>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, work->sendbuff, work->recvbuff,
|
||||
work->redOpArg, 1*Proto::MaxGroupWidth, 0, 0, work);
|
||||
|
||||
@@ -560,7 +560,7 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
|
||||
#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \
|
||||
if(work->rcclUseOneSlice){ \
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS_SINGLE_NODE, ALLREDUCE_SLICESTEPS_SINGLE_NODE>; \
|
||||
@@ -579,7 +579,7 @@ namespace {
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work);
|
||||
rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -392,14 +392,14 @@ __device__ __forceinline__ void loadWorkBatchToShmem(
|
||||
}
|
||||
}
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL, int Pipeline>
|
||||
struct RunWorkColl {
|
||||
__device__ void run(int tid, int tn, struct ncclDevWorkColl* work) {
|
||||
// Put NOT IMPLEMENTED behavior here.
|
||||
}
|
||||
};
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL, int Pipeline>
|
||||
struct RunWorkBatch;
|
||||
|
||||
// Specialized for P2p in sendrecv.h
|
||||
@@ -407,7 +407,7 @@ template<typename T, typename RedOp>
|
||||
struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE>;
|
||||
|
||||
// Specialized here for non-P2p (Coll and CollReg)
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL, int Pipeline>
|
||||
struct RunWorkBatch {
|
||||
// This __forceinline__ is necessary. The compiler was inserting a function call
|
||||
// here from the LL ncclKernel.
|
||||
@@ -437,7 +437,7 @@ struct RunWorkBatch {
|
||||
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
|
||||
// However, the code ensures that the participation is on a per-warp basis.
|
||||
// coverity[device_thread_diverged:FALSE]
|
||||
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, USE_ACC, COLL_UNROLL>().run(tid, subtn, work);
|
||||
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto>().run(tid, subtn, work);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -672,14 +672,14 @@ __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONST
|
||||
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
|
||||
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, pipeline, unroll) \
|
||||
__device__ void ncclDevFunc_##suffix() { \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll, pipeline>().run(); \
|
||||
}
|
||||
#else
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, pipeline, unroll) \
|
||||
__device__ __attribute__((noinline)) void ncclDevFunc_##suffix() { \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll, pipeline>().run(); \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -27,12 +27,11 @@ inline __device__ int loadInt(int* ptr) {
|
||||
return v;
|
||||
}
|
||||
|
||||
#ifndef RCCL_ENABLE_SW_PIPELINE
|
||||
template<typename RedFn, typename T, int Unroll, int BytePerPack,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn>
|
||||
__device__ __forceinline__ void reduceCopyPacks(
|
||||
__device__ __forceinline__ static void reduceCopyPacks(
|
||||
int nThreads, int &thread,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
|
||||
@@ -207,16 +206,15 @@ __device__ __forceinline__ void reduceCopyPacks(
|
||||
warp = -nHunksAhead;
|
||||
thread = warp*WARP_SIZE + lane;
|
||||
}
|
||||
#else
|
||||
|
||||
template <typename RedFn, typename SrcPtrFn, typename IntBytes, int MultimemSrcs, int MinSrcs, int MaxSrcs, int PreOpSrcs, int Unroll, int BytePerPack>
|
||||
__device__ __forceinline__ void loadSources(
|
||||
const RedFn& redFn,
|
||||
const SrcPtrFn& srcPtrFn,
|
||||
IntBytes& globalOffset,
|
||||
uintptr_t* minSrcs,
|
||||
const RedFn& redFn,
|
||||
const SrcPtrFn& srcPtrFn,
|
||||
IntBytes& globalOffset,
|
||||
uintptr_t* minSrcs,
|
||||
uint64_t *preOpArgs,
|
||||
BytePack<BytePerPack> buff[MaxSrcs + !MaxSrcs][Unroll],
|
||||
BytePack<BytePerPack> buff[MaxSrcs + !MaxSrcs][Unroll],
|
||||
int nSrcs
|
||||
) {
|
||||
#pragma unroll Unroll
|
||||
@@ -295,7 +293,7 @@ template<typename RedFn, typename T, int Unroll, int BytePerPack,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn>
|
||||
__device__ __forceinline__ void reduceCopyPacks(
|
||||
__device__ __forceinline__ static void reduceCopyPacksPipelined(
|
||||
int nThreads, int &thread,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
|
||||
@@ -354,12 +352,12 @@ __device__ __forceinline__ void reduceCopyPacks(
|
||||
loadSources<RedFn, SrcPtrFn, IntBytes, MultimemSrcs, MinSrcs, MaxSrcs, PreOpSrcs, Unroll, BytePerPack>(
|
||||
redFn, srcPtrFn, threadBytesBehind, minSrcs, preOpArgs, acc1, nSrcs
|
||||
);
|
||||
|
||||
|
||||
if(tailProcess) {
|
||||
reduceAndStore<RedFn, DstPtrFn, IntBytes, MultimemDsts, MinSrcs, MaxSrcs, MinDsts, MaxDsts, PreOpSrcs, Unroll, BytePerPack>(
|
||||
redFn, preOpArgs, acc2, minDsts, postOp, nDsts, dstPtrFn, tailThreadBytesBehind, nSrcs
|
||||
);
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int d=0; d < MinDsts; d++) {
|
||||
minDsts[d] += (nWarps-1)*BytePerHunk;
|
||||
@@ -373,7 +371,7 @@ __device__ __forceinline__ void reduceCopyPacks(
|
||||
threadBytesAhead -= nWarps*BytePerHunk;
|
||||
nHunksAhead -= nWarps;
|
||||
tailProcess = Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead);
|
||||
|
||||
|
||||
tailThreadBytesBehind = threadBytesBehind;
|
||||
threadBytesBehind += nWarps*BytePerHunk;
|
||||
if(tailProcess) {
|
||||
@@ -400,7 +398,7 @@ __device__ __forceinline__ void reduceCopyPacks(
|
||||
nHunksAhead -= nWarps;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(tailProcess) {
|
||||
reduceAndStore<RedFn, DstPtrFn, IntBytes, MultimemDsts, MinSrcs, MaxSrcs, MinDsts, MaxDsts, PreOpSrcs, Unroll, BytePerPack>(
|
||||
redFn, preOpArgs, acc2, minDsts, postOp, nDsts, dstPtrFn, tailThreadBytesBehind, nSrcs
|
||||
@@ -418,7 +416,6 @@ __device__ __forceinline__ void reduceCopyPacks(
|
||||
warp = -nHunksAhead;
|
||||
thread = warp*WARP_SIZE + lane;
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename RedFn, typename T, int Unroll, int BytePerPack,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
@@ -613,7 +610,7 @@ __device__ __forceinline__ void reduceCopyPacksWithBias(
|
||||
template<int Unroll, int useAcc, typename RedFn, typename T,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
|
||||
typename IntBytes, int Pipeline, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
|
||||
__device__ __forceinline__ void reduceCopy(
|
||||
int thread, int nThreads,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
@@ -647,40 +644,56 @@ __device__ __forceinline__ void reduceCopy(
|
||||
aligned = !(__any(!aligned));
|
||||
if (aligned) {
|
||||
#if defined(__gfx90a__)
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
reduceCopyPacks<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
#else
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
#endif
|
||||
if (nBytesAhead == 0) return;
|
||||
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
reduceCopyPacks<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
|
||||
if (nBytesAhead == 0) return;
|
||||
}
|
||||
}
|
||||
@@ -707,58 +720,81 @@ __device__ __forceinline__ void reduceCopy(
|
||||
*/
|
||||
#if defined(__gfx90a__)
|
||||
if (MinSrcs > 1) {
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
reduceCopyPacks<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
} else {
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
}
|
||||
#else
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
|
||||
#endif
|
||||
if (nBytesAhead == 0) return;
|
||||
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
if constexpr (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else if constexpr (Pipeline)
|
||||
reduceCopyPacksPipelined<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<int Unroll, int useAcc, typename RedFn, typename T,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes>
|
||||
int Pipeline = 0, typename IntBytes>
|
||||
__device__ __forceinline__ void reduceCopy(
|
||||
int thread, int nThreads,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
@@ -767,10 +803,10 @@ __device__ __forceinline__ void reduceCopy(
|
||||
) {
|
||||
reduceCopy<Unroll, useAcc, RedFn, T,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs,
|
||||
MultimemDsts, MinDsts, MaxDsts, PreOpSrcs, IntBytes>
|
||||
MultimemDsts, MinDsts, MaxDsts, PreOpSrcs, IntBytes, Pipeline>
|
||||
(thread, nThreads, redArg, preOpArgs, postOp,
|
||||
nSrcs, [=]__device__(int i) { return srcPtrs[i]; },
|
||||
nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts, [=]__device__() { return accPtr; });
|
||||
}
|
||||
|
||||
#endif // COMMON_KERNEL_H_
|
||||
#endif // COMMON_KERNEL_H_
|
||||
@@ -11,7 +11,13 @@ all_protos = ["LL","LL128","SIMPLE"]
|
||||
all_algos = ["TREE","RING", "", "", "", "", "PAT"]
|
||||
all_unroll = ["1", "2", "4"]
|
||||
use_acc = ["0", "1"]
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_unroll]
|
||||
|
||||
# Pipelining is not supported for LL/LL64 prims, so "1" is not a valid value for low latency protocols.
|
||||
# However, if it needs to be supported, equivalent_primary() can be modified to avoid the "non-zero"->"0" mapping.
|
||||
all_pipeline = ["0", "1"]
|
||||
pipelined_types = ["bf16"]
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_pipeline, all_unroll]
|
||||
|
||||
|
||||
################################################################################
|
||||
# The first command line argument is the path to the directory to generate and
|
||||
@@ -114,14 +120,25 @@ redops_of_coll = {
|
||||
}
|
||||
|
||||
tys_of_coll = {
|
||||
"AllGather": ["i8"],
|
||||
"AllReduce": all_tys,
|
||||
"AllGather": ["i8"],
|
||||
"AllReduce": all_tys,
|
||||
"AllReduceWithBias": all_tys,
|
||||
"AllToAllPivot": ["i8"],
|
||||
"Broadcast": ["i8"],
|
||||
"Reduce": all_tys,
|
||||
"ReduceScatter": all_tys,
|
||||
"SendRecv": ["i8"]
|
||||
"AllToAllPivot": ["i8"],
|
||||
"Broadcast": ["i8"],
|
||||
"Reduce": all_tys,
|
||||
"ReduceScatter": all_tys,
|
||||
"SendRecv": ["i8"]
|
||||
}
|
||||
|
||||
pipelines_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": all_pipeline,
|
||||
"AllReduceWithBias": ["0"],
|
||||
"AllToAllPivot": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": all_pipeline,
|
||||
"ReduceScatter": all_pipeline,
|
||||
"SendRecv": ["0"]
|
||||
}
|
||||
|
||||
coll_camel_to_lower = {
|
||||
@@ -179,7 +196,7 @@ def calc_unroll_for_local_arch():
|
||||
return all_unroll
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
if acc == "1" and coll != "AllReduceWithBias":
|
||||
return False
|
||||
if acc == "0" and coll == "AllReduceWithBias":
|
||||
@@ -188,7 +205,7 @@ def func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
return False
|
||||
if coll == "" or algo == "":
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or acc not in use_acc or unroll not in all_unroll:
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or acc not in use_acc or unroll not in all_unroll or pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -233,10 +250,10 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
# For each loop layer remove the last element in item_list
|
||||
item_list.pop()
|
||||
else:
|
||||
coll, algo, proto, redop, ty, acc, unroll = item_list
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = item_list
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
yield(coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
yield(coll, algo, proto, redop, ty, acc, unroll)
|
||||
|
||||
# Parse ONLY_FUNCS input and feed it to func_filter
|
||||
def parse_input(func_pattern):
|
||||
@@ -256,7 +273,7 @@ def parse_input(func_pattern):
|
||||
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, acc, unroll):
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
if coll in ("AllReduce", "AllReduceWithBias", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
||||
@@ -264,7 +281,11 @@ def equivalent_primary(coll, algo, proto, redop, ty, acc, unroll):
|
||||
# map signed integer min/max to unsigned for non-NVLS
|
||||
elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
|
||||
ty = "u"+ty[1:]
|
||||
return (coll, algo, proto, redop, ty, acc, unroll)
|
||||
# map pipelined to non-pipelined for LL/LL128 to avoid extra device codegen
|
||||
if (pipeline != "0" and proto != "SIMPLE"):
|
||||
pipeline = "0"
|
||||
|
||||
return (coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
|
||||
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
||||
# outermost loop should be for unroll factor; refer to host_table section
|
||||
@@ -276,12 +297,12 @@ def enumerate_func_rows():
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
yield (coll, algo, proto, redop, ty, acc, unroll)
|
||||
|
||||
for pipeline in all_pipeline:
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
yield (coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
use_acc.index(acc),
|
||||
@@ -289,7 +310,8 @@ def custom_sort_key(fn):
|
||||
all_algos.index(algo),
|
||||
all_protos.index(proto),
|
||||
all_redops.index(redop),
|
||||
all_tys.index(ty)
|
||||
all_tys.index(ty),
|
||||
all_pipeline.index(pipeline)
|
||||
)
|
||||
|
||||
################################################################################
|
||||
@@ -333,7 +355,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
|
||||
index1 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
if unroll != "1": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
@@ -350,7 +372,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
|
||||
index2 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
if unroll != "2": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
@@ -367,7 +389,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index4 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
if unroll != "4": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
@@ -448,7 +470,7 @@ if is_colltrace:
|
||||
out("\n")
|
||||
|
||||
seen_fns = set()
|
||||
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
out("const char* funcNames[] = {\n")
|
||||
for fn in primary_funcs:
|
||||
fn_no_unroll = fn[:-1]
|
||||
if fn_no_unroll not in seen_fns:
|
||||
@@ -466,6 +488,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
out('#include "device.h"\n')
|
||||
out("\n")
|
||||
out("// The key for the ncclDevFuncNameToId map is a 64-bit unsigned integer.\n")
|
||||
out("// Each field (coll, algo, proto, redop, ty, pipeline) is packed into 4 bits,\n")
|
||||
out("// Each field (coll, algo, proto, redop, ty) is packed into 4 bits,\n")
|
||||
out("// This allows up to 16 unique values per field. The layout is:\n")
|
||||
out("// bits 0-3: coll index\n")
|
||||
@@ -473,8 +496,9 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
out("// bits 8-11: proto index\n")
|
||||
out("// bits 12-15: redop index\n")
|
||||
out("// bits 16-19: ty index\n")
|
||||
out("// bits 20-23: pipeline index\n")
|
||||
out("#include <unordered_map>\n")
|
||||
out("extern std::unordered_map<uint64_t, int> ncclDevFuncNameToId = {\n")
|
||||
out("std::unordered_map<uint64_t, int> ncclDevFuncNameToId = {\n")
|
||||
|
||||
# host_table entries map device functions based on collective, algorithm, protocol, redop, and datatype
|
||||
# For GPU targets that support multiple unrolls, e.g., gfx950
|
||||
@@ -485,17 +509,20 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
comment = " // " + paste(" ", *fn[:-1])
|
||||
# Build the function signature string: "<coll> <algo> <proto> <redop> <ty>"
|
||||
# get parts indexes in order (coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
coll_idx = all_colls.index(fn[0])
|
||||
algo_idx = all_algos.index(fn[1])
|
||||
proto_idx = all_protos.index(fn[2])
|
||||
redop_idx = all_redops.index(fn[3])
|
||||
ty_idx = all_tys.index(fn[4])
|
||||
pipeline_idx = all_pipeline.index(fn[6])
|
||||
# Assert that 4 bits (16 values) is enough to map all_colls, all_algos, etc.
|
||||
assert len(all_colls) <= 16, "Error: all_colls has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_algos) <= 16, "Error: all_algos has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_protos) <= 16, "Error: all_protos has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_redops) <= 16, "Error: all_redops has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_tys) <= 16, "Error: all_tys has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_pipeline) <= 16, "Error: all_pipeline has more than 16 values, which exceeds 4-bit capacity."
|
||||
# Create a 64-bit unsigned integer key and pack the indices into 4 bits each
|
||||
key = (
|
||||
(coll_idx & 0xF)
|
||||
@@ -503,8 +530,9 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
| ((proto_idx & 0xF) << 8)
|
||||
| ((redop_idx & 0xF) << 12)
|
||||
| ((ty_idx & 0xF) << 16)
|
||||
| ((pipeline_idx & 0xF) << 20)
|
||||
)
|
||||
fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx}"
|
||||
fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {pipeline_idx}"
|
||||
if fn[0] == "Broadcast":
|
||||
key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8))
|
||||
if fn[0] in ["SendRecv", "AllToAllPivot"]:
|
||||
@@ -515,7 +543,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# Maps to .cu filename which implements this func. The only constraint is that
|
||||
# "coll" is reflected in the name: formally that no two funcs having different
|
||||
# coll's map to the same filename.
|
||||
def impl_filename(coll, algo, proto, redop, ty, acc, unroll):
|
||||
def impl_filename(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
|
||||
|
||||
# Partition the functions and kernels to the .cu filenames. The partition is
|
||||
@@ -573,14 +601,14 @@ for name in name_to_funcs.keys():
|
||||
)
|
||||
|
||||
for fn in fns:
|
||||
(coll, algo, proto, redop, ty, acc, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, acc, unroll)
|
||||
(coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
if proto == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out(
|
||||
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {unroll})\n"
|
||||
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {pipeline}, {unroll})\n"
|
||||
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, unroll=unroll)
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, pipeline=pipeline, unroll=unroll)
|
||||
)
|
||||
if proto == "LL128":
|
||||
out("#endif\n")
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
* to how that protocol operates with a consistent interface so that our
|
||||
* algorithm code can operate protocol parametrically.
|
||||
*/
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1,int useAcc, int Unroll_1, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1, int useAcc, int Unroll_1, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
|
||||
struct ProtoSimple {
|
||||
static constexpr int Id = NCCL_PROTO_SIMPLE;
|
||||
static constexpr int SlicePerChunk = SlicePerChunk_1;
|
||||
@@ -137,7 +137,7 @@ struct FanSymmetric {
|
||||
};
|
||||
|
||||
// The primitives class. Specialized per protocol in the other headers.
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, typename Proto, int P2p, bool isNetOffload = false, int Metadata = RCCL_METADATA_EMPTY>
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, typename Proto, int P2p, bool isNetOffload = false, int Metadata = RCCL_METADATA_EMPTY, int Pipeline = 0, int useAcc = 0>
|
||||
class Primitives;
|
||||
|
||||
// Used by LL & LL128 to implement direct members in the naive way.
|
||||
|
||||
@@ -22,9 +22,9 @@ enum primsMode {
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct,
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata, int useAcc>
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata, int Pipeline, int useAcc>
|
||||
class Primitives<
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, useAcc, Unroll, MultimemSrcs, MultimemDsts>, P2p, isNetOffload, Metadata
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, useAcc, Unroll, MultimemSrcs, MultimemDsts>, P2p, isNetOffload, Metadata, Pipeline, useAcc
|
||||
> {
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
@@ -80,9 +80,9 @@ private:
|
||||
|
||||
// Don't use barrier 0 as it's used by the final sync
|
||||
inline __device__ void barrier() {
|
||||
if (nthreads == WARP_SIZE)
|
||||
if (nthreads == WARP_SIZE)
|
||||
__syncwarp();
|
||||
else
|
||||
else
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
barrier_generic(__threadfence_block(), nworkers, barrier_next, barriers);
|
||||
#else
|
||||
@@ -380,7 +380,7 @@ private:
|
||||
// this case should only be directCopySend() with registered buffers and send to net peer
|
||||
reduceCopy<Unroll, useAcc, RedOp, T,
|
||||
0, Recv + Src, Recv * MaxRecv + Src,
|
||||
0, 1, 1, PreOpSrcs>
|
||||
0, 1, 1, PreOpSrcs, Pipeline>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs,
|
||||
1, ncclShmem.groups[group].dsts,
|
||||
@@ -388,7 +388,7 @@ private:
|
||||
} else {
|
||||
reduceCopy<Unroll, useAcc, RedOp, T,
|
||||
MultimemSrcs, Recv + Src, Recv * MaxRecv + Src,
|
||||
MultimemDsts, Send + Dst, Send * MaxSend + Dst, PreOpSrcs>
|
||||
MultimemDsts, Send + Dst, Send * MaxSend + Dst, PreOpSrcs, Pipeline>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs,
|
||||
Send * fan.nsend() + Dst, ncclShmem.groups[group].dsts,
|
||||
@@ -458,10 +458,10 @@ private:
|
||||
srcs[nsrcs] = dsts[0];
|
||||
nsrcs++;
|
||||
if (MULTISRCS){
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0, Pipeline>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem);
|
||||
} else {
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 2, 2, 0, 1, 1, 0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 2, 2, 0, 1, 1, 0, Pipeline>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &ring->next as an array. However, due to the use of
|
||||
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, 0, Pipeline>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex);
|
||||
|
||||
if (prevRank == root) {
|
||||
|
||||
@@ -56,7 +56,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &ring->next as an array. However, due to the use of
|
||||
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, 0, Pipeline>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -213,7 +213,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_PAT, NCCL_PROTO_SI
|
||||
int nGroups = nworkers / groupSize;
|
||||
int tidInGroup = tid - group*groupSize;
|
||||
// We don't use recvPeers/sendPeers so let's pass shmem structs instead
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0> prims
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, 0, Pipeline> prims
|
||||
(tidInGroup, groupSize, (int*)shmem->recvDims, (int*)shmem->sendDims, inputBuf, outputBuf, work->redOpArg, group, 0, 0, nullptr, nullptr, 0, primsModePatRs);
|
||||
|
||||
int step = group;
|
||||
|
||||
@@ -457,7 +457,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool
|
||||
}
|
||||
|
||||
NCCLCHECK(getAlgoInfo(comm, &agg, collNetSupport, nvlsSupport, nTasksPerChannel, simInfo));
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol);
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline);
|
||||
if (agg.devFuncId < 0) {
|
||||
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
|
||||
return ncclInvalidUsage;
|
||||
@@ -480,6 +480,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool
|
||||
struct ncclTaskColl* next = aggBeg->next;
|
||||
aggBeg->algorithm = agg.algorithm;
|
||||
aggBeg->protocol = agg.protocol;
|
||||
aggBeg->pipeline = agg.pipeline;
|
||||
if (aggBeg->protocol == NCCL_PROTO_LL) aggBeg->trafficBytes *= 4;
|
||||
aggBeg->nMaxChannels = agg.nMaxChannels;
|
||||
aggBeg->nWarps = agg.nWarps;
|
||||
@@ -1941,6 +1942,7 @@ static ncclResult_t topoGetAlgoInfo(
|
||||
return (algoEnv || protoEnv) ? ncclInvalidUsage : ncclInternalError;
|
||||
}
|
||||
rcclUpdateCollectiveProtocol(comm, nBytes, info);
|
||||
rcclSetPipelining(comm, nBytes, info);
|
||||
if (simInfo) simInfo->estimatedTime = time;
|
||||
TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", nBytes, info->algorithm, info->protocol, time);
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ struct ncclTaskColl {
|
||||
size_t trafficBytes;
|
||||
int32_t nMaxChannels:8;
|
||||
int32_t nWarps:8;
|
||||
int32_t algorithm:8, protocol:8;
|
||||
int32_t algorithm:8, protocol:8, pipeline:8;
|
||||
uint32_t isCollnet:1, isNvls:1;
|
||||
uint32_t devFuncId:30;
|
||||
int regBufType;
|
||||
|
||||
@@ -30,7 +30,7 @@ extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS];
|
||||
|
||||
extern const char* ncclProtoStr[NCCL_NUM_PROTOCOLS];
|
||||
|
||||
extern const char* funcNames[FUNC_INDEX_TOTAL];
|
||||
extern const char* funcNames[];
|
||||
|
||||
#define NCCL_MAX_OPS 2048
|
||||
#define NCCL_STEPS 8
|
||||
@@ -134,6 +134,7 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK
|
||||
#define RCCL_PROTO_SHIFT 8
|
||||
#define RCCL_REDOP_SHIFT 12
|
||||
#define RCCL_DTYPE_SHIFT 16
|
||||
#define RCCL_PIPELINE_SHIFT 20
|
||||
|
||||
struct ncclConnInfo {
|
||||
// Regular comm mechanism
|
||||
@@ -701,7 +702,7 @@ inline bool ncclNvlsSupported(int devRedOp, int type) {
|
||||
extern std::unordered_map<uint64_t, int> ncclDevFuncNameToId;
|
||||
|
||||
// `ncclDevFuncId()` needs to be in sync with 'all_colls' in generate.py
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) {
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int pipeline = 0) {
|
||||
int row = -1;
|
||||
uint64_t key;
|
||||
// Pack 4-bit fields from right (LSB) to left in order:
|
||||
@@ -717,14 +718,15 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto)
|
||||
((uint64_t)(algo & RCCL_FUNC_ID_MASK) << RCCL_ALGO_SHIFT ) |
|
||||
((uint64_t)(proto & RCCL_FUNC_ID_MASK) << RCCL_PROTO_SHIFT) |
|
||||
((uint64_t)(devRedOp & RCCL_FUNC_ID_MASK) << RCCL_REDOP_SHIFT) |
|
||||
((uint64_t)(type & RCCL_FUNC_ID_MASK) << RCCL_DTYPE_SHIFT);
|
||||
((uint64_t)(type & RCCL_FUNC_ID_MASK) << RCCL_DTYPE_SHIFT) |
|
||||
((uint64_t)(pipeline & RCCL_FUNC_ID_MASK) << RCCL_PIPELINE_SHIFT);
|
||||
}
|
||||
auto it = ncclDevFuncNameToId.find(key);
|
||||
if (it != ncclDevFuncNameToId.end()) {
|
||||
row = it->second;
|
||||
}
|
||||
if(row < 0) {
|
||||
WARN("Fatal error: ncclDevFuncId: %llu not found for coll: %d, algo: %d, proto: %d, devRedOp: %d, type: %d", key, coll, algo, proto, devRedOp, type);
|
||||
WARN("Fatal error: ncclDevFuncId: %lu not found for coll: %d, algo: %d, proto: %d, devRedOp: %d, type: %d", key, coll, algo, proto, devRedOp, type);
|
||||
return -1;
|
||||
}
|
||||
return row;
|
||||
|
||||
@@ -39,10 +39,6 @@ typedef enum {
|
||||
|
||||
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
|
||||
|
||||
#define NCCL_NUM_ONERANK 12
|
||||
#define AR_WITH_BIAS_FUNC_COUNTS 324
|
||||
#define FUNC_INDEX_TOTAL 821 + AR_WITH_BIAS_FUNC_COUNTS + NCCL_NUM_ONERANK
|
||||
|
||||
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
|
||||
typedef enum {
|
||||
ncclFuncBroadcast = 0,
|
||||
|
||||
@@ -83,6 +83,7 @@ inline size_t rcclGetSizePerRank(ncclFunc_t const& func, size_t const& nBytes, i
|
||||
}
|
||||
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
|
||||
void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info, int& threadThreshold);
|
||||
void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info);
|
||||
ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t count, ncclDataType_t dataType,
|
||||
int collNetSupport, int nvlsSupport, int numPipeOps,
|
||||
int* algo, int* protocol, int* maxChannels);
|
||||
|
||||
@@ -25,6 +25,14 @@ THE SOFTWARE.
|
||||
#include "graph/topo.h"
|
||||
#include "enqueue.h"
|
||||
|
||||
// Use this param to experiment pipelining new data types besides bfloat16
|
||||
// Make sure you generate the device code with the new data type (i.e. in generate.py)
|
||||
RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0);
|
||||
|
||||
// Use this to assess impact of pipelining on performance.
|
||||
// Otherwise, it is automatically set for certain archs, datatypes and reduction collectives
|
||||
RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0);
|
||||
|
||||
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
|
||||
// Honor user input for protocol choice
|
||||
static int userProtocolInput = -2;
|
||||
@@ -100,6 +108,54 @@ void rcclUpdateThreadThreshold(struct ncclComm* comm, size_t const& nBytes, stru
|
||||
}
|
||||
}
|
||||
|
||||
void rcclSetPipelining(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
|
||||
info->pipeline = 0; // Default to no pipelining
|
||||
if (rcclParamdisableReduceCopyPipelining()) {
|
||||
return;
|
||||
}
|
||||
const bool dtypeOK = (info->datatype == ncclBfloat16) || rcclParamPipelineAllDTypes();
|
||||
|
||||
if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && dtypeOK) {
|
||||
if (comm->nNodes > 1) {
|
||||
switch (info->func) {
|
||||
case ncclFuncAllReduce:
|
||||
case ncclFuncReduceScatter:
|
||||
case ncclFuncReduce:
|
||||
// Enable for multi-node
|
||||
info->pipeline = 1;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && dtypeOK) {
|
||||
switch (info->func) {
|
||||
// For multi-node case, we check if the number of bytes (`nBytes`) satisfies
|
||||
// the Bf16 Limit Equation for bf16 all_reduce on MI300:
|
||||
// 512MB × 2^(log2[nNodes] - 1), nNodes > 1
|
||||
// The above equation is derived from the tuning results of the bf16 all_reduce on MI300.
|
||||
case ncclFuncAllReduce:
|
||||
if ( comm->nNodes == 1 ||
|
||||
((comm->nNodes > 1) &&
|
||||
nBytes <= (1ULL << 29 /*512MB*/) * (1ULL << (log2i(comm->nNodes) - 1))) ) {
|
||||
info->pipeline = 1;
|
||||
}
|
||||
break;
|
||||
|
||||
case ncclFuncReduceScatter:
|
||||
case ncclFuncReduce:
|
||||
info->pipeline = 1;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern ncclResult_t getAlgoInfo(
|
||||
struct ncclComm* comm, struct ncclTaskColl* task,
|
||||
int collNetSupport, int nvlsSupport, int numPipeOps, ncclSimInfo_t* simInfo = NULL
|
||||
|
||||
Ссылка в новой задаче
Block a user