Fix ncclDevFuncId for AllReduceWithBias (#1980)
[ROCm/rccl commit: c35bc721ad]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
fca120343f
Коммит
0aa56fb0a5
@@ -4,7 +4,7 @@ import sys
|
||||
import subprocess
|
||||
|
||||
# Order of colls, redops, tys, protos, algos must match src/include/device.h
|
||||
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "AllReduceWithBias", "SendRecv", "", "", "AllToAllPivot"]
|
||||
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"]
|
||||
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
|
||||
all_protos = ["LL","LL128","SIMPLE"]
|
||||
@@ -82,14 +82,13 @@ func_pattern = sys.argv[6:7]
|
||||
if func_pattern and func_pattern[0]:
|
||||
func_pattern = func_pattern[0]
|
||||
else:
|
||||
func_pattern = "AllGather|AllReduce|AllReduceWithBias|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
|
||||
################################################################################
|
||||
|
||||
algos_of_coll = {
|
||||
"AllGather": ["RING", "PAT"],
|
||||
"AllReduce": ["RING", "TREE"],
|
||||
"AllReduceWithBias": ["RING", "TREE"],
|
||||
"AllToAllPivot": ["RING"],
|
||||
"Broadcast": ["RING"],
|
||||
"Reduce": ["RING"],
|
||||
@@ -100,7 +99,6 @@ algos_of_coll = {
|
||||
protos_of_coll = {
|
||||
"AllGather": all_protos,
|
||||
"AllReduce": all_protos,
|
||||
"AllReduceWithBias": all_protos,
|
||||
"AllToAllPivot": ["SIMPLE"],
|
||||
"Broadcast": all_protos,
|
||||
"Reduce": all_protos,
|
||||
@@ -111,7 +109,6 @@ protos_of_coll = {
|
||||
redops_of_coll = {
|
||||
"AllGather": ["Sum"],
|
||||
"AllReduce": all_redops,
|
||||
"AllReduceWithBias": all_redops,
|
||||
"AllToAllPivot": ["Sum"],
|
||||
"Broadcast": ["Sum"],
|
||||
"Reduce": all_redops,
|
||||
@@ -122,7 +119,6 @@ redops_of_coll = {
|
||||
tys_of_coll = {
|
||||
"AllGather": ["i8"],
|
||||
"AllReduce": all_tys,
|
||||
"AllReduceWithBias": all_tys,
|
||||
"AllToAllPivot": ["i8"],
|
||||
"Broadcast": ["i8"],
|
||||
"Reduce": all_tys,
|
||||
@@ -130,10 +126,19 @@ tys_of_coll = {
|
||||
"SendRecv": ["i8"]
|
||||
}
|
||||
|
||||
acc_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": use_acc,
|
||||
"AllToAllPivot": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": ["0"],
|
||||
"ReduceScatter": ["0"],
|
||||
"SendRecv": ["0"]
|
||||
}
|
||||
|
||||
pipelines_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": all_pipeline,
|
||||
"AllReduceWithBias": ["0"],
|
||||
"AllToAllPivot": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": all_pipeline,
|
||||
@@ -144,7 +149,6 @@ pipelines_of_coll = {
|
||||
coll_camel_to_lower = {
|
||||
"AllGather": "all_gather",
|
||||
"AllReduce": "all_reduce",
|
||||
"AllReduceWithBias": "allreduce_with_bias",
|
||||
"AllToAllPivot": "alltoall_pivot",
|
||||
"Broadcast": "broadcast",
|
||||
"Reduce": "reduce",
|
||||
@@ -197,15 +201,17 @@ def calc_unroll_for_local_arch():
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
if acc == "1" and coll != "AllReduceWithBias":
|
||||
return False
|
||||
if acc == "0" and coll == "AllReduceWithBias":
|
||||
return False
|
||||
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
||||
return False
|
||||
if coll == "" or algo == "":
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or acc not in use_acc or unroll not in all_unroll or pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types):
|
||||
if (algo not in algos_of_coll[coll] or
|
||||
proto not in protos_of_coll[coll] or
|
||||
redop not in redops_of_coll[coll] or
|
||||
ty not in tys_of_coll[coll] or
|
||||
acc not in acc_of_coll[coll] or
|
||||
pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types) or
|
||||
unroll not in all_unroll):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -274,7 +280,7 @@ def parse_input(func_pattern):
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
if coll in ("AllReduce", "AllReduceWithBias", "Reduce", "ReduceScatter"):
|
||||
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
||||
ty = "u"+ty[1:]
|
||||
@@ -291,26 +297,27 @@ def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
# outermost loop should be for unroll factor; refer to host_table section
|
||||
def enumerate_func_rows():
|
||||
for unroll in all_unroll:
|
||||
for acc in use_acc:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
for acc in use_acc:
|
||||
for pipeline in all_pipeline:
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll):
|
||||
yield (coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, acc, pipeline, unroll = fn
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
use_acc.index(acc),
|
||||
all_colls.index(coll),
|
||||
all_algos.index(algo),
|
||||
all_protos.index(proto),
|
||||
all_redops.index(redop),
|
||||
all_tys.index(ty),
|
||||
use_acc.index(acc),
|
||||
all_pipeline.index(pipeline)
|
||||
)
|
||||
|
||||
@@ -488,15 +495,15 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
out('#include "device.h"\n')
|
||||
out("\n")
|
||||
out("// The key for the ncclDevFuncNameToId map is a 64-bit unsigned integer.\n")
|
||||
out("// Each field (coll, algo, proto, redop, ty, pipeline) is packed into 4 bits,\n")
|
||||
out("// Each field (coll, algo, proto, redop, ty) is packed into 4 bits,\n")
|
||||
out("// Each field (coll, algo, proto, redop, ty, acc, pipeline) is packed into 4 bits,\n")
|
||||
out("// This allows up to 16 unique values per field. The layout is:\n")
|
||||
out("// bits 0-3: coll index\n")
|
||||
out("// bits 4-7: algo index\n")
|
||||
out("// bits 8-11: proto index\n")
|
||||
out("// bits 12-15: redop index\n")
|
||||
out("// bits 16-19: ty index\n")
|
||||
out("// bits 20-23: pipeline index\n")
|
||||
out("// bits 20-23: accumulator index\n")
|
||||
out("// bits 24-27: pipeline index\n")
|
||||
out("#include <unordered_map>\n")
|
||||
out("std::unordered_map<uint64_t, int> ncclDevFuncNameToId = {\n")
|
||||
|
||||
@@ -515,6 +522,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
proto_idx = all_protos.index(fn[2])
|
||||
redop_idx = all_redops.index(fn[3])
|
||||
ty_idx = all_tys.index(fn[4])
|
||||
acc_idx = use_acc.index(fn[5])
|
||||
pipeline_idx = all_pipeline.index(fn[6])
|
||||
# Assert that 4 bits (16 values) is enough to map all_colls, all_algos, etc.
|
||||
assert len(all_colls) <= 16, "Error: all_colls has more than 16 values, which exceeds 4-bit capacity."
|
||||
@@ -522,6 +530,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
assert len(all_protos) <= 16, "Error: all_protos has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_redops) <= 16, "Error: all_redops has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_tys) <= 16, "Error: all_tys has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(use_acc) <= 16, "Error: use_acc has more than 16 values, which exceeds 4-bit capacity."
|
||||
assert len(all_pipeline) <= 16, "Error: all_pipeline has more than 16 values, which exceeds 4-bit capacity."
|
||||
# Create a 64-bit unsigned integer key and pack the indices into 4 bits each
|
||||
key = (
|
||||
@@ -530,9 +539,10 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
| ((proto_idx & 0xF) << 8)
|
||||
| ((redop_idx & 0xF) << 12)
|
||||
| ((ty_idx & 0xF) << 16)
|
||||
| ((pipeline_idx & 0xF) << 20)
|
||||
| ((acc_idx & 0xF) << 20)
|
||||
| ((pipeline_idx & 0xF) << 24)
|
||||
)
|
||||
fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {pipeline_idx}"
|
||||
fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {acc_idx} {pipeline_idx}"
|
||||
if fn[0] == "Broadcast":
|
||||
key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8))
|
||||
if fn[0] in ["SendRecv", "AllToAllPivot"]:
|
||||
@@ -592,8 +602,6 @@ for name in name_to_funcs.keys():
|
||||
print("-- Generating %s" % os.path.join(gensrc, name))
|
||||
|
||||
out = f.write
|
||||
if coll == "AllReduceWithBias":
|
||||
coll = "AllReduce"
|
||||
out(
|
||||
'#include "common.h"\n'
|
||||
'#include "{lower_coll}.h"\n'
|
||||
@@ -603,8 +611,6 @@ for name in name_to_funcs.keys():
|
||||
for fn in fns:
|
||||
(coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
if coll == "AllReduceWithBias":
|
||||
coll = "AllReduce"
|
||||
if proto == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out(
|
||||
|
||||
@@ -501,13 +501,9 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool
|
||||
|
||||
NCCLCHECK(getAlgoInfo(comm, &agg, collNetSupport, nvlsSupport, nTasksPerChannel, simInfo));
|
||||
if(agg.func==ncclFuncAllReduce && agg.acc != nullptr)
|
||||
{
|
||||
agg.devFuncId = ncclDevFuncId(ncclFuncAllReduceWithBias, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline);
|
||||
}
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, 1, agg.pipeline);
|
||||
else
|
||||
{
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline);
|
||||
}
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, 0, agg.pipeline);
|
||||
if (agg.devFuncId < 0) {
|
||||
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
|
||||
return ncclInvalidUsage;
|
||||
|
||||
@@ -159,7 +159,8 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK
|
||||
#define RCCL_PROTO_SHIFT 8
|
||||
#define RCCL_REDOP_SHIFT 12
|
||||
#define RCCL_DTYPE_SHIFT 16
|
||||
#define RCCL_PIPELINE_SHIFT 20
|
||||
#define RCCL_ACC_SHIFT 20
|
||||
#define RCCL_PIPELINE_SHIFT 24
|
||||
|
||||
struct ncclConnInfo {
|
||||
// Regular comm mechanism
|
||||
@@ -717,11 +718,11 @@ inline bool ncclNvlsSupported(int devRedOp, int type) {
|
||||
extern std::unordered_map<uint64_t, int> ncclDevFuncNameToId;
|
||||
|
||||
// `ncclDevFuncId()` needs to be in sync with 'all_colls' in generate.py
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int pipeline = 0) {
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int acc = 0, int pipeline = 0) {
|
||||
int row = -1;
|
||||
uint64_t key;
|
||||
// Pack 4-bit fields from right (LSB) to left in order:
|
||||
// coll, algo, proto, devRedOp, type
|
||||
// coll, algo, proto, devRedOp, type, acc, pipeline
|
||||
// This logic must be in sync with the key generation logic in generate.py
|
||||
if (coll == ncclFuncBroadcast) {
|
||||
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) |
|
||||
@@ -734,6 +735,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto,
|
||||
((uint64_t)(proto & RCCL_FUNC_ID_MASK) << RCCL_PROTO_SHIFT) |
|
||||
((uint64_t)(devRedOp & RCCL_FUNC_ID_MASK) << RCCL_REDOP_SHIFT) |
|
||||
((uint64_t)(type & RCCL_FUNC_ID_MASK) << RCCL_DTYPE_SHIFT) |
|
||||
((uint64_t)(acc & RCCL_FUNC_ID_MASK) << RCCL_ACC_SHIFT) |
|
||||
((uint64_t)(pipeline & RCCL_FUNC_ID_MASK) << RCCL_PIPELINE_SHIFT);
|
||||
}
|
||||
auto it = ncclDevFuncNameToId.find(key);
|
||||
@@ -741,7 +743,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto,
|
||||
row = it->second;
|
||||
}
|
||||
if(row < 0) {
|
||||
WARN("Fatal error: ncclDevFuncId: %lu not found for coll: %d, algo: %d, proto: %d, devRedOp: %d, type: %d", key, coll, algo, proto, devRedOp, type);
|
||||
WARN("Fatal error: ncclDevFuncId: %lu not found for coll: %d, algo: %d, proto: %d, devRedOp: %d, type: %d, acc: %d, pipeline: %d", key, coll, algo, proto, devRedOp, type, acc, pipeline);
|
||||
return -1;
|
||||
}
|
||||
return row;
|
||||
|
||||
@@ -59,12 +59,11 @@ typedef enum {
|
||||
ncclFuncAllGather = 2,
|
||||
ncclFuncReduceScatter = 3,
|
||||
ncclFuncAllReduce = 4,
|
||||
ncclFuncAllReduceWithBias = 5,
|
||||
ncclFuncSendRecv = 6,
|
||||
ncclFuncSend = 7,
|
||||
ncclFuncRecv = 8,
|
||||
ncclFuncAllToAllPivot = 9,
|
||||
ncclNumFuncs = 10
|
||||
ncclFuncSendRecv = 5,
|
||||
ncclFuncSend = 6,
|
||||
ncclFuncRecv = 7,
|
||||
ncclFuncAllToAllPivot = 8,
|
||||
ncclNumFuncs = 9
|
||||
} ncclFunc_t;
|
||||
|
||||
#define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*/PAT
|
||||
|
||||
Ссылка в новой задаче
Block a user