From 04553b802a4e2cfd6cdb2309ac1d2a37f0809e66 Mon Sep 17 00:00:00 2001 From: John Bachan Date: Tue, 31 Aug 2021 14:33:48 -0700 Subject: [PATCH] Fix to https://github.com/NVIDIA/nccl/issues/560 ncclGroup's containing operations of mixed datatype, element, or collective would induce crash. [ROCm/rccl commit: 5f2f2f670f2604ff44d1f996733e218c0743af73] --- projects/rccl/src/collectives/device/common.h | 27 ++++++++++++++++--- projects/rccl/src/enqueue.cc | 22 ++++++++------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index f37995d0ef..2b5d51640e 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -71,9 +71,30 @@ template struct RunWork { __device__ void run(ncclWork *w) { int tid = threadIdx.x; - #pragma unroll 1 - for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) { - if (tid < w->elems[e].nThreads) + /* Some invariants that must hold: + * 1. All elems[] have same funcIndex. + * 2. All elems[] have same nThreads. + * 3. The thread-to-group relation (as in prims group numbers) is the same + * for all elems[]. + * + * If (1) isn't true then we might be in the wrong function since dispatch + * on ncclFuncs[w->elems[0].funcIndex] is how we got here. + * + * If (2) or (3) aren't true, then threads from different work elements + * could race for barrier resources (barrier numbers 0...15) which is fatal. + * + * Important, to ensure (3), implementations of + * `RunWorkElement::run()` may only use values which + * are the same for all elems[] when deciding how to map threads to groups, + * such as the following: + * Fn, T, RedOp, Algo, Proto, nThreads + * + * This last one is difficult to enforce and diagnosing it is a headeache. + * Device-side developers, consider yourselves warned. + */ + if (tid < w->elems[0].nThreads) { + #pragma unroll 1 + for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) RunWorkElement().run(&w->elems[e]); } } diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index 5f8c6ab234..df09166568 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -681,29 +681,29 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { // Reduce the per-channel size if we cannot fully utilize the channels while (comm->asyncTotalSize < channelSize * comm->nChannels && channelSize > NCCL_MIN_CHANNEL_SIZE) channelSize /= 2; int channelUsed = 0; - ncclFunc_t commonColl = ncclNumFuncs; - int fastPath = 1; + int homogeneous = 1; int allCollNetSupport = comm->collNetSupport; for (int c = 0; c < comm->asyncOpCount; c++) { struct ncclInfo* info = comm->asyncOps+c; info->nChannels = std::min(std::max(1, (int)DIVUP(info->nBytes, channelSize)), comm->nChannels); // assign number of channels channelUsed += info->nChannels; // We can use fast path if all collectives are the same - if (commonColl == ncclNumFuncs) commonColl = info->coll; - else if (commonColl != info->coll) fastPath = 0; - else if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport)); + homogeneous &= info->coll == comm->asyncOps[0].coll && + info->op == comm->asyncOps[0].op && + info->datatype == comm->asyncOps[0].datatype; + if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport)); } // Compute algo, proto, nthreads for the entire kernel struct ncclInfo total; total.comm = comm; - total.coll = commonColl; + total.coll = comm->asyncOps[0].coll; total.nBytes = comm->asyncTotalSize; total.nChannels = std::min(channelUsed, comm->nChannels); int perChannelOps = DIVUP(channelUsed, total.nChannels); - if (fastPath) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps)); + if (homogeneous) NCCLCHECK(getAlgoInfo(&total, allCollNetSupport, perChannelOps)); for (int c = 0; c < comm->asyncOpCount; c++) { struct ncclInfo* info = comm->asyncOps+c; - if (fastPath) { + if (homogeneous) { info->algorithm = total.algorithm; info->protocol = total.protocol; info->nThreads = total.nThreads; @@ -883,7 +883,11 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; struct ncclWork* w = channel->workFifo+opIndex; int segment = -1; - if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) { + if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0 && + // All elems in work must have same (funcIndex,nThreads), + // see "src/collectives/device/common.h" + w->elems[0].funcIndex == work->funcIndex && + w->elems[0].nThreads == work->nThreads) { // Try to pack more segments into a single operation segment = getSegment(COLL_SEGMENT, 0, w); }