Этот коммит содержится в:
Wenkai Du
2021-10-28 07:26:11 -07:00
коммит произвёл GitHub
родитель ec36c4c326
Коммит d221fb672a
+34 -30
Просмотреть файл
@@ -348,28 +348,31 @@ __device__ void ncclKernel(ncclWorkElem first) {
turn = copyToShmem(&shmem.channel, channel, turn);
// To optimize for latency, (only) the first operation is passed as argument.
struct ncclWorkElem* elems = NULL;
bool firstLaunch = true;
if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) elems = &first;
if (bid == 0 && first.active != 0)
turn = copyToShmem(&shmem.work.elems[0], &first, turn);
ncclWork *workFifoHost = channel->workFifo;
ncclWork *workFifoDev = channel->workFifoDev;
int workFifoIx = channel->index;
struct ncclWorkElem* elems = shmem.work.elems;
__syncthreads(); // publish shmem
while (1) {
if (elems == NULL) {
elems = shmem.work.elems;
__syncthreads();
copyToShmem(&shmem.work, &workFifoDev[workFifoIx]);
{ // Check whether the last operation was aborted and make sure all threads exit
int aborted = tid == 0 ? *shmem.comm.abortFlag : 0;
if (barrierReduceAny(aborted, &abortCount)) { // publish ncclShmem->work
if (COLLTRACE && tid == 0) traceAbort(0xffff);
break;
}
if (tid == 0)
workFifoHost[workFifoIx].elems[0].active = 0;
ncclWork *workFifoHost = shmem.channel.workFifo;
ncclWork *workFifoDev = shmem.channel.workFifoDev;
int workFifoIx = shmem.channel.index;
bool skipLoadWork = false, firstLaunch = true;
if (bid == 0 && first.active != 0)
skipLoadWork = true;
while (true) {
if (!skipLoadWork) {
copyToShmem(&shmem.work, &workFifoDev[workFifoIx]); // turn no longer helps
// Check whether the last operation was aborted and make sure all threads exit
int aborted = tid == 0 ? *shmem.comm.abortFlag : 0;
if (barrierReduceAny(aborted, &abortCount)) { // publish shmem.work
if (COLLTRACE && tid == 0) traceAbort(elems->funcIndex);
break;
}
if (tid == 0)
workFifoHost[workFifoIx].elems[0].active = 0;
if (COLLTRACE && tid == 0) {
if (firstLaunch) traceKernelLaunch(elems->funcIndex);
if (!firstLaunch) traceCollEnd(elems->funcIndex);
@@ -379,21 +382,22 @@ __device__ void ncclKernel(ncclWorkElem first) {
traceKernelLaunch(elems->funcIndex);
firstLaunch = false;
}
workFifoIx = (workFifoIx + 1)%NCCL_MAX_OPS;
if (tid == 0)
channel->index = workFifoIx; // write back to real channel, not shmem shadow
if (tid < elems->nThreads && elems->active != 0) {
if (elems->funcIndex == FnIndex) {
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
} else {
NCCL_CALL_FUNCTIONS(elems);
}
if (shmem.work.elems[0].funcIndex == FnIndex)
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
else
NCCL_CALL_FUNCTIONS(&elems[0]);
if (shmem.work.elems[0].active == 2) {
if (COLLTRACE && tid == 0) traceCollEnd(0xffff)
break;
}
if (elems->active == 2) {
if (COLLTRACE && tid == 0) traceCollEnd(0xffff);
return;
}
elems = NULL;
__syncthreads();
skipLoadWork = false;
}
}