diff --git a/examples/mpi/mpi.cpp b/examples/mpi/mpi.cpp index 2299c7db16..1ff9463e79 100644 --- a/examples/mpi/mpi.cpp +++ b/examples/mpi/mpi.cpp @@ -42,8 +42,9 @@ std::string _name = {}; template void -all2all(int _rank) +all2all(int _rank, MPI_Comm _comm) { + if(_comm == MPI_COMM_NULL) return; static_assert(N > 0, "Error! N must be greater than zero!"); auto _mt = std::mt19937_64{ size_t(_rank + 100) }; @@ -74,7 +75,7 @@ all2all(int _rank) printf("[%s][%i] values sent (# = %zu) :: %s.\n", _name.c_str(), _rank, values_sent.size(), _get_values_str(values_sent).c_str()); - auto _dtype = MPI_INT; + auto _dtype = MPI_INT; // NOLINT if(std::is_same::value) _dtype = MPI_LONG; else if(std::is_same::value) @@ -82,57 +83,177 @@ all2all(int _rank) else if(std::is_same::value) _dtype = MPI_DOUBLE; - MPI_Alltoall(&values_sent[_rank], 1, _dtype, &values_recv[_rank], 1, _dtype, - MPI_COMM_WORLD); + MPI_Alltoall(&values_sent[_rank], 1, _dtype, &values_recv[_rank], 1, _dtype, _comm); if(_rank == 0) printf("[%s][%i] values recv (# = %zu) :: %s.\n", _name.c_str(), _rank, values_sent.size(), _get_values_str(values_recv).c_str()); } +void +run(MPI_Comm _comm, int nitr) +{ + if(_comm == MPI_COMM_NULL) return; + int _rank = 0; + int _size = 0; + MPI_Comm_rank(_comm, &_rank); + MPI_Comm_size(_comm, &_size); + + printf("[%s][%i] running %i iterations on %i ranks...\n", _name.c_str(), _rank, nitr, + _size); + + MPI_Barrier(_comm); + for(int i = 0; i < nitr; ++i) + { + all2all(_rank, _comm); + all2all(_rank, _comm); + MPI_Barrier(_comm); + all2all(_rank, _comm); + all2all(_rank, _comm); + } + MPI_Barrier(_comm); +} + +void +print_info(MPI_Comm _comm, bool _verbose, std::string _msg = {}) +{ + if(_comm == MPI_COMM_NULL) return; + int _rank = 0; + int _size = 1; + MPI_Comm_rank(_comm, &_rank); + MPI_Comm_size(_comm, &_size); + + if(!_msg.empty()) _msg = "[" + _msg + "] "; + + if(_verbose) + { + auto _ppid = getppid(); + std::ifstream _ifs{ "/proc/" + std::to_string(_ppid) + "/task/" + + std::to_string(_ppid) + "/children" }; + std::stringstream _ss{}; + while(_ifs) + { + std::string _s{}; + _ifs >> _s; + _ss << _s << " "; + } + if(_rank == 0) + printf("[%s]%s RANK = %i (out of %i), PID = %i, PPID = %i :: %s\n", + _name.c_str(), _msg.c_str(), _rank, _size, getpid(), getppid(), + _ss.str().c_str()); + } + else + { + if(_rank == 0) + printf("[%s]%s RANK = %i (out of %i), PID = %i, PPID = %i\n", _name.c_str(), + _msg.c_str(), _rank, _size, getpid(), getppid()); + } +} + int main(int argc, char** argv) { + int _mpi_thread_provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &_mpi_thread_provided); + int rank = 0; int size = 1; int nitr = 1; + if(argc > 1) nitr = atoi(argv[2]); + MPI_Comm_size(MPI_COMM_WORLD, &size); + _name = argv[0]; auto _pos = _name.find_last_of('/'); if(_pos < _name.length()) _name = _name.substr(_pos + 1); printf("[%s] Number of iterations: %i\n", _name.c_str(), nitr); - int _mpi_thread_provided; - MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &_mpi_thread_provided); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); + printf("[%s][%i] running with MPI_COMM_WORLD...\n", _name.c_str(), getpid()); + run(MPI_COMM_WORLD, nitr); - auto _ppid = getppid(); - std::ifstream _ifs{ "/proc/" + std::to_string(_ppid) + "/task/" + - std::to_string(_ppid) + "/children" }; - std::stringstream _ss{}; - while(_ifs) - { - std::string _s{}; - _ifs >> _s; - _ss << _s << " "; - } - printf("[%s] RANK = %i, PID = %i, PPID = %i :: %s\n", _name.c_str(), rank, getpid(), - getppid(), _ss.str().c_str()); + print_info(MPI_COMM_WORLD, true, "MPI_COMM_WORLD"); - MPI_Barrier(MPI_COMM_WORLD); - for(int i = 0; i < nitr; ++i) + printf("[%s]\n", _name.c_str()); + + if(size > 1) { - all2all(rank); - all2all(rank); - MPI_Barrier(MPI_COMM_WORLD); - all2all(rank); - all2all(rank); + MPI_Comm dup; + printf("[%s][%i] Duplicating MPI_COMM_WORLD...\n", _name.c_str(), getpid()); + MPI_Comm_dup(MPI_COMM_WORLD, &dup); + + printf("[%s][%i] running with duplicated comm of MPI_COMM_WORLD...\n", + _name.c_str(), getpid()); + run(dup, nitr); + + MPI_Comm_rank(dup, &rank); + if(rank == 0) printf("[%s]\n", _name.c_str()); + printf("[%s][%i] RANK = %i on duplicated MPI_COMM_WORLD...\n", _name.c_str(), + getpid(), rank); + + if(size > 3) + { + std::vector comms(3); + for(int i = 0; i < size; ++i) + { + auto _idx = i % 3; + printf("[%s][%i] Splitting duplicated MPI_COMM_WORLD %i (rank = %i)...\n", + _name.c_str(), getpid(), _idx, rank); + MPI_Comm* comm = &comms.at(_idx); + MPI_Comm_split(dup, _idx, rank, comm); + } + + for(auto itr : comms) // NOLINT + MPI_Barrier(itr); + + for(int i = 0; i < size; ++i) + { + auto _idx = i % 3; + int _rank = 0; + MPI_Comm_rank(comms.at(_idx), &_rank); + printf("[%s][%i] Running on split communicator %i (rank = %i)...\n", + _name.c_str(), getpid(), _idx, _rank); + run(comms.at(_idx), nitr); + } + + // Get the group of processes in MPI_COMM_WORLD + MPI_Group world_group; + MPI_Comm_group(MPI_COMM_WORLD, &world_group); + + int n = 0; + const int ranks[7] = { 1, 2, 3, 5, 7, 11, 13 }; + for(int rank : ranks) + if(rank < size) ++n; + + // Construct a group containing all of the prime ranks in world_group + MPI_Group prime_group; + MPI_Group_incl(world_group, n, ranks, &prime_group); + + // Create a new communicator based on the group + MPI_Comm prime_comm; + MPI_Comm_create_group(MPI_COMM_WORLD, prime_group, 0, &prime_comm); + + MPI_Group nonprime_group; + MPI_Group_difference(world_group, prime_group, &nonprime_group); + + MPI_Comm nonprime_comm; + MPI_Comm_create_group(MPI_COMM_WORLD, nonprime_group, 1, &nonprime_comm); + + print_info(prime_comm, false, "Prime comm"); + print_info(nonprime_comm, false, "Non-prime comm"); + + run(prime_comm, nitr); + run(nonprime_comm, nitr); + + MPI_Group_free(&world_group); + MPI_Group_free(&prime_group); + MPI_Group_free(&nonprime_group); + } + + print_info(dup, false); } - MPI_Barrier(MPI_COMM_WORLD); + MPI_Finalize(); - return 0; } diff --git a/source/lib/omnitrace/library/components/mpi_gotcha.cpp b/source/lib/omnitrace/library/components/mpi_gotcha.cpp index 09f2ec57a2..91db3a8756 100644 --- a/source/lib/omnitrace/library/components/mpi_gotcha.cpp +++ b/source/lib/omnitrace/library/components/mpi_gotcha.cpp @@ -27,16 +27,71 @@ #include "library/debug.hpp" #include "library/mproc.hpp" -#include #include #include +#include + +#include +#include +#include +#include namespace omnitrace { namespace { -uint64_t mpip_index = std::numeric_limits::max(); -std::string mpi_init_string = {}; +struct comm_rank_data +{ + int rank = -1; + int size = -1; + uintptr_t comm = mpi_gotcha::null_comm(); + + auto updated() const + { + return comm != mpi_gotcha::null_comm() && rank >= 0 && size > 0; + }; + + friend bool operator==(const comm_rank_data& _lhs, const comm_rank_data& _rhs) + { + auto _lupd = _lhs.updated(); + auto _rupd = _rhs.updated(); + return std::tie(_lupd, _lhs.rank, _lhs.size, _lhs.comm) == + std::tie(_rupd, _rhs.rank, _rhs.size, _rhs.comm); + } + + friend bool operator!=(const comm_rank_data& _lhs, const comm_rank_data& _rhs) + { + return !(_lhs == _rhs); + } + + friend bool operator>(const comm_rank_data& _lhs, const comm_rank_data& _rhs) + { + OMNITRACE_CI_THROW(!_lhs.updated() && !_rhs.updated(), + "Error! comparing rank data that is not updated"); + + if(_lhs.updated() && !_rhs.updated()) return true; + if(!_lhs.updated() && _rhs.updated()) return false; + + if(_lhs.size != _rhs.size) return _lhs.size > _rhs.size; + if(_lhs.rank != _rhs.rank) return _lhs.rank > _rhs.rank; + + // lesser comm is greater + return _lhs.comm < _rhs.comm; + } + + friend bool operator<(const comm_rank_data& _lhs, const comm_rank_data& _rhs) + { + return (_lhs != _rhs && !(_lhs > _rhs)); + } +}; + +uint64_t mpip_index = std::numeric_limits::max(); +auto last_comm_record = comm_rank_data{}; +auto mproc_comm_record = comm_rank_data{}; +auto mpi_comm_records = std::map{}; + +using tim::auto_lock_t; +using tim::type_mutex; // this ensures omnitrace_finalize is called before MPI_Finalize void @@ -80,6 +135,50 @@ mpi_gotcha::configure() }; } +void +mpi_gotcha::stop() +{ + OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] Stopping MPI gotcha...\n", process::get_id()); + update(); +} + +void +mpi_gotcha::update() +{ + auto_lock_t _lk{ type_mutex(), std::defer_lock }; + if(!_lk.owns_lock()) _lk.lock(); + + comm_rank_data _rank_data = mproc_comm_record; + for(const auto& itr : mpi_comm_records) + { + // skip null comms + if(itr.first == null_comm()) continue; + // if currently have null comm, replace + else if(_rank_data.comm == null_comm()) + _rank_data = itr.second; + // if + else if(itr.second > _rank_data) + _rank_data = itr.second; + } + + if(_rank_data.updated() && _rank_data != last_comm_record) + { + auto _rank = _rank_data.rank; + auto _size = _rank_data.size; + + tim::mpi::set_rank(_rank); + tim::mpi::set_size(_size); + tim::settings::default_process_suffix() = _rank; + get_perfetto_output_filename().clear(); + + OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI rank: %i (%i)\n", process::get_id(), + tim::mpi::rank(), _rank); + OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI size: %i (%i)\n", process::get_id(), + tim::mpi::size(), _size); + last_comm_record = _rank_data; + } +} + void mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming, int*, char***) { @@ -126,17 +225,19 @@ mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming) } void -mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming, comm_t, int* _val) +mpi_gotcha::audit(const gotcha_data_t& _data, audit::incoming, comm_t _comm, int* _val) { OMNITRACE_BASIC_DEBUG_F("%s()\n", _data.tool_id.c_str()); omnitrace_push_trace_hidden(_data.tool_id.c_str()); if(_data.tool_id == "MPI_Comm_rank") { + m_comm_val = (uintptr_t) _comm; // NOLINT m_rank_ptr = _val; } else if(_data.tool_id == "MPI_Comm_size") { + m_comm_val = (uintptr_t) _comm; // NOLINT m_size_ptr = _val; } else @@ -172,71 +273,51 @@ mpi_gotcha::audit(const gotcha_data_t& _data, audit::outgoing, int _retval) api::omnitrace>(); } - auto _size = mproc::get_concurrent_processes().size(); - if(_size > 0) + auto_lock_t _lk{ type_mutex() }; + if(!mproc_comm_record.updated()) { - m_size = _size; - tim::mpi::set_size(_size); - OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI size: %i (%i)\n", process::get_id(), - tim::mpi::size(), m_size); - - auto _rank = mproc::get_process_index(); - if(_rank >= 0) + auto _pid = getpid(); + auto _ppid = getppid(); + auto _size = mproc::get_concurrent_processes(_ppid).size(); + if(_size > 0) { - m_rank = _rank; - tim::mpi::set_rank(_rank); - tim::settings::default_process_suffix() = _rank; - get_perfetto_output_filename().clear(); - OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI rank: %i (%i)\n", - process::get_id(), tim::mpi::rank(), m_rank); + mproc_comm_record.comm = _ppid; + mproc_comm_record.size = m_size = _size; + auto _rank = mproc::get_process_index(_pid, _ppid); + if(_rank >= 0) mproc_comm_record.rank = m_rank = _rank; } } } else if(_retval == tim::mpi::success_v && _data.tool_id.find("MPI_Comm_") == 0) { - if(_data.tool_id == "MPI_Comm_rank") + auto_lock_t _lk{ type_mutex() }; + if(m_comm_val != null_comm()) { - if(m_rank_ptr) + auto& _comm_entry = mpi_comm_records[m_comm_val]; + _comm_entry.comm = m_comm_val; + + auto _get_rank = [&]() { + return (m_rank_ptr) ? std::max(*m_rank_ptr, m_rank) : m_rank; + }; + + auto _get_size = [&]() { + return (m_size_ptr) ? std::max(*m_size_ptr, m_size) + : std::max(m_size, _get_rank() + 1); + }; + + if(_data.tool_id == "MPI_Comm_rank" || _data.tool_id == "MPI_Comm_size") { - if(mproc::get_concurrent_processes().empty()) - { - m_rank = std::max(*m_rank_ptr, m_rank); - tim::mpi::set_rank(m_rank); - tim::settings::default_process_suffix() = m_rank; - get_perfetto_output_filename().clear(); - OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI rank: %i (%i)\n", - process::get_id(), tim::mpi::rank(), m_rank); - } + _comm_entry.rank = m_rank = std::max(_comm_entry.rank, _get_rank()); + _comm_entry.size = m_size = std::max(_comm_entry.size, _get_size()); } else { - OMNITRACE_BASIC_VERBOSE(0, "%s() returned %i :: nullptr to rank\n", - _data.tool_id.c_str(), (int) _retval); + OMNITRACE_BASIC_VERBOSE( + 0, "%s() returned %i :: unexpected function wrapper\n", + _data.tool_id.c_str(), (int) _retval); } - } - else if(_data.tool_id == "MPI_Comm_size") - { - if(m_size_ptr) - { - if(mproc::get_concurrent_processes().empty()) - { - m_size = std::max(*m_size_ptr, m_size); - tim::mpi::set_size(m_size); - OMNITRACE_BASIC_VERBOSE(0, "[pid=%i] MPI size: %i (%i)\n", - process::get_id(), tim::mpi::size(), m_size); - } - } - else - { - OMNITRACE_BASIC_VERBOSE(0, "%s() returned %i :: nullptr to size\n", - _data.tool_id.c_str(), (int) _retval); - } - } - else - { - OMNITRACE_BASIC_VERBOSE(0, - "%s() returned %i :: unexpected function wrapper\n", - _data.tool_id.c_str(), (int) _retval); + + // if(_comm_entry.updated()) update(); } } omnitrace_pop_trace_hidden(_data.tool_id.c_str()); diff --git a/source/lib/omnitrace/library/components/mpi_gotcha.hpp b/source/lib/omnitrace/library/components/mpi_gotcha.hpp index 0def35f764..5fc01f3a4f 100644 --- a/source/lib/omnitrace/library/components/mpi_gotcha.hpp +++ b/source/lib/omnitrace/library/components/mpi_gotcha.hpp @@ -26,6 +26,8 @@ #include "library/defines.hpp" #include "library/timemory.hpp" +#include + namespace omnitrace { // this is used to wrap MPI_Init and MPI_Init_thread @@ -60,13 +62,17 @@ struct mpi_gotcha : comp::base // without these you will get a verbosity level 1 warning static void start() {} - static void stop() {} + static void stop(); + + static void update(); + static uintptr_t null_comm() { return std::numeric_limits::max(); } private: - int* m_rank_ptr = nullptr; - int* m_size_ptr = nullptr; - int m_rank = 0; - int m_size = 1; + int m_rank = 0; + int m_size = 1; + int* m_rank_ptr = nullptr; + int* m_size_ptr = nullptr; + uintptr_t m_comm_val = null_comm(); }; using mpi_gotcha_t = comp::gotcha<5, tim::component_tuple, api::omnitrace>; diff --git a/source/lib/omnitrace/library/debug.hpp b/source/lib/omnitrace/library/debug.hpp index b297a885f0..7912efafb5 100644 --- a/source/lib/omnitrace/library/debug.hpp +++ b/source/lib/omnitrace/library/debug.hpp @@ -70,6 +70,15 @@ get_critical_trace_debug() OMNITRACE_HOT; namespace debug { +inline void +flush() +{ + fflush(stdout); + std::cout << std::flush; + fflush(stderr); + std::cerr << std::flush; +} +// struct lock { lock(); @@ -154,50 +163,50 @@ get_chars(T&& _c, std::index_sequence) if((COND) && ::omnitrace::config::get_debug_tid() && \ ::omnitrace::config::get_debug_pid()) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ ::omnitrace::debug::lock _lk{}; \ fprintf(stderr, "[omnitrace][%i][%li]%s", OMNITRACE_PROCESS_IDENTIFIER, \ OMNITRACE_THREAD_IDENTIFIER, \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ fprintf(stderr, __VA_ARGS__); \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ } #define OMNITRACE_CONDITIONAL_BASIC_PRINT(COND, ...) \ if((COND) && ::omnitrace::config::get_debug_tid() && \ ::omnitrace::config::get_debug_pid()) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ ::omnitrace::debug::lock _lk{}; \ fprintf(stderr, "[omnitrace]%s", \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ fprintf(stderr, __VA_ARGS__); \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ } #define OMNITRACE_CONDITIONAL_PRINT_F(COND, ...) \ if((COND) && ::omnitrace::config::get_debug_tid() && \ ::omnitrace::config::get_debug_pid()) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ ::omnitrace::debug::lock _lk{}; \ fprintf(stderr, "[omnitrace][%i][%li][%s]%s", OMNITRACE_PROCESS_IDENTIFIER, \ OMNITRACE_THREAD_IDENTIFIER, OMNITRACE_FUNCTION, \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ fprintf(stderr, __VA_ARGS__); \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ } #define OMNITRACE_CONDITIONAL_BASIC_PRINT_F(COND, ...) \ if((COND) && ::omnitrace::config::get_debug_tid() && \ ::omnitrace::config::get_debug_pid()) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ ::omnitrace::debug::lock _lk{}; \ fprintf(stderr, "[omnitrace][%s]%s", OMNITRACE_FUNCTION, \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ fprintf(stderr, __VA_ARGS__); \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ } //--------------------------------------------------------------------------------------// @@ -240,7 +249,7 @@ get_chars(T&& _c, std::index_sequence) #define OMNITRACE_CONDITIONAL_FAIL(COND, ...) \ if(COND) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ fprintf(stderr, "[omnitrace][%i][%li]%s", OMNITRACE_PROCESS_IDENTIFIER, \ OMNITRACE_THREAD_IDENTIFIER, \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ @@ -254,7 +263,7 @@ get_chars(T&& _c, std::index_sequence) #define OMNITRACE_CONDITIONAL_BASIC_FAIL(COND, ...) \ if(COND) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ fprintf(stderr, "[omnitrace]%s", \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ fprintf(stderr, __VA_ARGS__); \ @@ -267,7 +276,7 @@ get_chars(T&& _c, std::index_sequence) #define OMNITRACE_CONDITIONAL_FAIL_F(COND, ...) \ if(COND) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ fprintf(stderr, "[omnitrace][%i][%li][%s]%s", OMNITRACE_PROCESS_IDENTIFIER, \ OMNITRACE_THREAD_IDENTIFIER, OMNITRACE_FUNCTION, \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ @@ -281,7 +290,7 @@ get_chars(T&& _c, std::index_sequence) #define OMNITRACE_CONDITIONAL_BASIC_FAIL_F(COND, ...) \ if(COND) \ { \ - fflush(stderr); \ + ::omnitrace::debug::flush(); \ fprintf(stderr, "[omnitrace][%s]%s", OMNITRACE_FUNCTION, \ ::omnitrace::debug::is_bracket(__VA_ARGS__) ? "" : " "); \ fprintf(stderr, __VA_ARGS__); \