diff --git a/src/device/prims_simple.h b/src/device/prims_simple.h index 63e9c5c88f..6a7260f7a2 100644 --- a/src/device/prims_simple.h +++ b/src/device/prims_simple.h @@ -131,6 +131,12 @@ private: template __device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) { +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_WAIT_PEER_ENTRY) + if (threadIdx.x == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_WAIT_PEER_ENTRY, nelts*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; // Yes, for some template arguments this code will be unreachable. That's fine. // coverity[dead_error_line] @@ -198,6 +204,12 @@ private: } step += StepPerSlice; } +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_WAIT_PEER_EXIT) + if (threadIdx.x == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_WAIT_PEER_EXIT, nelts*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif } template