diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 4cb18a04e7..f0edb0aa3a 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -348,28 +348,31 @@ __device__ void ncclKernel(ncclWorkElem first) { turn = copyToShmem(&shmem.channel, channel, turn); // To optimize for latency, (only) the first operation is passed as argument. - struct ncclWorkElem* elems = NULL; - bool firstLaunch = true; - if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) elems = &first; + if (bid == 0 && first.active != 0) + turn = copyToShmem(&shmem.work.elems[0], &first, turn); - ncclWork *workFifoHost = channel->workFifo; - ncclWork *workFifoDev = channel->workFifoDev; - int workFifoIx = channel->index; + struct ncclWorkElem* elems = shmem.work.elems; + __syncthreads(); // publish shmem - while (1) { - if (elems == NULL) { - elems = shmem.work.elems; - __syncthreads(); - copyToShmem(&shmem.work, &workFifoDev[workFifoIx]); - { // Check whether the last operation was aborted and make sure all threads exit - int aborted = tid == 0 ? *shmem.comm.abortFlag : 0; - if (barrierReduceAny(aborted, &abortCount)) { // publish ncclShmem->work - if (COLLTRACE && tid == 0) traceAbort(0xffff); - break; - } - if (tid == 0) - workFifoHost[workFifoIx].elems[0].active = 0; + ncclWork *workFifoHost = shmem.channel.workFifo; + ncclWork *workFifoDev = shmem.channel.workFifoDev; + int workFifoIx = shmem.channel.index; + + bool skipLoadWork = false, firstLaunch = true; + if (bid == 0 && first.active != 0) + skipLoadWork = true; + + while (true) { + if (!skipLoadWork) { + copyToShmem(&shmem.work, &workFifoDev[workFifoIx]); // turn no longer helps + // Check whether the last operation was aborted and make sure all threads exit + int aborted = tid == 0 ? *shmem.comm.abortFlag : 0; + if (barrierReduceAny(aborted, &abortCount)) { // publish shmem.work + if (COLLTRACE && tid == 0) traceAbort(elems->funcIndex); + break; } + if (tid == 0) + workFifoHost[workFifoIx].elems[0].active = 0; if (COLLTRACE && tid == 0) { if (firstLaunch) traceKernelLaunch(elems->funcIndex); if (!firstLaunch) traceCollEnd(elems->funcIndex); @@ -379,21 +382,22 @@ __device__ void ncclKernel(ncclWorkElem first) { traceKernelLaunch(elems->funcIndex); firstLaunch = false; } + workFifoIx = (workFifoIx + 1)%NCCL_MAX_OPS; if (tid == 0) channel->index = workFifoIx; // write back to real channel, not shmem shadow - if (tid < elems->nThreads && elems->active != 0) { - if (elems->funcIndex == FnIndex) { - RunWork().run(&shmem.work); - } else { - NCCL_CALL_FUNCTIONS(elems); - } + + if (shmem.work.elems[0].funcIndex == FnIndex) + RunWork().run(&shmem.work); + else + NCCL_CALL_FUNCTIONS(&elems[0]); + + if (shmem.work.elems[0].active == 2) { + if (COLLTRACE && tid == 0) traceCollEnd(0xffff) + break; } - if (elems->active == 2) { - if (COLLTRACE && tid == 0) traceCollEnd(0xffff); - return; - } - elems = NULL; + __syncthreads(); + skipLoadWork = false; } }