Add MSCCLPP user buffer registration APIs and integrate with RCCL (#1477)
* ext-src: add MSCCLPP memory registration APIs * update mem-reg patch with mscclpp helper routine to check if buffer is registered * RCCL integration of MSCCL++ user-buffer registration APIs * only include mscclpp_nccl header if ENABLE_MSCCLPP is defined * ext-src: update mscclpp mem-reg patch * add helper routine to patch * check handle before MSCCL++ deregister * fix typo to replace send buff with recv buff * in case of no mscclpp registration, dduring deRegister call, ont fall back to rccl deRegister which will return an error * Apply suggestions from code review Whitespace suggestions and reducing diffs to avoid future merge conflicts Co-authored-by: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com> * rename helper functions and change their return type * set RCCL user-buffer registration to occur if attempting MSCCL++ registration with a buffer in managed memory --------- Co-authored-by: isaki001 <Ioannis.Sakiotis@amd.com> Co-authored-by: isaki001 <36317038+isaki001@users.noreply.github.com> Co-authored-by: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com>
Bu işleme şunda yer alıyor:
@@ -68,14 +68,20 @@ if(ENABLE_MSCCLPP)
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
execute_process(
|
||||
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
message(STATUS "Building mscclpp only for gfx942.")
|
||||
|
||||
mscclpp_cmake_arg(CMAKE_PREFIX_PATH)
|
||||
@@ -102,17 +108,23 @@ if(ENABLE_MSCCLPP)
|
||||
|
||||
find_package(mscclpp_nccl REQUIRED)
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
|
||||
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
|
||||
)
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND objcopy
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h
|
||||
index 7f50792..b8b146d 100644
|
||||
--- a/apps/nccl/include/nccl.h
|
||||
+++ b/apps/nccl/include/nccl.h
|
||||
@@ -344,6 +344,13 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
|
||||
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
|
||||
ncclComm_t comm, cudaStream_t stream);
|
||||
|
||||
+/*
|
||||
+ * Register/Deregister
|
||||
+ */
|
||||
+ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle);
|
||||
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle);
|
||||
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count);
|
||||
+size_t mscclpp_BufferSize(ncclComm_t comm, void* handle);
|
||||
/*
|
||||
* Send
|
||||
*
|
||||
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
|
||||
index a697be2..1d4af61 100644
|
||||
--- a/apps/nccl/src/nccl.cu
|
||||
+++ b/apps/nccl/src/nccl.cu
|
||||
@@ -65,6 +65,7 @@ struct ncclComm {
|
||||
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
|
||||
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
|
||||
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
|
||||
+ std::unordered_map<void*, channelKey> handleKeys;
|
||||
std::shared_ptr<char> scratchBuff;
|
||||
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
|
||||
|
||||
@@ -73,6 +74,11 @@ struct ncclComm {
|
||||
uint32_t buffFlag;
|
||||
};
|
||||
|
||||
+struct handleInfo {
|
||||
+ void * buff;
|
||||
+ cudaIpcMemHandle_t ipcHandle;
|
||||
+};
|
||||
+
|
||||
static size_t ncclTypeSize(ncclDataType_t type) {
|
||||
switch (type) {
|
||||
case ncclInt8:
|
||||
@@ -577,6 +583,104 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) {
|
||||
+ size_t buffBytes = size;
|
||||
+ CUdeviceptr buffBasePtr;
|
||||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
|
||||
+
|
||||
+ int rank = comm->comm->bootstrap()->getRank();
|
||||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
|
||||
+
|
||||
+ std::vector<mscclpp::RegisteredMemory> remoteMemories;
|
||||
+
|
||||
+ // Creating the channels
|
||||
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
|
||||
+ if (buffIt == comm->channelScratchInfos.end()) {
|
||||
+ std::vector<mscclpp::SmChannel> channels =
|
||||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
|
||||
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
|
||||
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first;
|
||||
+ }
|
||||
+ auto sendIt = comm->channelInInfos.find(buffKey);
|
||||
+ if (sendIt == comm->channelInInfos.end()) {
|
||||
+ std::vector<mscclpp::SmChannel> channels =
|
||||
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
|
||||
+
|
||||
+ remoteMemories =
|
||||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
|
||||
+ std::vector<mscclpp::SmChannel> channels1 =
|
||||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
|
||||
+
|
||||
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)};
|
||||
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first;
|
||||
+ }
|
||||
+ auto recvIt = comm->channelOutInfos.find(buffKey);
|
||||
+ if (recvIt == comm->channelOutInfos.end()) {
|
||||
+ remoteMemories =
|
||||
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
|
||||
+ std::vector<mscclpp::SmChannel> outChannels =
|
||||
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
|
||||
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)};
|
||||
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first;
|
||||
+ }
|
||||
+
|
||||
+ cudaIpcMemHandle_t ipcHandle;
|
||||
+ MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr));
|
||||
+
|
||||
+ struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo));
|
||||
+ p->buff = buffBasePtr;
|
||||
+ p->ipcHandle = ipcHandle;
|
||||
+ *handle = p;
|
||||
+
|
||||
+ auto it = comm->handleKeys.find(*handle);
|
||||
+ if (it == comm->handleKeys.end()) {
|
||||
+ comm->handleKeys[*handle] = buffKey;
|
||||
+ }
|
||||
+
|
||||
+ return ncclSuccess;
|
||||
+}
|
||||
+
|
||||
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
|
||||
+ if (comm && handle) {
|
||||
+ channelKey buffKey = comm->handleKeys[handle];
|
||||
+
|
||||
+ auto scratchIt = comm->channelScratchInfos.find(buffKey);
|
||||
+ if (scratchIt != comm->channelScratchInfos.end()) {
|
||||
+ comm->channelScratchInfos.erase(scratchIt);
|
||||
+ }
|
||||
+
|
||||
+ auto inIt = comm->channelInInfos.find(buffKey);
|
||||
+ if (inIt != comm->channelInInfos.end()) {
|
||||
+ comm->channelInInfos.erase(inIt);
|
||||
+ }
|
||||
+
|
||||
+ auto outIt = comm->channelOutInfos.find(buffKey);
|
||||
+ if (outIt != comm->channelOutInfos.end()) {
|
||||
+ comm->channelOutInfos.erase(outIt);
|
||||
+ }
|
||||
+
|
||||
+ free(handle);
|
||||
+ }
|
||||
+ return ncclSuccess;
|
||||
+}
|
||||
+
|
||||
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count){
|
||||
+ size_t buffBytes;
|
||||
+ CUdeviceptr buffBasePtr;
|
||||
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
|
||||
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
|
||||
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
|
||||
+ bool registered = buffIt != comm->channelScratchInfos.end();
|
||||
+ return registered;
|
||||
+}
|
||||
+size_t
|
||||
+mscclpp_BufferSize(ncclComm_t comm, void* handle){
|
||||
+ if (!(comm && handle)){
|
||||
+ return 0;
|
||||
+ }
|
||||
+ auto buffKeyIt = comm->handleKeys.find(handle);
|
||||
+ return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0;
|
||||
+}
|
||||
NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
|
||||
// TODO: implement this function
|
||||
return ncclInternalError;
|
||||
@@ -38,6 +38,14 @@ extern "C" {
|
||||
/* See ncclAllGather. */
|
||||
ncclResult_t mscclpp_ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
ncclDataType_t datatype, mscclppComm_t comm, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclpp_ncclCommRegister(mscclppComm_t comm, void* buff, size_t size, void** handle);
|
||||
|
||||
ncclResult_t mscclpp_ncclCommDeregister(mscclppComm_t comm, void* handle);
|
||||
|
||||
bool mscclpp_BuffIsRegistered(mscclppComm_t comm, const void* buff, size_t count);
|
||||
|
||||
size_t mscclpp_BufferSize(mscclppComm_t comm, void* handle);
|
||||
}
|
||||
|
||||
namespace std {
|
||||
|
||||
@@ -524,8 +524,13 @@ ncclResult_t mscclEnqueueCheck(
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
}
|
||||
|
||||
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff, count);
|
||||
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff, count);
|
||||
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
|
||||
const bool buffsRegistedNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;
|
||||
|
||||
/* check if one rank per GPU and graph mode is enabled */
|
||||
if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
if ((graphMode || buffsRegistedNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
bool isManagedBuffer = false;
|
||||
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
|
||||
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
|
||||
@@ -565,8 +570,13 @@ ncclResult_t mscclEnqueueCheck(
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
}
|
||||
|
||||
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff, count);
|
||||
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff, count);
|
||||
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
|
||||
const bool buffsRegistedNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;
|
||||
|
||||
/* check if one rank per GPU and graph mode is enabled */
|
||||
if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
if ((graphMode || buffsRegistedNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
bool isManagedBuffer = false;
|
||||
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
|
||||
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
|
||||
|
||||
@@ -30,3 +30,5 @@ ncclRedOpDestroy mscclpp_ncclRedOpDestroy
|
||||
ncclReduce mscclpp_ncclReduce
|
||||
ncclReduceScatter mscclpp_ncclReduceScatter
|
||||
ncclSend mscclpp_ncclSend
|
||||
ncclCommRegister mscclpp_ncclCommRegister
|
||||
ncclCommDeregister mscclpp_ncclCommDeregister
|
||||
|
||||
@@ -10,6 +10,9 @@
|
||||
#include "net.h"
|
||||
#include "register.h"
|
||||
#include "api_trace.h"
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
#include "mscclpp/mscclpp_nccl.h"
|
||||
#endif
|
||||
|
||||
ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) {
|
||||
struct ncclRegCache* cache = &comm->regCache;
|
||||
@@ -155,12 +158,36 @@ NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size
|
||||
ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) {
|
||||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm"));
|
||||
if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister"));
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){
|
||||
bool isManagedBuffer = false;
|
||||
CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(buff)));
|
||||
if(!isManagedBuffer){
|
||||
INFO(NCCL_INIT, "MSCCL++: ncclCommRegister");
|
||||
NCCLCHECK(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle));
|
||||
return ncclSuccess;
|
||||
}
|
||||
else{
|
||||
WARN("MSCCL++: Cannot register user-buffers on managed memory. RCCL user-buffer registration will occur.");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
INFO(NCCL_INIT, "RCCL: ncclCommRegister");
|
||||
NCCLCHECK(ncclRegister(comm, buff, size, handle));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle);
|
||||
ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) {
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
const size_t size = mscclpp_BufferSize(comm->mscclpp_comm, handle);
|
||||
if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold) {
|
||||
NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle));
|
||||
return ncclSuccess;
|
||||
}
|
||||
#endif
|
||||
|
||||
NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm"));
|
||||
struct ncclReg* reg = (struct ncclReg*)handle;
|
||||
struct ncclRegCache* cache = &comm->regCache;
|
||||
|
||||
Yeni konuda referans
Bir kullanıcı engelle