19105206f6
* erase handle key from mscclpp communicator during deregistration * remove check on buffer size being a multiple of 32 from registration/deregistration routines since these checks are applied during enqueue * add check for greater than zero buffer size in mscclpp registration
148 rivejä
5.7 KiB
Diff
148 rivejä
5.7 KiB
Diff
diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h
|
|
index bfdb226..70d15cf 100644
|
|
--- a/apps/nccl/include/nccl.h
|
|
+++ b/apps/nccl/include/nccl.h
|
|
@@ -370,6 +370,10 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp
|
|
ncclResult_t pncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm,
|
|
cudaStream_t stream);
|
|
|
|
+ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle);
|
|
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff);
|
|
+size_t mscclpp_BufferSize(ncclComm_t comm, void* handle);
|
|
+
|
|
/*
|
|
* Receive
|
|
*
|
|
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
|
|
index 022d398..468fcf2 100644
|
|
--- a/apps/nccl/src/nccl.cu
|
|
+++ b/apps/nccl/src/nccl.cu
|
|
@@ -85,6 +85,7 @@ struct ncclComm {
|
|
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
|
|
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
|
|
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
|
|
+ std::unordered_map<void*, channelKey> handleKeys;
|
|
std::shared_ptr<char> scratchBuff;
|
|
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
|
|
|
|
@@ -92,6 +93,11 @@ struct ncclComm {
|
|
uint32_t buffFlag;
|
|
};
|
|
|
|
+struct handleInfo {
|
|
+ void * buff;
|
|
+ cudaIpcMemHandle_t ipcHandle;
|
|
+};
|
|
+
|
|
static size_t ncclTypeSize(ncclDataType_t type) {
|
|
switch (type) {
|
|
case ncclInt8:
|
|
@@ -561,6 +567,107 @@ NCCL_API ncclResult_t ncclRedOpDestroy(ncclRedOp_t, ncclComm_t) {
|
|
return ncclInternalError;
|
|
}
|
|
|
|
+NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) {
|
|
+ size_t buffBytes = size;
|
|
+ CUdeviceptr buffBasePtr;
|
|
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
|
|
+
|
|
+ int rank = comm->comm->bootstrap()->getRank();
|
|
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
|
|
+
|
|
+ std::vector<mscclpp::RegisteredMemory> remoteMemories;
|
|
+
|
|
+ // Creating the channels
|
|
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
|
|
+ if (buffIt == comm->channelScratchInfos.end()) {
|
|
+ std::vector<mscclpp::SmChannel> channels =
|
|
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
|
|
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
|
|
+ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first;
|
|
+ }
|
|
+ auto sendIt = comm->channelInInfos.find(buffKey);
|
|
+ if (sendIt == comm->channelInInfos.end()) {
|
|
+ std::vector<mscclpp::SmChannel> channels =
|
|
+ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)buffBasePtr));
|
|
+
|
|
+ remoteMemories =
|
|
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
|
|
+ std::vector<mscclpp::SmChannel> channels1 =
|
|
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
|
|
+
|
|
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)};
|
|
+ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first;
|
|
+ }
|
|
+ auto recvIt = comm->channelOutInfos.find(buffKey);
|
|
+ if (recvIt == comm->channelOutInfos.end()) {
|
|
+ remoteMemories =
|
|
+ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc);
|
|
+ std::vector<mscclpp::SmChannel> outChannels =
|
|
+ setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)buffBasePtr));
|
|
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)};
|
|
+ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first;
|
|
+ }
|
|
+
|
|
+ cudaIpcMemHandle_t ipcHandle;
|
|
+ MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr));
|
|
+
|
|
+ struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo));
|
|
+ p->buff = buffBasePtr;
|
|
+ p->ipcHandle = ipcHandle;
|
|
+ *handle = p;
|
|
+
|
|
+ auto it = comm->handleKeys.find(*handle);
|
|
+ if (it == comm->handleKeys.end()) {
|
|
+ comm->handleKeys[*handle] = buffKey;
|
|
+ }
|
|
+
|
|
+ return ncclSuccess;
|
|
+}
|
|
+
|
|
+NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
|
|
+ if (comm && handle) {
|
|
+ channelKey buffKey = comm->handleKeys[handle];
|
|
+
|
|
+ auto scratchIt = comm->channelScratchInfos.find(buffKey);
|
|
+ if (scratchIt != comm->channelScratchInfos.end()) {
|
|
+ comm->channelScratchInfos.erase(scratchIt);
|
|
+ }
|
|
+
|
|
+ auto inIt = comm->channelInInfos.find(buffKey);
|
|
+ if (inIt != comm->channelInInfos.end()) {
|
|
+ comm->channelInInfos.erase(inIt);
|
|
+ }
|
|
+
|
|
+ auto outIt = comm->channelOutInfos.find(buffKey);
|
|
+ if (outIt != comm->channelOutInfos.end()) {
|
|
+ comm->channelOutInfos.erase(outIt);
|
|
+ }
|
|
+ comm->handleKeys.erase(handle);
|
|
+ free(handle);
|
|
+ }
|
|
+ return ncclSuccess;
|
|
+}
|
|
+
|
|
+bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){
|
|
+ if(buff == nullptr)
|
|
+ return false;
|
|
+ size_t buffBytes;
|
|
+ CUdeviceptr buffBasePtr;
|
|
+ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
|
|
+ channelKey buffKey{(void*)buffBasePtr, buffBytes};
|
|
+ auto buffIt = comm->channelScratchInfos.find(buffKey);
|
|
+ bool registered = buffIt != comm->channelScratchInfos.end();
|
|
+ return registered;
|
|
+}
|
|
+size_t
|
|
+mscclpp_BufferSize(ncclComm_t comm, void* handle){
|
|
+ if (!(comm && handle)){
|
|
+ return 0;
|
|
+ }
|
|
+ auto buffKeyIt = comm->handleKeys.find(handle);
|
|
+ return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0;
|
|
+}
|
|
+
|
|
NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, int, ncclComm_t,
|
|
cudaStream_t) {
|
|
// TODO: implement this function
|