72d2432094
Symmetric memory API and symmetric kernels * Redesign from the ground up, enabling major latency and bandwidth improvements. * Add new API calls to register user-allocated memory among communicator ranks into a NCCL window: ncclCommWindowRegister() and ncclCommWindowDeregister(). The calls currently support symmetric registration for P2P and NVLS, and require VMM memory buffers (i.e., CUMEM must be operational). * Implement specialized kernels taking advantage of symmetrically registered memory, with performance gains expected particularly for small to medium message sizes. * The kernels support 32 bit floating point types and smaller, and sum as the reduction operator, with no more than one collective operation per group. * Floating point summation is always done in fp32 accumulators (with the exception of fp8 on NVLS, where it uses fp16 inside the switch). Thus, the accuracy with fp8 and fp16 data types should be much improved. * This initial implementation supports non-network communicators only (P2P and NVLS transports). * To explore this functionality users need to use the new memory registration API calls with the NCCL_WIN_COLL_SYMMETRIC flag and all ranks of a communicator must pass buffers at the same offset in the same registration when invoking a collective NCCL operation. Add support for DGX Spark. Add support for DirectNIC (CX8) to the internal IB plugin. Add a new ncclCommShrink() API call * It is a non-collective call similar to ncclCommSplit(), which makes it possible to exclude some (possibly unresponsive) ranks from the parent communicator. Add support for loading multiple network plugins * This enables the creation of generic containers that can work across a range of providers. * Allow NCCL_NET_PLUGIN to accept a comma-separated list of plugins to load. NVLink SHARP (NVLS) improvements * Implement NVLS+IB SHARP support for AllGather and ReduceScatter with user buffer registration. This improves performance and reduces the number of CTAs needed to achieve peak bandwidth. * Gracefully fall back by default to other transports if NVLS initialization fails (the old behavior of returning an error code from a NCCL call can be preserved by setting NCCL_NVLS_ENABLE=1). * Decrease the NVLS channel count to 24 on Blackwell systems with multiple NVLink domains per communicator. * Enable fine-tuning of NCCL behavior per communicator using new "ncclConfig_t" members "collnetEnable", "CTAPolicy", and "nvlsCTAs". Profiler improvements * Extend the init function by adding communicator name, comm id (hash), rank, number of ranks, number of nodes, and the NCCL log function to the argument list. This makes the name and the comm id available to all events in the communicator without explicitly passing them to each individual event. Add the communicator id and rank to the profiler trace filename. Now, the communicator name can be set via a new "ncclConfig_t" member "commName". * Improve the accuracy of the GPU kernel events by providing GPU-generated timestamps for the start and stop of every NCCL operation. * Harmonize proxy events, removing overlaps between ProxyOp and ProxyStep states. * Add support for network-defined event updates (through "recordEventState"). * Report the correct number of channels used by every collective/p2p operation (used to be set to nMaxChannels for collectives and absent for p2ps). * Fix the logic on proxyCtrl Idle/Active events (Issue #1162). * Fix an issue where the network proxy profiler could lose track of an event identifier (Issue #1682). * Improve the backward compatibility with plugins older than v4. * Ensure that the work counters are 0-initialized. * Fix a potential race condition in the network profiler that could result in an event being linked to a wrong parent. MNNVL improvements * Increase to 16 the number of NICs used to communicate between MNNVL domains on GB200 systems, to optimize the performance of collective operations. * Add support for more complex MNNVL topologies with up to 32 NICs per node. * If the MNNVL fabric initialization was unsuccessful, NCCL will now fail by default, so as to avoid inadvertently falling back to a potentially much slower network transport. Such failures are typically due to a misconfigured IMEX support on the system. To continue without MNNVL, restart the job with NCCL_MNNVL_ENABLE=0. * Fix a potential hang in alltoall-like communication patterns at a scale of over 80 ranks. * Make NCCL_P2P_DISABLE=1 imply NCCL_MNNVL_ENABLE=0 (so the latter no longer needs to be specified on MNNVL systems). * Fix an initialization failure when NCCL_TOPO_FILE is used on MNNVL systems. * Fix the graph search to exclude non-local NICs. * Fix the SHM transport to use fabric handles on MNNVL systems. NIC Fusion improvements * Disable the creation of fused NICs for physical devices that haven't been merged. * Flatten multiple ports to a single PCI device within the internal IB plugin and reparent dual-port NICs under the first PCI parent. If the parent is not a PCI switch, PCI devices for fused NICs won't be duplicated. * Route traffic on GB200-CX8 systems through DirectNIC, not the host interface. Improve support for platforms with C2C connectivity (e.g., GB200) * Enable GPUDirect RDMA for the NICs by default. * Add support for P2C (PXN over C2C) and the LL128 protocol. Extend NCCL fault tolerance in multithreaded scenarios * Support the creation of multiple nonblocking communicators within a single group and polling in parallel for the completion using multiple threads (one per communicator). Enable ncclImplicitOrderLaunch for CUDA 12.9+ * This can potentially speed up NCCL_IMPLICIT_LAUNCH_ORDER. Improve the netSocket transport latency and control * Provide finer control over the size of the socket send/receive buffers, the task size, and the number of sockets that a single peer can open. * Add support for the inlining of small messages behind the header when using multiple sockets per connection. Improve the readability of the CPU affinity in the debug output * Print it as a range string rather than a bitmask. Fix a potential race condition in graph execution * A contention could arise when mixing graph and non-graph execution. Improve PXN connection code * Avoid duplicate and unused connections. RAS fixes * Fix a memory corruption at job termination time in case of a previously failed initialization of a RAS socket connection. * Fix a race condition leading to a crash when generating a RAS report during communicator initialization (Issues #1669, #1718). * Fix a potential race condition when gathering data for a RAS status report. Fix a potential memory corruption in ncclCommSplit() * Memory could get corrupted when resource sharing was in use and the size of the NVLink domain in the new communicator was smaller than in the old one. Fix asynchronous graph upload * Fix a small memory leak. * Fix oversychronization. Add a check for out-of-memory conditions in ncclMemAlloc() Clean up the NCCL socket code * accept() will retry also if just reading the magic failed (Issue #1613). * connect() will retry also if poll() did not return a POLLOUT event (Issue #1618). * Add error checking in a few instances (Issue #1539). * Fix the loop condition in ncclFindInterfaceMatchSubnet() (Issue #1574). * Clean up the debug output, downgrading WARN messages to INFO in non-critical cases, and printing the peer's address where relevant. Switch NCCL_DEBUG_FILE to line buffering * This should help avoid mixed-up partial output lines in multithreaded cases. Other minor fixes * Improve the checks for buffer overflows in the graph code (Issue #1585). * Extend logging and state clearing to all four events in the internal IB plugin (Issue #1650). * Fix the error path in case IB communication is not ready (Issue #1489). * Add ECE logging for IB fabric. * Fix various minor issues in the graph module (Issue #1635). * Clean up the debug output in the graph code, downgrading WARN messages to INFO in non-critical cases. * Add a missing argument to a directSend() call (Issue #1628). * Remove duplicate code in sendProxySetup() (Issue #1420). * Fix the order of arguments of cudaDeviceCanAccessPeer() (Issue #1507). * Fix compiler warnings with GCC 14. * Fix a typo in a comment (Issue #1236).
424 righe
15 KiB
Python
Executable File
424 righe
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import os
|
|
import sys
|
|
|
|
# Order of redops, tys, protos, algos must match src/include/device.h
|
|
all_colls = ["Broadcast","Reduce","AllGather","ReduceScatter","AllReduce","SendRecv"]
|
|
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
|
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
|
|
all_protos = ["LL","LL128","SIMPLE"]
|
|
all_algos = ["TREE","RING","COLLNET_DIRECT","COLLNET_CHAIN","NVLS","NVLS_TREE","PAT"]
|
|
|
|
################################################################################
|
|
# The first command line argument is the path to the directory to generate and
|
|
# populate.
|
|
|
|
gensrc = sys.argv[1]
|
|
|
|
if os.path.exists(gensrc):
|
|
for name in os.listdir(gensrc):
|
|
os.remove(os.path.join(gensrc, name))
|
|
#os.truncate(os.path.join(gensrc, name), 0)
|
|
else:
|
|
os.mkdir(gensrc)
|
|
|
|
################################################################################
|
|
# The second command line argument is used as a regex to filter the functions
|
|
# which make it into libnccl. This is helpful for reducing the binary when
|
|
# developing device code. The regex supports non-space containing globs '*',
|
|
# parentheses '(x)', and union 'a|b'. The string representing the function has
|
|
# one of the forms:
|
|
#
|
|
# SendRecv
|
|
# (AllGather|Broadcast) <algo> <proto>
|
|
# (AlLReduce|Reduce|ReduceScatter) <redop> <type> <algo> <proto>
|
|
#
|
|
# The possible values for redop, type, algo, proto can be found in the all_<foo>
|
|
# lists at the top of this file.
|
|
#
|
|
# Since the Makefile forwards this from the ONLY_FUNCS variable, useful command
|
|
# line examples are given:
|
|
"""
|
|
# Only send/recv:
|
|
make ONLY_FUNCS="SendRecv"
|
|
|
|
# Only non-reductions:
|
|
make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
|
|
|
# Only AllReduce sum f32 (but all algos, protos)
|
|
make ONLY_FUNCS="AllReduce Sum f32 * *"
|
|
|
|
# Only AllReduce minmax i32 NVLS (but all protos)
|
|
make ONLY_FUNCS="AllReduce MinMax i32 NVLS *"
|
|
|
|
# AllReduce sum <all floats> RING LL128
|
|
make ONLY_FUNCS="AllReduce Sum f32 RING LL128"
|
|
"""
|
|
|
|
# Paste all non-None arguments together with `sep`.
|
|
def paste(sep, *args):
|
|
return sep.join(x for x in args if x is not None)
|
|
|
|
func_pattern = sys.argv[2:3]
|
|
if func_pattern and func_pattern[0]:
|
|
import re
|
|
func_pattern = func_pattern[0]
|
|
func_pattern = func_pattern.replace("*", "[^ ]*")
|
|
func_pattern += "$"
|
|
def func_filter(*fn):
|
|
return None is not re.match(func_pattern, paste(" ", *fn), flags=re.IGNORECASE)
|
|
else:
|
|
def func_filter(coll, redop, ty, algo, proto):
|
|
return True
|
|
|
|
################################################################################
|
|
|
|
algos_of_coll = {
|
|
"AllGather": ["RING","COLLNET_DIRECT","NVLS","PAT"],
|
|
"AllReduce": ["TREE","RING","COLLNET_DIRECT","COLLNET_CHAIN","NVLS","NVLS_TREE"],
|
|
"Broadcast": ["RING"],
|
|
"Reduce": ["RING"],
|
|
"ReduceScatter": ["RING","COLLNET_DIRECT","NVLS","PAT"],
|
|
"SendRecv": [None]
|
|
}
|
|
|
|
coll_camel_to_lower = {
|
|
"AllGather": "all_gather",
|
|
"AllReduce": "all_reduce",
|
|
"Broadcast": "broadcast",
|
|
"Reduce": "reduce",
|
|
"ReduceScatter": "reduce_scatter",
|
|
"SendRecv": "sendrecv"
|
|
}
|
|
coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
|
|
|
################################################################################
|
|
|
|
# Returns pair of minimum required values for (CUDART_VERSION, __CUDA_ARCH__)
|
|
# or None if function is never supported. Note that (0, 0) encodes universal
|
|
# support.
|
|
def required_cuda(coll, redop, ty, algo, proto):
|
|
cudart, arch = 0, 0
|
|
# kernels mapped to by coll="Nop" functions have coll="Generic"
|
|
if coll in ("SendRecv", "Generic", "Nop"): return (cudart, arch)
|
|
|
|
if proto!="SIMPLE" and algo not in ("RING","TREE"): return None
|
|
|
|
if coll in ("AllReduce","Reduce","ReduceScatter"):
|
|
if redop=="SumPostDiv" and ty[0] not in ("i","u"): return None
|
|
if ty=="bf16": cudart = max(cudart, 11000)
|
|
if ty.startswith("f8"):
|
|
cudart = max(cudart, 11080)
|
|
arch = max(arch, 900)
|
|
|
|
if "NVLS" in algo:
|
|
if coll in ("AllReduce","Reduce","ReduceScatter"):
|
|
# Must match ncclNvlsSupported() in src/include/device.h
|
|
nvls_ok = ((ty in ("i32","u32","i64","u64") and redop in ("Sum","MinMax")) or
|
|
(ty in ("f32","f64") and redop=="Sum") or
|
|
(ty in ("f16","bf16") and redop in ("Sum","MinMax")))
|
|
if not nvls_ok: return None
|
|
cudart = max(cudart, 12010)
|
|
arch = max(arch, 900)
|
|
|
|
return (cudart, arch)
|
|
|
|
# Maps functions to the chosen representative for the equivalence class it
|
|
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
|
def equivalent_primary(coll, redop, ty, algo, proto):
|
|
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
|
# map signed integer sum/prod to unsigned
|
|
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
|
return (coll, redop, "u"+ty[1:], algo, proto)
|
|
# map signed integer min/max to unsigned for non-NVLS
|
|
if redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
|
|
return (coll, redop, "u"+ty[1:], algo, proto)
|
|
return (coll, redop, ty, algo, proto)
|
|
|
|
# Map to another func representing the best kernel to use. Every distinct value
|
|
# returned will instantiate a ncclDevKernel specialized to run this func
|
|
# without function call overhead.
|
|
def best_kernel(coll, redop, ty, algo, proto):
|
|
def best(coll, redop, ty, algo, proto):
|
|
# Modify this logic to control how many kernels are specialized.
|
|
if coll=="Nop": return ("Generic", None, None, None, None)
|
|
if coll=="SendRecv": return ("SendRecv", None, None, None, None)
|
|
if coll in ("AllGather","Broadcast"): return (coll, None, None, "RING", "LL")
|
|
return (coll, "Sum", ty, ("TREE" if algo=="TREE" else "RING"), "LL")
|
|
# Need to ensure kernel is specialize for a primary function
|
|
kfn = equivalent_primary(*best(coll, redop, ty, algo, proto))
|
|
# And isn't filtered out.
|
|
if not func_filter(*kfn): return ("Generic", None, None, None, None)
|
|
return kfn
|
|
|
|
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
|
def enumerate_func_rows():
|
|
yield ("SendRecv", None, None, None, None)
|
|
for coll in ("AllGather", "Broadcast"):
|
|
algos = algos_of_coll[coll]
|
|
for algo in algos:
|
|
for proto in all_protos:
|
|
yield (coll, None, None, algo, proto)
|
|
for coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
|
algos = algos_of_coll[coll]
|
|
for redop in all_redops:
|
|
for ty in all_tys:
|
|
for algo in algos:
|
|
for proto in all_protos:
|
|
yield (coll, redop, ty, algo, proto)
|
|
|
|
################################################################################
|
|
|
|
def is_built(coll, redop, ty, algo, proto):
|
|
built = required_cuda(coll, redop, ty, algo, proto)
|
|
built = built and func_filter(coll, redop, ty, algo, proto)
|
|
return built
|
|
|
|
# Returns None if required_cuda(...) is None.
|
|
# Returns the coll="Nop" function if developer has filtered it out.
|
|
# Otherwise just returns func it was given.
|
|
def validate(coll, redop, ty, algo, proto):
|
|
valid = required_cuda(coll, redop, ty, algo, proto)
|
|
built = valid and func_filter(coll, redop, ty, algo, proto)
|
|
if built: return (coll, redop, ty, algo, proto)
|
|
if valid: return ("Nop", None, None, None, None)
|
|
return None
|
|
|
|
# Corresponds to ncclDevFuncRowToId[]
|
|
func_rows = [validate(*fn) for fn in enumerate_func_rows()]
|
|
|
|
# Corresponds to ncclDevFuncTable[]
|
|
primary_funcs = sorted(set(equivalent_primary(*fn) for fn in func_rows if fn is not None))
|
|
|
|
# primary_to_index[primary_funcs[i]] == i
|
|
primary_to_index = {fn: i for (i,fn) in zip(range(len(primary_funcs)), primary_funcs)}
|
|
|
|
kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs))
|
|
|
|
################################################################################
|
|
|
|
# Generate <gensrc>/device_table.cu
|
|
with open(os.path.join(gensrc, "device_table.cu"), "w") as f:
|
|
out = f.write
|
|
out('#include "common.h"\n')
|
|
out("\n")
|
|
|
|
for fn in primary_funcs:
|
|
sym = paste("_", "ncclDevFunc", *fn)
|
|
cudart, arch = required_cuda(*fn)
|
|
if (cudart, arch) != (0, 0):
|
|
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
|
|
out("__device__ void %s();\n" % sym)
|
|
if (cudart, arch) != (0, 0):
|
|
out("#endif\n")
|
|
out("\n")
|
|
|
|
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n");
|
|
index = 0
|
|
for fn in primary_funcs:
|
|
sym = paste("_", "ncclDevFunc", *fn)
|
|
cudart, arch = required_cuda(*fn)
|
|
if (cudart, arch) != (0, 0):
|
|
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart ,arch))
|
|
out("/*%4d*/ %s,\n" % (index, sym))
|
|
if (cudart, arch) != (0, 0):
|
|
out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
|
|
index += 1
|
|
out("nullptr};\n")
|
|
out("\n")
|
|
|
|
out("// Workaround for https://reviews.llvm.org/D55580\n"
|
|
"__device__ void ncclWorkaroundClangD55580() {}\n")
|
|
|
|
# Generate <gensrc>/host_table.cc
|
|
with open(os.path.join(gensrc, "host_table.cc"), "w") as f:
|
|
out = f.write
|
|
out('#include "device.h"\n')
|
|
out("\n")
|
|
|
|
out("extern int const ncclDevFuncIdCount = %d;\n" % len(primary_funcs))
|
|
|
|
# The mapping from function rows to valid primary function ids.
|
|
out("extern int const ncclDevFuncRowToId[] = {\n")
|
|
index = 0
|
|
for fn in func_rows:
|
|
fn_id, comment = -1, ""
|
|
if fn is not None:
|
|
fn_id = primary_to_index[equivalent_primary(*fn)]
|
|
comment = " // " + paste(" ", *fn)
|
|
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
|
|
index += 1
|
|
out("-1};\n")
|
|
out("\n")
|
|
|
|
# Forward declarations of kernels.
|
|
for kfn in kernel_funcs:
|
|
cudart, _ = required_cuda(*kfn)
|
|
sym = paste("_", "ncclDevKernel", *kfn)
|
|
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
|
|
# __global__ below gets removed by the host compiler, which results in
|
|
# Coverity diagnosing a specifiers inconsistency.
|
|
out("// coverity[declaration]\n")
|
|
out("__global__ void %s(ncclDevKernelArgs4K const);\n" % sym)
|
|
if cudart != 0: out("#endif\n")
|
|
out("\n")
|
|
|
|
# List of all kernel function pointers.
|
|
out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs))
|
|
out("extern void* const ncclDevKernelList[] = {\n")
|
|
index = 0
|
|
for kfn in kernel_funcs:
|
|
cudart, _ = required_cuda(*kfn)
|
|
sym = paste("_", "ncclDevKernel", *kfn)
|
|
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
|
|
out("/*%4d*/ (void*)%s,\n" % (index, sym));
|
|
if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
|
|
index += 1
|
|
out("nullptr};\n")
|
|
out("\n")
|
|
|
|
# Maps primary id to kernel function pointer.
|
|
out("extern void* const ncclDevKernelForFunc[] = {\n")
|
|
index = 0
|
|
for fn in primary_funcs:
|
|
kfn = best_kernel(*fn)
|
|
sym = paste("_", "ncclDevKernel", *kfn)
|
|
cudart, _ = required_cuda(*kfn)
|
|
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
|
|
out("/*%4d*/ (void*)%s,\n" % (index, sym))
|
|
if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
|
|
index += 1
|
|
out("nullptr};\n")
|
|
out("\n")
|
|
|
|
# Does the prior map use an explicitly specialized kernel.
|
|
out("extern bool const ncclDevKernelForFuncIsSpecialized[] = {\n")
|
|
index = 0
|
|
for fn in primary_funcs:
|
|
kfn = best_kernel(*fn)
|
|
specialized = "1" if fn == kfn else "0"
|
|
out("/*%4d*/ %s,\n" % (index, specialized))
|
|
index += 1
|
|
out("0};\n")
|
|
|
|
# Maps to .cu filename which implements this func. The only constraint is that
|
|
# "coll" is reflected in the name: formally that no two funcs having different
|
|
# coll's map to the same filename.
|
|
def impl_filename(coll, redop, ty, algo, proto):
|
|
return "%s.cu" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
|
|
|
|
# Partition the functions and kernels to the .cu filenames. The partition is
|
|
# a dictionary mapping filename to (coll, func-tuple list)
|
|
def partition_by_name(fns):
|
|
ans = {}
|
|
for fn in fns:
|
|
name = impl_filename(*fn)
|
|
coll = fn[0]
|
|
if name not in ans:
|
|
ans[name] = (coll, [])
|
|
ans[name][1].append(fn)
|
|
return ans
|
|
|
|
name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn[0]!="Nop")
|
|
name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Generic")
|
|
|
|
# Generate <gensrc>/rules.mk
|
|
with open(os.path.join(gensrc, "rules.mk"), "w") as f:
|
|
out = f.write
|
|
impl_names = sorted(name_to_funcs.keys())
|
|
names = impl_names + ["host_table.cc", "device_table.cu"]
|
|
out("LIB_OBJS_GEN = $(patsubst %,$(OBJDIR)/genobj/%.o,{names})\n"
|
|
.format(names=" ".join(names)))
|
|
out("\n")
|
|
|
|
# For each <coll>_<op>_<ty>.cu compile to a .cu.o file. Notice the dependencies
|
|
# come from the suffix-erased file (e.g. 'gensrc/all_reduce.cu')
|
|
for name in impl_names:
|
|
coll = name_to_funcs[name][0]
|
|
out(
|
|
"$(OBJDIR)/genobj/{name}.o: $(OBJDIR)/gensrc $(OBJDIR)/genobj/{lower_coll}.cu.d\n"
|
|
"\t" "$(call COMPILE,$@,$(OBJDIR)/gensrc/{name})\n"
|
|
"\n"
|
|
.format(name=name, lower_coll=coll_camel_to_lower[coll])
|
|
)
|
|
|
|
# Add the suffix-erased .cu's which are used only for dependency scraping.
|
|
for coll in set(coll for (coll,_,_,_,_) in primary_funcs if coll!="Nop"):
|
|
name = impl_filename(coll, None, None, None, None)
|
|
if name not in name_to_funcs:
|
|
name_to_funcs[name] = (coll, [])
|
|
|
|
redop_to_cxx = {
|
|
None: "FuncCopy",
|
|
"Sum": "FuncSum",
|
|
"Prod": "FuncProd",
|
|
"MinMax": "FuncMinMax",
|
|
"PreMulSum": "FuncPreMulSum",
|
|
"SumPostDiv": "FuncSumPostDiv"
|
|
}
|
|
|
|
ty_to_cxx = {
|
|
None: "int8_t",
|
|
"i8": "int8_t",
|
|
"u8": "uint8_t",
|
|
"i32": "int32_t",
|
|
"u32": "uint32_t",
|
|
"i64": "int64_t",
|
|
"u64": "uint64_t",
|
|
"f16": "half",
|
|
"f32": "float",
|
|
"f64": "double",
|
|
"bf16": "__nv_bfloat16",
|
|
"f8e4m3": "__nv_fp8_e4m3",
|
|
"f8e5m2": "__nv_fp8_e5m2"
|
|
}
|
|
|
|
# Generate each <gensrc>/<impl>.cu:
|
|
for name in name_to_funcs.keys():
|
|
(coll, fns) = name_to_funcs[name]
|
|
with open(os.path.join(gensrc, name), "w") as f:
|
|
out = f.write
|
|
out(
|
|
'#include "common.h"\n'
|
|
'#include "{lower_coll}.h"\n'
|
|
.format(lower_coll=coll_camel_to_lower[coll])
|
|
)
|
|
|
|
(_, kfns) = name_to_kernels.get(name) or (None, [])
|
|
for kfn in kfns:
|
|
(coll, redop, ty, algo, proto) = kfn
|
|
sym = paste("_", coll, redop, ty, algo, proto)
|
|
fn_id = primary_to_index[kfn]
|
|
cudart, arch = required_cuda(*kfn)
|
|
s = "DEFINE_ncclDevKernel({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {fn_id})\n"
|
|
if (cudart, arch) != (0, 0):
|
|
# Add conditional compilation logic around s. If CUDART_VERSION is satisfactory
|
|
# we must compile a kernel regardless of __CUDA_ARCH__ since the host code has
|
|
# to link against some stub.
|
|
s = "#if CUDART_VERSION >= {cudart}\n" \
|
|
" #if __CUDA_ARCH__ < {arch}\n" \
|
|
" DEFINE_ncclDevKernel_nop({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {fn_id})\n" \
|
|
" #else\n" \
|
|
" " + s + \
|
|
" #endif\n" \
|
|
"#endif\n"
|
|
out(s.format(
|
|
cudart=cudart, arch=arch, sym=sym, coll=coll,
|
|
redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
|
algo=(algo or "RING"), proto=(proto or "SIMPLE"), fn_id=fn_id
|
|
))
|
|
|
|
for fn in fns:
|
|
(coll, redop, ty, algo, proto) = fn
|
|
sym = paste("_", coll, redop, ty, algo, proto)
|
|
cudart, arch = required_cuda(*fn)
|
|
if (cudart, arch) != (0, 0):
|
|
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
|
|
out(
|
|
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto})\n"
|
|
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
|
algo=(algo or "RING"), proto=(proto or "SIMPLE"))
|
|
)
|
|
if (cudart, arch) != (0, 0):
|
|
out("#endif\n")
|