msccl: enable basic collective trace (#959)

To avoid increasing number of kernels, colltrace is only enabled with
RCCL_MSCCL_FORCE_FULLOPS=1
This commit is contained in:
Wenkai Du
2023-11-08 20:14:28 -08:00
committed by GitHub
orang tua 8e0258a73d
melakukan 5a800e00cd
5 mengubah file dengan 69 tambahan dan 12 penghapusan
+36 -1
Melihat File
@@ -26,6 +26,28 @@ extern __shared__ struct mscclShmemData mscclShmem;
#define GET_WORKINDEX_FROM_FLAG(__FLAG__) \
(__FLAG__) / (MSCCL_MAX_ITER*MSCCL_MAX_NUM_STEPS)
#ifdef ENABLE_COLLTRACE
#define INC_COLL_TRACE \
uint32_t pos = atomicAdd(&ncclShmem.collTraceTail->tail, 1)%COLLTRACE_NUM_ITEMS; \
struct ncclCollTrace* collTrace = ncclShmem.collTrace+pos; \
collTrace->timeStamp = wall_clock64(); \
collTrace->bid = blockIdx.x;
// TODO: switch to atomicInc after llvm crash is fixed
// uint32_t pos = atomicInc(&ncclShmem.collTraceTail->tail, COLLTRACE_NUM_ITEMS)
#define traceData(data2, data4, data8_0, data8_1) { \
INC_COLL_TRACE \
collTrace->funcIndex = data2; \
collTrace->data_0 = data4; \
collTrace->opCount = data8_0; \
collTrace->data_1 = data8_1; \
collTrace->type = ncclCollTraceDataType; \
}
#else
#define traceData(data2, data4, data8_0, data8_1)
#endif
// a copy of the volatile load/store from prims_ll
template<typename U>
__device__ static U load(U *src) {
@@ -173,7 +195,11 @@ __device__ __forceinline__ void mscclRunInterpreter(
break;
case 3:
/* set abort flag to 0 */
if (tid == 3 * WARP_SIZE) ncclShmem.aborted = 0;
if (tid%WARP_SIZE == 0) ncclShmem.aborted = 0;
#ifdef ENABLE_COLLTRACE
else if (tid%WARP_SIZE == 1) ncclShmem.collTrace = comm->collTrace + COLLTRACE_NUM_ITEMS*channelId;
else if (tid%WARP_SIZE == 2) ncclShmem.collTraceTail = comm->collTraceTail + channelId;
#endif
break;
default:
break;
@@ -193,6 +219,10 @@ __device__ __forceinline__ void mscclRunInterpreter(
#endif
__synclds(); // publish shmem
if (fullOps && tid == 0) {
traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0);
}
if (tid == 0)
*mscclShmem.work.workFifoDone = mscclShmem.work.workFifoDoneAck;
@@ -300,6 +330,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
srcPointer = (t->srcBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->srcBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
dstPointer = (t->dstBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->dstBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
prims.setDataPtrs(srcPointer, dstPointer);
int count = t->count;
for (int c = 0; c < count; c += maxAllowedCount) {
srcOffset = gridOffset + (ssize_t) (t->srcOffset+c) * sizePerMscclChunk;
@@ -441,6 +472,10 @@ __device__ __forceinline__ void mscclRunInterpreter(
copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head);
if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head;
#endif
if (fullOps && tid == 0) {
traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0);
}
}
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
+1
Melihat File
@@ -36,6 +36,7 @@ struct mscclSchedulerParam {
int nRanks;
bool scheduled;
mscclAlgoHandle_t handle;
uint64_t opCount;
};
typedef struct {
+1 -1
Melihat File
@@ -239,7 +239,7 @@ struct mscclWork {
int nChunksPerLoop;
bool hasReduce;
bool redOpArgIsPtr;
uint32_t pad[1];
uint32_t fnIndex;
};
static_assert(sizeof(struct mscclWork) % 16 == 0, "mscclWork needs to be 16B aligned");
+19
Melihat File
@@ -305,6 +305,7 @@ static ncclResult_t mscclSetSavedSchedulerParam(
param->p.nRanks = comm->nRanks;
param->comm = comm;
param->stream = stream;
param->p.opCount = comm->opCount;
return ncclSuccess;
}
@@ -322,9 +323,27 @@ static ncclResult_t mscclSaveCountsAndDispls(struct mscclSavedSchedulerParam* pa
return ncclSuccess;
}
const char *mscclFuncNames[] = {
"mscclFuncReduce",
"mscclFuncBroadcast",
"mscclFuncAllReduce",
"mscclFuncReduceScatter",
"mscclFuncAllGather",
"mscclFuncSend",
"mscclFuncRecv",
"mscclFuncGather",
"mscclFuncScatter",
"mscclFuncAllToAll",
"mscclFuncAllToAllv",
};
static ncclResult_t mscclRunSavedParams() {
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
for (auto& param : threadLocalStatus.savedSchedulerParams) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
mscclFuncNames[param.p.func], param.p.opCount, param.p.sendBuff, param.p.recvBuff, param.p.count,
param.p.dataType, param.p.op, param.p.root, param.comm, param.p.nRanks, param.stream);
NCCLCHECK(mscclRunAlgo(
param.p.sendBuff, param.p.sendCounts, param.p.sDisPls,
param.p.recvBuff, param.p.recvCounts, param.p.rDisPls,
+12 -10
Melihat File
@@ -391,6 +391,16 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
ncclDevRedOpFull opFull = {};
NCCLCHECK(hostToDevRedOp(&opFull, op, dataType, comm));
uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol;
uint8_t fullOpMask = (1<<MSCCL_RECV_COPY_SEND) |
(1<<MSCCL_RECV_REDUCE_SEND) |
(1<<MSCCL_RECV_REDUCE_COPY_SEND) |
(1<<MSCCL_RECV_REDUCE_COPY) |
(1<<MSCCL_LOCAL_COPY);
//check if need full ops msccl kernel
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps())
fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
mscclWork work;
work.syncFlags = status.syncFlags;
work.scratchBuffer = status.scratchBuffer;
@@ -403,7 +413,8 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
work.maxAllowedCount = status.maxAllowedCount;
work.hasReduce = hostAlgo->hasReduce;
work.redOpArgIsPtr = opFull.scalarArgIsPtr;
INFO(NCCL_COLL, "MSCCL: typeMask %x Setup Kernel finished", hostAlgo->typeMask);
work.fnIndex = fnIndex;
INFO(NCCL_COLL, "MSCCL: typeMask %x fnIndex %d Setup Kernel finished", hostAlgo->typeMask, fnIndex);
uint32_t workFifoIdxMask = status.workFifoDepth - 1;
uint32_t workFifoSent = status.workFifoSent;
@@ -428,15 +439,6 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
struct mscclWork *workPtr = status.workFifo + (workFifoSent & workFifoIdxMask);
void *args[3] = {&comm->devComm, &devAlgo, &workPtr};
uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol;
uint8_t fullOpMask = (1<<MSCCL_RECV_COPY_SEND) |
(1<<MSCCL_RECV_REDUCE_SEND) |
(1<<MSCCL_RECV_REDUCE_COPY_SEND) |
(1<<MSCCL_RECV_REDUCE_COPY) |
(1<<MSCCL_LOCAL_COPY);
//check if need full ops msccl kernel
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps())
fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
void *func = mscclKernelEntries[fnIndex];
if (enableDoneEvent) {
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));