Rework kernel launch code (#449)
Этот коммит содержится в:
@@ -348,28 +348,31 @@ __device__ void ncclKernel(ncclWorkElem first) {
|
||||
turn = copyToShmem(&shmem.channel, channel, turn);
|
||||
|
||||
// To optimize for latency, (only) the first operation is passed as argument.
|
||||
struct ncclWorkElem* elems = NULL;
|
||||
bool firstLaunch = true;
|
||||
if (bid == 0 && first.funcIndex != FUNC_INDEX_P2P) elems = &first;
|
||||
if (bid == 0 && first.active != 0)
|
||||
turn = copyToShmem(&shmem.work.elems[0], &first, turn);
|
||||
|
||||
ncclWork *workFifoHost = channel->workFifo;
|
||||
ncclWork *workFifoDev = channel->workFifoDev;
|
||||
int workFifoIx = channel->index;
|
||||
struct ncclWorkElem* elems = shmem.work.elems;
|
||||
__syncthreads(); // publish shmem
|
||||
|
||||
while (1) {
|
||||
if (elems == NULL) {
|
||||
elems = shmem.work.elems;
|
||||
__syncthreads();
|
||||
copyToShmem(&shmem.work, &workFifoDev[workFifoIx]);
|
||||
{ // Check whether the last operation was aborted and make sure all threads exit
|
||||
int aborted = tid == 0 ? *shmem.comm.abortFlag : 0;
|
||||
if (barrierReduceAny(aborted, &abortCount)) { // publish ncclShmem->work
|
||||
if (COLLTRACE && tid == 0) traceAbort(0xffff);
|
||||
break;
|
||||
}
|
||||
if (tid == 0)
|
||||
workFifoHost[workFifoIx].elems[0].active = 0;
|
||||
ncclWork *workFifoHost = shmem.channel.workFifo;
|
||||
ncclWork *workFifoDev = shmem.channel.workFifoDev;
|
||||
int workFifoIx = shmem.channel.index;
|
||||
|
||||
bool skipLoadWork = false, firstLaunch = true;
|
||||
if (bid == 0 && first.active != 0)
|
||||
skipLoadWork = true;
|
||||
|
||||
while (true) {
|
||||
if (!skipLoadWork) {
|
||||
copyToShmem(&shmem.work, &workFifoDev[workFifoIx]); // turn no longer helps
|
||||
// Check whether the last operation was aborted and make sure all threads exit
|
||||
int aborted = tid == 0 ? *shmem.comm.abortFlag : 0;
|
||||
if (barrierReduceAny(aborted, &abortCount)) { // publish shmem.work
|
||||
if (COLLTRACE && tid == 0) traceAbort(elems->funcIndex);
|
||||
break;
|
||||
}
|
||||
if (tid == 0)
|
||||
workFifoHost[workFifoIx].elems[0].active = 0;
|
||||
if (COLLTRACE && tid == 0) {
|
||||
if (firstLaunch) traceKernelLaunch(elems->funcIndex);
|
||||
if (!firstLaunch) traceCollEnd(elems->funcIndex);
|
||||
@@ -379,21 +382,22 @@ __device__ void ncclKernel(ncclWorkElem first) {
|
||||
traceKernelLaunch(elems->funcIndex);
|
||||
firstLaunch = false;
|
||||
}
|
||||
|
||||
workFifoIx = (workFifoIx + 1)%NCCL_MAX_OPS;
|
||||
if (tid == 0)
|
||||
channel->index = workFifoIx; // write back to real channel, not shmem shadow
|
||||
if (tid < elems->nThreads && elems->active != 0) {
|
||||
if (elems->funcIndex == FnIndex) {
|
||||
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
|
||||
} else {
|
||||
NCCL_CALL_FUNCTIONS(elems);
|
||||
}
|
||||
|
||||
if (shmem.work.elems[0].funcIndex == FnIndex)
|
||||
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
|
||||
else
|
||||
NCCL_CALL_FUNCTIONS(&elems[0]);
|
||||
|
||||
if (shmem.work.elems[0].active == 2) {
|
||||
if (COLLTRACE && tid == 0) traceCollEnd(0xffff)
|
||||
break;
|
||||
}
|
||||
if (elems->active == 2) {
|
||||
if (COLLTRACE && tid == 0) traceCollEnd(0xffff);
|
||||
return;
|
||||
}
|
||||
elems = NULL;
|
||||
__syncthreads();
|
||||
skipLoadWork = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user