2
0

Reduce NPKit latency overhead in MSCCL kernel (#893)

* Reduce NPKit latency overhead in MSCCL kernel

* Fix build error without NPKit enable
Este cometimento está contido em:
Wenkai Du
2023-09-15 13:28:26 -07:00
cometido por GitHub
ascendente 16dd05a58a
cometimento 26e982d913
3 ficheiros modificados com 39 adições e 21 eliminações
+6
Ver ficheiro
@@ -359,6 +359,8 @@ struct ncclShmemGroup {
uint64_t barrier_next[NCCL_MAX_GROUPS];
};
#define LDS_NUM_EVENTS 64
struct ncclShmemData {
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
uint64_t redOpArgs[NCCL_MAX_NVLS_ARITY+1];
@@ -374,6 +376,10 @@ struct ncclShmemData {
#ifdef ENABLE_PROFILING
struct ncclProf prof;
#endif
#if defined(ENABLE_NPKIT)
NpKitEvent event_buffer[LDS_NUM_EVENTS];
uint64_t event_buffer_head;
#endif
};
static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "ncclShmem.work needs to be 16B aligned");
+18 -21
Ver ficheiro
@@ -186,24 +186,23 @@ __device__ __forceinline__ void mscclRunInterpreter(
}
copyToShmem8(tid%WARP_SIZE, dst, src, bytes);
}
__synclds(); // publish shmem
#if defined(ENABLE_NPKIT)
int npKitCtxIdx = bid;
if (tid == 0) ncclShmem.event_buffer_head = 0;
#endif
__synclds(); // publish shmem
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
if (tid == 0) {
uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp;
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp,
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp);
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
@@ -263,8 +262,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_XCC_ID)" : "=s" (xcc_id));
#endif
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
@@ -317,31 +315,27 @@ __device__ __forceinline__ void mscclRunInterpreter(
if (t->type == MSCCL_SEND) {
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
prims.sendWithBarrier(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
}
else if (t->type == MSCCL_RECV) {
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
prims.recv(dstOffset, thisNelem);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
}
@@ -350,8 +344,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
if (thisNelem < nthreads){
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
@@ -387,8 +380,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
@@ -443,10 +435,15 @@ __device__ __forceinline__ void mscclRunInterpreter(
}
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RUN_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
}
#endif
#if defined(ENABLE_NPKIT)
__synclds();
NpKitEventCollectContext* ctx = ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx;
copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head);
if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head;
#endif
}
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
+15
Ver ficheiro
@@ -13,9 +13,11 @@
#include "npkit/npkit_event.h"
#include "npkit/npkit_struct.h"
#include "common.h"
#define NPKIT_GET_GPU_TIMESTAMP wall_clock64
class NpKit {
public:
static const uint64_t kNumGpuEventBuffers = 512;
@@ -43,6 +45,19 @@ class NpKit {
}
}
static inline __device__ void CollectGpuEventLDS(uint8_t type, int64_t size, uint32_t rsvd, uint64_t timestamp) {
#if defined(ENABLE_NPKIT)
if (ncclShmem.event_buffer_head < LDS_NUM_EVENTS) {
NpKitEvent& event = ncclShmem.event_buffer[ncclShmem.event_buffer_head];
event.fields.type = type;
event.fields.size = size < 0 ? 0 : size;
event.fields.rsvd = rsvd;
event.fields.timestamp = timestamp;
ncclShmem.event_buffer_head++;
}
#endif
}
static void CollectCpuEvent(uint8_t type, int64_t size, uint32_t rsvd, uint64_t timestamp, int channel_id);
static uint64_t *GetCpuTimestamp();