diff --git a/src/device/common.h b/src/device/common.h index 54be067108..0eb3ea9915 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -19,14 +19,6 @@ #define __syncwarp() -#ifdef __GFX12__ -#define __synclds() \ - asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier_signal -1 \n s_barrier_wait -1"); -#else -#define __synclds() \ - asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier"); -#endif - #ifdef __GFX9__ #define STORE(DST, SRC) \ { __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); } @@ -352,7 +344,7 @@ __device__ __forceinline__ void loadWorkBatchToShmem( if (ncclShmem.args.workStorageType == ncclDevWorkStorageTypeArgs) { char* src = (char*)args + (batch.offsetBase + srcWork*workSize + packInWork*16); tmp = *(ulong2*)src; // becomes ld.param.v2.u64 - } + } if (ncclShmem.args.workStorageType != ncclDevWorkStorageTypeArgs) { char* src = (char*)ncclShmem.args.workBuf + ((batch.offsetBase + srcWork*workSize + packInWork*16) & ncclShmem.args.workMask); tmp = *(ulong2*)src; // becomes ld.v2.u64 @@ -411,7 +403,7 @@ struct RunWorkBatch { work->redOpArg = RedOpArg::loadArg(reinterpret_cast(work->redOpArg)); } } - __synclds(); + __syncthreads(); } #pragma unroll 1 @@ -419,7 +411,7 @@ struct RunWorkBatch { struct ncclDevWorkColl* work = (struct ncclDevWorkColl*)(ncclShmem.workStorage + w*ncclShmem.workSize); if (w != 0) { struct ncclDevWorkColl* workPrev = (struct ncclDevWorkColl*)(ncclShmem.workStorage + (w-1)*ncclShmem.workSize); - if (work->nWarps != workPrev->nWarps) __synclds(); + if (work->nWarps != workPrev->nWarps) __syncthreads(); } int subtn = work->nWarps*WARP_SIZE; if (tid < subtn) RunWorkColl().run(tid, subtn, work); @@ -484,7 +476,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a default: break; } - __synclds(); // publish ncclShmem.{args, channelId} + __syncthreads(); // publish ncclShmem.{args, channelId} // Use first 2 warps to load comm and channel, and reamaining load work batch. switch (tid/WARP_SIZE) { @@ -515,7 +507,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a ncclShmem.collTraceTail = args->comm->collTraceTail + ncclShmem.channelId; } #endif - __synclds(); // publish shmem + __syncthreads(); // publish shmem #ifdef ENABLE_PROFILING if (tid == 0) { ncclShmem.prof.count = 0; @@ -532,7 +524,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a while (true) { if (tid == 0) __insert_timestamp(__LINE__); - + if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) { SpecializedRunWorkBatch().run(); } else { @@ -551,7 +543,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a if (ncclShmem.nextBatchIx == -1) break; int batchIx = ncclShmem.nextBatchIx; - __synclds(); + __syncthreads(); switch (tid/WARP_SIZE) { case 1: if (tid < WARP_SIZE + NCCL_MAX_GROUPS) @@ -561,7 +553,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a break; } loadWorkBatchToShmem(tid%WARP_SIZE, tn, args, batchIx); - __synclds(); + __syncthreads(); // Check whether the last operation was aborted and make sure all threads exit bool aborted = false; @@ -579,7 +571,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a #ifdef ENABLE_PROFILING if (ncclShmem.comm.devProf->seq < PROFILE_NUM_LAUNCHES) { - __synclds(); + __syncthreads(); copyToShmem16(tid, ncclShmem.comm.devProf+MAXCHANNELS*ncclShmem.prof.seq+blockIdx.x, &ncclShmem.prof, sizeof(struct ncclProf)); if (tid == 0) ncclShmem.comm.devProf[blockIdx.x].seq++; } diff --git a/src/device/msccl_kernel_impl.h b/src/device/msccl_kernel_impl.h index a506ddef33..a2aeccf2c0 100644 --- a/src/device/msccl_kernel_impl.h +++ b/src/device/msccl_kernel_impl.h @@ -108,7 +108,7 @@ __device__ __forceinline__ void mscclRunInterpreter( threadBlockCopy( (uint32_t *)&mscclShmem.mscclTB, (uint32_t *)(algo->mscclTBs + bid), sizeof(struct mscclThreadBlock) / sizeof(uint32_t), tid, nthreads); - __synclds(); // publish mscclShmem.mscclTB.channelId + __syncthreads(); // publish mscclShmem.mscclTB.channelId // initialize ncclShmem and mscclShmem.work int channelId = mscclShmem.mscclTB.channelId; @@ -146,7 +146,7 @@ __device__ __forceinline__ void mscclRunInterpreter( } if (bytes) copyToShmem8(tid%WARP_SIZE, dst, src, bytes); } - __synclds(); // publish shmem + __syncthreads(); // publish shmem #if defined(ENABLE_NPKIT) int npKitCtxIdx = bid;