Revert "Split primitive class to smaller structures" (#230)

This reverts commit 622b49e80a.

[ROCm/rccl commit: 5215130168]
Bu işleme şunda yer alıyor:
Wenkai Du
2020-07-08 11:06:50 -07:00
işlemeyi yapan: GitHub
ebeveyn 4a3b58ac3a
işleme 47dfed5f01
3 değiştirilmiş dosya ile 151 ekleme ve 210 silme
+4 -8
Dosyayı Görüntüle
@@ -129,8 +129,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclPrimitivesRecvData<T, NCCL_MAX_TREE_ARITY> recvData;
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount, recvData);
ncclPrimitives<UNROLL/2, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, 0, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -148,8 +147,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclPrimitivesSendData<T, NCCL_MAX_TREE_ARITY> sendData;
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount, sendData);
ncclPrimitives<UNROLL/2, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, &tree->up, tree->down, thisOutput, stepSize, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
@@ -323,8 +321,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclLLPrimitivesRecvData<T, NCCL_MAX_TREE_ARITY> recvData;
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount, recvData);
ncclLLPrimitives<T, FUNC, NCCL_MAX_TREE_ARITY, 1> LLprims(tid, nthreads, tree->down, &tree->up, stepLines, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -342,8 +339,7 @@ __device__ void ncclAllReduceTreeLLKernel(struct CollectiveArgs* args) {
do {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclLLPrimitivesSendData<T, NCCL_MAX_TREE_ARITY> sendData;
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount, sendData);
ncclLLPrimitives<T, FUNC, 1, NCCL_MAX_TREE_ARITY> LLprims(tid, nthreads, &tree->up, tree->down, stepLines, channel, comm, args->opCount);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;
+90 -116
Dosyayı Görüntüle
@@ -39,41 +39,6 @@
while (LOAD(barriers+id) < barrier_next[id*MAXWARPS+w]) /* spin */; \
} while (0)
template <typename T, int NRECV>
class ncclPrimitivesRecvData {
public:
struct ncclConnInfo* recvConn = NULL;
volatile uint64_t* recvConnHeadPtr = NULL;
uint64_t recvConnHead;
volatile uint64_t* recvConnTailPtr = NULL;
uint64_t recvConnTail;
uint64_t recvConnTailCache; // Cache last seen value
uint64_t recvStep[NRECV];
#if defined(RCCL_USE_DIRECT_BUFFER)
const T* recvDirectBuff[NRECV];
#endif
const T* recvBuff[NRECV];
};
template <typename T, int NSEND>
class ncclPrimitivesSendData {
public:
struct ncclConnInfo* sendConn = NULL;
volatile int* sendConnFifoPtr = NULL;
volatile uint64_t* sendConnTailPtr = NULL;
uint64_t sendConnTail;
volatile uint64_t* sendConnHeadPtr = NULL;
uint64_t sendConnHead;
uint64_t sendConnHeadCache; // Cache last seen value
uint64_t sendStep[NSEND];
#if defined(RCCL_USE_DIRECT_BUFFER)
const T* sendDirectBuff[NSEND];
#endif
T* sendBuff[NSEND];
};
// Implementation of primitive types
template <int UNROLL, int SLICESPERCHUNK, int SLICESTEPS, typename T, int NRECV, int NSEND, int DIRECT, class FUNC>
class ncclPrimitives {
@@ -84,18 +49,35 @@ class ncclPrimitives {
const int stepSize;
int nrecv = 0;
int nsend = 0;
struct ncclConnInfo* recvConn = NULL;
volatile uint64_t* recvConnHeadPtr = NULL;
uint64_t recvConnHead;
volatile uint64_t* recvConnTailPtr = NULL;
uint64_t recvConnTail;
uint64_t recvConnTailCache; // Cache last seen value
typename std::conditional<NRECV == NCCL_MAX_TREE_ARITY,
ncclPrimitivesRecvData<T, NRECV>&, ncclPrimitivesRecvData<T, NRECV>>::type r;
typename std::conditional<NSEND == NCCL_MAX_TREE_ARITY,
ncclPrimitivesSendData<T, NSEND>&, ncclPrimitivesSendData<T, NSEND>>::type s;
struct ncclConnInfo* sendConn = NULL;
volatile int* sendConnFifoPtr = NULL;
volatile uint64_t* sendConnTailPtr = NULL;
uint64_t sendConnTail;
volatile uint64_t* sendConnHeadPtr = NULL;
uint64_t sendConnHead;
uint64_t sendConnHeadCache; // Cache last seen value
const struct ncclDevComm* comm;
uint64_t recvStep[NRECV];
uint64_t sendStep[NSEND];
#if defined(RCCL_USE_DIRECT_BUFFER)
const T* recvDirectBuff[NRECV];
T* sendDirectBuff[NSEND];
#endif
const T* recvBuff[NRECV];
T* sendBuff[NSEND];
struct ncclDevComm* comm;
inline __device__ int recvOffset(int i) { return (r.recvStep[i]%NCCL_STEPS)*stepSize; }
inline __device__ int sendOffset(int i) { return (s.sendStep[i]%NCCL_STEPS)*stepSize; }
inline __device__ const T* recvPtr(int i) { return ((const T*)r.recvBuff[i])+recvOffset(i); }
inline __device__ T* sendPtr(int i) { return ((T*)s.sendBuff[i])+sendOffset(i); }
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; }
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
inline __device__ const T* recvPtr(int i) { return ((const T*)recvBuff[i])+recvOffset(i); }
inline __device__ T* sendPtr(int i) { return ((T*)sendBuff[i])+sendOffset(i); }
uint64_t* barriers;
uint64_t* barrier_next;
@@ -109,7 +91,11 @@ class ncclPrimitives {
else barrier_by_id(1);
}
#else
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads+WARP_SIZE));
if (NSEND>NRECV) {
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads+WARP_SIZE));
} else {
asm volatile ("bar.sync 2, %0;" :: "r"(nthreads+WARP_SIZE));
}
#endif
}
@@ -117,7 +103,11 @@ class ncclPrimitives {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
__syncthreads();
#else
asm volatile ("bar.sync 2, %0;" :: "r"(nthreads));
if (NSEND>NRECV) {
asm volatile ("bar.sync 3, %0;" :: "r"(nthreads));
} else {
asm volatile ("bar.sync 4, %0;" :: "r"(nthreads));
}
#endif
}
@@ -140,7 +130,7 @@ class ncclPrimitives {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = LOAD(comm->abortFlag);
if (wid == i) checkMismatch(send ? s.sendConn : r.recvConn);
if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
@@ -149,57 +139,57 @@ class ncclPrimitives {
inline __device__ void waitSend(int nbytes) {
spins = 0;
mismatch = 0;
if (s.sendConnHeadPtr) {
while (s.sendConnHeadCache + NCCL_STEPS < s.sendConnHead + SLICESTEPS) {
s.sendConnHeadCache = LOAD(s.sendConnHeadPtr);
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) {
sendConnHeadCache = LOAD(sendConnHeadPtr);
if (checkAbort(wid, 1)) break;
}
if (s.sendConnFifoPtr) {
STORE(s.sendConnFifoPtr+s.sendConnHead%NCCL_STEPS, nbytes);
if (sendConnFifoPtr) {
STORE(sendConnFifoPtr+sendConnHead%NCCL_STEPS, nbytes);
}
s.sendConnHead += SLICESTEPS;
sendConnHead += SLICESTEPS;
}
}
inline __device__ void waitRecv() {
spins = 0;
mismatch = 0;
if (r.recvConnTailPtr) {
if (recvConnTailPtr) {
#ifdef ENABLE_PROFILING
uint64_t t0 = __rtc64();
#endif
while (r.recvConnTailCache < r.recvConnTail + SLICESTEPS) {
r.recvConnTailCache = LOAD(r.recvConnTailPtr);
while (recvConnTailCache < recvConnTail + SLICESTEPS) {
recvConnTailCache = LOAD(recvConnTailPtr);
if (checkAbort(wid, 0)) break;
}
#ifdef ENABLE_PROFILING
if (opCount > 0) __atomic_fetch_add(&comm->devProf->wait_recv_cycle[blockIdx.x], __rtc64() - t0, __ATOMIC_SEQ_CST);
#endif
r.recvConnTail += SLICESTEPS;
recvConnTail += SLICESTEPS;
}
}
inline __device__ void incRecv(int i) {
r.recvStep[i] += SLICESTEPS;
recvStep[i] += SLICESTEPS;
}
inline __device__ void postRecv() {
if (r.recvConnHeadPtr) STORE(r.recvConnHeadPtr, r.recvConnHead += SLICESTEPS);
if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += SLICESTEPS);
}
inline __device__ void incSend(int i) {
s.sendStep[i] += SLICESTEPS;
sendStep[i] += SLICESTEPS;
}
inline __device__ void postSend() {
if (s.sendConnTailPtr) {
if (s.sendConn->next_hdp_reg) STORE(s.sendConn->next_hdp_reg, 0x1);
STORE(s.sendConnTailPtr, s.sendConnTail += SLICESTEPS);
if (sendConnTailPtr) {
if (sendConn->next_hdp_reg) STORE(sendConn->next_hdp_reg, 0x1);
STORE(sendConnTailPtr, sendConnTail += SLICESTEPS);
}
}
template <int DIRECTRECV>
inline __device__ const T* directRecvPtr(int i, ssize_t directOffset) {
#if defined(RCCL_USE_DIRECT_BUFFER)
return DIRECTRECV && r.recvDirectBuff[i] ? r.recvDirectBuff[i]+directOffset : recvPtr(i);
return DIRECTRECV && recvDirectBuff[i] ? recvDirectBuff[i]+directOffset : recvPtr(i);
#else
return recvPtr(i);
#endif
@@ -208,7 +198,7 @@ class ncclPrimitives {
template <int DIRECTSEND>
inline __device__ T* directSendPtr(int i, ssize_t directOffset) {
#if defined(RCCL_USE_DIRECT_BUFFER)
return DIRECTSEND && s.sendDirectBuff[i] ? s.sendDirectBuff[i]+directOffset : sendPtr(i);
return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
#else
return sendPtr(i);
#endif
@@ -217,7 +207,7 @@ class ncclPrimitives {
template <int DIRECTRECV>
inline __device__ int directRecvInc(int i, int directInc, int sliceInc) {
#if defined(RCCL_USE_DIRECT_BUFFER)
return DIRECTRECV && r.recvDirectBuff[i] ? directInc : sliceInc;
return DIRECTRECV && recvDirectBuff[i] ? directInc : sliceInc;
#else
return sliceInc;
#endif
@@ -226,7 +216,7 @@ inline __device__ int directRecvInc(int i, int directInc, int sliceInc) {
template <int DIRECTSEND>
inline __device__ int directSendInc(int i, int directInc, int sliceInc) {
#if defined(RCCL_USE_DIRECT_BUFFER)
return DIRECTSEND && s.sendDirectBuff[i] ? directInc : sliceInc;
return DIRECTSEND && sendDirectBuff[i] ? directInc : sliceInc;
#else
return sliceInc;
#endif
@@ -267,7 +257,7 @@ inline __device__ int directSendInc(int i, int directInc, int sliceInc) {
if (tid == 0 && opCount > 0) __atomic_fetch_add(&comm->devProf->wait_cycle[blockIdx.x], __rtc64() - t0, __ATOMIC_SEQ_CST);
#endif
#if defined(RCCL_USE_DIRECT_BUFFER)
if (DIRECTRECV && r.recvDirectBuff[0]) {
if (DIRECTRECV && recvDirectBuff[0]) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (SEND) {
ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads, 1, srcs, nsend, dsts+1, realSize);
@@ -299,81 +289,84 @@ inline __device__ int directSendInc(int i, int directInc, int sliceInc) {
}
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) {
r.recvBuff[i] = (const T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE);
r.recvStep[i] = LOAD(&conn->step);
r.recvStep[i] = ROUNDUP(r.recvStep[i], SLICESPERCHUNK*SLICESTEPS);
recvBuff[i] = (const T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE);
recvStep[i] = LOAD(&conn->step);
recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
#if defined(RCCL_USE_DIRECT_BUFFER)
r.recvDirectBuff[i] = NULL;
recvDirectBuff[i] = NULL;
if (DIRECT && LOAD((&conn->direct) & NCCL_DIRECT_GPU)) {
r.recvDirectBuff[i] = directBuff;
recvDirectBuff[i] = directBuff;
if (tid == 0) STORE(conn->ptrExchange, directBuff);
}
#endif
if (wid == i) r.recvConn = conn;
if (wid == i) r.recvConnTail = r.recvConnHead = r.recvStep[i]; // Make sure we set this after rounding up
if (wid == i) recvConn = conn;
if (wid == i) recvConnTail = recvConnHead = recvStep[i]; // Make sure we set this after rounding up
nrecv++;
}
__device__ __forceinline__ void loadRecvSync() {
if (tid >= WARP_SIZE && tid < 2*WARP_SIZE && wid<nrecv) {
r.recvConnTailPtr = LOAD(&r.recvConn->tail);
r.recvConnTailCache = LOAD(r.recvConnTailPtr);
recvConnTailPtr = LOAD(&recvConn->tail);
recvConnTailCache = LOAD(recvConnTailPtr);
}
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
r.recvConnHeadPtr = LOAD(&r.recvConn->head);
recvConnHeadPtr = LOAD(&recvConn->head);
// Return credits in case we rounded up.
STORE(r.recvConnHeadPtr, r.recvConnHead);
STORE(recvConnHeadPtr, recvConnHead);
// Update opCount in case we skipped some operations
STORE(r.recvConn->opCountLoc, opCount);
STORE(recvConn->opCountLoc, opCount);
}
}
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
s.sendBuff[i] = (T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE);
s.sendStep[i] = LOAD(&conn->step);
s.sendStep[i] = ROUNDUP(s.sendStep[i], SLICESPERCHUNK*SLICESTEPS);
sendBuff[i] = (T*)LOAD(conn->buffs+NCCL_PROTO_SIMPLE);
sendStep[i] = LOAD(&conn->step);
sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
#if defined(RCCL_USE_DIRECT_BUFFER)
s.sendDirectBuff[i] = NULL;
sendDirectBuff[i] = NULL;
if (DIRECT && LOAD((&conn->direct) & NCCL_DIRECT_GPU)) {
void* volatile* ptr = LOAD(&conn->ptrExchange);
while ((s.sendDirectBuff[i] = (T*)(LOAD(ptr))) == NULL);
while ((sendDirectBuff[i] = (T*)(LOAD(ptr))) == NULL);
barrier();
if (tid == 0) STORE(ptr, NULL);
}
#endif
if (wid == i) s.sendConn = conn;
if (wid == i) s.sendConnTail = s.sendConnHead = s.sendStep[i]; // Make sure we set this after rounding up
if (wid == i) sendConn = conn;
if (wid == i) sendConnTail = sendConnHead = sendStep[i]; // Make sure we set this after rounding up
nsend++;
}
__device__ __forceinline__ void loadSendSync() {
if (tid < nsend) {
s.sendConnHeadPtr = LOAD(&s.sendConn->head);
s.sendConnHeadCache = LOAD(s.sendConnHeadPtr);
s.sendConnFifoPtr = LOAD(&s.sendConn->fifo);
STORE(s.sendConn->opCountLoc, opCount);
sendConnHeadPtr = LOAD(&sendConn->head);
sendConnHeadCache = LOAD(sendConnHeadPtr);
sendConnFifoPtr = LOAD(&sendConn->fifo);
STORE(sendConn->opCountLoc, opCount);
}
if (tid >= nthreads-WARP_SIZE && wid < nsend) {
s.sendConnTailPtr = LOAD(&s.sendConn->tail);
sendConnTailPtr = LOAD(&sendConn->tail);
}
}
__device__ __forceinline__ void saveRecvSync() {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
STORE(&r.recvConn->step, r.recvConnHead);
STORE(r.recvConn->opCountLoc, opCount+1);
STORE(&recvConn->step, recvConnHead);
STORE(recvConn->opCountLoc, opCount+1);
__threadfence_system();
}
}
__device__ __forceinline__ void saveSendSync() {
if (tid < nsend) {
STORE(&s.sendConn->step, s.sendConnHead);
STORE(s.sendConn->opCountLoc, opCount+1);
STORE(&sendConn->step, sendConnHead);
STORE(sendConn->opCountLoc, opCount+1);
__threadfence_system();
}
}
inline __device__ void init(int* recvPeers, int* sendPeers, struct ncclChannel* channel) {
public:
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) {
barriers = channel->barrier;
barrier_next = channel->barrier_next;
// Make sure step is updated before we read it.
@@ -385,25 +378,6 @@ inline __device__ int directSendInc(int i, int directInc, int sliceInc) {
loadSendSync();
}
public:
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount) {
init(recvPeers, sendPeers, channel);
}
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclPrimitivesRecvData<T, NRECV>& r)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount), r(r) {
init(recvPeers, sendPeers, channel);
}
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclPrimitivesSendData<T, NSEND>& s)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize), opCount(opCount), s(s) {
init(recvPeers, sendPeers, channel);
}
__device__ __forceinline__ void
send(const T* src, int nelem) {
GenericOp<0, 0, 0, 1, 1, 0>(src, NULL, nelem, 0);
+57 -86
Dosyayı Görüntüle
@@ -5,28 +5,6 @@
* See LICENSE.txt for license information
************************************************************************/
template <typename T, int NRECV>
class ncclLLPrimitivesRecvData {
public:
struct ncclConnInfo* recvConn = NULL;
volatile uint64_t* recvConnHeadPtr = NULL;
uint64_t recvConnHead;
uint64_t recvStep[NRECV];
union ncclLLFifoLine* recvBuff[NRECV];
};
template <typename T, int NSEND>
class ncclLLPrimitivesSendData {
public:
struct ncclConnInfo* sendConn = NULL;
volatile int* sendConnFifoPtr = NULL;
volatile uint64_t* sendConnHeadPtr = NULL;
uint64_t sendConnHead;
uint64_t sendConnHeadCache; // Cache last seen value
uint64_t sendStep[NSEND];
union ncclLLFifoLine* sendBuff[NSEND];
};
template <typename T, class FUNC, int NRECV, int NSEND>
class ncclLLPrimitives {
private:
@@ -36,25 +14,34 @@ class ncclLLPrimitives {
const int stepLines;
int nrecv = 0;
int nsend = 0;
struct ncclConnInfo* recvConn = NULL;
volatile uint64_t* recvConnHeadPtr = NULL;
uint64_t recvConnHead;
struct ncclConnInfo* sendConn = NULL;
volatile int* sendConnFifoPtr = NULL;
volatile uint64_t* sendConnHeadPtr = NULL;
uint64_t sendConnHead;
uint64_t sendConnHeadCache; // Cache last seen value
uint64_t recvStep[NRECV];
uint64_t sendStep[NSEND];
union ncclLLFifoLine* recvBuff[NRECV];
union ncclLLFifoLine* sendBuff[NSEND];
struct ncclDevComm* comm;
typename std::conditional<NRECV == NCCL_MAX_TREE_ARITY,
ncclLLPrimitivesRecvData<T, NRECV>&, ncclLLPrimitivesRecvData<T, NRECV>>::type r;
typename std::conditional<NSEND == NCCL_MAX_TREE_ARITY,
ncclLLPrimitivesSendData<T, NSEND>&, ncclLLPrimitivesSendData<T, NSEND>>::type s;
inline __device__ int recvOffset(int i) { return (r.recvStep[i]%NCCL_STEPS)*stepLines; }
inline __device__ int sendOffset(int i) { return (s.sendStep[i]%NCCL_STEPS)*stepLines; }
inline __device__ union ncclLLFifoLine* recvPtr(int i) { return r.recvBuff[i]+recvOffset(i); }
inline __device__ union ncclLLFifoLine* sendPtr(int i) { return s.sendBuff[i]+sendOffset(i); }
inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(r.recvStep[i]+1); }
inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(s.sendStep[i]+1); }
inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepLines; }
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepLines; }
inline __device__ union ncclLLFifoLine* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
inline __device__ union ncclLLFifoLine* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
inline __device__ uint32_t recvFlag(int i) { return NCCL_LL_FLAG(recvStep[i]+1); }
inline __device__ uint32_t sendFlag(int i) { return NCCL_LL_FLAG(sendStep[i]+1); }
inline __device__ void barrier() {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
__syncthreads();
#else
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
asm volatile ("basync 1, %0;" :: "r"(nthreads));
#endif
}
@@ -78,7 +65,7 @@ class ncclLLPrimitives {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = LOAD(comm->abortFlag);
if (wid == i) checkMismatch(send ? s.sendConn : r.recvConn);
if (wid == i) checkMismatch(send ? sendConn : recvConn);
spins = 0;
}
return abort;
@@ -87,35 +74,35 @@ class ncclLLPrimitives {
inline __device__ void waitSend(int nbytes) {
spins = 0;
mismatch = 0;
if (s.sendConnHeadPtr) {
while (s.sendConnHeadCache + NCCL_STEPS < s.sendConnHead + 1) {
s.sendConnHeadCache = LOAD(s.sendConnHeadPtr);
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
sendConnHeadCache = LOAD(sendConnHeadPtr);
if (checkAbort(wid, 1)) break;
}
if (s.sendConnFifoPtr) {
int size = ((s.sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
STORE(s.sendConnFifoPtr+s.sendConnHead%NCCL_STEPS, size);
if (sendConnFifoPtr) {
int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes;
STORE(sendConnFifoPtr+sendConnHead%NCCL_STEPS, size);
}
s.sendConnHead += 1;
sendConnHead += 1;
}
barrier();
}
inline __device__ void incRecv(int i) {
r.recvStep[i] += 1;
recvStep[i] += 1;
}
inline __device__ void postRecv() {
barrier();
if (r.recvConnHeadPtr) STORE(r.recvConnHeadPtr, r.recvConnHead += 1);
if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += 1);
}
inline __device__ void incSend(int i, int offset) {
// LL Cleanup : write all flags in the slice to make sure we don't have
// data corruption when flag loops over.
if ((s.sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
// data corruption when flag loops ove
if ((sendStep[i] & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) {
for (int o = offset; o<stepLines; o+=nthreads) storeLL(sendPtr(i)+o, 0, sendFlag(i));
}
s.sendStep[i]++;
sendStep[i]++;
}
__device__ uint64_t readLL(int i, int offset) {
@@ -160,7 +147,7 @@ class ncclLLPrimitives {
#endif
}
// Using memcpy handles misaligned pointers.
// Using memcpy handles misaligned pointer
__device__ uint64_t readAL(uint64_t* src) {
uint64_t val;
memcpy((char*)&val, (char*)src, sizeof(uint64_t));
@@ -213,53 +200,56 @@ class ncclLLPrimitives {
}
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
r.recvBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL);
r.recvStep[i] = LOAD(&conn->step);
if (wid == i) r.recvConn = conn;
recvBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL);
recvStep[i] = LOAD(&conn->step);
if (wid == i) recvConn = conn;
nrecv++;
}
__device__ __forceinline__ void loadRecvSync() {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
r.recvConnHeadPtr = LOAD(&r.recvConn->head);
r.recvConnHead = LOAD(&r.recvConn->step);
recvConnHeadPtr = LOAD(&recvConn->head);
recvConnHead = LOAD(&recvConn->step);
// Update opCount in case we skipped some operations
STORE(r.recvConn->opCountLoc, opCount);
STORE(recvConn->opCountLoc, opCount);
}
}
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
s.sendBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL);
s.sendStep[i] = LOAD(&conn->step);
if (wid == i) s.sendConn = conn;
sendBuff[i] = (union ncclLLFifoLine*)LOAD(conn->buffs+NCCL_PROTO_LL);
sendStep[i] = LOAD(&conn->step);
if (wid == i) sendConn = conn;
nsend++;
}
__device__ __forceinline__ void loadSendSync() {
if (tid < nsend) {
s.sendConnHeadPtr = LOAD(&s.sendConn->head);
s.sendConnHeadCache = LOAD(s.sendConnHeadPtr);
s.sendConnHead = LOAD(&s.sendConn->step);
s.sendConnFifoPtr = LOAD(&s.sendConn->fifo);
STORE(s.sendConn->opCountLoc, opCount);
sendConnHeadPtr = LOAD(&sendConn->head);
sendConnHeadCache = LOAD(sendConnHeadPtr);
sendConnHead = LOAD(&sendConn->step);
sendConnFifoPtr = LOAD(&sendConn->fifo);
STORE(sendConn->opCountLoc, opCount);
}
}
__device__ __forceinline__ void saveRecvSync() {
if (tid >= nthreads-WARP_SIZE && wid < nrecv) {
STORE(&r.recvConn->step, r.recvConnHead);
STORE(r.recvConn->opCountLoc, opCount+1);
STORE(&recvConn->step, recvConnHead);
STORE(recvConn->opCountLoc, opCount+1);
__threadfence_block();
}
}
__device__ __forceinline__ void saveSendSync() {
if (tid < nsend) {
STORE(&s.sendConn->step, s.sendConnHead);
STORE(s.sendConn->opCountLoc, opCount+1);
STORE(&sendConn->step, sendConnHead);
STORE(sendConn->opCountLoc, opCount+1);
__threadfence_block();
}
}
inline __device__ void init(int* recvPeers, int* sendPeers, struct ncclChannel* channel) {
public:
__device__ __forceinline__
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) {
// Make sure step is updated before we read it.
barrier();
@@ -269,25 +259,6 @@ class ncclLLPrimitives {
loadSendSync();
}
public:
__device__ __forceinline__
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount) {
init(recvPeers, sendPeers, channel);
}
__device__ __forceinline__
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclLLPrimitivesRecvData<T, NRECV>& r)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount), r(r) {
init(recvPeers, sendPeers, channel);
}
__device__ __forceinline__
ncclLLPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepLines, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount, ncclLLPrimitivesSendData<T, NSEND>& s)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepLines(stepLines), opCount(opCount), s(s) {
init(recvPeers, sendPeers, channel);
}
__device__ void send(const T* src, int nelem) {
return LLGenericOp<0, 1, 1, 0>(src, NULL, nelem);
}