From a0be2b881260393e26137162d80de6bcb8747204 Mon Sep 17 00:00:00 2001 From: Wenkai Du Date: Wed, 20 Nov 2019 11:44:33 -0800 Subject: [PATCH] Disable direct buffers to reduce scratch memory size --- src/collectives/device/primitives.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index b0d4a2938d..eb044eeac6 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -47,8 +47,10 @@ class ncclPrimitives { uint64_t recvStep[NRECV]; uint64_t sendStep[NSEND]; uint64_t sendConnHead[NSEND]; +#if defined(RCCL_USE_DIRECT_BUFFER) const T* recvDirectBuff[NRECV]; T* sendDirectBuff[NSEND]; +#endif const T* recvBuff[NRECV]; T* sendBuff[NSEND]; struct ncclDevComm* comm; @@ -144,12 +146,20 @@ class ncclPrimitives { template __device__ const T* directRecvPtr(int i, int directOffset) { +#if defined(RCCL_USE_DIRECT_BUFFER) return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i); +#else + return recvPtr(i); +#endif } template __device__ T* directSendPtr(int i, int directOffset) { + #if defined(RCCL_USE_DIRECT_BUFFER) return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i); + #else + return sendPtr(i); + #endif } template @@ -179,6 +189,7 @@ class ncclPrimitives { FOR_RECV(waitRecv); if (realSize > 0) { barrier(); +#if defined(RCCL_USE_DIRECT_BUFFER) if (DIRECTRECV && recvDirectBuff[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (SEND) { @@ -187,6 +198,9 @@ class ncclPrimitives { } else { ReduceOrCopyMulti(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize); } +#else + ReduceOrCopyMulti(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize); +#endif } exitIfAbortBarrier(abort, abortCount); if (tid == 0) @@ -214,11 +228,13 @@ class ncclPrimitives { waitPtr = LOAD(&recvConn[i]->tail); STORE(recvConn[i]->opCountLoc, opCount); } +#if defined(RCCL_USE_DIRECT_BUFFER) recvDirectBuff[i] = NULL; if (directBuff && recvConn[i]->direct) { recvDirectBuff[i] = directBuff; if (tid == 0) STORE(recvConn[i]->ptrExchange, directBuff); } +#endif nrecv++; } @@ -232,6 +248,7 @@ class ncclPrimitives { sendConnHead[i] = LOAD(waitPtr); STORE(sendConn[i]->opCountLoc, opCount); } +#if defined(RCCL_USE_DIRECT_BUFFER) sendDirectBuff[i] = NULL; if (directBuff && sendConn[i]->direct) { void* volatile* ptr = sendConn[i]->ptrExchange; @@ -239,6 +256,7 @@ class ncclPrimitives { __syncthreads(); if (tid == 0) STORE(ptr, NULL); } +#endif nsend++; }