Use relaxed atomics and add sleep and wakeup in barrier loop (#425)

* Use relaxed atomics and add sleep and wakeup in barrier loop

* atomicAdd in ROCm 4.3 only support unsigned long long

* Switch to atomicAdd and atomicExch in more places

* Restore LOAD/STORE define to __ATOMIC_SEQ_CST

* Restore atomic for sizes FIFO

[ROCm/rccl commit: 020484bf40]
Этот коммит содержится в:
Wenkai Du
2021-09-13 17:03:49 -07:00
коммит произвёл GitHub
родитель 9ffeb41fe1
Коммит f4387b2954
3 изменённых файлов: 17 добавлений и 12 удалений
+3 -3
Просмотреть файл
@@ -19,9 +19,9 @@
// all CTA's threads enter the barrier and do a popc on their predicates being True
// If any of the thread's predicate was True, all the threads call exit()
#define exitIfAbortBarrier(abort, abortCount) \
if (abort) __atomic_fetch_add(abortCount, 1, __ATOMIC_SEQ_CST); \
if (abort) atomicAdd(abortCount, 1); \
__syncthreads(); \
if (LOAD(abortCount)) { /*asm volatile ("s_endpgm");*/ return false; }
if (atomicAdd(abortCount, 0)) { /*asm volatile ("s_endpgm");*/ return false; }
#define __syncwarp()
#define NCCL_FUNC5(func, algo, redop, type) \
@@ -184,7 +184,7 @@ static __device__ void load_parallel(void* dst, void* src, size_t size, int tid)
static __device__ bool load_coll(struct ncclWork* localWork, struct ncclWork *hostWork, struct ncclWork* workFifo, int tid, struct ncclDevComm* comm, uint32_t* abortCount) {
load_parallel(localWork, workFifo, sizeof(struct ncclWork), tid);
// Check whether the last operation was aborted and make sure all threads exit
int abort = tid == 0 ? LOAD(comm->abortFlag) : 0;
int abort = tid == 0 ? atomicAdd_system((unsigned int *)comm->abortFlag, 0) : 0;
exitIfAbortBarrier(abort, abortCount);
if (tid == 0) hostWork->elems[0].active = 0;
return true;
+12 -7
Просмотреть файл
@@ -40,8 +40,9 @@
const int wid = threadIdx.x%WARP_SIZE; \
if (wid == 0) { \
barrier_next[w] += nthreads/WARP_SIZE; \
__atomic_fetch_add(barriers, 1, __ATOMIC_SEQ_CST); \
while (LOAD(barriers) < barrier_next[w]) /* spin */; \
atomicAdd((unsigned long long *)barriers, 1); \
while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(8); \
__asm__ __volatile__("s_wakeup"); \
} \
} \
} while (0)
@@ -119,7 +120,7 @@ class ncclPrimitives {
inline __device__ int checkAbort() {
spins++;
if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) {
abort = LOAD(comm->abortFlag);
abort = atomicAdd_system((unsigned int *)comm->abortFlag, 0);
spins = 0;
}
return abort;
@@ -134,9 +135,11 @@ class ncclPrimitives {
inline __device__ void waitSend(ssize_t directOffset, int nbytes) {
spins = 0;
while (connHeadCache + NCCL_STEPS < step + SLICESTEPS) {
connHeadCache = LOAD(connHeadPtr);
__builtin_amdgcn_s_sleep(8);
connHeadCache = atomicAdd_system((unsigned long long *)connHeadPtr, 0);
if (checkAbort()) break;
}
__asm__ __volatile__("s_wakeup");
if (connSizesFifoPtr) {
STORE(connSizesFifoPtr+step%NCCL_STEPS, nbytes);
}
@@ -154,9 +157,11 @@ class ncclPrimitives {
if (tid == 0) t0 = __builtin_amdgcn_s_memrealtime();
#endif
while (connTailCache < step + SLICESTEPS) {
connTailCache = LOAD(connTailPtr);
__builtin_amdgcn_s_sleep(8);
connTailCache = atomicAdd_system((unsigned long long *)connTailPtr, 0);
if (checkAbort()) break;
}
__asm__ __volatile__("s_wakeup");
#ifdef ENABLE_PROFILING
if (tid == 0) comm->devProf->elems[blockIdx.x].wait_recv_cycle += (__builtin_amdgcn_s_memrealtime() - t0);
#endif
@@ -166,12 +171,12 @@ class ncclPrimitives {
}
inline __device__ void postRecv() {
STORE(connHeadPtr, step += SLICESTEPS);
atomicExch_system((unsigned long long *)connHeadPtr, step += SLICESTEPS);
}
inline __device__ void postSend() {
if (conn->next_hdp_reg) STORE(conn->next_hdp_reg, 0x1);
STORE(connTailPtr, step += SLICESTEPS);
atomicExch_system((unsigned long long *)connTailPtr, step += SLICESTEPS);
}
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
+2 -2
Просмотреть файл
@@ -18,8 +18,8 @@
// Convert volatile access to atomic
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_ACQUIRE)
#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_RELEASE)
#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_SEQ_CST)
#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST)
#else
#define LOAD(VAR) *(VAR)
#define STORE(DST, SRC) *(DST) = (SRC)