Merge pull request #155 from wenkaidu/direct

Disable direct buffers to reduce scratch memory size
This commit is contained in:
Wenkai Du
2019-11-21 09:39:09 -08:00
committato da GitHub
+18
Vedi File
@@ -47,8 +47,10 @@ class ncclPrimitives {
uint64_t recvStep[NRECV];
uint64_t sendStep[NSEND];
uint64_t sendConnHead[NSEND];
#if defined(RCCL_USE_DIRECT_BUFFER)
const T* recvDirectBuff[NRECV];
T* sendDirectBuff[NSEND];
#endif
const T* recvBuff[NRECV];
T* sendBuff[NSEND];
struct ncclDevComm* comm;
@@ -144,12 +146,20 @@ class ncclPrimitives {
template <int DIRECTRECV>
__device__ const T* directRecvPtr(int i, int directOffset) {
#if defined(RCCL_USE_DIRECT_BUFFER)
return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i);
#else
return recvPtr(i);
#endif
}
template <int DIRECTSEND>
__device__ T* directSendPtr(int i, int directOffset) {
#if defined(RCCL_USE_DIRECT_BUFFER)
return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
#else
return sendPtr(i);
#endif
}
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
@@ -179,6 +189,7 @@ class ncclPrimitives {
FOR_RECV(waitRecv);
if (realSize > 0) {
barrier();
#if defined(RCCL_USE_DIRECT_BUFFER)
if (DIRECTRECV && recvDirectBuff[0]) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (SEND) {
@@ -187,6 +198,9 @@ class ncclPrimitives {
} else {
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
}
#else
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
#endif
}
exitIfAbortBarrier(abort, abortCount);
if (tid == 0)
@@ -214,11 +228,13 @@ class ncclPrimitives {
waitPtr = LOAD(&recvConn[i]->tail);
STORE(recvConn[i]->opCountLoc, opCount);
}
#if defined(RCCL_USE_DIRECT_BUFFER)
recvDirectBuff[i] = NULL;
if (directBuff && recvConn[i]->direct) {
recvDirectBuff[i] = directBuff;
if (tid == 0) STORE(recvConn[i]->ptrExchange, directBuff);
}
#endif
nrecv++;
}
@@ -232,6 +248,7 @@ class ncclPrimitives {
sendConnHead[i] = LOAD(waitPtr);
STORE(sendConn[i]->opCountLoc, opCount);
}
#if defined(RCCL_USE_DIRECT_BUFFER)
sendDirectBuff[i] = NULL;
if (directBuff && sendConn[i]->direct) {
void* volatile* ptr = sendConn[i]->ptrExchange;
@@ -239,6 +256,7 @@ class ncclPrimitives {
__syncthreads();
if (tid == 0) STORE(ptr, NULL);
}
#endif
nsend++;
}