MSCCL: add support for out-of-place all reduce (#1156)
This commit is contained in:
@@ -272,7 +272,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
int16_t dependentStep = mscclShmem.mscclTB.dependentStep[dependentPointer+tid];
|
||||
uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep);
|
||||
while (true){
|
||||
uint64_t curFlag = __atomic_load_n(&(mscclFlags + dependentBid)->flag, __ATOMIC_RELAXED);
|
||||
uint64_t curFlag = __atomic_load_n(&(mscclFlags + dependentBid)->flag, (t->srcBuffer != MSCCL_OUTPUT_BUFFER) ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE);
|
||||
if (curFlag >= goalFlag && GET_WORKINDEX_FROM_FLAG(curFlag) == workIndex) break;
|
||||
}
|
||||
}
|
||||
@@ -372,7 +372,7 @@ __device__ __forceinline__ void mscclRunInterpreter(
|
||||
return;
|
||||
}
|
||||
if (t->hasDependence && tid == nthreads-1)
|
||||
__atomic_store_n(&mscclFlags[bid].flag, (uint64_t) COMPUTE_FLAG(workIndex, iter, step), ((t->type == MSCCL_REDUCE || t->type == MSCCL_RECV) && (t->dstBuffer != MSCCL_SCRATCH_BUFFER)) ? __ATOMIC_RELEASE : __ATOMIC_RELAXED);
|
||||
__atomic_store_n(&mscclFlags[bid].flag, (uint64_t) COMPUTE_FLAG(workIndex, iter, step), (t->dstBuffer != MSCCL_SCRATCH_BUFFER) ? __ATOMIC_RELEASE : __ATOMIC_RELAXED);
|
||||
step++;
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user