reverted the syncLDS back to syncthreads (#1554)

Этот коммит содержится в:
akolliasAMD
2025-02-19 10:44:32 -07:00
коммит произвёл GitHub
родитель baaa2ac64d
Коммит aedbc95735
2 изменённых файлов: 11 добавлений и 19 удалений
+9 -17
Просмотреть файл
@@ -19,14 +19,6 @@
#define __syncwarp()
#ifdef __GFX12__
#define __synclds() \
asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier_signal -1 \n s_barrier_wait -1");
#else
#define __synclds() \
asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier");
#endif
#ifdef __GFX9__
#define STORE(DST, SRC) \
{ __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); }
@@ -352,7 +344,7 @@ __device__ __forceinline__ void loadWorkBatchToShmem(
if (ncclShmem.args.workStorageType == ncclDevWorkStorageTypeArgs) {
char* src = (char*)args + (batch.offsetBase + srcWork*workSize + packInWork*16);
tmp = *(ulong2*)src; // becomes ld.param.v2.u64
}
}
if (ncclShmem.args.workStorageType != ncclDevWorkStorageTypeArgs) {
char* src = (char*)ncclShmem.args.workBuf + ((batch.offsetBase + srcWork*workSize + packInWork*16) & ncclShmem.args.workMask);
tmp = *(ulong2*)src; // becomes ld.v2.u64
@@ -411,7 +403,7 @@ struct RunWorkBatch {
work->redOpArg = RedOpArg<RedOp>::loadArg(reinterpret_cast<void*>(work->redOpArg));
}
}
__synclds();
__syncthreads();
}
#pragma unroll 1
@@ -419,7 +411,7 @@ struct RunWorkBatch {
struct ncclDevWorkColl* work = (struct ncclDevWorkColl*)(ncclShmem.workStorage + w*ncclShmem.workSize);
if (w != 0) {
struct ncclDevWorkColl* workPrev = (struct ncclDevWorkColl*)(ncclShmem.workStorage + (w-1)*ncclShmem.workSize);
if (work->nWarps != workPrev->nWarps) __synclds();
if (work->nWarps != workPrev->nWarps) __syncthreads();
}
int subtn = work->nWarps*WARP_SIZE;
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, COLL_UNROLL>().run(tid, subtn, work);
@@ -484,7 +476,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
default:
break;
}
__synclds(); // publish ncclShmem.{args, channelId}
__syncthreads(); // publish ncclShmem.{args, channelId}
// Use first 2 warps to load comm and channel, and reamaining load work batch.
switch (tid/WARP_SIZE) {
@@ -515,7 +507,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
ncclShmem.collTraceTail = args->comm->collTraceTail + ncclShmem.channelId;
}
#endif
__synclds(); // publish shmem
__syncthreads(); // publish shmem
#ifdef ENABLE_PROFILING
if (tid == 0) {
ncclShmem.prof.count = 0;
@@ -532,7 +524,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
while (true) {
if (tid == 0) __insert_timestamp(__LINE__);
if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) {
SpecializedRunWorkBatch().run();
} else {
@@ -551,7 +543,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
if (ncclShmem.nextBatchIx == -1) break;
int batchIx = ncclShmem.nextBatchIx;
__synclds();
__syncthreads();
switch (tid/WARP_SIZE) {
case 1:
if (tid < WARP_SIZE + NCCL_MAX_GROUPS)
@@ -561,7 +553,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
break;
}
loadWorkBatchToShmem(tid%WARP_SIZE, tn, args, batchIx);
__synclds();
__syncthreads();
// Check whether the last operation was aborted and make sure all threads exit
bool aborted = false;
@@ -579,7 +571,7 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
#ifdef ENABLE_PROFILING
if (ncclShmem.comm.devProf->seq < PROFILE_NUM_LAUNCHES) {
__synclds();
__syncthreads();
copyToShmem16(tid, ncclShmem.comm.devProf+MAXCHANNELS*ncclShmem.prof.seq+blockIdx.x, &ncclShmem.prof, sizeof(struct ncclProf));
if (tid == 0) ncclShmem.comm.devProf[blockIdx.x].seq++;
}
+2 -2
Просмотреть файл
@@ -108,7 +108,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
threadBlockCopy(
(uint32_t *)&mscclShmem.mscclTB, (uint32_t *)(algo->mscclTBs + bid),
sizeof(struct mscclThreadBlock) / sizeof(uint32_t), tid, nthreads);
__synclds(); // publish mscclShmem.mscclTB.channelId
__syncthreads(); // publish mscclShmem.mscclTB.channelId
// initialize ncclShmem and mscclShmem.work
int channelId = mscclShmem.mscclTB.channelId;
@@ -146,7 +146,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
}
if (bytes) copyToShmem8(tid%WARP_SIZE, dst, src, bytes);
}
__synclds(); // publish shmem
__syncthreads(); // publish shmem
#if defined(ENABLE_NPKIT)
int npKitCtxIdx = bid;