[DEVICE] Add unroll=2 for gfx950 multi-node (#1824)
This commit is contained in:
committed by
GitHub
orang tua
874cd657ef
melakukan
bd55f876e9
+23
-26
@@ -52,16 +52,16 @@ else:
|
||||
# make ONLY_FUNCS="AllReduce * * Sum i32"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
# make ONLY_FUNCS="AllReduce RING * Max f32"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|AllGather RING LL/SIMPLE Sum i8|AllToAllPivot RING SIMPLE Sum i8|Broadcast RING LL/SIMPLE Sum i8|Reduce RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|ReduceScatter RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|SendRecv RING SIMPLE Sum i8"
|
||||
|
||||
# Paste all non-None arguments together with `sep`.
|
||||
def paste(sep, *args):
|
||||
@@ -134,14 +134,14 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
||||
################################################################################
|
||||
|
||||
def calc_unroll_for_local_arch():
|
||||
if not is_local_arch_only:
|
||||
return
|
||||
if not is_local_arch_only:
|
||||
return all_unroll
|
||||
|
||||
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
|
||||
|
||||
res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
rocminfo_output = res.stdout
|
||||
|
||||
|
||||
# Parse rocminfo binary output
|
||||
gfx_targets = {}
|
||||
curr_name = None
|
||||
@@ -156,26 +156,28 @@ def calc_unroll_for_local_arch():
|
||||
cu_count = int(line.split(':')[-1].strip())
|
||||
gfx_targets[(curr_name, cu_count)] = None
|
||||
curr_name = None
|
||||
|
||||
|
||||
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
|
||||
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
|
||||
gfx_targets = list(gfx_targets.keys())
|
||||
|
||||
# Homogeneous system is required to build for only 1 varient of unroll factor
|
||||
# Homogeneous system is required to build for only 1 variant of unroll factor (except for gfx950)
|
||||
if len(gfx_targets) == 1:
|
||||
gfx_name, cu_count = gfx_targets[0]
|
||||
if "gfx950" == gfx_name:
|
||||
return 1
|
||||
return ["1", "2"]
|
||||
elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80):
|
||||
return 2
|
||||
return ["2"]
|
||||
else:
|
||||
return 4
|
||||
return ["4"]
|
||||
else:
|
||||
return all_unroll
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty):
|
||||
def func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -191,9 +193,6 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
|
||||
# If the paramter is equal to '*', include all possible cases for it
|
||||
if current_element == "*":
|
||||
if current_idx == 0:
|
||||
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
|
||||
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
|
||||
# Get the current list from all_params
|
||||
current_list = all_params[current_idx]
|
||||
@@ -225,7 +224,7 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
else:
|
||||
coll, algo, proto, redop, ty, unroll = item_list
|
||||
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
yield(coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
# Parse ONLY_FUNCS input and feed it to func_filter
|
||||
@@ -247,10 +246,6 @@ def parse_input(func_pattern):
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
|
||||
# if local arch only, we only need to build for 1 varient of coll_unroll.
|
||||
# map the other varient of coll_unroll to this one.
|
||||
if coll_unroll:
|
||||
unroll = str(coll_unroll)
|
||||
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
||||
@@ -268,7 +263,7 @@ def enumerate_func_rows():
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
yield (coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
@@ -286,7 +281,9 @@ def custom_sort_key(fn):
|
||||
|
||||
################################################################################
|
||||
|
||||
coll_unroll = calc_unroll_for_local_arch()
|
||||
# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets,
|
||||
# except for gfx950
|
||||
all_unroll = calc_unroll_for_local_arch()
|
||||
|
||||
# Corresponds to ncclDevFuncRowToId[]
|
||||
func_rows = [fn for fn in enumerate_func_rows()]
|
||||
@@ -459,7 +456,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# The mapping from function rows to valid primary function ids.
|
||||
out("extern int const ncclDevFuncRowToId[] = {\n")
|
||||
index = 0
|
||||
for fn in func_rows[:len(func_rows)//3]:
|
||||
for fn in func_rows[:len(func_rows)//len(all_unroll)]:
|
||||
fn_id, comment = -1, ""
|
||||
if fn is not None:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
|
||||
+4
-2
@@ -612,8 +612,7 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
|
||||
|
||||
// RCCL: create persistent stream for calloc
|
||||
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
|
||||
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
|
||||
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
|
||||
|
||||
@@ -1965,6 +1964,9 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
|
||||
NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail);
|
||||
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (job->parent) {
|
||||
if (job->parent->mscclppCompatible) {
|
||||
|
||||
+8
-2
@@ -157,11 +157,17 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count,
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
if(IsArchMatch(devProp.gcnArchName, "gfx950"))
|
||||
comm->unroll = NCCL_UNROLL_1;
|
||||
if(IsArchMatch(devProp.gcnArchName, "gfx950")) {
|
||||
if(comm->nNodes == 1)
|
||||
comm->unroll = NCCL_UNROLL_1;
|
||||
else
|
||||
comm->unroll = NCCL_UNROLL_2;
|
||||
}
|
||||
else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)))
|
||||
comm->unroll = NCCL_UNROLL_2;
|
||||
else
|
||||
comm->unroll = NCCL_UNROLL_4;
|
||||
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll+1);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user