Merge remote-tracking branch 'nccl/master' into develop

This commit is contained in:
Marzieh Berenjkoub
2026-01-20 13:01:49 -06:00
bovenliggende 239d62f545 f1308997d0
commit 858b4e76eb
240 gewijzigde bestanden met toevoegingen van 16266 en 3578 verwijderingen
+25 -19
Bestand weergeven
@@ -4,6 +4,7 @@
#!/usr/bin/env python3
import os
import sys
import shutil
################################################################################
# The first command line argument is the path to the directory to generate and
@@ -13,8 +14,11 @@ gensrc = sys.argv[1]
if os.path.exists(gensrc):
for name in os.listdir(gensrc):
os.remove(os.path.join(gensrc, name))
#os.truncate(os.path.join(gensrc, name), 0)
path = os.path.join(gensrc, name)
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
else:
os.mkdir(gensrc)
@@ -97,7 +101,7 @@ def enumerate_kernels():
yield Rec(coll="ReduceScatter", algo=algo, red=red, ty=ty)
def required_cuda(k):
cudart, arch, specific_sms = 0, 0, None
cudart, arch, specific_sms = 0, 600, None
is_nvls = k.algo in nvls_algos_by_coll.get(k.coll, [])
if is_nvls:
cudart = max(cudart, 12010)
@@ -136,13 +140,13 @@ def kernel_gencode(k):
def kernel_cname(k):
if k.coll in reductions:
return paste("_", "ncclSymDevKernel", k.coll, k.algo, k.red, k.ty)
return paste("_", "ncclSymkDevKernel", k.coll, k.algo, k.red, k.ty)
else:
return paste("_", "ncclSymDevKernel", k.coll, k.algo)
return paste("_", "ncclSymkDevKernel", k.coll, k.algo)
def kernel_conds(k):
cudart, arch, specific_sms = required_cuda(k)
if cudart == 0: return (None, None)
if cudart == 0 and arch == 0: return (None, None)
cudart_cond = "CUDART_VERSION >= %d"%cudart
if not specific_sms:
@@ -153,13 +157,13 @@ def kernel_conds(k):
def instantiate(k):
form_red_ty = (
"__global__ void {cname}(ncclSymDevArgs NCCL_GRID_CONSTANT const *args) {{\n"
" ncclSymRun_{id}<{red}, {ty}>(args);\n"
"__global__ void {cname}(ncclSymkDevWorkArgs4K NCCL_GRID_CONSTANT const *args4K) {{\n"
" ncclSymkRun_{id}<{red}, {ty}>(args4K->args);\n"
"}}"
)
form = (
"__global__ void {cname}(ncclSymDevArgs NCCL_GRID_CONSTANT const *args) {{\n"
" ncclSymRun_{id}(args);\n"
"__global__ void {cname}(ncclSymkDevWorkArgs4K NCCL_GRID_CONSTANT const *args4K) {{\n"
" ncclSymkRun_{id}(args4K->args);\n"
"}}"
)
@@ -172,7 +176,7 @@ def instantiate(k):
return inst
def prototype(k):
return "__global__ void {cname}(ncclSymDevArgs const *args);".format(cname=kernel_cname(k))
return "__global__ void {cname}(ncclSymkDevWorkArgs4K const *args4K);".format(cname=kernel_cname(k))
################################################################################
@@ -194,20 +198,22 @@ for coll in set(k.coll for k in enumerate_kernels()):
if (fname, coll) not in kernels_by_file:
kernels_by_file[fname, coll] = []
files_to_print = ""
# Generate each kernel instantiation file
for (fname, coll), ks in kernels_by_file.items():
files_to_print += fname + ";"
with open(os.path.join(gensrc, fname), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, fname))
emitln(f, '#include "symmetric.h"')
emitln(f, '#include "sym_kernels.h"')
emitln(f, '#include "symmetric/kernel.h"')
emitln(f, '#include "symmetric/{coll}.h"'.format(coll=coll_to_lower[coll]))
for k in ks:
emitln(f, instantiate(k))
# Generate <gensrc>/symmetric_host.cc
with open(os.path.join(gensrc, "symmetric_kernels.cc"), "w") as f:
# Generate <gensrc>/sym_kernels_host.cc
with open(os.path.join(gensrc, "sym_kernels_host.cc"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, "symmetric_kernels.cc"))
emitln(f, '#include "symmetric.h"')
emitln(f, '#include "sym_kernels.h"')
emitln(f, '#include "device.h"')
emitln(f, '')
@@ -215,19 +221,19 @@ with open(os.path.join(gensrc, "symmetric_kernels.cc"), "w") as f:
emitln(f, prototype(k))
emitln(f, '')
emitln(f, 'extern int const ncclSymKernelCount = %d;' % len(list(enumerate_kernels())))
emitln(f, 'extern void* const ncclSymKernelList[] = {')
emitln(f, 'extern int const ncclSymkKernelCount = %d;' % len(list(enumerate_kernels())))
emitln(f, 'extern void* const ncclSymkKernelList[] = {')
for k in enumerate_kernels():
emitln(f, '(void*){cname},'.format(cname=kernel_cname(k)))
emitln(f, 'nullptr};')
emitln(f, '')
emitln(f, 'void* ncclSymGetKernelPtr(ncclSymKernelId id, int red, ncclDataType_t ty) {')
emitln(f, 'void* ncclSymkGetKernelPtr(ncclSymkKernelId id, int red, ncclDataType_t ty) {')
indents += 1
emitln(f, 'switch (id) {')
emitln(f, 'default: return nullptr;')
for (coll, algo), coll_algo_ks in partition(enumerate_kernels(), lambda k: (k.coll, k.algo)).items():
emitln(f, 'case ncclSymKernelId_'+coll+'_'+algo+':')
emitln(f, 'case ncclSymkKernelId_'+coll+'_'+algo+':')
indents += 1
if len(coll_algo_ks) == 1:
emitln(f, 'return (void*)&'+kernel_cname(coll_algo_ks[0])+';')