From 19105206f61d522ebfb99aaa3435f2a76f7b3dd1 Mon Sep 17 00:00:00 2001 From: isaki001 <36317038+isaki001@users.noreply.github.com> Date: Tue, 4 Feb 2025 09:09:56 -0600 Subject: [PATCH] Update MSCCL++ register/deregister (#1523) * erase handle key from mscclpp communicator during deregistration * remove check on buffer size being a multiple of 32 from registration/deregistration routines since these checks are applied during enqueue * add check for greater than zero buffer size in mscclpp registration --- ext-src/mem-reg.patch | 2 +- src/register.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ext-src/mem-reg.patch b/ext-src/mem-reg.patch index 57c50a5341..c9f522b6d8 100644 --- a/ext-src/mem-reg.patch +++ b/ext-src/mem-reg.patch @@ -116,7 +116,7 @@ index 022d398..468fcf2 100644 + if (outIt != comm->channelOutInfos.end()) { + comm->channelOutInfos.erase(outIt); + } -+ ++ comm->handleKeys.erase(handle); + free(handle); + } + return ncclSuccess; diff --git a/src/register.cc b/src/register.cc index 4a60b7f5b6..980df6e6dd 100644 --- a/src/register.cc +++ b/src/register.cc @@ -161,7 +161,7 @@ ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t siz NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister")); #ifdef ENABLE_MSCCLPP - if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){ + if (comm->mscclCompatible && size > 0){ bool isManagedBuffer = false; CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(buff))); if(!isManagedBuffer){ @@ -184,7 +184,7 @@ ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) { #ifdef ENABLE_MSCCLPP const size_t size = mscclpp_BufferSize(comm->mscclpp_comm, handle); - if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold) { + if (comm->mscclCompatible && size > 0) { NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle)); return ncclSuccess; }