Message queue refactor to POSIX implementation and leak fix (#355)
* Fixing message queue leak. * Using POSIX implementation of Message Queues * Adding unlink to msgqueue * MsgQueue update * Adding timeout check to msgqueue broadcast; tightening up system checks * Removing unnecessary code * Removing extra argument from print * Adding explicit msg queue close call to all other ranks
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
415c7cd3d1
Коммит
70597789d0
@@ -72,28 +72,6 @@ CliqueManager::~CliqueManager()
|
||||
|
||||
void CliqueManager::CleanUp()
|
||||
{
|
||||
if (rcclParamEnableClique())
|
||||
{
|
||||
if (m_rank == 0)
|
||||
{
|
||||
int pid = getpid();
|
||||
for (auto it = CliqueShmNames.begin(); it != CliqueShmNames.end(); it++)
|
||||
{
|
||||
std::string msgQueueName = "/tmp/" + it->second + std::to_string(m_hash) + "_" + std::to_string(pid);
|
||||
ncclResult_t res = MsgQueueClose(msgQueueName, m_hash);
|
||||
if (res != ncclSuccess)
|
||||
{
|
||||
WARN("Unable to close Message Queue: %s\n", msgQueueName.c_str());
|
||||
}
|
||||
int ret = unlink(msgQueueName.c_str());
|
||||
if (ret != 0)
|
||||
{
|
||||
WARN("Unable to unlink %s\n", msgQueueName.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (m_cliqueMode == CLIQUE_DISABLED) return;
|
||||
|
||||
// Free variables that are shared between SINGLE_PROCESS / SINGLE_NODE
|
||||
@@ -131,7 +109,6 @@ void CliqueManager::CleanUp()
|
||||
ncclResult_t CliqueManager::Init(ncclUniqueId const* commId, int suffix)
|
||||
{
|
||||
ncclResult_t res;
|
||||
|
||||
if (m_init) return ncclSuccess;
|
||||
m_init = true;
|
||||
|
||||
@@ -167,7 +144,6 @@ ncclResult_t CliqueManager::Init(ncclUniqueId const* commId, int suffix)
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
|
||||
std::string shmSuffix = std::to_string(m_hash) + "_" + std::to_string(suffix);
|
||||
|
||||
// Allocate sense barrier variable on local GPU
|
||||
@@ -531,11 +507,10 @@ ncclResult_t CliqueManager::BootstrapRootInit(int pid, unsigned long hash)
|
||||
{
|
||||
for (auto it = CliqueShmNames.begin(); it != CliqueShmNames.end(); it++)
|
||||
{
|
||||
int msgid, fd;
|
||||
std::string msgQueueName = "/tmp/" + it->second + std::to_string(hash) + "_" + std::to_string(pid);
|
||||
SYSCHECKVAL(open(msgQueueName.c_str(), O_CREAT | O_RDWR, 0606), "open", fd);
|
||||
NCCLCHECK(MsgQueueGetId(msgQueueName, hash, true, msgid));
|
||||
SYSCHECK(close(fd), "close");
|
||||
mqd_t mq_desc;
|
||||
std::string msgQueueName = it->second + std::to_string(hash) + "_" + std::to_string(pid);
|
||||
NCCLCHECK(MsgQueueGetId(msgQueueName, true, mq_desc));
|
||||
NCCLCHECK(MsgQueueClose(msgQueueName, mq_desc, true));
|
||||
}
|
||||
|
||||
std::string shmDir = "/dev/shm/";
|
||||
|
||||
@@ -21,25 +21,31 @@ THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "MsgQueue.h"
|
||||
#include <chrono>
|
||||
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#define MSG_QUEUE_PERM S_IRUSR | S_IWUSR
|
||||
#define MSG_QUEUE_MODE O_RDWR
|
||||
#define MSG_SIZE 1
|
||||
#define MSG_QUEUE_TIMEOUT 60
|
||||
|
||||
#define MSG_QUEUE_PERM 0666
|
||||
|
||||
ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& msgid)
|
||||
ncclResult_t MsgQueueGetId(std::string name, bool exclusive, mqd_t& mq_desc)
|
||||
{
|
||||
key_t key;
|
||||
SYSCHECKVAL(ftok(name.c_str(), projid), "ftok", key);
|
||||
int flag = (exclusive == true ? IPC_CREAT | IPC_EXCL : IPC_CREAT);
|
||||
msgid = msgget(key, MSG_QUEUE_PERM | flag);
|
||||
int flag = (exclusive == true ? O_CREAT | O_EXCL : O_CREAT);
|
||||
struct mq_attr attr;
|
||||
attr.mq_maxmsg = 10;
|
||||
attr.mq_msgsize = MSG_SIZE;
|
||||
attr.mq_flags = 0;
|
||||
|
||||
std::string mq_name = "/" + name;
|
||||
mq_desc = mq_open(mq_name.c_str(), flag | MSG_QUEUE_MODE, MSG_QUEUE_PERM, &attr);
|
||||
|
||||
// Check if we're trying to create message queue and it already exists; if so, delete existing queue
|
||||
if (msgid == -1 && exclusive == true && errno == EEXIST)
|
||||
if (mq_desc == -1 && exclusive == true && errno == EBUSY)
|
||||
{
|
||||
NCCLCHECK(MsgQueueClose(name, projid));
|
||||
SYSCHECKVAL(msgget(key, MSG_QUEUE_PERM | flag), "msgget", msgid);
|
||||
NCCLCHECK(MsgQueueClose(name, mq_desc, true));
|
||||
SYSCHECKVAL(mq_open(mq_name.c_str(), flag | MSG_QUEUE_MODE, MSG_QUEUE_PERM, attr), "mq_open", mq_desc);
|
||||
}
|
||||
else if (msgid == -1)
|
||||
else if (mq_desc == -1)
|
||||
{
|
||||
WARN("Call to MsgQueueGetId failed : %s", strerror(errno));
|
||||
return ncclSystemError;
|
||||
@@ -47,25 +53,49 @@ ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& ms
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t MsgQueueSend(int msgid, const void* msgp, size_t msgsz, int msgflg)
|
||||
ncclResult_t MsgQueueSend(mqd_t const& mq_desc, const char* msgp, size_t msgsz)
|
||||
{
|
||||
SYSCHECK(msgsnd(msgid, msgp, msgsz, msgflg), "msgsnd");
|
||||
SYSCHECK(mq_send(mq_desc, msgp, msgsz, 0), "mq_send");
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t MsgQueueRecv(int msgid, void* msgp, size_t msgsz, long msgtyp, bool wait)
|
||||
ncclResult_t MsgQueueRecv(mqd_t const& mq_desc, char* msgp, size_t msgsz)
|
||||
{
|
||||
int msgflg = (wait == false ? IPC_NOWAIT : 0);
|
||||
SYSCHECK(msgrcv(msgid, msgp, msgsz, msgtyp, msgflg), "msgrcv");
|
||||
SYSCHECK(mq_receive(mq_desc, msgp, msgsz, NULL), "mq_receive");
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t MsgQueueClose(std::string name, int projid)
|
||||
ncclResult_t MsgQueueWaitUntilEmpty(mqd_t const& mq_desc)
|
||||
{
|
||||
key_t key;
|
||||
int msgid;
|
||||
key = ftok(name.c_str(), projid);
|
||||
SYSCHECKVAL(msgget(key, 0), "msgget", msgid);
|
||||
SYSCHECK(msgctl(msgid, IPC_RMID, NULL), "msgctl");
|
||||
mq_attr attr;
|
||||
mq_getattr(mq_desc, &attr);
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
while(attr.mq_curmsgs > 0)
|
||||
{
|
||||
SYSCHECK(mq_getattr(mq_desc, &attr), "mq_getattr");
|
||||
if(std::chrono::steady_clock::now() - start > std::chrono::seconds(MSG_QUEUE_TIMEOUT))
|
||||
{
|
||||
WARN("Message Queue timed out waiting for all ranks to receive messages.");
|
||||
return ncclSystemError;
|
||||
}
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t MsgQueueClose(std::string name, mqd_t& mq_desc, bool unlink)
|
||||
{
|
||||
if (unlink)
|
||||
{
|
||||
NCCLCHECK(MsgQueueUnlink(name));
|
||||
}
|
||||
SYSCHECK(mq_close(mq_desc), "mq_close");
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t MsgQueueUnlink(std::string name)
|
||||
{
|
||||
std::string mq_name = "/" + name;
|
||||
SYSCHECK(mq_unlink(mq_name.c_str()), "mq_unlink");
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -24,19 +24,16 @@ THE SOFTWARE.
|
||||
#define RCCL_MSG_QUEUE_HPP_
|
||||
|
||||
#include <string>
|
||||
#include <mqueue.h>
|
||||
|
||||
#include "nccl.h"
|
||||
#include "core.h"
|
||||
|
||||
struct MsgBuffer
|
||||
{
|
||||
long msg_type;
|
||||
char msg_text[1];
|
||||
};
|
||||
|
||||
ncclResult_t MsgQueueGetId(std::string name, int projid, bool exclusive, int& msgid);
|
||||
ncclResult_t MsgQueueSend(int msgid, const void* msgp, size_t msgsz, int msgflg);
|
||||
ncclResult_t MsgQueueRecv(int msgid, void* msgp, size_t msgsz, long msgtyp, bool wait);
|
||||
ncclResult_t MsgQueueClose(std::string name, int projid);
|
||||
ncclResult_t MsgQueueGetId(std::string name, bool exclusive, mqd_t& mq_desc);
|
||||
ncclResult_t MsgQueueSend(mqd_t const& mq_desc, const char* msgp, size_t msgsz);
|
||||
ncclResult_t MsgQueueRecv(mqd_t const& mq_desc, char* msgp, size_t msgsz);
|
||||
ncclResult_t MsgQueueWaitUntilEmpty(mqd_t const& mq_desc);
|
||||
ncclResult_t MsgQueueClose(std::string name, mqd_t& mq_desc, bool unlink);
|
||||
ncclResult_t MsgQueueUnlink(std::string name);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -92,19 +92,32 @@ ShmObject(size_t size, std::string fileName, int rank, int numRanks, int projid)
|
||||
return m_shmPtr;
|
||||
}
|
||||
protected:
|
||||
ncclResult_t BroadcastMessage(int msgid, bool pass)
|
||||
ncclResult_t BroadcastMessage(mqd_t& mq_desc, bool pass)
|
||||
{
|
||||
MsgBuffer msg;
|
||||
msg.msg_text[0] = (pass == 0 ? 'F': 'P');
|
||||
char msg_text[1];
|
||||
msg_text[0] = (pass == 0 ? 'F': 'P');
|
||||
for (int rank = 0; rank < m_numRanks; rank++)
|
||||
{
|
||||
if (rank == m_rank) continue;
|
||||
msg.msg_type = rank;
|
||||
NCCLCHECK(MsgQueueSend(msgid, &msg, sizeof(msg), 0));
|
||||
NCCLCHECK(MsgQueueSend(mq_desc, &msg_text[0], sizeof(msg_text)));
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t BroadcastAndCloseMessageQueue(mqd_t& mq_desc, bool pass)
|
||||
{
|
||||
ncclResult_t res;
|
||||
NCCLCHECKGOTO(BroadcastMessage(mq_desc, pass), res, dropback);
|
||||
NCCLCHECKGOTO(MsgQueueWaitUntilEmpty(mq_desc), res, dropback);
|
||||
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, true));
|
||||
return ncclSuccess;
|
||||
|
||||
dropback:
|
||||
WARN("Root rank unable to broadcast across message queue. Closing message queue.");
|
||||
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, true));
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
// tag for dispatch
|
||||
template<class U>
|
||||
struct OpenTag{};
|
||||
@@ -127,33 +140,35 @@ protected:
|
||||
template <typename T>
|
||||
ncclResult_t ShmObject<T>::Open()
|
||||
{
|
||||
mqd_t mq_desc;
|
||||
if (m_alloc == false)
|
||||
{
|
||||
int shmFd;
|
||||
int protection = PROT_READ | PROT_WRITE;
|
||||
int visibility = MAP_SHARED;
|
||||
|
||||
int msgid;
|
||||
std::string tmpFileName = "/tmp/" + m_shmName;
|
||||
NCCLCHECK(MsgQueueGetId(tmpFileName, m_projid, false, msgid));
|
||||
INFO(NCCL_INIT, "Rank %d Initializing message queue for %s\n", m_rank, m_shmName.c_str());
|
||||
|
||||
NCCLCHECK(MsgQueueGetId(m_shmName, false, mq_desc));
|
||||
if (m_rank == 0)
|
||||
{
|
||||
ncclResult_t resultSetup = shmSetupExclusive(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 1);
|
||||
ncclResult_t resultSemInit = InitIfSemaphore(OpenTag<T>{});
|
||||
if ((resultSetup != ncclSuccess && errno != EEXIST) || (resultSemInit != ncclSuccess))
|
||||
{
|
||||
NCCLCHECK(BroadcastMessage(msgid, false));
|
||||
NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, false));
|
||||
WARN("Call to ShmObject::Open in root rank failed : %s", strerror(errno));
|
||||
return ncclSystemError;
|
||||
}
|
||||
NCCLCHECK(BroadcastMessage(msgid, true));
|
||||
NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, true));
|
||||
}
|
||||
else
|
||||
{
|
||||
MsgBuffer msg;
|
||||
NCCLCHECK(MsgQueueRecv(msgid, &msg, sizeof(msg), m_rank, true));
|
||||
if (msg.msg_text[0] == 'P')
|
||||
char msg_text[1];
|
||||
ncclResult_t res;
|
||||
NCCLCHECKGOTO(MsgQueueRecv(mq_desc, &msg_text[0], sizeof(msg_text)), res, dropback);
|
||||
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false));
|
||||
if (msg_text[0] == 'P')
|
||||
{
|
||||
NCCLCHECK(shmSetup(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 0));
|
||||
}
|
||||
@@ -171,6 +186,11 @@ ncclResult_t ShmObject<T>::Open()
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
|
||||
dropback:
|
||||
WARN("Rank %d unable to receive message from root. Closing message queue.", m_rank);
|
||||
NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false));
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
||||
Ссылка в новой задаче
Block a user