[DEVICE] Add unroll=2 for gfx950 multi-node (#1824)

[ROCm/rccl commit: bd55f876e9]
此提交包含在:
Nilesh M Negi
2025-07-31 02:35:26 -05:00
提交者 GitHub
父節點 39c508b80d
當前提交 be810f10f3
共有 3 個檔案被更改,包括 35 行新增30 行删除
+23 -26
查看文件
@@ -52,16 +52,16 @@ else:
# make ONLY_FUNCS="AllReduce * * Sum i32"
#
# # Only AllReduce RING Max float (but all protos)
# make ONLY_FUNCS="AllReduce RING * Max float"
# make ONLY_FUNCS="AllReduce RING * Max f32"
#
# # AllReduce TREE LL128 Prod rccl_bfloat16
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16"
#
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32"
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32"
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|AllGather RING LL/SIMPLE Sum i8|AllToAllPivot RING SIMPLE Sum i8|Broadcast RING LL/SIMPLE Sum i8|Reduce RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|ReduceScatter RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|SendRecv RING SIMPLE Sum i8"
# Paste all non-None arguments together with `sep`.
def paste(sep, *args):
@@ -134,14 +134,14 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
################################################################################
def calc_unroll_for_local_arch():
if not is_local_arch_only:
return
if not is_local_arch_only:
return all_unroll
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True)
rocminfo_output = res.stdout
# Parse rocminfo binary output
gfx_targets = {}
curr_name = None
@@ -156,26 +156,28 @@ def calc_unroll_for_local_arch():
cu_count = int(line.split(':')[-1].strip())
gfx_targets[(curr_name, cu_count)] = None
curr_name = None
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
gfx_targets = list(gfx_targets.keys())
# Homogeneous system is required to build for only 1 varient of unroll factor
# Homogeneous system is required to build for only 1 variant of unroll factor (except for gfx950)
if len(gfx_targets) == 1:
gfx_name, cu_count = gfx_targets[0]
if "gfx950" == gfx_name:
return 1
return ["1", "2"]
elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80):
return 2
return ["2"]
else:
return 4
return ["4"]
else:
return all_unroll
# Helper function to check if the conditions for the collective is being met
def func_validate(coll, algo, proto, redop, ty):
def func_validate(coll, algo, proto, redop, ty, unroll):
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
return False
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
return False
return True
@@ -191,9 +193,6 @@ def func_filter(function_params, current_idx, item_list=None):
# If the paramter is equal to '*', include all possible cases for it
if current_element == "*":
if current_idx == 0:
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
# Get the current list from all_params
current_list = all_params[current_idx]
@@ -225,7 +224,7 @@ def func_filter(function_params, current_idx, item_list=None):
else:
coll, algo, proto, redop, ty, unroll = item_list
if func_validate(coll, algo, proto, redop, ty):
if func_validate(coll, algo, proto, redop, ty, unroll):
yield(coll, algo, proto, redop, ty, unroll)
# Parse ONLY_FUNCS input and feed it to func_filter
@@ -247,10 +246,6 @@ def parse_input(func_pattern):
# Maps functions to the chosen representative for the equivalence class it
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
# if local arch only, we only need to build for 1 varient of coll_unroll.
# map the other varient of coll_unroll to this one.
if coll_unroll:
unroll = str(coll_unroll)
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
# map signed integer sum/prod to unsigned
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
@@ -268,7 +263,7 @@ def enumerate_func_rows():
for proto in all_protos:
for redop in all_redops:
for ty in all_tys:
if func_validate(coll, algo, proto, redop, ty):
if func_validate(coll, algo, proto, redop, ty, unroll):
yield (coll, algo, proto, redop, ty, unroll)
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
@@ -286,7 +281,9 @@ def custom_sort_key(fn):
################################################################################
coll_unroll = calc_unroll_for_local_arch()
# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets,
# except for gfx950
all_unroll = calc_unroll_for_local_arch()
# Corresponds to ncclDevFuncRowToId[]
func_rows = [fn for fn in enumerate_func_rows()]
@@ -459,7 +456,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
# The mapping from function rows to valid primary function ids.
out("extern int const ncclDevFuncRowToId[] = {\n")
index = 0
for fn in func_rows[:len(func_rows)//3]:
for fn in func_rows[:len(func_rows)//len(all_unroll)]:
fn_id, comment = -1, ""
if fn is not None:
fn_id = primary_to_index[equivalent_primary(*fn)]
+4 -2
查看文件
@@ -612,8 +612,7 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
// RCCL: create persistent stream for calloc
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
// RCCL: determine and set unroll factor for comm
NCCLCHECK(commSetUnrollFactor(comm));
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
@@ -1965,6 +1964,9 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail);
// RCCL: determine and set unroll factor for comm
NCCLCHECK(commSetUnrollFactor(comm));
#ifdef ENABLE_MSCCLPP
if (job->parent) {
if (job->parent->mscclppCompatible) {
+8 -2
查看文件
@@ -157,11 +157,17 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count,
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
if(IsArchMatch(devProp.gcnArchName, "gfx950"))
comm->unroll = NCCL_UNROLL_1;
if(IsArchMatch(devProp.gcnArchName, "gfx950")) {
if(comm->nNodes == 1)
comm->unroll = NCCL_UNROLL_1;
else
comm->unroll = NCCL_UNROLL_2;
}
else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)))
comm->unroll = NCCL_UNROLL_2;
else
comm->unroll = NCCL_UNROLL_4;
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll+1);
return ncclSuccess;
}