diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index c9d6d9b266..88f3456caf 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -52,16 +52,16 @@ else: # make ONLY_FUNCS="AllReduce * * Sum i32" # # # Only AllReduce RING Max float (but all protos) -# make ONLY_FUNCS="AllReduce RING * Max float" +# make ONLY_FUNCS="AllReduce RING * Max f32" # # # AllReduce TREE LL128 Prod rccl_bfloat16 -# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16" +# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16" # # # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter) -# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float" +# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32" # --- or --- -# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" -# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t" +# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32" +# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|AllGather RING LL/SIMPLE Sum i8|AllToAllPivot RING SIMPLE Sum i8|Broadcast RING LL/SIMPLE Sum i8|Reduce RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|ReduceScatter RING LL/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2|SendRecv RING SIMPLE Sum i8" # Paste all non-None arguments together with `sep`. def paste(sep, *args): @@ -134,14 +134,14 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower} ################################################################################ def calc_unroll_for_local_arch(): - if not is_local_arch_only: - return + if not is_local_arch_only: + return all_unroll rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo" res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True) rocminfo_output = res.stdout - + # Parse rocminfo binary output gfx_targets = {} curr_name = None @@ -156,26 +156,28 @@ def calc_unroll_for_local_arch(): cu_count = int(line.split(':')[-1].strip()) gfx_targets[(curr_name, cu_count)] = None curr_name = None - + # We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts # Use (gfx_name, cu_count) as key for dictionary and convert it to list here gfx_targets = list(gfx_targets.keys()) - # Homogeneous system is required to build for only 1 varient of unroll factor + # Homogeneous system is required to build for only 1 variant of unroll factor (except for gfx950) if len(gfx_targets) == 1: gfx_name, cu_count = gfx_targets[0] if "gfx950" == gfx_name: - return 1 + return ["1", "2"] elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80): - return 2 + return ["2"] else: - return 4 + return ["4"] + else: + return all_unroll # Helper function to check if the conditions for the collective is being met -def func_validate(coll, algo, proto, redop, ty): +def func_validate(coll, algo, proto, redop, ty, unroll): if redop == "SumPostDiv" and ty[0] not in ("i","u"): return False - if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]: + if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll: return False return True @@ -191,9 +193,6 @@ def func_filter(function_params, current_idx, item_list=None): # If the paramter is equal to '*', include all possible cases for it if current_element == "*": - if current_idx == 0: - raise ValueError("Error: Paramter 'COLL' can not be type all '*'.") - # all_params list must be in the same order as function_params --> # Get the current list from all_params current_list = all_params[current_idx] @@ -225,7 +224,7 @@ def func_filter(function_params, current_idx, item_list=None): else: coll, algo, proto, redop, ty, unroll = item_list - if func_validate(coll, algo, proto, redop, ty): + if func_validate(coll, algo, proto, redop, ty, unroll): yield(coll, algo, proto, redop, ty, unroll) # Parse ONLY_FUNCS input and feed it to func_filter @@ -247,10 +246,6 @@ def parse_input(func_pattern): # Maps functions to the chosen representative for the equivalence class it # belongs to. For instance (sum, signed int) maps to (sum, unsigned int). def equivalent_primary(coll, algo, proto, redop, ty, unroll): - # if local arch only, we only need to build for 1 varient of coll_unroll. - # map the other varient of coll_unroll to this one. - if coll_unroll: - unroll = str(coll_unroll) if coll in ("AllReduce", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i": @@ -268,7 +263,7 @@ def enumerate_func_rows(): for proto in all_protos: for redop in all_redops: for ty in all_tys: - if func_validate(coll, algo, proto, redop, ty): + if func_validate(coll, algo, proto, redop, ty, unroll): yield (coll, algo, proto, redop, ty, unroll) # Sort the hashmap based on custom key @@ -286,7 +281,9 @@ def custom_sort_key(fn): ################################################################################ -coll_unroll = calc_unroll_for_local_arch() +# if building for local arch only, we only need to build for 1 variant of unroll for most gfx targets, +# except for gfx950 +all_unroll = calc_unroll_for_local_arch() # Corresponds to ncclDevFuncRowToId[] func_rows = [fn for fn in enumerate_func_rows()] @@ -459,7 +456,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: # The mapping from function rows to valid primary function ids. out("extern int const ncclDevFuncRowToId[] = {\n") index = 0 - for fn in func_rows[:len(func_rows)//3]: + for fn in func_rows[:len(func_rows)//len(all_unroll)]: fn_id, comment = -1, "" if fn is not None: fn_id = primary_to_index[equivalent_primary(*fn)] diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 392d6b4c7c..e4c1c5d9f7 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -612,8 +612,7 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in // RCCL: create persistent stream for calloc CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking)); - // RCCL: determine and set unroll factor for comm - NCCLCHECK(commSetUnrollFactor(comm)); + comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false; comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false; @@ -1965,6 +1964,9 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail); + // RCCL: determine and set unroll factor for comm + NCCLCHECK(commSetUnrollFactor(comm)); + #ifdef ENABLE_MSCCLPP if (job->parent) { if (job->parent->mscclppCompatible) { diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index 162df74862..bf258a39ce 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -157,11 +157,17 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { hipDeviceProp_t devProp; CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev)); - if(IsArchMatch(devProp.gcnArchName, "gfx950")) - comm->unroll = NCCL_UNROLL_1; + if(IsArchMatch(devProp.gcnArchName, "gfx950")) { + if(comm->nNodes == 1) + comm->unroll = NCCL_UNROLL_1; + else + comm->unroll = NCCL_UNROLL_2; + } else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))) comm->unroll = NCCL_UNROLL_2; else comm->unroll = NCCL_UNROLL_4; + + INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll+1); return ncclSuccess; }