Merge remote-tracking branch 'nccl/master' into develop
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
################################################################################
|
||||
# The first command line argument is the path to the directory to generate and
|
||||
@@ -13,8 +14,11 @@ gensrc = sys.argv[1]
|
||||
|
||||
if os.path.exists(gensrc):
|
||||
for name in os.listdir(gensrc):
|
||||
os.remove(os.path.join(gensrc, name))
|
||||
#os.truncate(os.path.join(gensrc, name), 0)
|
||||
path = os.path.join(gensrc, name)
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
elif os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
os.mkdir(gensrc)
|
||||
|
||||
@@ -97,7 +101,7 @@ def enumerate_kernels():
|
||||
yield Rec(coll="ReduceScatter", algo=algo, red=red, ty=ty)
|
||||
|
||||
def required_cuda(k):
|
||||
cudart, arch, specific_sms = 0, 0, None
|
||||
cudart, arch, specific_sms = 0, 600, None
|
||||
is_nvls = k.algo in nvls_algos_by_coll.get(k.coll, [])
|
||||
if is_nvls:
|
||||
cudart = max(cudart, 12010)
|
||||
@@ -136,13 +140,13 @@ def kernel_gencode(k):
|
||||
|
||||
def kernel_cname(k):
|
||||
if k.coll in reductions:
|
||||
return paste("_", "ncclSymDevKernel", k.coll, k.algo, k.red, k.ty)
|
||||
return paste("_", "ncclSymkDevKernel", k.coll, k.algo, k.red, k.ty)
|
||||
else:
|
||||
return paste("_", "ncclSymDevKernel", k.coll, k.algo)
|
||||
return paste("_", "ncclSymkDevKernel", k.coll, k.algo)
|
||||
|
||||
def kernel_conds(k):
|
||||
cudart, arch, specific_sms = required_cuda(k)
|
||||
if cudart == 0: return (None, None)
|
||||
if cudart == 0 and arch == 0: return (None, None)
|
||||
|
||||
cudart_cond = "CUDART_VERSION >= %d"%cudart
|
||||
if not specific_sms:
|
||||
@@ -153,13 +157,13 @@ def kernel_conds(k):
|
||||
|
||||
def instantiate(k):
|
||||
form_red_ty = (
|
||||
"__global__ void {cname}(ncclSymDevArgs NCCL_GRID_CONSTANT const *args) {{\n"
|
||||
" ncclSymRun_{id}<{red}, {ty}>(args);\n"
|
||||
"__global__ void {cname}(ncclSymkDevWorkArgs4K NCCL_GRID_CONSTANT const *args4K) {{\n"
|
||||
" ncclSymkRun_{id}<{red}, {ty}>(args4K->args);\n"
|
||||
"}}"
|
||||
)
|
||||
form = (
|
||||
"__global__ void {cname}(ncclSymDevArgs NCCL_GRID_CONSTANT const *args) {{\n"
|
||||
" ncclSymRun_{id}(args);\n"
|
||||
"__global__ void {cname}(ncclSymkDevWorkArgs4K NCCL_GRID_CONSTANT const *args4K) {{\n"
|
||||
" ncclSymkRun_{id}(args4K->args);\n"
|
||||
"}}"
|
||||
)
|
||||
|
||||
@@ -172,7 +176,7 @@ def instantiate(k):
|
||||
return inst
|
||||
|
||||
def prototype(k):
|
||||
return "__global__ void {cname}(ncclSymDevArgs const *args);".format(cname=kernel_cname(k))
|
||||
return "__global__ void {cname}(ncclSymkDevWorkArgs4K const *args4K);".format(cname=kernel_cname(k))
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -194,20 +198,22 @@ for coll in set(k.coll for k in enumerate_kernels()):
|
||||
if (fname, coll) not in kernels_by_file:
|
||||
kernels_by_file[fname, coll] = []
|
||||
|
||||
files_to_print = ""
|
||||
# Generate each kernel instantiation file
|
||||
for (fname, coll), ks in kernels_by_file.items():
|
||||
files_to_print += fname + ";"
|
||||
with open(os.path.join(gensrc, fname), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, fname))
|
||||
emitln(f, '#include "symmetric.h"')
|
||||
emitln(f, '#include "sym_kernels.h"')
|
||||
emitln(f, '#include "symmetric/kernel.h"')
|
||||
emitln(f, '#include "symmetric/{coll}.h"'.format(coll=coll_to_lower[coll]))
|
||||
for k in ks:
|
||||
emitln(f, instantiate(k))
|
||||
|
||||
# Generate <gensrc>/symmetric_host.cc
|
||||
with open(os.path.join(gensrc, "symmetric_kernels.cc"), "w") as f:
|
||||
# Generate <gensrc>/sym_kernels_host.cc
|
||||
with open(os.path.join(gensrc, "sym_kernels_host.cc"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, "symmetric_kernels.cc"))
|
||||
emitln(f, '#include "symmetric.h"')
|
||||
emitln(f, '#include "sym_kernels.h"')
|
||||
emitln(f, '#include "device.h"')
|
||||
emitln(f, '')
|
||||
|
||||
@@ -215,19 +221,19 @@ with open(os.path.join(gensrc, "symmetric_kernels.cc"), "w") as f:
|
||||
emitln(f, prototype(k))
|
||||
emitln(f, '')
|
||||
|
||||
emitln(f, 'extern int const ncclSymKernelCount = %d;' % len(list(enumerate_kernels())))
|
||||
emitln(f, 'extern void* const ncclSymKernelList[] = {')
|
||||
emitln(f, 'extern int const ncclSymkKernelCount = %d;' % len(list(enumerate_kernels())))
|
||||
emitln(f, 'extern void* const ncclSymkKernelList[] = {')
|
||||
for k in enumerate_kernels():
|
||||
emitln(f, '(void*){cname},'.format(cname=kernel_cname(k)))
|
||||
emitln(f, 'nullptr};')
|
||||
emitln(f, '')
|
||||
|
||||
emitln(f, 'void* ncclSymGetKernelPtr(ncclSymKernelId id, int red, ncclDataType_t ty) {')
|
||||
emitln(f, 'void* ncclSymkGetKernelPtr(ncclSymkKernelId id, int red, ncclDataType_t ty) {')
|
||||
indents += 1
|
||||
emitln(f, 'switch (id) {')
|
||||
emitln(f, 'default: return nullptr;')
|
||||
for (coll, algo), coll_algo_ks in partition(enumerate_kernels(), lambda k: (k.coll, k.algo)).items():
|
||||
emitln(f, 'case ncclSymKernelId_'+coll+'_'+algo+':')
|
||||
emitln(f, 'case ncclSymkKernelId_'+coll+'_'+algo+':')
|
||||
indents += 1
|
||||
if len(coll_algo_ks) == 1:
|
||||
emitln(f, 'return (void*)&'+kernel_cname(coll_algo_ks[0])+';')
|
||||
|
||||
Verwijs in nieuw issue
Block a user