MSCCL: add support for out-of-place all reduce (#1156)

This commit is contained in:
Wenkai Du
2024-04-28 19:49:09 -07:00
committed by GitHub
parent cd6e840e0b
commit 4e1b8c1cbb
5 changed files with 32794 additions and 2 deletions
+2 -2
View File
@@ -272,7 +272,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
int16_t dependentStep = mscclShmem.mscclTB.dependentStep[dependentPointer+tid];
uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep);
while (true){
uint64_t curFlag = __atomic_load_n(&(mscclFlags + dependentBid)->flag, __ATOMIC_RELAXED);
uint64_t curFlag = __atomic_load_n(&(mscclFlags + dependentBid)->flag, (t->srcBuffer != MSCCL_OUTPUT_BUFFER) ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE);
if (curFlag >= goalFlag && GET_WORKINDEX_FROM_FLAG(curFlag) == workIndex) break;
}
}
@@ -372,7 +372,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
return;
}
if (t->hasDependence && tid == nthreads-1)
__atomic_store_n(&mscclFlags[bid].flag, (uint64_t) COMPUTE_FLAG(workIndex, iter, step), ((t->type == MSCCL_REDUCE || t->type == MSCCL_RECV) && (t->dstBuffer != MSCCL_SCRATCH_BUFFER)) ? __ATOMIC_RELEASE : __ATOMIC_RELAXED);
__atomic_store_n(&mscclFlags[bid].flag, (uint64_t) COMPUTE_FLAG(workIndex, iter, step), (t->dstBuffer != MSCCL_SCRATCH_BUFFER) ? __ATOMIC_RELEASE : __ATOMIC_RELAXED);
step++;
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff