misc/msccl: Read graph capture status for every collective call (#1576)
* misc/msccl: read graphCaptureStatus for every collective call
* fix a bug in checking whether UBR is enabled in MSCCLPP
* cmake: Fix patch reversal order
* misc/msccl: add logging
[ROCm/rccl commit: 23c0b7bd84]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
3be905ca83
Коммит
f70f406463
@@ -95,6 +95,11 @@ if(ENABLE_MSCCLPP)
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
message(STATUS "Building mscclpp only for supported variants:gfx942,gfx950")
|
||||
mscclpp_cmake_arg(CMAKE_PREFIX_PATH)
|
||||
mscclpp_cmake_arg(CMAKE_INSTALL_RPATH_USE_LINK_PATH)
|
||||
@@ -121,23 +126,14 @@ if(ENABLE_MSCCLPP)
|
||||
|
||||
|
||||
find_package(mscclpp_nccl REQUIRED)
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
@@ -147,10 +143,25 @@ if(ENABLE_MSCCLPP)
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
#endif()
|
||||
|
||||
execute_process(COMMAND objcopy
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
|
||||
index 5c19dc6..5fb99ef 100644
|
||||
--- a/apps/nccl/src/nccl.cu
|
||||
+++ b/apps/nccl/src/nccl.cu
|
||||
@@ -85,6 +85,7 @@ struct ncclComm {
|
||||
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
|
||||
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
|
||||
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
|
||||
+ std::unordered_map<channelKey, cudaIpcMemHandle_t> regHandles;
|
||||
std::unordered_map<void*, channelKey> handleKeys;
|
||||
std::shared_ptr<char> scratchBuff;
|
||||
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
|
||||
@@ -616,6 +617,11 @@ NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size,
|
||||
p->ipcHandle = ipcHandle;
|
||||
*handle = p;
|
||||
|
||||
+ auto regIt = comm->regHandles.find(buffKey);
|
||||
+ if (regIt == comm->regHandles.end()) {
|
||||
+ comm->regHandles[buffKey] = ipcHandle;
|
||||
+ }
|
||||
+
|
||||
auto it = comm->handleKeys.find(*handle);
|
||||
if (it == comm->handleKeys.end()) {
|
||||
comm->handleKeys[*handle] = buffKey;
|
||||
@@ -642,6 +648,7 @@ NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
|
||||
if (outIt != comm->channelOutInfos.end()) {
|
||||
comm->channelOutInfos.erase(outIt);
|
||||
}
|
||||
+ comm->regHandles.erase(buffKey);
|
||||
comm->handleKeys.erase(handle);
|
||||
free(handle);
|
||||
}
|
||||
@@ -655,8 +662,8 @@ bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){
|
||||
CUdeviceptr buffBasePtr;
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
|
||||
channelKey buffKey{(void*)buffBasePtr, buffBytes};
|
||||
- auto buffIt = comm->channelScratchInfos.find(buffKey);
|
||||
- bool registered = buffIt != comm->channelScratchInfos.end();
|
||||
+ auto buffIt = comm->regHandles.find(buffKey);
|
||||
+ bool registered = buffIt != comm->regHandles.end();
|
||||
return registered;
|
||||
}
|
||||
size_t
|
||||
@@ -519,33 +519,31 @@ ncclResult_t mscclEnqueueCheck(
|
||||
case mscclNoGroup:
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (comm->mscclppCompatible) {
|
||||
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
|
||||
INFO(NCCL_COLL, "MSCCL++: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
}
|
||||
INFO(NCCL_COLL, "MSCCL++: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
|
||||
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
|
||||
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
|
||||
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
|
||||
const bool buffsRegisteredNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;
|
||||
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
|
||||
|
||||
/* check if one rank per GPU and graph mode is enabled */
|
||||
if ((graphMode || buffsRegisteredNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
bool isManagedBuffer = false;
|
||||
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
|
||||
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
|
||||
|
||||
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
|
||||
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
}
|
||||
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
@@ -565,33 +563,31 @@ ncclResult_t mscclEnqueueCheck(
|
||||
case mscclGroupSupportedOp:
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (comm->mscclppCompatible) {
|
||||
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
|
||||
INFO(NCCL_COLL, "MSCCL++: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
}
|
||||
INFO(NCCL_COLL, "MSCCL++: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
|
||||
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
|
||||
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
|
||||
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
|
||||
const bool buffsRegisteredNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;
|
||||
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
|
||||
|
||||
/* check if one rank per GPU and graph mode is enabled */
|
||||
if ((graphMode || buffsRegisteredNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
bool isManagedBuffer = false;
|
||||
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
|
||||
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
|
||||
|
||||
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
|
||||
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
}
|
||||
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d" ,
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
|
||||
Ссылка в новой задаче
Block a user