diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index 448e2f7b45..cb16565ab0 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -16,6 +16,9 @@ #define __syncwarp() +#define __synclds() \ + asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier"); + #define NCCL_FUNC5(func, algo, devredop, type, nullify) \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ @@ -509,7 +512,7 @@ __forceinline__ __device__ void ncclKernel( } } } - __syncthreads(); // publish ncclShmem.channelId + __synclds(); // publish ncclShmem.channelId int channelId = ncclShmem.channelId; if (true) { @@ -542,7 +545,7 @@ __forceinline__ __device__ void ncclKernel( } copyToShmem16(tid%WARP_SIZE, dst, src, bytes); } - __syncthreads(); // publish shmem + __synclds(); // publish shmem #ifdef ENABLE_PROFILING if (tid == 0) { ncclShmem.prof.count = 0; @@ -565,7 +568,7 @@ __forceinline__ __device__ void ncclKernel( } else if (ncclShmem.work.header.type == ncclWorkTypeRegColl) { if (tid < NCCL_MAX_WORK_ELEMENTS_REG) ncclRedopPtrDeref(&ncclShmem.work.regElems[tid].elem); } - __syncthreads(); + __synclds(); if (tid == 0) __insert_timestamp(__LINE__); if (ncclShmem.work.header.funcIndex == FnIndex) { @@ -575,7 +578,7 @@ __forceinline__ __device__ void ncclKernel( } int workIxNext = ncclShmem.work.header.workNext; - __syncthreads(); + __synclds(); if (ncclShmem.work.header.isLast) break; copyToShmem16(tid, &ncclShmem.work, workHead + workIxNext, sizeof(ncclWork)); @@ -592,7 +595,7 @@ __forceinline__ __device__ void ncclKernel( if (COLLTRACE && tid == 0) traceKernelEnd(); #ifdef ENABLE_PROFILING if (ncclShmem.comm.devProf->seq < PROFILE_NUM_LAUNCHES) { - __syncthreads(); + __synclds(); copyToShmem16(tid, ncclShmem.comm.devProf+MAXCHANNELS*ncclShmem.prof.seq+blockIdx.x, &ncclShmem.prof, sizeof(struct ncclProf)); if (tid == 0) ncclShmem.comm.devProf[blockIdx.x].seq++; }