From 5a38ff192bde2ca6f6b86795b324ef035ffbf6da Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Wed, 31 May 2023 13:36:51 -0700 Subject: [PATCH] Rework barrier and event code (#761) * Rework barrier and event code * Switch to inline asm --- src/channel.cc | 2 ++ src/collectives/device/primitives.h | 7 +++---- src/misc/strongstream.cc | 2 -- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/channel.cc b/src/channel.cc index eecaa20833..ed4c623d30 100644 --- a/src/channel.cc +++ b/src/channel.cc @@ -29,6 +29,8 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) { ncclCommPushCudaFree(comm, channel->devRingUserRanks); NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); + CUDACHECK(hipEventRecord(comm->deviceStream.scratchEvent, comm->deviceStream.cudaStream)); + CUDACHECK(hipStreamWaitEvent(comm->deviceStream.cudaStream, comm->deviceStream.scratchEvent, 0)); for (int r=0; r < nPeers; ++r) { for (int b=0; b < NCCL_MAX_CONNS; b++) { diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 19191b6c12..024df18c64 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -16,12 +16,11 @@ #define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 #define barrier_by_group() do { \ - const int w = threadIdx.x/WARP_SIZE; \ - const int wid = threadIdx.x%WARP_SIZE; \ - __threadfence(); \ if (nthreads == NCCL_MAX_NTHREADS) { \ - __syncthreads(); \ + __asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)\ns_barrier\ns_waitcnt lgkmcnt(0)"); \ } else { \ + const int w = threadIdx.x/WARP_SIZE; \ + const int wid = threadIdx.x%WARP_SIZE; \ if (wid == 0) { \ barrier_next[w] += nthreads/WARP_SIZE; \ atomicAdd((unsigned long long *)barriers, 1); \ diff --git a/src/misc/strongstream.cc b/src/misc/strongstream.cc index 27186cc8ac..faeec4bca3 100644 --- a/src/misc/strongstream.cc +++ b/src/misc/strongstream.cc @@ -237,8 +237,6 @@ ncclResult_t ncclStrongStreamRelease(struct ncclCudaGraph graph, struct ncclStro } } #endif - CUDACHECK(cudaEventRecord(ss->scratchEvent, ss->cudaStream)); - CUDACHECK(cudaStreamWaitEvent(ss->cudaStream, ss->scratchEvent, 0)); return ncclSuccess; }