diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index 5c6d8c1aea..945937f264 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -223,9 +223,9 @@ class ncclFunction { #endif __device__ inline bool barrierReduceAny(int bit, uint32_t* abortCount) { - if (bit) __atomic_fetch_add(abortCount, 1, __ATOMIC_SEQ_CST); \ + if (bit) atomicAdd(abortCount, 1); \ __syncthreads(); \ - return LOAD(abortCount) != 0; + return atomicAdd(abortCount, 0) != 0; } template diff --git a/projects/rccl/src/collectives/device/primitives.h b/projects/rccl/src/collectives/device/primitives.h index 11146890f6..2b500e230c 100644 --- a/projects/rccl/src/collectives/device/primitives.h +++ b/projects/rccl/src/collectives/device/primitives.h @@ -22,8 +22,9 @@ const int wid = threadIdx.x%WARP_SIZE; \ if (wid == 0) { \ barrier_next[w] += nthreads/WARP_SIZE; \ - __atomic_fetch_add(barriers, 1, __ATOMIC_SEQ_CST); \ - while (LOAD(barriers) < barrier_next[w]) /* spin */; \ + atomicAdd((unsigned long long *)barriers, 1); \ + while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(8); \ + __asm__ __volatile__("s_wakeup"); \ } \ } \ } while (0) diff --git a/projects/rccl/src/collectives/device/prims_simple.h b/projects/rccl/src/collectives/device/prims_simple.h index 5d870fdf1a..fec3a4f51b 100644 --- a/projects/rccl/src/collectives/device/prims_simple.h +++ b/projects/rccl/src/collectives/device/prims_simple.h @@ -76,7 +76,7 @@ class Primitives< inline __device__ bool checkAbort(int &spins) { spins++; if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { - flags |= LOAD(ncclShmem->comm.abortFlag) ? Aborted : 0; + flags |= atomicAdd_system((unsigned int *)ncclShmem->comm.abortFlag, 0) ? Aborted : 0; spins = 0; } return flags & Aborted; @@ -88,10 +88,12 @@ class Primitives< bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { - connStepCache = LOAD(connStepPtr); + __builtin_amdgcn_s_sleep(8); + connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0); if (checkAbort(spins)) break; //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); } + __asm__ __volatile__("s_wakeup"); if (isSendNotRecv && (flags & SizesFifoEnabled)) STORE(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T)); @@ -112,7 +114,7 @@ class Primitives< inline __device__ void postPeer() { if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { step += StepPerSlice; - STORE(connStepPtr, step); + atomicExch_system((unsigned long long *)connStepPtr, step); } } diff --git a/projects/rccl/src/include/devcomm.h b/projects/rccl/src/include/devcomm.h index 001c74cd6d..ae95f19cd0 100644 --- a/projects/rccl/src/include/devcomm.h +++ b/projects/rccl/src/include/devcomm.h @@ -18,8 +18,8 @@ // Convert volatile access to atomic #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) -#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_ACQUIRE) -#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_RELEASE) +#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_SEQ_CST) +#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST) #else #define LOAD(VAR) *(VAR) #define STORE(DST, SRC) *(DST) = (SRC)