[BUILD] Populate host_table entries only for 1 unroll (#1871)
[ROCm/rccl commit: bf6660ee4e]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
40462cc845
Коммит
fbe014c870
@@ -267,6 +267,7 @@ def equivalent_primary(coll, algo, proto, redop, ty, acc, unroll):
|
||||
return (coll, algo, proto, redop, ty, acc, unroll)
|
||||
|
||||
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
||||
# outermost loop should be for unroll factor; refer to host_table section
|
||||
def enumerate_func_rows():
|
||||
for unroll in all_unroll:
|
||||
for acc in use_acc:
|
||||
@@ -474,7 +475,11 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
out("// bits 16-19: ty index\n")
|
||||
out("#include <unordered_map>\n")
|
||||
out("extern std::unordered_map<uint64_t, int> ncclDevFuncNameToId = {\n")
|
||||
for fn in func_rows:
|
||||
|
||||
# host_table entries map device functions based on collective, algorithm, protocol, redop, and datatype
|
||||
# For GPU targets that support multiple unrolls, e.g., gfx950
|
||||
# (or) for non-local builds, only a single set of functions are needed in the host_table.
|
||||
for fn in func_rows[:len(func_rows)//len(all_unroll)]:
|
||||
fn_id = -1
|
||||
if fn is not None:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
|
||||
Ссылка в новой задаче
Block a user