[DEVICE] Adding ability to choose unroll factor at runtime (#1734)
* Adding runtime unroll factor selection via RCCL_UNROLL_FACTOR * [BUILD] Add support for user-defined UNROLL for debugging * Update CHANGELOG.md * Fix COLLTRACE errors in CI * Add debug statements for unroll and resolve warnings * Incorporate UNROLL into ONLY_FUNCS for debugging --------- Signed-off-by: nileshnegi <Nilesh.Negi@amd.com> Co-authored-by: gilbertlee-amd <44450918+gilbertlee-amd@users.noreply.github.com> Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
682ed36fe6
Коммит
9d72be7b2f
+3
-2
@@ -14,9 +14,10 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
|
||||
### Added
|
||||
|
||||
* Added new GPU target `gfx950`.
|
||||
* Added support for `unroll=1` in device-code generation to improve performance,
|
||||
* Set a default of 112 channels for a single node with `8 * gfx950`,
|
||||
* Added support for `unroll=1` in device-code generation to improve performance.
|
||||
* Set a default of 112 channels for a single node with `8 * gfx950`.
|
||||
* Enabled LL128 protocol on `gfx950`.
|
||||
* Adding ability to choose unroll factor at runtime via `RCCL_UNROLL_FACTOR`. This can be set at runtime to 1, 2, or 4. This change currently increases compilation and linking time because it triples the number of kernels generated.
|
||||
* Added MSCCL support for AllGather multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the environment variable `RCCL_MSCCL_FORCE_ENABLE=1`. Max message size for MSCCL AllGather usage is `12292 * sizeof(datatype) * nGPUs`.
|
||||
* Thread thresholds for LL/LL128 are selected in Tuning Models for the MI300X. This impacts the number of channels used for AG and RS. Channel tuning model is bypassed if `NCCL_THREAD_THRESHOLDS`, `NCCL_MIN_NCHANNELS', or 'NCCL_MAX_NCHANNELS` are set.
|
||||
* Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges.
|
||||
|
||||
@@ -688,9 +688,13 @@ endif()
|
||||
|
||||
set(GEN_DIR "${HIPIFY_DIR}/gensrc")
|
||||
|
||||
if(ONLY_FUNCS)
|
||||
message(WARNING "Using ONLY_FUNCS = ${ONLY_FUNCS}. Not meant for release builds.")
|
||||
endif()
|
||||
|
||||
# Execute the python script to generate required files
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${BUILD_LOCAL_GPU_TARGET_ONLY} ${ONLY_FUNCS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS}
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
RESULT_VARIABLE gen_py_result
|
||||
ERROR_VARIABLE gen_py_error
|
||||
|
||||
@@ -13,31 +13,6 @@ __shared__ ncclShmemData ncclShmem;
|
||||
__shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)];
|
||||
#endif
|
||||
|
||||
struct RunWorkNop {
|
||||
__device__ void run() {}
|
||||
};
|
||||
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args);
|
||||
}
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
__device__ void ncclDevFunc_Nop();
|
||||
#else
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
#include "collectives.h"
|
||||
#include "device.h"
|
||||
#include "op128.h"
|
||||
#include "reduce_kernel.h"
|
||||
#include "device_table.h"
|
||||
#include "reduce_kernel.h"
|
||||
#include "network/unpack/unpack_defs.h"
|
||||
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
||||
|
||||
@@ -572,21 +572,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
|
||||
if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) {
|
||||
SpecializedRunWorkBatch().run();
|
||||
} else {
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
if (COLL_UNROLL == 1)
|
||||
ncclDevFuncTable_1[ncclShmem.funcId]();
|
||||
else if (COLL_UNROLL == 2)
|
||||
ncclDevFuncTable_2[ncclShmem.funcId]();
|
||||
else
|
||||
ncclDevFuncTable_4[ncclShmem.funcId]();
|
||||
#else
|
||||
if (COLL_UNROLL == 1)
|
||||
NCCL_CALL_FUNCTIONS_1(ncclShmem.funcId);
|
||||
else if (COLL_UNROLL == 2)
|
||||
NCCL_CALL_FUNCTIONS_2(ncclShmem.funcId);
|
||||
else
|
||||
NCCL_CALL_FUNCTIONS_4(ncclShmem.funcId);
|
||||
#endif
|
||||
NCCL_CALL_FUNCTIONS<COLL_UNROLL>(ncclShmem.funcId);
|
||||
}
|
||||
|
||||
if (ncclShmem.nextBatchIx == -1) break;
|
||||
@@ -628,15 +614,6 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
#endif
|
||||
|
||||
#define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \
|
||||
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
|
||||
|
||||
|
||||
+109
-167
@@ -32,7 +32,7 @@ else:
|
||||
# developing device code. The regex supports non-space containing globs '*',
|
||||
# and union 'a|b'. The string representing the function has the form:
|
||||
#
|
||||
# <coll> <algo> <proto> <redop> <type>
|
||||
# <coll> <algo> <proto> <redop> <type> <unroll>
|
||||
#
|
||||
# The possible values for redop, type, algo, proto can be found in the all_<foo>
|
||||
# lists at the top of this file.
|
||||
@@ -45,23 +45,28 @@ else:
|
||||
# # Only AllReduce and Reduce
|
||||
# make ONLY_FUNCS="AllReduce|Reduce"
|
||||
#
|
||||
# # Only AllGather with unroll=4
|
||||
# make ONLY_FUNCS="AllGather * * * * 4"
|
||||
#
|
||||
# # Only non-reductions:
|
||||
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
#
|
||||
# # Only AllReduce Sum int32_t (but all algos, protos)
|
||||
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos)
|
||||
# # Only AllReduce RING Max float (but all protos and unrolls)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16 1"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float *"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float *"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|AllGather RING LL/SIMPLE Sum int8_t 1/2/4|AllToAllPivot RING SIMPLE Sum int8_t 1/2/4|Broadcast RING LL/SIMPLE Sum int8_t 1/2/4|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|SendRecv RING SIMPLE Sum int8_t 1/2/4"
|
||||
#
|
||||
# # ONLY_FUNCS can be used together for debugging
|
||||
|
||||
# Paste all non-None arguments together with `sep`.
|
||||
def paste(sep, *args):
|
||||
@@ -70,9 +75,8 @@ def paste(sep, *args):
|
||||
is_ifc = 1 if sys.argv[2] == "ON" else 0
|
||||
is_colltrace = 1 if sys.argv[3] == "ON" else 0
|
||||
is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0
|
||||
is_local_arch_only = 1 if sys.argv[5] == "ON" else 0
|
||||
|
||||
func_pattern = sys.argv[6:7]
|
||||
func_pattern = sys.argv[5:]
|
||||
if func_pattern and func_pattern[0]:
|
||||
func_pattern = func_pattern[0]
|
||||
else:
|
||||
@@ -133,49 +137,13 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
||||
|
||||
################################################################################
|
||||
|
||||
def calc_unroll_for_local_arch():
|
||||
if not is_local_arch_only:
|
||||
return
|
||||
|
||||
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
|
||||
|
||||
res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
rocminfo_output = res.stdout
|
||||
|
||||
# Parse rocminfo binary output
|
||||
gfx_targets = {}
|
||||
curr_name = None
|
||||
for line in rocminfo_output.splitlines():
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith("Name:"):
|
||||
name = line.split(':')[-1].strip()
|
||||
if "gfx" in name:
|
||||
curr_name = name
|
||||
if line.startswith("Compute Unit:") and curr_name:
|
||||
cu_count = int(line.split(':')[-1].strip())
|
||||
gfx_targets[(curr_name, cu_count)] = None
|
||||
curr_name = None
|
||||
|
||||
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
|
||||
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
|
||||
gfx_targets = list(gfx_targets.keys())
|
||||
|
||||
# Homogeneous system is required to build for only 1 varient of unroll factor
|
||||
if len(gfx_targets) == 1:
|
||||
gfx_name, cu_count = gfx_targets[0]
|
||||
if "gfx950" == gfx_name:
|
||||
return 1
|
||||
elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80):
|
||||
return 2
|
||||
else:
|
||||
return 4
|
||||
seen_unroll = []
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty):
|
||||
def func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -193,8 +161,8 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
if current_element == "*":
|
||||
if current_idx == 0:
|
||||
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
|
||||
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
|
||||
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type> <unroll>
|
||||
# Get the current list from all_params
|
||||
current_list = all_params[current_idx]
|
||||
|
||||
@@ -210,12 +178,12 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
# Check if the current element is recognized
|
||||
elements = current_element.split("/")
|
||||
current_param = all_params[current_idx]
|
||||
|
||||
|
||||
# Iterate over the elements in the elements list
|
||||
for item in elements:
|
||||
if item not in current_param:
|
||||
raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.")
|
||||
|
||||
|
||||
for item in elements:
|
||||
item_list.append(item)
|
||||
yield from func_filter(function_params, current_idx+1, item_list)
|
||||
@@ -225,7 +193,9 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
else:
|
||||
coll, algo, proto, redop, ty, unroll = item_list
|
||||
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
if not unroll in seen_unroll:
|
||||
seen_unroll.append(unroll)
|
||||
yield(coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
# Parse ONLY_FUNCS input and feed it to func_filter
|
||||
@@ -247,10 +217,6 @@ def parse_input(func_pattern):
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
|
||||
# if local arch only, we only need to build for 1 varient of coll_unroll.
|
||||
# map the other varient of coll_unroll to this one.
|
||||
if coll_unroll:
|
||||
unroll = str(coll_unroll)
|
||||
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
||||
@@ -268,13 +234,13 @@ def enumerate_func_rows():
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
yield (coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty> <unroll>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
|
||||
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
all_colls.index(coll),
|
||||
@@ -286,8 +252,6 @@ def custom_sort_key(fn):
|
||||
|
||||
################################################################################
|
||||
|
||||
coll_unroll = calc_unroll_for_local_arch()
|
||||
|
||||
# Corresponds to ncclDevFuncRowToId[]
|
||||
func_rows = [fn for fn in enumerate_func_rows()]
|
||||
|
||||
@@ -304,6 +268,8 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, "device_table.h"))
|
||||
out = f.write
|
||||
|
||||
out("#include \"common.h\"\n\n")
|
||||
|
||||
if is_ifc: func_declaration = "__device__ void"
|
||||
else: func_declaration = "__device__ __attribute__((noinline)) void"
|
||||
|
||||
@@ -320,113 +286,86 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("\n")
|
||||
|
||||
out("typedef void(*ncclDevFuncPtr_t)();\n\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
|
||||
index1 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "1": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index1, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index1, sym))
|
||||
index1 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
|
||||
index2 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "2": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index2, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index2, sym))
|
||||
index2 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index4 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "4": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index4, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index4, sym))
|
||||
index4 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
|
||||
# Generate function tables per unroll factor
|
||||
tableIdx = 0
|
||||
for curr_unroll in seen_unroll:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % curr_unroll)
|
||||
tableIdx = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if curr_unroll != unroll: continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (tableIdx, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (tableIdx, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (tableIdx, sym))
|
||||
tableIdx += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
# Construct indirection function workaround
|
||||
if not is_ifc:
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller1 {\n"
|
||||
out("template<int unroll, unsigned short f, unsigned short l>\n"
|
||||
"struct Caller {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call1(unsigned short funcIndex) noexcept\n"
|
||||
" void call(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller1<f, m>::call1(funcIndex) : Caller1<m, l>::call1(funcIndex);\n"
|
||||
" return (funcIndex < m) ? Caller<unroll, f, m>::call(funcIndex) : Caller<unroll, m, l>::call(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller1<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller1<0, {index1}>::call1(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller2 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call2(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller2<f, m>::call2(funcIndex) : Caller2<m, l>::call2(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller2<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller2<0, {index2}>::call2(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller4 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller4<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller4<0, {index4}>::call4(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
"\n")
|
||||
|
||||
for curr_unroll in seen_unroll:
|
||||
out("template<unsigned short f>\n")
|
||||
out("struct Caller<%s, f, f + 1>{\n" % curr_unroll)
|
||||
out(" static __forceinline__ __device__ __host__\n");
|
||||
out(" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable_%s[f](); }\n" % curr_unroll)
|
||||
out("};\n")
|
||||
|
||||
out("\n")
|
||||
# Create NCCL_CALL_FUNCTION helper function that will call the appropriate device function
|
||||
out("template <int unroll>\n"
|
||||
"__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
|
||||
if is_ifc:
|
||||
for curr_unroll in seen_unroll:
|
||||
out(" if (unroll == %s) { ncclDevFuncTable_%s[funcIndex]();\n" % (curr_unroll, curr_unroll))
|
||||
else:
|
||||
out(f" Caller<unroll, 0, {tableIdx}>::call(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
|
||||
# Create RCCL
|
||||
out("template<int SpecializedFnId, typename SpecializedRunWorkBatch, bool COLLTRACE, int COLL_UNROLL>\n");
|
||||
out("__device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* args);\n\n");
|
||||
|
||||
out("struct RunWorkNop {\n");
|
||||
out(" __device__ void run() {}\n");
|
||||
out("};\n\n");
|
||||
|
||||
out("template <int UNROLL, bool COLLTRACE>\n"
|
||||
"__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void rcclGenericKernel(ncclDevKernelArgs4K const args4K) {\n"
|
||||
" ncclKernelMain<-1, RunWorkNop, COLLTRACE, UNROLL>(&args4K.args);\n"
|
||||
"}\n\n")
|
||||
|
||||
out("struct rcclKernelItem {\n");
|
||||
out(" void* funcPtr;\n");
|
||||
out(" int unroll;\n");
|
||||
out("};\n\n");
|
||||
|
||||
out("/* This table contains all the __global__ functions that were compiled */\n");
|
||||
out("static struct rcclKernelItem rcclKernelTable[] = {\n")
|
||||
for unroll in seen_unroll:
|
||||
out(" {(void*)&(rcclGenericKernel<%s, false>), %s},\n" % (unroll, unroll))
|
||||
out("#ifdef ENABLE_COLLTRACE\n")
|
||||
for unroll in seen_unroll:
|
||||
out(" {(void*)&(rcclGenericKernel<%s, true>), %s},\n" % (unroll, unroll))
|
||||
out("#endif\n");
|
||||
out("};\n\n");
|
||||
|
||||
# Generate <gensrc>/device_table.cpp
|
||||
if is_colltrace:
|
||||
@@ -436,7 +375,7 @@ if is_colltrace:
|
||||
out = f.write
|
||||
out('#include "nccl_common.h"\n#include "device.h"\n')
|
||||
out("\n")
|
||||
|
||||
|
||||
seen_fns = set()
|
||||
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
for fn in primary_funcs:
|
||||
@@ -459,10 +398,13 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# The mapping from function rows to valid primary function ids.
|
||||
out("extern int const ncclDevFuncRowToId[] = {\n")
|
||||
index = 0
|
||||
for fn in func_rows[:len(func_rows)//3]:
|
||||
offset = len(func_rows)//len(all_unroll)
|
||||
start = all_unroll.index(seen_unroll[0]) * offset
|
||||
end = start + offset
|
||||
for fn in func_rows[start:end]:
|
||||
fn_id, comment = -1, ""
|
||||
if fn is not None:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)] % offset if primary_to_index[equivalent_primary(*fn)] != -1 else -1
|
||||
comment = " // " + paste(" ", *fn[:-1])
|
||||
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
|
||||
index += 1
|
||||
|
||||
@@ -48,7 +48,10 @@
|
||||
} while (0)
|
||||
|
||||
#define barrier_by_group() barrier_by_group_common(__threadfence())
|
||||
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#define barrier_by_group_block() barrier_by_group_common(__threadfence_block())
|
||||
#endif
|
||||
|
||||
/* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128
|
||||
* We use these as template args to the Primtiives class instead of integral
|
||||
|
||||
+34
-25
@@ -28,30 +28,31 @@
|
||||
|
||||
using namespace rccl;
|
||||
|
||||
struct ncclKernelMatch {
|
||||
void* kernelFn;
|
||||
bool specialized;
|
||||
};
|
||||
/* [RCCL] Determine which GPU kernel to execute */
|
||||
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL)
|
||||
{
|
||||
// At this time, unroll factor is controlled only by passed in unroll argument
|
||||
// After more investigation, this may be further tuned by the actual task being processed
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0))
|
||||
static ncclKernelMatch const ncclKerns[6] = {
|
||||
{(void *)ncclDevKernel_Generic_1, true},
|
||||
{(void *)ncclDevKernel_Generic_2, true},
|
||||
{(void *)ncclDevKernel_Generic_4, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_1, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_2, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_4, true}
|
||||
};
|
||||
int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]) / 2;
|
||||
int firstKernel = useCollTrace ? numKernels : 0;
|
||||
#else
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll)
|
||||
static ncclKernelMatch const ncclKerns[3] = {
|
||||
{(void*)ncclDevKernel_Generic_1, true},
|
||||
{(void*)ncclDevKernel_Generic_2, true},
|
||||
{(void*)ncclDevKernel_Generic_4, true}
|
||||
};
|
||||
int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]);
|
||||
int firstKernel = 0;
|
||||
#endif
|
||||
|
||||
// Check if the requested unroll exists
|
||||
for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
|
||||
if (rcclKernelTable[firstKernel + kernelIdx].unroll == unroll) {
|
||||
return rcclKernelTable[firstKernel + kernelIdx].funcPtr;
|
||||
}
|
||||
}
|
||||
// Fall back to default unroll
|
||||
WARN("Requested RCCL_UNROLL_FACTOR: %d does not exist in `rcclKernelTable`. Falling back to default unroll: %d", unroll, rcclKernelTable[firstKernel].unroll);
|
||||
return rcclKernelTable[firstKernel].funcPtr;
|
||||
}
|
||||
|
||||
static int rcclProtoGrainSize(int proto, ncclComm *comm){
|
||||
switch (proto) {
|
||||
case NCCL_PROTO_LL: return 16;
|
||||
@@ -81,7 +82,7 @@ NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0);
|
||||
|
||||
// Returns maximum kernel stack size of all CUDA kernels
|
||||
ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize) {
|
||||
constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]);
|
||||
constexpr int KernelCount = sizeof(rcclKernelTable)/sizeof(rcclKernelTable[0]);
|
||||
ncclResult_t result = ncclSuccess;
|
||||
int print = 0;
|
||||
|
||||
@@ -95,7 +96,7 @@ ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* ma
|
||||
int ncclMaxSharedMem = rcclShmemDynamicSize(cudaArch, WarpSize);
|
||||
|
||||
for (int k=0; k < KernelCount; k++) {
|
||||
void* fn = ncclKerns[k].kernelFn;
|
||||
void* fn = rcclKernelTable[k].funcPtr;
|
||||
cudaFuncAttributes attr = {0};
|
||||
if (fn == nullptr) continue;
|
||||
|
||||
@@ -783,8 +784,12 @@ static ncclResult_t scheduleCollTasksToPlan(
|
||||
//plan->channelMask.masks[channelId/64] |= (2ull<<devWork->channelHi) - (1ull<<devWork->channelLo);
|
||||
plan->threadPerBlock = std::max(plan->threadPerBlock, 192 /* 3*WARP_SIZE */);
|
||||
if (!plan->kernelSpecialized) {
|
||||
plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn;
|
||||
plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized;
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled);
|
||||
#else
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, false);
|
||||
#endif
|
||||
plan->kernelSpecialized = true;
|
||||
}
|
||||
|
||||
if (comm->rank == 0) {
|
||||
@@ -1084,8 +1089,12 @@ static ncclResult_t scheduleP2pTasksToPlan(
|
||||
|
||||
plan->threadPerBlock = std::max(plan->threadPerBlock, NCCL_MAX_NTHREADS);
|
||||
if (!plan->kernelSpecialized) {
|
||||
plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn;
|
||||
plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized;
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled);
|
||||
#else
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, false);
|
||||
#endif
|
||||
plan->kernelSpecialized = true;
|
||||
}
|
||||
|
||||
// Compute how much to split operations
|
||||
|
||||
@@ -74,10 +74,5 @@ typedef enum {
|
||||
|
||||
#define NCCL_ALGO_PROTO_IGNORE -1.0
|
||||
|
||||
#define NCCL_NUM_UNROLLS 3 // 1/2/4
|
||||
#define NCCL_UNROLL_1 0
|
||||
#define NCCL_UNROLL_2 1
|
||||
#define NCCL_UNROLL_4 2
|
||||
|
||||
#define NCCL_NUM_FLOATS 6 // half/float/double/rccl_bfloat16/rccl_float8/rccl_bfloat8
|
||||
#endif
|
||||
|
||||
+28
-8
@@ -98,15 +98,34 @@ static uint64_t hashUniqueId(ncclUniqueId const &id) {
|
||||
return h;
|
||||
}
|
||||
|
||||
//RCCL runtime param to set Unroll Factor
|
||||
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
|
||||
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
if(IsArchMatch(devProp.gcnArchName, "gfx950"))
|
||||
comm->unroll = NCCL_UNROLL_1;
|
||||
else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)))
|
||||
comm->unroll = NCCL_UNROLL_2;
|
||||
else
|
||||
comm->unroll = NCCL_UNROLL_4;
|
||||
|
||||
//If RCCL runtime param is set, it will override defaults
|
||||
if (rcclParamUnrollFactor() != 0) {
|
||||
comm->unroll = rcclParamUnrollFactor();
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
|
||||
}
|
||||
else {
|
||||
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
|
||||
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
|
||||
if (comm->nNodes == 1)
|
||||
comm->unroll = 1;
|
||||
else
|
||||
comm->unroll = 2;
|
||||
}
|
||||
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
|
||||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
|
||||
//on MI300X and gfx908, use unroll=2
|
||||
comm->unroll = 2;
|
||||
else
|
||||
comm->unroll = 4;
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -617,8 +636,6 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
|
||||
|
||||
// RCCL: create persistent stream for calloc
|
||||
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
|
||||
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
|
||||
|
||||
@@ -1945,6 +1962,9 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
|
||||
NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail);
|
||||
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (job->parent) {
|
||||
if (job->parent->mscclppCompatible) {
|
||||
|
||||
Ссылка в новой задаче
Block a user