Files
rocm-systems/ext-src/mem-reg.patch
T
Nusrat Islam e9b6bbca8a Add MSCCLPP user buffer registration APIs and integrate with RCCL (#1477)
* ext-src: add MSCCLPP memory registration APIs

* update mem-reg patch with mscclpp helper routine to check if buffer is registered

* RCCL integration of MSCCL++ user-buffer registration APIs

* only include mscclpp_nccl header if ENABLE_MSCCLPP is defined

* ext-src: update mscclpp mem-reg patch

* add helper routine to patch

* check handle before MSCCL++ deregister

* fix typo to replace send buff with recv buff

* in case of no mscclpp registration, dduring deRegister call, ont fall back to rccl deRegister which will return an error

* Apply suggestions from code review

Whitespace suggestions and reducing diffs to avoid future merge conflicts

Co-authored-by: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com>

* rename helper functions and change their return type

* set RCCL user-buffer registration to occur if attempting MSCCL++ registration with a buffer in managed memory

---------

Co-authored-by: isaki001 <Ioannis.Sakiotis@amd.com>
Co-authored-by: isaki001 <36317038+isaki001@users.noreply.github.com>
Co-authored-by: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com>
2025-01-14 08:20:24 -06:00

148 строки
5.8 KiB
Diff

diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h
index 7f50792..b8b146d 100644
--- a/apps/nccl/include/nccl.h
+++ b/apps/nccl/include/nccl.h
@@ -344,6 +344,13 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream);
+/*
+ * Register/Deregister
+ */
+ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle);
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle);
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count);
+size_t mscclpp_BufferSize(ncclComm_t comm, void* handle);
/*
* Send
*
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index a697be2..1d4af61 100644
--- a/apps/nccl/src/nccl.cu
+++ b/apps/nccl/src/nccl.cu
@@ -65,6 +65,7 @@ struct ncclComm {
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
+ std::unordered_map<void*, channelKey> handleKeys;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
@@ -73,6 +74,11 @@ struct ncclComm {
uint32_t buffFlag;
};
+struct handleInfo {
+ void * buff;
+ cudaIpcMemHandle_t ipcHandle;
+};
+
static size_t ncclTypeSize(ncclDataType_t type) {
switch (type) {
case ncclInt8:
@@ -577,6 +583,104 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
return ncclSuccess;
}
+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) {
+ size_t buffBytes = size;
+ CUdeviceptr buffBasePtr;
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
+
+ int rank = comm->comm->bootstrap()->getRank();
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
+
+ std::vector<mscclpp::RegisteredMemory> remoteMemories;
+
+ // Creating the channels
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
+ if (buffIt == comm->channelScratchInfos.end()) {
+ std::vector<mscclpp::SmChannel> channels =
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first;
+ }
+ auto sendIt = comm->channelInInfos.find(buffKey);
+ if (sendIt == comm->channelInInfos.end()) {
+ std::vector<mscclpp::SmChannel> channels =
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
+
+ remoteMemories =
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
+ std::vector<mscclpp::SmChannel> channels1 =
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
+
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)};
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first;
+ }
+ auto recvIt = comm->channelOutInfos.find(buffKey);
+ if (recvIt == comm->channelOutInfos.end()) {
+ remoteMemories =
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
+ std::vector<mscclpp::SmChannel> outChannels =
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)};
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first;
+ }
+
+ cudaIpcMemHandle_t ipcHandle;
+ MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr));
+
+ struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo));
+ p->buff = buffBasePtr;
+ p->ipcHandle = ipcHandle;
+ *handle = p;
+
+ auto it = comm->handleKeys.find(*handle);
+ if (it == comm->handleKeys.end()) {
+ comm->handleKeys[*handle] = buffKey;
+ }
+
+ return ncclSuccess;
+}
+
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
+ if (comm && handle) {
+ channelKey buffKey = comm->handleKeys[handle];
+
+ auto scratchIt = comm->channelScratchInfos.find(buffKey);
+ if (scratchIt != comm->channelScratchInfos.end()) {
+ comm->channelScratchInfos.erase(scratchIt);
+ }
+
+ auto inIt = comm->channelInInfos.find(buffKey);
+ if (inIt != comm->channelInInfos.end()) {
+ comm->channelInInfos.erase(inIt);
+ }
+
+ auto outIt = comm->channelOutInfos.find(buffKey);
+ if (outIt != comm->channelOutInfos.end()) {
+ comm->channelOutInfos.erase(outIt);
+ }
+
+ free(handle);
+ }
+ return ncclSuccess;
+}
+
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count){
+ size_t buffBytes;
+ CUdeviceptr buffBasePtr;
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
+ bool registered = buffIt != comm->channelScratchInfos.end();
+ return registered;
+}
+size_t
+mscclpp_BufferSize(ncclComm_t comm, void* handle){
+ if (!(comm && handle)){
+ return 0;
+ }
+ auto buffKeyIt = comm->handleKeys.find(handle);
+ return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0;
+}
NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;