Improved the determination of MPI rank (#61)

* Improved the determination of MPI rank

* C-style cast of MPI_Comm
Este commit está contenido en:
Jonathan R. Madsen
2022-06-21 00:27:52 -05:00
cometido por GitHub
padre 5583168dbc
commit dfda902092
Se han modificado 4 ficheros con 319 adiciones y 102 borrados
+150 -29
Ver fichero
@@ -42,8 +42,9 @@ std::string _name = {};
template <typename Tp, size_t N>
void
all2all(int _rank)
all2all(int _rank, MPI_Comm _comm)
{
if(_comm == MPI_COMM_NULL) return;
static_assert(N > 0, "Error! N must be greater than zero!");
auto _mt = std::mt19937_64{ size_t(_rank + 100) };
@@ -74,7 +75,7 @@ all2all(int _rank)
printf("[%s][%i] values sent (# = %zu) :: %s.\n", _name.c_str(), _rank,
values_sent.size(), _get_values_str(values_sent).c_str());
auto _dtype = MPI_INT;
auto _dtype = MPI_INT; // NOLINT
if(std::is_same<Tp, long>::value)
_dtype = MPI_LONG;
else if(std::is_same<Tp, float>::value)
@@ -82,57 +83,177 @@ all2all(int _rank)
else if(std::is_same<Tp, double>::value)
_dtype = MPI_DOUBLE;
MPI_Alltoall(&values_sent[_rank], 1, _dtype, &values_recv[_rank], 1, _dtype,
MPI_COMM_WORLD);
MPI_Alltoall(&values_sent[_rank], 1, _dtype, &values_recv[_rank], 1, _dtype, _comm);
if(_rank == 0)
printf("[%s][%i] values recv (# = %zu) :: %s.\n", _name.c_str(), _rank,
values_sent.size(), _get_values_str(values_recv).c_str());
}
void
run(MPI_Comm _comm, int nitr)
{
if(_comm == MPI_COMM_NULL) return;
int _rank = 0;
int _size = 0;
MPI_Comm_rank(_comm, &_rank);
MPI_Comm_size(_comm, &_size);
printf("[%s][%i] running %i iterations on %i ranks...\n", _name.c_str(), _rank, nitr,
_size);
MPI_Barrier(_comm);
for(int i = 0; i < nitr; ++i)
{
all2all<int, 3>(_rank, _comm);
all2all<long, 4>(_rank, _comm);
MPI_Barrier(_comm);
all2all<float, 5>(_rank, _comm);
all2all<double, 6>(_rank, _comm);
}
MPI_Barrier(_comm);
}
void
print_info(MPI_Comm _comm, bool _verbose, std::string _msg = {})
{
if(_comm == MPI_COMM_NULL) return;
int _rank = 0;
int _size = 1;
MPI_Comm_rank(_comm, &_rank);
MPI_Comm_size(_comm, &_size);
if(!_msg.empty()) _msg = "[" + _msg + "] ";
if(_verbose)
{
auto _ppid = getppid();
std::ifstream _ifs{ "/proc/" + std::to_string(_ppid) + "/task/" +
std::to_string(_ppid) + "/children" };
std::stringstream _ss{};
while(_ifs)
{
std::string _s{};
_ifs >> _s;
_ss << _s << " ";
}
if(_rank == 0)
printf("[%s]%s RANK = %i (out of %i), PID = %i, PPID = %i :: %s\n",
_name.c_str(), _msg.c_str(), _rank, _size, getpid(), getppid(),
_ss.str().c_str());
}
else
{
if(_rank == 0)
printf("[%s]%s RANK = %i (out of %i), PID = %i, PPID = %i\n", _name.c_str(),
_msg.c_str(), _rank, _size, getpid(), getppid());
}
}
int
main(int argc, char** argv)
{
int _mpi_thread_provided;
MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &_mpi_thread_provided);
int rank = 0;
int size = 1;
int nitr = 1;
if(argc > 1) nitr = atoi(argv[2]);
MPI_Comm_size(MPI_COMM_WORLD, &size);
_name = argv[0];
auto _pos = _name.find_last_of('/');
if(_pos < _name.length()) _name = _name.substr(_pos + 1);
printf("[%s] Number of iterations: %i\n", _name.c_str(), nitr);
int _mpi_thread_provided;
MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &_mpi_thread_provided);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
printf("[%s][%i] running with MPI_COMM_WORLD...\n", _name.c_str(), getpid());
run(MPI_COMM_WORLD, nitr);
auto _ppid = getppid();
std::ifstream _ifs{ "/proc/" + std::to_string(_ppid) + "/task/" +
std::to_string(_ppid) + "/children" };
std::stringstream _ss{};
while(_ifs)
{
std::string _s{};
_ifs >> _s;
_ss << _s << " ";
}
printf("[%s] RANK = %i, PID = %i, PPID = %i :: %s\n", _name.c_str(), rank, getpid(),
getppid(), _ss.str().c_str());
print_info(MPI_COMM_WORLD, true, "MPI_COMM_WORLD");
MPI_Barrier(MPI_COMM_WORLD);
for(int i = 0; i < nitr; ++i)
printf("[%s]\n", _name.c_str());
if(size > 1)
{
all2all<int, 3>(rank);
all2all<long, 4>(rank);
MPI_Barrier(MPI_COMM_WORLD);
all2all<float, 5>(rank);
all2all<double, 6>(rank);
MPI_Comm dup;
printf("[%s][%i] Duplicating MPI_COMM_WORLD...\n", _name.c_str(), getpid());
MPI_Comm_dup(MPI_COMM_WORLD, &dup);
printf("[%s][%i] running with duplicated comm of MPI_COMM_WORLD...\n",
_name.c_str(), getpid());
run(dup, nitr);
MPI_Comm_rank(dup, &rank);
if(rank == 0) printf("[%s]\n", _name.c_str());
printf("[%s][%i] RANK = %i on duplicated MPI_COMM_WORLD...\n", _name.c_str(),
getpid(), rank);
if(size > 3)
{
std::vector<MPI_Comm> comms(3);
for(int i = 0; i < size; ++i)
{
auto _idx = i % 3;
printf("[%s][%i] Splitting duplicated MPI_COMM_WORLD %i (rank = %i)...\n",
_name.c_str(), getpid(), _idx, rank);
MPI_Comm* comm = &comms.at(_idx);
MPI_Comm_split(dup, _idx, rank, comm);
}
for(auto itr : comms) // NOLINT
MPI_Barrier(itr);
for(int i = 0; i < size; ++i)
{
auto _idx = i % 3;
int _rank = 0;
MPI_Comm_rank(comms.at(_idx), &_rank);
printf("[%s][%i] Running on split communicator %i (rank = %i)...\n",
_name.c_str(), getpid(), _idx, _rank);
run(comms.at(_idx), nitr);
}
// Get the group of processes in MPI_COMM_WORLD
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int n = 0;
const int ranks[7] = { 1, 2, 3, 5, 7, 11, 13 };
for(int rank : ranks)
if(rank < size) ++n;
// Construct a group containing all of the prime ranks in world_group
MPI_Group prime_group;
MPI_Group_incl(world_group, n, ranks, &prime_group);
// Create a new communicator based on the group
MPI_Comm prime_comm;
MPI_Comm_create_group(MPI_COMM_WORLD, prime_group, 0, &prime_comm);
MPI_Group nonprime_group;
MPI_Group_difference(world_group, prime_group, &nonprime_group);
MPI_Comm nonprime_comm;
MPI_Comm_create_group(MPI_COMM_WORLD, nonprime_group, 1, &nonprime_comm);
print_info(prime_comm, false, "Prime comm");
print_info(nonprime_comm, false, "Non-prime comm");
run(prime_comm, nitr);
run(nonprime_comm, nitr);
MPI_Group_free(&world_group);
MPI_Group_free(&prime_group);
MPI_Group_free(&nonprime_group);
}
print_info(dup, false);
}
MPI_Barrier(MPI_COMM_WORLD);
MPI_Finalize();
return 0;
}
@@ -27,16 +27,71 @@
#include "library/debug.hpp"
#include "library/mproc.hpp"
#include <thread>
#include <timemory/backends/mpi.hpp>
#include <timemory/backends/process.hpp>
#include <timemory/utility/locking.hpp>
#include <cstdint>
#include <limits>
#include <thread>
#include <unistd.h>
namespace omnitrace
{
namespace
{
uint64_t mpip_index = std::numeric_limits<uint64_t>::max();
std::string mpi_init_string = {};
struct comm_rank_data
{
int rank = -1;
int size = -1;
uintptr_t comm = mpi_gotcha::null_comm();
auto updated() const
{
return comm != mpi_gotcha::null_comm() && rank >= 0 && size > 0;
};
friend bool operator==(const comm_rank_data& _lhs, const comm_rank_data& _rhs)
{
auto _lupd = _lhs.updated();
auto _rupd = _rhs.updated();
return std::tie(_lupd, _lhs.rank, _lhs.size, _lhs.comm) ==
std::tie(_rupd, _rhs.rank, _rhs.size, _rhs.comm);
}
friend bool operator!=(const comm_rank_data& _lhs, const comm_rank_data& _rhs)
{
return !(_lhs == _rhs);
}
friend bool operator>(const comm_rank_data& _lhs, const comm_rank_data& _rhs)
{
OMNITRACE_CI_THROW(!_lhs.updated() && !_rhs.updated(),
"Error! comparing rank data that is not updated");
if(_lhs.updated() && !_rhs.updated()) return true;
if(!_lhs.updated() && _rhs.updated()) return false;
if(_lhs.size != _rhs.size) return _lhs.size > _rhs.size;
if(_lhs.rank != _rhs.rank) return _lhs.rank > _rhs.rank;
// lesser comm is greater
return _lhs.comm < _rhs.comm;
}
friend bool operator<(const comm_rank_data& _lhs, const comm_rank_data& _rhs)
{
return (_lhs != _rhs && !(_lhs > _rhs));
}
};
uint64_t mpip_index = std::numeric_limits<uint64_t>::max();
auto last_comm_record = comm_rank_data{};
auto mproc_comm_record = comm_rank_data{};
auto mpi_comm_records = std::map<uintptr_t, comm_rank_data>{};
using tim::auto_lock_t;
using tim::type_mutex;
// this ensures omnitrace_finalize is called before MPI_Finalize
void
@@ -80,6 +135,50 @@ mpi_gotcha::configure()
};
}
void
mpi_gotcha::stop()
{
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] Stopping MPI gotcha...\n", process::get_id());
update();
}
void
mpi_gotcha::update()
{
auto_lock_t _lk{ type_mutex<mpi_gotcha>(), std::defer_lock };
if(!_lk.owns_lock()) _lk.lock();
comm_rank_data _rank_data = mproc_comm_record;
for(const auto& itr : mpi_comm_records)
{
// skip null comms
if(itr.first == null_comm()) continue;
// if currently have null comm, replace
else if(_rank_data.comm == null_comm())
_rank_data = itr.second;
// if
else if(itr.second > _rank_data)
_rank_data = itr.second;
}
if(_rank_data.updated() && _rank_data != last_comm_record)
{
auto _rank = _rank_data.rank;
auto _size = _rank_data.size;
tim::mpi::set_rank(_rank);
tim::mpi::set_size(_size);
tim::settings::default_process_suffix() = _rank;
get_perfetto_output_filename().clear();
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI rank: %i (%i)\n", process::get_id(),
tim::mpi::rank(), _rank);
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI size: %i (%i)\n", process::get_id(),
tim::mpi::size(), _size);
last_comm_record = _rank_data;
}
}
void
mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming, int*, char***)
{
@@ -126,17 +225,19 @@ mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming)
}
void
mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming, comm_t, int* _val)
mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming, comm_t _comm, int* _val)
{
OMNITRACE_BASIC_DEBUG_F("%s()\n", _data.tool_id.c_str());
omnitrace_push_trace_hidden(_data.tool_id.c_str());
if(_data.tool_id == "MPI_Comm_rank")
{
m_comm_val = (uintptr_t) _comm; // NOLINT
m_rank_ptr = _val;
}
else if(_data.tool_id == "MPI_Comm_size")
{
m_comm_val = (uintptr_t) _comm; // NOLINT
m_size_ptr = _val;
}
else
@@ -172,71 +273,51 @@ mpi_gotcha::audit(const gotcha_data_t& _data, audit::outgoing, int _retval)
api::omnitrace>();
}
auto _size = mproc::get_concurrent_processes().size();
if(_size > 0)
auto_lock_t _lk{ type_mutex<mpi_gotcha>() };
if(!mproc_comm_record.updated())
{
m_size = _size;
tim::mpi::set_size(_size);
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI size: %i (%i)\n", process::get_id(),
tim::mpi::size(), m_size);
auto _rank = mproc::get_process_index();
if(_rank >= 0)
auto _pid = getpid();
auto _ppid = getppid();
auto _size = mproc::get_concurrent_processes(_ppid).size();
if(_size > 0)
{
m_rank = _rank;
tim::mpi::set_rank(_rank);
tim::settings::default_process_suffix() = _rank;
get_perfetto_output_filename().clear();
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI rank: %i (%i)\n",
process::get_id(), tim::mpi::rank(), m_rank);
mproc_comm_record.comm = _ppid;
mproc_comm_record.size = m_size = _size;
auto _rank = mproc::get_process_index(_pid, _ppid);
if(_rank >= 0) mproc_comm_record.rank = m_rank = _rank;
}
}
}
else if(_retval == tim::mpi::success_v && _data.tool_id.find("MPI_Comm_") == 0)
{
if(_data.tool_id == "MPI_Comm_rank")
auto_lock_t _lk{ type_mutex<mpi_gotcha>() };
if(m_comm_val != null_comm())
{
if(m_rank_ptr)
auto& _comm_entry = mpi_comm_records[m_comm_val];
_comm_entry.comm = m_comm_val;
auto _get_rank = [&]() {
return (m_rank_ptr) ? std::max<int>(*m_rank_ptr, m_rank) : m_rank;
};
auto _get_size = [&]() {
return (m_size_ptr) ? std::max<int>(*m_size_ptr, m_size)
: std::max<int>(m_size, _get_rank() + 1);
};
if(_data.tool_id == "MPI_Comm_rank" || _data.tool_id == "MPI_Comm_size")
{
if(mproc::get_concurrent_processes().empty())
{
m_rank = std::max<int>(*m_rank_ptr, m_rank);
tim::mpi::set_rank(m_rank);
tim::settings::default_process_suffix() = m_rank;
get_perfetto_output_filename().clear();
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI rank: %i (%i)\n",
process::get_id(), tim::mpi::rank(), m_rank);
}
_comm_entry.rank = m_rank = std::max<int>(_comm_entry.rank, _get_rank());
_comm_entry.size = m_size = std::max<int>(_comm_entry.size, _get_size());
}
else
{
OMNITRACE_BASIC_VERBOSE(0, "%s() returned %i :: nullptr to rank\n",
_data.tool_id.c_str(), (int) _retval);
OMNITRACE_BASIC_VERBOSE(
0, "%s() returned %i :: unexpected function wrapper\n",
_data.tool_id.c_str(), (int) _retval);
}
}
else if(_data.tool_id == "MPI_Comm_size")
{
if(m_size_ptr)
{
if(mproc::get_concurrent_processes().empty())
{
m_size = std::max<int>(*m_size_ptr, m_size);
tim::mpi::set_size(m_size);
OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI size: %i (%i)\n",
process::get_id(), tim::mpi::size(), m_size);
}
}
else
{
OMNITRACE_BASIC_VERBOSE(0, "%s() returned %i :: nullptr to size\n",
_data.tool_id.c_str(), (int) _retval);
}
}
else
{
OMNITRACE_BASIC_VERBOSE(0,
"%s() returned %i :: unexpected function wrapper\n",
_data.tool_id.c_str(), (int) _retval);
// if(_comm_entry.updated()) update();
}
}
omnitrace_pop_trace_hidden(_data.tool_id.c_str());
@@ -26,6 +26,8 @@
#include "library/defines.hpp"
#include "library/timemory.hpp"
#include <cstdint>
namespace omnitrace
{
// this is used to wrap MPI_Init and MPI_Init_thread
@@ -60,13 +62,17 @@ struct mpi_gotcha : comp::base<mpi_gotcha, void>
// without these you will get a verbosity level 1 warning
static void start() {}
static void stop() {}
static void stop();
static void update();
static uintptr_t null_comm() { return std::numeric_limits<uintptr_t>::max(); }
private:
int* m_rank_ptr = nullptr;
int* m_size_ptr = nullptr;
int m_rank = 0;
int m_size = 1;
int m_rank = 0;
int m_size = 1;
int* m_rank_ptr = nullptr;
int* m_size_ptr = nullptr;
uintptr_t m_comm_val = null_comm();
};
using mpi_gotcha_t = comp::gotcha<5, tim::component_tuple<mpi_gotcha>, api::omnitrace>;
+21 -12
Ver fichero
@@ -70,6 +70,15 @@ get_critical_trace_debug() OMNITRACE_HOT;
namespace debug
{
inline void
flush()
{
fflush(stdout);
std::cout << std::flush;
fflush(stderr);
std::cerr << std::flush;
}
//
struct lock
{
lock();
@@ -154,50 +163,50 @@ get_chars(T&& _c, std::index_sequence<Idx...>)
if((COND) && ::omnitrace::config::get_debug_tid() && \
::omnitrace::config::get_debug_pid()) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
::omnitrace::debug::lock _lk{}; \
fprintf(stderr, "[omnitrace][%i][%li]%s", OMNITRACE_PROCESS_IDENTIFIER, \
OMNITRACE_THREAD_IDENTIFIER, \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
fprintf(stderr, __VA_ARGS__); \
fflush(stderr); \
::omnitrace::debug::flush(); \
}
#define OMNITRACE_CONDITIONAL_BASIC_PRINT(COND, ...) \
if((COND) && ::omnitrace::config::get_debug_tid() && \
::omnitrace::config::get_debug_pid()) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
::omnitrace::debug::lock _lk{}; \
fprintf(stderr, "[omnitrace]%s", \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
fprintf(stderr, __VA_ARGS__); \
fflush(stderr); \
::omnitrace::debug::flush(); \
}
#define OMNITRACE_CONDITIONAL_PRINT_F(COND, ...) \
if((COND) && ::omnitrace::config::get_debug_tid() && \
::omnitrace::config::get_debug_pid()) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
::omnitrace::debug::lock _lk{}; \
fprintf(stderr, "[omnitrace][%i][%li][%s]%s", OMNITRACE_PROCESS_IDENTIFIER, \
OMNITRACE_THREAD_IDENTIFIER, OMNITRACE_FUNCTION, \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
fprintf(stderr, __VA_ARGS__); \
fflush(stderr); \
::omnitrace::debug::flush(); \
}
#define OMNITRACE_CONDITIONAL_BASIC_PRINT_F(COND, ...) \
if((COND) && ::omnitrace::config::get_debug_tid() && \
::omnitrace::config::get_debug_pid()) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
::omnitrace::debug::lock _lk{}; \
fprintf(stderr, "[omnitrace][%s]%s", OMNITRACE_FUNCTION, \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
fprintf(stderr, __VA_ARGS__); \
fflush(stderr); \
::omnitrace::debug::flush(); \
}
//--------------------------------------------------------------------------------------//
@@ -240,7 +249,7 @@ get_chars(T&& _c, std::index_sequence<Idx...>)
#define OMNITRACE_CONDITIONAL_FAIL(COND, ...) \
if(COND) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
fprintf(stderr, "[omnitrace][%i][%li]%s", OMNITRACE_PROCESS_IDENTIFIER, \
OMNITRACE_THREAD_IDENTIFIER, \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
@@ -254,7 +263,7 @@ get_chars(T&& _c, std::index_sequence<Idx...>)
#define OMNITRACE_CONDITIONAL_BASIC_FAIL(COND, ...) \
if(COND) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
fprintf(stderr, "[omnitrace]%s", \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
fprintf(stderr, __VA_ARGS__); \
@@ -267,7 +276,7 @@ get_chars(T&& _c, std::index_sequence<Idx...>)
#define OMNITRACE_CONDITIONAL_FAIL_F(COND, ...) \
if(COND) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
fprintf(stderr, "[omnitrace][%i][%li][%s]%s", OMNITRACE_PROCESS_IDENTIFIER, \
OMNITRACE_THREAD_IDENTIFIER, OMNITRACE_FUNCTION, \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
@@ -281,7 +290,7 @@ get_chars(T&& _c, std::index_sequence<Idx...>)
#define OMNITRACE_CONDITIONAL_BASIC_FAIL_F(COND, ...) \
if(COND) \
{ \
fflush(stderr); \
::omnitrace::debug::flush(); \
fprintf(stderr, "[omnitrace][%s]%s", OMNITRACE_FUNCTION, \
::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \
fprintf(stderr, __VA_ARGS__); \