From dfe4a3ed815e4bc5daaa6be8b2e851f989bb8fb6 Mon Sep 17 00:00:00 2001 From: Bertan Dogancay <111835151+BertanDogancay@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:40:05 -0500 Subject: [PATCH] Fix typo in ncclGetKernelIndex macro (#1424) --- src/device/generate.py | 7 +++++-- src/enqueue.cc | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/device/generate.py b/src/device/generate.py index 6fa02c44dd..482975538d 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -398,10 +398,13 @@ if is_colltrace: out('#include "nccl_common.h"\n#include "device.h"\n') out("\n") + seen_fns = set() out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n") for fn in primary_funcs: - if fn[5] == "4": continue - out(' "%s",\n' % paste("_", "ncclDevFunc", *fn[:-1])) + fn_no_unroll = fn[:-1] + if fn_no_unroll not in seen_fns: + out(' "%s",\n' % paste("_", "ncclDevFunc", *fn_no_unroll)) + seen_fns.add(fn_no_unroll) for ty in all_tys: out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n') out("};\n") diff --git a/src/enqueue.cc b/src/enqueue.cc index 769389b813..9cb0044a89 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -31,7 +31,7 @@ struct ncclKernelMatch { }; #ifdef ENABLE_COLLTRACE -#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + (p_comm)->collTraceThread ? 2 : 0) +#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceThread ? 2 : 0)) static ncclKernelMatch const ncclKerns[4] = { {(void *)ncclDevKernel_Generic, true}, {(void *)ncclDevKernel_Generic_4, true},