Reduce NPKit latency overhead in MSCCL kernel (#893)
* Reduce NPKit latency overhead in MSCCL kernel * Fix build error without NPKit enable
Este cometimento está contido em:
cometido por
GitHub
ascendente
16dd05a58a
cometimento
26e982d913
@@ -359,6 +359,8 @@ struct ncclShmemGroup {
|
||||
uint64_t barrier_next[NCCL_MAX_GROUPS];
|
||||
};
|
||||
|
||||
#define LDS_NUM_EVENTS 64
|
||||
|
||||
struct ncclShmemData {
|
||||
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
|
||||
uint64_t redOpArgs[NCCL_MAX_NVLS_ARITY+1];
|
||||
@@ -374,6 +376,10 @@ struct ncclShmemData {
|
||||
#ifdef ENABLE_PROFILING
|
||||
struct ncclProf prof;
|
||||
#endif
|
||||
#if defined(ENABLE_NPKIT)
|
||||
NpKitEvent event_buffer[LDS_NUM_EVENTS];
|
||||
uint64_t event_buffer_head;
|
||||
#endif
|
||||
};
|
||||
static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "ncclShmem.work needs to be 16B aligned");
|
||||
|
||||
|
||||
@@ -186,24 +186,23 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
copyToShmem8(tid%WARP_SIZE, dst, src, bytes);
|
||||
}
|
||||
__synclds(); // publish shmem
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
int npKitCtxIdx = bid;
|
||||
if (tid == 0) ncclShmem.event_buffer_head = 0;
|
||||
#endif
|
||||
__synclds(); // publish shmem
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
||||
if (tid == 0) {
|
||||
uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp;
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp,
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -263,8 +262,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_XCC_ID)" : "=s" (xcc_id));
|
||||
#endif
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -317,31 +315,27 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
if (t->type == MSCCL_SEND) {
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
prims.sendWithBarrier(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if (t->type == MSCCL_RECV) {
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
prims.recv(dstOffset, thisNelem);
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -350,8 +344,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
if (thisNelem < nthreads){
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -387,8 +380,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -443,10 +435,15 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
}
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RUN_EXIT)
|
||||
if (tid == 0) {
|
||||
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
|
||||
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NPKIT)
|
||||
__synclds();
|
||||
NpKitEventCollectContext* ctx = ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx;
|
||||
copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head);
|
||||
if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head;
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type) \
|
||||
|
||||
@@ -13,9 +13,11 @@
|
||||
|
||||
#include "npkit/npkit_event.h"
|
||||
#include "npkit/npkit_struct.h"
|
||||
#include "common.h"
|
||||
|
||||
#define NPKIT_GET_GPU_TIMESTAMP wall_clock64
|
||||
|
||||
|
||||
class NpKit {
|
||||
public:
|
||||
static const uint64_t kNumGpuEventBuffers = 512;
|
||||
@@ -43,6 +45,19 @@ class NpKit {
|
||||
}
|
||||
}
|
||||
|
||||
static inline __device__ void CollectGpuEventLDS(uint8_t type, int64_t size, uint32_t rsvd, uint64_t timestamp) {
|
||||
#if defined(ENABLE_NPKIT)
|
||||
if (ncclShmem.event_buffer_head < LDS_NUM_EVENTS) {
|
||||
NpKitEvent& event = ncclShmem.event_buffer[ncclShmem.event_buffer_head];
|
||||
event.fields.type = type;
|
||||
event.fields.size = size < 0 ? 0 : size;
|
||||
event.fields.rsvd = rsvd;
|
||||
event.fields.timestamp = timestamp;
|
||||
ncclShmem.event_buffer_head++;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void CollectCpuEvent(uint8_t type, int64_t size, uint32_t rsvd, uint64_t timestamp, int channel_id);
|
||||
|
||||
static uint64_t *GetCpuTimestamp();
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador