Message queue refactor to POSIX implementation and leak fix (#355)

* Fixing message queue leak.

* Using POSIX implementation of Message Queues

* Adding unlink to msgqueue

* MsgQueue update

* Adding timeout check to msgqueue broadcast; tightening up system checks

* Removing unnecessary code

* Removing extra argument from print

* Adding explicit msg queue close call to all other ranks
Этот коммит содержится в:
Stanley Tsang
2021-04-23 11:33:20 -06:00
коммит произвёл GitHub
родитель 415c7cd3d1
Коммит 70597789d0
4 изменённых файлов: 98 добавлений и 76 удалений
+4 -29
Просмотреть файл
@@ -72,28 +72,6 @@ CliqueManager::~CliqueManager()
void CliqueManager::CleanUp()
{
if (rcclParamEnableClique())
{
if (m_rank == 0)
{
int pid = getpid();
for (auto it = CliqueShmNames.begin(); it != CliqueShmNames.end(); it++)
{
std::string msgQueueName = "/tmp/" + it->second + std::to_string(m_hash) + "_" + std::to_string(pid);
ncclResult_t res = MsgQueueClose(msgQueueName, m_hash);
if (res != ncclSuccess)
{
WARN("Unable to close Message Queue: %s\n", msgQueueName.c_str());
}
int ret = unlink(msgQueueName.c_str());
if (ret != 0)
{
WARN("Unable to unlink %s\n", msgQueueName.c_str());
}
}
}
}
if (m_cliqueMode == CLIQUE_DISABLED) return;
// Free variables that are shared between SINGLE_PROCESS / SINGLE_NODE
@@ -131,7 +109,6 @@ void CliqueManager::CleanUp()
ncclResult_t CliqueManager::Init(ncclUniqueId const* commId, int suffix)
{
ncclResult_t res;
if (m_init) return ncclSuccess;
m_init = true;
@@ -167,7 +144,6 @@ ncclResult_t CliqueManager::Init(ncclUniqueId const* commId, int suffix)
return ncclSuccess;
}
std::string shmSuffix = std::to_string(m_hash) + "_" + std::to_string(suffix);
// Allocate sense barrier variable on local GPU
@@ -531,11 +507,10 @@ ncclResult_t CliqueManager::BootstrapRootInit(int pid, unsigned long hash)
{
for (auto it = CliqueShmNames.begin(); it != CliqueShmNames.end(); it++)
{
int msgid, fd;
std::string msgQueueName = "/tmp/" + it->second + std::to_string(hash) + "_" + std::to_string(pid);
SYSCHECKVAL(open(msgQueueName.c_str(), O_CREAT | O_RDWR, 0606), "open", fd);
NCCLCHECK(MsgQueueGetId(msgQueueName, hash, true, msgid));
SYSCHECK(close(fd), "close");
mqd_t mq_desc;
std::string msgQueueName = it->second + std::to_string(hash) + "_" + std::to_string(pid);
NCCLCHECK(MsgQueueGetId(msgQueueName, true, mq_desc));
NCCLCHECK(MsgQueueClose(msgQueueName, mq_desc, true));
}
std::string shmDir = "/dev/shm/";
+54 -24
Просмотреть файл
@@ -21,25 +21,31 @@ THE SOFTWARE.
*/
#include "MsgQueue.h"
#include <chrono>
#include <sys/ipc.h>
#include <sys/msg.h>
#define MSG_QUEUE_PERM S_IRUSR | S_IWUSR
#define MSG_QUEUE_MODE O_RDWR
#define MSG_SIZE 1
#define MSG_QUEUE_TIMEOUT 60
#define MSG_QUEUE_PERM 0666
ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& msgid)
ncclResult_t MsgQueueGetId(std::string name, bool exclusive, mqd_t& mq_desc)
{
key_t key;
SYSCHECKVAL(ftok(name.c_str(), projid), "ftok", key);
int flag = (exclusive == true ? IPC_CREAT | IPC_EXCL : IPC_CREAT);
msgid = msgget(key, MSG_QUEUE_PERM | flag);
int flag = (exclusive == true ? O_CREAT | O_EXCL : O_CREAT);
struct mq_attr attr;
attr.mq_maxmsg = 10;
attr.mq_msgsize = MSG_SIZE;
attr.mq_flags = 0;
std::string mq_name = "/" + name;
mq_desc = mq_open(mq_name.c_str(), flag | MSG_QUEUE_MODE, MSG_QUEUE_PERM, &attr);
// Check if we're trying to create message queue and it already exists; if so, delete existing queue
if (msgid == -1 && exclusive == true && errno == EEXIST)
if (mq_desc == -1 && exclusive == true && errno == EBUSY)
{
NCCLCHECK(MsgQueueClose(name, projid));
SYSCHECKVAL(msgget(key, MSG_QUEUE_PERM | flag), "msgget", msgid);
NCCLCHECK(MsgQueueClose(name, mq_desc, true));
SYSCHECKVAL(mq_open(mq_name.c_str(), flag | MSG_QUEUE_MODE, MSG_QUEUE_PERM, attr), "mq_open", mq_desc);
}
else if (msgid == -1)
else if (mq_desc == -1)
{
WARN("Call to MsgQueueGetId failed : %s", strerror(errno));
return ncclSystemError;
@@ -47,25 +53,49 @@ ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& ms
return ncclSuccess;
}
ncclResult_t MsgQueueSend(int msgid, const void* msgp, size_t msgsz, int msgflg)
ncclResult_t MsgQueueSend(mqd_t const& mq_desc, const char* msgp, size_t msgsz)
{
SYSCHECK(msgsnd(msgid, msgp, msgsz, msgflg), "msgsnd");
SYSCHECK(mq_send(mq_desc, msgp, msgsz, 0), "mq_send");
return ncclSuccess;
}
ncclResult_t MsgQueueRecv(int msgid, void* msgp, size_t msgsz, long msgtyp, bool wait)
ncclResult_t MsgQueueRecv(mqd_t const& mq_desc, char* msgp, size_t msgsz)
{
int msgflg = (wait == false ? IPC_NOWAIT : 0);
SYSCHECK(msgrcv(msgid, msgp, msgsz, msgtyp, msgflg), "msgrcv");
SYSCHECK(mq_receive(mq_desc, msgp, msgsz, NULL), "mq_receive");
return ncclSuccess;
}
ncclResult_t MsgQueueClose(std::string name, int projid)
ncclResult_t MsgQueueWaitUntilEmpty(mqd_t const& mq_desc)
{
key_t key;
int msgid;
key = ftok(name.c_str(), projid);
SYSCHECKVAL(msgget(key, 0), "msgget", msgid);
SYSCHECK(msgctl(msgid, IPC_RMID, NULL), "msgctl");
mq_attr attr;
mq_getattr(mq_desc, &attr);
auto start = std::chrono::steady_clock::now();
while(attr.mq_curmsgs > 0)
{
SYSCHECK(mq_getattr(mq_desc, &attr), "mq_getattr");
if(std::chrono::steady_clock::now() - start > std::chrono::seconds(MSG_QUEUE_TIMEOUT))
{
WARN("Message Queue timed out waiting for all ranks to receive messages.");
return ncclSystemError;
}
}
return ncclSuccess;
}
ncclResult_t MsgQueueClose(std::string name, mqd_t& mq_desc, bool unlink)
{
if (unlink)
{
NCCLCHECK(MsgQueueUnlink(name));
}
SYSCHECK(mq_close(mq_desc), "mq_close");
return ncclSuccess;
}
ncclResult_t MsgQueueUnlink(std::string name)
{
std::string mq_name = "/" + name;
SYSCHECK(mq_unlink(mq_name.c_str()), "mq_unlink");
return ncclSuccess;
}
+7 -10
Просмотреть файл
@@ -24,19 +24,16 @@ THE SOFTWARE.
#define RCCL_MSG_QUEUE_HPP_
#include <string>
#include <mqueue.h>
#include "nccl.h"
#include "core.h"
struct MsgBuffer
{
long msg_type;
char msg_text[1];
};
ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& msgid);
ncclResult_t MsgQueueSend(int msgid, const void* msgp, size_t msgsz, int msgflg);
ncclResult_t MsgQueueRecv(int msgid, void* msgp, size_t msgsz, long msgtyp, bool wait);
ncclResult_t MsgQueueClose(std::string name, int projid);
ncclResult_t MsgQueueGetId(std::string name, bool exclusive, mqd_t& mq_desc);
ncclResult_t MsgQueueSend(mqd_t const& mq_desc, const char* msgp, size_t msgsz);
ncclResult_t MsgQueueRecv(mqd_t const& mq_desc, char* msgp, size_t msgsz);
ncclResult_t MsgQueueWaitUntilEmpty(mqd_t const& mq_desc);
ncclResult_t MsgQueueClose(std::string name, mqd_t& mq_desc, bool unlink);
ncclResult_t MsgQueueUnlink(std::string name);
#endif
+33 -13
Просмотреть файл
@@ -92,19 +92,32 @@ ShmObject(size_t size, std::string fileName, int rank, int numRanks, int projid)
return m_shmPtr;
}
protected:
ncclResult_t BroadcastMessage(int msgid, bool pass)
ncclResult_t BroadcastMessage(mqd_t& mq_desc, bool pass)
{
MsgBuffer msg;
msg.msg_text[0] = (pass == 0 ? 'F': 'P');
char msg_text[1];
msg_text[0] = (pass == 0 ? 'F': 'P');
for (int rank = 0; rank < m_numRanks; rank++)
{
if (rank == m_rank) continue;
msg.msg_type = rank;
NCCLCHECK(MsgQueueSend(msgid, &msg, sizeof(msg), 0));
NCCLCHECK(MsgQueueSend(mq_desc, &msg_text[0], sizeof(msg_text)));
}
return ncclSuccess;
}
ncclResult_t BroadcastAndCloseMessageQueue(mqd_t& mq_desc, bool pass)
{
ncclResult_t res;
NCCLCHECKGOTO(BroadcastMessage(mq_desc, pass), res, dropback);
NCCLCHECKGOTO(MsgQueueWaitUntilEmpty(mq_desc), res, dropback);
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, true));
return ncclSuccess;
dropback:
WARN("Root rank unable to broadcast across message queue. Closing message queue.");
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, true));
return ncclSystemError;
}
// tag for dispatch
template<class U>
struct OpenTag{};
@@ -127,33 +140,35 @@ protected:
template <typename T>
ncclResult_t ShmObject<T>::Open()
{
mqd_t mq_desc;
if (m_alloc == false)
{
int shmFd;
int protection = PROT_READ | PROT_WRITE;
int visibility = MAP_SHARED;
int msgid;
std::string tmpFileName = "/tmp/" + m_shmName;
NCCLCHECK(MsgQueueGetId(tmpFileName, m_projid, false, msgid));
INFO(NCCL_INIT, "Rank %d Initializing message queue for %s\n", m_rank, m_shmName.c_str());
NCCLCHECK(MsgQueueGetId(m_shmName, false, mq_desc));
if (m_rank == 0)
{
ncclResult_t resultSetup = shmSetupExclusive(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 1);
ncclResult_t resultSemInit = InitIfSemaphore(OpenTag<T>{});
if ((resultSetup != ncclSuccess && errno != EEXIST) || (resultSemInit != ncclSuccess))
{
NCCLCHECK(BroadcastMessage(msgid, false));
NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, false));
WARN("Call to ShmObject::Open in root rank failed : %s", strerror(errno));
return ncclSystemError;
}
NCCLCHECK(BroadcastMessage(msgid, true));
NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, true));
}
else
{
MsgBuffer msg;
NCCLCHECK(MsgQueueRecv(msgid, &msg, sizeof(msg), m_rank, true));
if (msg.msg_text[0] == 'P')
char msg_text[1];
ncclResult_t res;
NCCLCHECKGOTO(MsgQueueRecv(mq_desc, &msg_text[0], sizeof(msg_text)), res, dropback);
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false));
if (msg_text[0] == 'P')
{
NCCLCHECK(shmSetup(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 0));
}
@@ -171,6 +186,11 @@ ncclResult_t ShmObject<T>::Open()
return ncclInvalidUsage;
}
return ncclSuccess;
dropback:
WARN("Rank %d unable to receive message from root. Closing message queue.", m_rank);
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false));
return ncclSystemError;
}
template<typename T>