diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h index bfdb226..70d15cf 100644 --- a/apps/nccl/include/nccl.h +++ b/apps/nccl/include/nccl.h @@ -370,6 +370,10 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp ncclResult_t pncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream); +ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle); +bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff); +size_t mscclpp_BufferSize(ncclComm_t comm, void* handle); + /* * Receive * diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 022d398..468fcf2 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -85,6 +85,7 @@ struct ncclComm { std::unordered_map channelInInfos; std::unordered_map channelOutInfos; std::unordered_map channelScratchInfos; + std::unordered_map handleKeys; std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; @@ -92,6 +93,11 @@ struct ncclComm { uint32_t buffFlag; }; +struct handleInfo { + void * buff; + cudaIpcMemHandle_t ipcHandle; +}; + static size_t ncclTypeSize(ncclDataType_t type) { switch (type) { case ncclInt8: @@ -561,6 +567,107 @@ NCCL_API ncclResult_t ncclRedOpDestroy(ncclRedOp_t, ncclComm_t) { return ncclInternalError; } +NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) { + size_t buffBytes = size; + CUdeviceptr buffBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); + + int rank = comm->comm->bootstrap()->getRank(); + channelKey buffKey{(void*)buffBasePtr, buffBytes}; + + std::vector remoteMemories; + + // Creating the channels + auto buffIt = comm->channelScratchInfos.find(buffKey); + if (buffIt == comm->channelScratchInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)buffBasePtr)); + ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; + buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first; + } + auto sendIt = comm->channelInInfos.find(buffKey); + if (sendIt == comm->channelInInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)buffBasePtr)); + + remoteMemories = + setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); + std::vector channels1 = + setupSmChannels(comm, remoteMemories, const_cast((void*)buffBasePtr)); + + ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; + sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first; + } + auto recvIt = comm->channelOutInfos.find(buffKey); + if (recvIt == comm->channelOutInfos.end()) { + remoteMemories = + setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); + std::vector outChannels = + setupSmChannels(comm, remoteMemories, const_cast((void*)buffBasePtr)); + ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; + recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first; + } + + cudaIpcMemHandle_t ipcHandle; + MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr)); + + struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo)); + p->buff = buffBasePtr; + p->ipcHandle = ipcHandle; + *handle = p; + + auto it = comm->handleKeys.find(*handle); + if (it == comm->handleKeys.end()) { + comm->handleKeys[*handle] = buffKey; + } + + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { + if (comm && handle) { + channelKey buffKey = comm->handleKeys[handle]; + + auto scratchIt = comm->channelScratchInfos.find(buffKey); + if (scratchIt != comm->channelScratchInfos.end()) { + comm->channelScratchInfos.erase(scratchIt); + } + + auto inIt = comm->channelInInfos.find(buffKey); + if (inIt != comm->channelInInfos.end()) { + comm->channelInInfos.erase(inIt); + } + + auto outIt = comm->channelOutInfos.find(buffKey); + if (outIt != comm->channelOutInfos.end()) { + comm->channelOutInfos.erase(outIt); + } + comm->handleKeys.erase(handle); + free(handle); + } + return ncclSuccess; +} + +bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){ + if(buff == nullptr) + return false; + size_t buffBytes; + CUdeviceptr buffBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); + channelKey buffKey{(void*)buffBasePtr, buffBytes}; + auto buffIt = comm->channelScratchInfos.find(buffKey); + bool registered = buffIt != comm->channelScratchInfos.end(); + return registered; +} +size_t +mscclpp_BufferSize(ncclComm_t comm, void* handle){ + if (!(comm && handle)){ + return 0; + } + auto buffKeyIt = comm->handleKeys.find(handle); + return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0; +} + NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, int, ncclComm_t, cudaStream_t) { // TODO: implement this function