Merge remote-tracking branch 'origin/develop' into 2.10.3

[ROCm/rccl commit: 3667d308ab]
This commit is contained in:
Wenkai Du
2021-09-13 17:19:07 -07:00
commit d6064367f0
4 muutettua tiedostoa jossa 12 lisäystä ja 9 poistoa
@@ -223,9 +223,9 @@ class ncclFunction {
#endif
__device__ inline bool barrierReduceAny(int bit, uint32_t* abortCount) {
if (bit) __atomic_fetch_add(abortCount, 1, __ATOMIC_SEQ_CST); \
if (bit) atomicAdd(abortCount, 1); \
__syncthreads(); \
return LOAD(abortCount) != 0;
return atomicAdd(abortCount, 0) != 0;
}
template<typename T>
@@ -22,8 +22,9 @@
const int wid = threadIdx.x%WARP_SIZE; \
if (wid == 0) { \
barrier_next[w] += nthreads/WARP_SIZE; \
__atomic_fetch_add(barriers, 1, __ATOMIC_SEQ_CST); \
while (LOAD(barriers) < barrier_next[w]) /* spin */; \
atomicAdd((unsigned long long *)barriers, 1); \
while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(8); \
__asm__ __volatile__("s_wakeup"); \
} \
} \
} while (0)
@@ -76,7 +76,7 @@ class Primitives<
inline __device__ bool checkAbort(int &spins) {
spins++;
if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
flags |= LOAD(ncclShmem->comm.abortFlag) ? Aborted : 0;
flags |= atomicAdd_system((unsigned int *)ncclShmem->comm.abortFlag, 0) ? Aborted : 0;
spins = 0;
}
return flags & Aborted;
@@ -88,10 +88,12 @@ class Primitives<
bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
int spins = 0;
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
connStepCache = LOAD(connStepPtr);
__builtin_amdgcn_s_sleep(8);
connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0);
if (checkAbort(spins)) break;
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
}
__asm__ __volatile__("s_wakeup");
if (isSendNotRecv && (flags & SizesFifoEnabled))
STORE(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T));
@@ -112,7 +114,7 @@ class Primitives<
inline __device__ void postPeer() {
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
step += StepPerSlice;
STORE(connStepPtr, step);
atomicExch_system((unsigned long long *)connStepPtr, step);
}
}
+2 -2
Näytä tiedosto
@@ -18,8 +18,8 @@
// Convert volatile access to atomic
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_ACQUIRE)
#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_RELEASE)
#define LOAD(VAR) __atomic_load_n((VAR), __ATOMIC_SEQ_CST)
#define STORE(DST, SRC) __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST)
#else
#define LOAD(VAR) *(VAR)
#define STORE(DST, SRC) *(DST) = (SRC)