Fix memory fence and use non-temporal store (#1007)
* Fix memory fence and use non-temporal store * Use amdgcn builtin instead of inline asm * Move threadfence location * Revert changes to gfx90a * Rework gfx90a change * Apply changes to gfx94x
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
|
||||
#ifdef __GFX9__
|
||||
#define STORE(DST, SRC) \
|
||||
{ __threadfence(); __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); }
|
||||
{ __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); }
|
||||
#else
|
||||
#define STORE(DST, SRC) \
|
||||
{ __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); }
|
||||
|
||||
@@ -227,8 +227,8 @@ DEFINE_ld_st(8, uint64_t, b64, l, global, uintptr_t, l)
|
||||
} \
|
||||
template<> \
|
||||
__device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \
|
||||
*((uint64_t*)addr) = value.u64[0]; \
|
||||
*((uint64_t*)addr+1) = value.u64[1]; \
|
||||
__builtin_nontemporal_store(value.u64[0], (uint64_t*)addr); \
|
||||
__builtin_nontemporal_store(value.u64[1], (uint64_t*)addr+1); \
|
||||
}
|
||||
DEFINE_ld_st_16(global, uintptr_t, l)
|
||||
//DEFINE_ld_st_16(shared, uint32_t, r)
|
||||
|
||||
@@ -15,9 +15,25 @@
|
||||
|
||||
#define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000
|
||||
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define barrier_by_group() do { \
|
||||
if (nthreads == NCCL_MAX_NTHREADS) { \
|
||||
__asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)\ns_barrier\ns_waitcnt lgkmcnt(0)"); \
|
||||
__builtin_amdgcn_s_barrier(); \
|
||||
} else { \
|
||||
const int w = threadIdx.x/WARP_SIZE; \
|
||||
const int wid = threadIdx.x%WARP_SIZE; \
|
||||
if (wid == 0) { \
|
||||
barrier_next[w] += nthreads/WARP_SIZE; \
|
||||
atomicAdd((unsigned long long *)barriers, 1); \
|
||||
while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(1); \
|
||||
__asm__ __volatile__("s_wakeup"); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define barrier_by_group() do { \
|
||||
if (nthreads == NCCL_MAX_NTHREADS) { \
|
||||
__threadfence(); __builtin_amdgcn_s_barrier(); \
|
||||
} else { \
|
||||
const int w = threadIdx.x/WARP_SIZE; \
|
||||
const int wid = threadIdx.x%WARP_SIZE; \
|
||||
@@ -30,6 +46,7 @@
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
/* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128
|
||||
* We use these as template args to the Primtiives class instead of integral
|
||||
|
||||
@@ -114,7 +114,7 @@ private:
|
||||
if (recvConnHeadPtr) STORE(recvConnHeadPtr, recvConnHead += 1);
|
||||
}
|
||||
inline __device__ void postSend() {
|
||||
if (sendConnTailPtr) { __threadfence(); STORE((unsigned long long *)sendConnTailPtr, sendConnTail += 1); }
|
||||
if (sendConnTailPtr) { STORE((unsigned long long *)sendConnTailPtr, sendConnTail += 1); }
|
||||
}
|
||||
|
||||
template<int WordPerThread>
|
||||
|
||||
@@ -161,9 +161,9 @@ private:
|
||||
inline __device__ void postPeer(bool dataStored) {
|
||||
if (Send && (flags & RolePostSend) && dataStored)
|
||||
#ifdef __GFX9__
|
||||
__builtin_amdgcn_buffer_wbinvl1();
|
||||
__threadfence();
|
||||
#else
|
||||
__threadfence_system();
|
||||
__threadfence_system();
|
||||
#endif
|
||||
|
||||
if ((flags & Send*RolePostSend) && next_hdp_reg)
|
||||
|
||||
Reference in New Issue
Block a user