diff --git a/src/device/generate.py b/src/device/generate.py index 6fa02c44dd..482975538d 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -398,10 +398,13 @@ if is_colltrace: out('#include "nccl_common.h"\n#include "device.h"\n') out("\n") + seen_fns = set() out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n") for fn in primary_funcs: - if fn[5] == "4": continue - out(' "%s",\n' % paste("_", "ncclDevFunc", *fn[:-1])) + fn_no_unroll = fn[:-1] + if fn_no_unroll not in seen_fns: + out(' "%s",\n' % paste("_", "ncclDevFunc", *fn_no_unroll)) + seen_fns.add(fn_no_unroll) for ty in all_tys: out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n') out("};\n") diff --git a/src/enqueue.cc b/src/enqueue.cc index 769389b813..9cb0044a89 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -31,7 +31,7 @@ struct ncclKernelMatch { }; #ifdef ENABLE_COLLTRACE -#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + (p_comm)->collTraceThread ? 2 : 0) +#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceThread ? 2 : 0)) static ncclKernelMatch const ncclKerns[4] = { {(void *)ncclDevKernel_Generic, true}, {(void *)ncclDevKernel_Generic_4, true},