diff --git a/src/backend_bc.cpp b/src/backend_bc.cpp index 7b7623154c..56e57bcd52 100644 --- a/src/backend_bc.cpp +++ b/src/backend_bc.cpp @@ -33,7 +33,15 @@ namespace rocshmem { -Backend::Backend() { +#define NET_CHECK(cmd) \ + { \ + if (cmd != MPI_SUCCESS) { \ + fprintf(stderr, "Unrecoverable error: MPI Failure\n"); \ + abort() ; \ + } \ + } + +Backend::Backend(MPI_Comm comm) : heap{comm} { int num_cus{}; if (hipDeviceGetAttribute(&num_cus, hipDeviceAttributeMultiprocessorCount, 0)) { @@ -71,12 +79,31 @@ Backend::Backend() { CHECK_HIP( hipHostMalloc(reinterpret_cast(&done_init), sizeof(uint8_t))); + init_mpi_once(comm); /* * Notify other threads that Backend has been initialized. */ *done_init = 0; } +void Backend::init_mpi_once(MPI_Comm comm) { + int init_done{}; + NET_CHECK(MPI_Initialized(&init_done)); + + int provided{}; + if (!init_done) { + NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided)); + if (provided != MPI_THREAD_MULTIPLE) { + std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n"; + } + } + if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD; + + NET_CHECK(MPI_Comm_dup(comm, &backend_comm)); + NET_CHECK(MPI_Comm_size(backend_comm, &num_pes)); + NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe)); +} + void Backend::track_ctx(Context* ctx) { /** * TODO: Don't track CTX_PRIVATE when we support it diff --git a/src/backend_bc.hpp b/src/backend_bc.hpp index 6b71293ce2..f0b5615823 100644 --- a/src/backend_bc.hpp +++ b/src/backend_bc.hpp @@ -67,7 +67,7 @@ class Backend { * @note Implementation may reduce the number of workgroups if the * number exceeds hardware limits. */ - explicit Backend(); + explicit Backend(MPI_Comm comm); /** * @brief Destructor. @@ -221,7 +221,7 @@ class Backend { * @todo document where this is used and try to coalesce this into another * class */ - MPI_Comm thread_comm{}; + MPI_Comm backend_comm{}; /** * @brief Object contains the interface and internal data structures @@ -295,6 +295,13 @@ class Backend { * @brief List of ctxs created by the user. */ std::vector list_of_ctxs{}; + + /** + * @brief initialize MPI. + * + * Backend relies on MPI to exchange meta data across PEs. + */ + void init_mpi_once(MPI_Comm comm); }; /** diff --git a/src/ipc/backend_ipc.cpp b/src/ipc/backend_ipc.cpp index 387e3c27e2..4784447920 100644 --- a/src/ipc/backend_ipc.cpp +++ b/src/ipc/backend_ipc.cpp @@ -54,7 +54,7 @@ int get_ls_non_zero_bit(char *bitmask, int mask_length) { } IPCBackend::IPCBackend(MPI_Comm comm) - : Backend() { + : Backend(comm) { type = BackendType::IPC_BACKEND; if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) { @@ -62,8 +62,6 @@ IPCBackend::IPCBackend(MPI_Comm comm) sstream >> maximum_num_contexts_; } - init_mpi_once(comm); - initIPC(); /** @@ -74,8 +72,8 @@ IPCBackend::IPCBackend(MPI_Comm comm) /* Initialize the host interface */ host_interface = std::make_shared(hdp_proxy_.get(), - thread_comm, - &heap); + backend_comm, + &heap); default_host_ctx = std::make_unique(this, 0); @@ -156,7 +154,7 @@ void IPCBackend::setup_team_world() { IPCTeam *team_world{nullptr}; CHECK_HIP(hipMalloc(&team_world, sizeof(IPCTeam))); new (team_world) IPCTeam(this, team_info_wrt_parent, team_info_wrt_world, - num_pes, my_pe, thread_comm, 0); + num_pes, my_pe, backend_comm, 0); team_tracker.set_team_world(team_world); /** @@ -165,24 +163,6 @@ void IPCBackend::setup_team_world() { ROCSHMEM_TEAM_WORLD = reinterpret_cast(team_world); } -void IPCBackend::init_mpi_once(MPI_Comm comm) { - int init_done{}; - NET_CHECK(MPI_Initialized(&init_done)); - - int provided{}; - if (!init_done) { - NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided)); - if (provided != MPI_THREAD_MULTIPLE) { - std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n"; - } - } - if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD; - - NET_CHECK(MPI_Comm_dup(comm, &thread_comm)); - NET_CHECK(MPI_Comm_size(thread_comm, &num_pes)); - NET_CHECK(MPI_Comm_rank(thread_comm, &my_pe)); -} - void IPCBackend::team_destroy(rocshmem_team_t team) { IPCTeam *team_obj = get_internal_ipc_team(team); @@ -260,11 +240,11 @@ void IPCBackend::initIPC() { const auto &heap_bases{heap.get_heap_bases()}; ipcImpl.ipcHostInit(my_pe, heap_bases, - thread_comm); + backend_comm); } void IPCBackend::global_exit(int status) { - MPI_Abort(MPI_COMM_WORLD, status); + MPI_Abort(backend_comm, status); } void IPCBackend::teams_destroy() { @@ -331,7 +311,7 @@ void IPCBackend::init_wrk_sync_buffer() { * all-to-all exchange with each PE to share the IPC handles. */ MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, - ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, thread_comm); + ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm); /* * Allocate device-side fine grained memory to hold IPC addresses of @@ -399,7 +379,7 @@ void IPCBackend::rocshmem_collective_init() { * Make sure that all processing elements have done this before * continuing. */ - NET_CHECK(MPI_Barrier(thread_comm)); + NET_CHECK(MPI_Barrier(backend_comm)); } void IPCBackend::teams_init() { @@ -491,7 +471,7 @@ void IPCBackend::teams_init() { * Make sure that all processing elements have done this before * continuing. */ - NET_CHECK(MPI_Barrier(thread_comm)); + NET_CHECK(MPI_Barrier(backend_comm)); } } // namespace rocshmem diff --git a/src/ipc/backend_ipc.hpp b/src/ipc/backend_ipc.hpp index cb33080fb2..bcb4723539 100644 --- a/src/ipc/backend_ipc.hpp +++ b/src/ipc/backend_ipc.hpp @@ -65,16 +65,6 @@ class IPCBackend : public Backend { */ void ctx_destroy(Context *ctx) override; - /** - * @brief initialize MPI. - * - * IPC relies on MPI just to exchange the IPC_handle information. - * - * todo: remove the dependency on MPI and make it generic to PMI-X or just - * to OpenSHMEM to have support for both CPU and GPU - */ - void init_mpi_once(MPI_Comm comm); - /** * @brief Helper to initialize IPC interface. */ diff --git a/src/memory/remote_heap_info.hpp b/src/memory/remote_heap_info.hpp index 945d1abf53..b89e20096f 100644 --- a/src/memory/remote_heap_info.hpp +++ b/src/memory/remote_heap_info.hpp @@ -50,7 +50,9 @@ class CommunicatorMPI { /** * @brief Primary constructor */ - CommunicatorMPI(char* heap_base, size_t heap_size) { + CommunicatorMPI(char* heap_base, size_t heap_size, + MPI_Comm comm = MPI_COMM_WORLD) + : comm_{comm} { int initialized; MPI_Initialized(&initialized); if (!initialized) { @@ -87,8 +89,8 @@ class CommunicatorMPI { * @brief Performs MPI_Allgather on recvbuf */ void allgather(void* recvbuf) { - MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf, sizeof(void*), - MPI_CHAR, comm_); + MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf, + sizeof(void*), MPI_CHAR, comm_); } /** @@ -100,7 +102,7 @@ class CommunicatorMPI { /** * @brief Identifier for this processing element */ - MPI_Comm comm_{MPI_COMM_WORLD}; + MPI_Comm comm_{}; /** * @brief Identifier for this processing element @@ -139,8 +141,9 @@ class RemoteHeapInfo { * @param[in] The identifier for this processing element * @param[in] The total number of processing elements */ - RemoteHeapInfo(char* heap_ptr, size_t heap_size) - : communicator_{heap_ptr, heap_size} { + RemoteHeapInfo(char* heap_ptr, size_t heap_size, + MPI_Comm comm = MPI_COMM_WORLD) + : communicator_{heap_ptr, heap_size, comm} { heap_bases_.resize(communicator_.num_pes()); for (auto& base : heap_bases_) { base = nullptr; diff --git a/src/memory/symmetric_heap.hpp b/src/memory/symmetric_heap.hpp index 4c6a39396c..77d6e33fdd 100644 --- a/src/memory/symmetric_heap.hpp +++ b/src/memory/symmetric_heap.hpp @@ -54,6 +54,10 @@ class SymmetricHeap { using RemoteHeapInfoType = RemoteHeapInfo; public: + SymmetricHeap(MPI_Comm comm = MPI_COMM_WORLD) + : remote_heap_info_{single_heap_.get_base_ptr(), + single_heap_.get_size(), + comm} {} /** * @brief Allocates heap memory and returns ptr to caller * @@ -120,8 +124,7 @@ class SymmetricHeap { /** * @brief Implementation of remote heaps */ - RemoteHeapInfoType remote_heap_info_{single_heap_.get_base_ptr(), - single_heap_.get_size()}; + RemoteHeapInfoType remote_heap_info_{}; }; } // namespace rocshmem diff --git a/src/memory/window_info.hpp b/src/memory/window_info.hpp index 287fa96adb..75dd7f0463 100644 --- a/src/memory/window_info.hpp +++ b/src/memory/window_info.hpp @@ -165,7 +165,7 @@ class WindowInfo { /** * @brief MPI Communicator */ - MPI_Comm comm_{MPI_COMM_WORLD}; + MPI_Comm comm_{}; /** * @brief Owning pointer to MPI_Win diff --git a/src/reverse_offload/backend_ro.cpp b/src/reverse_offload/backend_ro.cpp index f2d4b9c068..ff20908531 100644 --- a/src/reverse_offload/backend_ro.cpp +++ b/src/reverse_offload/backend_ro.cpp @@ -45,7 +45,7 @@ namespace rocshmem { extern rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; ROBackend::ROBackend(MPI_Comm comm) - : Backend() { + : Backend(comm) { type = BackendType::RO_BACKEND; if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) { @@ -78,7 +78,7 @@ ROBackend::ROBackend(MPI_Comm comm) queue_ = Queue(maximum_num_contexts_, queue_size_); - transport_ = new MPITransport(comm, &queue_); + transport_ = new MPITransport(backend_comm, &queue_); num_pes = transport_->getNumPes(); my_pe = transport_->getMyPe(); diff --git a/src/reverse_offload/mpi_transport.cpp b/src/reverse_offload/mpi_transport.cpp index a0e672a88b..f6cee2aae6 100644 --- a/src/reverse_offload/mpi_transport.cpp +++ b/src/reverse_offload/mpi_transport.cpp @@ -54,7 +54,7 @@ MPITransport::MPITransport(MPI_Comm comm, Queue* q) std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n"; } } - if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD; + assert(comm != MPI_COMM_NULL); NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world)); NET_CHECK(MPI_Comm_size(ro_net_comm_world, &num_pes)); diff --git a/src/rocshmem.cpp b/src/rocshmem.cpp index 9d2844cbc3..1932e221e7 100644 --- a/src/rocshmem.cpp +++ b/src/rocshmem.cpp @@ -197,13 +197,24 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } [[maybe_unused]] __host__ int rocshmem_my_pe() { - MPIInitSingleton *s = s->GetInstance(); - return s->get_rank(); + if(backend == nullptr) { + MPIInitSingleton *s = s->GetInstance(); + return s->get_rank(); + } + else + { + return backend->getMyPE(); + } } [[maybe_unused]] __host__ int rocshmem_n_pes() { - MPIInitSingleton *s = s->GetInstance(); - return s->get_nprocs(); + if(backend == nullptr) { + MPIInitSingleton *s = s->GetInstance(); + return s->get_nprocs(); + } + else { + return backend->getNumPEs(); + } } [[maybe_unused]] __host__ void *rocshmem_malloc(size_t size) {