diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index 980989ab62..4d46ea1192 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -19,9 +19,9 @@ // all CTA's threads enter the barrier and do a popc on their predicates being True // If any of the thread's predicate was True, all the threads call exit() #define exitIfAbortBarrier(abort, abortCount) \ - if (abort) __atomic_fetch_add(abortCount, 1, __ATOMIC_SEQ_CST); \ + if (abort) atomicAdd(abortCount, 1); \ __syncthreads(); \ - if (LOAD(abortCount)) { /*asm volatile ("s_endpgm");*/ return false; } + if (atomicAdd(abortCount, 0)) { /*asm volatile ("s_endpgm");*/ return false; } #define __syncwarp() #define NCCL_FUNC5(func, algo, redop, type) \ @@ -184,7 +184,7 @@ static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) static __device__ bool load_coll(struct ncclWork* localWork, struct ncclWork *hostWork, struct ncclWork* workFifo, int tid, struct ncclDevComm* comm, uint32_t* abortCount) { load_parallel(localWork, workFifo, sizeof(struct ncclWork), tid); // Check whether the last operation was aborted and make sure all threads exit - int abort = tid == 0 ? LOAD(comm->abortFlag) : 0; + int abort = tid == 0 ? atomicAdd_system((unsigned int *)comm->abortFlag, 0) : 0; exitIfAbortBarrier(abort, abortCount); if (tid == 0) hostWork->elems[0].active = 0; return true; diff --git a/projects/rccl/src/collectives/device/primitives.h b/projects/rccl/src/collectives/device/primitives.h index c2c2e2cd10..395dae600a 100644 --- a/projects/rccl/src/collectives/device/primitives.h +++ b/projects/rccl/src/collectives/device/primitives.h @@ -40,8 +40,9 @@ const int wid = threadIdx.x%WARP_SIZE; \ if (wid == 0) { \ barrier_next[w] += nthreads/WARP_SIZE; \ - __atomic_fetch_add(barriers, 1, __ATOMIC_SEQ_CST); \ - while (LOAD(barriers) < barrier_next[w]) /* spin */; \ + atomicAdd((unsigned long long *)barriers, 1); \ + while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(8); \ + __asm__ __volatile__("s_wakeup"); \ } \ } \ } while (0) @@ -119,7 +120,7 @@ class ncclPrimitives { inline __device__ int checkAbort() { spins++; if (abort == 0 && spins == SPINS_BEFORE_CHECK_ABORT) { - abort = LOAD(comm->abortFlag); + abort = atomicAdd_system((unsigned int *)comm->abortFlag, 0); spins = 0; } return abort; @@ -134,9 +135,11 @@ class ncclPrimitives { inline __device__ void waitSend(ssize_t directOffset, int nbytes) { spins = 0; while (connHeadCache + NCCL_STEPS < step + SLICESTEPS) { - connHeadCache = LOAD(connHeadPtr); + __builtin_amdgcn_s_sleep(8); + connHeadCache = atomicAdd_system((unsigned long long *)connHeadPtr, 0); if (checkAbort()) break; } + __asm__ __volatile__("s_wakeup"); if (connSizesFifoPtr) { STORE(connSizesFifoPtr+step%NCCL_STEPS, nbytes); } @@ -154,9 +157,11 @@ class ncclPrimitives { if (tid == 0) t0 = __builtin_amdgcn_s_memrealtime(); #endif while (connTailCache < step + SLICESTEPS) { - connTailCache = LOAD(connTailPtr); + __builtin_amdgcn_s_sleep(8); + connTailCache = atomicAdd_system((unsigned long long *)connTailPtr, 0); if (checkAbort()) break; } + __asm__ __volatile__("s_wakeup"); #ifdef ENABLE_PROFILING if (tid == 0) comm->devProf->elems[blockIdx.x].wait_recv_cycle += (__builtin_amdgcn_s_memrealtime() - t0); #endif @@ -166,12 +171,12 @@ class ncclPrimitives { } inline __device__ void postRecv() { - STORE(connHeadPtr, step += SLICESTEPS); + atomicExch_system((unsigned long long *)connHeadPtr, step += SLICESTEPS); } inline __device__ void postSend() { if (conn->next_hdp_reg) STORE(conn->next_hdp_reg, 0x1); - STORE(connTailPtr, step += SLICESTEPS); + atomicExch_system((unsigned long long *)connTailPtr, step += SLICESTEPS); } template diff --git a/projects/rccl/src/include/devcomm.h b/projects/rccl/src/include/devcomm.h index b7923f97eb..42976e36c7 100644 --- a/projects/rccl/src/include/devcomm.h +++ b/projects/rccl/src/include/devcomm.h @@ -18,8 +18,8 @@ // Convert volatile access to atomic #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) -#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_ACQUIRE) -#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_RELEASE) +#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_SEQ_CST) +#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST) #else #define LOAD(VAR) *(VAR) #define STORE(DST, SRC) *(DST) = (SRC)