NPKit update (#844)
* NPKit update 1. Enable NPKit for MSCCL kernels 2. Fix NPKit context index calculation for sendrecv kernels * Update build script for npkit
This commit is contained in:
@@ -318,6 +318,10 @@ if ($npkit_enabled); then
|
||||
-DENABLE_NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_EXIT \
|
||||
-DENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_ENTRY \
|
||||
-DENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT \
|
||||
-DENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY \
|
||||
-DENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT \
|
||||
-DENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY \
|
||||
-DENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT \
|
||||
-DENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME"
|
||||
fi
|
||||
|
||||
|
||||
@@ -185,6 +185,25 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
__synclds(); // publish shmem
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
int npKitCtxIdx = bid;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
||||
if (tid == 0) {
|
||||
uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp;
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp,
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Deference reduce args if required
|
||||
if (tid == 0 && mscclShmem.work.hasReduce && mscclShmem.work.redOpArgIsPtr) {
|
||||
switch (sizeof(T)) {
|
||||
@@ -226,6 +245,12 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0> prims
|
||||
(tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
if (tid == 0) {
|
||||
prims.npKitCtxIdx = npKitCtxIdx;
|
||||
}
|
||||
#endif
|
||||
|
||||
const ssize_t sizePerMscclChunk = mscclShmem.work.count / mscclShmem.work.nChunksPerLoop;
|
||||
uint32_t maxAllowedCount = mscclShmem.work.maxAllowedCount;
|
||||
|
||||
@@ -279,6 +304,13 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
else if (t->type == MSCCL_REDUCE) {
|
||||
int numReductions = t->numReductions;
|
||||
if (thisNelem < nthreads){
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (tid < thisNelem){
|
||||
dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk;
|
||||
T* dstIndex = dstPointer + dstOffset + tid;
|
||||
@@ -308,6 +340,14 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
store(dstIndex, o);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier(nthreads, mscclBarrierNext, mscclBarriers);
|
||||
} else {
|
||||
T* srcs[MSCCL_MAX_REDUCE_FUSION+1]; // +1 is for SIMPLE protocol as dst is added in the list of srcs
|
||||
|
||||
@@ -476,6 +476,13 @@ private:
|
||||
|
||||
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
|
||||
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
T *srcElts = srcs[0];
|
||||
T *dstElts = dsts[0];
|
||||
@@ -534,6 +541,14 @@ private:
|
||||
nelem -= eltPerTrip;
|
||||
offset += nthreads;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
|
||||
@@ -374,6 +374,13 @@ private:
|
||||
|
||||
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
|
||||
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
T const *srcPtr = srcs[0];
|
||||
T *dstPtr = dsts[0];
|
||||
int wireOffset = WireWordPerSlice*warp + 2*wid;
|
||||
@@ -447,6 +454,14 @@ private:
|
||||
}
|
||||
nelem -= DataEltPerSlice*nwarps;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
|
||||
@@ -364,6 +364,13 @@ private:
|
||||
|
||||
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
|
||||
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
if (tid < nworkers) {
|
||||
if (REDUCE){
|
||||
@@ -388,6 +395,14 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
bool isNpKitThread = (tid == 0);
|
||||
int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P;
|
||||
int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P + group;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
||||
@@ -116,7 +116,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ void runRecv(const int tid, const int nthreads, const uint8_t group, struct ncclWorkElemP2p* args) {
|
||||
#if defined(ENABLE_NPKIT)
|
||||
bool isNpKitThread = (tid == 0);
|
||||
int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P + 1;
|
||||
int npKitCtxIdx = blockIdx.x * NCCL_MAX_WORK_ELEMENTS_P2P + group;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
||||
|
||||
@@ -112,5 +112,9 @@
|
||||
#define NPKIT_EVENT_NET_TEST_ENTRY 0x58
|
||||
#define NPKIT_EVENT_NET_TEST_EXIT 0x59
|
||||
|
||||
#define NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY 0x5A
|
||||
#define NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT 0x5B
|
||||
#define NPKIT_EVENT_MSCCL_REDUCE_ENTRY 0x5C
|
||||
#define NPKIT_EVENT_MSCCL_REDUCE_EXIT 0x5D
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user