diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 37fe97f705..4d22026d3d 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -26,7 +26,7 @@ #ifdef __GFX9__ #define STORE(DST, SRC) \ - { __threadfence(); __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); } + { __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); } #else #define STORE(DST, SRC) \ { __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); } diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h index 586b50e37d..fc37e29302 100644 --- a/src/collectives/device/op128.h +++ b/src/collectives/device/op128.h @@ -227,8 +227,8 @@ DEFINE_ld_st(8, uint64_t, b64, l, global, uintptr_t, l) } \ template<> \ __device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \ - *((uint64_t*)addr) = value.u64[0]; \ - *((uint64_t*)addr+1) = value.u64[1]; \ + __builtin_nontemporal_store(value.u64[0], (uint64_t*)addr); \ + __builtin_nontemporal_store(value.u64[1], (uint64_t*)addr+1); \ } DEFINE_ld_st_16(global, uintptr_t, l) //DEFINE_ld_st_16(shared, uint32_t, r) diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index a1a5d173cb..9c6c266a80 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -15,9 +15,25 @@ #define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define barrier_by_group() do { \ if (nthreads == NCCL_MAX_NTHREADS) { \ - __asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)\ns_barrier\ns_waitcnt lgkmcnt(0)"); \ + __builtin_amdgcn_s_barrier(); \ + } else { \ + const int w = threadIdx.x/WARP_SIZE; \ + const int wid = threadIdx.x%WARP_SIZE; \ + if (wid == 0) { \ + barrier_next[w] += nthreads/WARP_SIZE; \ + atomicAdd((unsigned long long *)barriers, 1); \ + while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(1); \ + __asm__ __volatile__("s_wakeup"); \ + } \ + } \ +} while (0) +#else +#define barrier_by_group() do { \ + if (nthreads == NCCL_MAX_NTHREADS) { \ + __threadfence(); __builtin_amdgcn_s_barrier(); \ } else { \ const int w = threadIdx.x/WARP_SIZE; \ const int wid = threadIdx.x%WARP_SIZE; \ @@ -30,6 +46,7 @@ } \ } \ } while (0) +#endif /* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128 * We use these as template args to the Primtiives class instead of integral diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 858ae5ca71..561a490bd1 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -114,7 +114,7 @@ private: if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += 1); } inline __device__ void postSend() { - if (sendConnTailPtr) { __threadfence(); STORE((unsigned long long *)sendConnTailPtr, sendConnTail += 1); } + if (sendConnTailPtr) { STORE((unsigned long long *)sendConnTailPtr, sendConnTail += 1); } } template diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index 70e6333ca0..5337675fd2 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -161,9 +161,9 @@ private: inline __device__ void postPeer(bool dataStored) { if (Send && (flags & RolePostSend) && dataStored) #ifdef __GFX9__ - __builtin_amdgcn_buffer_wbinvl1(); + __threadfence(); #else - __threadfence_system(); + __threadfence_system(); #endif if ((flags & Send*RolePostSend) && next_hdp_reg)