Merge pull request #681 from wenkaidu/gfx9

Add HIP event optimization and remove special code for gfx90a
This commit is contained in:
Wenkai Du
2023-02-13 08:04:59 -08:00
gecommit door GitHub
bovenliggende f525b8e1e6 39534e8724
commit 9461a43168
7 gewijzigde bestanden met toevoegingen van 31 en 36 verwijderingen
+4
Bestand weergeven
@@ -103,6 +103,10 @@ list(APPEND CMAKE_PREFIX_PATH
find_package(hip REQUIRED)
message(STATUS "HIP compiler: ${HIP_COMPILER}")
message(STATUS "HIP runtime: ${HIP_RUNTIME}")
check_symbol_exists("hipEventDisableSystemFence" "hip/hip_runtime_api.h" HIP_EVENT_DISABLE_FENCE)
if(${HIP_EVENT_DISABLE_FENCE})
add_definitions(-DHIP_EVENT_DISABLE_FENCE)
endif()
find_package(hsa-runtime64 REQUIRED)
get_target_property(HSA_INCLUDE_PATH hsa-runtime64::hsa-runtime64 INTERFACE_INCLUDE_DIRECTORIES)
+2 -2
Bestand weergeven
@@ -19,9 +19,9 @@
#define __synclds() \
asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier");
#if defined(__gfx90a__)
#ifdef __GFX9__
#define STORE(DST, SRC) \
{ __threadfence_block(); atomicExch((unsigned long long *)(DST), (SRC)); }
{ atomicExch((unsigned long long *)(DST), (SRC)); }
#else
#define STORE(DST, SRC) \
{ __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); }
@@ -20,7 +20,7 @@ static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; }
inline __device__ int loadInt(int* ptr) {
int v;
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
v = atomicAdd_system((unsigned long long *)ptr, 0);
v = atomicAdd((unsigned long long *)ptr, 0);
#else
asm volatile("ld.volatile.global.u32 %0, [%1];"
: "=r"(v) : "l"(ptr));
+8 -11
Bestand weergeven
@@ -15,17 +15,14 @@
#define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000
#define barrier_by_group() do { \
if (nthreads == NCCL_MAX_NTHREADS) \
__syncthreads(); \
else { \
const int w = threadIdx.x/WARP_SIZE; \
const int wid = threadIdx.x%WARP_SIZE; \
if (wid == 0) { \
barrier_next[w] += nthreads/WARP_SIZE; \
atomicAdd((unsigned long long *)barriers, 1); \
while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(1); \
__asm__ __volatile__("s_wakeup"); \
} \
const int w = threadIdx.x/WARP_SIZE; \
const int wid = threadIdx.x%WARP_SIZE; \
if (wid == 0) { \
__asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)"); \
barrier_next[w] += nthreads/WARP_SIZE; \
atomicAdd((unsigned long long *)barriers, 1); \
while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(1); \
__asm__ __volatile__("s_wakeup"); \
} \
} while (0)
+1 -1
Bestand weergeven
@@ -104,7 +104,7 @@ private:
int spins = 0;
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
__builtin_amdgcn_s_sleep(1);
sendConnHeadCache = atomicAdd_system((unsigned long long *)sendConnHeadPtr, 0);
sendConnHeadCache = atomicAdd((unsigned long long *)sendConnHeadPtr, 0);
if (checkAbort(spins, 1)) break;
}
__asm__ __volatile__("s_wakeup");
+11 -21
Bestand weergeven
@@ -112,7 +112,7 @@ private:
int spins = 0;
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
__builtin_amdgcn_s_sleep(1);
connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0);
connStepCache = atomicAdd((unsigned long long *)connStepPtr, 0);
if (checkAbort(spins)) break;
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
if (spins == 0) traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
@@ -327,13 +327,8 @@ private:
}
barrier(); // This barrier has a counterpart in following loop
#if defined(__gfx90a__)
if (Send && (flags & RolePostSend) && index == 0) {
if (MaxSend == 0 || MaxRecv == 0)
__threadfence_system();
else
__asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0); buffer_wbinvl1_vol");
}
#ifdef __GFX9__
if (Send && (flags & RolePostSend) && index == 0) __asm__ __volatile__("buffer_wbinvl1_vol");
#else
if (Send && (flags & RolePostSend) && index == 0) __threadfence_system();
#endif
@@ -355,13 +350,8 @@ private:
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(0, 0, 0, 0);
}
barrier(); // Has couterpart in preceding worker-only loop.
#if defined(__gfx90a__)
if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) {
if (MaxSend == 0 || MaxRecv == 0)
__threadfence_system();
else
__asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0); buffer_wbinvl1_vol");
}
#ifdef __GFX9__
if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __asm__ __volatile__("buffer_wbinvl1_vol");
#else
if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __threadfence_system();
#endif
@@ -482,7 +472,7 @@ private:
if (flags & RoleWaitRecv) {
ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs()
connStepPtr = conn->tail;
connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0);
connStepCache = atomicAdd((unsigned long long *)connStepPtr, 0);
flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0;
if (Direct) {
// User buffers have been registered
@@ -522,7 +512,7 @@ private:
if (flags & RoleWaitSend) {
ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs()
connStepPtr = conn->head;
connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0);
connStepCache = atomicAdd((unsigned long long *)connStepPtr, 0);
flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0;
if (flags & OffsFifoEnabled)
connOffsFifoPtr = conn->offsFifo;
@@ -634,7 +624,7 @@ private:
int spins = 0;
void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange;
// Wait for consumer to consume previous value before trampling it.
while ((void *)atomicAdd_system((unsigned long long *) slot,0) != nullptr && !checkAbort(spins));
while ((void *)atomicAdd((unsigned long long *) slot,0) != nullptr && !checkAbort(spins));
directBuff = (T*)outputBuf;
// Encode pointer by XOR'ing against some address they definitely wouldn't send
// since we want to allow them sending us nullptr while not colliding with
@@ -646,7 +636,7 @@ private:
void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange;
void *ptr;
while (true) {
ptr = (void *)atomicAdd_system((unsigned long long *) slot,0);
ptr = (void *)atomicAdd((unsigned long long *) slot,0);
if (ptr != nullptr || checkAbort(spins)) break;
}
directBuff = regUsed ? (T*)(e->dnOutputs[index]) :
@@ -659,7 +649,7 @@ private:
volatile uint64_t* argSlot0 = ncclShmem.groups[group].sendConns[index]->redOpArgExchange;
volatile uint64_t* argSlot1 = ncclShmem.groups[group].sendConns[index]->redOpArgExchange+1;
// Wait for consumer to consume previous value before trampling it.
while (((void *)atomicAdd_system((unsigned long long *) slot,0) != nullptr || *argSlot0 != 0 || *argSlot1 !=0) && !checkAbort(spins));
while (((void *)atomicAdd((unsigned long long *) slot,0) != nullptr || *argSlot0 != 0 || *argSlot1 !=0) && !checkAbort(spins));
// If there is no recv, then we are directly pulling from input buffer (e.g. directScatter)
// Otherwise, we are pulling from output buffer (e.g. recvCopyDirectSend)
directBuff = MaxRecv == 0 ? (T*)inputBuf : (T*)outputBuf;
@@ -678,7 +668,7 @@ private:
volatile uint64_t* argSlot1 = ncclShmem.groups[group].recvConns[index]->redOpArgExchange+1;
void *ptr;
while (true) {
ptr = (void *)atomicAdd_system((unsigned long long *) slot,0);
ptr = (void *)atomicAdd((unsigned long long *) slot,0);
if (ptr != nullptr || checkAbort(spins)) break;
}
directBuff = regUsed ? (T*)(MaxSend == 0 ? e->upOutputs[index] : e->dnInputs[index]) :
+4
Bestand weergeven
@@ -472,7 +472,11 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank, int virtua
// Try to create a CUDA object right away. If there is something wrong with
// the device we're on (failure cause #1) , better know it early.
hipEvent_t doneEvent;
#ifdef HIP_EVENT_DISABLE_FENCE
CUDACHECK(hipEventCreateWithFlags(&doneEvent, hipEventDisableTiming|hipEventDisableSystemFence));
#else
CUDACHECK(hipEventCreateWithFlags(&doneEvent, hipEventDisableTiming));
#endif
NCCLCHECK(ncclStrongStreamConstruct(&comm->deviceStream));
NCCLCHECK(ncclStrongStreamConstruct(&comm->hostStream));