From e4c427a736a5550252e7d085b44413aec71b7560 Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Wed, 1 Oct 2025 08:06:56 -0500 Subject: [PATCH] Remove MPI compile-time dependency (#264) * use dlsym for MPI functions to allow compiling without MPI support, convert the usage of MPI functions and symbols to be based on a dlopen/dlsym based mechanism. Turns out this cannot be done entirely vendor neutral, slightly different solutions might be required for Open MPI, MPICH and the new MPI ABI. * checkpoint more work to be done. * checkpoint 2 * checkpoint 3 * checkpoint 4 examples compile and link correctly * checkpoitn 5 (I think) * Checkpoitn 6 * dyld-mpi: adapt GDA * dyldmpi: tests that depend on MPI need to link with it themselves * do not ../mpi_instance.h * dyldmpi: make the symetricHeapTestFixture compile * dyldmpi: Change cmakery, compiles and run gda w/o external MPI * Make it also compile in external MPI mode * dyldmpi: ipc unit tests compile but do not link * dyldmpi: new approach, if external mpi required, link with mpi, otherwise use ompi5 abi * C-style comments in cmakelist.. * dyldmpi: examples: do not fail compiling if MPI not found at build time, instead do not compile the MPI required examples * more updates to CMake logic * convert RO backend and a few other cleanups * update some unit tests to work with the dlopen MPI environment correctly. --------- Co-authored-by: Aurelien Bouteiller --- CMakeLists.txt | 23 ++- cmake/rocshmem_config.h.in | 2 + examples/CMakeLists.txt | 16 +- examples/rocshmem_init_attr_test.cc | 1 + examples/util.h | 2 +- include/rocshmem/rocshmem.hpp | 24 ++- include/rocshmem/rocshmem_mpi.hpp | 143 ++++++++++++++++++ src/atomic_return.cpp | 2 +- src/atomic_return.hpp | 1 - src/backend_bc.cpp | 9 +- src/backend_bc.hpp | 5 +- src/gda/backend_gda.cpp | 21 +-- src/gda/backend_gda.hpp | 18 --- src/gda/context_gda_host.cpp | 4 +- src/gda/gda_team.cpp | 1 + src/host/host.cpp | 40 +++-- src/host/host.hpp | 11 +- src/host/host_helpers.hpp | 13 +- src/host/host_templates.hpp | 34 ++--- src/ipc/backend_ipc.cpp | 21 +-- src/ipc/context_ipc_host.cpp | 2 - src/ipc_policy.cpp | 27 ++-- src/ipc_policy.hpp | 2 +- src/memory/remote_heap_info.hpp | 16 +- src/memory/symmetric_heap.hpp | 6 +- src/memory/window_info.hpp | 15 +- src/mpi_instance.cpp | 126 +++++++++++++-- src/mpi_instance.hpp | 74 ++++++++- src/reverse_offload/context_ro_host.cpp | 1 - src/reverse_offload/mpi_transport.cpp | 80 +++++----- src/reverse_offload/queue_proxy.hpp | 2 - src/reverse_offload/ro_team_proxy.hpp | 3 +- src/reverse_offload/transport.hpp | 3 +- src/rocshmem.cpp | 68 ++++++--- src/stats.hpp | 16 +- src/team.cpp | 5 +- src/team.hpp | 3 +- src/util.hpp | 29 ++++ tests/functional_tests/CMakeLists.txt | 5 +- tests/functional_tests/tester.cpp | 1 - tests/unit_tests/CMakeLists.txt | 3 +- .../ipc_impl_simple_coarse_gtest.cpp | 6 +- .../ipc_impl_simple_coarse_gtest.hpp | 26 ++-- .../unit_tests/ipc_impl_simple_fine_gtest.cpp | 6 +- .../unit_tests/ipc_impl_simple_fine_gtest.hpp | 28 ++-- .../unit_tests/ipc_impl_tiled_fine_gtest.cpp | 6 +- .../unit_tests/ipc_impl_tiled_fine_gtest.hpp | 28 ++-- tests/unit_tests/remote_heap_info_gtest.hpp | 5 +- tests/unit_tests/symmetric_heap_gtest.cpp | 10 +- tests/unit_tests/symmetric_heap_gtest.hpp | 13 +- 50 files changed, 712 insertions(+), 294 deletions(-) create mode 100644 include/rocshmem/rocshmem_mpi.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b12f59688..3c3b785669 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,10 +59,11 @@ option(USE_SHARED_CTX "Request support for shared ctx between WG" OFF) option(USE_SINGLE_NODE "Enable single node support only." OFF) option(USE_HDP_FLUSH "Force flush the HDP cache." OFF) option(USE_HDP_FLUSH_HOST_SIDE "Use a polling thread to flush the HDP cache on the host." OFF) +option(USE_EXTERNAL_MPI "Link with an external MPI (required if used MPI is ABI incompatible with Open MPI v5" OFF) -option(BUILD_FUNCTIONAL_TESTS "Build the functional tests" OFF) +option(BUILD_FUNCTIONAL_TESTS "Build the functional tests (Requires MPI)" OFF) option(BUILD_EXAMPLES "Build the examples" ON) -option(BUILD_UNIT_TESTS "Build the unit tests" OFF) +option(BUILD_UNIT_TESTS "Build the unit tests (Requires MPI)" OFF) option(BUILD_TESTS_ONLY "Build only tests. Used to link agains rocSHMEM in a ROCm Release" OFF) option(BUILD_TOOLS "Build binary tools (e.g., rocshmem_info)" ON) @@ -150,7 +151,21 @@ if (NOT BUILD_TESTS_ONLY) ############################################################################# # PACKAGE DEPENDENCIES ############################################################################# - find_package(MPI REQUIRED) + find_package(MPI) + if (MPI_FOUND) + set(HAVE_EXTERNAL_MPI ON) + else() + set(HAVE_EXTERNAL_MPI OFF) + set(BUILD_FUNCTIONAL_TESTS OFF) + set(BUILD_UNIT_TESTS OFF) + endif() + + if (USE_EXTERNAL_MPI) + if(NOT HAVE_EXTERNAL_MPI) + message(FATAL_ERROR "External MPI support requested but MPI support not found. Build Aborted") + endif() + endif() + find_package(hip REQUIRED PATHS /opt/rocm) find_package(hsa-runtime64 REQUIRED) @@ -182,8 +197,8 @@ if (NOT BUILD_TESTS_ONLY) target_link_libraries( ${PROJECT_NAME} PUBLIC + $<$:MPI::MPI_CXX> Threads::Threads - MPI::MPI_CXX hip::device hip::host hsa-runtime64::hsa-runtime64 diff --git a/cmake/rocshmem_config.h.in b/cmake/rocshmem_config.h.in index 7b54099d38..1abc25d2b9 100644 --- a/cmake/rocshmem_config.h.in +++ b/cmake/rocshmem_config.h.in @@ -45,3 +45,5 @@ #cmakedefine GDA_IONIC #cmakedefine GDA_BNXT #cmakedefine GDA_MLX5 +#cmakedefine USE_EXTERNAL_MPI +#cmakedefine HAVE_EXTERNAL_MPI diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 66dcde161b..8483106940 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -29,6 +29,12 @@ cmake_minimum_required(VERSION 3.16.3 FATAL_ERROR) include(${CMAKE_SOURCE_DIR}/cmake/setup_project.cmake) project(rocshmem_examples VERSION 1.0.0 LANGUAGES CXX) +find_package(MPI) +find_package(hip REQUIRED PATHS /opt/rocm) +if (NOT TARGET roc::rocshmem) + find_package(rocshmem REQUIRED PATHS /opt/rocm) +endif() + ############################################################################### # SOURCES ############################################################################### @@ -38,13 +44,11 @@ set(EXAMPLE_SOURCES rocshmem_broadcast_test.cc rocshmem_getmem_test.cc rocshmem_put_signal_test.cc - rocshmem_init_attr_test.cc ) -find_package(MPI REQUIRED) -find_package(hip REQUIRED PATHS /opt/rocm) -if(NOT TARGET roc::rocshmem) - find_package(rocshmem REQUIRED PATHS /opt/rocm) +if (HAVE_EXTERNAL_MPI) + list(APPEND EXAMPLE_SOURCES + rocshmem_init_attr_test.cc) endif() foreach(SOURCE_FILE IN LISTS EXAMPLE_SOURCES) @@ -55,7 +59,7 @@ foreach(SOURCE_FILE IN LISTS EXAMPLE_SOURCES) target_link_libraries( ${EXECUTABLE_NAME} PRIVATE - MPI::MPI_CXX + $ roc::rocshmem ) endforeach() diff --git a/examples/rocshmem_init_attr_test.cc b/examples/rocshmem_init_attr_test.cc index b7d6d81307..ba68cca727 100644 --- a/examples/rocshmem_init_attr_test.cc +++ b/examples/rocshmem_init_attr_test.cc @@ -59,6 +59,7 @@ */ #include +#include #include "util.h" diff --git a/examples/util.h b/examples/util.h index 3118f52f89..646902dd92 100644 --- a/examples/util.h +++ b/examples/util.h @@ -34,7 +34,7 @@ hipError_t error = condition; \ if(error != hipSuccess){ \ fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ + exit(error); \ } \ } diff --git a/include/rocshmem/rocshmem.hpp b/include/rocshmem/rocshmem.hpp index e2bf79abce..4e30438388 100644 --- a/include/rocshmem/rocshmem.hpp +++ b/include/rocshmem/rocshmem.hpp @@ -26,7 +26,6 @@ #define LIBRARY_INCLUDE_ROCSHMEM_HPP #include -#include #include "rocshmem_config.h" #include "rocshmem_common.hpp" @@ -36,6 +35,10 @@ #include "rocshmem_COLL.hpp" #include "rocshmem_P2P_SYNC.hpp" #include "rocshmem_RMA_X.hpp" +#if defined(HAVE_EXTERNAL_MPI) +#include +#endif + /** * @file rocshmem.hpp * @brief Public header for rocSHMEM device and host libraries. @@ -57,13 +60,22 @@ constexpr char VERSION[] = "3.0.0"; /****************************************************************************** **************************** HOST INTERFACE ********************************** *****************************************************************************/ +#if defined(HAVE_EXTERNAL_MPI) /** * @brief Initialize the rocSHMEM runtime and underlying transport layer. * - * @param[in] comm (Optional) MPI Communicator that rocSHMEM will be using + * @param[in] comm MPI Communicator that rocSHMEM will be using * If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD */ -__host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD); +[[deprecated]] __host__ void rocshmem_init(MPI_Comm comm); +#endif + +/** + * @brief Initialize the rocSHMEM runtime and underlying transport layer. + * This is equivalent to the previous function, using implicitely + * MPI_COMM_WORLD for initialization + */ +__host__ void rocshmem_init(void); /** * @brief Query rocSHMEM context from host API @@ -88,6 +100,7 @@ __host__ void * rocshmem_get_device_ctx(); */ __host__ void *rocshmem_ptr(void *dest, int pe); +#if defined(HAVE_EXTERNAL_MPI) /** * @brief Initialize the rocSHMEM runtime and underlying transport layer * with an attempt to enable the requested thread support. @@ -102,8 +115,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe); * @return int returns 0 upon success; otherwise, it returns a nonzero * value */ -__host__ int rocshmem_init_thread(int requested, int *provided, - MPI_Comm comm = MPI_COMM_WORLD); +[[deprecated]] __host__ int rocshmem_init_thread(int requested, int *provided, + MPI_Comm comm); +#endif /** * @brief Initialize the rocSHMEM runtime and underlying transport layer diff --git a/include/rocshmem/rocshmem_mpi.hpp b/include/rocshmem/rocshmem_mpi.hpp new file mode 100644 index 0000000000..86a87f22fc --- /dev/null +++ b/include/rocshmem/rocshmem_mpi.hpp @@ -0,0 +1,143 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP +#define LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP + +#if defined(HAVE_EXTERNAL_MPI) +#include +#endif + +#if defined(c_plusplus) || defined(__cplusplus) +extern "C" { +#endif + +#if !defined(MPI_VERSION) +// Open MPI based values for the constants/handles etc. +// Even though we did not include an external MPI header file +// The includer may have (e.g., a unit test). + +typedef void* MPI_Comm; +typedef void* MPI_Win; +typedef void* MPI_Group; +typedef void* MPI_Op; +typedef void* MPI_Datatype; +typedef void* MPI_Request; +typedef void* MPI_Info; + +struct ompi_status_public_t { + int MPI_SOURCE; + int MPI_TAG; + int MPI_ERROR; + int _cancelled; + size_t _ucount; +}; +typedef struct ompi_status_public_t MPI_Status; + +#define MPI_Aint uint64_t + +#define MPI_UNDEFINED -32766 +#define MPI_THREAD_MULTIPLE 3 +#define MPI_SUCCESS 0 +#define MPI_IN_PLACE (void*)1 +#define MPI_MODE_NOCHECK 1 +#define MPI_COMM_TYPE_SHARED 0 + +#define MPI_Aint_diff(addr1, addr2) ((MPI_Aint) ((char *) (addr1) - (char *) (addr2))) + +struct ompi_internal_symbols_t { + void *ompi_mpi_comm_world; + void *ompi_mpi_comm_null; + void *ompi_request_null; + void *ompi_mpi_info_null; + void *ompi_mpi_datatype_null; + + void *ompi_mpi_op_max; + void *ompi_mpi_op_min; + void *ompi_mpi_op_sum; + void *ompi_mpi_op_prod; + void *ompi_mpi_op_band; + void *ompi_mpi_op_bor; + void *ompi_mpi_op_bxor; + void *ompi_mpi_op_replace; + void *ompi_mpi_op_no_op; + + void *ompi_mpi_char; + void *ompi_mpi_unsigned_char; + void *ompi_mpi_signed_char; + void *ompi_mpi_short; + void *ompi_mpi_unsigned_short; + void *ompi_mpi_int; + void *ompi_mpi_unsigned; + void *ompi_mpi_long; + void *ompi_mpi_unsigned_long; + void *ompi_mpi_long_long_int; + void *ompi_mpi_unsigned_long_long; + void *ompi_mpi_float; + void *ompi_mpi_double; + void *ompi_mpi_long_double; +}; + +extern struct ompi_internal_symbols_t ompi_symbols_; + +#define OMPI_PREDEFINED_GLOBAL(type, global) (static_cast (global)) +#define MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_world) +#define MPI_COMM_NULL OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_null) +#define MPI_REQUEST_NULL OMPI_PREDEFINED_GLOBAL(MPI_Request, ompi_symbols_.ompi_request_null) +#define MPI_WIN_NULL OMPI_PREDEFINED_GLOBAL(MPI_Win, ompi_symbols_.ompi_mpi_win_null) +#define MPI_INFO_NULL OMPI_PREDEFINED_GLOBAL(MPI_Info, ompi_symbols_.ompi_mpi_info_null) + +#define MPI_MAX OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_max) +#define MPI_MIN OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_min) +#define MPI_SUM OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_sum) +#define MPI_PROD OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_prod) +#define MPI_BAND OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_band) +#define MPI_BOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bor) +#define MPI_BXOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bxor) +#define MPI_REPLACE OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_replace) +#define MPI_NO_OP OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_no_op) + +#define MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_datatype_null) +#define MPI_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_char) +#define MPI_UNSIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_char) +#define MPI_SIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_signed_char) +#define MPI_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_short) +#define MPI_UNSIGNED_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_short) +#define MPI_INT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_int) +#define MPI_UNSIGNED OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned) +#define MPI_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long) +#define MPI_UNSIGNED_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long) +#define MPI_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_long_int) +#define MPI_UNSIGNED_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long_long) +#define MPI_FLOAT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_float) +#define MPI_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_double) +#define MPI_LONG_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_double) + +#endif //!defined(MPI_VERSION) + +#if defined(c_plusplus) || defined(__cplusplus) +} +#endif + +#endif //LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP diff --git a/src/atomic_return.cpp b/src/atomic_return.cpp index ee7c5de424..6069d85ac6 100644 --- a/src/atomic_return.cpp +++ b/src/atomic_return.cpp @@ -80,7 +80,7 @@ void init_g_ret(SymmetricHeap* heap_handle, MPI_Comm thread_comm, int num_wg, * Make sure that all processing elements have done this before * continuing. */ - MPI_Barrier(thread_comm); + mpilib_ftable_.Barrier(thread_comm); } } // namespace rocshmem diff --git a/src/atomic_return.hpp b/src/atomic_return.hpp index e6d886c6b9..381ca85ef5 100644 --- a/src/atomic_return.hpp +++ b/src/atomic_return.hpp @@ -26,7 +26,6 @@ #define LIBRARY_SRC_ATOMIC_RETURN_HPP_ #include -#include #include "memory/symmetric_heap.hpp" #include "util.hpp" diff --git a/src/backend_bc.cpp b/src/backend_bc.cpp index beefbeebd8..95ff10d976 100644 --- a/src/backend_bc.cpp +++ b/src/backend_bc.cpp @@ -59,6 +59,7 @@ Backend::Backend(MPI_Comm comm) : heap(comm, nullptr) { Backend::Backend(TcpBootstrap* bootstrap) : heap(MPI_COMM_NULL, bootstrap) { init(); backend_bootstr = bootstrap; + backend_comm = MPI_COMM_NULL; my_pe = bootstrap->getRank(); num_pes = bootstrap->getNranks(); @@ -106,9 +107,9 @@ void Backend::init(void) { void Backend::init_mpi_once(MPI_Comm comm) { if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD; - NET_CHECK(MPI_Comm_dup(comm, &backend_comm)); - NET_CHECK(MPI_Comm_size(backend_comm, &num_pes)); - NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe)); + NET_CHECK(mpilib_ftable_.Comm_dup(comm, &backend_comm)); + NET_CHECK(mpilib_ftable_.Comm_size(backend_comm, &num_pes)); + NET_CHECK(mpilib_ftable_.Comm_rank(backend_comm, &my_pe)); } void Backend::track_ctx(Context* ctx) { @@ -140,7 +141,7 @@ void Backend::destroy_remaining_ctxs() { Backend::~Backend() { CHECK_HIP(hipFree(print_lock)); if (backend_comm != MPI_COMM_NULL) - NET_CHECK(MPI_Comm_free(&backend_comm)); + NET_CHECK(mpilib_ftable_.Comm_free(&backend_comm)); } void Backend::dump_stats() { diff --git a/src/backend_bc.hpp b/src/backend_bc.hpp index da3972c2ea..240277caaa 100644 --- a/src/backend_bc.hpp +++ b/src/backend_bc.hpp @@ -33,12 +33,11 @@ * It is the top-level interface for these resources. */ -#include - #include #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) #include "rocshmem/rocshmem.hpp" +#include "mpi_instance.hpp" #include "backend_type.hpp" #include "ipc_policy.hpp" #include "memory/symmetric_heap.hpp" @@ -225,7 +224,7 @@ class Backend { * @todo document where this is used and try to coalesce this into another * class */ - MPI_Comm backend_comm{MPI_COMM_NULL}; + MPI_Comm backend_comm; /** * @todo document where this is used diff --git a/src/gda/backend_gda.cpp b/src/gda/backend_gda.cpp index afd44127b6..b8ad1e7882 100644 --- a/src/gda/backend_gda.cpp +++ b/src/gda/backend_gda.cpp @@ -24,15 +24,16 @@ #include -#include "backend_gda.hpp" -#include "gda_team.hpp" -#include "util.hpp" -#include "topology.hpp" - #include #include #include +#include "backend_gda.hpp" +#include "mpi_instance.hpp" +#include "gda_team.hpp" +#include "util.hpp" +#include "topology.hpp" + namespace rocshmem { #define NET_CHECK(cmd) { \ @@ -303,7 +304,7 @@ void GDABackend::create_new_team([[maybe_unused]] Team *parent_team, * the pool of available work arrays. */ if (team_comm != MPI_COMM_NULL) { - NET_CHECK(MPI_Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, + NET_CHECK(mpilib_ftable_.Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, MPI_CHAR, MPI_BAND, team_comm)); } else { Allreduce_char_BAND (team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, parent_team); @@ -361,7 +362,7 @@ void GDABackend::dump_backend_stats() { __host__ void GDABackend::global_exit(int status) { if (backend_comm != MPI_COMM_NULL) - MPI_Abort(backend_comm, status); + mpilib_ftable_.Abort(backend_comm, status); else abort(); } @@ -536,7 +537,7 @@ void GDABackend::setup_teams() { void GDABackend::rte_barrier() { if (backend_comm != MPI_COMM_NULL) { - NET_CHECK(MPI_Barrier(backend_comm)); + NET_CHECK(mpilib_ftable_.Barrier(backend_comm)); } else { backend_bootstr->barrier(); } @@ -668,7 +669,7 @@ void GDABackend::exchange_qp_dest_info() { for (int i = 0; i < maximum_num_contexts_ + 1; i++) { if (backend_comm != MPI_COMM_NULL) { - MPI_Alltoall(MPI_IN_PLACE, sizeof(dest_info_t), MPI_CHAR, dest_info.data() + i * num_pes, sizeof(dest_info_t), MPI_CHAR, backend_comm); + mpilib_ftable_.Alltoall(MPI_IN_PLACE, sizeof(dest_info_t), MPI_CHAR, dest_info.data() + i * num_pes, sizeof(dest_info_t), MPI_CHAR, backend_comm); } else { Alltoall_char_inplace(reinterpret_cast(dest_info.data() + i * num_pes), sizeof(dest_info_t), ROCSHMEM_TEAM_WORLD); } @@ -695,7 +696,7 @@ void GDABackend::setup_heap_memory_rkey() { CHECK_HIP(hipStreamSynchronize(stream)); if (backend_comm != MPI_COMM_NULL) - MPI_Allgather(MPI_IN_PLACE, sizeof(uint32_t), MPI_CHAR, host_rkey_cpy, sizeof(uint32_t), MPI_CHAR, backend_comm); + mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(uint32_t), MPI_CHAR, host_rkey_cpy, sizeof(uint32_t), MPI_CHAR, backend_comm); else backend_bootstr->allGather(host_rkey_cpy, sizeof(uint32_t)); diff --git a/src/gda/backend_gda.hpp b/src/gda/backend_gda.hpp index 0401b4df47..2a67d011e0 100644 --- a/src/gda/backend_gda.hpp +++ b/src/gda/backend_gda.hpp @@ -66,24 +66,6 @@ struct mlx5dv_funcs_t { int (*init_obj)(struct mlx5dv_obj *obj, uint64_t obj_type); }; -/* Helper Macros for handling dynamic libraries */ -#define PPCAT_NX(prefix, func_name) prefix##func_name -#define PPCAT(prefix, func_name) PPCAT_NX(prefix, func_name) - -#define STRINGIFY_NX(name) #name -#define STRINGIFY(name) STRINGIFY_NX(name) - -#define DLSYM_HELPER(func_struct, prefix, handle, func_name) \ -do { \ - *(void **) (&func_struct.func_name) = dlsym(handle, STRINGIFY(PPCAT(prefix, func_name))); \ - if (!func_struct.func_name) { \ - DPRINTF("Failed to find function %s \n", STRINGIFY(PPCAT(prefix, func_name))); \ - dlclose(handle); \ - handle = nullptr; \ - return ROCSHMEM_ERROR; \ - } \ -} while (0) - namespace rocshmem { class GDAContext; diff --git a/src/gda/context_gda_host.cpp b/src/gda/context_gda_host.cpp index ddc2536d11..c6ffbacb14 100644 --- a/src/gda/context_gda_host.cpp +++ b/src/gda/context_gda_host.cpp @@ -22,11 +22,9 @@ * IN THE SOFTWARE. *****************************************************************************/ -#include "context_gda_host.hpp" - -#include #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) +#include "context_gda_host.hpp" #include "backend_type.hpp" #include "context_incl.hpp" #include "backend_gda.hpp" diff --git a/src/gda/gda_team.cpp b/src/gda/gda_team.cpp index 9ca6066450..64bbe38fdb 100644 --- a/src/gda/gda_team.cpp +++ b/src/gda/gda_team.cpp @@ -27,6 +27,7 @@ #include "constants.hpp" #include "backend_type.hpp" #include "backend_gda.hpp" +#include "rocshmem/rocshmem_mpi.hpp" namespace rocshmem { diff --git a/src/host/host.cpp b/src/host/host.cpp index 4cdc5f33a8..ede97eb275 100644 --- a/src/host/host.cpp +++ b/src/host/host.cpp @@ -24,8 +24,6 @@ #include "host.hpp" -#include - #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) #include "host_helpers.hpp" #include "memory/window_info.hpp" @@ -98,9 +96,9 @@ __host__ HostInterface::HostInterface(HdpPolicy* hdp_policy, * Duplicate a communicator from roc_shem's comm * world for the host interface */ - MPI_Comm_dup(rocshmem_comm, &host_comm_world_); - MPI_Comm_rank(host_comm_world_, &my_pe_); - MPI_Comm_rank(host_comm_world_, &num_pes_); + mpilib_ftable_.Comm_dup(rocshmem_comm, &host_comm_world_); + mpilib_ftable_.Comm_rank(host_comm_world_, &my_pe_); + mpilib_ftable_.Comm_size(host_comm_world_, &num_pes_); /* * Create an MPI window on the HDP so that it can be flushed @@ -136,18 +134,18 @@ __host__ HostInterface::HostInterface(HdpPolicy* hdp_policy, #if defined USE_HDP_FLUSH __host__ void HostInterface::create_hdp_window() { - MPI_Win_create(hdp_policy_->get_hdp_flush_ptr(), - sizeof(unsigned int), /* size of window */ - sizeof(unsigned int), /* displacement */ - MPI_INFO_NULL, host_comm_world_, &hdp_win); - + mpilib_ftable_.Win_create(hdp_policy_->get_hdp_flush_ptr(), + sizeof(unsigned int), /* size of window */ + sizeof(unsigned int), /* displacement */ + MPI_INFO_NULL, host_comm_world_, &hdp_win); + /* * Start a shared access epoch on windows of all ranks, * and let the library there is no need to check for * lock exclusivity during operations on this window * (MPI_MODE_NOCHECK). */ - MPI_Win_lock_all(MPI_MODE_NOCHECK, hdp_win); + mpilib_ftable_.Win_lock_all(MPI_MODE_NOCHECK, hdp_win); } #endif // USE_HDP_FLUSH @@ -188,9 +186,9 @@ __host__ HostInterface::HostInterface(HdpPolicy* hdp_policy, __host__ HostInterface::~HostInterface() { #if defined USE_HDP_FLUSH - MPI_Win_unlock_all(hdp_win); + mpilib_ftable_.Win_unlock_all(hdp_win); - MPI_Win_free(&hdp_win); + mpilib_ftable_.Win_free(&hdp_win); #endif // USE_HDP_FLUSH /* Detroy the pool of contexts */ @@ -203,7 +201,7 @@ __host__ HostInterface::~HostInterface() { } if (host_comm_world_ != MPI_COMM_NULL) { - MPI_Comm_free(&host_comm_world_); + mpilib_ftable_.Comm_free(&host_comm_world_); } } @@ -236,7 +234,7 @@ __host__ void HostInterface::putmem(void* dest, const void* source, } initiate_put(dest, source, nelems, pe, window_info_mpi); - MPI_Win_flush_local(pe, window_info_mpi->get_win()); + mpilib_ftable_.Win_flush_local(pe, window_info_mpi->get_win()); } __host__ void HostInterface::getmem(void* dest, const void* source, @@ -248,7 +246,7 @@ __host__ void HostInterface::getmem(void* dest, const void* source, } initiate_get(dest, source, nelems, pe, window_info_mpi); - MPI_Win_flush_local(pe, window_info_mpi->get_win()); + mpilib_ftable_.Win_flush_local(pe, window_info_mpi->get_win()); /* * Flush local HDP to ensure that the NIC's write @@ -295,8 +293,8 @@ __host__ void HostInterface::quiet(WindowInfo* window_info) { __host__ void HostInterface::sync_all(WindowInfo* window_info) { WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); - if (!window_info_mpi) { - MPI_Win_sync(window_info_mpi->get_win()); + if (window_info_mpi) { + mpilib_ftable_.Win_sync(window_info_mpi->get_win()); hdp_policy_->hdp_flush(); /* @@ -305,7 +303,7 @@ __host__ void HostInterface::sync_all(WindowInfo* window_info) { * participating. */ - MPI_Barrier(host_comm_world_); + mpilib_ftable_.Barrier(host_comm_world_); } else { hdp_policy_->hdp_flush(); host_bootstrap_->barrier(); @@ -325,7 +323,7 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) { */ hdp_policy_->hdp_flush(); - MPI_Barrier(host_comm_world_); + mpilib_ftable_.Barrier(host_comm_world_); } else { // Probably not required hdp_policy_->hdp_flush(); @@ -337,7 +335,7 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) { __host__ void HostInterface::barrier_for_sync() { if (host_comm_world_ != MPI_COMM_NULL) { - MPI_Barrier(host_comm_world_); + mpilib_ftable_.Barrier(host_comm_world_); } else { host_bootstrap_->barrier(); } diff --git a/src/host/host.hpp b/src/host/host.hpp index 9c777d2e0b..d7ea9ebc0f 100644 --- a/src/host/host.hpp +++ b/src/host/host.hpp @@ -34,8 +34,6 @@ * any backend type. */ -#include - #include #include "rocshmem/rocshmem.hpp" @@ -43,6 +41,7 @@ #include "memory/symmetric_heap.hpp" #include "memory/window_info.hpp" #include "bootstrap/bootstrap.hpp" +#include "mpi_instance.hpp" namespace rocshmem { @@ -268,17 +267,17 @@ class HostInterface { if (i == my_pe_) { continue; } - MPI_Put(&flush_val, 1, MPI_UNSIGNED, i, 0, 1, MPI_UNSIGNED, hdp_win); + mpilib_ftable_.Put(&flush_val, 1, MPI_UNSIGNED, i, 0, 1, MPI_UNSIGNED, hdp_win); } - MPI_Win_flush_all(hdp_win); + mpilib_ftable_.Win_flush_all(hdp_win); #endif // USE_HDP_FLUSH } __host__ void flush_remote_hdp(int pe) { #if defined USE_HDP_FLUSH unsigned flush_val{HdpPolicy::HDP_FLUSH_VAL}; - MPI_Put(&flush_val, 1, MPI_UNSIGNED, pe, 0, 1, MPI_UNSIGNED, hdp_win); - MPI_Win_flush(pe, hdp_win); + mpilib_ftable_.Put(&flush_val, 1, MPI_UNSIGNED, pe, 0, 1, MPI_UNSIGNED, hdp_win); + mpilib_ftable_.Win_flush(pe, hdp_win); #endif // USE_HDP_FLUSH } diff --git a/src/host/host_helpers.hpp b/src/host/host_helpers.hpp index 4490c7a9da..45ea7dc612 100644 --- a/src/host/host_helpers.hpp +++ b/src/host/host_helpers.hpp @@ -27,6 +27,7 @@ #include "host.hpp" #include "memory/window_info.hpp" +#include "mpi_instance.hpp" #include @@ -42,15 +43,15 @@ __host__ inline MPI_Aint HostInterface::compute_offset( MPI_Aint dest_disp{}; MPI_Aint start_disp{}; - MPI_Get_address(dest, &dest_disp); - MPI_Get_address(win_start, &start_disp); + mpilib_ftable_.Get_address(dest, &dest_disp); + mpilib_ftable_.Get_address(win_start, &start_disp); return MPI_Aint_diff(dest_disp, start_disp); } __host__ inline void HostInterface::complete_all(MPI_Win win) { - MPI_Win_flush_all(win); /* RMA operations */ - MPI_Win_sync(win); /* memory stores */ + mpilib_ftable_.Win_flush_all(win); /* RMA operations */ + mpilib_ftable_.Win_sync(win); /* memory stores */ } __host__ inline void HostInterface::initiate_put(void* dest, const void* source, @@ -74,7 +75,7 @@ __host__ inline void HostInterface::initiate_put(void* dest, const void* source, hdp_policy_->hdp_flush(); /* Offload remote write operation to MPI */ - MPI_Put(source, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win); + mpilib_ftable_.Put(source, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win); } __host__ inline void HostInterface::initiate_get(void* dest, const void* source, @@ -88,7 +89,7 @@ __host__ inline void HostInterface::initiate_get(void* dest, const void* source, MPI_Aint offset = compute_offset(source, win_start, win_end); /* Offload remote fetch operation to MPI */ - MPI_Get(dest, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win); + mpilib_ftable_.Get(dest, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win); } } // namespace rocshmem diff --git a/src/host/host_templates.hpp b/src/host/host_templates.hpp index f95ef62e57..e492094f20 100644 --- a/src/host/host_templates.hpp +++ b/src/host/host_templates.hpp @@ -74,7 +74,7 @@ __host__ T HostInterface::g(const T* source, int pe, WindowInfo* window_info) { */ getmem_nbi(&ret, source, sizeof(T), pe, window_info); - MPI_Win_flush_local(pe, window_info_mpi->get_win()); + mpilib_ftable_.Win_flush_local(pe, window_info_mpi->get_win()); return ret; } @@ -101,7 +101,7 @@ __host__ MPI_Comm HostInterface::get_mpi_comm(int pe_start, int log_pe_stride, * First, check to see if the active set is the same as COMM_WORLD */ int comm_world_size{-1}; - MPI_Comm_size(host_comm_world_, &comm_world_size); + mpilib_ftable_.Comm_size(host_comm_world_, &comm_world_size); if (pe_start == 0 && log_pe_stride == 0 && pe_size == comm_world_size) { /* @@ -139,12 +139,12 @@ __host__ MPI_Comm HostInterface::get_mpi_comm(int pe_start, int log_pe_stride, MPI_Group comm_world_group{}; MPI_Group active_set_group{}; - MPI_Comm_group(host_comm_world_, &comm_world_group); + mpilib_ftable_.Comm_group(host_comm_world_, &comm_world_group); - MPI_Group_incl(comm_world_group, pe_size, active_set_ranks.data(), - &active_set_group); + mpilib_ftable_.Group_incl(comm_world_group, pe_size, active_set_ranks.data(), + &active_set_group); - MPI_Comm_create_group(host_comm_world_, active_set_group, 0, + mpilib_ftable_.Comm_create_group(host_comm_world_, active_set_group, 0, &active_set_comm); /* @@ -168,7 +168,7 @@ __host__ void HostInterface::broadcast_internal(MPI_Comm mpi_comm, T* dest, */ int active_set_rank{-1}; void* buffer{nullptr}; - MPI_Comm_rank(mpi_comm, &active_set_rank); + mpilib_ftable_.Comm_rank(mpi_comm, &active_set_rank); if (pe_root == active_set_rank) { buffer = const_cast(source); } else { @@ -183,7 +183,7 @@ __host__ void HostInterface::broadcast_internal(MPI_Comm mpi_comm, T* dest, /* * Offload the broadcast to MPI */ - MPI_Bcast(buffer, nelems * sizeof(T), MPI_CHAR, pe_root, mpi_comm); + mpilib_ftable_.Bcast(buffer, nelems * sizeof(T), MPI_CHAR, pe_root, mpi_comm); return; } @@ -312,9 +312,9 @@ __host__ T HostInterface::amo_fetch_add(void* dst, T value, int pe, T ret{}; MPI_Win win{window_info_mpi->get_win()}; MPI_Datatype mpi_type{get_mpi_type()}; - MPI_Fetch_and_op(&value, &ret, mpi_type, pe, offset, MPI_SUM, win); + mpilib_ftable_.Fetch_and_op(&value, &ret, mpi_type, pe, offset, MPI_SUM, win); - MPI_Win_flush_local(pe, win); + mpilib_ftable_.Win_flush_local(pe, win); return ret; } @@ -341,9 +341,9 @@ __host__ T HostInterface::amo_fetch_cas(void* dst, T value, T cond, int pe, T ret{}; MPI_Win win{window_info_mpi->get_win()}; MPI_Datatype mpi_type{get_mpi_type()}; - MPI_Compare_and_swap(&value, &cond, &ret, mpi_type, pe, offset, win); + mpilib_ftable_.Compare_and_swap(&value, &cond, &ret, mpi_type, pe, offset, win); - MPI_Win_flush_local(pe, win); + mpilib_ftable_.Win_flush_local(pe, win); return ret; } @@ -368,8 +368,8 @@ __host__ void HostInterface::to_all_internal(MPI_Comm mpi_comm, T* dest, /* * Offload the allreduce to MPI */ - MPI_Allreduce((dest == source) ? MPI_IN_PLACE : send_buf, recv_buf, nreduce, - mpi_type, mpi_op, mpi_comm); + mpilib_ftable_.Allreduce((dest == source) ? MPI_IN_PLACE : send_buf, recv_buf, nreduce, + mpi_type, mpi_op, mpi_comm); return; } @@ -453,9 +453,9 @@ __host__ inline int HostInterface::test_and_compare(MPI_Aint offset, */ hdp_policy_->hdp_flush(); - MPI_Fetch_and_op(nullptr, // because no operation happening here - &fetched_val, mpi_type, my_pe_, offset, MPI_NO_OP, win); - MPI_Win_flush_local(my_pe_, win); + mpilib_ftable_.Fetch_and_op(nullptr, // because no operation happening here + &fetched_val, mpi_type, my_pe_, offset, MPI_NO_OP, win); + mpilib_ftable_.Win_flush_local(my_pe_, win); /* * Compare based on the operation diff --git a/src/ipc/backend_ipc.cpp b/src/ipc/backend_ipc.cpp index 2d4dfbdf58..e69bce40c8 100644 --- a/src/ipc/backend_ipc.cpp +++ b/src/ipc/backend_ipc.cpp @@ -24,13 +24,14 @@ #include -#include "backend_ipc.hpp" -#include "ipc_team.hpp" - #include #include #include +#include "backend_ipc.hpp" +#include "mpi_instance.hpp" +#include "ipc_team.hpp" + namespace rocshmem { #define NET_CHECK(cmd) \ @@ -249,8 +250,8 @@ void IPCBackend::create_new_team([[maybe_unused]] Team *parent_team, * the pool of available work arrays. */ if (team_comm != MPI_COMM_NULL) { - NET_CHECK(MPI_Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, - MPI_CHAR, MPI_BAND, team_comm)); + NET_CHECK(mpilib_ftable_.Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, + MPI_CHAR, MPI_BAND, team_comm)); } else { Allreduce_char_BAND (team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, parent_team); } @@ -321,7 +322,7 @@ void IPCBackend::initIPC(TcpBootstrap *bootstr) { void IPCBackend::global_exit(int status) { if (backend_comm != MPI_COMM_NULL) - MPI_Abort(backend_comm, status); + mpilib_ftable_.Abort(backend_comm, status); else abort(); } @@ -388,8 +389,8 @@ void IPCBackend::setup_wrk_sync_buffers() { * all-to-all exchange with each PE to share the IPC handles. */ if (backend_comm != MPI_COMM_NULL) { - MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, - ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm); + mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, + ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm); } else { assert (backend_bootstr != nullptr); backend_bootstr->allGather(ipc_handle, sizeof(hipIpcMemHandle_t)); @@ -460,7 +461,7 @@ void IPCBackend::rocshmem_collective_init() { * continuing. */ if (backend_comm != MPI_COMM_NULL) { - NET_CHECK(MPI_Barrier(backend_comm)); + NET_CHECK(mpilib_ftable_.Barrier(backend_comm)); } else { backend_bootstr->barrier(); } @@ -551,7 +552,7 @@ void IPCBackend::teams_init() { * continuing. */ if (backend_comm != MPI_COMM_NULL) { - NET_CHECK(MPI_Barrier(backend_comm)); + NET_CHECK(mpilib_ftable_.Barrier(backend_comm)); } else { backend_bootstr->barrier(); } diff --git a/src/ipc/context_ipc_host.cpp b/src/ipc/context_ipc_host.cpp index b549f8c426..4efd477b3e 100644 --- a/src/ipc/context_ipc_host.cpp +++ b/src/ipc/context_ipc_host.cpp @@ -24,8 +24,6 @@ #include "context_ipc_host.hpp" -#include - #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) #include "backend_type.hpp" #include "context_incl.hpp" diff --git a/src/ipc_policy.cpp b/src/ipc_policy.cpp index cfc169b162..5d688385b8 100644 --- a/src/ipc_policy.cpp +++ b/src/ipc_policy.cpp @@ -24,8 +24,6 @@ #include "ipc_policy.hpp" -#include - #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) #include "backend_bc.hpp" #include "context_incl.hpp" @@ -39,20 +37,20 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, * Create an MPI communicator that deals only with local processes. */ MPI_Comm shmcomm; - MPI_Comm_split_type(thread_comm, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, - &shmcomm); + mpilib_ftable_.Comm_split_type(thread_comm, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, + &shmcomm); /* * Figure out how many local process there are. */ int Shm_size; - MPI_Comm_size(shmcomm, &Shm_size); + mpilib_ftable_.Comm_size(shmcomm, &Shm_size); shm_size = Shm_size; /* * Figure out how this process' rank among local processes. */ - MPI_Comm_rank(shmcomm, &shm_rank); + mpilib_ftable_.Comm_rank(shmcomm, &shm_rank); /* * Allocate a host-side c-array to hold the IPC handles. @@ -73,8 +71,8 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, * Do an all-to-all exchange with each local processing element to * share the symmetric heap IPC handles. */ - MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, - vec_ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, shmcomm); + mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, + vec_ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, shmcomm); /* * Allocate device-side array to hold the IPC symmetric heap base @@ -114,16 +112,17 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, CHECK_HIP(hipMalloc(reinterpret_cast(&pes_with_ipc_avail), shm_size * sizeof(int))); - MPI_Group thread_grp, shm_grp; - MPI_Comm_group(thread_comm, &thread_grp); - MPI_Comm_group(shmcomm, &shm_grp); + MPI_Group thread_grp; + MPI_Group shm_grp; + mpilib_ftable_.Comm_group(thread_comm, &thread_grp); + mpilib_ftable_.Comm_group(shmcomm, &shm_grp); int *seqranks = new int[shm_size]; for(int i = 0; i < shm_size; i++) seqranks[i] = i; - MPI_Group_translate_ranks(shm_grp, shm_size, seqranks, thread_grp, pes_with_ipc_avail); + mpilib_ftable_.Group_translate_ranks(shm_grp, shm_size, seqranks, thread_grp, pes_with_ipc_avail); delete [] seqranks; - MPI_Group_free(&shm_grp); - MPI_Group_free(&thread_grp); + mpilib_ftable_.Group_free(&shm_grp); + mpilib_ftable_.Group_free(&thread_grp); } } diff --git a/src/ipc_policy.hpp b/src/ipc_policy.hpp index 83bcefd4a6..10bcb12470 100644 --- a/src/ipc_policy.hpp +++ b/src/ipc_policy.hpp @@ -26,12 +26,12 @@ #define LIBRARY_SRC_IPC_POLICY_HPP_ #include -#include #include #include #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) +#include "mpi_instance.hpp" #include "memory/hip_allocator.hpp" #include "util.hpp" #include "bootstrap/bootstrap.hpp" diff --git a/src/memory/remote_heap_info.hpp b/src/memory/remote_heap_info.hpp index 29286d6dac..3ff25d2790 100644 --- a/src/memory/remote_heap_info.hpp +++ b/src/memory/remote_heap_info.hpp @@ -26,10 +26,10 @@ #define LIBRARY_SRC_MEMORY_REMOTE_HEAP_INFO_HPP_ #include -#include #include +#include "mpi_instance.hpp" #include "hip_allocator.hpp" #include "window_info.hpp" #include "bootstrap/bootstrap.hpp" @@ -53,10 +53,10 @@ class CommunicatorMPI { * @brief Primary constructor */ CommunicatorMPI(char* heap_base, size_t heap_size, - MPI_Comm comm = MPI_COMM_WORLD) + MPI_Comm comm) : comm_{comm} { - MPI_Comm_rank(comm_, &my_pe_); - MPI_Comm_size(comm_, &num_pes_); + mpilib_ftable_.Comm_rank(comm_, &my_pe_); + mpilib_ftable_.Comm_size(comm_, &num_pes_); heap_window_info_ = WindowInfoMPI(comm_, heap_base, heap_size); } @@ -78,14 +78,14 @@ class CommunicatorMPI { /** * @brief Performs MPI_Barrier */ - void barrier() { MPI_Barrier(comm_); } + void barrier() { mpilib_ftable_.Barrier(comm_); } /** * @brief Performs MPI_Allgather on recvbuf */ void allgather(void* recvbuf) { - MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf, - sizeof(void*), MPI_CHAR, comm_); + mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf, + sizeof(void*), MPI_CHAR, comm_); } /** @@ -206,7 +206,7 @@ class RemoteHeapInfo { * @param[in] The total number of processing elements */ RemoteHeapInfo(char* heap_ptr, size_t heap_size, - MPI_Comm comm = MPI_COMM_WORLD) + MPI_Comm comm) : communicator_{heap_ptr, heap_size, comm} { init(heap_ptr, heap_size); } diff --git a/src/memory/symmetric_heap.hpp b/src/memory/symmetric_heap.hpp index c823918c3b..08bdaeb1f4 100644 --- a/src/memory/symmetric_heap.hpp +++ b/src/memory/symmetric_heap.hpp @@ -83,14 +83,14 @@ private: class SymmetricHeap { public: - SymmetricHeap(MPI_Comm comm = MPI_COMM_NULL, TcpBootstrap* bootstrap = nullptr) { + SymmetricHeap(MPI_Comm comm, TcpBootstrap* bootstrap = nullptr) { if (comm != MPI_COMM_NULL) { remote_heap_info_ = new RemoteHeapInfoMPI(single_heap_.get_base_ptr(), - single_heap_.get_size(), comm); + single_heap_.get_size(), comm); } else { remote_heap_info_ = new RemoteHeapInfoTCP(single_heap_.get_base_ptr(), - single_heap_.get_size(), bootstrap); + single_heap_.get_size(), bootstrap); } } /** diff --git a/src/memory/window_info.hpp b/src/memory/window_info.hpp index 82b66a630b..6499468eef 100644 --- a/src/memory/window_info.hpp +++ b/src/memory/window_info.hpp @@ -25,10 +25,9 @@ #ifndef LIBRARY_SRC_MEMORY_WINDOW_INFO_HPP_ #define LIBRARY_SRC_MEMORY_WINDOW_INFO_HPP_ -#include - #include #include +#include "mpi_instance.hpp" /** * @file window_info.hpp @@ -155,8 +154,8 @@ class WindowInfoMPI: public WindowInfo { win_end_ = reinterpret_cast(start) + size; up_win_ = std::unique_ptr(new MPI_Win); - MPI_Win_create(win_start_, size, 1, MPI_INFO_NULL, comm_, up_win_.get()); - MPI_Win_lock_all(MPI_MODE_NOCHECK, *up_win_.get()); + mpilib_ftable_.Win_create(win_start_, size, 1, MPI_INFO_NULL, comm_, up_win_.get()); + mpilib_ftable_.Win_lock_all(MPI_MODE_NOCHECK, *up_win_.get()); } /** @@ -164,8 +163,8 @@ class WindowInfoMPI: public WindowInfo { */ ~WindowInfoMPI() { if (up_win_) { - MPI_Win_unlock_all(*up_win_.get()); - MPI_Win_free(up_win_.get()); + mpilib_ftable_.Win_unlock_all(*up_win_.get()); + mpilib_ftable_.Win_free(up_win_.get()); } } @@ -222,9 +221,9 @@ class WindowInfoMPI: public WindowInfo { reinterpret_cast(win_end_)); MPI_Aint dest_disp; - MPI_Get_address(dest, &dest_disp); + mpilib_ftable_.Get_address(dest, &dest_disp); MPI_Aint start_disp; - MPI_Get_address(win_start_, &start_disp); + mpilib_ftable_.Get_address(win_start_, &start_disp); return static_cast(MPI_Aint_diff(dest_disp, start_disp)); } diff --git a/src/mpi_instance.cpp b/src/mpi_instance.cpp index d863b003af..b3058c51c9 100644 --- a/src/mpi_instance.cpp +++ b/src/mpi_instance.cpp @@ -22,33 +22,139 @@ * IN THE SOFTWARE. *****************************************************************************/ +#include + +#include "rocshmem/rocshmem.hpp" #include "mpi_instance.hpp" +#include "util.hpp" + +#if !defined(HAVE_EXTERNAL_MPI) +// Open MPI specific symbols +struct ompi_internal_symbols_t ompi_symbols_; +#endif //!defined(HAVE_EXTERNAL_MPI) namespace rocshmem { +static void* mpilib_handle_{nullptr}; +struct mpilib_funcs_t mpilib_ftable_; + +int MPIInstance::mpilib_dl_init() { + if (mpilib_handle_ != nullptr) + return ROCSHMEM_SUCCESS; + + mpilib_handle_ = dlopen("libmpi.so", RTLD_NOW); + if (!mpilib_handle_) { + printf("Could not open libmpi.so. Returning\n"); + return ROCSHMEM_ERROR; + } + + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Init_thread); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Initialized); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Finalize); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Finalized); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_rank); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_size); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Abort); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Get_address); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Type_size); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Iprobe); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Testsome); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_split); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_split_type); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_group); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_create_group); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_dup); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_free); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Group_free); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Group_translate_ranks); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Group_incl); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Allgather); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Alltoall); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Allreduce); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Bcast); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Barrier); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Iallreduce); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Ibarrier); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_create); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_free); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_flush); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_flush_all); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_flush_local); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_lock); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_lock_all); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_unlock); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_unlock_all); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_lock); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_sync); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Get); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Rget); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Put); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Rput); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Compare_and_swap); + DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Fetch_and_op); + +#if !defined(HAVE_EXTERNAL_MPI) + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_comm_world); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_comm_null); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_datatype_null); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_request_null); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_info_null); + + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_max); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_min); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_sum); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_prod); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_band); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_bor); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_bxor); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_replace); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_no_op); + + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_char); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_char); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_signed_char); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_short); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_short); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_int); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_long); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_long); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_long_long_int); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_long_long); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_float); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_double); + DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_long_double); +#endif //!defined(HAVE_EXTERNAL_MPI) + + return ROCSHMEM_SUCCESS; +} + +void MPIInstance::mpilib_dl_close() { + if (mpilib_handle_ != nullptr) { + dlclose(mpilib_handle_); + mpilib_handle_ = nullptr; + } +} + MPIInstance::MPIInstance(MPI_Comm comm) { int is_init{0}; - MPI_Initialized(&is_init); + mpilib_ftable_.Initialized(&is_init); if (!is_init) { int provided; - MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); + mpilib_ftable_.Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); init_in_this_class = 1; } - if (comm == MPI_COMM_NULL) { - comm = MPI_COMM_WORLD; - } - - MPI_Comm_size(comm, &nprocs_); - MPI_Comm_rank(comm, &my_rank_); + mpilib_ftable_.Comm_size(comm, &nprocs_); + mpilib_ftable_.Comm_rank(comm, &my_rank_); } MPIInstance::~MPIInstance() { int finalized{0}; - MPI_Finalized(&finalized); + mpilib_ftable_.Finalized(&finalized); if (!finalized && init_in_this_class) { - MPI_Finalize(); + mpilib_ftable_.Finalize(); } } diff --git a/src/mpi_instance.hpp b/src/mpi_instance.hpp index 5192ed1278..c18ff1af4b 100644 --- a/src/mpi_instance.hpp +++ b/src/mpi_instance.hpp @@ -25,8 +25,8 @@ #ifndef LIBRARY_SRC_MPI_INSTANCE_HPP_ #define LIBRARY_SRC_MPI_INSTANCE_HPP_ -#include - +#include +#include #include /** @@ -37,6 +37,63 @@ namespace rocshmem { +struct mpilib_funcs_t { + int (*Init_thread)(int *argc, char ***argv, int required, int *provided); + int (*Initialized)(int *flag); + int (*Finalize)(void); + int (*Finalized)(int *flag); + int (*Comm_rank)(MPI_Comm comm, int *rank); + int (*Comm_size)(MPI_Comm comm, int *size); + int (*Abort)(MPI_Comm comm, int errorcode); + int (*Get_address)(const void *location, MPI_Aint *address); + int (*Type_size)(MPI_Datatype type, int *size); + int (*Iprobe)(int source, int tag, MPI_Comm comm, int *flag, MPI_Status *status); + int (*Testsome)(int incount, MPI_Request array_of_requests[], int *outcount, int array_of_indices[], + MPI_Status array_of_statuses[]); + int (*Comm_split)(MPI_Comm comm, int color, int key, MPI_Comm *newcomm); + int (*Comm_split_type)(MPI_Comm comm, int split_type, int key, MPI_Info info, MPI_Comm *newcomm); + int (*Comm_group)(MPI_Comm comm, MPI_Group *group); + int (*Comm_create_group)(MPI_Comm comm, MPI_Group group, int tag, MPI_Comm *newcomm); + int (*Comm_dup)(MPI_Comm comm, MPI_Comm *newcomm); + int (*Comm_free)(MPI_Comm *comm); + int (*Group_free)(MPI_Group *group); + int (*Group_translate_ranks)(MPI_Group group1, int n, const int ranks1[], MPI_Group group2, int ranks2[]); + int (*Group_incl)(MPI_Group group, int n, const int ranks[], MPI_Group *newgroup); + int (*Allgather)(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm); + int (*Allreduce)(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, + MPI_Op op, MPI_Comm comm); + int (*Alltoall)(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, + MPI_Datatype recvtype, MPI_Comm comm); + int (*Bcast)(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm); + int (*Barrier)(MPI_Comm comm); + int (*Iallreduce)(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, + MPI_Op op, MPI_Comm comm, MPI_Request *request); + int (*Ibarrier)(MPI_Comm comm, MPI_Request *request); + int (*Win_create)(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm, MPI_Win *win); + int (*Win_free)(MPI_Win *win); + int (*Win_flush)(MPI_Win win); + int (*Win_flush_all)(MPI_Win win); + int (*Win_flush_local)(int rank, MPI_Win win); + int (*Win_lock)(int lock_type, int rank, int mpi_assert, MPI_Win win); + int (*Win_lock_all)(int mpi_assert, MPI_Win win); + int (*Win_sync)(MPI_Win win); + int (*Win_unlock)(int rank, MPI_Win win); + int (*Win_unlock_all)(MPI_Win win); + int (*Get)(void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, + MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win); + int (*Rget)(void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp, + int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request *request); + int (*Put)(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp, + int target_count, MPI_Datatype target_datatype, MPI_Win win); + int (*Rput)(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp, + int target_cout, MPI_Datatype target_datatype, MPI_Win win, MPI_Request *request); + int (*Compare_and_swap)(const void *origin_addr, const void *compare_addr, void *result_addr, MPI_Datatype datatype, int target_rank, + MPI_Aint target_disp, MPI_Win win); + int (*Fetch_and_op)(const void *origin_addr, void *result_addr, MPI_Datatype datatype, + int target_rank, MPI_Aint target_disp, MPI_Op op, MPI_Win win); +}; +extern struct mpilib_funcs_t mpilib_ftable_; + class MPIInstance { public: /** @@ -63,6 +120,19 @@ class MPIInstance { */ int get_nprocs(); + /** + * @brief dlopen the MPI library and set + * function pointers. + * @return ROCSHMEM_SUCCESS on success, + * ROCSHMEM_ERROR otherwise. + */ + static int mpilib_dl_init(void); + + /** + * @brief dlclose the MPI library + */ + static void mpilib_dl_close(void); + private: /** * @brief My MPI rank identifier diff --git a/src/reverse_offload/context_ro_host.cpp b/src/reverse_offload/context_ro_host.cpp index 27671ac76d..6f2a034e7a 100644 --- a/src/reverse_offload/context_ro_host.cpp +++ b/src/reverse_offload/context_ro_host.cpp @@ -24,7 +24,6 @@ #include "context_ro_host.hpp" -#include #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) #include "backend_type.hpp" diff --git a/src/reverse_offload/mpi_transport.cpp b/src/reverse_offload/mpi_transport.cpp index 27ce188e3c..41418cfa6f 100644 --- a/src/reverse_offload/mpi_transport.cpp +++ b/src/reverse_offload/mpi_transport.cpp @@ -50,14 +50,14 @@ MPITransport::MPITransport(MPI_Comm comm, Queue* q) assert(comm != MPI_COMM_NULL); - NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world)); - NET_CHECK(MPI_Comm_size(ro_net_comm_world, &num_pes)); - NET_CHECK(MPI_Comm_rank(ro_net_comm_world, &my_pe)); + NET_CHECK(mpilib_ftable_.Comm_dup(comm, &ro_net_comm_world)); + NET_CHECK(mpilib_ftable_.Comm_size(ro_net_comm_world, &num_pes)); + NET_CHECK(mpilib_ftable_.Comm_rank(ro_net_comm_world, &my_pe)); } MPITransport::~MPITransport() { if (ro_net_comm_world != MPI_COMM_NULL) - NET_CHECK(MPI_Comm_free(&ro_net_comm_world)); + NET_CHECK(mpilib_ftable_.Comm_free(&ro_net_comm_world)); } void MPITransport::threadProgressEngine() { @@ -267,13 +267,13 @@ void MPITransport::createNewTeam(ROBackend *backend, Team *parent_team, } void MPITransport::global_exit(int status) { - MPI_Abort(ro_net_comm_world, status); + mpilib_ftable_.Abort(ro_net_comm_world, status); } void MPITransport::barrier(int contextId, volatile char *status, bool blocking, MPI_Comm team, bool do_quiet) { MPI_Request request{}; - NET_CHECK(MPI_Ibarrier(team, &request)); + NET_CHECK(mpilib_ftable_.Ibarrier(team, &request)); if (do_quiet) { requests.push_back({request, {nullptr, contextId, false}}); @@ -351,10 +351,10 @@ void MPITransport::team_reduction(void *dst, void *src, int size, int win_id, MPI_Comm comm{team}; if (dst == src) { - NET_CHECK(MPI_Iallreduce(MPI_IN_PLACE, dst, size, mpi_type, mpi_op, comm, + NET_CHECK(mpilib_ftable_.Iallreduce(MPI_IN_PLACE, dst, size, mpi_type, mpi_op, comm, &request)); } else { - NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request)); + NET_CHECK(mpilib_ftable_.Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request)); } requests.push_back({request, {status, contextId, blocking}}); @@ -370,25 +370,25 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id, MPI_Comm comm{team}; int rank{}, pe_size{}; - NET_CHECK(MPI_Comm_rank(comm, &rank)); - NET_CHECK(MPI_Comm_size(comm, &pe_size)); + NET_CHECK(mpilib_ftable_.Comm_rank(comm, &rank)); + NET_CHECK(mpilib_ftable_.Comm_size(comm, &pe_size)); MPI_Group grp{}, world_grp{}; - NET_CHECK(MPI_Comm_group(comm, &grp)); - NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp)); + NET_CHECK(mpilib_ftable_.Comm_group(comm, &grp)); + NET_CHECK(mpilib_ftable_.Comm_group(ro_net_comm_world, &world_grp)); std::vector ranks(pe_size); std::vector world_ranks(pe_size); for (int i = 0; i < pe_size; i++) ranks[i] = i; - NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data())); + NET_CHECK(mpilib_ftable_.Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data())); MPI_Datatype mpi_type{convertType(type)}; MPI_Request req; if (rank != root){ - NET_CHECK(MPI_Rget(reinterpret_cast(dst), size, mpi_type, world_ranks[root], + NET_CHECK(mpilib_ftable_.Rget(reinterpret_cast(dst), size, mpi_type, world_ranks[root], bp->heap_window_info[win_id]->get_offset(reinterpret_cast(src)), size, mpi_type, bp->heap_window_info[win_id]->get_win(), &req)); @@ -396,7 +396,7 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id, outstanding[contextId]++; } - NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); + NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win())); barrier(contextId, nullptr, false, comm, false); quiet(contextId, status); } @@ -409,22 +409,22 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id, MPI_Comm comm{team}; int rank{}, pe_size{}; - NET_CHECK(MPI_Comm_rank(comm, &rank)); - NET_CHECK(MPI_Comm_size(comm, &pe_size)); + NET_CHECK(mpilib_ftable_.Comm_rank(comm, &rank)); + NET_CHECK(mpilib_ftable_.Comm_size(comm, &pe_size)); MPI_Group grp{}, world_grp{}; - NET_CHECK(MPI_Comm_group(comm, &grp)); - NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp)); + NET_CHECK(mpilib_ftable_.Comm_group(comm, &grp)); + NET_CHECK(mpilib_ftable_.Comm_group(ro_net_comm_world, &world_grp)); std::vector ranks(pe_size); std::vector world_ranks(pe_size); for (int i = 0; i < pe_size; i++) ranks[i] = i; - NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data())); + NET_CHECK(mpilib_ftable_.Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data())); MPI_Datatype mpi_type{convertType(type)}; int type_size{}; - NET_CHECK(MPI_Type_size(mpi_type, &type_size)); + NET_CHECK(mpilib_ftable_.Type_size(mpi_type, &type_size)); if (dst == src) { fprintf(stderr, "IN_PLACE option not support for alltoall in the RO rocSHMEM conduit\n"); @@ -436,7 +436,7 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id, int target = (rank + i) % pe_size; int src_offset = target * type_size * size; int dst_offset = rank * type_size * size; - NET_CHECK(MPI_Rput(reinterpret_cast(src) + src_offset, size, + NET_CHECK(mpilib_ftable_.Rput(reinterpret_cast(src) + src_offset, size, mpi_type, world_ranks[target], bp->heap_window_info[win_id]->get_offset(reinterpret_cast(dst) + dst_offset), size, mpi_type, bp->heap_window_info[win_id]->get_win(), @@ -445,7 +445,7 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id, outstanding[contextId]++; } - NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); + NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win())); quiet(contextId, status); } @@ -457,23 +457,23 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id, MPI_Comm comm{team}; int rank{}, pe_size{}; - NET_CHECK(MPI_Comm_rank(comm, &rank)); - NET_CHECK(MPI_Comm_size(comm, &pe_size)); + NET_CHECK(mpilib_ftable_.Comm_rank(comm, &rank)); + NET_CHECK(mpilib_ftable_.Comm_size(comm, &pe_size)); MPI_Group grp{}, world_grp{}; - NET_CHECK(MPI_Comm_group(comm, &grp)); - NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp)); + NET_CHECK(mpilib_ftable_.Comm_group(comm, &grp)); + NET_CHECK(mpilib_ftable_.Comm_group(ro_net_comm_world, &world_grp)); std::vector ranks(pe_size); std::vector world_ranks(pe_size); for (int i = 0; i < pe_size; i++) ranks[i] = i; - NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data())); + NET_CHECK(mpilib_ftable_.Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data())); MPI_Datatype mpi_type{convertType(type)}; int type_size{}; - NET_CHECK(MPI_Type_size(mpi_type, &type_size)); + NET_CHECK(mpilib_ftable_.Type_size(mpi_type, &type_size)); if (dst == src) { fprintf(stderr, "IN_PLACE option not support for fcollect in the RO rocSHMEM conduit\n"); @@ -484,7 +484,7 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id, for (int i = 0; i < pe_size; ++i) { int target = (rank + i) % pe_size; int offset = rank * type_size * size; - NET_CHECK(MPI_Rput(reinterpret_cast(src), size, mpi_type, world_ranks[target], + NET_CHECK(mpilib_ftable_.Rput(reinterpret_cast(src), size, mpi_type, world_ranks[target], bp->heap_window_info[win_id]->get_offset(reinterpret_cast(dst) + offset), size, mpi_type, bp->heap_window_info[win_id]->get_win(), &pe_req[i])); @@ -492,7 +492,7 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id, outstanding[contextId]++; } - NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); + NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win())); quiet(contextId, status); } @@ -504,14 +504,14 @@ void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id, auto *bp{backend_proxy->get()}; MPI_Request request{}; - NET_CHECK(MPI_Rput( + NET_CHECK(mpilib_ftable_.Rput( src, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(dst), size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request)); // Since MPI makes puts as complete as soon as the local buffer is free, // we need a flush to satisfy quiet. Put it here as a hack for now even // though it should be in the progress loop. - NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); + NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win())); requests.push_back({request, {status, contextId, blocking}}); @@ -525,7 +525,7 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id, auto *bp{backend_proxy->get()}; MPI_Datatype mpi_type{convertType(type)}; - NET_CHECK(MPI_Fetch_and_op(reinterpret_cast(val), src, mpi_type, pe, + NET_CHECK(mpilib_ftable_.Fetch_and_op(reinterpret_cast(val), src, mpi_type, pe, bp->heap_window_info[win_id]->get_offset(dst), get_mpi_op(op), bp->heap_window_info[win_id]->get_win())); @@ -533,7 +533,7 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id, // Since MPI makes puts as complete as soon as the local buffer is free, // we need a flush to satisfy quiet. Put it here as a hack for now even // though it should be in the progress loop. - NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win())); + NET_CHECK(mpilib_ftable_.Win_flush_local(pe, bp->heap_window_info[win_id]->get_win())); queue->notify(status); @@ -547,7 +547,7 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe, auto *bp{backend_proxy->get()}; MPI_Datatype mpi_type{convertType(type)}; - NET_CHECK(MPI_Compare_and_swap((const void *)val, (const void *)cond, src, + NET_CHECK(mpilib_ftable_.Compare_and_swap((const void *)val, (const void *)cond, src, mpi_type, pe, bp->heap_window_info[win_id]->get_offset(dst), bp->heap_window_info[win_id]->get_win())); @@ -555,7 +555,7 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe, // Since MPI makes puts as complete as soon as the local buffer is free, // we need a flush to satisfy quiet. Put it here as a hack for now even // though it should be in the progress loop. - NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win())); + NET_CHECK(mpilib_ftable_.Win_flush_local(pe, bp->heap_window_info[win_id]->get_win())); queue->notify(status); @@ -569,7 +569,7 @@ void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id, auto *bp{backend_proxy->get()}; MPI_Request request{}; - NET_CHECK(MPI_Rget( + NET_CHECK(mpilib_ftable_.Rget( dst, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(src), size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request)); @@ -595,7 +595,7 @@ void MPITransport::progress() { // Slowing the progress engine down a bit avoid hammering the memory subsystem. // This leads to significant performance benefits usleep (progress_delay); - NET_CHECK(MPI_Iprobe(MPI_ANY_SOURCE, tag, ro_net_comm_world, &flag, &status)); + NET_CHECK(mpilib_ftable_.Iprobe(MPI_ANY_SOURCE, tag, ro_net_comm_world, &flag, &status)); } else { DPRINTF("Testing all outstanding requests (%zu)\n", requests.size()); @@ -605,7 +605,7 @@ void MPITransport::progress() { int outcount{}; auto uptr_req_arr {raw_requests()}; - NET_CHECK(MPI_Testsome(incount, uptr_req_arr.get(), &outcount, + NET_CHECK(mpilib_ftable_.Testsome(incount, uptr_req_arr.get(), &outcount, testsome_indices.data(), MPI_STATUSES_IGNORE)); auto *bp{backend_proxy->get()}; diff --git a/src/reverse_offload/queue_proxy.hpp b/src/reverse_offload/queue_proxy.hpp index d6a2b9be9e..b9289c3a53 100644 --- a/src/reverse_offload/queue_proxy.hpp +++ b/src/reverse_offload/queue_proxy.hpp @@ -25,8 +25,6 @@ #ifndef LIBRARY_SRC_REVERSE_OFFLOAD_QUEUE_PROXY_HPP_ #define LIBRARY_SRC_REVERSE_OFFLOAD_QUEUE_PROXY_HPP_ -#include - #include "atomic_return.hpp" #include "device_proxy.hpp" #include "hdp_policy.hpp" diff --git a/src/reverse_offload/ro_team_proxy.hpp b/src/reverse_offload/ro_team_proxy.hpp index 28e620f3e5..1d4716d755 100644 --- a/src/reverse_offload/ro_team_proxy.hpp +++ b/src/reverse_offload/ro_team_proxy.hpp @@ -25,11 +25,10 @@ #ifndef LIBRARY_SRC_REVERSE_OFFLOAD_RO_TEAM_PROXY_HPP_ #define LIBRARY_SRC_REVERSE_OFFLOAD_RO_TEAM_PROXY_HPP_ -#include - #include "device_proxy.hpp" #include "ro_net_team.hpp" #include "team_info_proxy.hpp" +#include "mpi_instance.hpp" namespace rocshmem { diff --git a/src/reverse_offload/transport.hpp b/src/reverse_offload/transport.hpp index 01b72b8bd9..2a2516cd8c 100644 --- a/src/reverse_offload/transport.hpp +++ b/src/reverse_offload/transport.hpp @@ -25,13 +25,12 @@ #ifndef LIBRARY_SRC_REVERSE_OFFLOAD_TRANSPORT_HPP_ #define LIBRARY_SRC_REVERSE_OFFLOAD_TRANSPORT_HPP_ -#include - #include #include "rocshmem/rocshmem.hpp" #include "backend_proxy.hpp" #include "ro_net_team.hpp" +#include "mpi_instance.hpp" namespace rocshmem { diff --git a/src/rocshmem.cpp b/src/rocshmem.cpp index eee3bd3c82..225ccbc582 100644 --- a/src/rocshmem.cpp +++ b/src/rocshmem.cpp @@ -76,7 +76,7 @@ MPIInstance *mpi_instance = nullptr; TcpBootstrap *bootstr = nullptr; rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; -/** + /** * Begin Host Code **/ @@ -92,12 +92,18 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; rocm_init(); + int ret; + ret = MPIInstance::mpilib_dl_init(); mpi_instance = new MPIInstance(comm); #if defined(USE_GDA) CHECK_HIP(hipHostMalloc(&backend, sizeof(GDABackend))); backend = new (backend) GDABackend(comm); #elif defined(USE_RO) + if (ret != ROCSHMEM_SUCCESS) { + printf("Could not initialize MPI library. RO conduit requires MPI library to be loaded at runtime. Aborting\n"); + abort(); + } CHECK_HIP(hipHostMalloc(&backend, sizeof(ROBackend))); backend = new (backend) ROBackend(comm); #elif defined(USE_IPC) @@ -113,7 +119,15 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; [[maybe_unused]] __host__ static void inline library_init_subcomm(TcpBootstrap *bootstrap, int nranks, int rank) { int initialized; int world_size = -1; - MPI_Initialized(&initialized); + + int ret; + ret = MPIInstance::mpilib_dl_init(); + if (ret == ROCSHMEM_SUCCESS) { + printf("Could not initialize MPI library. This initialization method of " + "rocSHMEM requires MPI library to be loaded at runtime. Aborting\n"); + abort(); + } + mpilib_ftable_.Initialized(&initialized); if (!initialized) { // This is an Open MPI specific solution to retrieve the number of @@ -131,7 +145,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; abort(); } } else { - MPI_Comm_size (MPI_COMM_WORLD, &world_size); + mpilib_ftable_.Comm_size (MPI_COMM_WORLD, &world_size); } if (world_size == nranks) { @@ -140,8 +154,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; MPI_Group world_group; int world_rank; - MPI_Comm_rank (MPI_COMM_WORLD, &world_rank); - MPI_Comm_group (MPI_COMM_WORLD, &world_group); + mpilib_ftable_.Comm_rank (MPI_COMM_WORLD, &world_rank); + mpilib_ftable_.Comm_group (MPI_COMM_WORLD, &world_group); int *inc_ranks = new int[nranks]; inc_ranks[rank] = world_rank; @@ -150,14 +164,14 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; MPI_Group sub_group; MPI_Comm sub_comm; - MPI_Group_incl (world_group, nranks, inc_ranks, &sub_group); - MPI_Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm); + mpilib_ftable_.Group_incl (world_group, nranks, inc_ranks, &sub_group); + mpilib_ftable_.Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm); library_init(sub_comm); - MPI_Group_free (&sub_group); - MPI_Group_free (&world_group); - MPI_Comm_free (&sub_comm); + mpilib_ftable_.Group_free (&sub_group); + mpilib_ftable_.Group_free (&world_group); + mpilib_ftable_.Comm_free (&sub_comm); delete[] inc_ranks; } } @@ -192,7 +206,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; [[maybe_unused]] __host__ int rocshmem_init_attr(unsigned int flags, rocshmem_init_attr_t *attr) { - MPI_Comm comm = MPI_COMM_NULL; + MPI_Comm comm; if ((attr == nullptr) || ((flags != ROCSHMEM_INIT_WITH_UNIQUEID) && @@ -228,12 +242,12 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } [[maybe_unused]] __host__ int rocshmem_set_attr_uniqueid_args(int rank, int nranks, - rocshmem_uniqueid_t *uid, - rocshmem_init_attr_t *attr) { + rocshmem_uniqueid_t *uid, + rocshmem_init_attr_t *attr) { if (uid == nullptr || attr == nullptr) { fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", "Call 'rocshmem_get_uniqueid: invalid input argument'", - __FILE__, __LINE__); + __FILE__, __LINE__); return ROCSHMEM_ERROR; } @@ -252,7 +266,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; if (uid == nullptr) { fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", "Call 'rocshmem_get_uniqueid: invalid input argument'", - __FILE__, __LINE__); + __FILE__, __LINE__); return ROCSHMEM_ERROR; } @@ -262,17 +276,29 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; return ROCSHMEM_SUCCESS; } +#if defined(HAVE_EXTERNAL_MPI) [[maybe_unused]] __host__ void rocshmem_init(MPI_Comm comm) { library_init(comm); } +#endif +[[maybe_unused]] __host__ void rocshmem_init() { + MPIInstance::mpilib_dl_init(); + library_init(MPI_COMM_WORLD); +} + +#if defined(HAVE_EXTERNAL_MPI) [[maybe_unused]] __host__ int rocshmem_init_thread( [[maybe_unused]] int required, int *provided, MPI_Comm comm) { + if (comm == static_cast(0) || comm == MPI_COMM_NULL) { + comm = MPI_COMM_WORLD; + } library_init(comm); rocshmem_query_thread(provided); return ROCSHMEM_SUCCESS; } +#endif [[maybe_unused]] __host__ int rocshmem_my_pe() { if (backend != nullptr) { @@ -297,7 +323,6 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; void *ptr; backend->heap.malloc(&ptr, size); - rocshmem_barrier_all(); return ptr; @@ -455,7 +480,8 @@ __host__ int rocshmem_team_split_strided( TeamInfo(team_world, pe_start_in_world, stride_in_world, size); MPI_Comm team_comm{MPI_COMM_NULL}; - if (parent_team_obj->mpi_comm != MPI_COMM_NULL) { + if (parent_team_obj->mpi_comm != MPI_COMM_NULL && + parent_team_obj->mpi_comm != static_cast(0)) { /* Create a new MPI communicator for this team */ int color; if (my_pe_in_new_team < 0) { @@ -464,8 +490,8 @@ __host__ int rocshmem_team_split_strided( color = 1; } - MPI_Comm_split(parent_team_obj->mpi_comm, color, my_pe_in_world, &team_comm); - } + mpilib_ftable_.Comm_split(parent_team_obj->mpi_comm, color, my_pe_in_world, &team_comm); +} /** * Allocate new team for GPU-inittiated communication with backend-specific * objects @@ -484,8 +510,8 @@ __host__ int rocshmem_team_split_strided( backend->team_tracker.track(*new_team); } - if (team_comm != MPI_COMM_NULL) { - MPI_Comm_free (&team_comm); + if (team_comm != MPI_COMM_NULL && team_comm != static_cast(0)) { + mpilib_ftable_.Comm_free (&team_comm); } return 0; } diff --git a/src/stats.hpp b/src/stats.hpp index 8c4370d62f..86905596c7 100644 --- a/src/stats.hpp +++ b/src/stats.hpp @@ -25,7 +25,7 @@ #ifndef LIBRARY_SRC_STATS_HPP_ #define LIBRARY_SRC_STATS_HPP_ -#include +#include #include @@ -179,10 +179,10 @@ class HostStats { AtomicStatType stats[I] = {}; public: - __host__ uint64_t startTimer() const { return MPI_Wtime(); } + __host__ uint64_t startTimer() const { return wtime(); } __host__ void endTimer(uint64_t start, int index) { - incStat(index, MPI_Wtime() - start); + incStat(index, wtime() - start); } __host__ void incStat(int index, int value = 1) { stats[index] += value; } @@ -197,6 +197,16 @@ class HostStats { } __host__ StatType getStat(int index) const { return stats[index].load(); } + private: + double wtime(void) { + double wt; + struct timespec tp; + + (void) clock_gettime(CLOCK_MONOTONIC, &tp); + wt = (static_cast(tp.tv_nsec))/1.0e+9; + wt += static_cast(tp.tv_sec); + return wt; + } }; // clang-format off diff --git a/src/team.cpp b/src/team.cpp index e26420e408..72dcd7bcb8 100644 --- a/src/team.cpp +++ b/src/team.cpp @@ -29,6 +29,7 @@ #include "rocshmem/rocshmem.hpp" #include "backend_bc.hpp" #include "util.hpp" +#include "mpi_instance.hpp" namespace rocshmem { @@ -84,7 +85,7 @@ __host__ Team::Team(Backend* handle, TeamInfo* team_info_wrt_parent, num_pes(_num_pes), my_pe(_my_pe) { if (_mpi_comm != MPI_COMM_NULL) { - MPI_Comm_dup (_mpi_comm, &mpi_comm); + mpilib_ftable_.Comm_dup (_mpi_comm, &mpi_comm); } } @@ -117,7 +118,7 @@ __host__ __device__ int Team::get_pe_in_my_team(int pe_in_world) { __host__ Team::~Team() { if (mpi_comm != MPI_COMM_NULL) - MPI_Comm_free (&mpi_comm); + mpilib_ftable_.Comm_free (&mpi_comm); } } // namespace rocshmem diff --git a/src/team.hpp b/src/team.hpp index 343424da39..10c3e4c951 100644 --- a/src/team.hpp +++ b/src/team.hpp @@ -25,9 +25,8 @@ #ifndef LIBRARY_SRC_TEAM_HPP_ #define LIBRARY_SRC_TEAM_HPP_ -#include - #include "rocshmem/rocshmem.hpp" +#include "rocshmem/rocshmem_mpi.hpp" #include "backend_type.hpp" namespace rocshmem { diff --git a/src/util.hpp b/src/util.hpp index 24ba037e70..d06e82c857 100644 --- a/src/util.hpp +++ b/src/util.hpp @@ -115,6 +115,35 @@ namespace rocshmem { } while (0); #endif +/* Helper Macros for handling dynamic libraries */ +#define PPCAT_NX(prefix, func_name) prefix##func_name +#define PPCAT(prefix, func_name) PPCAT_NX(prefix, func_name) + +#define STRINGIFY_NX(name) #name +#define STRINGIFY(name) STRINGIFY_NX(name) + +#define DLSYM_HELPER(func_struct, prefix, handle, func_name) \ +do { \ + *(void **) (&func_struct.func_name) = dlsym(handle, STRINGIFY(PPCAT(prefix, func_name))); \ + if (!func_struct.func_name) { \ + DPRINTF("Failed to find function %s \n", STRINGIFY(PPCAT(prefix, func_name))); \ + dlclose(handle); \ + handle = nullptr; \ + return ROCSHMEM_ERROR; \ + } \ +} while (0) + +#define DLSYM_VAR_HELPER(func_struct, handle, var_name) \ +do { \ + *(void **) (&func_struct.var_name) = dlsym(handle, STRINGIFY(var_name)); \ + if (!func_struct.var_name) { \ + DPRINTF("Failed to find function %s \n", STRINGIFY(var_name)); \ + dlclose(handle); \ + handle = nullptr; \ + return ROCSHMEM_ERROR; \ + } \ +} while (0) + extern const int gpu_clock_freq_mhz; /* Device-side internal functions */ diff --git a/tests/functional_tests/CMakeLists.txt b/tests/functional_tests/CMakeLists.txt index 7352299187..f5069dd70e 100644 --- a/tests/functional_tests/CMakeLists.txt +++ b/tests/functional_tests/CMakeLists.txt @@ -72,8 +72,9 @@ target_sources( # ROCSHMEM ############################################################################### if (BUILD_TESTS_ONLY) - find_package(MPI REQUIRED) - find_package(hip REQUIRED PATHS /opt/rocm) + #TODO check that build_test_only still works with external-mpi + #find_package(MPI REQUIRED) + #find_package(hip REQUIRED PATHS /opt/rocm) find_package(rocshmem REQUIRED PATHS /opt/rocm) target_include_directories( diff --git a/tests/functional_tests/tester.cpp b/tests/functional_tests/tester.cpp index 52c9c60bd0..193ad31be4 100644 --- a/tests/functional_tests/tester.cpp +++ b/tests/functional_tests/tester.cpp @@ -25,7 +25,6 @@ #include "tester.hpp" #include -#include #include #include diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index a85c57c0df..e3ca9b41a3 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -79,9 +79,9 @@ endif() # ROCSHMEM DEPENDENCY ############################################################################### find_package(hip REQUIRED PATHS /opt/rocm) +find_package(MPI REQUIRED) if (BUILD_TESTS_ONLY) - find_package(MPI REQUIRED) find_package(rocshmem REQUIRED PATHS /opt/rocm) target_include_directories( @@ -95,6 +95,7 @@ endif() target_link_libraries( ${PROJECT_NAME} PRIVATE + MPI::MPI_CXX roc::rocshmem ) diff --git a/tests/unit_tests/ipc_impl_simple_coarse_gtest.cpp b/tests/unit_tests/ipc_impl_simple_coarse_gtest.cpp index 636a4f9f9c..3f008685fc 100644 --- a/tests/unit_tests/ipc_impl_simple_coarse_gtest.cpp +++ b/tests/unit_tests/ipc_impl_simple_coarse_gtest.cpp @@ -32,13 +32,13 @@ TEST_P(DegenerateSimpleCoarse, ptr_check) { } TEST_P(DegenerateSimpleCoarse, MPI_num_pes) { - ASSERT_EQ(mpi_.num_pes(), 2); + ASSERT_EQ(mpi_->num_pes(), 2); } TEST_P(DegenerateSimpleCoarse, IPC_bases) { - ASSERT_EQ(mpi_.num_pes(), 2); + ASSERT_EQ(mpi_->num_pes(), 2); ASSERT_NE(ipc_impl_.ipc_bases, nullptr); - for(int i{0}; i < mpi_.num_pes(); i++) { + for(int i{0}; i < mpi_->num_pes(); i++) { ASSERT_NE(ipc_impl_.ipc_bases[i], nullptr); } } diff --git a/tests/unit_tests/ipc_impl_simple_coarse_gtest.hpp b/tests/unit_tests/ipc_impl_simple_coarse_gtest.hpp index d78063234f..ab1e24527d 100644 --- a/tests/unit_tests/ipc_impl_simple_coarse_gtest.hpp +++ b/tests/unit_tests/ipc_impl_simple_coarse_gtest.hpp @@ -73,7 +73,10 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParammy_pe(), mpi_->get_heap_bases(), MPI_COMM_WORLD); assert(ipc_impl_dptr_ == nullptr); hip_allocator_.allocate((void**)&ipc_impl_dptr_, sizeof(IpcImpl)); CHECK_HIP(hipMemcpy(ipc_impl_dptr_, &ipc_impl_, @@ -85,6 +88,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam(ipc_impl_.ipc_bases[mpi_.my_pe()]); + auto dev_src = reinterpret_cast(ipc_impl_.ipc_bases[mpi_->my_pe()]); CHECK_HIP(hipMemcpy(dev_src, golden_.data(), bytes, hipMemcpyHostToDevice)); CHECK_HIP(hipStreamSynchronize(nullptr)); } @@ -140,14 +144,14 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParammy_pe() == 0) || + (is_read_test && mpi_->my_pe() == 1); } void execute(TestType test, FN_T fn, const dim3 grid, const dim3 block) { - if (mpi_.my_pe()) { - mpi_.barrier(); - mpi_.barrier(); + if (mpi_->my_pe()) { + mpi_->barrier(); + mpi_->barrier(); return; } int *src{nullptr}; @@ -160,9 +164,9 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam(ipc_impl_.ipc_bases[0]); } size_t bytes = golden_.size() * sizeof(int); - mpi_.barrier(); + mpi_->barrier(); launch(fn, grid, block, src, dest, bytes); - mpi_.barrier(); + mpi_->barrier(); } void validate_dest_buffer(TestType test) { @@ -170,7 +174,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam(ipc_impl_.ipc_bases[mpi_.my_pe()]); + auto dev_dest = reinterpret_cast(ipc_impl_.ipc_bases[mpi_->my_pe()]); for (int i = 0; i < static_cast(golden_.size()); i++) { ASSERT_EQ(golden_[i], dev_dest[i]); } @@ -184,7 +188,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam golden_; HEAP_T heap_mem_ {}; - MPI_T mpi_ {heap_mem_.get_ptr(), heap_mem_.get_size()}; + MPI_T *mpi_{nullptr}; IpcImpl ipc_impl_ {}; IpcImpl *ipc_impl_dptr_ {nullptr}; diff --git a/tests/unit_tests/ipc_impl_simple_fine_gtest.cpp b/tests/unit_tests/ipc_impl_simple_fine_gtest.cpp index ae985aa15d..6587b79ef7 100644 --- a/tests/unit_tests/ipc_impl_simple_fine_gtest.cpp +++ b/tests/unit_tests/ipc_impl_simple_fine_gtest.cpp @@ -31,13 +31,13 @@ TEST_P(DegenerateSimpleFine, ptr_check) { } TEST_P(DegenerateSimpleFine, MPI_num_pes) { - ASSERT_EQ(mpi_.num_pes(), 2); + ASSERT_EQ(mpi_->num_pes(), 2); } TEST_P(DegenerateSimpleFine, IPC_bases) { - ASSERT_EQ(mpi_.num_pes(), 2); + ASSERT_EQ(mpi_->num_pes(), 2); ASSERT_NE(ipc_impl_.ipc_bases, nullptr); - for(int i{0}; i < mpi_.num_pes(); i++) { + for(int i{0}; i < mpi_->num_pes(); i++) { ASSERT_NE(ipc_impl_.ipc_bases[i], nullptr); } } diff --git a/tests/unit_tests/ipc_impl_simple_fine_gtest.hpp b/tests/unit_tests/ipc_impl_simple_fine_gtest.hpp index bb7b4fd31b..a68de223a5 100644 --- a/tests/unit_tests/ipc_impl_simple_fine_gtest.hpp +++ b/tests/unit_tests/ipc_impl_simple_fine_gtest.hpp @@ -140,7 +140,10 @@ class IPCImplSimpleFine : public ::testing::TestWithParammy_pe(), mpi_->get_heap_bases(), MPI_COMM_WORLD); assert(ipc_impl_dptr_ == nullptr); hip_allocator_.allocate((void**)&ipc_impl_dptr_, sizeof(IpcImpl)); @@ -163,6 +166,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParammy_pe() == 0) { int *dest = reinterpret_cast(ipc_impl_.ipc_bases[1]); *(dest + SIGNAL_OFFSET) = 0; } @@ -225,27 +229,27 @@ class IPCImplSimpleFine : public ::testing::TestWithParam(ipc_impl_.ipc_bases[mpi_.my_pe()]); + auto dev_src = reinterpret_cast(ipc_impl_.ipc_bases[mpi_->my_pe()]); CHECK_HIP(hipMemcpy(dev_src, golden_.data(), bytes, hipMemcpyHostToDevice)); } bool pe_initializes_src_buffer(TestType test) { bool is_write_test = test; bool is_read_test = !test; - return (is_write_test && mpi_.my_pe() == 0) || - (is_read_test && mpi_.my_pe() == 1); + return (is_write_test && mpi_->my_pe() == 0) || + (is_read_test && mpi_->my_pe() == 1); } void execute(TestType test, FN_T1 fn, const dim3 grid, const dim3 block) { size_t bytes = golden_.size() * sizeof(int); - if (mpi_.my_pe()) { - mpi_.barrier(); + if (mpi_->my_pe()) { + mpi_->barrier(); if (test == WRITE) { int *dest = reinterpret_cast(ipc_impl_.ipc_bases[1]); FN_T2 val_fn = kernel_put_with_signal_simple_validator; launch(val_fn, grid, block, dest, bytes); } - mpi_.barrier(); + mpi_->barrier(); return; } int *src{nullptr}; @@ -257,9 +261,9 @@ class IPCImplSimpleFine : public ::testing::TestWithParam(ipc_impl_.ipc_bases[1]); dest = reinterpret_cast(ipc_impl_.ipc_bases[0]); } - mpi_.barrier(); + mpi_->barrier(); launch(fn, grid, block, src, dest, bytes, test); - mpi_.barrier(); + mpi_->barrier(); } void check_device_validation_errors(TestType test) { @@ -274,7 +278,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParam(ipc_impl_.ipc_bases[mpi_.my_pe()]); + auto dev_dest = reinterpret_cast(ipc_impl_.ipc_bases[mpi_->my_pe()]); for (int i = 0; i < static_cast(golden_.size()); i++) { ASSERT_EQ(golden_[i], dev_dest[i]); } @@ -291,7 +295,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParam golden_; diff --git a/tests/unit_tests/ipc_impl_tiled_fine_gtest.cpp b/tests/unit_tests/ipc_impl_tiled_fine_gtest.cpp index 7522d46afe..78ac95214f 100644 --- a/tests/unit_tests/ipc_impl_tiled_fine_gtest.cpp +++ b/tests/unit_tests/ipc_impl_tiled_fine_gtest.cpp @@ -33,13 +33,13 @@ TEST_F(DegenerateTiledFine, ptr_check) { } TEST_F(DegenerateTiledFine, MPI_num_pes) { - ASSERT_EQ(mpi_.num_pes(), 2); + ASSERT_EQ(mpi_->num_pes(), 2); } TEST_F(DegenerateTiledFine, IPC_bases) { - ASSERT_EQ(mpi_.num_pes(), 2); + ASSERT_EQ(mpi_->num_pes(), 2); ASSERT_NE(ipc_impl_.ipc_bases, nullptr); - for(int i{0}; i < mpi_.num_pes(); i++) { + for(int i{0}; i < mpi_->num_pes(); i++) { ASSERT_NE(ipc_impl_.ipc_bases[i], nullptr); } } diff --git a/tests/unit_tests/ipc_impl_tiled_fine_gtest.hpp b/tests/unit_tests/ipc_impl_tiled_fine_gtest.hpp index f65e5fbde2..d99d3866f8 100644 --- a/tests/unit_tests/ipc_impl_tiled_fine_gtest.hpp +++ b/tests/unit_tests/ipc_impl_tiled_fine_gtest.hpp @@ -158,7 +158,10 @@ class IPCImplTiledFine : public ::testing::TestWithParammy_pe(), mpi_->get_heap_bases(), MPI_COMM_WORLD); assert(ipc_impl_dptr_ == nullptr); hip_allocator_.allocate((void**)&ipc_impl_dptr_, sizeof(IpcImpl)); @@ -181,6 +184,7 @@ class IPCImplTiledFine : public ::testing::TestWithParammy_pe() == 0) { int *dest = reinterpret_cast(ipc_impl_.ipc_bases[1]); *(dest + SIGNAL_OFFSET) = signal_value; } @@ -243,28 +247,28 @@ class IPCImplTiledFine : public ::testing::TestWithParam(ipc_impl_.ipc_bases[mpi_.my_pe()]); + auto dev_src = reinterpret_cast(ipc_impl_.ipc_bases[mpi_->my_pe()]); CHECK_HIP(hipMemcpy(dev_src, golden_.data(), bytes, hipMemcpyHostToDevice)); } bool pe_initializes_src_buffer(TestType test) { bool is_write_test = test; bool is_read_test = !test; - return (is_write_test && mpi_.my_pe() == 0) || - (is_read_test && mpi_.my_pe() == 1); + return (is_write_test && mpi_->my_pe() == 0) || + (is_read_test && mpi_->my_pe() == 1); } void execute(TestType test, FN_T1 fn, const dim3 grid, const dim3 block) { size_t bytes = golden_.size() * sizeof(int); - if (mpi_.my_pe()) { - mpi_.barrier(); + if (mpi_->my_pe()) { + mpi_->barrier(); if (test == WRITE) { int *dest = reinterpret_cast(ipc_impl_.ipc_bases[1]); FN_T2 val_fn = kernel_put_with_signal_tiled_validator; launch(val_fn, grid, block, dest, bytes); ASSERT_EQ(*(dest + SIGNAL_OFFSET), 0); } - mpi_.barrier(); + mpi_->barrier(); return; } int *src{nullptr}; @@ -276,9 +280,9 @@ class IPCImplTiledFine : public ::testing::TestWithParam(ipc_impl_.ipc_bases[1]); dest = reinterpret_cast(ipc_impl_.ipc_bases[0]); } - mpi_.barrier(); + mpi_->barrier(); launch(fn, grid, block, src, dest, bytes, test); - mpi_.barrier(); + mpi_->barrier(); } void check_device_validation_errors(TestType test) { @@ -293,7 +297,7 @@ class IPCImplTiledFine : public ::testing::TestWithParam(ipc_impl_.ipc_bases[mpi_.my_pe()]); + auto dev_dest = reinterpret_cast(ipc_impl_.ipc_bases[mpi_->my_pe()]); for (int i = 0; i < static_cast(golden_.size()); i++) { ASSERT_EQ(golden_[i], dev_dest[i]); } @@ -310,7 +314,7 @@ class IPCImplTiledFine : public ::testing::TestWithParam golden_; diff --git a/tests/unit_tests/remote_heap_info_gtest.hpp b/tests/unit_tests/remote_heap_info_gtest.hpp index 1e3125e2bb..ab5a0efe14 100644 --- a/tests/unit_tests/remote_heap_info_gtest.hpp +++ b/tests/unit_tests/remote_heap_info_gtest.hpp @@ -27,6 +27,8 @@ #include "gtest/gtest.h" +#include + #include "../src/memory/heap_memory.hpp" #include "../src/memory/hip_allocator.hpp" #include "../src/memory/remote_heap_info.hpp" @@ -55,7 +57,8 @@ class RemoteHeapInfoTestFixture : public ::testing::Test * @brief Remote heap info with MPI Communicator */ MPI_T mpi_ {heap_mem_.get_ptr(), - heap_mem_.get_size()}; + heap_mem_.get_size(), + MPI_COMM_WORLD}; }; } // namespace rocshmem diff --git a/tests/unit_tests/symmetric_heap_gtest.cpp b/tests/unit_tests/symmetric_heap_gtest.cpp index b99d123a2d..ab5281e26c 100644 --- a/tests/unit_tests/symmetric_heap_gtest.cpp +++ b/tests/unit_tests/symmetric_heap_gtest.cpp @@ -30,27 +30,27 @@ TEST_F(SymmetricHeapTestFixture, malloc_free) { void *ptr{nullptr}; size_t request_bytes{48}; - symmetric_heap_.malloc(&ptr, request_bytes); + symmetric_heap_->malloc(&ptr, request_bytes); ASSERT_NE(ptr, nullptr); - ASSERT_NO_FATAL_FAILURE(symmetric_heap_.free(ptr)); + ASSERT_NO_FATAL_FAILURE(symmetric_heap_->free(ptr)); } TEST_F(SymmetricHeapTestFixture, window_info) { - auto win_info_ptr{symmetric_heap_.get_window_info()}; + auto win_info_ptr{symmetric_heap_->get_window_info()}; WindowInfoMPI* window_info_mpi = dynamic_cast(win_info_ptr); if (window_info_mpi) { void *window_base_addr{nullptr}; int flag{0}; MPI_Win_get_attr(window_info_mpi->get_win(), MPI_WIN_BASE, &window_base_addr, - &flag); + &flag); ASSERT_NE(0, flag); ASSERT_NE(nullptr, window_base_addr); } } TEST_F(SymmetricHeapTestFixture, heap_bases) { - auto heap_bases{symmetric_heap_.get_heap_bases()}; + auto heap_bases{symmetric_heap_->get_heap_bases()}; for (const auto &base : heap_bases) { ASSERT_NE(nullptr, base); } diff --git a/tests/unit_tests/symmetric_heap_gtest.hpp b/tests/unit_tests/symmetric_heap_gtest.hpp index 35b4ab6ca2..3bc8ea683a 100644 --- a/tests/unit_tests/symmetric_heap_gtest.hpp +++ b/tests/unit_tests/symmetric_heap_gtest.hpp @@ -25,6 +25,8 @@ #ifndef ROCSHMEM_SYMMETRIC_HEAP_GTEST_HPP #define ROCSHMEM_SYMMETRIC_HEAP_GTEST_HPP +#include + #include "gtest/gtest.h" #include "../src/memory/symmetric_heap.hpp" @@ -37,7 +39,16 @@ class SymmetricHeapTestFixture : public ::testing::Test /** * @brief Symmetric heap object */ - SymmetricHeap symmetric_heap_ {MPI_COMM_WORLD}; + SymmetricHeap *symmetric_heap_; + + void SetUp() override { + MPIInstance::mpilib_dl_init(); + symmetric_heap_ = new SymmetricHeap(MPI_COMM_WORLD); + } + + void TearDown() override { + MPIInstance::mpilib_dl_close(); + } }; } // namespace rocshmem