diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index c4e515824f..1e76b1bf5b 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -2,22 +2,19 @@ import os import sys import subprocess +from dataclasses import dataclass # Order of colls, redops, tys, protos, algos must match src/include/device.h -all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"] -all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] -all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"] -all_protos = ["LL","LL128","SIMPLE"] -all_algos = ["TREE","RING", "", "", "", "", "PAT"] -all_unroll = ["1", "2", "4"] -use_acc = ["0", "1"] - -# Pipelining is not supported for LL/LL64 prims, so "1" is not a valid value for low latency protocols. -# However, if it needs to be supported, equivalent_primary() can be modified to avoid the "non-zero"->"0" mapping. -all_pipeline = ["0", "1"] -pipelined_types = ["bf16"] -all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_pipeline, all_unroll] +all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"] +all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] +all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"] +all_protos = ["LL","LL128","SIMPLE"] +all_algos = ["TREE","RING", "", "", "", "", "PAT"] +all_accs = ["0", "1"] +all_pipelines = ["0", "1"] +all_unrolls = ["1", "2", "4"] +all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, all_accs, all_pipelines, all_unrolls] ################################################################################ # The first command line argument is the path to the directory to generate and @@ -128,7 +125,7 @@ tys_of_coll = { acc_of_coll = { "AllGather": ["0"], - "AllReduce": use_acc, + "AllReduce": all_accs, "AllToAllPivot": ["0"], "Broadcast": ["0"], "Reduce": ["0"], @@ -138,13 +135,14 @@ acc_of_coll = { pipelines_of_coll = { "AllGather": ["0"], - "AllReduce": all_pipeline, + "AllReduce": all_pipelines, "AllToAllPivot": ["0"], "Broadcast": ["0"], - "Reduce": all_pipeline, - "ReduceScatter": all_pipeline, + "Reduce": all_pipelines, + "ReduceScatter": all_pipelines, "SendRecv": ["0"] } +pipelined_types = ["bf16"] coll_camel_to_lower = { "AllGather": "all_gather", @@ -152,16 +150,29 @@ coll_camel_to_lower = { "AllToAllPivot": "alltoall_pivot", "Broadcast": "broadcast", "Reduce": "reduce", - "ReduceScatter": "reduce_scatter", - "SendRecv": "sendrecv" + "ReduceScatter": "reduce_scatter", + "SendRecv": "sendrecv" } coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower} ################################################################################ +@dataclass(frozen=True) +class Fn: + coll: str + algo: str + proto: str + redop: str + ty: str + acc: str + pipeline: str + unroll: str + + def __iter__(self): + return iter((self.coll, self.algo, self.proto, self.redop, self.ty, self.acc, self.pipeline, self.unroll)) def calc_unroll_for_local_arch(): if not is_local_arch_only: - return all_unroll + return all_unrolls rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo" @@ -197,7 +208,11 @@ def calc_unroll_for_local_arch(): else: return ["4"] else: - return all_unroll + return all_unrolls + +# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets, +# except for gfx950 +local_unroll = calc_unroll_for_local_arch() # Helper function to check if the conditions for the collective is being met def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll): @@ -211,7 +226,7 @@ def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll): ty not in tys_of_coll[coll] or acc not in acc_of_coll[coll] or pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types) or - unroll not in all_unroll): + unroll not in local_unroll): return False return True @@ -296,45 +311,55 @@ def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll): # Order rows are enumerated must match formula of `ncclDevFuncId()`: # outermost loop should be for unroll factor; refer to host_table section def enumerate_func_rows(): - for unroll in all_unroll: + for unroll in local_unroll: for coll in all_colls: for algo in all_algos: for proto in all_protos: for redop in all_redops: for ty in all_tys: - for acc in use_acc: - for pipeline in all_pipeline: + for acc in all_accs: + for pipeline in all_pipelines: if func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll): yield (coll, algo, proto, redop, ty, acc, pipeline, unroll) # Sort the hashmap based on custom key -def custom_sort_key(fn): - coll, algo, proto, redop, ty, acc, pipeline, unroll = fn +def custom_sort_key(fn: Fn): return ( - all_unroll.index(unroll), - all_colls.index(coll), - all_algos.index(algo), - all_protos.index(proto), - all_redops.index(redop), - all_tys.index(ty), - use_acc.index(acc), - all_pipeline.index(pipeline) + local_unroll.index(fn.unroll), + all_colls.index(fn.coll), + all_algos.index(fn.algo), + all_protos.index(fn.proto), + all_redops.index(fn.redop), + all_tys.index(fn.ty), + all_accs.index(fn.acc), + all_pipelines.index(fn.pipeline) ) +def get_arch_guard(fn): + cond = None + + if fn.proto == "LL128" and fn.acc == "1": + cond = "(defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)" + elif fn.proto == "LL128": + cond = "(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)" + elif fn.acc == "1": + cond = "defined(__gfx942__) || defined(__gfx950__)" + + return cond + ################################################################################ -# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets, -# except for gfx950 -all_unroll = calc_unroll_for_local_arch() - # Corresponds to ncclDevFuncRowToId[] -func_rows = [fn for fn in enumerate_func_rows()] +func_rows = [Fn(*fn) for fn in enumerate_func_rows()] # Corresponds to ncclDevFuncTable[] -primary_funcs = sorted(set(equivalent_primary(*fn) for fn in parse_input(func_pattern)), key=custom_sort_key) +primary_funcs = sorted( + {Fn(*equivalent_primary(*fn)) for fn in parse_input(func_pattern)}, key=custom_sort_key +) # primary_to_index[primary_funcs[i]] == i -primary_to_index = {fn: primary_funcs.index(fn) if fn in primary_funcs else -1 for fn in func_rows} +primary_to_index = {fn: i for i, fn in enumerate(primary_funcs)} +primary_to_index = {fn: primary_to_index.get(Fn(*fn), -1) for fn in func_rows} ################################################################################ @@ -348,124 +373,55 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f: for fn in primary_funcs: sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("%s %s();\n#else\n" % (func_declaration, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("%s %s();\n#endif\n" % (func_declaration, sym_ll)) + guard = get_arch_guard(fn) + if guard: + out("#if %s\n%s %s();\n#endif\n" % (guard, func_declaration, sym)) else: out("%s %s();\n" % (func_declaration, sym)) out("\n") + index = {val: None for val in all_unrolls} out("typedef void(*ncclDevFuncPtr_t)();\n\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n") - index1 = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, acc, pipeline, unroll = fn - if unroll != "1": continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (index1, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (index1, sym)) - index1 += 1 - out("nullptr};\n") - out("\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n") - index2 = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, acc, pipeline, unroll = fn - if unroll != "2": continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (index2, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (index2, sym)) - index2 += 1 - out("nullptr};\n") - out("\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") - index4 = 0 - for fn in primary_funcs: - coll, algo, proto, redop, ty, acc, pipeline, unroll = fn - if unroll != "4": continue - sym = paste("_", "ncclDevFunc", *fn) - if fn[2] == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") - out("/*%4d*/ %s,\n#else\n" % (index4, sym)) - fn_ll = fn[:2] + ("LL",) + fn[3:] - sym_ll = paste("_", "ncclDevFunc", *fn_ll) - out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll)) - else: - out("/*%4d*/ %s,\n" % (index4, sym)) - index4 += 1 - out("nullptr};\n") - out("\n") + for unroll in all_unrolls: + index[unroll] = 0 + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % unroll) + for fn in primary_funcs: + if fn.unroll != unroll: continue + sym = paste("_", "ncclDevFunc", *fn) + guard = get_arch_guard(fn) + if guard: + out("#if %s\n/*%4d*/ %s,\n#else\n/*%4d*/ nullptr,\n#endif\n" % (guard, index[unroll], sym, index[unroll])) + else: + out("/*%4d*/ %s,\n" % (index[unroll], sym)) + index[unroll] += 1 + out("nullptr};\n") + out("\n") if not is_ifc: - out("template\n" - "struct Caller1 {\n" - " static __forceinline__ __device__ __host__\n" - " void call1(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller1::call1(funcIndex) : Caller1::call1(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller1{\n" - " static __forceinline__ __device__ __host__\n" - " void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n" - "};\n") - out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n") - out(f" Caller1<0, {index1}>::call1(funcIndex);\n") - out("}\n\n") - out("template\n" - "struct Caller2 {\n" - " static __forceinline__ __device__ __host__\n" - " void call2(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller2::call2(funcIndex) : Caller2::call2(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller2{\n" - " static __forceinline__ __device__ __host__\n" - " void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n" - "};\n") - out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n") - out(f" Caller2<0, {index2}>::call2(funcIndex);\n") - out("}\n\n") - out("template\n" - "struct Caller4 {\n" - " static __forceinline__ __device__ __host__\n" - " void call4(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller4::call4(funcIndex) : Caller4::call4(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller4{\n" - " static __forceinline__ __device__ __host__\n" - " void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n" - "};\n") - out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n") - out(f" Caller4<0, {index4}>::call4(funcIndex);\n") - out("}\n\n") + for unroll in all_unrolls: + out(f"template\n" + f"struct Caller{unroll} {{\n" + " static __forceinline__ __device__ __host__\n" + f" void call{unroll}(unsigned short funcIndex) noexcept {{\n" + " constexpr unsigned short m = f + (l - f) / 2;\n" + f" return (funcIndex < m)\n" + f" ? Caller{unroll}::call{unroll}(funcIndex)\n" + f" : Caller{unroll}::call{unroll}(funcIndex);\n" + " }\n" + "};\n\n") + + out(f"template\n" + f"struct Caller{unroll} {{\n" + " static __forceinline__ __device__ __host__\n" + f" void call{unroll}(unsigned short funcIndex) noexcept {{\n" + f" ncclDevFuncTable_{unroll}[f]();\n" + " }\n" + "};\n\n") + + # emit NCCL_CALL_FUNCTIONS_ wrapper using last index value + out(f"__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_{unroll}(unsigned short funcIndex) noexcept {{\n") + out(f" Caller{unroll}<0, {index[unroll]}>::call{unroll}(funcIndex);\n") + out("}\n\n") # Generate /device_table.cpp if is_colltrace: @@ -479,7 +435,7 @@ if is_colltrace: seen_fns = set() out("const char* funcNames[] = {\n") for fn in primary_funcs: - fn_no_unroll = fn[:-1] + fn_no_unroll = (fn.coll, fn.algo, fn.proto, fn.redop, fn.ty, fn.acc, fn.pipeline) if fn_no_unroll not in seen_fns: out(' "%s",\n' % paste("_", "ncclDevFunc", *fn_no_unroll)) seen_fns.add(fn_no_unroll) @@ -510,28 +466,29 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: # host_table entries map device functions based on collective, algorithm, protocol, redop, and datatype # For GPU targets that support multiple unrolls, e.g., gfx950 # (or) for non-local builds, only a single set of functions are needed in the host_table. - for fn in func_rows[:len(func_rows)//len(all_unroll)]: + for fn in func_rows[:len(func_rows)//len(local_unroll)]: fn_id = -1 if fn is not None: - fn_id = primary_to_index[equivalent_primary(*fn)] - comment = " // " + paste(" ", *fn[:-1]) + guard = get_arch_guard(fn) + fn_id = primary_to_index[Fn(*equivalent_primary(*fn))] + comment = " // " + paste(" ", *fn) # Build the function signature string: " " # get parts indexes in order (coll, algo, proto, redop, ty, acc, pipeline, unroll) - coll_idx = all_colls.index(fn[0]) - algo_idx = all_algos.index(fn[1]) - proto_idx = all_protos.index(fn[2]) - redop_idx = all_redops.index(fn[3]) - ty_idx = all_tys.index(fn[4]) - acc_idx = use_acc.index(fn[5]) - pipeline_idx = all_pipeline.index(fn[6]) + coll_idx = all_colls.index(fn.coll) + algo_idx = all_algos.index(fn.algo) + proto_idx = all_protos.index(fn.proto) + redop_idx = all_redops.index(fn.redop) + ty_idx = all_tys.index(fn.ty) + acc_idx = all_accs.index(fn.acc) + pipeline_idx = all_pipelines.index(fn.pipeline) # Assert that 4 bits (16 values) is enough to map all_colls, all_algos, etc. assert len(all_colls) <= 16, "Error: all_colls has more than 16 values, which exceeds 4-bit capacity." assert len(all_algos) <= 16, "Error: all_algos has more than 16 values, which exceeds 4-bit capacity." assert len(all_protos) <= 16, "Error: all_protos has more than 16 values, which exceeds 4-bit capacity." assert len(all_redops) <= 16, "Error: all_redops has more than 16 values, which exceeds 4-bit capacity." assert len(all_tys) <= 16, "Error: all_tys has more than 16 values, which exceeds 4-bit capacity." - assert len(use_acc) <= 16, "Error: use_acc has more than 16 values, which exceeds 4-bit capacity." - assert len(all_pipeline) <= 16, "Error: all_pipeline has more than 16 values, which exceeds 4-bit capacity." + assert len(all_accs) <= 16, "Error: all_accs has more than 16 values, which exceeds 4-bit capacity." + assert len(all_pipelines) <= 16, "Error: all_pipelines has more than 16 values, which exceeds 4-bit capacity." # Create a 64-bit unsigned integer key and pack the indices into 4 bits each key = ( (coll_idx & 0xF) @@ -542,12 +499,15 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: | ((acc_idx & 0xF) << 20) | ((pipeline_idx & 0xF) << 24) ) - fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {acc_idx} {pipeline_idx}" - if fn[0] == "Broadcast": + if fn.coll == "Broadcast": key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8)) - if fn[0] in ["SendRecv", "AllToAllPivot"]: + if fn.coll in ["SendRecv", "AllToAllPivot"]: key = ((coll_idx & 0x3F)) - out(f' {{{key}, {fn_id}}}, {comment}\n') + + if fn_id != -1 and guard: + out(f'#if {guard}\n {{{key}, {fn_id}}}, {comment}\n#else\n {{{key}, -1}}, {comment}\n#endif\n') + else: + out(f' {{{key}, {fn_id}}}, {comment}\n') out("};\n") # Maps to .cu filename which implements this func. The only constraint is that @@ -562,13 +522,13 @@ def partition_by_name(fns): ans = {} for fn in fns: name = impl_filename(*fn) - coll = fn[0] + coll = fn.coll if name not in ans: ans[name] = (coll, []) ans[name][1].append(fn) return ans -name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn[0]!="Nop") +name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn.coll !="Nop") redop_to_cxx = { None: "FuncCopy", @@ -609,16 +569,16 @@ for name in name_to_funcs.keys(): ) for fn in fns: - (coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn - sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll) - if proto == "LL128": - out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") + sym = paste("_", fn.coll, fn.algo, fn.proto, fn.redop, fn.ty, fn.acc, fn.pipeline, fn.unroll) + guard = get_arch_guard(fn) + if guard: + out("#if %s\n" % guard) out( "DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {pipeline}, {unroll})\n" - .format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty], - algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, pipeline=pipeline, unroll=unroll) + .format(sym=sym, coll=fn.coll, redop_cxx=redop_to_cxx[fn.redop], ty_cxx=ty_to_cxx[fn.ty], + algo=(fn.algo or "RING"), proto=(fn.proto or "SIMPLE"), acc=fn.acc, pipeline=fn.pipeline, unroll=fn.unroll) ) - if proto == "LL128": + if guard: out("#endif\n") # Generate each /.cpp