diff --git a/install.sh b/install.sh index 2c209f5edf..2d778f1928 100755 --- a/install.sh +++ b/install.sh @@ -318,6 +318,10 @@ if ($npkit_enabled); then -DENABLE_NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_EXIT \ -DENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_ENTRY \ -DENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT \ + -DENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY \ + -DENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT \ + -DENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY \ + -DENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT \ -DENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME" fi diff --git a/src/collectives/device/msccl_kernel_impl.h b/src/collectives/device/msccl_kernel_impl.h index 74d4162832..416a8ee98c 100644 --- a/src/collectives/device/msccl_kernel_impl.h +++ b/src/collectives/device/msccl_kernel_impl.h @@ -185,6 +185,25 @@ __device__ __forceinline__ void mscclRunInterpreter( } __synclds(); // publish shmem +#if defined(ENABLE_NPKIT) + int npKitCtxIdx = bid; +#endif + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) + if (tid == 0) { + uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp; + NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp, + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + // Deference reduce args if required if (tid == 0 && mscclShmem.work.hasReduce && mscclShmem.work.redOpArgIsPtr) { switch (sizeof(T)) { @@ -226,6 +245,12 @@ __device__ __forceinline__ void mscclRunInterpreter( Primitives, 1, Proto, 0> prims (tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg); +#if defined(ENABLE_NPKIT) + if (tid == 0) { + prims.npKitCtxIdx = npKitCtxIdx; + } +#endif + const ssize_t sizePerMscclChunk = mscclShmem.work.count / mscclShmem.work.nChunksPerLoop; uint32_t maxAllowedCount = mscclShmem.work.maxAllowedCount; @@ -279,6 +304,13 @@ __device__ __forceinline__ void mscclRunInterpreter( else if (t->type == MSCCL_REDUCE) { int numReductions = t->numReductions; if (thisNelem < nthreads){ +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + if (tid < thisNelem){ dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk; T* dstIndex = dstPointer + dstOffset + tid; @@ -308,6 +340,14 @@ __device__ __forceinline__ void mscclRunInterpreter( } store(dstIndex, o); } + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + barrier(nthreads, mscclBarrierNext, mscclBarriers); } else { T* srcs[MSCCL_MAX_REDUCE_FUSION+1]; // +1 is for SIMPLE protocol as dst is added in the list of srcs diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index faa2b03770..d372f979c5 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -476,6 +476,13 @@ private: template __device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) { +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + nelem = nelem < 0 ? 0 : nelem; T *srcElts = srcs[0]; T *dstElts = dsts[0]; @@ -534,6 +541,14 @@ private: nelem -= eltPerTrip; offset += nthreads; } + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + barrier(); } diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 4d8ef8f0ed..bb236c0679 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -374,6 +374,13 @@ private: template __device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) { +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + T const *srcPtr = srcs[0]; T *dstPtr = dsts[0]; int wireOffset = WireWordPerSlice*warp + 2*wid; @@ -447,6 +454,14 @@ private: } nelem -= DataEltPerSlice*nwarps; } + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + barrier(); } diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index 39f6e46ed6..3aa21d882c 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -364,6 +364,13 @@ private: template __device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) { +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + nelem = nelem < 0 ? 0 : nelem; if (tid < nworkers) { if (REDUCE){ @@ -388,6 +395,14 @@ private: } } } + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT) + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + barrier(); } diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h index 647378f2a3..5b6cf4b647 100644 --- a/src/collectives/device/sendrecv.h +++ b/src/collectives/device/sendrecv.h @@ -21,7 +21,7 @@ struct RunWork { #if defined(ENABLE_NPKIT) bool isNpKitThread = (tid == 0); - int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P; + int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P + group; #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) @@ -116,7 +116,7 @@ struct RunWork { __device__ void runRecv(const int tid, const int nthreads, const uint8_t group, struct ncclWorkElemP2p* args) { #if defined(ENABLE_NPKIT) bool isNpKitThread = (tid == 0); - int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P + 1; + int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P + group; #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) diff --git a/src/include/npkit/npkit_event.h b/src/include/npkit/npkit_event.h index 80ad637c3b..a1d24fd3fe 100644 --- a/src/include/npkit/npkit_event.h +++ b/src/include/npkit/npkit_event.h @@ -112,5 +112,9 @@ #define NPKIT_EVENT_NET_TEST_ENTRY 0x58 #define NPKIT_EVENT_NET_TEST_EXIT 0x59 +#define NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY 0x5A +#define NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT 0x5B +#define NPKIT_EVENT_MSCCL_REDUCE_ENTRY 0x5C +#define NPKIT_EVENT_MSCCL_REDUCE_EXIT 0x5D #endif