Files
rocm-systems/src/device/symmetric/generate.py
T

257 خطوط
7.5 KiB
Python

# Modification Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
2025-05-29 20:56:40 -07:00
#!/usr/bin/env python3
import os
import sys
2025-09-02 13:21:14 -07:00
import shutil
2025-05-29 20:56:40 -07:00
################################################################################
# The first command line argument is the path to the directory to generate and
# populate.
gensrc = sys.argv[1]
if os.path.exists(gensrc):
for name in os.listdir(gensrc):
2025-09-02 13:21:14 -07:00
path = os.path.join(gensrc, name)
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
2025-05-29 20:56:40 -07:00
else:
os.mkdir(gensrc)
def paste(sep, *args):
return sep.join(args)
indents = 0
def emitln(f, lines):
global indents
for ln in ((lines,) if isinstance(lines, str) else lines):
f.write(' '*indents + ln + '\n')
def indent(s):
return '\n'.join(' '+l for l in s.splitlines())
class Rec(object):
def __init__(me, **kw):
me.__dict__.update(kw)
def __eq__(x, y):
if len(x) != len(y): return False
for k in x:
if k not in y: return False
if x[k] != y[k]: return False
return True
def __hash__(me):
h = 0
for k in me.__dict__:
h += hash((k, me.__dict__[k]))
return h
################################################################################
# Edit this region for introducing new algos etc
reductions = ["AllReduce","ReduceScatter"]
all_reds = ["sum"]
all_tys = ["f32","f16","bf16","f8e4m3","f8e5m2"]
nvls_algos_by_coll = {
"AllReduce": ["AGxLLMC_R","RSxLDMC_AGxSTMC"],
"ReduceScatter": ["LDMC"]
}
ldmc_algos = ["RSxLDMC_AGxSTMC", "LDMC"]
coll_to_lower = {
"AllGather": "all_gather",
"AllReduce": "all_reduce",
"ReduceScatter": "reduce_scatter"
}
red_to_ncclDevRedOp = {
"sum": "ncclDevSum"
}
red_to_Func = {
"sum": "FuncSum"
}
ty_to_ncclDataType = {
"f32": "ncclFloat32",
"f16": "ncclFloat16",
"bf16": "ncclBfloat16",
"f8e4m3": "ncclFloat8e4m3",
"f8e5m2": "ncclFloat8e5m2"
}
ty_to_cxxtype = {
"f32": "float",
"f16": "half",
"bf16": "hip_bfloat16",
"f8e4m3": "rccl_float8",
"f8e5m2": "rccl_bfloat8"
2025-05-29 20:56:40 -07:00
}
def enumerate_kernels():
for algo in ["LL","ST"]:
2025-05-29 20:56:40 -07:00
yield Rec(coll="AllGather", algo=algo)
for red in all_reds:
for ty in all_tys:
for algo in ["AGxLL_R","RSxLD_AGxST"]:
2025-05-29 20:56:40 -07:00
yield Rec(coll="AllReduce", algo=algo, red=red, ty=ty)
for algo in ["LL","LD"]:
2025-05-29 20:56:40 -07:00
yield Rec(coll="ReduceScatter", algo=algo, red=red, ty=ty)
def required_cuda(k):
2025-09-02 13:21:14 -07:00
cudart, arch, specific_sms = 0, 600, None
2025-05-29 20:56:40 -07:00
is_nvls = k.algo in nvls_algos_by_coll.get(k.coll, [])
if is_nvls:
cudart = max(cudart, 12010)
arch = 900
if k.coll in reductions:
if k.ty == "bf16":
cudart = max(cudart, 11000)
if k.ty.startswith("f8"):
cudart = max(cudart, 11080)
arch = 900
if k.algo in ldmc_algos:
cudart = 12070
arch = None
2025-06-18 10:34:47 -07:00
specific_sms = ["100a", "101a", "100f", "101f", "120a", "121a"]
2025-05-29 20:56:40 -07:00
return (cudart, arch, specific_sms)
################################################################################
def kernel_fdep(k):
return coll_to_lower[k.coll] + '.cpp'
2025-05-29 20:56:40 -07:00
def kernel_fname(k):
if k.coll in reductions:
if k.algo in ldmc_algos and k.ty.startswith('f8'):
return paste('_', coll_to_lower[k.coll], k.red, k.ty, k.algo) + '.cpp'
2025-05-29 20:56:40 -07:00
else:
return paste('_', coll_to_lower[k.coll], k.red, k.ty) + '.cpp'
2025-05-29 20:56:40 -07:00
else:
return coll_to_lower[k.coll] + '.cpp'
2025-05-29 20:56:40 -07:00
def kernel_gencode(k):
if k.coll in reductions and k.algo in ldmc_algos and k.ty.startswith('f8'):
return "$(NVCC_GENCODE_LDMC_FP8)"
else:
return "$(NVCC_GENCODE)"
def kernel_cname(k):
if k.coll in reductions:
2025-09-02 13:21:14 -07:00
return paste("_", "ncclSymkDevKernel", k.coll, k.algo, k.red, k.ty)
2025-05-29 20:56:40 -07:00
else:
2025-09-02 13:21:14 -07:00
return paste("_", "ncclSymkDevKernel", k.coll, k.algo)
2025-05-29 20:56:40 -07:00
def kernel_conds(k):
cudart, arch, specific_sms = required_cuda(k)
2025-09-02 13:21:14 -07:00
if cudart == 0 and arch == 0: return (None, None)
2025-05-29 20:56:40 -07:00
cudart_cond = "CUDART_VERSION >= %d"%cudart
if not specific_sms:
arch_cond = "__CUDA_ARCH__ >= %d"%arch
else:
2025-06-18 10:34:47 -07:00
arch_cond = " || ".join(["0"] + ["NCCL_CUDA_ARCH_%sSPECIFIC==%d"%("FAMILY_" if sm[-1] == "f" else "", 10*int(sm.replace('a', '').replace('f', ''))) for sm in specific_sms])
2025-05-29 20:56:40 -07:00
return cudart_cond, arch_cond
def instantiate(k):
form_red_ty = (
"__global__ void {cname}(ncclSymkDevWorkArgs4K NCCL_GRID_CONSTANT const *args4K) {{\n"
" ncclSymkRun_{id}<{red}, {ty}>(args4K->args);\n"
"}}"
)
form = (
"__global__ void {cname}(ncclSymkDevWorkArgs4K NCCL_GRID_CONSTANT const *args4K) {{\n"
" ncclSymkRun_{id}(args4K->args);\n"
"}}"
)
2025-05-29 20:56:40 -07:00
id = k.coll+'_'+k.algo
cname = kernel_cname(k)
if k.coll in reductions:
inst = form_red_ty.format(cname=cname, id=id, red=red_to_Func[k.red], ty=ty_to_cxxtype[k.ty])
2025-05-29 20:56:40 -07:00
else:
inst = form.format(cname=cname, id=id)
2025-05-29 20:56:40 -07:00
return inst
def prototype(k):
return "__global__ void {cname}(ncclSymkDevWorkArgs4K const *args4K);".format(cname=kernel_cname(k))
2025-05-29 20:56:40 -07:00
################################################################################
def partition(vals, keyfn):
ans = {}
for x in vals:
k = keyfn(x)
if k not in ans:
ans[k] = []
ans[k].append(x)
return ans
kernels_by_file = partition(enumerate_kernels(), lambda k: (kernel_fname(k), k.coll))
# Add dependency only files (e.g. allreduce.cpp)
2025-05-29 20:56:40 -07:00
for coll in set(k.coll for k in enumerate_kernels()):
fname = coll_to_lower[coll]+'.cpp'
2025-05-29 20:56:40 -07:00
if (fname, coll) not in kernels_by_file:
kernels_by_file[fname, coll] = []
2025-09-02 13:21:14 -07:00
files_to_print = ""
2025-05-29 20:56:40 -07:00
# Generate each kernel instantiation file
for (fname, coll), ks in kernels_by_file.items():
2025-09-02 13:21:14 -07:00
files_to_print += fname + ";"
2025-05-29 20:56:40 -07:00
with open(os.path.join(gensrc, fname), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, fname))
2025-09-02 13:21:14 -07:00
emitln(f, '#include "sym_kernels.h"')
emitln(f, '#include "symmetric/kernel.h"')
emitln(f, '#include "symmetric/{coll}.h"'.format(coll=coll_to_lower[coll]))
2025-05-29 20:56:40 -07:00
for k in ks:
emitln(f, instantiate(k))
2025-09-02 13:21:14 -07:00
# Generate <gensrc>/sym_kernels_host.cc
with open(os.path.join(gensrc, "sym_kernels_host.cc"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, "symmetric_kernels.cc"))
2025-09-02 13:21:14 -07:00
emitln(f, '#include "sym_kernels.h"')
2025-05-29 20:56:40 -07:00
emitln(f, '#include "device.h"')
emitln(f, '')
for k in enumerate_kernels():
emitln(f, prototype(k))
emitln(f, '')
2025-09-02 13:21:14 -07:00
emitln(f, 'extern int const ncclSymkKernelCount = %d;' % len(list(enumerate_kernels())))
emitln(f, 'extern void* const ncclSymkKernelList[] = {')
2025-05-29 20:56:40 -07:00
for k in enumerate_kernels():
emitln(f, '(void*){cname},'.format(cname=kernel_cname(k)))
emitln(f, 'nullptr};')
emitln(f, '')
2025-09-02 13:21:14 -07:00
emitln(f, 'void* ncclSymkGetKernelPtr(ncclSymkKernelId id, int red, ncclDataType_t ty) {')
2025-05-29 20:56:40 -07:00
indents += 1
emitln(f, 'switch (id) {')
emitln(f, 'default: return nullptr;')
for (coll, algo), coll_algo_ks in partition(enumerate_kernels(), lambda k: (k.coll, k.algo)).items():
2025-09-02 13:21:14 -07:00
emitln(f, 'case ncclSymkKernelId_'+coll+'_'+algo+':')
2025-05-29 20:56:40 -07:00
indents += 1
if len(coll_algo_ks) == 1:
emitln(f, 'return (void*)&'+kernel_cname(coll_algo_ks[0])+';')
else:
emitln(f, 'switch ((ncclDevRedOp_t)red) {')
emitln(f, 'default: return nullptr;')
for red, coll_algo_red_ks in partition(coll_algo_ks, lambda k: k.red).items():
emitln(f, 'case '+red_to_ncclDevRedOp[red]+':')
indents += 1
emitln(f, 'switch (ty) {')
emitln(f, 'default: return nullptr;')
for k in coll_algo_red_ks:
emitln(f, 'case '+ty_to_ncclDataType[k.ty]+': return (void*)'+kernel_cname(k)+';')
emitln(f, '}')
indents -= 1
emitln(f, '}')
indents -=1
emitln(f, '}')
indents -= 1
emitln(f, '}')