Merge pull request #742 from whchung/skip_done_event_msccl

Allow skipping doneEvent inside MSCCL.

[ROCm/rccl commit: eba4e9e100]
This commit is contained in:
Wen-Heng (Jack) Chung
2023-05-18 10:17:20 -05:00
committed by GitHub
+17 -2
View File
@@ -13,6 +13,10 @@
#include "msccl/msccl_setup.h"
#include "msccl/msccl_status.h"
#ifndef HIP_EVENT_DISABLE_FENCE
RCCL_PARAM(MscclEnableDoneEvent, "MSCCL_ENABLE_DONE_EVENT", 1);
#endif
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
mscclStatus& status = mscclGetStatus();
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
@@ -260,7 +264,14 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
if (status.lastStream != stream && status.lastStream != nullptr) {
bool enableDoneEvent =
#ifndef HIP_EVENT_DISABLE_FENCE
(rcclParamMscclEnableDoneEvent() == 1);
#else
true;
#endif
if (enableDoneEvent && (status.lastStream != stream && status.lastStream != nullptr)) {
CUDACHECK(hipStreamWaitEvent(stream, comm->doneEvent, 0));
}
@@ -284,7 +295,11 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
void *args[3] = {&comm->devComm, &devAlgo, &work};
void *func = mscclKernelEntries[(opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol];
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));
if (enableDoneEvent) {
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));
} else {
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, NULL, 0));
}
status.workIndex++;
status.lastStream = stream;
return ncclSuccess;