[GEN/BUILD] Refactor generate.py and reduce build time for older archs (#2006)
This commit is contained in:
committed by
GitHub
orang tua
8444b3c6e9
melakukan
bed7cdf863
+139
-179
@@ -2,22 +2,19 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Order of colls, redops, tys, protos, algos must match src/include/device.h
|
||||
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"]
|
||||
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
|
||||
all_protos = ["LL","LL128","SIMPLE"]
|
||||
all_algos = ["TREE","RING", "", "", "", "", "PAT"]
|
||||
all_unroll = ["1", "2", "4"]
|
||||
use_acc = ["0", "1"]
|
||||
|
||||
# Pipelining is not supported for LL/LL64 prims, so "1" is not a valid value for low latency protocols.
|
||||
# However, if it needs to be supported, equivalent_primary() can be modified to avoid the "non-zero"->"0" mapping.
|
||||
all_pipeline = ["0", "1"]
|
||||
pipelined_types = ["bf16"]
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_pipeline, all_unroll]
|
||||
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"]
|
||||
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
|
||||
all_protos = ["LL","LL128","SIMPLE"]
|
||||
all_algos = ["TREE","RING", "", "", "", "", "PAT"]
|
||||
all_accs = ["0", "1"]
|
||||
all_pipelines = ["0", "1"]
|
||||
all_unrolls = ["1", "2", "4"]
|
||||
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, all_accs, all_pipelines, all_unrolls]
|
||||
|
||||
################################################################################
|
||||
# The first command line argument is the path to the directory to generate and
|
||||
@@ -128,7 +125,7 @@ tys_of_coll = {
|
||||
|
||||
acc_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": use_acc,
|
||||
"AllReduce": all_accs,
|
||||
"AllToAllPivot": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": ["0"],
|
||||
@@ -138,13 +135,14 @@ acc_of_coll = {
|
||||
|
||||
pipelines_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": all_pipeline,
|
||||
"AllReduce": all_pipelines,
|
||||
"AllToAllPivot": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": all_pipeline,
|
||||
"ReduceScatter": all_pipeline,
|
||||
"Reduce": all_pipelines,
|
||||
"ReduceScatter": all_pipelines,
|
||||
"SendRecv": ["0"]
|
||||
}
|
||||
pipelined_types = ["bf16"]
|
||||
|
||||
coll_camel_to_lower = {
|
||||
"AllGather": "all_gather",
|
||||
@@ -152,16 +150,29 @@ coll_camel_to_lower = {
|
||||
"AllToAllPivot": "alltoall_pivot",
|
||||
"Broadcast": "broadcast",
|
||||
"Reduce": "reduce",
|
||||
"ReduceScatter": "reduce_scatter",
|
||||
"SendRecv": "sendrecv"
|
||||
"ReduceScatter": "reduce_scatter",
|
||||
"SendRecv": "sendrecv"
|
||||
}
|
||||
coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
||||
|
||||
################################################################################
|
||||
@dataclass(frozen=True)
|
||||
class Fn:
|
||||
coll: str
|
||||
algo: str
|
||||
proto: str
|
||||
redop: str
|
||||
ty: str
|
||||
acc: str
|
||||
pipeline: str
|
||||
unroll: str
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.coll, self.algo, self.proto, self.redop, self.ty, self.acc, self.pipeline, self.unroll))
|
||||
|
||||
def calc_unroll_for_local_arch():
|
||||
if not is_local_arch_only:
|
||||
return all_unroll
|
||||
return all_unrolls
|
||||
|
||||
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
|
||||
|
||||
@@ -197,7 +208,11 @@ def calc_unroll_for_local_arch():
|
||||
else:
|
||||
return ["4"]
|
||||
else:
|
||||
return all_unroll
|
||||
return all_unrolls
|
||||
|
||||
# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets,
|
||||
# except for gfx950
|
||||
local_unroll = calc_unroll_for_local_arch()
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
@@ -211,7 +226,7 @@ def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
ty not in tys_of_coll[coll] or
|
||||
acc not in acc_of_coll[coll] or
|
||||
pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types) or
|
||||
unroll not in all_unroll):
|
||||
unroll not in local_unroll):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -296,45 +311,55 @@ def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
||||
# outermost loop should be for unroll factor; refer to host_table section
|
||||
def enumerate_func_rows():
|
||||
for unroll in all_unroll:
|
||||
for unroll in local_unroll:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
for acc in use_acc:
|
||||
for pipeline in all_pipeline:
|
||||
for acc in all_accs:
|
||||
for pipeline in all_pipelines:
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
yield (coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
def custom_sort_key(fn: Fn):
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
all_colls.index(coll),
|
||||
all_algos.index(algo),
|
||||
all_protos.index(proto),
|
||||
all_redops.index(redop),
|
||||
all_tys.index(ty),
|
||||
use_acc.index(acc),
|
||||
all_pipeline.index(pipeline)
|
||||
local_unroll.index(fn.unroll),
|
||||
all_colls.index(fn.coll),
|
||||
all_algos.index(fn.algo),
|
||||
all_protos.index(fn.proto),
|
||||
all_redops.index(fn.redop),
|
||||
all_tys.index(fn.ty),
|
||||
all_accs.index(fn.acc),
|
||||
all_pipelines.index(fn.pipeline)
|
||||
)
|
||||
|
||||
def get_arch_guard(fn):
|
||||
cond = None
|
||||
|
||||
if fn.proto == "LL128" and fn.acc == "1":
|
||||
cond = "(defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)"
|
||||
elif fn.proto == "LL128":
|
||||
cond = "(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)"
|
||||
elif fn.acc == "1":
|
||||
cond = "defined(__gfx942__) || defined(__gfx950__)"
|
||||
|
||||
return cond
|
||||
|
||||
################################################################################
|
||||
|
||||
# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets,
|
||||
# except for gfx950
|
||||
all_unroll = calc_unroll_for_local_arch()
|
||||
|
||||
# Corresponds to ncclDevFuncRowToId[]
|
||||
func_rows = [fn for fn in enumerate_func_rows()]
|
||||
func_rows = [Fn(*fn) for fn in enumerate_func_rows()]
|
||||
|
||||
# Corresponds to ncclDevFuncTable[]
|
||||
primary_funcs = sorted(set(equivalent_primary(*fn) for fn in parse_input(func_pattern)), key=custom_sort_key)
|
||||
primary_funcs = sorted(
|
||||
{Fn(*equivalent_primary(*fn)) for fn in parse_input(func_pattern)}, key=custom_sort_key
|
||||
)
|
||||
|
||||
# primary_to_index[primary_funcs[i]] == i
|
||||
primary_to_index = {fn: primary_funcs.index(fn) if fn in primary_funcs else -1 for fn in func_rows}
|
||||
primary_to_index = {fn: i for i, fn in enumerate(primary_funcs)}
|
||||
primary_to_index = {fn: primary_to_index.get(Fn(*fn), -1) for fn in func_rows}
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -348,124 +373,55 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
|
||||
for fn in primary_funcs:
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("%s %s();\n#else\n" % (func_declaration, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("%s %s();\n#endif\n" % (func_declaration, sym_ll))
|
||||
guard = get_arch_guard(fn)
|
||||
if guard:
|
||||
out("#if %s\n%s %s();\n#endif\n" % (guard, func_declaration, sym))
|
||||
else:
|
||||
out("%s %s();\n" % (func_declaration, sym))
|
||||
out("\n")
|
||||
|
||||
index = {val: None for val in all_unrolls}
|
||||
out("typedef void(*ncclDevFuncPtr_t)();\n\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
|
||||
index1 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
if unroll != "1": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index1, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index1, sym))
|
||||
index1 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
|
||||
index2 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
if unroll != "2": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index2, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index2, sym))
|
||||
index2 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index4 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
if unroll != "4": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index4, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index4, sym))
|
||||
index4 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
for unroll in all_unrolls:
|
||||
index[unroll] = 0
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % unroll)
|
||||
for fn in primary_funcs:
|
||||
if fn.unroll != unroll: continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
guard = get_arch_guard(fn)
|
||||
if guard:
|
||||
out("#if %s\n/*%4d*/ %s,\n#else\n/*%4d*/ nullptr,\n#endif\n" % (guard, index[unroll], sym, index[unroll]))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index[unroll], sym))
|
||||
index[unroll] += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
if not is_ifc:
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller1 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call1(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller1<f, m>::call1(funcIndex) : Caller1<m, l>::call1(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller1<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller1<0, {index1}>::call1(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller2 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call2(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller2<f, m>::call2(funcIndex) : Caller2<m, l>::call2(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller2<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller2<0, {index2}>::call2(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller4 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller4<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller4<0, {index4}>::call4(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
for unroll in all_unrolls:
|
||||
out(f"template<unsigned short f, unsigned short l>\n"
|
||||
f"struct Caller{unroll} {{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
f" void call{unroll}(unsigned short funcIndex) noexcept {{\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
f" return (funcIndex < m)\n"
|
||||
f" ? Caller{unroll}<f, m>::call{unroll}(funcIndex)\n"
|
||||
f" : Caller{unroll}<m, l>::call{unroll}(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n\n")
|
||||
|
||||
out(f"template<unsigned short f>\n"
|
||||
f"struct Caller{unroll}<f, f + 1> {{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
f" void call{unroll}(unsigned short funcIndex) noexcept {{\n"
|
||||
f" ncclDevFuncTable_{unroll}[f]();\n"
|
||||
" }\n"
|
||||
"};\n\n")
|
||||
|
||||
# emit NCCL_CALL_FUNCTIONS_<unroll> wrapper using last index value
|
||||
out(f"__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_{unroll}(unsigned short funcIndex) noexcept {{\n")
|
||||
out(f" Caller{unroll}<0, {index[unroll]}>::call{unroll}(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
|
||||
# Generate <gensrc>/device_table.cpp
|
||||
if is_colltrace:
|
||||
@@ -479,7 +435,7 @@ if is_colltrace:
|
||||
seen_fns = set()
|
||||
out("const char* funcNames[] = {\n")
|
||||
for fn in primary_funcs:
|
||||
fn_no_unroll = fn[:-1]
|
||||
fn_no_unroll = (fn.coll, fn.algo, fn.proto, fn.redop, fn.ty, fn.acc, fn.pipeline)
|
||||
if fn_no_unroll not in seen_fns:
|
||||
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn_no_unroll))
|
||||
seen_fns.add(fn_no_unroll)
|
||||
@@ -510,28 +466,29 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# host_table entries map device functions based on collective, algorithm, protocol, redop, and datatype
|
||||
# For GPU targets that support multiple unrolls, e.g., gfx950
|
||||
# (or) for non-local builds, only a single set of functions are needed in the host_table.
|
||||
for fn in func_rows[:len(func_rows)//len(all_unroll)]:
|
||||
for fn in func_rows[:len(func_rows)//len(local_unroll)]:
|
||||
fn_id = -1
|
||||
if fn is not None:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
comment = " // " + paste(" ", *fn[:-1])
|
||||
guard = get_arch_guard(fn)
|
||||
fn_id = primary_to_index[Fn(*equivalent_primary(*fn))]
|
||||
comment = " // " + paste(" ", *fn)
|
||||
# Build the function signature string: "<coll> <algo> <proto> <redop> <ty>"
|
||||
# get parts indexes in order (coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
coll_idx = all_colls.index(fn[0])
|
||||
algo_idx = all_algos.index(fn[1])
|
||||
proto_idx = all_protos.index(fn[2])
|
||||
redop_idx = all_redops.index(fn[3])
|
||||
ty_idx = all_tys.index(fn[4])
|
||||
acc_idx = use_acc.index(fn[5])
|
||||
pipeline_idx = all_pipeline.index(fn[6])
|
||||
coll_idx = all_colls.index(fn.coll)
|
||||
algo_idx = all_algos.index(fn.algo)
|
||||
proto_idx = all_protos.index(fn.proto)
|
||||
redop_idx = all_redops.index(fn.redop)
|
||||
ty_idx = all_tys.index(fn.ty)
|
||||
acc_idx = all_accs.index(fn.acc)
|
||||
pipeline_idx = all_pipelines.index(fn.pipeline)
|
||||
# Assert that 4 bits (16 values) is enough to map all_colls, all_algos, etc.
|
||||
assert len(all_colls) <= 16, "Error: all_colls has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_algos) <= 16, "Error: all_algos has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_protos) <= 16, "Error: all_protos has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_redops) <= 16, "Error: all_redops has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_tys) <= 16, "Error: all_tys has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(use_acc) <= 16, "Error: use_acc has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_pipeline) <= 16, "Error: all_pipeline has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_accs) <= 16, "Error: all_accs has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_pipelines) <= 16, "Error: all_pipelines has more than 16 values, which exceeds 4-bit capacity."
|
||||
# Create a 64-bit unsigned integer key and pack the indices into 4 bits each
|
||||
key = (
|
||||
(coll_idx & 0xF)
|
||||
@@ -542,12 +499,15 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
| ((acc_idx & 0xF) << 20)
|
||||
| ((pipeline_idx & 0xF) << 24)
|
||||
)
|
||||
fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {acc_idx} {pipeline_idx}"
|
||||
if fn[0] == "Broadcast":
|
||||
if fn.coll == "Broadcast":
|
||||
key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8))
|
||||
if fn[0] in ["SendRecv", "AllToAllPivot"]:
|
||||
if fn.coll in ["SendRecv", "AllToAllPivot"]:
|
||||
key = ((coll_idx & 0x3F))
|
||||
out(f' {{{key}, {fn_id}}}, {comment}\n')
|
||||
|
||||
if fn_id != -1 and guard:
|
||||
out(f'#if {guard}\n {{{key}, {fn_id}}}, {comment}\n#else\n {{{key}, -1}}, {comment}\n#endif\n')
|
||||
else:
|
||||
out(f' {{{key}, {fn_id}}}, {comment}\n')
|
||||
out("};\n")
|
||||
|
||||
# Maps to .cu filename which implements this func. The only constraint is that
|
||||
@@ -562,13 +522,13 @@ def partition_by_name(fns):
|
||||
ans = {}
|
||||
for fn in fns:
|
||||
name = impl_filename(*fn)
|
||||
coll = fn[0]
|
||||
coll = fn.coll
|
||||
if name not in ans:
|
||||
ans[name] = (coll, [])
|
||||
ans[name][1].append(fn)
|
||||
return ans
|
||||
|
||||
name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn[0]!="Nop")
|
||||
name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn.coll !="Nop")
|
||||
|
||||
redop_to_cxx = {
|
||||
None: "FuncCopy",
|
||||
@@ -609,16 +569,16 @@ for name in name_to_funcs.keys():
|
||||
)
|
||||
|
||||
for fn in fns:
|
||||
(coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
if proto == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
sym = paste("_", fn.coll, fn.algo, fn.proto, fn.redop, fn.ty, fn.acc, fn.pipeline, fn.unroll)
|
||||
guard = get_arch_guard(fn)
|
||||
if guard:
|
||||
out("#if %s\n" % guard)
|
||||
out(
|
||||
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {pipeline}, {unroll})\n"
|
||||
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, pipeline=pipeline, unroll=unroll)
|
||||
.format(sym=sym, coll=fn.coll, redop_cxx=redop_to_cxx[fn.redop], ty_cxx=ty_to_cxx[fn.ty],
|
||||
algo=(fn.algo or "RING"), proto=(fn.proto or "SIMPLE"), acc=fn.acc, pipeline=fn.pipeline, unroll=fn.unroll)
|
||||
)
|
||||
if proto == "LL128":
|
||||
if guard:
|
||||
out("#endif\n")
|
||||
|
||||
# Generate each <gensrc>/<msccl_impl>.cpp
|
||||
|
||||
Reference in New Issue
Block a user