Fix memory fence and use non-temporal store (#1007)

* Fix memory fence and use non-temporal store

* Use amdgcn builtin instead of inline asm

* Move threadfence location

* Revert changes to gfx90a

* Rework gfx90a change

* Apply changes to gfx94x
This commit is contained in:
Wenkai Du
2023-12-09 12:16:08 -08:00
committed by GitHub
parent c002f20029
commit 7965c8b53c
5 changed files with 24 additions and 7 deletions
+1 -1
View File
@@ -26,7 +26,7 @@
#ifdef __GFX9__
#define STORE(DST, SRC) \
{ __threadfence(); __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); }
{ __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); }
#else
#define STORE(DST, SRC) \
{ __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); }
+2 -2
View File
@@ -227,8 +227,8 @@ DEFINE_ld_st(8, uint64_t, b64, l, global, uintptr_t, l)
} \
template<> \
__device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \
*((uint64_t*)addr) = value.u64[0]; \
*((uint64_t*)addr+1) = value.u64[1]; \
__builtin_nontemporal_store(value.u64[0], (uint64_t*)addr); \
__builtin_nontemporal_store(value.u64[1], (uint64_t*)addr+1); \
}
DEFINE_ld_st_16(global, uintptr_t, l)
//DEFINE_ld_st_16(shared, uint32_t, r)
+18 -1
View File
@@ -15,9 +15,25 @@
#define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define barrier_by_group() do { \
if (nthreads == NCCL_MAX_NTHREADS) { \
__asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)\ns_barrier\ns_waitcnt lgkmcnt(0)"); \
__builtin_amdgcn_s_barrier(); \
} else { \
const int w = threadIdx.x/WARP_SIZE; \
const int wid = threadIdx.x%WARP_SIZE; \
if (wid == 0) { \
barrier_next[w] += nthreads/WARP_SIZE; \
atomicAdd((unsigned long long *)barriers, 1); \
while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(1); \
__asm__ __volatile__("s_wakeup"); \
} \
} \
} while (0)
#else
#define barrier_by_group() do { \
if (nthreads == NCCL_MAX_NTHREADS) { \
__threadfence(); __builtin_amdgcn_s_barrier(); \
} else { \
const int w = threadIdx.x/WARP_SIZE; \
const int wid = threadIdx.x%WARP_SIZE; \
@@ -30,6 +46,7 @@
} \
} \
} while (0)
#endif
/* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128
* We use these as template args to the Primtiives class instead of integral
+1 -1
View File
@@ -114,7 +114,7 @@ private:
if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += 1);
}
inline __device__ void postSend() {
if (sendConnTailPtr) { __threadfence(); STORE((unsigned long long *)sendConnTailPtr, sendConnTail += 1); }
if (sendConnTailPtr) { STORE((unsigned long long *)sendConnTailPtr, sendConnTail += 1); }
}
template<int WordPerThread>
+2 -2
View File
@@ -161,9 +161,9 @@ private:
inline __device__ void postPeer(bool dataStored) {
if (Send && (flags & RolePostSend) && dataStored)
#ifdef __GFX9__
__builtin_amdgcn_buffer_wbinvl1();
__threadfence();
#else
__threadfence_system();
__threadfence_system();
#endif
if ((flags & Send*RolePostSend) && next_hdp_reg)