diff --git a/src/collectives/device/msccl_kernel_impl.h b/src/collectives/device/msccl_kernel_impl.h index 31ca61ad1c..a66a083362 100644 --- a/src/collectives/device/msccl_kernel_impl.h +++ b/src/collectives/device/msccl_kernel_impl.h @@ -26,6 +26,28 @@ extern __shared__ struct mscclShmemData mscclShmem; #define GET_WORKINDEX_FROM_FLAG(__FLAG__) \ (__FLAG__) / (MSCCL_MAX_ITER*MSCCL_MAX_NUM_STEPS) +#ifdef ENABLE_COLLTRACE + #define INC_COLL_TRACE \ + uint32_t pos = atomicAdd(&ncclShmem.collTraceTail->tail, 1)%COLLTRACE_NUM_ITEMS; \ + struct ncclCollTrace* collTrace = ncclShmem.collTrace+pos; \ + collTrace->timeStamp = wall_clock64(); \ + collTrace->bid = blockIdx.x; + // TODO: switch to atomicInc after llvm crash is fixed + // uint32_t pos = atomicInc(&ncclShmem.collTraceTail->tail, COLLTRACE_NUM_ITEMS) + + #define traceData(data2, data4, data8_0, data8_1) { \ + INC_COLL_TRACE \ + collTrace->funcIndex = data2; \ + collTrace->data_0 = data4; \ + collTrace->opCount = data8_0; \ + collTrace->data_1 = data8_1; \ + collTrace->type = ncclCollTraceDataType; \ + } +#else +#define traceData(data2, data4, data8_0, data8_1) +#endif + + // a copy of the volatile load/store from prims_ll template __device__ static U load(U *src) { @@ -173,7 +195,11 @@ __device__ __forceinline__ void mscclRunInterpreter( break; case 3: /* set abort flag to 0 */ - if (tid == 3 * WARP_SIZE) ncclShmem.aborted = 0; + if (tid%WARP_SIZE == 0) ncclShmem.aborted = 0; +#ifdef ENABLE_COLLTRACE + else if (tid%WARP_SIZE == 1) ncclShmem.collTrace = comm->collTrace + COLLTRACE_NUM_ITEMS*channelId; + else if (tid%WARP_SIZE == 2) ncclShmem.collTraceTail = comm->collTraceTail + channelId; +#endif break; default: break; @@ -193,6 +219,10 @@ __device__ __forceinline__ void mscclRunInterpreter( #endif __synclds(); // publish shmem + if (fullOps && tid == 0) { + traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0); + } + if (tid == 0) *mscclShmem.work.workFifoDone = mscclShmem.work.workFifoDoneAck; @@ -300,6 +330,7 @@ __device__ __forceinline__ void mscclRunInterpreter( srcPointer = (t->srcBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->srcBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch); dstPointer = (t->dstBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->dstBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch); prims.setDataPtrs(srcPointer, dstPointer); + int count = t->count; for (int c = 0; c < count; c += maxAllowedCount) { srcOffset = gridOffset + (ssize_t) (t->srcOffset+c) * sizePerMscclChunk; @@ -441,6 +472,10 @@ __device__ __forceinline__ void mscclRunInterpreter( copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head); if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head; #endif + + if (fullOps && tid == 0) { + traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0); + } } #define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \ diff --git a/src/include/msccl/msccl_scheduler.h b/src/include/msccl/msccl_scheduler.h index 1dd9f6b9fb..1776d12cdb 100644 --- a/src/include/msccl/msccl_scheduler.h +++ b/src/include/msccl/msccl_scheduler.h @@ -36,6 +36,7 @@ struct mscclSchedulerParam { int nRanks; bool scheduled; mscclAlgoHandle_t handle; + uint64_t opCount; }; typedef struct { diff --git a/src/include/msccl/msccl_struct.h b/src/include/msccl/msccl_struct.h index 73fc5c5f85..bd288de4ce 100644 --- a/src/include/msccl/msccl_struct.h +++ b/src/include/msccl/msccl_struct.h @@ -239,7 +239,7 @@ struct mscclWork { int nChunksPerLoop; bool hasReduce; bool redOpArgIsPtr; - uint32_t pad[1]; + uint32_t fnIndex; }; static_assert(sizeof(struct mscclWork) % 16 == 0, "mscclWork needs to be 16B aligned"); diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index e1d8e6e3ef..e402fb019d 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -305,6 +305,7 @@ static ncclResult_t mscclSetSavedSchedulerParam( param->p.nRanks = comm->nRanks; param->comm = comm; param->stream = stream; + param->p.opCount = comm->opCount; return ncclSuccess; } @@ -322,9 +323,27 @@ static ncclResult_t mscclSaveCountsAndDispls(struct mscclSavedSchedulerParam* pa return ncclSuccess; } +const char *mscclFuncNames[] = { + "mscclFuncReduce", + "mscclFuncBroadcast", + "mscclFuncAllReduce", + "mscclFuncReduceScatter", + "mscclFuncAllGather", + "mscclFuncSend", + "mscclFuncRecv", + "mscclFuncGather", + "mscclFuncScatter", + "mscclFuncAllToAll", + "mscclFuncAllToAllv", + }; + static ncclResult_t mscclRunSavedParams() { mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); for (auto& param : threadLocalStatus.savedSchedulerParams) { + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", + mscclFuncNames[param.p.func], param.p.opCount, param.p.sendBuff, param.p.recvBuff, param.p.count, + param.p.dataType, param.p.op, param.p.root, param.comm, param.p.nRanks, param.stream); + NCCLCHECK(mscclRunAlgo( param.p.sendBuff, param.p.sendCounts, param.p.sDisPls, param.p.recvBuff, param.p.recvCounts, param.p.rDisPls, diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index 5c8dff5106..7765ec2e71 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -391,6 +391,16 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count ncclDevRedOpFull opFull = {}; NCCLCHECK(hostToDevRedOp(&opFull, op, dataType, comm)); + uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol; + uint8_t fullOpMask = (1<typeMask & fullOpMask) || rcclParamMscclForceFullOps()) + fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2; + mscclWork work; work.syncFlags = status.syncFlags; work.scratchBuffer = status.scratchBuffer; @@ -403,7 +413,8 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count work.maxAllowedCount = status.maxAllowedCount; work.hasReduce = hostAlgo->hasReduce; work.redOpArgIsPtr = opFull.scalarArgIsPtr; - INFO(NCCL_COLL, "MSCCL: typeMask %x Setup Kernel finished", hostAlgo->typeMask); + work.fnIndex = fnIndex; + INFO(NCCL_COLL, "MSCCL: typeMask %x fnIndex %d Setup Kernel finished", hostAlgo->typeMask, fnIndex); uint32_t workFifoIdxMask = status.workFifoDepth - 1; uint32_t workFifoSent = status.workFifoSent; @@ -428,15 +439,6 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count struct mscclWork *workPtr = status.workFifo + (workFifoSent & workFifoIdxMask); void *args[3] = {&comm->devComm, &devAlgo, &workPtr}; - uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol; - uint8_t fullOpMask = (1<typeMask & fullOpMask) || rcclParamMscclForceFullOps()) - fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2; void *func = mscclKernelEntries[fnIndex]; if (enableDoneEvent) { CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));