msccl: enable basic collective trace (#959)
To avoid increasing number of kernels, colltrace is only enabled with RCCL_MSCCL_FORCE_FULLOPS=1
This commit is contained in:
@@ -26,6 +26,28 @@ extern __shared__ struct mscclShmemData mscclShmem;
|
||||
#define GET_WORKINDEX_FROM_FLAG(__FLAG__) \
|
||||
(__FLAG__) / (MSCCL_MAX_ITER*MSCCL_MAX_NUM_STEPS)
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define INC_COLL_TRACE \
|
||||
uint32_t pos = atomicAdd(&ncclShmem.collTraceTail->tail, 1)%COLLTRACE_NUM_ITEMS; \
|
||||
struct ncclCollTrace* collTrace = ncclShmem.collTrace+pos; \
|
||||
collTrace->timeStamp = wall_clock64(); \
|
||||
collTrace->bid = blockIdx.x;
|
||||
// TODO: switch to atomicInc after llvm crash is fixed
|
||||
// uint32_t pos = atomicInc(&ncclShmem.collTraceTail->tail, COLLTRACE_NUM_ITEMS)
|
||||
|
||||
#define traceData(data2, data4, data8_0, data8_1) { \
|
||||
INC_COLL_TRACE \
|
||||
collTrace->funcIndex = data2; \
|
||||
collTrace->data_0 = data4; \
|
||||
collTrace->opCount = data8_0; \
|
||||
collTrace->data_1 = data8_1; \
|
||||
collTrace->type = ncclCollTraceDataType; \
|
||||
}
|
||||
#else
|
||||
#define traceData(data2, data4, data8_0, data8_1)
|
||||
#endif
|
||||
|
||||
|
||||
// a copy of the volatile load/store from prims_ll
|
||||
template<typename U>
|
||||
__device__ static U load(U *src) {
|
||||
@@ -173,7 +195,11 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
break;
|
||||
case 3:
|
||||
/* set abort flag to 0 */
|
||||
if (tid == 3 * WARP_SIZE) ncclShmem.aborted = 0;
|
||||
if (tid%WARP_SIZE == 0) ncclShmem.aborted = 0;
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
else if (tid%WARP_SIZE == 1) ncclShmem.collTrace = comm->collTrace + COLLTRACE_NUM_ITEMS*channelId;
|
||||
else if (tid%WARP_SIZE == 2) ncclShmem.collTraceTail = comm->collTraceTail + channelId;
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -193,6 +219,10 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
#endif
|
||||
__synclds(); // publish shmem
|
||||
|
||||
if (fullOps && tid == 0) {
|
||||
traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0);
|
||||
}
|
||||
|
||||
if (tid == 0)
|
||||
*mscclShmem.work.workFifoDone = mscclShmem.work.workFifoDoneAck;
|
||||
|
||||
@@ -300,6 +330,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
srcPointer = (t->srcBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->srcBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
|
||||
dstPointer = (t->dstBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->dstBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
|
||||
prims.setDataPtrs(srcPointer, dstPointer);
|
||||
|
||||
int count = t->count;
|
||||
for (int c = 0; c < count; c += maxAllowedCount) {
|
||||
srcOffset = gridOffset + (ssize_t) (t->srcOffset+c) * sizePerMscclChunk;
|
||||
@@ -441,6 +472,10 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head);
|
||||
if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head;
|
||||
#endif
|
||||
|
||||
if (fullOps && tid == 0) {
|
||||
traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
|
||||
|
||||
@@ -36,6 +36,7 @@ struct mscclSchedulerParam {
|
||||
int nRanks;
|
||||
bool scheduled;
|
||||
mscclAlgoHandle_t handle;
|
||||
uint64_t opCount;
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -239,7 +239,7 @@ struct mscclWork {
|
||||
int nChunksPerLoop;
|
||||
bool hasReduce;
|
||||
bool redOpArgIsPtr;
|
||||
uint32_t pad[1];
|
||||
uint32_t fnIndex;
|
||||
};
|
||||
static_assert(sizeof(struct mscclWork) % 16 == 0, "mscclWork needs to be 16B aligned");
|
||||
|
||||
|
||||
@@ -305,6 +305,7 @@ static ncclResult_t mscclSetSavedSchedulerParam(
|
||||
param->p.nRanks = comm->nRanks;
|
||||
param->comm = comm;
|
||||
param->stream = stream;
|
||||
param->p.opCount = comm->opCount;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -322,9 +323,27 @@ static ncclResult_t mscclSaveCountsAndDispls(struct mscclSavedSchedulerParam* pa
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
const char *mscclFuncNames[] = {
|
||||
"mscclFuncReduce",
|
||||
"mscclFuncBroadcast",
|
||||
"mscclFuncAllReduce",
|
||||
"mscclFuncReduceScatter",
|
||||
"mscclFuncAllGather",
|
||||
"mscclFuncSend",
|
||||
"mscclFuncRecv",
|
||||
"mscclFuncGather",
|
||||
"mscclFuncScatter",
|
||||
"mscclFuncAllToAll",
|
||||
"mscclFuncAllToAllv",
|
||||
};
|
||||
|
||||
static ncclResult_t mscclRunSavedParams() {
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
for (auto& param : threadLocalStatus.savedSchedulerParams) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
mscclFuncNames[param.p.func], param.p.opCount, param.p.sendBuff, param.p.recvBuff, param.p.count,
|
||||
param.p.dataType, param.p.op, param.p.root, param.comm, param.p.nRanks, param.stream);
|
||||
|
||||
NCCLCHECK(mscclRunAlgo(
|
||||
param.p.sendBuff, param.p.sendCounts, param.p.sDisPls,
|
||||
param.p.recvBuff, param.p.recvCounts, param.p.rDisPls,
|
||||
|
||||
@@ -391,6 +391,16 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
ncclDevRedOpFull opFull = {};
|
||||
NCCLCHECK(hostToDevRedOp(&opFull, op, dataType, comm));
|
||||
|
||||
uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol;
|
||||
uint8_t fullOpMask = (1<<MSCCL_RECV_COPY_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_COPY_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_COPY) |
|
||||
(1<<MSCCL_LOCAL_COPY);
|
||||
//check if need full ops msccl kernel
|
||||
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps())
|
||||
fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
|
||||
|
||||
mscclWork work;
|
||||
work.syncFlags = status.syncFlags;
|
||||
work.scratchBuffer = status.scratchBuffer;
|
||||
@@ -403,7 +413,8 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
work.maxAllowedCount = status.maxAllowedCount;
|
||||
work.hasReduce = hostAlgo->hasReduce;
|
||||
work.redOpArgIsPtr = opFull.scalarArgIsPtr;
|
||||
INFO(NCCL_COLL, "MSCCL: typeMask %x Setup Kernel finished", hostAlgo->typeMask);
|
||||
work.fnIndex = fnIndex;
|
||||
INFO(NCCL_COLL, "MSCCL: typeMask %x fnIndex %d Setup Kernel finished", hostAlgo->typeMask, fnIndex);
|
||||
|
||||
uint32_t workFifoIdxMask = status.workFifoDepth - 1;
|
||||
uint32_t workFifoSent = status.workFifoSent;
|
||||
@@ -428,15 +439,6 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
|
||||
struct mscclWork *workPtr = status.workFifo + (workFifoSent & workFifoIdxMask);
|
||||
void *args[3] = {&comm->devComm, &devAlgo, &workPtr};
|
||||
uint32_t fnIndex = (opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol;
|
||||
uint8_t fullOpMask = (1<<MSCCL_RECV_COPY_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_COPY_SEND) |
|
||||
(1<<MSCCL_RECV_REDUCE_COPY) |
|
||||
(1<<MSCCL_LOCAL_COPY);
|
||||
//check if need full ops msccl kernel
|
||||
if ((hostAlgo->typeMask & fullOpMask) || rcclParamMscclForceFullOps())
|
||||
fnIndex += sizeof(mscclKernelEntries)/sizeof(void *)/2;
|
||||
void *func = mscclKernelEntries[fnIndex];
|
||||
if (enableDoneEvent) {
|
||||
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));
|
||||
|
||||
Reference in New Issue
Block a user