From c3bb9e70d074a1dcfbffbdf281a66eeeb3d4a690 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Thu, 23 Jun 2022 09:16:41 -0700 Subject: [PATCH] Use different atomics to check flags in kernel (#568) --- src/collectives/device/prims_ll.h | 4 +++- src/collectives/device/prims_simple.h | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index bdc711058f..5874994c41 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -99,9 +99,11 @@ private: if (sendConnHeadPtr) { int spins = 0; while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { - sendConnHeadCache = LOAD(sendConnHeadPtr); + __builtin_amdgcn_s_sleep(8); + sendConnHeadCache = atomicAdd_system((unsigned long long *)sendConnHeadPtr, 0); if (checkAbort(spins, 1)) break; } + __asm__ __volatile__("s_wakeup"); if (sendConnFifoPtr) { int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; STORE(sendConnFifoPtr+sendConnHead%NCCL_STEPS, size); diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index cc6f275041..edc0ac9a65 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -107,7 +107,7 @@ private: int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { __builtin_amdgcn_s_sleep(8); - connStepCache = LOAD(connStepPtr); + connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0); if (checkAbort(spins)) break; //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); if (spins == 0) traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));