diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index bdc711058f..5874994c41 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -99,9 +99,11 @@ private: if (sendConnHeadPtr) { int spins = 0; while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { - sendConnHeadCache = LOAD(sendConnHeadPtr); + __builtin_amdgcn_s_sleep(8); + sendConnHeadCache = atomicAdd_system((unsigned long long *)sendConnHeadPtr, 0); if (checkAbort(spins, 1)) break; } + __asm__ __volatile__("s_wakeup"); if (sendConnFifoPtr) { int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; STORE(sendConnFifoPtr+sendConnHead%NCCL_STEPS, size); diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index cc6f275041..edc0ac9a65 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -107,7 +107,7 @@ private: int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { __builtin_amdgcn_s_sleep(8); - connStepCache = LOAD(connStepPtr); + connStepCache = atomicAdd_system((unsigned long long *)connStepPtr, 0); if (checkAbort(spins)) break; //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); if (spins == 0) traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));