From 26e982d91344bbf8d4fec773f4a0752d456187ef Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:28:26 -0700 Subject: [PATCH] Reduce NPKit latency overhead in MSCCL kernel (#893) * Reduce NPKit latency overhead in MSCCL kernel * Fix build error without NPKit enable --- src/collectives/device/common.h | 6 ++++ src/collectives/device/msccl_kernel_impl.h | 39 ++++++++++------------ src/include/npkit/npkit.h | 15 +++++++++ 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 49eea4ee6e..77da1ded1d 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -359,6 +359,8 @@ struct ncclShmemGroup { uint64_t barrier_next[NCCL_MAX_GROUPS]; }; +#define LDS_NUM_EVENTS 64 + struct ncclShmemData { struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; uint64_t redOpArgs[NCCL_MAX_NVLS_ARITY+1]; @@ -374,6 +376,10 @@ struct ncclShmemData { #ifdef ENABLE_PROFILING struct ncclProf prof; #endif +#if defined(ENABLE_NPKIT) + NpKitEvent event_buffer[LDS_NUM_EVENTS]; + uint64_t event_buffer_head; +#endif }; static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "ncclShmem.work needs to be 16B aligned"); diff --git a/src/collectives/device/msccl_kernel_impl.h b/src/collectives/device/msccl_kernel_impl.h index 40cf056f64..ba9460d48f 100644 --- a/src/collectives/device/msccl_kernel_impl.h +++ b/src/collectives/device/msccl_kernel_impl.h @@ -186,24 +186,23 @@ __device__ __forceinline__ void mscclRunInterpreter( } copyToShmem8(tid%WARP_SIZE, dst, src, bytes); } - __synclds(); // publish shmem #if defined(ENABLE_NPKIT) int npKitCtxIdx = bid; + if (tid == 0) ncclShmem.event_buffer_head = 0; #endif + __synclds(); // publish shmem #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) if (tid == 0) { uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp; - NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp, - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp); } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif @@ -263,8 +262,7 @@ __device__ __forceinline__ void mscclRunInterpreter( #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_XCC_ID)" : "=s" (xcc_id)); #endif - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif @@ -317,31 +315,27 @@ __device__ __forceinline__ void mscclRunInterpreter( if (t->type == MSCCL_SEND) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_ENTRY) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif prims.sendWithBarrier(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end. #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif } else if (t->type == MSCCL_RECV) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_ENTRY) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif prims.recv(dstOffset, thisNelem); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_EXIT) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif } @@ -350,8 +344,7 @@ __device__ __forceinline__ void mscclRunInterpreter( if (thisNelem < nthreads){ #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif @@ -387,8 +380,7 @@ __device__ __forceinline__ void mscclRunInterpreter( #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif @@ -443,10 +435,15 @@ __device__ __forceinline__ void mscclRunInterpreter( } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RUN_EXIT) if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP()); } #endif +#if defined(ENABLE_NPKIT) + __synclds(); + NpKitEventCollectContext* ctx = ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx; + copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head); + if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head; +#endif } #define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \ diff --git a/src/include/npkit/npkit.h b/src/include/npkit/npkit.h index cfec52aa1b..c535b2ccde 100644 --- a/src/include/npkit/npkit.h +++ b/src/include/npkit/npkit.h @@ -13,9 +13,11 @@ #include "npkit/npkit_event.h" #include "npkit/npkit_struct.h" +#include "common.h" #define NPKIT_GET_GPU_TIMESTAMP wall_clock64 + class NpKit { public: static const uint64_t kNumGpuEventBuffers = 512; @@ -43,6 +45,19 @@ class NpKit { } } + static inline __device__ void CollectGpuEventLDS(uint8_t type, int64_t size, uint32_t rsvd, uint64_t timestamp) { +#if defined(ENABLE_NPKIT) + if (ncclShmem.event_buffer_head < LDS_NUM_EVENTS) { + NpKitEvent& event = ncclShmem.event_buffer[ncclShmem.event_buffer_head]; + event.fields.type = type; + event.fields.size = size < 0 ? 0 : size; + event.fields.rsvd = rsvd; + event.fields.timestamp = timestamp; + ncclShmem.event_buffer_head++; + } +#endif + } + static void CollectCpuEvent(uint8_t type, int64_t size, uint32_t rsvd, uint64_t timestamp, int channel_id); static uint64_t *GetCpuTimestamp();