From 0aa56fb0a5f402f419266a0dc250e3f40a3fef27 Mon Sep 17 00:00:00 2001 From: Nilesh M Negi Date: Fri, 17 Oct 2025 09:28:57 -0500 Subject: [PATCH] Fix ncclDevFuncId for AllReduceWithBias (#1980) [ROCm/rccl commit: c35bc721adbb4fe500b9e58f6ff59c5afb8c81ac] --- projects/rccl/src/device/generate.py | 66 ++++++++++++++----------- projects/rccl/src/enqueue.cc | 8 +-- projects/rccl/src/include/device.h | 10 ++-- projects/rccl/src/include/nccl_common.h | 11 ++--- 4 files changed, 49 insertions(+), 46 deletions(-) diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index ce3e679606..c4e515824f 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -4,7 +4,7 @@ import sys import subprocess # Order of colls, redops, tys, protos, algos must match src/include/device.h -all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "AllReduceWithBias", "SendRecv", "", "", "AllToAllPivot"] +all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"] all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"] all_protos = ["LL","LL128","SIMPLE"] @@ -82,14 +82,13 @@ func_pattern = sys.argv[6:7] if func_pattern and func_pattern[0]: func_pattern = func_pattern[0] else: - func_pattern = "AllGather|AllReduce|AllReduceWithBias|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv" + func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv" ################################################################################ algos_of_coll = { "AllGather": ["RING", "PAT"], "AllReduce": ["RING", "TREE"], - "AllReduceWithBias": ["RING", "TREE"], "AllToAllPivot": ["RING"], "Broadcast": ["RING"], "Reduce": ["RING"], @@ -100,7 +99,6 @@ algos_of_coll = { protos_of_coll = { "AllGather": all_protos, "AllReduce": all_protos, - "AllReduceWithBias": all_protos, "AllToAllPivot": ["SIMPLE"], "Broadcast": all_protos, "Reduce": all_protos, @@ -111,7 +109,6 @@ protos_of_coll = { redops_of_coll = { "AllGather": ["Sum"], "AllReduce": all_redops, - "AllReduceWithBias": all_redops, "AllToAllPivot": ["Sum"], "Broadcast": ["Sum"], "Reduce": all_redops, @@ -122,7 +119,6 @@ redops_of_coll = { tys_of_coll = { "AllGather": ["i8"], "AllReduce": all_tys, - "AllReduceWithBias": all_tys, "AllToAllPivot": ["i8"], "Broadcast": ["i8"], "Reduce": all_tys, @@ -130,10 +126,19 @@ tys_of_coll = { "SendRecv": ["i8"] } +acc_of_coll = { + "AllGather": ["0"], + "AllReduce": use_acc, + "AllToAllPivot": ["0"], + "Broadcast": ["0"], + "Reduce": ["0"], + "ReduceScatter": ["0"], + "SendRecv": ["0"] +} + pipelines_of_coll = { "AllGather": ["0"], "AllReduce": all_pipeline, - "AllReduceWithBias": ["0"], "AllToAllPivot": ["0"], "Broadcast": ["0"], "Reduce": all_pipeline, @@ -144,7 +149,6 @@ pipelines_of_coll = { coll_camel_to_lower = { "AllGather": "all_gather", "AllReduce": "all_reduce", - "AllReduceWithBias": "allreduce_with_bias", "AllToAllPivot": "alltoall_pivot", "Broadcast": "broadcast", "Reduce": "reduce", @@ -197,15 +201,17 @@ def calc_unroll_for_local_arch(): # Helper function to check if the conditions for the collective is being met def func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll): - if acc == "1" and coll != "AllReduceWithBias": - return False - if acc == "0" and coll == "AllReduceWithBias": - return False if redop == "SumPostDiv" and ty[0] not in ("i","u"): return False if coll == "" or algo == "": return False - if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or acc not in use_acc or unroll not in all_unroll or pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types): + if (algo not in algos_of_coll[coll] or + proto not in protos_of_coll[coll] or + redop not in redops_of_coll[coll] or + ty not in tys_of_coll[coll] or + acc not in acc_of_coll[coll] or + pipeline not in pipelines_of_coll[coll] or (pipeline in ["1"] and ty not in pipelined_types) or + unroll not in all_unroll): return False return True @@ -274,7 +280,7 @@ def parse_input(func_pattern): # Maps functions to the chosen representative for the equivalence class it # belongs to. For instance (sum, signed int) maps to (sum, unsigned int). def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll): - if coll in ("AllReduce", "AllReduceWithBias", "Reduce", "ReduceScatter"): + if coll in ("AllReduce", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i": ty = "u"+ty[1:] @@ -291,26 +297,27 @@ def equivalent_primary(coll, algo, proto, redop, ty, acc, pipeline, unroll): # outermost loop should be for unroll factor; refer to host_table section def enumerate_func_rows(): for unroll in all_unroll: - for acc in use_acc: - for coll in all_colls: - for algo in all_algos: - for proto in all_protos: - for redop in all_redops: - for ty in all_tys: + for coll in all_colls: + for algo in all_algos: + for proto in all_protos: + for redop in all_redops: + for ty in all_tys: + for acc in use_acc: for pipeline in all_pipeline: if func_validate(coll, algo, proto, redop, ty, acc, pipeline, unroll): yield (coll, algo, proto, redop, ty, acc, pipeline, unroll) + # Sort the hashmap based on custom key def custom_sort_key(fn): coll, algo, proto, redop, ty, acc, pipeline, unroll = fn return ( all_unroll.index(unroll), - use_acc.index(acc), all_colls.index(coll), all_algos.index(algo), all_protos.index(proto), all_redops.index(redop), all_tys.index(ty), + use_acc.index(acc), all_pipeline.index(pipeline) ) @@ -488,15 +495,15 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: out('#include "device.h"\n') out("\n") out("// The key for the ncclDevFuncNameToId map is a 64-bit unsigned integer.\n") - out("// Each field (coll, algo, proto, redop, ty, pipeline) is packed into 4 bits,\n") - out("// Each field (coll, algo, proto, redop, ty) is packed into 4 bits,\n") + out("// Each field (coll, algo, proto, redop, ty, acc, pipeline) is packed into 4 bits,\n") out("// This allows up to 16 unique values per field. The layout is:\n") out("// bits 0-3: coll index\n") out("// bits 4-7: algo index\n") out("// bits 8-11: proto index\n") out("// bits 12-15: redop index\n") out("// bits 16-19: ty index\n") - out("// bits 20-23: pipeline index\n") + out("// bits 20-23: accumulator index\n") + out("// bits 24-27: pipeline index\n") out("#include \n") out("std::unordered_map ncclDevFuncNameToId = {\n") @@ -515,6 +522,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: proto_idx = all_protos.index(fn[2]) redop_idx = all_redops.index(fn[3]) ty_idx = all_tys.index(fn[4]) + acc_idx = use_acc.index(fn[5]) pipeline_idx = all_pipeline.index(fn[6]) # Assert that 4 bits (16 values) is enough to map all_colls, all_algos, etc. assert len(all_colls) <= 16, "Error: all_colls has more than 16 values, which exceeds 4-bit capacity." @@ -522,6 +530,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: assert len(all_protos) <= 16, "Error: all_protos has more than 16 values, which exceeds 4-bit capacity." assert len(all_redops) <= 16, "Error: all_redops has more than 16 values, which exceeds 4-bit capacity." assert len(all_tys) <= 16, "Error: all_tys has more than 16 values, which exceeds 4-bit capacity." + assert len(use_acc) <= 16, "Error: use_acc has more than 16 values, which exceeds 4-bit capacity." assert len(all_pipeline) <= 16, "Error: all_pipeline has more than 16 values, which exceeds 4-bit capacity." # Create a 64-bit unsigned integer key and pack the indices into 4 bits each key = ( @@ -530,9 +539,10 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: | ((proto_idx & 0xF) << 8) | ((redop_idx & 0xF) << 12) | ((ty_idx & 0xF) << 16) - | ((pipeline_idx & 0xF) << 20) + | ((acc_idx & 0xF) << 20) + | ((pipeline_idx & 0xF) << 24) ) - fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {pipeline_idx}" + fn_str = f"{coll_idx} {algo_idx} {proto_idx} {redop_idx} {ty_idx} {acc_idx} {pipeline_idx}" if fn[0] == "Broadcast": key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8)) if fn[0] in ["SendRecv", "AllToAllPivot"]: @@ -592,8 +602,6 @@ for name in name_to_funcs.keys(): print("-- Generating %s" % os.path.join(gensrc, name)) out = f.write - if coll == "AllReduceWithBias": - coll = "AllReduce" out( '#include "common.h"\n' '#include "{lower_coll}.h"\n' @@ -603,8 +611,6 @@ for name in name_to_funcs.keys(): for fn in fns: (coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll) - if coll == "AllReduceWithBias": - coll = "AllReduce" if proto == "LL128": out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") out( diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index 6aa589a0b8..575d84680a 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -501,13 +501,9 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool NCCLCHECK(getAlgoInfo(comm, &agg, collNetSupport, nvlsSupport, nTasksPerChannel, simInfo)); if(agg.func==ncclFuncAllReduce && agg.acc != nullptr) - { - agg.devFuncId = ncclDevFuncId(ncclFuncAllReduceWithBias, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline); - } + agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, 1, agg.pipeline); else - { - agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline); - } + agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, 0, agg.pipeline); if (agg.devFuncId < 0) { WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); return ncclInvalidUsage; diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index 9df54adfda..3e0cb12f3c 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -159,7 +159,8 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define RCCL_PROTO_SHIFT 8 #define RCCL_REDOP_SHIFT 12 #define RCCL_DTYPE_SHIFT 16 -#define RCCL_PIPELINE_SHIFT 20 +#define RCCL_ACC_SHIFT 20 +#define RCCL_PIPELINE_SHIFT 24 struct ncclConnInfo { // Regular comm mechanism @@ -717,11 +718,11 @@ inline bool ncclNvlsSupported(int devRedOp, int type) { extern std::unordered_map ncclDevFuncNameToId; // `ncclDevFuncId()` needs to be in sync with 'all_colls' in generate.py -inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int pipeline = 0) { +inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, int acc = 0, int pipeline = 0) { int row = -1; uint64_t key; // Pack 4-bit fields from right (LSB) to left in order: - // coll, algo, proto, devRedOp, type + // coll, algo, proto, devRedOp, type, acc, pipeline // This logic must be in sync with the key generation logic in generate.py if (coll == ncclFuncBroadcast) { key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) | @@ -734,6 +735,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, ((uint64_t)(proto & RCCL_FUNC_ID_MASK) << RCCL_PROTO_SHIFT) | ((uint64_t)(devRedOp & RCCL_FUNC_ID_MASK) << RCCL_REDOP_SHIFT) | ((uint64_t)(type & RCCL_FUNC_ID_MASK) << RCCL_DTYPE_SHIFT) | + ((uint64_t)(acc & RCCL_FUNC_ID_MASK) << RCCL_ACC_SHIFT) | ((uint64_t)(pipeline & RCCL_FUNC_ID_MASK) << RCCL_PIPELINE_SHIFT); } auto it = ncclDevFuncNameToId.find(key); @@ -741,7 +743,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, row = it->second; } if(row < 0) { - WARN("Fatal error: ncclDevFuncId: %lu not found for coll: %d, algo: %d, proto: %d, devRedOp: %d, type: %d", key, coll, algo, proto, devRedOp, type); + WARN("Fatal error: ncclDevFuncId: %lu not found for coll: %d, algo: %d, proto: %d, devRedOp: %d, type: %d, acc: %d, pipeline: %d", key, coll, algo, proto, devRedOp, type, acc, pipeline); return -1; } return row; diff --git a/projects/rccl/src/include/nccl_common.h b/projects/rccl/src/include/nccl_common.h index 9738e9e60a..7ddf367417 100644 --- a/projects/rccl/src/include/nccl_common.h +++ b/projects/rccl/src/include/nccl_common.h @@ -59,12 +59,11 @@ typedef enum { ncclFuncAllGather = 2, ncclFuncReduceScatter = 3, ncclFuncAllReduce = 4, - ncclFuncAllReduceWithBias = 5, - ncclFuncSendRecv = 6, - ncclFuncSend = 7, - ncclFuncRecv = 8, - ncclFuncAllToAllPivot = 9, - ncclNumFuncs = 10 + ncclFuncSendRecv = 5, + ncclFuncSend = 6, + ncclFuncRecv = 7, + ncclFuncAllToAllPivot = 8, + ncclNumFuncs = 9 } ncclFunc_t; #define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*/PAT