Disable direct buffers to reduce scratch memory size
Αυτή η υποβολή περιλαμβάνεται σε:
@@ -47,8 +47,10 @@ class ncclPrimitives {
|
||||
uint64_t recvStep[NRECV];
|
||||
uint64_t sendStep[NSEND];
|
||||
uint64_t sendConnHead[NSEND];
|
||||
#if defined(RCCL_USE_DIRECT_BUFFER)
|
||||
const T* recvDirectBuff[NRECV];
|
||||
T* sendDirectBuff[NSEND];
|
||||
#endif
|
||||
const T* recvBuff[NRECV];
|
||||
T* sendBuff[NSEND];
|
||||
struct ncclDevComm* comm;
|
||||
@@ -144,12 +146,20 @@ class ncclPrimitives {
|
||||
|
||||
template <int DIRECTRECV>
|
||||
__device__ const T* directRecvPtr(int i, int directOffset) {
|
||||
#if defined(RCCL_USE_DIRECT_BUFFER)
|
||||
return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i);
|
||||
#else
|
||||
return recvPtr(i);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int DIRECTSEND>
|
||||
__device__ T* directSendPtr(int i, int directOffset) {
|
||||
#if defined(RCCL_USE_DIRECT_BUFFER)
|
||||
return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
|
||||
#else
|
||||
return sendPtr(i);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
|
||||
@@ -179,6 +189,7 @@ class ncclPrimitives {
|
||||
FOR_RECV(waitRecv);
|
||||
if (realSize > 0) {
|
||||
barrier();
|
||||
#if defined(RCCL_USE_DIRECT_BUFFER)
|
||||
if (DIRECTRECV && recvDirectBuff[0]) {
|
||||
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
|
||||
if (SEND) {
|
||||
@@ -187,6 +198,9 @@ class ncclPrimitives {
|
||||
} else {
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
|
||||
}
|
||||
#else
|
||||
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
|
||||
#endif
|
||||
}
|
||||
exitIfAbortBarrier(abort, abortCount);
|
||||
if (tid == 0)
|
||||
@@ -214,11 +228,13 @@ class ncclPrimitives {
|
||||
waitPtr = LOAD(&recvConn[i]->tail);
|
||||
STORE(recvConn[i]->opCountLoc, opCount);
|
||||
}
|
||||
#if defined(RCCL_USE_DIRECT_BUFFER)
|
||||
recvDirectBuff[i] = NULL;
|
||||
if (directBuff && recvConn[i]->direct) {
|
||||
recvDirectBuff[i] = directBuff;
|
||||
if (tid == 0) STORE(recvConn[i]->ptrExchange, directBuff);
|
||||
}
|
||||
#endif
|
||||
nrecv++;
|
||||
}
|
||||
|
||||
@@ -232,6 +248,7 @@ class ncclPrimitives {
|
||||
sendConnHead[i] = LOAD(waitPtr);
|
||||
STORE(sendConn[i]->opCountLoc, opCount);
|
||||
}
|
||||
#if defined(RCCL_USE_DIRECT_BUFFER)
|
||||
sendDirectBuff[i] = NULL;
|
||||
if (directBuff && sendConn[i]->direct) {
|
||||
void* volatile* ptr = sendConn[i]->ptrExchange;
|
||||
@@ -239,6 +256,7 @@ class ncclPrimitives {
|
||||
__syncthreads();
|
||||
if (tid == 0) STORE(ptr, NULL);
|
||||
}
|
||||
#endif
|
||||
nsend++;
|
||||
}
|
||||
|
||||
|
||||
Αναφορά σε νέο ζήτημα
Block a user