|
|
|
@@ -52,16 +52,16 @@ else:
|
|
|
|
|
# make ONLY_FUNCS="AllReduce * * Sum i32"
|
|
|
|
|
#
|
|
|
|
|
# # Only AllReduce RING Max float (but all protos)
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING * Max float"
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING * Max f32"
|
|
|
|
|
#
|
|
|
|
|
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
|
|
|
|
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
|
|
|
|
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16"
|
|
|
|
|
#
|
|
|
|
|
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32"
|
|
|
|
|
# --- or ---
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32"
|
|
|
|
|
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|AllGather RING LL/SIMPLE Sum i8|AllToAllPivot RING SIMPLE Sum i8|Broadcast RING LL/SIMPLE Sum i8|Reduce RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|ReduceScatter RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|SendRecv RING SIMPLE Sum i8"
|
|
|
|
|
|
|
|
|
|
# Paste all non-None arguments together with `sep`.
|
|
|
|
|
def paste(sep, *args):
|
|
|
|
@@ -134,14 +134,14 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
|
|
|
|
################################################################################
|
|
|
|
|
|
|
|
|
|
def calc_unroll_for_local_arch():
|
|
|
|
|
if not is_local_arch_only:
|
|
|
|
|
return
|
|
|
|
|
if not is_local_arch_only:
|
|
|
|
|
return all_unroll
|
|
|
|
|
|
|
|
|
|
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
|
|
|
|
|
|
|
|
|
|
res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True)
|
|
|
|
|
rocminfo_output = res.stdout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Parse rocminfo binary output
|
|
|
|
|
gfx_targets = {}
|
|
|
|
|
curr_name = None
|
|
|
|
@@ -156,26 +156,28 @@ def calc_unroll_for_local_arch():
|
|
|
|
|
cu_count = int(line.split(':')[-1].strip())
|
|
|
|
|
gfx_targets[(curr_name, cu_count)] = None
|
|
|
|
|
curr_name = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
|
|
|
|
|
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
|
|
|
|
|
gfx_targets = list(gfx_targets.keys())
|
|
|
|
|
|
|
|
|
|
# Homogeneous system is required to build for only 1 varient of unroll factor
|
|
|
|
|
# Homogeneous system is required to build for only 1 variant of unroll factor (except for gfx950)
|
|
|
|
|
if len(gfx_targets) == 1:
|
|
|
|
|
gfx_name, cu_count = gfx_targets[0]
|
|
|
|
|
if "gfx950" == gfx_name:
|
|
|
|
|
return 1
|
|
|
|
|
return ["1", "2"]
|
|
|
|
|
elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80):
|
|
|
|
|
return 2
|
|
|
|
|
return ["2"]
|
|
|
|
|
else:
|
|
|
|
|
return 4
|
|
|
|
|
return ["4"]
|
|
|
|
|
else:
|
|
|
|
|
return all_unroll
|
|
|
|
|
|
|
|
|
|
# Helper function to check if the conditions for the collective is being met
|
|
|
|
|
def func_validate(coll, algo, proto, redop, ty):
|
|
|
|
|
def func_validate(coll, algo, proto, redop, ty, unroll):
|
|
|
|
|
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
|
|
|
|
return False
|
|
|
|
|
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
|
|
|
|
|
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
@@ -191,9 +193,6 @@ def func_filter(function_params, current_idx, item_list=None):
|
|
|
|
|
|
|
|
|
|
# If the paramter is equal to '*', include all possible cases for it
|
|
|
|
|
if current_element == "*":
|
|
|
|
|
if current_idx == 0:
|
|
|
|
|
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
|
|
|
|
|
|
|
|
|
|
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
|
|
|
|
|
# Get the current list from all_params
|
|
|
|
|
current_list = all_params[current_idx]
|
|
|
|
@@ -225,7 +224,7 @@ def func_filter(function_params, current_idx, item_list=None):
|
|
|
|
|
else:
|
|
|
|
|
coll, algo, proto, redop, ty, unroll = item_list
|
|
|
|
|
|
|
|
|
|
if func_validate(coll, algo, proto, redop, ty):
|
|
|
|
|
if func_validate(coll, algo, proto, redop, ty, unroll):
|
|
|
|
|
yield(coll, algo, proto, redop, ty, unroll)
|
|
|
|
|
|
|
|
|
|
# Parse ONLY_FUNCS input and feed it to func_filter
|
|
|
|
@@ -247,10 +246,6 @@ def parse_input(func_pattern):
|
|
|
|
|
# Maps functions to the chosen representative for the equivalence class it
|
|
|
|
|
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
|
|
|
|
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
|
|
|
|
|
# if local arch only, we only need to build for 1 varient of coll_unroll.
|
|
|
|
|
# map the other varient of coll_unroll to this one.
|
|
|
|
|
if coll_unroll:
|
|
|
|
|
unroll = str(coll_unroll)
|
|
|
|
|
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
|
|
|
|
# map signed integer sum/prod to unsigned
|
|
|
|
|
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
|
|
|
@@ -268,7 +263,7 @@ def enumerate_func_rows():
|
|
|
|
|
for proto in all_protos:
|
|
|
|
|
for redop in all_redops:
|
|
|
|
|
for ty in all_tys:
|
|
|
|
|
if func_validate(coll, algo, proto, redop, ty):
|
|
|
|
|
if func_validate(coll, algo, proto, redop, ty, unroll):
|
|
|
|
|
yield (coll, algo, proto, redop, ty, unroll)
|
|
|
|
|
|
|
|
|
|
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
|
|
|
@@ -286,7 +281,9 @@ def custom_sort_key(fn):
|
|
|
|
|
|
|
|
|
|
################################################################################
|
|
|
|
|
|
|
|
|
|
coll_unroll = calc_unroll_for_local_arch()
|
|
|
|
|
# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets,
|
|
|
|
|
# except for gfx950
|
|
|
|
|
all_unroll = calc_unroll_for_local_arch()
|
|
|
|
|
|
|
|
|
|
# Corresponds to ncclDevFuncRowToId[]
|
|
|
|
|
func_rows = [fn for fn in enumerate_func_rows()]
|
|
|
|
@@ -459,7 +456,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
|
|
|
|
# The mapping from function rows to valid primary function ids.
|
|
|
|
|
out("extern int const ncclDevFuncRowToId[] = {\n")
|
|
|
|
|
index = 0
|
|
|
|
|
for fn in func_rows[:len(func_rows)//3]:
|
|
|
|
|
for fn in func_rows[:len(func_rows)//len(all_unroll)]:
|
|
|
|
|
fn_id, comment = -1, ""
|
|
|
|
|
if fn is not None:
|
|
|
|
|
fn_id = primary_to_index[equivalent_primary(*fn)]
|
|
|
|
|