diff --git a/projects/rocshmem/CMakeLists.txt b/projects/rocshmem/CMakeLists.txt index 791fac7944..5614bb68ab 100644 --- a/projects/rocshmem/CMakeLists.txt +++ b/projects/rocshmem/CMakeLists.txt @@ -68,7 +68,6 @@ option(BUILD_TOOLS "Build binary tools (e.g., rocshmem_info)" ON) option(BUILD_LOCAL_GPU_TARGET_ONLY "Build only for GPUs detected on this machine" OFF) option(BUILD_CODE_COVERAGE "Build with code coverage flags (gcc only)" OFF) -configure_file(cmake/rocshmem_config.h.in rocshmem_config.h) ############################################################################### # PROJECT @@ -150,6 +149,8 @@ if (NOT BUILD_TESTS_ONLY) set(THREADS_PREFER_PTHREAD_FLAG TRUE) find_package(Threads REQUIRED) + configure_file(cmake/rocshmem_config.h.in rocshmem_config.h) + ############################################################################# # LINKING AND INCLUDE DIRECTORIES ############################################################################# diff --git a/projects/rocshmem/cmake/find_pmix.cmake b/projects/rocshmem/cmake/find_pmix.cmake new file mode 100644 index 0000000000..24011c41bc --- /dev/null +++ b/projects/rocshmem/cmake/find_pmix.cmake @@ -0,0 +1,62 @@ +############################################################################### +# Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +############################################################################### + +# Find pmix installation. +# Two different scenarios need to be covered: +# - pmix installed as part of Open MPI, i.e. it will be in the MPI installation directories +# - pmix deployed with linux distros +# - later: handle pmix deployed with slurm. + +macro(check_pmix) + set(prev_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + find_package(PMIx QUIET) + if (PMIx_FOUND) + message("-- Found pmix at ${PMIx_CONFIG} ${PMIx_INCLUDE_DIRS}") + else() + if(NOT PMIX_HEADER OR NOT PMIX_LIBRARY) + message("-- Cound not find pmix using find_package") + + list(APPEND CMAKE_REQUIRED_INCLUDES "${MPI_CXX_HEADER_DIR}") #prefer Open MPI internal PMIx if any + find_path(PMIX_HEADER pmix.h PATHS ${CMAKE_REQUIRED_INCLUDES}) + if (PMIX_HEADER) + message("-- Found pmix.h at ${PMIX_HEADER}") + get_filename_component(pmix_lib_dir ${PMIX_HEADER} DIRECTORY) + find_library(PMIX_LIBRARY pmix PATHS ${pmix_lib_dir} PATH_SUFFIXES lib lib64 NO_DEFAULT_PATH) + endif() + if(PMIX_HEADER AND PMIX_LIBRARY) + message("-- Found libpmix at ${PMIX_LIBRARY}") + elseif(NOT PMIX_HEADER) + message("-- Cound not find pmix.h") + elseif(NOT PMIX_LIBRARY) + message("-- Could not find libpmix.so") + endif() + endif() + endif() + set(CMAKE_REQUIRED_INCLUDES ${prev_CMAKE_REQUIRED_INCLUDES}) + if (PMIX_HEADER AND PMIX_LIBRARY) + set(PMIX_INCLUDE_DIRECTORIES ${PMIX_HEADER}) + set(PMIX_LIBRARIES ${PMIX_LIBRARY}) + set(HAVE_PMIX TRUE) + endif() +endmacro() diff --git a/projects/rocshmem/docs/compile_and_run.rst b/projects/rocshmem/docs/compile_and_run.rst index cefb88d7c0..9d80dd4981 100644 --- a/projects/rocshmem/docs/compile_and_run.rst +++ b/projects/rocshmem/docs/compile_and_run.rst @@ -81,3 +81,6 @@ You can control the behavior of rocSHMEM by using the following environment vari * - ROCSHMEM_MAX_NUM_TEAMS - 40 - Defines the number of teams an application can use. + * - ROCSHMEM_UNIQUEID_WITH_MPI + - 0 + - Defines whether rocSHMEM is expected to use MPI when using the uniqueId based initialization. diff --git a/projects/rocshmem/examples/rocshmem_init_attr_test.cc b/projects/rocshmem/examples/rocshmem_init_attr_test.cc index d4c061dff5..b7d6d81307 100644 --- a/projects/rocshmem/examples/rocshmem_init_attr_test.cc +++ b/projects/rocshmem/examples/rocshmem_init_attr_test.cc @@ -52,6 +52,10 @@ * To run: mpirun -np 8 -x ROCSHMEM_MAX_NUM_CONTEXTS=2 ./rocshmem_init_attr_test + * Note: + running this test with the Reverse Offload (RO) conduit requires setting + ROCSHMEM_UNIQUEID_WITH_MPI=1 + */ #include diff --git a/projects/rocshmem/src/backend_bc.cpp b/projects/rocshmem/src/backend_bc.cpp index d2a451dfe6..1469d9a774 100644 --- a/projects/rocshmem/src/backend_bc.cpp +++ b/projects/rocshmem/src/backend_bc.cpp @@ -45,7 +45,28 @@ namespace rocshmem { } \ } -Backend::Backend(MPI_Comm comm) : heap{comm} { +Backend::Backend(MPI_Comm comm) : heap(comm, nullptr) { + init(); + init_mpi_once(comm); + /* + * Notify other threads that Backend has been initialized. + */ + *done_init = 0; +} + +Backend::Backend(TcpBootstrap* bootstrap) : heap(MPI_COMM_NULL, bootstrap) { + init(); + backend_bootstr = bootstrap; + + my_pe = bootstrap->getRank(); + num_pes = bootstrap->getNranks(); + /* + * Notify other threads that Backend has been initialized. + */ + *done_init = 0; +} + +void Backend::init(void) { CHECK_HIP(hipGetDevice(&hip_dev_id)); int num_cus{}; @@ -79,12 +100,6 @@ Backend::Backend(MPI_Comm comm) : heap{comm} { CHECK_HIP( hipHostMalloc(reinterpret_cast(&done_init), sizeof(uint8_t))); - - init_mpi_once(comm); - /* - * Notify other threads that Backend has been initialized. - */ - *done_init = 0; } void Backend::init_mpi_once(MPI_Comm comm) { diff --git a/projects/rocshmem/src/backend_bc.hpp b/projects/rocshmem/src/backend_bc.hpp index 9ddf1175eb..149704d5f3 100644 --- a/projects/rocshmem/src/backend_bc.hpp +++ b/projects/rocshmem/src/backend_bc.hpp @@ -44,6 +44,7 @@ #include "memory/symmetric_heap.hpp" #include "stats.hpp" #include "team_tracker.hpp" +#include "bootstrap/bootstrap.hpp" namespace rocshmem { @@ -71,6 +72,7 @@ class Backend { */ explicit Backend(MPI_Comm comm); + explicit Backend(TcpBootstrap* bootstrap); /** * @brief Destructor. */ @@ -225,11 +227,16 @@ class Backend { */ MPI_Comm backend_comm{MPI_COMM_NULL}; + /** + * @todo document where this is used + */ + TcpBootstrap *backend_bootstr{nullptr}; + /** * @brief Object contains the interface and internal data structures * needed to allocate/free memory on the symmetric heap. */ - SymmetricHeap heap{}; + SymmetricHeap heap; /** * @brief Determines which device to launch device kernels onto. @@ -293,6 +300,12 @@ class Backend { virtual void reset_backend_stats() = 0; private: + /** + * @brief initialization code used by all constructors + */ + void init (void); + + /** * @brief List of ctxs created by the user. */ diff --git a/projects/rocshmem/src/host/host.cpp b/projects/rocshmem/src/host/host.cpp index 41ddb957e0..862eaadb69 100644 --- a/projects/rocshmem/src/host/host.cpp +++ b/projects/rocshmem/src/host/host.cpp @@ -38,7 +38,12 @@ namespace rocshmem { __host__ HostContextWindowInfo::HostContextWindowInfo(MPI_Comm comm_world, SymmetricHeap* heap) { window_info_ = - new WindowInfo(comm_world, heap->get_local_heap_base(), heap->get_size()); + new WindowInfoMPI(comm_world, heap->get_local_heap_base(), heap->get_size()); +} + +__host__ HostContextWindowInfo::HostContextWindowInfo(SymmetricHeap* heap) { + window_info_ = + new WindowInfo(heap->get_local_heap_base(), heap->get_size()); } __host__ HostContextWindowInfo::~HostContextWindowInfo() { @@ -146,6 +151,41 @@ __host__ void HostInterface::create_hdp_window() { } #endif // USE_HDP_FLUSH +__host__ HostInterface::HostInterface(HdpPolicy* hdp_policy, + TcpBootstrap *bootstr, + SymmetricHeap* heap) { + host_bootstrap_ = bootstr; + my_pe_ = bootstr->getRank(); + num_pes_ = bootstr->getNranks(); + + /* + * Not sure we need this. + */ + hdp_policy_ = hdp_policy; + + /* + * Allocate and initialize pool of windows for contexts + */ + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_HOST_CONTEXTS"))) { + max_num_ctxs_ = atoi(value); + } + + size_t pool_size = max_num_ctxs_ * sizeof(HostContextWindowInfo*); + host_window_context_pool_ = + reinterpret_cast(malloc(pool_size)); + + for (int ctx_i = 0; ctx_i < max_num_ctxs_; ctx_i++) { + host_window_context_pool_[ctx_i] = + new HostContextWindowInfo(heap); + } + +#if defined USE_HDP_FLUSH && not defined USE_SINGLE_NODE + printf("Non-mpi use-cases only supported with coherent heap at the moment. Aborting.\n"); + abort(); +#endif +} + __host__ HostInterface::~HostInterface() { #if defined USE_HDP_FLUSH MPI_Win_unlock_all(hdp_win); @@ -154,41 +194,61 @@ __host__ HostInterface::~HostInterface() { #endif // USE_HDP_FLUSH /* Detroy the pool of contexts */ - for (int ctx_i = 0; ctx_i < max_num_ctxs_; ctx_i++) { - delete host_window_context_pool_[ctx_i]; + + if (host_window_context_pool_ != nullptr) { + for (int ctx_i = 0; ctx_i < max_num_ctxs_; ctx_i++) { + delete host_window_context_pool_[ctx_i]; + } + free(host_window_context_pool_); } - free(host_window_context_pool_); - - MPI_Comm_free(&host_comm_world_); + if (host_comm_world_ != MPI_COMM_NULL) { + MPI_Comm_free(&host_comm_world_); + } } __host__ void HostInterface::putmem_nbi(void* dest, const void* source, size_t nelems, int pe, WindowInfo* window_info) { - initiate_put(dest, source, nelems, pe, window_info); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + initiate_put(dest, source, nelems, pe, window_info_mpi); } __host__ void HostInterface::getmem_nbi(void* dest, const void* source, size_t nelems, int pe, WindowInfo* window_info) { - initiate_get(dest, source, nelems, pe, window_info); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + initiate_get(dest, source, nelems, pe, window_info_mpi); } __host__ void HostInterface::putmem(void* dest, const void* source, size_t nelems, int pe, WindowInfo* window_info) { - initiate_put(dest, source, nelems, pe, window_info); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + initiate_put(dest, source, nelems, pe, window_info_mpi); - MPI_Win_flush_local(pe, window_info->get_win()); + MPI_Win_flush_local(pe, window_info_mpi->get_win()); } __host__ void HostInterface::getmem(void* dest, const void* source, size_t nelems, int pe, WindowInfo* window_info) { - initiate_get(dest, source, nelems, pe, window_info); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + initiate_get(dest, source, nelems, pe, window_info_mpi); - MPI_Win_flush_local(pe, window_info->get_win()); + MPI_Win_flush_local(pe, window_info_mpi->get_win()); /* * Flush local HDP to ensure that the NIC's write @@ -198,7 +258,11 @@ __host__ void HostInterface::getmem(void* dest, const void* source, } __host__ void HostInterface::fence(WindowInfo* window_info) { - complete_all(window_info->get_win()); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + complete_all(window_info_mpi->get_win()); /* * Flush my HDP and the HDPs of remote GPUs. @@ -216,7 +280,11 @@ __host__ void HostInterface::fence(WindowInfo* window_info) { } __host__ void HostInterface::quiet(WindowInfo* window_info) { - complete_all(window_info->get_win()); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + complete_all(window_info_mpi->get_win()); /* Same explanation as in fence */ hdp_policy_->hdp_flush(); @@ -226,34 +294,53 @@ __host__ void HostInterface::quiet(WindowInfo* window_info) { } __host__ void HostInterface::sync_all(WindowInfo* window_info) { - MPI_Win_sync(window_info->get_win()); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + MPI_Win_sync(window_info_mpi->get_win()); - hdp_policy_->hdp_flush(); - /* - * No need to flush remote - * HDPs here since all PEs are - * participating. - */ + hdp_policy_->hdp_flush(); + /* + * No need to flush remote + * HDPs here since all PEs are + * participating. + */ - MPI_Barrier(host_comm_world_); + MPI_Barrier(host_comm_world_); + } else { + hdp_policy_->hdp_flush(); + host_bootstrap_->barrier(); + } return; } __host__ void HostInterface::barrier_all(WindowInfo* window_info) { - complete_all(window_info->get_win()); + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (window_info_mpi) { + complete_all(window_info_mpi->get_win()); - /* - * Flush my HDP cache so remote NICs will - * see the latest values in device memory - */ - hdp_policy_->hdp_flush(); + /* + * Flush my HDP cache so remote NICs will + * see the latest values in device memory + */ + hdp_policy_->hdp_flush(); - MPI_Barrier(host_comm_world_); + MPI_Barrier(host_comm_world_); + } else { + // Probably not required + hdp_policy_->hdp_flush(); + host_bootstrap_->barrier(); + } + + return; } __host__ void HostInterface::barrier_for_sync() { - MPI_Barrier(host_comm_world_); + if (host_comm_world_ != MPI_COMM_NULL) { + MPI_Barrier(host_comm_world_); + } else { + host_bootstrap_->barrier(); + } } } // namespace rocshmem diff --git a/projects/rocshmem/src/host/host.hpp b/projects/rocshmem/src/host/host.hpp index cbd4ebb1ab..ef1ef32563 100644 --- a/projects/rocshmem/src/host/host.hpp +++ b/projects/rocshmem/src/host/host.hpp @@ -42,6 +42,7 @@ #include "../hdp_policy.hpp" #include "../memory/symmetric_heap.hpp" #include "../memory/window_info.hpp" +#include "../bootstrap/bootstrap.hpp" namespace rocshmem { @@ -59,6 +60,7 @@ class HostContextWindowInfo { * @param[in] team_info information about participating PEs */ HostContextWindowInfo(MPI_Comm comm_world, SymmetricHeap* heap); + HostContextWindowInfo(SymmetricHeap* heap); /** * @brief Destructor @@ -104,11 +106,14 @@ class HostContextWindowInfo { class HostInterface { public: /** - * @brief Primary constructor + * @brief Primary constructors */ __host__ HostInterface(HdpPolicy* hdp_policy, MPI_Comm rocshmem_comm, SymmetricHeap* heap); + __host__ HostInterface(HdpPolicy* hdp_policy, TcpBootstrap *bootstrap, + SymmetricHeap* heap); + /** * @brief Destructor */ @@ -278,10 +283,10 @@ class HostInterface { } __host__ void initiate_put(void* dest, const void* source, size_t nelems, - int pe, WindowInfo* window_info); + int pe, WindowInfoMPI* window_info); __host__ void initiate_get(void* dest, const void* source, size_t nelems, - int pe, WindowInfo* window_info); + int pe, WindowInfoMPI* window_info); __host__ void complete_all(MPI_Win win); @@ -321,7 +326,12 @@ class HostInterface { /** * @brief Global MPI communicator for those host API */ - MPI_Comm host_comm_world_{}; + MPI_Comm host_comm_world_{MPI_COMM_NULL}; + + /** + * @brief Bootstrap object used in the non-mpi workloads + */ + TcpBootstrap *host_bootstrap_{nullptr}; /** * @brief Duplicate of this processing element's id within global rank diff --git a/projects/rocshmem/src/host/host_helpers.hpp b/projects/rocshmem/src/host/host_helpers.hpp index 4da5e7e43c..d6d450a38c 100644 --- a/projects/rocshmem/src/host/host_helpers.hpp +++ b/projects/rocshmem/src/host/host_helpers.hpp @@ -55,7 +55,7 @@ __host__ inline void HostInterface::complete_all(MPI_Win win) { __host__ inline void HostInterface::initiate_put(void* dest, const void* source, size_t nelems, int pe, - WindowInfo* window_info) { + WindowInfoMPI* window_info) { MPI_Win win{window_info->get_win()}; void* win_start{window_info->get_start()}; void* win_end{window_info->get_end()}; @@ -79,7 +79,7 @@ __host__ inline void HostInterface::initiate_put(void* dest, const void* source, __host__ inline void HostInterface::initiate_get(void* dest, const void* source, size_t nelems, int pe, - WindowInfo* window_info) { + WindowInfoMPI* window_info) { MPI_Win win{window_info->get_win()}; void* win_start{window_info->get_start()}; void* win_end{window_info->get_end()}; diff --git a/projects/rocshmem/src/host/host_templates.hpp b/projects/rocshmem/src/host/host_templates.hpp index 56b38dbe71..f384686406 100644 --- a/projects/rocshmem/src/host/host_templates.hpp +++ b/projects/rocshmem/src/host/host_templates.hpp @@ -58,6 +58,10 @@ __host__ void HostInterface::put_nbi(T* dest, const T* source, size_t nelems, template __host__ T HostInterface::g(const T* source, int pe, WindowInfo* window_info) { + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } DPRINTF("Function: host_g\n"); T ret{}; @@ -70,7 +74,7 @@ __host__ T HostInterface::g(const T* source, int pe, WindowInfo* window_info) { */ getmem_nbi(&ret, source, sizeof(T), pe, window_info); - MPI_Win_flush_local(pe, window_info->get_win()); + MPI_Win_flush_local(pe, window_info_mpi->get_win()); return ret; } @@ -289,6 +293,11 @@ __host__ void HostInterface::amo_cas(void* dst, T value, T cond, int pe, template __host__ T HostInterface::amo_fetch_add(void* dst, T value, int pe, WindowInfo* window_info) { + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + /* Calculate offset of remote dest from base address of window */ MPI_Aint offset{ compute_offset(dst, window_info->get_start(), window_info->get_end())}; @@ -301,7 +310,7 @@ __host__ T HostInterface::amo_fetch_add(void* dst, T value, int pe, /* Offload remote fetch and op operation to MPI */ T ret{}; - MPI_Win win{window_info->get_win()}; + MPI_Win win{window_info_mpi->get_win()}; MPI_Datatype mpi_type{get_mpi_type()}; MPI_Fetch_and_op(&value, &ret, mpi_type, pe, offset, MPI_SUM, win); @@ -313,6 +322,11 @@ __host__ T HostInterface::amo_fetch_add(void* dst, T value, int pe, template __host__ T HostInterface::amo_fetch_cas(void* dst, T value, T cond, int pe, WindowInfo* window_info) { + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } + /* Calculate offset of remote dest from base address of window */ MPI_Aint offset{ compute_offset(dst, window_info->get_start(), window_info->get_end())}; @@ -325,7 +339,7 @@ __host__ T HostInterface::amo_fetch_cas(void* dst, T value, T cond, int pe, /* Offload remote compare and swap operation to MPI */ T ret{}; - MPI_Win win{window_info->get_win()}; + MPI_Win win{window_info_mpi->get_win()}; MPI_Datatype mpi_type{get_mpi_type()}; MPI_Compare_and_swap(&value, &cond, &ret, mpi_type, pe, offset, win); @@ -452,6 +466,10 @@ __host__ inline int HostInterface::test_and_compare(MPI_Aint offset, template __host__ void HostInterface::wait_until(T *ivars, int cmp, T val, WindowInfo* window_info) { + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } DPRINTF("Function: host_wait_until\n"); /* @@ -461,7 +479,7 @@ __host__ void HostInterface::wait_until(T *ivars, int cmp, T val, compute_offset(ivars, window_info->get_start(), window_info->get_end())}; MPI_Datatype mpi_type{get_mpi_type()}; - MPI_Win win{window_info->get_win()}; + MPI_Win win{window_info_mpi->get_win()}; /* * Continuously read the ivars atomically until it satisfies the condition @@ -631,6 +649,10 @@ __host__ size_t HostInterface::wait_until_some_vector(T* ivars, size_t nelems, template __host__ int HostInterface::test(T* ivars, int cmp, T val, WindowInfo* window_info) { + WindowInfoMPI* window_info_mpi = dynamic_cast(window_info); + if (!window_info_mpi) { + abort(); + } DPRINTF("Function: host_test\n"); /* @@ -641,7 +663,7 @@ __host__ int HostInterface::test(T* ivars, int cmp, T val, MPI_Datatype mpi_type{get_mpi_type()}; - return test_and_compare(offset, mpi_type, cmp, val, window_info->get_win()); + return test_and_compare(offset, mpi_type, cmp, val, window_info_mpi->get_win()); } } // namespace rocshmem diff --git a/projects/rocshmem/src/ipc/backend_ipc.cpp b/projects/rocshmem/src/ipc/backend_ipc.cpp index 0da67eb9d2..934b07555c 100644 --- a/projects/rocshmem/src/ipc/backend_ipc.cpp +++ b/projects/rocshmem/src/ipc/backend_ipc.cpp @@ -22,6 +22,8 @@ * IN THE SOFTWARE. *****************************************************************************/ +#include + #include "backend_ipc.hpp" #include "ipc_team.hpp" @@ -63,11 +65,6 @@ IPCBackend::IPCBackend(MPI_Comm comm) : Backend(comm) { type = BackendType::IPC_BACKEND; - if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) { - std::stringstream sstream(maximum_num_contexts_str); - sstream >> maximum_num_contexts_; - } - initIPC(); /** @@ -83,6 +80,37 @@ IPCBackend::IPCBackend(MPI_Comm comm) default_host_ctx = std::make_unique(this, 0); + init(); +} + +IPCBackend::IPCBackend(TcpBootstrap *bootstrap) + : Backend(bootstrap) { + type = BackendType::IPC_BACKEND; + + initIPC(bootstrap); // no MPI involved + + /** + * Check if num_pes == ipcImpl.shm_size) + * All the PEs must be with in a node for IPC conduit + */ + assert(num_pes == ipcImpl.shm_size); + + /* Initialize the host interface */ + host_interface = std::make_shared(hdp_proxy_.get(), + bootstrap, + &heap); + + default_host_ctx = std::make_unique(this, 0); + + init(); +} + +void IPCBackend::init() { + if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) { + std::stringstream sstream(maximum_num_contexts_str); + sstream >> maximum_num_contexts_; + } + ROCSHMEM_HOST_CTX_DEFAULT.ctx_opaque = default_host_ctx.get(); setup_team_world(); @@ -181,6 +209,38 @@ void IPCBackend::team_destroy(rocshmem_team_t team) { CHECK_HIP(hipFree(team_obj)); } +void IPCBackend::Allreduce_char_BAND (char* inbuf, char *outbuf, size_t num_bytes, + Team *team) { + + // Implement an Allreduce outside of MPI. This is specialized for the scenario + // required for the team creation, i.e. assuming bytes and using BAND operation. + // Implementation uses an Allgather operation followed a local reduction. + + IPCTeam *team_obj = reinterpret_cast(team); + int num_pes = team_obj->num_pes; + int my_pe = team_obj->my_pe; + + char *tmp_buffer = new char[num_pes * num_bytes]; + std::memset(tmp_buffer, 0, num_pes * num_bytes); + std::memcpy (&tmp_buffer[my_pe * num_bytes], inbuf, num_bytes); + + if (num_pes == backend_bootstr->getNranks() ) { + backend_bootstr->allGather(tmp_buffer, num_bytes); + } else { + printf("IPCBackend::create_new_team: non-mpi version only supports parent_teams that contain all processes. Aborting.\n"); + abort(); + } + + for (int i = 0; i < num_bytes; i++) { + outbuf[i] = tmp_buffer[i]; + for (int j = 1; j < num_pes; j++) { + outbuf[i] &= tmp_buffer[j * num_bytes + i]; + } + } + + delete[] tmp_buffer; +} + void IPCBackend::create_new_team([[maybe_unused]] Team *parent_team, TeamInfo *team_info_wrt_parent, TeamInfo *team_info_wrt_world, int num_pes, @@ -190,8 +250,12 @@ void IPCBackend::create_new_team([[maybe_unused]] Team *parent_team, * Read the bit mask and find out a common index into * the pool of available work arrays. */ - NET_CHECK(MPI_Allreduce(pool_bitmask_, reduced_bitmask_, bitmask_size_, - MPI_CHAR, MPI_BAND, team_comm)); + if (team_comm != MPI_COMM_NULL) { + NET_CHECK(MPI_Allreduce(pool_bitmask_, reduced_bitmask_, bitmask_size_, + MPI_CHAR, MPI_BAND, team_comm)); + } else { + Allreduce_char_BAND (pool_bitmask_, reduced_bitmask_, bitmask_size_, parent_team); + } /* Pick the least significant non-zero bit (logical layout) in the reduced * bitmask */ @@ -199,6 +263,7 @@ void IPCBackend::create_new_team([[maybe_unused]] Team *parent_team, int common_index = get_ls_non_zero_bit(reduced_bitmask_, max_num_teams); if (common_index < 0) { /* No team available */ + printf("Could not create team, all bits in use. Aborting.\n"); abort(); } @@ -249,8 +314,18 @@ void IPCBackend::initIPC() { backend_comm); } +void IPCBackend::initIPC(TcpBootstrap *bootstr) { + const auto &heap_bases{heap.get_heap_bases()}; + + ipcImpl.ipcHostInit(my_pe, heap_bases, + bootstr); +} + void IPCBackend::global_exit(int status) { - MPI_Abort(backend_comm, status); + if (backend_comm != MPI_COMM_NULL) + MPI_Abort(backend_comm, status); + else + abort(); } void IPCBackend::teams_destroy() { @@ -315,8 +390,13 @@ void IPCBackend::init_wrk_sync_buffer() { /* * all-to-all exchange with each PE to share the IPC handles. */ - MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, - ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm); + if (backend_comm != MPI_COMM_NULL) { + MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR, + ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm); + } else { + assert (backend_bootstr != nullptr); + backend_bootstr->allGather(ipc_handle, sizeof(hipIpcMemHandle_t)); + } /* * Allocate device-side fine grained memory to hold IPC addresses of @@ -382,7 +462,11 @@ void IPCBackend::rocshmem_collective_init() { * Make sure that all processing elements have done this before * continuing. */ - NET_CHECK(MPI_Barrier(backend_comm)); + if (backend_comm != MPI_COMM_NULL) { + NET_CHECK(MPI_Barrier(backend_comm)); + } else { + backend_bootstr->barrier(); + } } void IPCBackend::teams_init() { @@ -474,7 +558,11 @@ void IPCBackend::teams_init() { * Make sure that all processing elements have done this before * continuing. */ - NET_CHECK(MPI_Barrier(backend_comm)); + if (backend_comm != MPI_COMM_NULL) { + NET_CHECK(MPI_Barrier(backend_comm)); + } else { + backend_bootstr->barrier(); + } } } // namespace rocshmem diff --git a/projects/rocshmem/src/ipc/backend_ipc.hpp b/projects/rocshmem/src/ipc/backend_ipc.hpp index 7d1a8f3242..b783fc6dc4 100644 --- a/projects/rocshmem/src/ipc/backend_ipc.hpp +++ b/projects/rocshmem/src/ipc/backend_ipc.hpp @@ -32,6 +32,7 @@ #include "../context_incl.hpp" #include "ipc_context_proxy.hpp" #include "../ipc_policy.hpp" +#include "../bootstrap/bootstrap.hpp" namespace rocshmem { @@ -43,6 +44,7 @@ class IPCBackend : public Backend { * @copydoc Backend::Backend(unsigned) */ explicit IPCBackend(MPI_Comm comm); + explicit IPCBackend(TcpBootstrap *bootstr); /** * @copydoc Backend::~Backend() @@ -72,6 +74,11 @@ class IPCBackend : public Backend { */ void initIPC(); + /** + * @brief Helper to initialize IPC interface, non-MPI based version. + */ + void initIPC(TcpBootstrap *bootstrap); + /** * @brief Allocation and initialization of backend contexts. */ @@ -207,6 +214,11 @@ class IPCBackend : public Backend { void setup_fence_buffer(); private: + /** + * @brief Common code invoked from the different constructors + */ + void init(); + /** * @brief Proxy for the default context * @@ -288,6 +300,11 @@ class IPCBackend : public Backend { */ void cleanup_wrk_sync_buffer(); + /** + * @brief + */ + void Allreduce_char_BAND (char* inbuf, char *outbuf, size_t num_bytes, Team *team); + }; } // namespace rocshmem diff --git a/projects/rocshmem/src/ipc_policy.cpp b/projects/rocshmem/src/ipc_policy.cpp index ada1609f28..0855f1f09c 100644 --- a/projects/rocshmem/src/ipc_policy.cpp +++ b/projects/rocshmem/src/ipc_policy.cpp @@ -119,6 +119,71 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, } } +__host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, + TcpBootstrap *bootstr) { + /* + * The non-MPI based version only works for ipc conduit for now, + * i.e. total number of ranks and number of local ranks have to match. + */ + shm_size = bootstr->getNranksPerNode(); + assert (shm_size == bootstr->getNranks()); + shm_rank = my_pe; + + /* + * Allocate a host-side c-array to hold the IPC handles. + */ + void *ipc_mem_handle_uncast = malloc(shm_size * sizeof(hipIpcMemHandle_t)); + hipIpcMemHandle_t *vec_ipc_handle = + reinterpret_cast(ipc_mem_handle_uncast); + + /* + * Call into the hip runtime to get an IPC handle for my symmetric + * heap and store that IPC handle into the host-side c-array which was + * just allocated. + */ + char *base_heap = heap_bases[my_pe]; + CHECK_HIP(hipIpcGetMemHandle(&vec_ipc_handle[shm_rank], base_heap)); + + /* + * Do an all-to-all exchange with each local processing element to + * share the symmetric heap IPC handles. + */ + bootstr->allGather(vec_ipc_handle, sizeof(hipIpcMemHandle_t)); + + /* + * Allocate device-side array to hold the IPC symmetric heap base + * addresses. + */ + char **ipc_base; + CHECK_HIP(hipMalloc(reinterpret_cast(&ipc_base), + shm_size * sizeof(char **))); + + /* + * For all local processing elements, initialize the device-side array + * with the IPC symmetric heap base addresses. + */ + for (int i = 0; i < shm_size; i++) { + if (i != shm_rank) { + void **ipc_base_uncast = reinterpret_cast(&ipc_base[i]); + CHECK_HIP(hipIpcOpenMemHandle(ipc_base_uncast, vec_ipc_handle[i], + hipIpcMemLazyEnablePeerAccess)); + } else { + ipc_base[i] = base_heap; + } + } + + /* + * Set member variables used by subsequent method calls. + */ + ipc_bases = ipc_base; + + /* + * Free the host-side memory used to exchange the symmetric heap base + * addresses. + */ + free(vec_ipc_handle); +} + __host__ void IpcOnImpl::ipcHostStop() { for (int i = 0; i < shm_size; i++) { if (i != shm_rank) { diff --git a/projects/rocshmem/src/ipc_policy.hpp b/projects/rocshmem/src/ipc_policy.hpp index c1560ab8ba..7973390e40 100644 --- a/projects/rocshmem/src/ipc_policy.hpp +++ b/projects/rocshmem/src/ipc_policy.hpp @@ -34,6 +34,7 @@ #include "rocshmem_config.h" // NOLINT(build/include_subdir) #include "memory/hip_allocator.hpp" #include "util.hpp" +#include "bootstrap/bootstrap.hpp" namespace rocshmem { @@ -55,6 +56,9 @@ class IpcOnImpl { __host__ void ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, MPI_Comm thread_comm); + __host__ void ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, + TcpBootstrap *bootstrap); + __host__ void ipcHostStop(); __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { @@ -137,6 +141,9 @@ class IpcOffImpl { __host__ void ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, MPI_Comm thread_comm) {} + __host__ void ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, + TcpBootstrap *bootstrap){} + __host__ void ipcHostStop() {} __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { return false; } diff --git a/projects/rocshmem/src/memory/remote_heap_info.hpp b/projects/rocshmem/src/memory/remote_heap_info.hpp index f941e146e8..9918540668 100644 --- a/projects/rocshmem/src/memory/remote_heap_info.hpp +++ b/projects/rocshmem/src/memory/remote_heap_info.hpp @@ -32,6 +32,7 @@ #include "hip_allocator.hpp" #include "window_info.hpp" +#include "../bootstrap/bootstrap.hpp" /** * @file remote_heap_info.hpp @@ -56,7 +57,7 @@ class CommunicatorMPI { : comm_{comm} { MPI_Comm_rank(comm_, &my_pe_); MPI_Comm_size(comm_, &num_pes_); - heap_window_info_ = WindowInfo(comm_, heap_base, heap_size); + heap_window_info_ = WindowInfoMPI(comm_, heap_base, heap_size); } /** @@ -90,7 +91,7 @@ class CommunicatorMPI { /** * @brief Accessor method for heap_window_info_ */ - WindowInfo* get_window_info() { return &heap_window_info_; } + WindowInfoMPI* get_window_info() { return &heap_window_info_; } private: /** @@ -111,6 +112,75 @@ class CommunicatorMPI { /** * @brief MPI window on the symmetric GPU heap */ + WindowInfoMPI heap_window_info_{}; +}; + + +class CommunicatorTCP { + public: + + /** + * @brief Primary constructor + */ + CommunicatorTCP(char* heap_base, size_t heap_size, + TcpBootstrap* bootstrap) : bootstrap_{bootstrap} { + my_pe_ = bootstrap_->getRank(); + num_pes_ = bootstrap_->getNranks(); + + heap_window_info_ = WindowInfo(heap_base, heap_size); + } + + /** + * @brief Destructor + */ + ~CommunicatorTCP() {} + + /** + * @brief Returns my processing element ID + */ + int my_pe() { return my_pe_; } + + /** + * @brief Returns number of processing elements + */ + int num_pes() { return num_pes_; } + + /** + * @brief Performs MPI_Barrier + */ + void barrier() {bootstrap_->barrier(); } + + /** + * @brief Performs MPI_Allgather on recvbuf + */ + void allgather(void* recvbuf) { + bootstrap_->allGather(recvbuf, sizeof(void*)); + } + + /** + * @brief Accessor method for heap_window_info_ + */ + WindowInfo* get_window_info() { return &heap_window_info_; } + + private: + /** + * @brief Identifier for this processing element + */ + TcpBootstrap* bootstrap_; + + /** + * @brief Identifier for this processing element + */ + int my_pe_{-1}; + + /** + * @brief The total number of processing elements + */ + int num_pes_{-1}; + + /** + * @brief window on the symmetric GPU heap + */ WindowInfo heap_window_info_{}; }; @@ -138,14 +208,13 @@ class RemoteHeapInfo { RemoteHeapInfo(char* heap_ptr, size_t heap_size, MPI_Comm comm = MPI_COMM_WORLD) : communicator_{heap_ptr, heap_size, comm} { - heap_bases_.resize(communicator_.num_pes()); - for (auto& base : heap_bases_) { - base = nullptr; - } - heap_bases_[communicator_.my_pe()] = heap_ptr; - communicator_.allgather(heap_bases_.data()); + init(heap_ptr, heap_size); + } - device_heap_bases_ = heap_bases_.data(); + RemoteHeapInfo(char* heap_ptr, size_t heap_size, + TcpBootstrap* bootstrap) + : communicator_{heap_ptr, heap_size, bootstrap} { + init(heap_ptr, heap_size); } /** @@ -189,6 +258,20 @@ class RemoteHeapInfo { __device__ auto get_heap_bases() { return device_heap_bases_; } private: + /** + ** @brief common initialization code + */ + void init(char* heap_ptr, size_t heap_size) { + heap_bases_.resize(communicator_.num_pes()); + for (auto& base : heap_bases_) { + base = nullptr; + } + heap_bases_[communicator_.my_pe()] = heap_ptr; + communicator_.allgather(heap_bases_.data()); + + device_heap_bases_ = heap_bases_.data(); + } + /** * @brief Communicator implementation */ diff --git a/projects/rocshmem/src/memory/symmetric_heap.hpp b/projects/rocshmem/src/memory/symmetric_heap.hpp index f7f30daa6b..f7d3cf4871 100644 --- a/projects/rocshmem/src/memory/symmetric_heap.hpp +++ b/projects/rocshmem/src/memory/symmetric_heap.hpp @@ -41,25 +41,58 @@ * which needs to be shared across the network (to access the memory * region). */ - #include #include "remote_heap_info.hpp" #include "single_heap.hpp" +#include "../bootstrap/bootstrap.hpp" namespace rocshmem { +class RemoteHeapInfoAbstract { +public: + virtual WindowInfo* get_window_info() = 0; + __host__ virtual const std::vector>& get_heap_bases() = 0; + __device__ char** get_heap_bases() { return nullptr; } +}; + +class RemoteHeapInfoMPI : public RemoteHeapInfoAbstract { +public: + RemoteHeapInfoMPI(char *base_ptr, size_t size, MPI_Comm comm) : rheap(base_ptr, size, comm) {}; + + WindowInfo* get_window_info() override { return rheap.get_window_info(); }; + __host__ const std::vector>& get_heap_bases() override { return rheap.get_heap_bases(); }; + __device__ char** get_heap_bases() { return rheap.get_heap_bases(); }; + +private: + RemoteHeapInfo rheap; +}; + +class RemoteHeapInfoTCP : public RemoteHeapInfoAbstract { +public: + RemoteHeapInfoTCP(char *base_ptr, size_t size, TcpBootstrap *bootstrap) : rheap(base_ptr, size, bootstrap) {}; + + WindowInfo* get_window_info() override { return rheap.get_window_info(); }; + __host__ const std::vector>& get_heap_bases() override { return rheap.get_heap_bases(); }; + __device__ char** get_heap_bases() { return rheap.get_heap_bases(); }; + +private: + RemoteHeapInfo rheap; +}; + class SymmetricHeap { - /** - * @brief Helper type for RemoteHeapInfo with MPI - */ - using RemoteHeapInfoType = RemoteHeapInfo; public: - SymmetricHeap(MPI_Comm comm = MPI_COMM_WORLD) - : remote_heap_info_{single_heap_.get_base_ptr(), - single_heap_.get_size(), - comm} {} + SymmetricHeap(MPI_Comm comm = MPI_COMM_NULL, TcpBootstrap* bootstrap = nullptr) { + + if (comm != MPI_COMM_NULL) { + remote_heap_info_ = new RemoteHeapInfoMPI(single_heap_.get_base_ptr(), + single_heap_.get_size(), comm); + } else { + remote_heap_info_ = new RemoteHeapInfoTCP(single_heap_.get_base_ptr(), + single_heap_.get_size(), bootstrap); + } + } /** * @brief Allocates heap memory and returns ptr to caller * @@ -87,10 +120,17 @@ class SymmetricHeap { */ auto get_size() { return single_heap_.get_size(); } + /** + * @brief Returns is the heap is allocated with managed memory + * + * @return bool + */ + bool is_managed() { return single_heap_.is_managed(); } + /** * @brief Accessor method for heap_window_info_ */ - auto get_window_info() { return remote_heap_info_.get_window_info(); } + auto get_window_info() { return remote_heap_info_->get_window_info(); } /** * @brief Accessor for heap bases @@ -98,7 +138,7 @@ class SymmetricHeap { * @return Vector containing the addresses of the symmetric heap bases */ __host__ const auto& get_heap_bases() { - return remote_heap_info_.get_heap_bases(); + return remote_heap_info_->get_heap_bases(); } /** @@ -107,16 +147,9 @@ class SymmetricHeap { * @return Vector containing the addresses of the symmetric heap bases */ __device__ auto get_heap_bases() { - return remote_heap_info_.get_heap_bases(); + return remote_heap_info_->get_heap_bases(); } - /** - * @brief Returns is the heap is allocated with managed memory - * - * @return bool - */ - bool is_managed() { return single_heap_.is_managed(); } - private: /** * @brief Processing element's implementation of heap @@ -126,7 +159,7 @@ class SymmetricHeap { /** * @brief Implementation of remote heaps */ - RemoteHeapInfoType remote_heap_info_{}; + RemoteHeapInfoAbstract *remote_heap_info_{nullptr}; }; } // namespace rocshmem diff --git a/projects/rocshmem/src/memory/window_info.hpp b/projects/rocshmem/src/memory/window_info.hpp index 0e0a53a241..82b66a630b 100644 --- a/projects/rocshmem/src/memory/window_info.hpp +++ b/projects/rocshmem/src/memory/window_info.hpp @@ -48,29 +48,18 @@ class WindowInfo { /** * @brief Primary constructor */ - WindowInfo(MPI_Comm comm, void* start, size_t size) - : comm_{comm}, - win_start_{start}, - win_end_{reinterpret_cast(start) + size} { - up_win_ = std::unique_ptr(new MPI_Win); - MPI_Win_create(win_start_, size, 1, MPI_INFO_NULL, comm_, up_win_.get()); - MPI_Win_lock_all(MPI_MODE_NOCHECK, *up_win_.get()); - } + WindowInfo(void* start, size_t size) + : win_start_{start}, + win_end_{reinterpret_cast(start) + size} {} /** * @brief Destructor */ - ~WindowInfo() { - if (up_win_) { - MPI_Win_unlock_all(*up_win_.get()); - MPI_Win_free(up_win_.get()); - } - } + ~WindowInfo() = default; /** * @brief Copy constructor * - * @note Disabled due to up_win_ */ WindowInfo(WindowInfo& other) = delete; // NOLINT @@ -81,13 +70,6 @@ class WindowInfo { */ WindowInfo(const WindowInfo& other) = delete; - /** - * @brief Copy assignment - * - * @note Disabled due to up_win_ - */ - WindowInfo& operator=(WindowInfo other) = delete; - /** * @brief Move constructor */ @@ -98,13 +80,6 @@ class WindowInfo { */ WindowInfo& operator=(WindowInfo&& other) = default; - /** - * @brief Accessor for object in up_win_ - * - * @return MPI_Win object - */ - MPI_Win get_win() const { return *up_win_.get(); } - /** * @brief Accessor for win_start_ * @@ -119,13 +94,6 @@ class WindowInfo { */ void* get_end() const { return win_end_; } - /** - * @brief Setter for object in up_win_ - * - * @param[in] An MPI Window object - */ - void set_win(MPI_Win win) { *up_win_.get() = win; } - /** * @brief Setter for win_start_ * @@ -147,7 +115,105 @@ class WindowInfo { * * @return Difference between dest and window start */ - MPI_Aint get_offset(const void* dest) { + virtual ptrdiff_t get_offset(const void* dest) { + assert(reinterpret_cast(const_cast(dest)) >= + reinterpret_cast(win_start_)); + assert(reinterpret_cast(const_cast(dest)) >= + reinterpret_cast(win_start_)); + assert(reinterpret_cast(const_cast(dest)) < + reinterpret_cast(win_end_)); + + return reinterpret_cast(reinterpret_cast(const_cast(dest)) - reinterpret_cast(win_start_)); + } + + protected: + /** + * @brief Raw pointer marking the start of window + */ + void* win_start_{nullptr}; + + /** + * @brief Raw pointer marking the end of window + */ + void* win_end_{nullptr}; +}; + + +class WindowInfoMPI: public WindowInfo { + public: + /** + * @brief Default constructor + */ + WindowInfoMPI() = default; + + /** + * @brief Primary constructor + */ + WindowInfoMPI(MPI_Comm comm, void* start, size_t size) + : comm_{comm} { + win_start_ = start; + win_end_ = reinterpret_cast(start) + size; + + up_win_ = std::unique_ptr(new MPI_Win); + MPI_Win_create(win_start_, size, 1, MPI_INFO_NULL, comm_, up_win_.get()); + MPI_Win_lock_all(MPI_MODE_NOCHECK, *up_win_.get()); + } + + /** + * @brief Destructor + */ + ~WindowInfoMPI() { + if (up_win_) { + MPI_Win_unlock_all(*up_win_.get()); + MPI_Win_free(up_win_.get()); + } + } + + /** + * @brief Copy constructor + * + */ + WindowInfoMPI(WindowInfoMPI& other) = delete; // NOLINT + + /** + * @brief Const copy constructor + * + * @note Disabled due to up_win_ + */ + WindowInfoMPI(const WindowInfoMPI& other) = delete; + + /** + * @brief Move constructor + */ + WindowInfoMPI(WindowInfoMPI&& other) = default; + + /** + * @brief Move assignment + */ + WindowInfoMPI& operator=(WindowInfoMPI&& other) = default; + + /** + * @brief Accessor for object in up_win_ + * + * @return MPI_Win object + */ + MPI_Win get_win() const { return *up_win_.get(); } + + /** + * @brief Setter for object in up_win_ + * + * @param[in] An MPI Window object + */ + void set_win(MPI_Win win) { *up_win_.get() = win; } + + /** + * @brief Get offset between address and start of window + * + * @param[in] Address in raw pointer format + * + * @return Difference between dest and window start + */ + ptrdiff_t get_offset(const void* dest) override { assert(reinterpret_cast(const_cast(dest)) >= reinterpret_cast(win_start_)); assert(reinterpret_cast(const_cast(dest)) >= @@ -160,7 +226,7 @@ class WindowInfo { MPI_Aint start_disp; MPI_Get_address(win_start_, &start_disp); - return MPI_Aint_diff(dest_disp, start_disp); + return static_cast(MPI_Aint_diff(dest_disp, start_disp)); } private: @@ -180,15 +246,6 @@ class WindowInfo { */ std::unique_ptr up_win_{nullptr}; - /** - * @brief Raw pointer marking the start of window - */ - void* win_start_{nullptr}; - - /** - * @brief Raw pointer marking the end of window - */ - void* win_end_{nullptr}; }; } // namespace rocshmem diff --git a/projects/rocshmem/src/mpi_instance.cpp b/projects/rocshmem/src/mpi_instance.cpp index 37ff3b4228..d863b003af 100644 --- a/projects/rocshmem/src/mpi_instance.cpp +++ b/projects/rocshmem/src/mpi_instance.cpp @@ -27,11 +27,13 @@ namespace rocshmem { MPIInstance::MPIInstance(MPI_Comm comm) { - MPI_Initialized(&pre_init_done); + int is_init{0}; + MPI_Initialized(&is_init); - if (!pre_init_done) { + if (!is_init) { int provided; MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); + init_in_this_class = 1; } if (comm == MPI_COMM_NULL) { @@ -45,7 +47,7 @@ MPIInstance::MPIInstance(MPI_Comm comm) { MPIInstance::~MPIInstance() { int finalized{0}; MPI_Finalized(&finalized); - if (!finalized && !pre_init_done) { + if (!finalized && init_in_this_class) { MPI_Finalize(); } } diff --git a/projects/rocshmem/src/mpi_instance.hpp b/projects/rocshmem/src/mpi_instance.hpp index a9b83fd88d..5192ed1278 100644 --- a/projects/rocshmem/src/mpi_instance.hpp +++ b/projects/rocshmem/src/mpi_instance.hpp @@ -75,9 +75,9 @@ class MPIInstance { int nprocs_{-1}; /** - * @brief Was MPI initialized before rocshmem_init call + * @brief Was MPI initialized in this class */ - int pre_init_done{0}; + int init_in_this_class{0}; }; } // namespace rocshmem diff --git a/projects/rocshmem/src/reverse_offload/backend_proxy.hpp b/projects/rocshmem/src/reverse_offload/backend_proxy.hpp index ecb1d54bba..7efad56b77 100644 --- a/projects/rocshmem/src/reverse_offload/backend_proxy.hpp +++ b/projects/rocshmem/src/reverse_offload/backend_proxy.hpp @@ -39,7 +39,7 @@ struct BackendRegister { bool *needs_quiet{nullptr}; bool *needs_blocking{nullptr}; HdpPolicy *hdp_policy{nullptr}; - WindowInfo **heap_window_info{nullptr}; + WindowInfoMPI **heap_window_info{nullptr}; SymmetricHeap *heap_ptr{nullptr}; }; diff --git a/projects/rocshmem/src/reverse_offload/context_ro_host.cpp b/projects/rocshmem/src/reverse_offload/context_ro_host.cpp index 346ed5d0ab..ecf00673eb 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_host.cpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_host.cpp @@ -40,7 +40,7 @@ __host__ ROHostContext::ROHostContext(Backend *backend, long options) host_interface = b->host_interface; - context_window_info = host_interface->acquire_window_context(); + context_window_info = dynamic_cast(host_interface->acquire_window_context()); } __host__ ROHostContext::~ROHostContext() { diff --git a/projects/rocshmem/src/reverse_offload/context_ro_host.hpp b/projects/rocshmem/src/reverse_offload/context_ro_host.hpp index 0291c525ef..c306c0690e 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_host.hpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_host.hpp @@ -54,14 +54,14 @@ class ROContextWindowInfo { * * @return WindowInfo pointer */ - WindowInfo *get() { return window_info_; } + WindowInfoMPI *get() { return window_info_; } private: /** * @brief Pointer to the WindowInfo object that manages the MPI Window for * this context */ - WindowInfo *window_info_{nullptr}; + WindowInfoMPI *window_info_{nullptr}; }; class ROHostContext : public Context { @@ -79,7 +79,7 @@ class ROHostContext : public Context { HostInterface *host_interface = nullptr; /* An MPI Window implements a context */ - WindowInfo *context_window_info = nullptr; + WindowInfoMPI *context_window_info = nullptr; /************************************************************************** ****************************** HOST METHODS ****************************** diff --git a/projects/rocshmem/src/reverse_offload/window_proxy.hpp b/projects/rocshmem/src/reverse_offload/window_proxy.hpp index 0afa532913..3883628ebb 100644 --- a/projects/rocshmem/src/reverse_offload/window_proxy.hpp +++ b/projects/rocshmem/src/reverse_offload/window_proxy.hpp @@ -34,7 +34,7 @@ namespace rocshmem { template class WindowProxy { private: - using ProxyT = DeviceProxy; + using ProxyT = DeviceProxy; public: /* @@ -43,11 +43,11 @@ class WindowProxy { WindowProxy(SymmetricHeap *heap, MPI_Comm comm, size_t num_windows) : num_windows_{num_windows}, proxy_{num_windows} { - auto *window_info{proxy_.get()}; + WindowInfoMPI** window_info{proxy_.get()}; for (size_t i{0}; i < num_windows_; i++) { window_info[i] = - new WindowInfo(comm, heap->get_local_heap_base(), heap->get_size()); + new WindowInfoMPI(comm, heap->get_local_heap_base(), heap->get_size()); } } @@ -74,7 +74,7 @@ class WindowProxy { /* * @brief Provide access to the memory referenced by the proxy */ - __host__ __device__ WindowInfo **get() { return proxy_.get(); } + __host__ __device__ WindowInfoMPI **get() { return proxy_.get(); } __host__ size_t get_num_MPI_windows() { return num_windows_; } private: diff --git a/projects/rocshmem/src/rocshmem.cpp b/projects/rocshmem/src/rocshmem.cpp index 37a8826b1c..f992a70bbb 100644 --- a/projects/rocshmem/src/rocshmem.cpp +++ b/projects/rocshmem/src/rocshmem.cpp @@ -68,7 +68,7 @@ namespace rocshmem { Backend *backend = nullptr; MPIInstance *mpi_instance = nullptr; - +TcpBootstrap *bootstr = nullptr; rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; /** @@ -102,8 +102,85 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } } +[[maybe_unused]] __host__ static void inline library_init_subcomm(TcpBootstrap *bootstrap, int nranks, int rank) { + int initialized; + int world_size = -1; + MPI_Initialized(&initialized); + + if (!initialized) { + // This is an Open MPI specific solution to retrieve the number of + // processes that have been started, value can be checked before MPI_Init + char *value = getenv("OMPI_COMM_WORLD_SIZE"); + if (value != NULL) { + world_size = atoi(value); + } + if (world_size != nranks) { + // This solution will require MPI_Sessions. This is planned for the + // future, but is not supported in the current version. + fprintf (stderr, "Unsupported configuration to initialize rocSHMEM. Please " + "initialize the MPI library using MPI_Init first, if you want to " + "initialize rocSHMEM with a subset of the processes\n"); + abort(); + } + } else { + MPI_Comm_size (MPI_COMM_WORLD, &world_size); + } + + if (world_size == nranks) { + library_init(MPI_COMM_WORLD); + } else { + MPI_Group world_group; + int world_rank; + + MPI_Comm_rank (MPI_COMM_WORLD, &world_rank); + MPI_Comm_group (MPI_COMM_WORLD, &world_group); + + int *inc_ranks = new int[nranks]; + inc_ranks[rank] = world_rank; + + bootstr->allGather (inc_ranks, sizeof(int)); + + MPI_Group sub_group; + MPI_Comm sub_comm; + MPI_Group_incl (world_group, nranks, inc_ranks, &sub_group); + MPI_Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm); + + library_init(sub_comm); + + MPI_Group_free (&sub_group); + MPI_Group_free (&world_group); + MPI_Comm_free (&sub_comm); + delete[] inc_ranks; + } +} + +[[maybe_unused]] __host__ void inline library_init(TcpBootstrap *bootstrap) { + assert(!backend); + int count = 0; + CHECK_HIP(hipGetDeviceCount(&count)); + + if (count == 0) { + printf("No GPU found! \n"); + abort(); + } + + rocm_init(); + +#ifdef USE_RO + printf("RO Backend requires MPI library to be initialized, even when using uniqueId initializations!\n"); + abort(); +#else + CHECK_HIP(hipHostMalloc(&backend, sizeof(IPCBackend))); + backend = new (backend) IPCBackend(bootstrap); +#endif + + if (!backend) { + abort(); + } +} + [[maybe_unused]] __host__ int rocshmem_init_attr(unsigned int flags, - rocshmem_init_attr_t *attr) { + rocshmem_init_attr_t *attr) { MPI_Comm comm = MPI_COMM_NULL; if ((attr == nullptr) || @@ -122,57 +199,17 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } if (flags == ROCSHMEM_INIT_WITH_UNIQUEID) { - int initialized; - int world_size = -1; - MPI_Initialized(&initialized); + assert (attr->nranks > 0); + assert (attr->rank >= 0); + assert (attr->rank < attr->nranks); - if (!initialized) { - // This is an Open MPI specific solution to retrieve the number of - // processes that have been started, value can be checked before MPI_Init - char *value = getenv("OMPI_COMM_WORLD_SIZE"); - if (value != NULL) { - world_size = atoi(value); - } - if (world_size != attr->nranks) { - // This solution will require MPI_Sessions. This is planned for the - // future, but is not supported in the current version. - fprintf (stderr, "Unsupported configuration to initialize rocSHMEM. Please " - "initialize the MPI library using MPI_Init first, if you want to " - "initialize rocSHMEM with a subset of the processes\n"); - abort(); - } + bootstr = new TcpBootstrap(attr->rank, attr->nranks); + bootstr->initialize(attr->uid, rocshmem_env_.get_bootstrap_timeout()); + + if (rocshmem_env_.get_uniqueid_with_mpi() ) { + library_init_subcomm(bootstr, attr->nranks, attr->rank); } else { - MPI_Comm_size (MPI_COMM_WORLD, &world_size); - } - - if (world_size == attr->nranks) { - library_init(MPI_COMM_WORLD); - } else { - MPI_Group world_group; - int world_rank; - - MPI_Comm_rank (MPI_COMM_WORLD, &world_rank); - MPI_Comm_group (MPI_COMM_WORLD, &world_group); - - TcpBootstrap bootstr(attr->rank, attr->nranks); - - bootstr.initialize(attr->uid, rocshmem_env_.get_bootstrap_timeout()); - int *inc_ranks = new int[attr->nranks]; - inc_ranks[attr->rank] = world_rank; - - bootstr.allGather (inc_ranks, sizeof(int)); - - MPI_Group sub_group; - MPI_Comm sub_comm; - MPI_Group_incl (world_group, attr->nranks, inc_ranks, &sub_group); - MPI_Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm); - - library_init(sub_comm); - - MPI_Group_free (&sub_group); - MPI_Group_free (&world_group); - MPI_Comm_free (&sub_comm); - delete[] inc_ranks; + library_init (bootstr); } } @@ -227,8 +264,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } [[maybe_unused]] __host__ int rocshmem_my_pe() { - if (mpi_instance != nullptr) { - return mpi_instance->get_rank(); + if (backend != nullptr) { + return backend->getMyPE(); } fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n"); @@ -236,8 +273,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } [[maybe_unused]] __host__ int rocshmem_n_pes() { - if (mpi_instance != nullptr) { - return mpi_instance->get_nprocs(); + if (backend != nullptr) { + return backend->getNumPEs(); } fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n"); @@ -294,7 +331,11 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; backend->~Backend(); CHECK_HIP(hipHostFree(backend)); - delete mpi_instance; + if (bootstr == nullptr) + delete mpi_instance; + + if (bootstr != nullptr) + delete bootstr; } __host__ void rocshmem_query_thread(int *provided) { @@ -395,22 +436,24 @@ __host__ int rocshmem_team_split_strided( new (team_info_wrt_world) TeamInfo(team_world, pe_start_in_world, stride_in_world, size); - /* Create a new MPI communicator for this team */ - int color; - if (my_pe_in_new_team < 0) { - color = MPI_UNDEFINED; - } else { - color = 1; + MPI_Comm team_comm{MPI_COMM_NULL}; + if (parent_team_obj->mpi_comm != MPI_COMM_NULL) { + /* Create a new MPI communicator for this team */ + int color; + if (my_pe_in_new_team < 0) { + color = MPI_UNDEFINED; + } else { + color = 1; + } + + MPI_Comm_split(parent_team_obj->mpi_comm, color, my_pe_in_world, &team_comm); } - - MPI_Comm team_comm; - MPI_Comm_split(parent_team_obj->mpi_comm, color, my_pe_in_world, &team_comm); - /** * Allocate new team for GPU-inittiated communication with backend-specific * objects * TODO: are there any backend specific objects? */ + if (my_pe_in_new_team < 0) { *new_team = ROCSHMEM_TEAM_INVALID; } else { @@ -422,7 +465,10 @@ __host__ int rocshmem_team_split_strided( * not */ backend->team_tracker.track(*new_team); } - MPI_Comm_free (&team_comm); + + if (team_comm != MPI_COMM_NULL) { + MPI_Comm_free (&team_comm); + } return 0; } diff --git a/projects/rocshmem/src/team.cpp b/projects/rocshmem/src/team.cpp index d3e5bc74d5..ab06c22d74 100644 --- a/projects/rocshmem/src/team.cpp +++ b/projects/rocshmem/src/team.cpp @@ -79,7 +79,9 @@ __host__ Team::Team(Backend* handle, TeamInfo* team_info_wrt_parent, tinfo_wrt_world(team_info_wrt_world), num_pes(_num_pes), my_pe(_my_pe) { - MPI_Comm_dup (_mpi_comm, &mpi_comm); + if (_mpi_comm != MPI_COMM_NULL) { + MPI_Comm_dup (_mpi_comm, &mpi_comm); + } } __host__ __device__ int Team::get_pe_in_world(int pe) { diff --git a/projects/rocshmem/src/util.cpp b/projects/rocshmem/src/util.cpp index 76ff42aa4f..8c48870557 100644 --- a/projects/rocshmem/src/util.cpp +++ b/projects/rocshmem/src/util.cpp @@ -137,6 +137,11 @@ rocshmem_env_config::rocshmem_env_config() { ro_progress_delay = atoi(env_value); } + env_value = getenv("ROCSHMEM_UNIQUEID_WITH_MPI"); + if (nullptr != env_value) { + uniqueid_with_mpi = atoi(env_value); + } + env_value = getenv("ROCSHMEM_BOOTSTRAP_TIMEOUT"); if (nullptr != env_value) { bootstrap_timeout = atoi(env_value); @@ -166,6 +171,10 @@ int rocshmem_env_config::get_ro_progress_delay() { return ro_progress_delay; } +int rocshmem_env_config::get_uniqueid_with_mpi() { + return uniqueid_with_mpi; +} + int rocshmem_env_config::get_bootstrap_timeout() { return bootstrap_timeout; } diff --git a/projects/rocshmem/src/util.hpp b/projects/rocshmem/src/util.hpp index f4793563c2..d46d2609b6 100644 --- a/projects/rocshmem/src/util.hpp +++ b/projects/rocshmem/src/util.hpp @@ -273,6 +273,7 @@ public: int get_ro_disable_ipc(); int get_ro_progress_delay(); + int get_uniqueid_with_mpi(); int get_bootstrap_timeout(); std::string get_bootstrap_hostid(); std::string get_bootstrap_socket_family(); @@ -282,6 +283,7 @@ private: int ro_disable_ipc = 0; int ro_progress_delay = 3; int bootstrap_timeout = 5; + int uniqueid_with_mpi = 0; std::string bootstrap_hostid; std::string bootstrap_socket_family; std::string bootstrap_socket_ifname; diff --git a/projects/rocshmem/tests/functional_tests/CMakeLists.txt b/projects/rocshmem/tests/functional_tests/CMakeLists.txt index 8df56386b7..b064a8fce2 100644 --- a/projects/rocshmem/tests/functional_tests/CMakeLists.txt +++ b/projects/rocshmem/tests/functional_tests/CMakeLists.txt @@ -71,6 +71,9 @@ target_sources( ############################################################################### # ROCSHMEM ############################################################################### +include(${CMAKE_SOURCE_DIR}/cmake/find_pmix.cmake) +check_pmix() + if (BUILD_TESTS_ONLY) find_package(MPI REQUIRED) find_package(hip REQUIRED PATHS /opt/rocm) @@ -84,8 +87,21 @@ if (BUILD_TESTS_ONLY) ) endif() +target_compile_definitions( + ${PROJECT_NAME} + PRIVATE + $<$:HAVE_PMIX=1> +) + +target_include_directories( + ${PROJECT_NAME} + PRIVATE + ${PMIX_INCLUDE_DIRECTORIES} +) + target_link_libraries( ${PROJECT_NAME} PRIVATE roc::rocshmem + ${PMIX_LIBRARIES} ) diff --git a/projects/rocshmem/tests/functional_tests/test_driver.cpp b/projects/rocshmem/tests/functional_tests/test_driver.cpp index 1214c2b6be..d23f97d2ae 100644 --- a/projects/rocshmem/tests/functional_tests/test_driver.cpp +++ b/projects/rocshmem/tests/functional_tests/test_driver.cpp @@ -28,6 +28,104 @@ #include "tester.hpp" #include "tester_arguments.hpp" +#if defined(HAVE_PMIX) +#include + +static pmix_proc_t pmix_myproc; +static pmix_proc_t pmix_proc; + +static void init_pmix(int *rank, int *nranks) +{ + pmix_status_t rc; + pmix_value_t *val; + + if (PMIX_SUCCESS != (rc = PMIx_Init(&pmix_myproc, NULL, 0))) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Init failed: " << rc << std::endl; + abort(); + } +#ifdef VERBOSE + printf("Client ns %s rank %d: Running\n", pmix_myproc.nspace, pmix_myproc.rank); +#endif + PMIX_PROC_CONSTRUCT(&pmix_proc); + PMIX_LOAD_PROCID(&pmix_proc, pmix_myproc.nspace, PMIX_RANK_WILDCARD); + + /* get our job size */ + if (PMIX_SUCCESS != (rc = PMIx_Get(&pmix_proc, PMIX_JOB_SIZE, NULL, 0, &val))) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Get universe size failed: " + << rc << std::endl; + abort(); + } + + *nranks = val->data.uint32; + *rank = pmix_myproc.rank; + + PMIX_VALUE_RELEASE(val); + return; +} + +static void pmix_bcast(void *buf, size_t nbytes, char *key, int root) +{ + pmix_status_t rc; + pmix_value_t value; + pmix_value_t *val; + pmix_info_t *info; + bool flag; + + if (pmix_myproc.rank == root) { + value.type = PMIX_BYTE_OBJECT; + value.data.bo.bytes = (char *) (buf); + value.data.bo.size = nbytes; + + rc = PMIx_Put(PMIX_GLOBAL, key, &value); + if (PMIX_SUCCESS != rc) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Put failed: " << rc << std::endl; + abort(); + } + + /* push the data to our PMIx server */ + if (PMIX_SUCCESS != (rc = PMIx_Commit())) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Commit failed: " << rc << std::endl; + abort(); + } + } + + /* call fence to synchronize with our peers - instruct + * the fence operation to collect and return all "put" + * data from our peers */ + PMIX_INFO_CREATE(info, 1); + flag = true; + PMIX_INFO_LOAD(info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL); + if (PMIX_SUCCESS != (rc = PMIx_Fence(&pmix_proc, 1, info, 1))) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Fence failed: " << rc << std::endl; + abort(); + } + PMIX_INFO_FREE(info, 1); + + pmix_proc.rank = 0; + if (PMIX_SUCCESS != (rc = PMIx_Get(&pmix_proc, key, NULL, 0, &val))) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Get failed: " << rc << std::endl; + abort(); + } + if (PMIX_BYTE_OBJECT != val->type) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Get returned wrong type: " << val->type << std::endl; + PMIX_VALUE_RELEASE(val); + abort(); + } + + if (pmix_myproc.rank != root) { + if (NULL == val->data.bo.bytes) { + std::cerr << "Rank " << pmix_myproc.rank << " PMIx_Get %d returned NULL pointer\n"; + PMIX_VALUE_RELEASE(val); + abort(); + } + memcpy (buf, val->data.bo.bytes, val->data.bo.size); + } + PMIX_VALUE_RELEASE(val); + + return; +} +#endif + using namespace rocshmem; int main(int argc, char *argv[]) { @@ -40,12 +138,61 @@ int main(int argc, char *argv[]) { * Select a GPU */ char* ompi_local_rank = getenv("OMPI_COMM_WORLD_LOCAL_RANK"); + if (nullptr == ompi_local_rank) { + printf("Could not determine local rank, use Open MPI `mpiexec`\n"); + abort(); + } CHECK_HIP(hipSetDevice(atoi(ompi_local_rank))); /** * Must initialize rocshmem to access arguments needed by the tester. */ +#ifdef HAVE_PMIX + int test_uuid = 0; + char *rocshmem_test_uuid = getenv("ROCSHMEM_TEST_UUID"); + if (rocshmem_test_uuid != nullptr) { + test_uuid = atoi(rocshmem_test_uuid); + } + + if (test_uuid) { + int ret; + int rank, nranks; + rocshmem_uniqueid_t uid; + rocshmem_init_attr_t attr; + + init_pmix(&rank, &nranks); + if (rank == 0) { + ret = rocshmem_get_uniqueid (&uid); + if (ret != ROCSHMEM_SUCCESS) { + std::cout << rank << ": Error in rocshmem_get_uniqueid. Aborting.\n"; + abort(); + } + } + + char key[] = "rocshmem-uuid"; + pmix_bcast(&uid, sizeof(rocshmem_uniqueid_t), key, 0); + + ret = rocshmem_set_attr_uniqueid_args(rank, nranks, &uid, &attr); + if (ret != ROCSHMEM_SUCCESS) { + std::cout << rank << ": Error in rocshmem_set_attr_uniqueid_args. Aborting.\n"; + abort(); + } + + ret = rocshmem_init_attr(ROCSHMEM_INIT_WITH_UNIQUEID, &attr); + if (ret != ROCSHMEM_SUCCESS) { + std::cout << rank << ": Error in rocshmem_init_attr. Aborting.\n"; + abort(); + } + +#ifdef VERBOSE + std::cout << rank << ": rocshmem_init_attr SUCCESS\n"; +#endif + } else { + rocshmem_init(); + } +#else rocshmem_init(); +#endif /** * Now grab the arguments from rocshmem. @@ -77,5 +224,11 @@ int main(int argc, char *argv[]) { */ rocshmem_finalize(); +#ifdef HAVE_PMIX + if (test_uuid) { + PMIx_Finalize(NULL, 0); + } +#endif + return 0; } diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index cccb164c99..2661cce5a0 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -617,7 +617,7 @@ void flush_hdp() { } void Tester::barrier() { - MPI_Barrier(MPI_COMM_WORLD); + rocshmem_barrier_all(); flush_hdp(); } diff --git a/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.cpp b/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.cpp index 7f79dc1c6d..b99d123a2d 100644 --- a/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.cpp +++ b/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.cpp @@ -38,12 +38,15 @@ TEST_F(SymmetricHeapTestFixture, malloc_free) { TEST_F(SymmetricHeapTestFixture, window_info) { auto win_info_ptr{symmetric_heap_.get_window_info()}; - void *window_base_addr{nullptr}; - int flag{0}; - MPI_Win_get_attr(win_info_ptr->get_win(), MPI_WIN_BASE, &window_base_addr, - &flag); - ASSERT_NE(0, flag); - ASSERT_NE(nullptr, window_base_addr); + WindowInfoMPI* window_info_mpi = dynamic_cast(win_info_ptr); + if (window_info_mpi) { + void *window_base_addr{nullptr}; + int flag{0}; + MPI_Win_get_attr(window_info_mpi->get_win(), MPI_WIN_BASE, &window_base_addr, + &flag); + ASSERT_NE(0, flag); + ASSERT_NE(nullptr, window_base_addr); + } } TEST_F(SymmetricHeapTestFixture, heap_bases) { diff --git a/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.hpp b/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.hpp index a7cb2bd1d2..35b4ab6ca2 100644 --- a/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.hpp +++ b/projects/rocshmem/tests/unit_tests/symmetric_heap_gtest.hpp @@ -37,7 +37,7 @@ class SymmetricHeapTestFixture : public ::testing::Test /** * @brief Symmetric heap object */ - SymmetricHeap symmetric_heap_ {}; + SymmetricHeap symmetric_heap_ {MPI_COMM_WORLD}; }; } // namespace rocshmem