From c76a4492f1c7c108f783e0a4cbcb4a96078e293a Mon Sep 17 00:00:00 2001 From: mberenjk <146776561+mberenjk@users.noreply.github.com> Date: Thu, 14 Aug 2025 15:58:54 -0500 Subject: [PATCH] Added useAcc as a template parameter to address the performance regression (#1856) * Added useAcc as a template parameter to address the 2% performance regression in allreduceWithBias --------- Co-authored-by: Marzieh Berenjkoub [ROCm/rccl commit: c61152baa406c49f7bfa1a8ce3b96efc367002c6] --- projects/rccl/cmake/scripts/add_unroll.sh | 16 +-- projects/rccl/src/device/all_gather.h | 8 +- projects/rccl/src/device/all_reduce.h | 28 ++--- projects/rccl/src/device/alltoall_pivot.h | 2 +- projects/rccl/src/device/broadcast.h | 4 +- projects/rccl/src/device/common.h | 16 +-- projects/rccl/src/device/common_kernel.h | 8 +- projects/rccl/src/device/generate.py | 120 +++++++++++-------- projects/rccl/src/device/msccl_kernel_impl.h | 2 +- projects/rccl/src/device/onerank.cu | 2 +- projects/rccl/src/device/primitives.h | 2 +- projects/rccl/src/device/prims_ll.h | 8 +- projects/rccl/src/device/prims_ll128.h | 8 +- projects/rccl/src/device/prims_simple.h | 28 ++--- projects/rccl/src/device/reduce_scatter.h | 8 +- projects/rccl/src/device/sendrecv.h | 14 +-- projects/rccl/src/include/device.h | 2 +- projects/rccl/src/include/nccl_common.h | 14 ++- 18 files changed, 153 insertions(+), 137 deletions(-) diff --git a/projects/rccl/cmake/scripts/add_unroll.sh b/projects/rccl/cmake/scripts/add_unroll.sh index 8b1cddfff3..2c74479153 100644 --- a/projects/rccl/cmake/scripts/add_unroll.sh +++ b/projects/rccl/cmake/scripts/add_unroll.sh @@ -21,13 +21,13 @@ HIP_FILE=$1 if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then - perl -pi -e 's/(template/\1, int COLL_UNROLL\2>/g' "$HIP_FILE" - perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, COLL_UNROLL>/g' "$HIP_FILE" - perl -pi -e 's/(runRing\()/\1, COLL_UNROLL\2/g' "$HIP_FILE" - perl -pi -e 's/(runTreeUpDown\(/\1, COLL_UNROLL>(/' "$HIP_FILE" - perl -pi -e 's/(runTreeSplit\(/\1, COLL_UNROLL>(/' "$HIP_FILE" - sed -i "s/\\(struct RunWorkColl]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" - sed -i "s/\\(struct RunWorkBatch]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + perl -pi -e 's/(template/\1, int USE_ACC, int COLL_UNROLL\2>/g' "$HIP_FILE" + perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, USE_ACC, COLL_UNROLL>/g' "$HIP_FILE" + perl -pi -e 's/(runRing\()/\1, USE_ACC, COLL_UNROLL\2/g' "$HIP_FILE" + perl -pi -e 's/(runTreeUpDown\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE" + perl -pi -e 's/(runTreeSplit\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE" + sed -i "s/\\(struct RunWorkColl]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/\\(struct RunWorkBatch]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE" + echo "Added COLL_UNROLL and USE_ACC template arguments to $HIP_FILE" - echo "Added COLL_UNROLL template argument to $HIP_FILE" fi \ No newline at end of file diff --git a/projects/rccl/src/device/all_gather.h b/projects/rccl/src/device/all_gather.h index ddaf386b20..046b063592 100644 --- a/projects/rccl/src/device/all_gather.h +++ b/projects/rccl/src/device/all_gather.h @@ -162,7 +162,7 @@ namespace { } else if (inputBuf != outputBuf + ringRanks[0] * count) { inputBuf = inputBuf + partOffset; outputBuf = outputBuf + partOffset + ringRanks[0] * count; - reduceCopy + reduceCopy (tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, partCount); } #if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__) @@ -303,7 +303,7 @@ struct RunWorkCollregUsed) { if (tid < tidEndGather) { // Gather - using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -329,7 +329,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsGather, nvls->up, nvls->up, NULL, NULL, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -409,7 +409,7 @@ struct RunWorkColl diff --git a/projects/rccl/src/device/all_reduce.h b/projects/rccl/src/device/all_reduce.h index 901420352c..ac2347bcd5 100644 --- a/projects/rccl/src/device/all_reduce.h +++ b/projects/rccl/src/device/all_reduce.h @@ -14,7 +14,7 @@ #endif namespace { - template + template #if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__) __device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { #else @@ -420,7 +420,7 @@ namespace { if (tree->up == -1) { // Reduce and broadcast. Max number of recv is 2, max number of send is 2 - Primitives, /*Direct=*/0, Proto, 0> + Primitives, /*Direct=*/0, Proto,USE_ACC > prims(tid, nthreads, tree->down, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work); #if defined(ENABLE_NPKIT) @@ -566,7 +566,7 @@ namespace { using Proto = ProtoSimple; \ if(work->regUsed || work->netRegUsed || work->gfx942CheapFenceOff){ \ runRing(tid, nthreads, work); \ - } \ + } \ else { \ runRing(tid, nthreads, work); \ } \ @@ -772,7 +772,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -784,7 +784,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff, work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); @@ -819,7 +819,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -831,7 +831,7 @@ struct RunWorkColl Coverity think prims.index can be greater than 1 } else if (tid < tidEndGather) { // Gather - using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff, work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); @@ -842,7 +842,7 @@ struct RunWorkCollheadRank != -1) { // Reduce, send to network - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>; // Coverity complains about a possible overrun inside the class below, but that's actually // a false positive. // coverity[identity_transfer:FALSE] @@ -857,7 +857,7 @@ struct RunWorkCollheadRank != -1) { // Recv from network, broadcast - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 0, 1>; // Coverity complains about a possible overrun inside the class below, but that's actually // a false positive. // coverity[identity_transfer:FALSE] @@ -907,7 +907,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -919,7 +919,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff, work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); @@ -932,7 +932,7 @@ struct RunWorkCollheadRank != -1) { if (!hasUp) { // Reduce and Broadcast - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 1>; Primitives, /*Direct=*/1, Proto, 0> prims(tid - tidEndGather, nThreadsReduce, treeDown, treeDown, NULL, NULL, work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, work); @@ -946,7 +946,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>; // Coverity reports that the callee treats &treeUp as an array. However, due to the use of // FanAsymmetric<3, 1>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] @@ -964,7 +964,7 @@ struct RunWorkCollheadRank != -1) { // Recv from network, broadcast - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 0, 1>; // Coverity reports that the callee treats &treeUp as an array. However, due to the use of // FanAsymmetric<1, 3>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] diff --git a/projects/rccl/src/device/alltoall_pivot.h b/projects/rccl/src/device/alltoall_pivot.h index 4715619aa6..9e988605e9 100644 --- a/projects/rccl/src/device/alltoall_pivot.h +++ b/projects/rccl/src/device/alltoall_pivot.h @@ -52,7 +52,7 @@ namespace { if (num_hops == 0 && work->sendbuff != work->recvbuff) { const T* sendbuff = (const T*)work->sendbuff + send_offset; T* recvbuff = (T *)work->recvbuff + recv_offset; - reduceCopy( + reduceCopy( tid, nthreads, 0, nullptr, false, 1, (void **)&sendbuff, 1, (void **)&recvbuff, send_recv_size); } else { for (ssize_t prims_offset = 0; prims_offset < send_recv_size; prims_offset += prims_size) { diff --git a/projects/rccl/src/device/broadcast.h b/projects/rccl/src/device/broadcast.h index e39da0bad4..364e87ee2b 100644 --- a/projects/rccl/src/device/broadcast.h +++ b/projects/rccl/src/device/broadcast.h @@ -91,7 +91,7 @@ namespace { } else if (inputBuf != outputBuf && rank == root) { inputBuf = inputBuf + gridOffset; outputBuf = outputBuf + gridOffset; - reduceCopy + reduceCopy (tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, channelCount); } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_EXIT) @@ -126,4 +126,4 @@ struct RunWorkColl(tid, nthreads, work); } -}; \ No newline at end of file +}; diff --git a/projects/rccl/src/device/common.h b/projects/rccl/src/device/common.h index c6c61021ad..028fda3189 100644 --- a/projects/rccl/src/device/common.h +++ b/projects/rccl/src/device/common.h @@ -392,14 +392,14 @@ __device__ __forceinline__ void loadWorkBatchToShmem( } } -template +template struct RunWorkColl { __device__ void run(int tid, int tn, struct ncclDevWorkColl* work) { // Put NOT IMPLEMENTED behavior here. } }; -template +template struct RunWorkBatch; // Specialized for P2p in sendrecv.h @@ -407,7 +407,7 @@ template struct RunWorkBatch; // Specialized here for non-P2p (Coll and CollReg) -template +template struct RunWorkBatch { // This __forceinline__ is necessary. The compiler was inserting a function call // here from the LL ncclKernel. @@ -437,7 +437,7 @@ struct RunWorkBatch { // Coverity reports a possible thread divergence due to not all threads participating in the collective. // However, the code ensures that the participation is on a per-warp basis. // coverity[device_thread_diverged:FALSE] - if (tid < subtn) RunWorkColl().run(tid, subtn, work); + if (tid < subtn) RunWorkColl().run(tid, subtn, work); } } }; @@ -672,14 +672,14 @@ __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONST __global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {} #ifdef USE_INDIRECT_FUNCTION_CALL -#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \ +#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \ __device__ void ncclDevFunc_##suffix() { \ - RunWorkBatch, algo, proto, unroll>().run(); \ + RunWorkBatch, algo, proto, acc, unroll>().run(); \ } #else -#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \ +#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \ __device__ __attribute__((noinline)) void ncclDevFunc_##suffix() { \ - RunWorkBatch, algo, proto, unroll>().run(); \ + RunWorkBatch, algo, proto, acc, unroll>().run(); \ } #endif diff --git a/projects/rccl/src/device/common_kernel.h b/projects/rccl/src/device/common_kernel.h index 35cace5d88..98113702fc 100644 --- a/projects/rccl/src/device/common_kernel.h +++ b/projects/rccl/src/device/common_kernel.h @@ -618,7 +618,7 @@ __device__ __attribute__((noinline)) void reduceCopyPacksWithBias( thread = warp*WARP_SIZE + lane; } -template @@ -641,7 +641,7 @@ __device__ __forceinline__ void reduceCopy( IntBytes nBytesBehind = 0; IntBytes nBytesAhead = nElts*sizeof(T); - bool useAcc = accPtrFn() != nullptr; + //bool useAcc = accPtrFn() != nullptr; #if __cpp_if_constexpr if constexpr (BigPackSize > sizeof(T)) { @@ -763,7 +763,7 @@ __device__ __forceinline__ void reduceCopy( nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead); } -template @@ -773,7 +773,7 @@ __device__ __forceinline__ void reduceCopy( int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs, IntBytes nElts, void *accPtr = nullptr ) { - reduceCopy (thread, nThreads, redArg, preOpArgs, postOp, diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index 88f3456caf..a129fc65ca 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -4,14 +4,15 @@ import sys import subprocess # Order of redops, tys, protos, algos must match src/include/device.h -all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"] +all_colls = ["AllGather","AllReduce","AllReduceWithBias","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"] all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"] all_protos = ["LL","LL128","SIMPLE"] all_algos = ["TREE","RING", "PAT"] all_unroll = ["1", "2", "4"] +use_acc = ["0", "1"] -all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, all_unroll] +all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_unroll] ################################################################################ # The first command line argument is the path to the directory to generate and @@ -76,43 +77,47 @@ func_pattern = sys.argv[6:7] if func_pattern and func_pattern[0]: func_pattern = func_pattern[0] else: - func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv" + func_pattern = "AllGather|AllReduce|AllReduceWithBias|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv" ################################################################################ algos_of_coll = { - "AllGather": ["RING", "PAT"], - "AllReduce": ["RING", "TREE"], - "AllToAllPivot": ["RING"], - "Broadcast": ["RING"], - "Reduce": ["RING"], - "ReduceScatter": ["RING", "PAT"], - "SendRecv": ["RING"] + "AllGather": ["RING", "PAT"], + "AllReduce": ["RING", "TREE"], + "AllReduceWithBias": ["RING", "TREE"], + "AllToAllPivot": ["RING"], + "Broadcast": ["RING"], + "Reduce": ["RING"], + "ReduceScatter": ["RING", "PAT"], + "SendRecv": ["RING"] } protos_of_coll = { - "AllGather": all_protos, - "AllReduce": all_protos, - "AllToAllPivot": ["SIMPLE"], - "Broadcast": all_protos, - "Reduce": all_protos, - "ReduceScatter": all_protos, - "SendRecv": ["SIMPLE"] + "AllGather": all_protos, + "AllReduce": all_protos, + "AllReduceWithBias": all_protos, + "AllToAllPivot": ["SIMPLE"], + "Broadcast": all_protos, + "Reduce": all_protos, + "ReduceScatter": all_protos, + "SendRecv": ["SIMPLE"] } redops_of_coll = { - "AllGather": ["Sum"], - "AllReduce": all_redops, - "AllToAllPivot": ["Sum"], - "Broadcast": ["Sum"], - "Reduce": all_redops, - "ReduceScatter": all_redops, - "SendRecv": ["Sum"] + "AllGather": ["Sum"], + "AllReduce": all_redops, + "AllReduceWithBias": all_redops, + "AllToAllPivot": ["Sum"], + "Broadcast": ["Sum"], + "Reduce": all_redops, + "ReduceScatter": all_redops, + "SendRecv": ["Sum"] } tys_of_coll = { "AllGather": ["i8"], "AllReduce": all_tys, + "AllReduceWithBias": all_tys, "AllToAllPivot": ["i8"], "Broadcast": ["i8"], "Reduce": all_tys, @@ -121,11 +126,12 @@ tys_of_coll = { } coll_camel_to_lower = { - "AllGather": "all_gather", - "AllReduce": "all_reduce", - "AllToAllPivot": "alltoall_pivot", - "Broadcast": "broadcast", - "Reduce": "reduce", + "AllGather": "all_gather", + "AllReduce": "all_reduce", + "AllReduceWithBias": "allreduce_with_bias", + "AllToAllPivot": "alltoall_pivot", + "Broadcast": "broadcast", + "Reduce": "reduce", "ReduceScatter": "reduce_scatter", "SendRecv": "sendrecv" } @@ -174,7 +180,11 @@ def calc_unroll_for_local_arch(): return all_unroll # Helper function to check if the conditions for the collective is being met -def func_validate(coll, algo, proto, redop, ty, unroll): +def func_validate(coll, algo, proto, redop, ty, acc, unroll): + if acc == "1" and coll != "AllReduceWithBias": + return False + if acc == "0" and coll == "AllReduceWithBias": + return False if redop == "SumPostDiv" and ty[0] not in ("i","u"): return False if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll: @@ -222,10 +232,10 @@ def func_filter(function_params, current_idx, item_list=None): # For each loop layer remove the last element in item_list item_list.pop() else: - coll, algo, proto, redop, ty, unroll = item_list + coll, algo, proto, redop, ty, acc, unroll = item_list - if func_validate(coll, algo, proto, redop, ty, unroll): - yield(coll, algo, proto, redop, ty, unroll) + if func_validate(coll, algo, proto, redop, ty, acc, unroll): + yield(coll, algo, proto, redop, ty, acc, unroll) # Parse ONLY_FUNCS input and feed it to func_filter def parse_input(func_pattern): @@ -245,33 +255,35 @@ def parse_input(func_pattern): # Maps functions to the chosen representative for the equivalence class it # belongs to. For instance (sum, signed int) maps to (sum, unsigned int). -def equivalent_primary(coll, algo, proto, redop, ty, unroll): - if coll in ("AllReduce", "Reduce", "ReduceScatter"): +def equivalent_primary(coll, algo, proto, redop, ty, acc, unroll): + if coll in ("AllReduce", "AllReduceWithBias", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i": ty = "u"+ty[1:] # map signed integer min/max to unsigned for non-NVLS elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo): ty = "u"+ty[1:] - return (coll, algo, proto, redop, ty, unroll) + return (coll, algo, proto, redop, ty, acc, unroll) # Order rows are enumerated must match formula of `ncclDevFuncId()`: def enumerate_func_rows(): - for unroll in all_unroll: - for coll in all_colls: - for algo in all_algos: - for proto in all_protos: - for redop in all_redops: - for ty in all_tys: - if func_validate(coll, algo, proto, redop, ty, unroll): - yield (coll, algo, proto, redop, ty, unroll) + for acc in use_acc: + for unroll in all_unroll: + for coll in all_colls: + for algo in all_algos: + for proto in all_protos: + for redop in all_redops: + for ty in all_tys: + if func_validate(coll, algo, proto, redop, ty, acc, unroll): + yield (coll, algo, proto, redop, ty, acc, unroll) # Sort the hashmap based on custom key def custom_sort_key(fn): - coll, algo, proto, redop, ty, unroll = fn + coll, algo, proto, redop, ty, acc, unroll = fn return ( all_unroll.index(unroll), + use_acc.index(acc), all_colls.index(coll), all_algos.index(algo), all_protos.index(proto), @@ -320,7 +332,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n") index1 = 0 for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn + coll, algo, proto, redop, ty, acc, unroll = fn if unroll != "1": continue sym = paste("_", "ncclDevFunc", *fn) if fn[2] == "LL128": @@ -337,7 +349,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n") index2 = 0 for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn + coll, algo, proto, redop, ty, acc, unroll = fn if unroll != "2": continue sym = paste("_", "ncclDevFunc", *fn) if fn[2] == "LL128": @@ -354,7 +366,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") index4 = 0 for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn + coll, algo, proto, redop, ty, acc, unroll = fn if unroll != "4": continue sym = paste("_", "ncclDevFunc", *fn) if fn[2] == "LL128": @@ -469,7 +481,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: # Maps to .cu filename which implements this func. The only constraint is that # "coll" is reflected in the name: formally that no two funcs having different # coll's map to the same filename. -def impl_filename(coll, algo, proto, redop, ty, unroll): +def impl_filename(coll, algo, proto, redop, ty, acc, unroll): return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty) # Partition the functions and kernels to the .cu filenames. The partition is @@ -518,6 +530,8 @@ for name in name_to_funcs.keys(): print("-- Generating %s" % os.path.join(gensrc, name)) out = f.write + if coll == "AllReduceWithBias": + coll = "AllReduce" out( '#include "common.h"\n' '#include "{lower_coll}.h"\n' @@ -525,14 +539,14 @@ for name in name_to_funcs.keys(): ) for fn in fns: - (coll, algo, proto, redop, ty, unroll) = fn - sym = paste("_", coll, algo, proto, redop, ty, unroll) + (coll, algo, proto, redop, ty, acc, unroll) = fn + sym = paste("_", coll, algo, proto, redop, ty, acc, unroll) if proto == "LL128": out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") out( - "DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {unroll})\n" + "DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {unroll})\n" .format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty], - algo=(algo or "RING"), proto=(proto or "SIMPLE"), unroll=unroll) + algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, unroll=unroll) ) if proto == "LL128": out("#endif\n") diff --git a/projects/rccl/src/device/msccl_kernel_impl.h b/projects/rccl/src/device/msccl_kernel_impl.h index 5e0c574e27..9eabe54aa4 100644 --- a/projects/rccl/src/device/msccl_kernel_impl.h +++ b/projects/rccl/src/device/msccl_kernel_impl.h @@ -376,7 +376,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct n mscclRunInterpreter, ProtoLL128, fullOps>(comm, algo, work); \ } \ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \ - mscclRunInterpreter, ProtoSimple, fullOps>(comm, algo, work); \ + mscclRunInterpreter, ProtoSimple, fullOps>(comm, algo, work); \ } #define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \ diff --git a/projects/rccl/src/device/onerank.cu b/projects/rccl/src/device/onerank.cu index 25bb2ea442..b183cf62fa 100644 --- a/projects/rccl/src/device/onerank.cu +++ b/projects/rccl/src/device/onerank.cu @@ -49,7 +49,7 @@ namespace { redOpArg = *reinterpret_cast(redOpArg); } } - reduceCopy + reduceCopy (tid, tn, redOpArg, &redOpArg, true, 1, &src, 1, &dst, i1-i0); } } diff --git a/projects/rccl/src/device/primitives.h b/projects/rccl/src/device/primitives.h index 1c56d8c24a..daee34fd1f 100644 --- a/projects/rccl/src/device/primitives.h +++ b/projects/rccl/src/device/primitives.h @@ -55,7 +55,7 @@ * to how that protocol operates with a consistent interface so that our * algorithm code can operate protocol parametrically. */ -template +template struct ProtoSimple { static constexpr int Id = NCCL_PROTO_SIMPLE; static constexpr int SlicePerChunk = SlicePerChunk_1; diff --git a/projects/rccl/src/device/prims_ll.h b/projects/rccl/src/device/prims_ll.h index 0e2dbf300d..2709df2dfc 100644 --- a/projects/rccl/src/device/prims_ll.h +++ b/projects/rccl/src/device/prims_ll.h @@ -10,9 +10,9 @@ #include "npkit/npkit.h" #endif -template -class Primitives: - public PrimitivesWithoutDirect> { +template +class Primitives: + public PrimitivesWithoutDirect> { // In the case of Fan::MaxRecv == 0, we need to force MaxRecv to 1 for this to compile // This is because of a recv buffer which is allocated to MaxRecv length in send-only cases @@ -437,7 +437,7 @@ private: constexpr int DST = DstBuf != -1 ? 1 : 0; T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; - T *accElts = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx; + T *accElts = (DstBuf == -1 || !useAcc) ? nullptr : userBufs[Acc] + dstIx; // Always waitSend in case of cleanup nelem = nelem < 0 ? 0 : nelem; diff --git a/projects/rccl/src/device/prims_ll128.h b/projects/rccl/src/device/prims_ll128.h index 7be997fe65..698ab7b032 100644 --- a/projects/rccl/src/device/prims_ll128.h +++ b/projects/rccl/src/device/prims_ll128.h @@ -20,9 +20,9 @@ #endif #endif -template -class Primitives: - public PrimitivesWithoutDirect> { +template +class Primitives: + public PrimitivesWithoutDirect> { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1, Acc=2;; @@ -366,7 +366,7 @@ private: constexpr int DST = DstBuf != -1 ? 1 : 0; T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; - T *accPtr = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx; + T *accPtr = (DstBuf == -1 || !useAcc) ? nullptr : userBufs[Acc] + dstIx; int wireOffset = WireWordPerSlice*warp + 2*wid; const int nwarps = nthreads/WARP_SIZE; nelem = nelem < 0 ? 0 : nelem; diff --git a/projects/rccl/src/device/prims_simple.h b/projects/rccl/src/device/prims_simple.h index 86f7bef84f..9a5f067e04 100644 --- a/projects/rccl/src/device/prims_simple.h +++ b/projects/rccl/src/device/prims_simple.h @@ -23,9 +23,9 @@ enum primsMode { }; template + int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata, int useAcc> class Primitives< - T, RedOp, Fan, Direct, ProtoSimple, P2p, isNetOffload, Metadata + T, RedOp, Fan, Direct, ProtoSimple, P2p, isNetOffload, Metadata > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; @@ -299,7 +299,7 @@ private: } #endif - reduceCopy + reduceCopy (tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false, 1, ncclShmem.groups[group].srcs, fan.nsend(), ncclShmem.groups[group].dsts+1, @@ -335,7 +335,7 @@ private: } #endif - reduceCopy + reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp, Recv, ncclShmem.groups[group].srcs, Dst, ncclShmem.groups[group].dsts, @@ -373,7 +373,7 @@ private: DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; if (Send && Dst && ncclShmem.groups[group].dsts[1] == nullptr) { // this case should only be directCopySend() with registered buffers and send to net peer - reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, @@ -381,7 +381,7 @@ private: 1, ncclShmem.groups[group].dsts, workSize); } else { - reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, @@ -453,19 +453,19 @@ private: srcs[nsrcs] = dsts[0]; nsrcs++; if (MULTISRCS){ - reduceCopy + reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem); } else { - reduceCopy + reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem); } } if (COPY){ - reduceCopy + reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)srcs, 1, (void **)dsts, nelem); if (MULTISRCS) { for (int i = 1; i < nsrcs; i++){ - reduceCopy + reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)&srcs[i], 1, (void **)&dsts[i], nelem); } } @@ -617,7 +617,7 @@ private: void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset; ssize_t realPeerSize = min(realSize, totalElem-pOffset); if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) { - reduceCopy(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize); + reduceCopy(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize); // Mark for threadfence at the end fenceNeeded |= true; } @@ -637,7 +637,7 @@ private: void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset; ssize_t realPeerSize = min(realSize, totalElem-pOffset); if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0; - if (realPeerSize > 0) reduceCopy(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); + if (realPeerSize > 0) reduceCopy(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); } } } @@ -1190,7 +1190,7 @@ public: int workSize = ncclShmem.aborted ? 0 : nelem; - reduceCopy + reduceCopy (tid, nthreads, ncclShmem.redOpArgs[0], nullptr, /*postOp=*/false, nSrcs, srcs, 1, ncclShmem.groups[group].dsts, workSize); @@ -1284,7 +1284,7 @@ public: int workSize = ncclShmem.aborted ? 0 : nelem; - reduceCopy + reduceCopy (tid, nthreads, ncclShmem.redOpArgs[0], nullptr, /*postOp=*/false, 1, ncclShmem.groups[group].srcs, nDsts, dsts, workSize); diff --git a/projects/rccl/src/device/reduce_scatter.h b/projects/rccl/src/device/reduce_scatter.h index ac4c33f079..82b22938b7 100644 --- a/projects/rccl/src/device/reduce_scatter.h +++ b/projects/rccl/src/device/reduce_scatter.h @@ -258,7 +258,7 @@ struct RunWorkCollregUsed) { if (tid < tidEndScatter) { // Scatter - using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -270,7 +270,7 @@ struct RunWorkColl Coverity think prims.index can be greater than 1 } else if (tid < tidEndReduce) { // Reduce through NVLS - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>; Primitives, /*Direct=*/0, Proto, 0> prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, work->recvbuff, work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0); @@ -283,7 +283,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>; Primitives, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, nvls->up, nvls->up, NULL, NULL, work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); @@ -295,7 +295,7 @@ struct RunWorkColl; + using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>; Primitives, /*Direct=*/1, Proto, 0> prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, &nvls->down, NULL, work->recvbuff, work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, work); diff --git a/projects/rccl/src/device/sendrecv.h b/projects/rccl/src/device/sendrecv.h index 723ee68a41..7bf63e1f49 100644 --- a/projects/rccl/src/device/sendrecv.h +++ b/projects/rccl/src/device/sendrecv.h @@ -247,10 +247,10 @@ struct RunWorkBatch + reduceCopy (subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes); #else - reduceCopy + reduceCopy (subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes); #endif } else if (isSend) { @@ -258,9 +258,9 @@ struct RunWorkBatch(subtid, subtn, group, work); } else { #if defined(__gfx90a__) - runSend>(subtid, subtn, group, work); + runSend>(subtid, subtn, group, work); #elif defined(__gfx908__) || defined(__gfx942__) || defined(__gfx950__) - runSend>(subtid, subtn, group, work); + runSend>(subtid, subtn, group, work); #else runSend>(subtid, subtn, group, work); #endif @@ -270,13 +270,13 @@ struct RunWorkBatch(subtid, subtn, group, work); } else { #if defined(__gfx90a__) - runRecv>(subtid, subtn, group, work); + runRecv>(subtid, subtn, group, work); #elif defined(__gfx908__) || defined(__gfx942__) || defined(__gfx950__) - runRecv>(subtid, subtn, group, work); + runRecv>(subtid, subtn, group, work); #else runRecv>(subtid, subtn, group, work); #endif } } } -}; \ No newline at end of file +}; diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index efc1087c36..a1e92a413a 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -755,6 +755,6 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) return ncclDevFuncRowToId[row]; } -inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - 1]; } +inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - AR_WITH_BIAS_FUNC_COUNTS - NCCL_NUM_ONERANK - 1]; } #endif diff --git a/projects/rccl/src/include/nccl_common.h b/projects/rccl/src/include/nccl_common.h index 6bab98d954..3512072e87 100644 --- a/projects/rccl/src/include/nccl_common.h +++ b/projects/rccl/src/include/nccl_common.h @@ -40,7 +40,8 @@ typedef enum { typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); #define NCCL_NUM_ONERANK 12 -#define FUNC_INDEX_TOTAL 821 + NCCL_NUM_ONERANK +#define AR_WITH_BIAS_FUNC_COUNTS 324 +#define FUNC_INDEX_TOTAL 821 + AR_WITH_BIAS_FUNC_COUNTS + NCCL_NUM_ONERANK #define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now typedef enum { @@ -49,11 +50,12 @@ typedef enum { ncclFuncAllGather = 2, ncclFuncReduceScatter = 3, ncclFuncAllReduce = 4, - ncclFuncSendRecv = 5, - ncclFuncSend = 6, - ncclFuncRecv = 7, - ncclFuncAllToAllPivot = 8, - ncclNumFuncs = 9 + ncclFuncAllReduceWithBias = 5, + ncclFuncSendRecv = 6, + ncclFuncSend = 7, + ncclFuncRecv = 8, + ncclFuncAllToAllPivot = 9, + ncclNumFuncs = 10 } ncclFunc_t; #define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*