From 70597789d046c19eb0855c1865505669e81ee834 Mon Sep 17 00:00:00 2001 From: Stanley Tsang Date: Fri, 23 Apr 2021 11:33:20 -0600 Subject: [PATCH] Message queue refactor to POSIX implementation and leak fix (#355) * Fixing message queue leak. * Using POSIX implementation of Message Queues * Adding unlink to msgqueue * MsgQueue update * Adding timeout check to msgqueue broadcast; tightening up system checks * Removing unnecessary code * Removing extra argument from print * Adding explicit msg queue close call to all other ranks --- src/clique/CliqueManager.cc | 33 ++-------------- src/clique/MsgQueue.cc | 78 +++++++++++++++++++++++++------------ src/clique/MsgQueue.h | 17 ++++---- src/clique/ShmObject.h | 46 +++++++++++++++------- 4 files changed, 98 insertions(+), 76 deletions(-) diff --git a/src/clique/CliqueManager.cc b/src/clique/CliqueManager.cc index 665d882e41..7e6100d603 100644 --- a/src/clique/CliqueManager.cc +++ b/src/clique/CliqueManager.cc @@ -72,28 +72,6 @@ CliqueManager::~CliqueManager() void CliqueManager::CleanUp() { - if (rcclParamEnableClique()) - { - if (m_rank == 0) - { - int pid = getpid(); - for (auto it = CliqueShmNames.begin(); it != CliqueShmNames.end(); it++) - { - std::string msgQueueName = "/tmp/" + it->second + std::to_string(m_hash) + "_" + std::to_string(pid); - ncclResult_t res = MsgQueueClose(msgQueueName, m_hash); - if (res != ncclSuccess) - { - WARN("Unable to close Message Queue: %s\n", msgQueueName.c_str()); - } - int ret = unlink(msgQueueName.c_str()); - if (ret != 0) - { - WARN("Unable to unlink %s\n", msgQueueName.c_str()); - } - } - } - } - if (m_cliqueMode == CLIQUE_DISABLED) return; // Free variables that are shared between SINGLE_PROCESS / SINGLE_NODE @@ -131,7 +109,6 @@ void CliqueManager::CleanUp() ncclResult_t CliqueManager::Init(ncclUniqueId const* commId, int suffix) { ncclResult_t res; - if (m_init) return ncclSuccess; m_init = true; @@ -167,7 +144,6 @@ ncclResult_t CliqueManager::Init(ncclUniqueId const* commId, int suffix) return ncclSuccess; } - std::string shmSuffix = std::to_string(m_hash) + "_" + std::to_string(suffix); // Allocate sense barrier variable on local GPU @@ -531,11 +507,10 @@ ncclResult_t CliqueManager::BootstrapRootInit(int pid, unsigned long hash) { for (auto it = CliqueShmNames.begin(); it != CliqueShmNames.end(); it++) { - int msgid, fd; - std::string msgQueueName = "/tmp/" + it->second + std::to_string(hash) + "_" + std::to_string(pid); - SYSCHECKVAL(open(msgQueueName.c_str(), O_CREAT | O_RDWR, 0606), "open", fd); - NCCLCHECK(MsgQueueGetId(msgQueueName, hash, true, msgid)); - SYSCHECK(close(fd), "close"); + mqd_t mq_desc; + std::string msgQueueName = it->second + std::to_string(hash) + "_" + std::to_string(pid); + NCCLCHECK(MsgQueueGetId(msgQueueName, true, mq_desc)); + NCCLCHECK(MsgQueueClose(msgQueueName, mq_desc, true)); } std::string shmDir = "/dev/shm/"; diff --git a/src/clique/MsgQueue.cc b/src/clique/MsgQueue.cc index ef5dd6fa19..ba1da846bc 100644 --- a/src/clique/MsgQueue.cc +++ b/src/clique/MsgQueue.cc @@ -21,25 +21,31 @@ THE SOFTWARE. */ #include "MsgQueue.h" +#include -#include -#include +#define MSG_QUEUE_PERM S_IRUSR | S_IWUSR +#define MSG_QUEUE_MODE O_RDWR +#define MSG_SIZE 1 +#define MSG_QUEUE_TIMEOUT 60 -#define MSG_QUEUE_PERM 0666 - -ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& msgid) +ncclResult_t MsgQueueGetId(std::string name, bool exclusive, mqd_t& mq_desc) { - key_t key; - SYSCHECKVAL(ftok(name.c_str(), projid), "ftok", key); - int flag = (exclusive == true ? IPC_CREAT | IPC_EXCL : IPC_CREAT); - msgid = msgget(key, MSG_QUEUE_PERM | flag); + int flag = (exclusive == true ? O_CREAT | O_EXCL : O_CREAT); + struct mq_attr attr; + attr.mq_maxmsg = 10; + attr.mq_msgsize = MSG_SIZE; + attr.mq_flags = 0; + + std::string mq_name = "/" + name; + mq_desc = mq_open(mq_name.c_str(), flag | MSG_QUEUE_MODE, MSG_QUEUE_PERM, &attr); + // Check if we're trying to create message queue and it already exists; if so, delete existing queue - if (msgid == -1 && exclusive == true && errno == EEXIST) + if (mq_desc == -1 && exclusive == true && errno == EBUSY) { - NCCLCHECK(MsgQueueClose(name, projid)); - SYSCHECKVAL(msgget(key, MSG_QUEUE_PERM | flag), "msgget", msgid); + NCCLCHECK(MsgQueueClose(name, mq_desc, true)); + SYSCHECKVAL(mq_open(mq_name.c_str(), flag | MSG_QUEUE_MODE, MSG_QUEUE_PERM, attr), "mq_open", mq_desc); } - else if (msgid == -1) + else if (mq_desc == -1) { WARN("Call to MsgQueueGetId failed : %s", strerror(errno)); return ncclSystemError; @@ -47,25 +53,49 @@ ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& ms return ncclSuccess; } -ncclResult_t MsgQueueSend(int msgid, const void* msgp, size_t msgsz, int msgflg) +ncclResult_t MsgQueueSend(mqd_t const& mq_desc, const char* msgp, size_t msgsz) { - SYSCHECK(msgsnd(msgid, msgp, msgsz, msgflg), "msgsnd"); + SYSCHECK(mq_send(mq_desc, msgp, msgsz, 0), "mq_send"); return ncclSuccess; } -ncclResult_t MsgQueueRecv(int msgid, void* msgp, size_t msgsz, long msgtyp, bool wait) +ncclResult_t MsgQueueRecv(mqd_t const& mq_desc, char* msgp, size_t msgsz) { - int msgflg = (wait == false ? IPC_NOWAIT : 0); - SYSCHECK(msgrcv(msgid, msgp, msgsz, msgtyp, msgflg), "msgrcv"); + SYSCHECK(mq_receive(mq_desc, msgp, msgsz, NULL), "mq_receive"); return ncclSuccess; } -ncclResult_t MsgQueueClose(std::string name, int projid) +ncclResult_t MsgQueueWaitUntilEmpty(mqd_t const& mq_desc) { - key_t key; - int msgid; - key = ftok(name.c_str(), projid); - SYSCHECKVAL(msgget(key, 0), "msgget", msgid); - SYSCHECK(msgctl(msgid, IPC_RMID, NULL), "msgctl"); + mq_attr attr; + mq_getattr(mq_desc, &attr); + + auto start = std::chrono::steady_clock::now(); + while(attr.mq_curmsgs > 0) + { + SYSCHECK(mq_getattr(mq_desc, &attr), "mq_getattr"); + if(std::chrono::steady_clock::now() - start > std::chrono::seconds(MSG_QUEUE_TIMEOUT)) + { + WARN("Message Queue timed out waiting for all ranks to receive messages."); + return ncclSystemError; + } + } return ncclSuccess; } + +ncclResult_t MsgQueueClose(std::string name, mqd_t& mq_desc, bool unlink) +{ + if (unlink) + { + NCCLCHECK(MsgQueueUnlink(name)); + } + SYSCHECK(mq_close(mq_desc), "mq_close"); + return ncclSuccess; +} + +ncclResult_t MsgQueueUnlink(std::string name) +{ + std::string mq_name = "/" + name; + SYSCHECK(mq_unlink(mq_name.c_str()), "mq_unlink"); + return ncclSuccess; +} \ No newline at end of file diff --git a/src/clique/MsgQueue.h b/src/clique/MsgQueue.h index 346208a6e8..af91add388 100644 --- a/src/clique/MsgQueue.h +++ b/src/clique/MsgQueue.h @@ -24,19 +24,16 @@ THE SOFTWARE. #define RCCL_MSG_QUEUE_HPP_ #include +#include #include "nccl.h" #include "core.h" -struct MsgBuffer -{ - long msg_type; - char msg_text[1]; -}; - -ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& msgid); -ncclResult_t MsgQueueSend(int msgid, const void* msgp, size_t msgsz, int msgflg); -ncclResult_t MsgQueueRecv(int msgid, void* msgp, size_t msgsz, long msgtyp, bool wait); -ncclResult_t MsgQueueClose(std::string name, int projid); +ncclResult_t MsgQueueGetId(std::string name, bool exclusive, mqd_t& mq_desc); +ncclResult_t MsgQueueSend(mqd_t const& mq_desc, const char* msgp, size_t msgsz); +ncclResult_t MsgQueueRecv(mqd_t const& mq_desc, char* msgp, size_t msgsz); +ncclResult_t MsgQueueWaitUntilEmpty(mqd_t const& mq_desc); +ncclResult_t MsgQueueClose(std::string name, mqd_t& mq_desc, bool unlink); +ncclResult_t MsgQueueUnlink(std::string name); #endif diff --git a/src/clique/ShmObject.h b/src/clique/ShmObject.h index 144f10d22c..9759292c88 100644 --- a/src/clique/ShmObject.h +++ b/src/clique/ShmObject.h @@ -92,19 +92,32 @@ ShmObject(size_t size, std::string fileName, int rank, int numRanks, int projid) return m_shmPtr; } protected: - ncclResult_t BroadcastMessage(int msgid, bool pass) + ncclResult_t BroadcastMessage(mqd_t& mq_desc, bool pass) { - MsgBuffer msg; - msg.msg_text[0] = (pass == 0 ? 'F': 'P'); + char msg_text[1]; + msg_text[0] = (pass == 0 ? 'F': 'P'); for (int rank = 0; rank < m_numRanks; rank++) { if (rank == m_rank) continue; - msg.msg_type = rank; - NCCLCHECK(MsgQueueSend(msgid, &msg, sizeof(msg), 0)); + NCCLCHECK(MsgQueueSend(mq_desc, &msg_text[0], sizeof(msg_text))); } return ncclSuccess; } + ncclResult_t BroadcastAndCloseMessageQueue(mqd_t& mq_desc, bool pass) + { + ncclResult_t res; + NCCLCHECKGOTO(BroadcastMessage(mq_desc, pass), res, dropback); + NCCLCHECKGOTO(MsgQueueWaitUntilEmpty(mq_desc), res, dropback); + NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, true)); + return ncclSuccess; + +dropback: + WARN("Root rank unable to broadcast across message queue. Closing message queue."); + NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, true)); + return ncclSystemError; + } + // tag for dispatch template struct OpenTag{}; @@ -127,33 +140,35 @@ protected: template ncclResult_t ShmObject::Open() { + mqd_t mq_desc; if (m_alloc == false) { int shmFd; int protection = PROT_READ | PROT_WRITE; int visibility = MAP_SHARED; - int msgid; - std::string tmpFileName = "/tmp/" + m_shmName; - NCCLCHECK(MsgQueueGetId(tmpFileName, m_projid, false, msgid)); + INFO(NCCL_INIT, "Rank %d Initializing message queue for %s\n", m_rank, m_shmName.c_str()); + NCCLCHECK(MsgQueueGetId(m_shmName, false, mq_desc)); if (m_rank == 0) { ncclResult_t resultSetup = shmSetupExclusive(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 1); ncclResult_t resultSemInit = InitIfSemaphore(OpenTag{}); if ((resultSetup != ncclSuccess && errno != EEXIST) || (resultSemInit != ncclSuccess)) { - NCCLCHECK(BroadcastMessage(msgid, false)); + NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, false)); WARN("Call to ShmObject::Open in root rank failed : %s", strerror(errno)); return ncclSystemError; } - NCCLCHECK(BroadcastMessage(msgid, true)); + NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, true)); } else { - MsgBuffer msg; - NCCLCHECK(MsgQueueRecv(msgid, &msg, sizeof(msg), m_rank, true)); - if (msg.msg_text[0] == 'P') + char msg_text[1]; + ncclResult_t res; + NCCLCHECKGOTO(MsgQueueRecv(mq_desc, &msg_text[0], sizeof(msg_text)), res, dropback); + NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false)); + if (msg_text[0] == 'P') { NCCLCHECK(shmSetup(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 0)); } @@ -171,6 +186,11 @@ ncclResult_t ShmObject::Open() return ncclInvalidUsage; } return ncclSuccess; + +dropback: + WARN("Rank %d unable to receive message from root. Closing message queue.", m_rank); + NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false)); + return ncclSystemError; } template