From 9d72be7b2fc89963e97c31c6cf89140a2c20e1f5 Mon Sep 17 00:00:00 2001 From: Nilesh M Negi Date: Wed, 11 Jun 2025 00:07:59 -0500 Subject: [PATCH] [DEVICE] Adding ability to choose unroll factor at runtime (#1734) * Adding runtime unroll factor selection via RCCL_UNROLL_FACTOR * [BUILD] Add support for user-defined UNROLL for debugging * Update CHANGELOG.md * Fix COLLTRACE errors in CI * Add debug statements for unroll and resolve warnings * Incorporate UNROLL into ONLY_FUNCS for debugging --------- Signed-off-by: nileshnegi Co-authored-by: gilbertlee-amd <44450918+gilbertlee-amd@users.noreply.github.com> Co-authored-by: Jeffrey Novotny --- CHANGELOG.md | 5 +- CMakeLists.txt | 6 +- src/device/common.cu | 25 ---- src/device/common.h | 27 +--- src/device/generate.py | 276 +++++++++++++++----------------------- src/device/primitives.h | 3 + src/enqueue.cc | 59 ++++---- src/include/nccl_common.h | 5 - src/init.cc | 36 +++-- 9 files changed, 184 insertions(+), 258 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e6c9fa02e..6cc66778ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,10 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https: ### Added * Added new GPU target `gfx950`. -* Added support for `unroll=1` in device-code generation to improve performance, -* Set a default of 112 channels for a single node with `8 * gfx950`, +* Added support for `unroll=1` in device-code generation to improve performance. +* Set a default of 112 channels for a single node with `8 * gfx950`. * Enabled LL128 protocol on `gfx950`. +* Adding ability to choose unroll factor at runtime via `RCCL_UNROLL_FACTOR`. This can be set at runtime to 1, 2, or 4. This change currently increases compilation and linking time because it triples the number of kernels generated. * Added MSCCL support for AllGather multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the environment variable `RCCL_MSCCL_FORCE_ENABLE=1`. Max message size for MSCCL AllGather usage is `12292 * sizeof(datatype) * nGPUs`. * Thread thresholds for LL/LL128 are selected in Tuning Models for the MI300X. This impacts the number of channels used for AG and RS. Channel tuning model is bypassed if `NCCL_THREAD_THRESHOLDS`, `NCCL_MIN_NCHANNELS', or 'NCCL_MAX_NCHANNELS` are set. * Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges. diff --git a/CMakeLists.txt b/CMakeLists.txt index a93c1c0c19..6e9e993823 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -688,9 +688,13 @@ endif() set(GEN_DIR "${HIPIFY_DIR}/gensrc") +if(ONLY_FUNCS) + message(WARNING "Using ONLY_FUNCS = ${ONLY_FUNCS}. Not meant for release builds.") +endif() + # Execute the python script to generate required files execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${BUILD_LOCAL_GPU_TARGET_ONLY} ${ONLY_FUNCS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} RESULT_VARIABLE gen_py_result ERROR_VARIABLE gen_py_error diff --git a/src/device/common.cu b/src/device/common.cu index 36d396fbb8..40aeb5cd76 100644 --- a/src/device/common.cu +++ b/src/device/common.cu @@ -13,31 +13,6 @@ __shared__ ncclShmemData ncclShmem; __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; #endif -struct RunWorkNop { - __device__ void run() {} -}; - -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args); -} -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args); -} -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args); -} -#ifdef ENABLE_COLLTRACE -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args); -} -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args); -} -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args); -} -#endif - #ifdef USE_INDIRECT_FUNCTION_CALL __device__ void ncclDevFunc_Nop(); #else diff --git a/src/device/common.h b/src/device/common.h index 472b923fd1..a8e66ec7b1 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -11,8 +11,8 @@ #include "collectives.h" #include "device.h" #include "op128.h" -#include "reduce_kernel.h" #include "device_table.h" +#include "reduce_kernel.h" #include "network/unpack/unpack_defs.h" #define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree @@ -572,21 +572,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) { SpecializedRunWorkBatch().run(); } else { -#ifdef USE_INDIRECT_FUNCTION_CALL - if (COLL_UNROLL == 1) - ncclDevFuncTable_1[ncclShmem.funcId](); - else if (COLL_UNROLL == 2) - ncclDevFuncTable_2[ncclShmem.funcId](); - else - ncclDevFuncTable_4[ncclShmem.funcId](); -#else - if (COLL_UNROLL == 1) - NCCL_CALL_FUNCTIONS_1(ncclShmem.funcId); - else if (COLL_UNROLL == 2) - NCCL_CALL_FUNCTIONS_2(ncclShmem.funcId); - else - NCCL_CALL_FUNCTIONS_4(ncclShmem.funcId); -#endif + NCCL_CALL_FUNCTIONS(ncclShmem.funcId); } if (ncclShmem.nextBatchIx == -1) break; @@ -628,15 +614,6 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a #endif } -__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -#ifdef ENABLE_COLLTRACE -__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -#endif - #define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \ __global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {} diff --git a/src/device/generate.py b/src/device/generate.py index 58096432f9..95604d5a0d 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -32,7 +32,7 @@ else: # developing device code. The regex supports non-space containing globs '*', # and union 'a|b'. The string representing the function has the form: # -# +# # # The possible values for redop, type, algo, proto can be found in the all_ # lists at the top of this file. @@ -45,23 +45,28 @@ else: # # Only AllReduce and Reduce # make ONLY_FUNCS="AllReduce|Reduce" # +# # Only AllGather with unroll=4 +# make ONLY_FUNCS="AllGather * * * * 4" +# # # Only non-reductions: # make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" # # # Only AllReduce Sum int32_t (but all algos, protos) # make ONLY_FUNCS="AllReduce * * Sum int32_t" # -# # Only AllReduce RING Max float (but all protos) +# # Only AllReduce RING Max float (but all protos and unrolls) # make ONLY_FUNCS="AllReduce RING * Max float" # -# # AllReduce TREE LL128 Prod rccl_bfloat16 -# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16" +# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1 +# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16 1" # -# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter) -# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float" +# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter) +# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float *" # --- or --- -# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" -# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t" +# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float *" +# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|AllGather RING LL/SIMPLE Sum int8_t 1/2/4|AllToAllPivot RING SIMPLE Sum int8_t 1/2/4|Broadcast RING LL/SIMPLE Sum int8_t 1/2/4|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|SendRecv RING SIMPLE Sum int8_t 1/2/4" +# +# # ONLY_FUNCS can be used together for debugging # Paste all non-None arguments together with `sep`. def paste(sep, *args): @@ -70,9 +75,8 @@ def paste(sep, *args): is_ifc = 1 if sys.argv[2] == "ON" else 0 is_colltrace = 1 if sys.argv[3] == "ON" else 0 is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0 -is_local_arch_only = 1 if sys.argv[5] == "ON" else 0 -func_pattern = sys.argv[6:7] +func_pattern = sys.argv[5:] if func_pattern and func_pattern[0]: func_pattern = func_pattern[0] else: @@ -133,49 +137,13 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower} ################################################################################ -def calc_unroll_for_local_arch(): - if not is_local_arch_only: - return - - rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo" - - res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True) - rocminfo_output = res.stdout - - # Parse rocminfo binary output - gfx_targets = {} - curr_name = None - for line in rocminfo_output.splitlines(): - line = line.strip() - - if line.startswith("Name:"): - name = line.split(':')[-1].strip() - if "gfx" in name: - curr_name = name - if line.startswith("Compute Unit:") and curr_name: - cu_count = int(line.split(':')[-1].strip()) - gfx_targets[(curr_name, cu_count)] = None - curr_name = None - - # We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts - # Use (gfx_name, cu_count) as key for dictionary and convert it to list here - gfx_targets = list(gfx_targets.keys()) - - # Homogeneous system is required to build for only 1 varient of unroll factor - if len(gfx_targets) == 1: - gfx_name, cu_count = gfx_targets[0] - if "gfx950" == gfx_name: - return 1 - elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80): - return 2 - else: - return 4 +seen_unroll = [] # Helper function to check if the conditions for the collective is being met -def func_validate(coll, algo, proto, redop, ty): +def func_validate(coll, algo, proto, redop, ty, unroll): if redop == "SumPostDiv" and ty[0] not in ("i","u"): return False - if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]: + if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll: return False return True @@ -193,8 +161,8 @@ def func_filter(function_params, current_idx, item_list=None): if current_element == "*": if current_idx == 0: raise ValueError("Error: Paramter 'COLL' can not be type all '*'.") - - # all_params list must be in the same order as function_params --> + + # all_params list must be in the same order as function_params --> # Get the current list from all_params current_list = all_params[current_idx] @@ -210,12 +178,12 @@ def func_filter(function_params, current_idx, item_list=None): # Check if the current element is recognized elements = current_element.split("/") current_param = all_params[current_idx] - + # Iterate over the elements in the elements list for item in elements: if item not in current_param: raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.") - + for item in elements: item_list.append(item) yield from func_filter(function_params, current_idx+1, item_list) @@ -225,7 +193,9 @@ def func_filter(function_params, current_idx, item_list=None): else: coll, algo, proto, redop, ty, unroll = item_list - if func_validate(coll, algo, proto, redop, ty): + if func_validate(coll, algo, proto, redop, ty, unroll): + if not unroll in seen_unroll: + seen_unroll.append(unroll) yield(coll, algo, proto, redop, ty, unroll) # Parse ONLY_FUNCS input and feed it to func_filter @@ -247,10 +217,6 @@ def parse_input(func_pattern): # Maps functions to the chosen representative for the equivalence class it # belongs to. For instance (sum, signed int) maps to (sum, unsigned int). def equivalent_primary(coll, algo, proto, redop, ty, unroll): - # if local arch only, we only need to build for 1 varient of coll_unroll. - # map the other varient of coll_unroll to this one. - if coll_unroll: - unroll = str(coll_unroll) if coll in ("AllReduce", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i": @@ -268,13 +234,13 @@ def enumerate_func_rows(): for proto in all_protos: for redop in all_redops: for ty in all_tys: - if func_validate(coll, algo, proto, redop, ty): + if func_validate(coll, algo, proto, redop, ty, unroll): yield (coll, algo, proto, redop, ty, unroll) -# Sort the hashmap based on custom key +# Sort the hashmap based on custom key def custom_sort_key(fn): coll, algo, proto, redop, ty, unroll = fn - + return ( all_unroll.index(unroll), all_colls.index(coll), @@ -286,8 +252,6 @@ def custom_sort_key(fn): ################################################################################ -coll_unroll = calc_unroll_for_local_arch() - # Corresponds to ncclDevFuncRowToId[] func_rows = [fn for fn in enumerate_func_rows()] @@ -304,6 +268,8 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: print("-- Generating %s" % os.path.join(gensrc, "device_table.h")) out = f.write + out("#include \"common.h\"\n\n") + if is_ifc: func_declaration = "__device__ void" else: func_declaration = "__device__ __attribute__((noinline)) void" @@ -320,113 +286,86 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: out("\n") out("typedef void(*ncclDevFuncPtr_t)();\n\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n") - index1 = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn - if unroll != "1": continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (index1, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (index1, sym)) - index1 += 1 - out("nullptr};\n") - out("\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n") - index2 = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn - if unroll != "2": continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (index2, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (index2, sym)) - index2 += 1 - out("nullptr};\n") - out("\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") - index4 = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, unroll = fn - if unroll != "4": continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (index4, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (index4, sym)) - index4 += 1 - out("nullptr};\n") - out("\n") - + + # Generate function tables per unroll factor + tableIdx = 0 + for curr_unroll in seen_unroll: + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % curr_unroll) + tableIdx = 0 + for fn in primary_funcs: + coll, algo, proto, redop, ty, unroll = fn + if curr_unroll != unroll: continue + sym = paste("_", "ncclDevFunc", *fn) + if fn[2] == "LL128": + out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s,\n#else\n" % (tableIdx, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s,\n#endif\n" % (tableIdx, sym_ll)) + else: + out("/*%4d*/ %s,\n" % (tableIdx, sym)) + tableIdx += 1 + out("nullptr};\n") + out("\n") + + # Construct indirection function workaround if not is_ifc: - out("template\n" - "struct Caller1 {\n" + out("template\n" + "struct Caller {\n" " static __forceinline__ __device__ __host__\n" - " void call1(unsigned short funcIndex) noexcept\n" + " void call(unsigned short funcIndex) noexcept\n" " {\n" " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller1::call1(funcIndex) : Caller1::call1(funcIndex);\n" + " return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex);\n" " }\n" "};\n" - "\n" - "template\n" - "struct Caller1{\n" - " static __forceinline__ __device__ __host__\n" - " void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n" - "};\n") - out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n") - out(f" Caller1<0, {index1}>::call1(funcIndex);\n") - out("}\n\n") - out("template\n" - "struct Caller2 {\n" - " static __forceinline__ __device__ __host__\n" - " void call2(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller2::call2(funcIndex) : Caller2::call2(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller2{\n" - " static __forceinline__ __device__ __host__\n" - " void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n" - "};\n") - out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n") - out(f" Caller2<0, {index2}>::call2(funcIndex);\n") - out("}\n\n") - out("template\n" - "struct Caller4 {\n" - " static __forceinline__ __device__ __host__\n" - " void call4(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller4::call4(funcIndex) : Caller4::call4(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller4{\n" - " static __forceinline__ __device__ __host__\n" - " void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n" - "};\n") - out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n") - out(f" Caller4<0, {index4}>::call4(funcIndex);\n") - out("}\n\n") + "\n") + + for curr_unroll in seen_unroll: + out("template\n") + out("struct Caller<%s, f, f + 1>{\n" % curr_unroll) + out(" static __forceinline__ __device__ __host__\n"); + out(" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable_%s[f](); }\n" % curr_unroll) + out("};\n") + + out("\n") + # Create NCCL_CALL_FUNCTION helper function that will call the appropriate device function + out("template \n" + "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n") + if is_ifc: + for curr_unroll in seen_unroll: + out(" if (unroll == %s) { ncclDevFuncTable_%s[funcIndex]();\n" % (curr_unroll, curr_unroll)) + else: + out(f" Caller::call(funcIndex);\n") + out("}\n\n") + + # Create RCCL + out("template\n"); + out("__device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* args);\n\n"); + + out("struct RunWorkNop {\n"); + out(" __device__ void run() {}\n"); + out("};\n\n"); + + out("template \n" + "__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void rcclGenericKernel(ncclDevKernelArgs4K const args4K) {\n" + " ncclKernelMain<-1, RunWorkNop, COLLTRACE, UNROLL>(&args4K.args);\n" + "}\n\n") + + out("struct rcclKernelItem {\n"); + out(" void* funcPtr;\n"); + out(" int unroll;\n"); + out("};\n\n"); + + out("/* This table contains all the __global__ functions that were compiled */\n"); + out("static struct rcclKernelItem rcclKernelTable[] = {\n") + for unroll in seen_unroll: + out(" {(void*)&(rcclGenericKernel<%s, false>), %s},\n" % (unroll, unroll)) + out("#ifdef ENABLE_COLLTRACE\n") + for unroll in seen_unroll: + out(" {(void*)&(rcclGenericKernel<%s, true>), %s},\n" % (unroll, unroll)) + out("#endif\n"); + out("};\n\n"); # Generate /device_table.cpp if is_colltrace: @@ -436,7 +375,7 @@ if is_colltrace: out = f.write out('#include "nccl_common.h"\n#include "device.h"\n') out("\n") - + seen_fns = set() out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n") for fn in primary_funcs: @@ -459,10 +398,13 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: # The mapping from function rows to valid primary function ids. out("extern int const ncclDevFuncRowToId[] = {\n") index = 0 - for fn in func_rows[:len(func_rows)//3]: + offset = len(func_rows)//len(all_unroll) + start = all_unroll.index(seen_unroll[0]) * offset + end = start + offset + for fn in func_rows[start:end]: fn_id, comment = -1, "" if fn is not None: - fn_id = primary_to_index[equivalent_primary(*fn)] + fn_id = primary_to_index[equivalent_primary(*fn)] % offset if primary_to_index[equivalent_primary(*fn)] != -1 else -1 comment = " // " + paste(" ", *fn[:-1]) out("/*%4d*/ %d,%s\n" % (index, fn_id, comment)) index += 1 diff --git a/src/device/primitives.h b/src/device/primitives.h index c0536f1cf4..071a259d63 100644 --- a/src/device/primitives.h +++ b/src/device/primitives.h @@ -48,7 +48,10 @@ } while (0) #define barrier_by_group() barrier_by_group_common(__threadfence()) + +#if defined(__gfx942__) || defined(__gfx950__) #define barrier_by_group_block() barrier_by_group_common(__threadfence_block()) +#endif /* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128 * We use these as template args to the Primtiives class instead of integral diff --git a/src/enqueue.cc b/src/enqueue.cc index b4949e5206..d1b33ae192 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -28,30 +28,31 @@ using namespace rccl; -struct ncclKernelMatch { - void* kernelFn; - bool specialized; -}; +/* [RCCL] Determine which GPU kernel to execute */ +void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL) +{ + // At this time, unroll factor is controlled only by passed in unroll argument + // After more investigation, this may be further tuned by the actual task being processed #ifdef ENABLE_COLLTRACE -#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0)) -static ncclKernelMatch const ncclKerns[6] = { - {(void *)ncclDevKernel_Generic_1, true}, - {(void *)ncclDevKernel_Generic_2, true}, - {(void *)ncclDevKernel_Generic_4, true}, - {(void *)ncclDevKernelDebug_Generic_1, true}, - {(void *)ncclDevKernelDebug_Generic_2, true}, - {(void *)ncclDevKernelDebug_Generic_4, true} -}; + int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]) / 2; + int firstKernel = useCollTrace ? numKernels : 0; #else -#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll) -static ncclKernelMatch const ncclKerns[3] = { - {(void*)ncclDevKernel_Generic_1, true}, - {(void*)ncclDevKernel_Generic_2, true}, - {(void*)ncclDevKernel_Generic_4, true} -}; + int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]); + int firstKernel = 0; #endif + // Check if the requested unroll exists + for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) { + if (rcclKernelTable[firstKernel + kernelIdx].unroll == unroll) { + return rcclKernelTable[firstKernel + kernelIdx].funcPtr; + } + } + // Fall back to default unroll + WARN("Requested RCCL_UNROLL_FACTOR: %d does not exist in `rcclKernelTable`. Falling back to default unroll: %d", unroll, rcclKernelTable[firstKernel].unroll); + return rcclKernelTable[firstKernel].funcPtr; +} + static int rcclProtoGrainSize(int proto, ncclComm *comm){ switch (proto) { case NCCL_PROTO_LL: return 16; @@ -81,7 +82,7 @@ NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); // Returns maximum kernel stack size of all CUDA kernels ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize) { - constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]); + constexpr int KernelCount = sizeof(rcclKernelTable)/sizeof(rcclKernelTable[0]); ncclResult_t result = ncclSuccess; int print = 0; @@ -95,7 +96,7 @@ ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* ma int ncclMaxSharedMem = rcclShmemDynamicSize(cudaArch, WarpSize); for (int k=0; k < KernelCount; k++) { - void* fn = ncclKerns[k].kernelFn; + void* fn = rcclKernelTable[k].funcPtr; cudaFuncAttributes attr = {0}; if (fn == nullptr) continue; @@ -783,8 +784,12 @@ static ncclResult_t scheduleCollTasksToPlan( //plan->channelMask.masks[channelId/64] |= (2ull<channelHi) - (1ull<channelLo); plan->threadPerBlock = std::max(plan->threadPerBlock, 192 /* 3*WARP_SIZE */); if (!plan->kernelSpecialized) { - plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn; - plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized; +#ifdef ENABLE_COLLTRACE + plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled); +#else + plan->kernelFn = rcclGetKernelIndex(comm->unroll, false); +#endif + plan->kernelSpecialized = true; } if (comm->rank == 0) { @@ -1084,8 +1089,12 @@ static ncclResult_t scheduleP2pTasksToPlan( plan->threadPerBlock = std::max(plan->threadPerBlock, NCCL_MAX_NTHREADS); if (!plan->kernelSpecialized) { - plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn; - plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized; +#ifdef ENABLE_COLLTRACE + plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled); +#else + plan->kernelFn = rcclGetKernelIndex(comm->unroll, false); +#endif + plan->kernelSpecialized = true; } // Compute how much to split operations diff --git a/src/include/nccl_common.h b/src/include/nccl_common.h index 6a9d18d3b5..3ba71c98ad 100644 --- a/src/include/nccl_common.h +++ b/src/include/nccl_common.h @@ -74,10 +74,5 @@ typedef enum { #define NCCL_ALGO_PROTO_IGNORE -1.0 -#define NCCL_NUM_UNROLLS 3 // 1/2/4 -#define NCCL_UNROLL_1 0 -#define NCCL_UNROLL_2 1 -#define NCCL_UNROLL_4 2 - #define NCCL_NUM_FLOATS 6 // half/float/double/rccl_bfloat16/rccl_float8/rccl_bfloat8 #endif diff --git a/src/init.cc b/src/init.cc index fdb387a254..8f3884fa1b 100644 --- a/src/init.cc +++ b/src/init.cc @@ -98,15 +98,34 @@ static uint64_t hashUniqueId(ncclUniqueId const &id) { return h; } +//RCCL runtime param to set Unroll Factor +RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0); + ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { hipDeviceProp_t devProp; CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); - if(IsArchMatch(devProp.gcnArchName, "gfx950")) - comm->unroll = NCCL_UNROLL_1; - else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))) - comm->unroll = NCCL_UNROLL_2; - else - comm->unroll = NCCL_UNROLL_4; + + //If RCCL runtime param is set, it will override defaults + if (rcclParamUnrollFactor() != 0) { + comm->unroll = rcclParamUnrollFactor(); + INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll); + } + else { + if (IsArchMatch(devProp.gcnArchName, "gfx950")) { + //on gfx950, use unroll=1 for single-node and unroll=2 for multi-node + if (comm->nNodes == 1) + comm->unroll = 1; + else + comm->unroll = 2; + } + else if((IsArchMatch(devProp.gcnArchName, "gfx908")) || + (IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)) + //on MI300X and gfx908, use unroll=2 + comm->unroll = 2; + else + comm->unroll = 4; + INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll); + } return ncclSuccess; } @@ -617,8 +636,6 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in // RCCL: create persistent stream for calloc CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking)); - // RCCL: determine and set unroll factor for comm - NCCLCHECK(commSetUnrollFactor(comm)); comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false; comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false; @@ -1945,6 +1962,9 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail); + // RCCL: determine and set unroll factor for comm + NCCLCHECK(commSetUnrollFactor(comm)); + #ifdef ENABLE_MSCCLPP if (job->parent) { if (job->parent->mscclppCompatible) {