Fix typo in ncclGetKernelIndex macro (#1424)
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
4336a0f3a3
Коммит
dfe4a3ed81
@@ -398,10 +398,13 @@ if is_colltrace:
|
||||
out('#include "nccl_common.h"\n#include "device.h"\n')
|
||||
out("\n")
|
||||
|
||||
seen_fns = set()
|
||||
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
for fn in primary_funcs:
|
||||
if fn[5] == "4": continue
|
||||
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn[:-1]))
|
||||
fn_no_unroll = fn[:-1]
|
||||
if fn_no_unroll not in seen_fns:
|
||||
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn_no_unroll))
|
||||
seen_fns.add(fn_no_unroll)
|
||||
for ty in all_tys:
|
||||
out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n')
|
||||
out("};\n")
|
||||
|
||||
@@ -31,7 +31,7 @@ struct ncclKernelMatch {
|
||||
};
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + (p_comm)->collTraceThread ? 2 : 0)
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceThread ? 2 : 0))
|
||||
static ncclKernelMatch const ncclKerns[4] = {
|
||||
{(void *)ncclDevKernel_Generic, true},
|
||||
{(void *)ncclDevKernel_Generic_4, true},
|
||||
|
||||
Ссылка в новой задаче
Block a user