Remove MPI compile-time dependency (#264)
* use dlsym for MPI functions
to allow compiling without MPI support, convert the usage of MPI functions and symbols to be based on a dlopen/dlsym based mechanism. Turns out this cannot be done entirely vendor neutral, slightly different solutions might be required for Open MPI, MPICH and the new MPI ABI.
* checkpoint
more work to be done.
* checkpoint 2
* checkpoint 3
* checkpoint 4
examples compile and link correctly
* checkpoitn 5 (I think)
* Checkpoitn 6
* dyld-mpi: adapt GDA
* dyldmpi: tests that depend on MPI need to link with it themselves
* do not ../mpi_instance.h
* dyldmpi: make the symetricHeapTestFixture compile
* dyldmpi: Change cmakery, compiles and run gda w/o external MPI
* Make it also compile in external MPI mode
* dyldmpi: ipc unit tests compile but do not link
* dyldmpi: new approach, if external mpi required, link with mpi,
otherwise use ompi5 abi
* C-style comments in cmakelist..
* dyldmpi: examples: do not fail compiling if MPI not found at build time,
instead do not compile the MPI required examples
* more updates to CMake logic
* convert RO backend
and a few other cleanups
* update some unit tests
to work with the dlopen MPI environment correctly.
---------
Co-authored-by: Aurelien Bouteiller <abouteil@amd.com>
[ROCm/rocshmem commit: e4c427a736]
This commit is contained in:
@@ -59,10 +59,11 @@ option(USE_SHARED_CTX "Request support for shared ctx between WG" OFF)
|
||||
option(USE_SINGLE_NODE "Enable single node support only." OFF)
|
||||
option(USE_HDP_FLUSH "Force flush the HDP cache." OFF)
|
||||
option(USE_HDP_FLUSH_HOST_SIDE "Use a polling thread to flush the HDP cache on the host." OFF)
|
||||
option(USE_EXTERNAL_MPI "Link with an external MPI (required if used MPI is ABI incompatible with Open MPI v5" OFF)
|
||||
|
||||
option(BUILD_FUNCTIONAL_TESTS "Build the functional tests" OFF)
|
||||
option(BUILD_FUNCTIONAL_TESTS "Build the functional tests (Requires MPI)" OFF)
|
||||
option(BUILD_EXAMPLES "Build the examples" ON)
|
||||
option(BUILD_UNIT_TESTS "Build the unit tests" OFF)
|
||||
option(BUILD_UNIT_TESTS "Build the unit tests (Requires MPI)" OFF)
|
||||
option(BUILD_TESTS_ONLY "Build only tests. Used to link agains rocSHMEM in a ROCm Release" OFF)
|
||||
option(BUILD_TOOLS "Build binary tools (e.g., rocshmem_info)" ON)
|
||||
|
||||
@@ -150,7 +151,21 @@ if (NOT BUILD_TESTS_ONLY)
|
||||
#############################################################################
|
||||
# PACKAGE DEPENDENCIES
|
||||
#############################################################################
|
||||
find_package(MPI REQUIRED)
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
set(HAVE_EXTERNAL_MPI ON)
|
||||
else()
|
||||
set(HAVE_EXTERNAL_MPI OFF)
|
||||
set(BUILD_FUNCTIONAL_TESTS OFF)
|
||||
set(BUILD_UNIT_TESTS OFF)
|
||||
endif()
|
||||
|
||||
if (USE_EXTERNAL_MPI)
|
||||
if(NOT HAVE_EXTERNAL_MPI)
|
||||
message(FATAL_ERROR "External MPI support requested but MPI support not found. Build Aborted")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
find_package(hsa-runtime64 REQUIRED)
|
||||
|
||||
@@ -182,8 +197,8 @@ if (NOT BUILD_TESTS_ONLY)
|
||||
target_link_libraries(
|
||||
${PROJECT_NAME}
|
||||
PUBLIC
|
||||
$<$<BOOL:${HAVE_EXTERNAL_MPI}>:MPI::MPI_CXX>
|
||||
Threads::Threads
|
||||
MPI::MPI_CXX
|
||||
hip::device
|
||||
hip::host
|
||||
hsa-runtime64::hsa-runtime64
|
||||
|
||||
@@ -45,3 +45,5 @@
|
||||
#cmakedefine GDA_IONIC
|
||||
#cmakedefine GDA_BNXT
|
||||
#cmakedefine GDA_MLX5
|
||||
#cmakedefine USE_EXTERNAL_MPI
|
||||
#cmakedefine HAVE_EXTERNAL_MPI
|
||||
|
||||
@@ -29,6 +29,12 @@ cmake_minimum_required(VERSION 3.16.3 FATAL_ERROR)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/setup_project.cmake)
|
||||
project(rocshmem_examples VERSION 1.0.0 LANGUAGES CXX)
|
||||
|
||||
find_package(MPI)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
if (NOT TARGET roc::rocshmem)
|
||||
find_package(rocshmem REQUIRED PATHS /opt/rocm)
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# SOURCES
|
||||
###############################################################################
|
||||
@@ -38,13 +44,11 @@ set(EXAMPLE_SOURCES
|
||||
rocshmem_broadcast_test.cc
|
||||
rocshmem_getmem_test.cc
|
||||
rocshmem_put_signal_test.cc
|
||||
rocshmem_init_attr_test.cc
|
||||
)
|
||||
|
||||
find_package(MPI REQUIRED)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
if(NOT TARGET roc::rocshmem)
|
||||
find_package(rocshmem REQUIRED PATHS /opt/rocm)
|
||||
if (HAVE_EXTERNAL_MPI)
|
||||
list(APPEND EXAMPLE_SOURCES
|
||||
rocshmem_init_attr_test.cc)
|
||||
endif()
|
||||
|
||||
foreach(SOURCE_FILE IN LISTS EXAMPLE_SOURCES)
|
||||
@@ -55,7 +59,7 @@ foreach(SOURCE_FILE IN LISTS EXAMPLE_SOURCES)
|
||||
target_link_libraries(
|
||||
${EXECUTABLE_NAME}
|
||||
PRIVATE
|
||||
MPI::MPI_CXX
|
||||
$<TARGET_NAME_IF_EXISTS:MPI::MPI_CXX>
|
||||
roc::rocshmem
|
||||
)
|
||||
endforeach()
|
||||
|
||||
@@ -59,6 +59,7 @@
|
||||
*/
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "util.h"
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
exit(error); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@
|
||||
#define LIBRARY_INCLUDE_ROCSHMEM_HPP
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem_config.h"
|
||||
#include "rocshmem_common.hpp"
|
||||
@@ -36,6 +35,10 @@
|
||||
#include "rocshmem_COLL.hpp"
|
||||
#include "rocshmem_P2P_SYNC.hpp"
|
||||
#include "rocshmem_RMA_X.hpp"
|
||||
#if defined(HAVE_EXTERNAL_MPI)
|
||||
#include <mpi.h>
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @file rocshmem.hpp
|
||||
* @brief Public header for rocSHMEM device and host libraries.
|
||||
@@ -57,13 +60,22 @@ constexpr char VERSION[] = "3.0.0";
|
||||
/******************************************************************************
|
||||
**************************** HOST INTERFACE **********************************
|
||||
*****************************************************************************/
|
||||
#if defined(HAVE_EXTERNAL_MPI)
|
||||
/**
|
||||
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
|
||||
*
|
||||
* @param[in] comm (Optional) MPI Communicator that rocSHMEM will be using
|
||||
* @param[in] comm MPI Communicator that rocSHMEM will be using
|
||||
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
|
||||
*/
|
||||
__host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD);
|
||||
[[deprecated]] __host__ void rocshmem_init(MPI_Comm comm);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
|
||||
* This is equivalent to the previous function, using implicitely
|
||||
* MPI_COMM_WORLD for initialization
|
||||
*/
|
||||
__host__ void rocshmem_init(void);
|
||||
|
||||
/**
|
||||
* @brief Query rocSHMEM context from host API
|
||||
@@ -88,6 +100,7 @@ __host__ void * rocshmem_get_device_ctx();
|
||||
*/
|
||||
__host__ void *rocshmem_ptr(void *dest, int pe);
|
||||
|
||||
#if defined(HAVE_EXTERNAL_MPI)
|
||||
/**
|
||||
* @brief Initialize the rocSHMEM runtime and underlying transport layer
|
||||
* with an attempt to enable the requested thread support.
|
||||
@@ -102,8 +115,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
|
||||
* @return int returns 0 upon success; otherwise, it returns a nonzero
|
||||
* value
|
||||
*/
|
||||
__host__ int rocshmem_init_thread(int requested, int *provided,
|
||||
MPI_Comm comm = MPI_COMM_WORLD);
|
||||
[[deprecated]] __host__ int rocshmem_init_thread(int requested, int *provided,
|
||||
MPI_Comm comm);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Initialize the rocSHMEM runtime and underlying transport layer
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
|
||||
#define LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
|
||||
|
||||
#if defined(HAVE_EXTERNAL_MPI)
|
||||
#include <mpi.h>
|
||||
#endif
|
||||
|
||||
#if defined(c_plusplus) || defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if !defined(MPI_VERSION)
|
||||
// Open MPI based values for the constants/handles etc.
|
||||
// Even though we did not include an external MPI header file
|
||||
// The includer may have (e.g., a unit test).
|
||||
|
||||
typedef void* MPI_Comm;
|
||||
typedef void* MPI_Win;
|
||||
typedef void* MPI_Group;
|
||||
typedef void* MPI_Op;
|
||||
typedef void* MPI_Datatype;
|
||||
typedef void* MPI_Request;
|
||||
typedef void* MPI_Info;
|
||||
|
||||
struct ompi_status_public_t {
|
||||
int MPI_SOURCE;
|
||||
int MPI_TAG;
|
||||
int MPI_ERROR;
|
||||
int _cancelled;
|
||||
size_t _ucount;
|
||||
};
|
||||
typedef struct ompi_status_public_t MPI_Status;
|
||||
|
||||
#define MPI_Aint uint64_t
|
||||
|
||||
#define MPI_UNDEFINED -32766
|
||||
#define MPI_THREAD_MULTIPLE 3
|
||||
#define MPI_SUCCESS 0
|
||||
#define MPI_IN_PLACE (void*)1
|
||||
#define MPI_MODE_NOCHECK 1
|
||||
#define MPI_COMM_TYPE_SHARED 0
|
||||
|
||||
#define MPI_Aint_diff(addr1, addr2) ((MPI_Aint) ((char *) (addr1) - (char *) (addr2)))
|
||||
|
||||
struct ompi_internal_symbols_t {
|
||||
void *ompi_mpi_comm_world;
|
||||
void *ompi_mpi_comm_null;
|
||||
void *ompi_request_null;
|
||||
void *ompi_mpi_info_null;
|
||||
void *ompi_mpi_datatype_null;
|
||||
|
||||
void *ompi_mpi_op_max;
|
||||
void *ompi_mpi_op_min;
|
||||
void *ompi_mpi_op_sum;
|
||||
void *ompi_mpi_op_prod;
|
||||
void *ompi_mpi_op_band;
|
||||
void *ompi_mpi_op_bor;
|
||||
void *ompi_mpi_op_bxor;
|
||||
void *ompi_mpi_op_replace;
|
||||
void *ompi_mpi_op_no_op;
|
||||
|
||||
void *ompi_mpi_char;
|
||||
void *ompi_mpi_unsigned_char;
|
||||
void *ompi_mpi_signed_char;
|
||||
void *ompi_mpi_short;
|
||||
void *ompi_mpi_unsigned_short;
|
||||
void *ompi_mpi_int;
|
||||
void *ompi_mpi_unsigned;
|
||||
void *ompi_mpi_long;
|
||||
void *ompi_mpi_unsigned_long;
|
||||
void *ompi_mpi_long_long_int;
|
||||
void *ompi_mpi_unsigned_long_long;
|
||||
void *ompi_mpi_float;
|
||||
void *ompi_mpi_double;
|
||||
void *ompi_mpi_long_double;
|
||||
};
|
||||
|
||||
extern struct ompi_internal_symbols_t ompi_symbols_;
|
||||
|
||||
#define OMPI_PREDEFINED_GLOBAL(type, global) (static_cast<type> (global))
|
||||
#define MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_world)
|
||||
#define MPI_COMM_NULL OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_null)
|
||||
#define MPI_REQUEST_NULL OMPI_PREDEFINED_GLOBAL(MPI_Request, ompi_symbols_.ompi_request_null)
|
||||
#define MPI_WIN_NULL OMPI_PREDEFINED_GLOBAL(MPI_Win, ompi_symbols_.ompi_mpi_win_null)
|
||||
#define MPI_INFO_NULL OMPI_PREDEFINED_GLOBAL(MPI_Info, ompi_symbols_.ompi_mpi_info_null)
|
||||
|
||||
#define MPI_MAX OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_max)
|
||||
#define MPI_MIN OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_min)
|
||||
#define MPI_SUM OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_sum)
|
||||
#define MPI_PROD OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_prod)
|
||||
#define MPI_BAND OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_band)
|
||||
#define MPI_BOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bor)
|
||||
#define MPI_BXOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bxor)
|
||||
#define MPI_REPLACE OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_replace)
|
||||
#define MPI_NO_OP OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_no_op)
|
||||
|
||||
#define MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_datatype_null)
|
||||
#define MPI_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_char)
|
||||
#define MPI_UNSIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_char)
|
||||
#define MPI_SIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_signed_char)
|
||||
#define MPI_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_short)
|
||||
#define MPI_UNSIGNED_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_short)
|
||||
#define MPI_INT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_int)
|
||||
#define MPI_UNSIGNED OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned)
|
||||
#define MPI_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long)
|
||||
#define MPI_UNSIGNED_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long)
|
||||
#define MPI_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_long_int)
|
||||
#define MPI_UNSIGNED_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long_long)
|
||||
#define MPI_FLOAT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_float)
|
||||
#define MPI_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_double)
|
||||
#define MPI_LONG_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_double)
|
||||
|
||||
#endif //!defined(MPI_VERSION)
|
||||
|
||||
#if defined(c_plusplus) || defined(__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
|
||||
@@ -80,7 +80,7 @@ void init_g_ret(SymmetricHeap* heap_handle, MPI_Comm thread_comm, int num_wg,
|
||||
* Make sure that all processing elements have done this before
|
||||
* continuing.
|
||||
*/
|
||||
MPI_Barrier(thread_comm);
|
||||
mpilib_ftable_.Barrier(thread_comm);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -26,7 +26,6 @@
|
||||
#define LIBRARY_SRC_ATOMIC_RETURN_HPP_
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "memory/symmetric_heap.hpp"
|
||||
#include "util.hpp"
|
||||
|
||||
@@ -59,6 +59,7 @@ Backend::Backend(MPI_Comm comm) : heap(comm, nullptr) {
|
||||
Backend::Backend(TcpBootstrap* bootstrap) : heap(MPI_COMM_NULL, bootstrap) {
|
||||
init();
|
||||
backend_bootstr = bootstrap;
|
||||
backend_comm = MPI_COMM_NULL;
|
||||
|
||||
my_pe = bootstrap->getRank();
|
||||
num_pes = bootstrap->getNranks();
|
||||
@@ -106,9 +107,9 @@ void Backend::init(void) {
|
||||
|
||||
void Backend::init_mpi_once(MPI_Comm comm) {
|
||||
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
|
||||
NET_CHECK(MPI_Comm_dup(comm, &backend_comm));
|
||||
NET_CHECK(MPI_Comm_size(backend_comm, &num_pes));
|
||||
NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe));
|
||||
NET_CHECK(mpilib_ftable_.Comm_dup(comm, &backend_comm));
|
||||
NET_CHECK(mpilib_ftable_.Comm_size(backend_comm, &num_pes));
|
||||
NET_CHECK(mpilib_ftable_.Comm_rank(backend_comm, &my_pe));
|
||||
}
|
||||
|
||||
void Backend::track_ctx(Context* ctx) {
|
||||
@@ -140,7 +141,7 @@ void Backend::destroy_remaining_ctxs() {
|
||||
Backend::~Backend() {
|
||||
CHECK_HIP(hipFree(print_lock));
|
||||
if (backend_comm != MPI_COMM_NULL)
|
||||
NET_CHECK(MPI_Comm_free(&backend_comm));
|
||||
NET_CHECK(mpilib_ftable_.Comm_free(&backend_comm));
|
||||
}
|
||||
|
||||
void Backend::dump_stats() {
|
||||
|
||||
@@ -33,12 +33,11 @@
|
||||
* It is the top-level interface for these resources.
|
||||
*/
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "rocshmem/rocshmem.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
#include "backend_type.hpp"
|
||||
#include "ipc_policy.hpp"
|
||||
#include "memory/symmetric_heap.hpp"
|
||||
@@ -225,7 +224,7 @@ class Backend {
|
||||
* @todo document where this is used and try to coalesce this into another
|
||||
* class
|
||||
*/
|
||||
MPI_Comm backend_comm{MPI_COMM_NULL};
|
||||
MPI_Comm backend_comm;
|
||||
|
||||
/**
|
||||
* @todo document where this is used
|
||||
|
||||
@@ -24,15 +24,16 @@
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "backend_gda.hpp"
|
||||
#include "gda_team.hpp"
|
||||
#include "util.hpp"
|
||||
#include "topology.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include "backend_gda.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
#include "gda_team.hpp"
|
||||
#include "util.hpp"
|
||||
#include "topology.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
#define NET_CHECK(cmd) { \
|
||||
@@ -303,7 +304,7 @@ void GDABackend::create_new_team([[maybe_unused]] Team *parent_team,
|
||||
* the pool of available work arrays.
|
||||
*/
|
||||
if (team_comm != MPI_COMM_NULL) {
|
||||
NET_CHECK(MPI_Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_,
|
||||
NET_CHECK(mpilib_ftable_.Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_,
|
||||
MPI_CHAR, MPI_BAND, team_comm));
|
||||
} else {
|
||||
Allreduce_char_BAND (team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, parent_team);
|
||||
@@ -361,7 +362,7 @@ void GDABackend::dump_backend_stats() {
|
||||
|
||||
__host__ void GDABackend::global_exit(int status) {
|
||||
if (backend_comm != MPI_COMM_NULL)
|
||||
MPI_Abort(backend_comm, status);
|
||||
mpilib_ftable_.Abort(backend_comm, status);
|
||||
else
|
||||
abort();
|
||||
}
|
||||
@@ -536,7 +537,7 @@ void GDABackend::setup_teams() {
|
||||
|
||||
void GDABackend::rte_barrier() {
|
||||
if (backend_comm != MPI_COMM_NULL) {
|
||||
NET_CHECK(MPI_Barrier(backend_comm));
|
||||
NET_CHECK(mpilib_ftable_.Barrier(backend_comm));
|
||||
} else {
|
||||
backend_bootstr->barrier();
|
||||
}
|
||||
@@ -668,7 +669,7 @@ void GDABackend::exchange_qp_dest_info() {
|
||||
|
||||
for (int i = 0; i < maximum_num_contexts_ + 1; i++) {
|
||||
if (backend_comm != MPI_COMM_NULL) {
|
||||
MPI_Alltoall(MPI_IN_PLACE, sizeof(dest_info_t), MPI_CHAR, dest_info.data() + i * num_pes, sizeof(dest_info_t), MPI_CHAR, backend_comm);
|
||||
mpilib_ftable_.Alltoall(MPI_IN_PLACE, sizeof(dest_info_t), MPI_CHAR, dest_info.data() + i * num_pes, sizeof(dest_info_t), MPI_CHAR, backend_comm);
|
||||
} else {
|
||||
Alltoall_char_inplace(reinterpret_cast<char*>(dest_info.data() + i * num_pes), sizeof(dest_info_t), ROCSHMEM_TEAM_WORLD);
|
||||
}
|
||||
@@ -695,7 +696,7 @@ void GDABackend::setup_heap_memory_rkey() {
|
||||
CHECK_HIP(hipStreamSynchronize(stream));
|
||||
|
||||
if (backend_comm != MPI_COMM_NULL)
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(uint32_t), MPI_CHAR, host_rkey_cpy, sizeof(uint32_t), MPI_CHAR, backend_comm);
|
||||
mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(uint32_t), MPI_CHAR, host_rkey_cpy, sizeof(uint32_t), MPI_CHAR, backend_comm);
|
||||
else
|
||||
backend_bootstr->allGather(host_rkey_cpy, sizeof(uint32_t));
|
||||
|
||||
|
||||
@@ -66,24 +66,6 @@ struct mlx5dv_funcs_t {
|
||||
int (*init_obj)(struct mlx5dv_obj *obj, uint64_t obj_type);
|
||||
};
|
||||
|
||||
/* Helper Macros for handling dynamic libraries */
|
||||
#define PPCAT_NX(prefix, func_name) prefix##func_name
|
||||
#define PPCAT(prefix, func_name) PPCAT_NX(prefix, func_name)
|
||||
|
||||
#define STRINGIFY_NX(name) #name
|
||||
#define STRINGIFY(name) STRINGIFY_NX(name)
|
||||
|
||||
#define DLSYM_HELPER(func_struct, prefix, handle, func_name) \
|
||||
do { \
|
||||
*(void **) (&func_struct.func_name) = dlsym(handle, STRINGIFY(PPCAT(prefix, func_name))); \
|
||||
if (!func_struct.func_name) { \
|
||||
DPRINTF("Failed to find function %s \n", STRINGIFY(PPCAT(prefix, func_name))); \
|
||||
dlclose(handle); \
|
||||
handle = nullptr; \
|
||||
return ROCSHMEM_ERROR; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
class GDAContext;
|
||||
|
||||
@@ -22,11 +22,9 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "context_gda_host.hpp"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "context_gda_host.hpp"
|
||||
#include "backend_type.hpp"
|
||||
#include "context_incl.hpp"
|
||||
#include "backend_gda.hpp"
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "constants.hpp"
|
||||
#include "backend_type.hpp"
|
||||
#include "backend_gda.hpp"
|
||||
#include "rocshmem/rocshmem_mpi.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@
|
||||
|
||||
#include "host.hpp"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "host_helpers.hpp"
|
||||
#include "memory/window_info.hpp"
|
||||
@@ -98,9 +96,9 @@ __host__ HostInterface::HostInterface(HdpPolicy* hdp_policy,
|
||||
* Duplicate a communicator from roc_shem's comm
|
||||
* world for the host interface
|
||||
*/
|
||||
MPI_Comm_dup(rocshmem_comm, &host_comm_world_);
|
||||
MPI_Comm_rank(host_comm_world_, &my_pe_);
|
||||
MPI_Comm_rank(host_comm_world_, &num_pes_);
|
||||
mpilib_ftable_.Comm_dup(rocshmem_comm, &host_comm_world_);
|
||||
mpilib_ftable_.Comm_rank(host_comm_world_, &my_pe_);
|
||||
mpilib_ftable_.Comm_size(host_comm_world_, &num_pes_);
|
||||
|
||||
/*
|
||||
* Create an MPI window on the HDP so that it can be flushed
|
||||
@@ -136,18 +134,18 @@ __host__ HostInterface::HostInterface(HdpPolicy* hdp_policy,
|
||||
|
||||
#if defined USE_HDP_FLUSH
|
||||
__host__ void HostInterface::create_hdp_window() {
|
||||
MPI_Win_create(hdp_policy_->get_hdp_flush_ptr(),
|
||||
sizeof(unsigned int), /* size of window */
|
||||
sizeof(unsigned int), /* displacement */
|
||||
MPI_INFO_NULL, host_comm_world_, &hdp_win);
|
||||
|
||||
mpilib_ftable_.Win_create(hdp_policy_->get_hdp_flush_ptr(),
|
||||
sizeof(unsigned int), /* size of window */
|
||||
sizeof(unsigned int), /* displacement */
|
||||
MPI_INFO_NULL, host_comm_world_, &hdp_win);
|
||||
|
||||
/*
|
||||
* Start a shared access epoch on windows of all ranks,
|
||||
* and let the library there is no need to check for
|
||||
* lock exclusivity during operations on this window
|
||||
* (MPI_MODE_NOCHECK).
|
||||
*/
|
||||
MPI_Win_lock_all(MPI_MODE_NOCHECK, hdp_win);
|
||||
mpilib_ftable_.Win_lock_all(MPI_MODE_NOCHECK, hdp_win);
|
||||
}
|
||||
#endif // USE_HDP_FLUSH
|
||||
|
||||
@@ -188,9 +186,9 @@ __host__ HostInterface::HostInterface(HdpPolicy* hdp_policy,
|
||||
|
||||
__host__ HostInterface::~HostInterface() {
|
||||
#if defined USE_HDP_FLUSH
|
||||
MPI_Win_unlock_all(hdp_win);
|
||||
mpilib_ftable_.Win_unlock_all(hdp_win);
|
||||
|
||||
MPI_Win_free(&hdp_win);
|
||||
mpilib_ftable_.Win_free(&hdp_win);
|
||||
#endif // USE_HDP_FLUSH
|
||||
|
||||
/* Detroy the pool of contexts */
|
||||
@@ -203,7 +201,7 @@ __host__ HostInterface::~HostInterface() {
|
||||
}
|
||||
|
||||
if (host_comm_world_ != MPI_COMM_NULL) {
|
||||
MPI_Comm_free(&host_comm_world_);
|
||||
mpilib_ftable_.Comm_free(&host_comm_world_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,7 +234,7 @@ __host__ void HostInterface::putmem(void* dest, const void* source,
|
||||
}
|
||||
initiate_put(dest, source, nelems, pe, window_info_mpi);
|
||||
|
||||
MPI_Win_flush_local(pe, window_info_mpi->get_win());
|
||||
mpilib_ftable_.Win_flush_local(pe, window_info_mpi->get_win());
|
||||
}
|
||||
|
||||
__host__ void HostInterface::getmem(void* dest, const void* source,
|
||||
@@ -248,7 +246,7 @@ __host__ void HostInterface::getmem(void* dest, const void* source,
|
||||
}
|
||||
initiate_get(dest, source, nelems, pe, window_info_mpi);
|
||||
|
||||
MPI_Win_flush_local(pe, window_info_mpi->get_win());
|
||||
mpilib_ftable_.Win_flush_local(pe, window_info_mpi->get_win());
|
||||
|
||||
/*
|
||||
* Flush local HDP to ensure that the NIC's write
|
||||
@@ -295,8 +293,8 @@ __host__ void HostInterface::quiet(WindowInfo* window_info) {
|
||||
|
||||
__host__ void HostInterface::sync_all(WindowInfo* window_info) {
|
||||
WindowInfoMPI* window_info_mpi = dynamic_cast<WindowInfoMPI*>(window_info);
|
||||
if (!window_info_mpi) {
|
||||
MPI_Win_sync(window_info_mpi->get_win());
|
||||
if (window_info_mpi) {
|
||||
mpilib_ftable_.Win_sync(window_info_mpi->get_win());
|
||||
|
||||
hdp_policy_->hdp_flush();
|
||||
/*
|
||||
@@ -305,7 +303,7 @@ __host__ void HostInterface::sync_all(WindowInfo* window_info) {
|
||||
* participating.
|
||||
*/
|
||||
|
||||
MPI_Barrier(host_comm_world_);
|
||||
mpilib_ftable_.Barrier(host_comm_world_);
|
||||
} else {
|
||||
hdp_policy_->hdp_flush();
|
||||
host_bootstrap_->barrier();
|
||||
@@ -325,7 +323,7 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) {
|
||||
*/
|
||||
hdp_policy_->hdp_flush();
|
||||
|
||||
MPI_Barrier(host_comm_world_);
|
||||
mpilib_ftable_.Barrier(host_comm_world_);
|
||||
} else {
|
||||
// Probably not required
|
||||
hdp_policy_->hdp_flush();
|
||||
@@ -337,7 +335,7 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) {
|
||||
|
||||
__host__ void HostInterface::barrier_for_sync() {
|
||||
if (host_comm_world_ != MPI_COMM_NULL) {
|
||||
MPI_Barrier(host_comm_world_);
|
||||
mpilib_ftable_.Barrier(host_comm_world_);
|
||||
} else {
|
||||
host_bootstrap_->barrier();
|
||||
}
|
||||
|
||||
@@ -34,8 +34,6 @@
|
||||
* any backend type.
|
||||
*/
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "rocshmem/rocshmem.hpp"
|
||||
@@ -43,6 +41,7 @@
|
||||
#include "memory/symmetric_heap.hpp"
|
||||
#include "memory/window_info.hpp"
|
||||
#include "bootstrap/bootstrap.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
@@ -268,17 +267,17 @@ class HostInterface {
|
||||
if (i == my_pe_) {
|
||||
continue;
|
||||
}
|
||||
MPI_Put(&flush_val, 1, MPI_UNSIGNED, i, 0, 1, MPI_UNSIGNED, hdp_win);
|
||||
mpilib_ftable_.Put(&flush_val, 1, MPI_UNSIGNED, i, 0, 1, MPI_UNSIGNED, hdp_win);
|
||||
}
|
||||
MPI_Win_flush_all(hdp_win);
|
||||
mpilib_ftable_.Win_flush_all(hdp_win);
|
||||
#endif // USE_HDP_FLUSH
|
||||
}
|
||||
|
||||
__host__ void flush_remote_hdp(int pe) {
|
||||
#if defined USE_HDP_FLUSH
|
||||
unsigned flush_val{HdpPolicy::HDP_FLUSH_VAL};
|
||||
MPI_Put(&flush_val, 1, MPI_UNSIGNED, pe, 0, 1, MPI_UNSIGNED, hdp_win);
|
||||
MPI_Win_flush(pe, hdp_win);
|
||||
mpilib_ftable_.Put(&flush_val, 1, MPI_UNSIGNED, pe, 0, 1, MPI_UNSIGNED, hdp_win);
|
||||
mpilib_ftable_.Win_flush(pe, hdp_win);
|
||||
#endif // USE_HDP_FLUSH
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
|
||||
#include "host.hpp"
|
||||
#include "memory/window_info.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -42,15 +43,15 @@ __host__ inline MPI_Aint HostInterface::compute_offset(
|
||||
MPI_Aint dest_disp{};
|
||||
MPI_Aint start_disp{};
|
||||
|
||||
MPI_Get_address(dest, &dest_disp);
|
||||
MPI_Get_address(win_start, &start_disp);
|
||||
mpilib_ftable_.Get_address(dest, &dest_disp);
|
||||
mpilib_ftable_.Get_address(win_start, &start_disp);
|
||||
|
||||
return MPI_Aint_diff(dest_disp, start_disp);
|
||||
}
|
||||
|
||||
__host__ inline void HostInterface::complete_all(MPI_Win win) {
|
||||
MPI_Win_flush_all(win); /* RMA operations */
|
||||
MPI_Win_sync(win); /* memory stores */
|
||||
mpilib_ftable_.Win_flush_all(win); /* RMA operations */
|
||||
mpilib_ftable_.Win_sync(win); /* memory stores */
|
||||
}
|
||||
|
||||
__host__ inline void HostInterface::initiate_put(void* dest, const void* source,
|
||||
@@ -74,7 +75,7 @@ __host__ inline void HostInterface::initiate_put(void* dest, const void* source,
|
||||
hdp_policy_->hdp_flush();
|
||||
|
||||
/* Offload remote write operation to MPI */
|
||||
MPI_Put(source, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win);
|
||||
mpilib_ftable_.Put(source, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win);
|
||||
}
|
||||
|
||||
__host__ inline void HostInterface::initiate_get(void* dest, const void* source,
|
||||
@@ -88,7 +89,7 @@ __host__ inline void HostInterface::initiate_get(void* dest, const void* source,
|
||||
MPI_Aint offset = compute_offset(source, win_start, win_end);
|
||||
|
||||
/* Offload remote fetch operation to MPI */
|
||||
MPI_Get(dest, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win);
|
||||
mpilib_ftable_.Get(dest, nelems, MPI_CHAR, pe, offset, nelems, MPI_CHAR, win);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -74,7 +74,7 @@ __host__ T HostInterface::g(const T* source, int pe, WindowInfo* window_info) {
|
||||
*/
|
||||
getmem_nbi(&ret, source, sizeof(T), pe, window_info);
|
||||
|
||||
MPI_Win_flush_local(pe, window_info_mpi->get_win());
|
||||
mpilib_ftable_.Win_flush_local(pe, window_info_mpi->get_win());
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -101,7 +101,7 @@ __host__ MPI_Comm HostInterface::get_mpi_comm(int pe_start, int log_pe_stride,
|
||||
* First, check to see if the active set is the same as COMM_WORLD
|
||||
*/
|
||||
int comm_world_size{-1};
|
||||
MPI_Comm_size(host_comm_world_, &comm_world_size);
|
||||
mpilib_ftable_.Comm_size(host_comm_world_, &comm_world_size);
|
||||
|
||||
if (pe_start == 0 && log_pe_stride == 0 && pe_size == comm_world_size) {
|
||||
/*
|
||||
@@ -139,12 +139,12 @@ __host__ MPI_Comm HostInterface::get_mpi_comm(int pe_start, int log_pe_stride,
|
||||
MPI_Group comm_world_group{};
|
||||
MPI_Group active_set_group{};
|
||||
|
||||
MPI_Comm_group(host_comm_world_, &comm_world_group);
|
||||
mpilib_ftable_.Comm_group(host_comm_world_, &comm_world_group);
|
||||
|
||||
MPI_Group_incl(comm_world_group, pe_size, active_set_ranks.data(),
|
||||
&active_set_group);
|
||||
mpilib_ftable_.Group_incl(comm_world_group, pe_size, active_set_ranks.data(),
|
||||
&active_set_group);
|
||||
|
||||
MPI_Comm_create_group(host_comm_world_, active_set_group, 0,
|
||||
mpilib_ftable_.Comm_create_group(host_comm_world_, active_set_group, 0,
|
||||
&active_set_comm);
|
||||
|
||||
/*
|
||||
@@ -168,7 +168,7 @@ __host__ void HostInterface::broadcast_internal(MPI_Comm mpi_comm, T* dest,
|
||||
*/
|
||||
int active_set_rank{-1};
|
||||
void* buffer{nullptr};
|
||||
MPI_Comm_rank(mpi_comm, &active_set_rank);
|
||||
mpilib_ftable_.Comm_rank(mpi_comm, &active_set_rank);
|
||||
if (pe_root == active_set_rank) {
|
||||
buffer = const_cast<T*>(source);
|
||||
} else {
|
||||
@@ -183,7 +183,7 @@ __host__ void HostInterface::broadcast_internal(MPI_Comm mpi_comm, T* dest,
|
||||
/*
|
||||
* Offload the broadcast to MPI
|
||||
*/
|
||||
MPI_Bcast(buffer, nelems * sizeof(T), MPI_CHAR, pe_root, mpi_comm);
|
||||
mpilib_ftable_.Bcast(buffer, nelems * sizeof(T), MPI_CHAR, pe_root, mpi_comm);
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -312,9 +312,9 @@ __host__ T HostInterface::amo_fetch_add(void* dst, T value, int pe,
|
||||
T ret{};
|
||||
MPI_Win win{window_info_mpi->get_win()};
|
||||
MPI_Datatype mpi_type{get_mpi_type<T>()};
|
||||
MPI_Fetch_and_op(&value, &ret, mpi_type, pe, offset, MPI_SUM, win);
|
||||
mpilib_ftable_.Fetch_and_op(&value, &ret, mpi_type, pe, offset, MPI_SUM, win);
|
||||
|
||||
MPI_Win_flush_local(pe, win);
|
||||
mpilib_ftable_.Win_flush_local(pe, win);
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -341,9 +341,9 @@ __host__ T HostInterface::amo_fetch_cas(void* dst, T value, T cond, int pe,
|
||||
T ret{};
|
||||
MPI_Win win{window_info_mpi->get_win()};
|
||||
MPI_Datatype mpi_type{get_mpi_type<T>()};
|
||||
MPI_Compare_and_swap(&value, &cond, &ret, mpi_type, pe, offset, win);
|
||||
mpilib_ftable_.Compare_and_swap(&value, &cond, &ret, mpi_type, pe, offset, win);
|
||||
|
||||
MPI_Win_flush_local(pe, win);
|
||||
mpilib_ftable_.Win_flush_local(pe, win);
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -368,8 +368,8 @@ __host__ void HostInterface::to_all_internal(MPI_Comm mpi_comm, T* dest,
|
||||
/*
|
||||
* Offload the allreduce to MPI
|
||||
*/
|
||||
MPI_Allreduce((dest == source) ? MPI_IN_PLACE : send_buf, recv_buf, nreduce,
|
||||
mpi_type, mpi_op, mpi_comm);
|
||||
mpilib_ftable_.Allreduce((dest == source) ? MPI_IN_PLACE : send_buf, recv_buf, nreduce,
|
||||
mpi_type, mpi_op, mpi_comm);
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -453,9 +453,9 @@ __host__ inline int HostInterface::test_and_compare(MPI_Aint offset,
|
||||
*/
|
||||
hdp_policy_->hdp_flush();
|
||||
|
||||
MPI_Fetch_and_op(nullptr, // because no operation happening here
|
||||
&fetched_val, mpi_type, my_pe_, offset, MPI_NO_OP, win);
|
||||
MPI_Win_flush_local(my_pe_, win);
|
||||
mpilib_ftable_.Fetch_and_op(nullptr, // because no operation happening here
|
||||
&fetched_val, mpi_type, my_pe_, offset, MPI_NO_OP, win);
|
||||
mpilib_ftable_.Win_flush_local(my_pe_, win);
|
||||
|
||||
/*
|
||||
* Compare based on the operation
|
||||
|
||||
@@ -24,13 +24,14 @@
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "backend_ipc.hpp"
|
||||
#include "ipc_team.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include "backend_ipc.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
#include "ipc_team.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
#define NET_CHECK(cmd) \
|
||||
@@ -249,8 +250,8 @@ void IPCBackend::create_new_team([[maybe_unused]] Team *parent_team,
|
||||
* the pool of available work arrays.
|
||||
*/
|
||||
if (team_comm != MPI_COMM_NULL) {
|
||||
NET_CHECK(MPI_Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_,
|
||||
MPI_CHAR, MPI_BAND, team_comm));
|
||||
NET_CHECK(mpilib_ftable_.Allreduce(team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_,
|
||||
MPI_CHAR, MPI_BAND, team_comm));
|
||||
} else {
|
||||
Allreduce_char_BAND (team_pool_bitmask_, team_reduced_bitmask_, team_bitmask_size_, parent_team);
|
||||
}
|
||||
@@ -321,7 +322,7 @@ void IPCBackend::initIPC(TcpBootstrap *bootstr) {
|
||||
|
||||
void IPCBackend::global_exit(int status) {
|
||||
if (backend_comm != MPI_COMM_NULL)
|
||||
MPI_Abort(backend_comm, status);
|
||||
mpilib_ftable_.Abort(backend_comm, status);
|
||||
else
|
||||
abort();
|
||||
}
|
||||
@@ -388,8 +389,8 @@ void IPCBackend::setup_wrk_sync_buffers() {
|
||||
* all-to-all exchange with each PE to share the IPC handles.
|
||||
*/
|
||||
if (backend_comm != MPI_COMM_NULL) {
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR,
|
||||
ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm);
|
||||
mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR,
|
||||
ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm);
|
||||
} else {
|
||||
assert (backend_bootstr != nullptr);
|
||||
backend_bootstr->allGather(ipc_handle, sizeof(hipIpcMemHandle_t));
|
||||
@@ -460,7 +461,7 @@ void IPCBackend::rocshmem_collective_init() {
|
||||
* continuing.
|
||||
*/
|
||||
if (backend_comm != MPI_COMM_NULL) {
|
||||
NET_CHECK(MPI_Barrier(backend_comm));
|
||||
NET_CHECK(mpilib_ftable_.Barrier(backend_comm));
|
||||
} else {
|
||||
backend_bootstr->barrier();
|
||||
}
|
||||
@@ -551,7 +552,7 @@ void IPCBackend::teams_init() {
|
||||
* continuing.
|
||||
*/
|
||||
if (backend_comm != MPI_COMM_NULL) {
|
||||
NET_CHECK(MPI_Barrier(backend_comm));
|
||||
NET_CHECK(mpilib_ftable_.Barrier(backend_comm));
|
||||
} else {
|
||||
backend_bootstr->barrier();
|
||||
}
|
||||
|
||||
@@ -24,8 +24,6 @@
|
||||
|
||||
#include "context_ipc_host.hpp"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "backend_type.hpp"
|
||||
#include "context_incl.hpp"
|
||||
|
||||
@@ -24,8 +24,6 @@
|
||||
|
||||
#include "ipc_policy.hpp"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "backend_bc.hpp"
|
||||
#include "context_incl.hpp"
|
||||
@@ -39,20 +37,20 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases,
|
||||
* Create an MPI communicator that deals only with local processes.
|
||||
*/
|
||||
MPI_Comm shmcomm;
|
||||
MPI_Comm_split_type(thread_comm, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
|
||||
&shmcomm);
|
||||
mpilib_ftable_.Comm_split_type(thread_comm, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
|
||||
&shmcomm);
|
||||
|
||||
/*
|
||||
* Figure out how many local process there are.
|
||||
*/
|
||||
int Shm_size;
|
||||
MPI_Comm_size(shmcomm, &Shm_size);
|
||||
mpilib_ftable_.Comm_size(shmcomm, &Shm_size);
|
||||
shm_size = Shm_size;
|
||||
|
||||
/*
|
||||
* Figure out how this process' rank among local processes.
|
||||
*/
|
||||
MPI_Comm_rank(shmcomm, &shm_rank);
|
||||
mpilib_ftable_.Comm_rank(shmcomm, &shm_rank);
|
||||
|
||||
/*
|
||||
* Allocate a host-side c-array to hold the IPC handles.
|
||||
@@ -73,8 +71,8 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases,
|
||||
* Do an all-to-all exchange with each local processing element to
|
||||
* share the symmetric heap IPC handles.
|
||||
*/
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR,
|
||||
vec_ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, shmcomm);
|
||||
mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR,
|
||||
vec_ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, shmcomm);
|
||||
|
||||
/*
|
||||
* Allocate device-side array to hold the IPC symmetric heap base
|
||||
@@ -114,16 +112,17 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases,
|
||||
|
||||
CHECK_HIP(hipMalloc(reinterpret_cast<void**>(&pes_with_ipc_avail), shm_size * sizeof(int)));
|
||||
|
||||
MPI_Group thread_grp, shm_grp;
|
||||
MPI_Comm_group(thread_comm, &thread_grp);
|
||||
MPI_Comm_group(shmcomm, &shm_grp);
|
||||
MPI_Group thread_grp;
|
||||
MPI_Group shm_grp;
|
||||
mpilib_ftable_.Comm_group(thread_comm, &thread_grp);
|
||||
mpilib_ftable_.Comm_group(shmcomm, &shm_grp);
|
||||
int *seqranks = new int[shm_size];
|
||||
for(int i = 0; i < shm_size; i++)
|
||||
seqranks[i] = i;
|
||||
MPI_Group_translate_ranks(shm_grp, shm_size, seqranks, thread_grp, pes_with_ipc_avail);
|
||||
mpilib_ftable_.Group_translate_ranks(shm_grp, shm_size, seqranks, thread_grp, pes_with_ipc_avail);
|
||||
delete [] seqranks;
|
||||
MPI_Group_free(&shm_grp);
|
||||
MPI_Group_free(&thread_grp);
|
||||
mpilib_ftable_.Group_free(&shm_grp);
|
||||
mpilib_ftable_.Group_free(&thread_grp);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,12 +26,12 @@
|
||||
#define LIBRARY_SRC_IPC_POLICY_HPP_
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "mpi_instance.hpp"
|
||||
#include "memory/hip_allocator.hpp"
|
||||
#include "util.hpp"
|
||||
#include "bootstrap/bootstrap.hpp"
|
||||
|
||||
@@ -26,10 +26,10 @@
|
||||
#define LIBRARY_SRC_MEMORY_REMOTE_HEAP_INFO_HPP_
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mpi_instance.hpp"
|
||||
#include "hip_allocator.hpp"
|
||||
#include "window_info.hpp"
|
||||
#include "bootstrap/bootstrap.hpp"
|
||||
@@ -53,10 +53,10 @@ class CommunicatorMPI {
|
||||
* @brief Primary constructor
|
||||
*/
|
||||
CommunicatorMPI(char* heap_base, size_t heap_size,
|
||||
MPI_Comm comm = MPI_COMM_WORLD)
|
||||
MPI_Comm comm)
|
||||
: comm_{comm} {
|
||||
MPI_Comm_rank(comm_, &my_pe_);
|
||||
MPI_Comm_size(comm_, &num_pes_);
|
||||
mpilib_ftable_.Comm_rank(comm_, &my_pe_);
|
||||
mpilib_ftable_.Comm_size(comm_, &num_pes_);
|
||||
heap_window_info_ = WindowInfoMPI(comm_, heap_base, heap_size);
|
||||
}
|
||||
|
||||
@@ -78,14 +78,14 @@ class CommunicatorMPI {
|
||||
/**
|
||||
* @brief Performs MPI_Barrier
|
||||
*/
|
||||
void barrier() { MPI_Barrier(comm_); }
|
||||
void barrier() { mpilib_ftable_.Barrier(comm_); }
|
||||
|
||||
/**
|
||||
* @brief Performs MPI_Allgather on recvbuf
|
||||
*/
|
||||
void allgather(void* recvbuf) {
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf,
|
||||
sizeof(void*), MPI_CHAR, comm_);
|
||||
mpilib_ftable_.Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf,
|
||||
sizeof(void*), MPI_CHAR, comm_);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -206,7 +206,7 @@ class RemoteHeapInfo {
|
||||
* @param[in] The total number of processing elements
|
||||
*/
|
||||
RemoteHeapInfo(char* heap_ptr, size_t heap_size,
|
||||
MPI_Comm comm = MPI_COMM_WORLD)
|
||||
MPI_Comm comm)
|
||||
: communicator_{heap_ptr, heap_size, comm} {
|
||||
init(heap_ptr, heap_size);
|
||||
}
|
||||
|
||||
@@ -83,14 +83,14 @@ private:
|
||||
class SymmetricHeap {
|
||||
|
||||
public:
|
||||
SymmetricHeap(MPI_Comm comm = MPI_COMM_NULL, TcpBootstrap* bootstrap = nullptr) {
|
||||
SymmetricHeap(MPI_Comm comm, TcpBootstrap* bootstrap = nullptr) {
|
||||
|
||||
if (comm != MPI_COMM_NULL) {
|
||||
remote_heap_info_ = new RemoteHeapInfoMPI(single_heap_.get_base_ptr(),
|
||||
single_heap_.get_size(), comm);
|
||||
single_heap_.get_size(), comm);
|
||||
} else {
|
||||
remote_heap_info_ = new RemoteHeapInfoTCP(single_heap_.get_base_ptr(),
|
||||
single_heap_.get_size(), bootstrap);
|
||||
single_heap_.get_size(), bootstrap);
|
||||
}
|
||||
}
|
||||
/**
|
||||
|
||||
@@ -25,10 +25,9 @@
|
||||
#ifndef LIBRARY_SRC_MEMORY_WINDOW_INFO_HPP_
|
||||
#define LIBRARY_SRC_MEMORY_WINDOW_INFO_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
/**
|
||||
* @file window_info.hpp
|
||||
@@ -155,8 +154,8 @@ class WindowInfoMPI: public WindowInfo {
|
||||
win_end_ = reinterpret_cast<char*>(start) + size;
|
||||
|
||||
up_win_ = std::unique_ptr<MPI_Win>(new MPI_Win);
|
||||
MPI_Win_create(win_start_, size, 1, MPI_INFO_NULL, comm_, up_win_.get());
|
||||
MPI_Win_lock_all(MPI_MODE_NOCHECK, *up_win_.get());
|
||||
mpilib_ftable_.Win_create(win_start_, size, 1, MPI_INFO_NULL, comm_, up_win_.get());
|
||||
mpilib_ftable_.Win_lock_all(MPI_MODE_NOCHECK, *up_win_.get());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -164,8 +163,8 @@ class WindowInfoMPI: public WindowInfo {
|
||||
*/
|
||||
~WindowInfoMPI() {
|
||||
if (up_win_) {
|
||||
MPI_Win_unlock_all(*up_win_.get());
|
||||
MPI_Win_free(up_win_.get());
|
||||
mpilib_ftable_.Win_unlock_all(*up_win_.get());
|
||||
mpilib_ftable_.Win_free(up_win_.get());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,9 +221,9 @@ class WindowInfoMPI: public WindowInfo {
|
||||
reinterpret_cast<char*>(win_end_));
|
||||
|
||||
MPI_Aint dest_disp;
|
||||
MPI_Get_address(dest, &dest_disp);
|
||||
mpilib_ftable_.Get_address(dest, &dest_disp);
|
||||
MPI_Aint start_disp;
|
||||
MPI_Get_address(win_start_, &start_disp);
|
||||
mpilib_ftable_.Get_address(win_start_, &start_disp);
|
||||
|
||||
return static_cast<ptrdiff_t>(MPI_Aint_diff(dest_disp, start_disp));
|
||||
}
|
||||
|
||||
@@ -22,33 +22,139 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "rocshmem/rocshmem.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
#include "util.hpp"
|
||||
|
||||
#if !defined(HAVE_EXTERNAL_MPI)
|
||||
// Open MPI specific symbols
|
||||
struct ompi_internal_symbols_t ompi_symbols_;
|
||||
#endif //!defined(HAVE_EXTERNAL_MPI)
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
static void* mpilib_handle_{nullptr};
|
||||
struct mpilib_funcs_t mpilib_ftable_;
|
||||
|
||||
int MPIInstance::mpilib_dl_init() {
|
||||
if (mpilib_handle_ != nullptr)
|
||||
return ROCSHMEM_SUCCESS;
|
||||
|
||||
mpilib_handle_ = dlopen("libmpi.so", RTLD_NOW);
|
||||
if (!mpilib_handle_) {
|
||||
printf("Could not open libmpi.so. Returning\n");
|
||||
return ROCSHMEM_ERROR;
|
||||
}
|
||||
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Init_thread);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Initialized);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Finalize);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Finalized);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_rank);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_size);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Abort);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Get_address);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Type_size);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Iprobe);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Testsome);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_split);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_split_type);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_group);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_create_group);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_dup);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Comm_free);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Group_free);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Group_translate_ranks);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Group_incl);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Allgather);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Alltoall);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Allreduce);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Bcast);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Barrier);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Iallreduce);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Ibarrier);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_create);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_free);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_flush);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_flush_all);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_flush_local);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_lock);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_lock_all);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_unlock);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_unlock_all);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_lock);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Win_sync);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Get);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Rget);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Put);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Rput);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Compare_and_swap);
|
||||
DLSYM_HELPER(mpilib_ftable_, MPI_, mpilib_handle_, Fetch_and_op);
|
||||
|
||||
#if !defined(HAVE_EXTERNAL_MPI)
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_comm_world);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_comm_null);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_datatype_null);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_request_null);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_info_null);
|
||||
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_max);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_min);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_sum);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_prod);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_band);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_bor);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_bxor);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_replace);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_op_no_op);
|
||||
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_char);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_char);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_signed_char);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_short);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_short);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_int);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_long);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_long);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_long_long_int);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_unsigned_long_long);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_float);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_double);
|
||||
DLSYM_VAR_HELPER(ompi_symbols_, mpilib_handle_, ompi_mpi_long_double);
|
||||
#endif //!defined(HAVE_EXTERNAL_MPI)
|
||||
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
void MPIInstance::mpilib_dl_close() {
|
||||
if (mpilib_handle_ != nullptr) {
|
||||
dlclose(mpilib_handle_);
|
||||
mpilib_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
MPIInstance::MPIInstance(MPI_Comm comm) {
|
||||
int is_init{0};
|
||||
MPI_Initialized(&is_init);
|
||||
mpilib_ftable_.Initialized(&is_init);
|
||||
|
||||
if (!is_init) {
|
||||
int provided;
|
||||
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
|
||||
mpilib_ftable_.Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
|
||||
init_in_this_class = 1;
|
||||
}
|
||||
|
||||
if (comm == MPI_COMM_NULL) {
|
||||
comm = MPI_COMM_WORLD;
|
||||
}
|
||||
|
||||
MPI_Comm_size(comm, &nprocs_);
|
||||
MPI_Comm_rank(comm, &my_rank_);
|
||||
mpilib_ftable_.Comm_size(comm, &nprocs_);
|
||||
mpilib_ftable_.Comm_rank(comm, &my_rank_);
|
||||
}
|
||||
|
||||
MPIInstance::~MPIInstance() {
|
||||
int finalized{0};
|
||||
MPI_Finalized(&finalized);
|
||||
mpilib_ftable_.Finalized(&finalized);
|
||||
if (!finalized && init_in_this_class) {
|
||||
MPI_Finalize();
|
||||
mpilib_ftable_.Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@
|
||||
#ifndef LIBRARY_SRC_MPI_INSTANCE_HPP_
|
||||
#define LIBRARY_SRC_MPI_INSTANCE_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <rocshmem/rocshmem_config.h>
|
||||
#include <rocshmem/rocshmem_mpi.hpp>
|
||||
#include <memory>
|
||||
|
||||
/**
|
||||
@@ -37,6 +37,63 @@
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
struct mpilib_funcs_t {
|
||||
int (*Init_thread)(int *argc, char ***argv, int required, int *provided);
|
||||
int (*Initialized)(int *flag);
|
||||
int (*Finalize)(void);
|
||||
int (*Finalized)(int *flag);
|
||||
int (*Comm_rank)(MPI_Comm comm, int *rank);
|
||||
int (*Comm_size)(MPI_Comm comm, int *size);
|
||||
int (*Abort)(MPI_Comm comm, int errorcode);
|
||||
int (*Get_address)(const void *location, MPI_Aint *address);
|
||||
int (*Type_size)(MPI_Datatype type, int *size);
|
||||
int (*Iprobe)(int source, int tag, MPI_Comm comm, int *flag, MPI_Status *status);
|
||||
int (*Testsome)(int incount, MPI_Request array_of_requests[], int *outcount, int array_of_indices[],
|
||||
MPI_Status array_of_statuses[]);
|
||||
int (*Comm_split)(MPI_Comm comm, int color, int key, MPI_Comm *newcomm);
|
||||
int (*Comm_split_type)(MPI_Comm comm, int split_type, int key, MPI_Info info, MPI_Comm *newcomm);
|
||||
int (*Comm_group)(MPI_Comm comm, MPI_Group *group);
|
||||
int (*Comm_create_group)(MPI_Comm comm, MPI_Group group, int tag, MPI_Comm *newcomm);
|
||||
int (*Comm_dup)(MPI_Comm comm, MPI_Comm *newcomm);
|
||||
int (*Comm_free)(MPI_Comm *comm);
|
||||
int (*Group_free)(MPI_Group *group);
|
||||
int (*Group_translate_ranks)(MPI_Group group1, int n, const int ranks1[], MPI_Group group2, int ranks2[]);
|
||||
int (*Group_incl)(MPI_Group group, int n, const int ranks[], MPI_Group *newgroup);
|
||||
int (*Allgather)(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm);
|
||||
int (*Allreduce)(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
|
||||
MPI_Op op, MPI_Comm comm);
|
||||
int (*Alltoall)(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount,
|
||||
MPI_Datatype recvtype, MPI_Comm comm);
|
||||
int (*Bcast)(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
|
||||
int (*Barrier)(MPI_Comm comm);
|
||||
int (*Iallreduce)(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
|
||||
MPI_Op op, MPI_Comm comm, MPI_Request *request);
|
||||
int (*Ibarrier)(MPI_Comm comm, MPI_Request *request);
|
||||
int (*Win_create)(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm, MPI_Win *win);
|
||||
int (*Win_free)(MPI_Win *win);
|
||||
int (*Win_flush)(MPI_Win win);
|
||||
int (*Win_flush_all)(MPI_Win win);
|
||||
int (*Win_flush_local)(int rank, MPI_Win win);
|
||||
int (*Win_lock)(int lock_type, int rank, int mpi_assert, MPI_Win win);
|
||||
int (*Win_lock_all)(int mpi_assert, MPI_Win win);
|
||||
int (*Win_sync)(MPI_Win win);
|
||||
int (*Win_unlock)(int rank, MPI_Win win);
|
||||
int (*Win_unlock_all)(MPI_Win win);
|
||||
int (*Get)(void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
|
||||
MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win);
|
||||
int (*Rget)(void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp,
|
||||
int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request *request);
|
||||
int (*Put)(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp,
|
||||
int target_count, MPI_Datatype target_datatype, MPI_Win win);
|
||||
int (*Rput)(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp,
|
||||
int target_cout, MPI_Datatype target_datatype, MPI_Win win, MPI_Request *request);
|
||||
int (*Compare_and_swap)(const void *origin_addr, const void *compare_addr, void *result_addr, MPI_Datatype datatype, int target_rank,
|
||||
MPI_Aint target_disp, MPI_Win win);
|
||||
int (*Fetch_and_op)(const void *origin_addr, void *result_addr, MPI_Datatype datatype,
|
||||
int target_rank, MPI_Aint target_disp, MPI_Op op, MPI_Win win);
|
||||
};
|
||||
extern struct mpilib_funcs_t mpilib_ftable_;
|
||||
|
||||
class MPIInstance {
|
||||
public:
|
||||
/**
|
||||
@@ -63,6 +120,19 @@ class MPIInstance {
|
||||
*/
|
||||
int get_nprocs();
|
||||
|
||||
/**
|
||||
* @brief dlopen the MPI library and set
|
||||
* function pointers.
|
||||
* @return ROCSHMEM_SUCCESS on success,
|
||||
* ROCSHMEM_ERROR otherwise.
|
||||
*/
|
||||
static int mpilib_dl_init(void);
|
||||
|
||||
/**
|
||||
* @brief dlclose the MPI library
|
||||
*/
|
||||
static void mpilib_dl_close(void);
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief My MPI rank identifier
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
|
||||
#include "context_ro_host.hpp"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir)
|
||||
#include "backend_type.hpp"
|
||||
|
||||
@@ -50,14 +50,14 @@ MPITransport::MPITransport(MPI_Comm comm, Queue* q)
|
||||
|
||||
assert(comm != MPI_COMM_NULL);
|
||||
|
||||
NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world));
|
||||
NET_CHECK(MPI_Comm_size(ro_net_comm_world, &num_pes));
|
||||
NET_CHECK(MPI_Comm_rank(ro_net_comm_world, &my_pe));
|
||||
NET_CHECK(mpilib_ftable_.Comm_dup(comm, &ro_net_comm_world));
|
||||
NET_CHECK(mpilib_ftable_.Comm_size(ro_net_comm_world, &num_pes));
|
||||
NET_CHECK(mpilib_ftable_.Comm_rank(ro_net_comm_world, &my_pe));
|
||||
}
|
||||
|
||||
MPITransport::~MPITransport() {
|
||||
if (ro_net_comm_world != MPI_COMM_NULL)
|
||||
NET_CHECK(MPI_Comm_free(&ro_net_comm_world));
|
||||
NET_CHECK(mpilib_ftable_.Comm_free(&ro_net_comm_world));
|
||||
}
|
||||
|
||||
void MPITransport::threadProgressEngine() {
|
||||
@@ -267,13 +267,13 @@ void MPITransport::createNewTeam(ROBackend *backend, Team *parent_team,
|
||||
}
|
||||
|
||||
void MPITransport::global_exit(int status) {
|
||||
MPI_Abort(ro_net_comm_world, status);
|
||||
mpilib_ftable_.Abort(ro_net_comm_world, status);
|
||||
}
|
||||
|
||||
void MPITransport::barrier(int contextId, volatile char *status, bool blocking,
|
||||
MPI_Comm team, bool do_quiet) {
|
||||
MPI_Request request{};
|
||||
NET_CHECK(MPI_Ibarrier(team, &request));
|
||||
NET_CHECK(mpilib_ftable_.Ibarrier(team, &request));
|
||||
|
||||
if (do_quiet) {
|
||||
requests.push_back({request, {nullptr, contextId, false}});
|
||||
@@ -351,10 +351,10 @@ void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
|
||||
MPI_Comm comm{team};
|
||||
|
||||
if (dst == src) {
|
||||
NET_CHECK(MPI_Iallreduce(MPI_IN_PLACE, dst, size, mpi_type, mpi_op, comm,
|
||||
NET_CHECK(mpilib_ftable_.Iallreduce(MPI_IN_PLACE, dst, size, mpi_type, mpi_op, comm,
|
||||
&request));
|
||||
} else {
|
||||
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
|
||||
NET_CHECK(mpilib_ftable_.Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
|
||||
}
|
||||
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
@@ -370,25 +370,25 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
|
||||
MPI_Comm comm{team};
|
||||
int rank{}, pe_size{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
NET_CHECK(mpilib_ftable_.Comm_rank(comm, &rank));
|
||||
NET_CHECK(mpilib_ftable_.Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp{}, world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
|
||||
NET_CHECK(mpilib_ftable_.Comm_group(comm, &grp));
|
||||
NET_CHECK(mpilib_ftable_.Comm_group(ro_net_comm_world, &world_grp));
|
||||
|
||||
std::vector<int> ranks(pe_size);
|
||||
std::vector<int> world_ranks(pe_size);
|
||||
|
||||
for (int i = 0; i < pe_size; i++) ranks[i] = i;
|
||||
|
||||
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
NET_CHECK(mpilib_ftable_.Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
MPI_Request req;
|
||||
|
||||
if (rank != root){
|
||||
NET_CHECK(MPI_Rget(reinterpret_cast<char *>(dst), size, mpi_type, world_ranks[root],
|
||||
NET_CHECK(mpilib_ftable_.Rget(reinterpret_cast<char *>(dst), size, mpi_type, world_ranks[root],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(src)),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &req));
|
||||
|
||||
@@ -396,7 +396,7 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
barrier(contextId, nullptr, false, comm, false);
|
||||
quiet(contextId, status);
|
||||
}
|
||||
@@ -409,22 +409,22 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
|
||||
|
||||
MPI_Comm comm{team};
|
||||
int rank{}, pe_size{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
NET_CHECK(mpilib_ftable_.Comm_rank(comm, &rank));
|
||||
NET_CHECK(mpilib_ftable_.Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp{}, world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
|
||||
NET_CHECK(mpilib_ftable_.Comm_group(comm, &grp));
|
||||
NET_CHECK(mpilib_ftable_.Comm_group(ro_net_comm_world, &world_grp));
|
||||
|
||||
std::vector<int> ranks(pe_size);
|
||||
std::vector<int> world_ranks(pe_size);
|
||||
for (int i = 0; i < pe_size; i++) ranks[i] = i;
|
||||
|
||||
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
NET_CHECK(mpilib_ftable_.Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
int type_size{};
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
NET_CHECK(mpilib_ftable_.Type_size(mpi_type, &type_size));
|
||||
|
||||
if (dst == src) {
|
||||
fprintf(stderr, "IN_PLACE option not support for alltoall in the RO rocSHMEM conduit\n");
|
||||
@@ -436,7 +436,7 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
|
||||
int target = (rank + i) % pe_size;
|
||||
int src_offset = target * type_size * size;
|
||||
int dst_offset = rank * type_size * size;
|
||||
NET_CHECK(MPI_Rput(reinterpret_cast<char *>(src) + src_offset, size,
|
||||
NET_CHECK(mpilib_ftable_.Rput(reinterpret_cast<char *>(src) + src_offset, size,
|
||||
mpi_type, world_ranks[target],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) + dst_offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
|
||||
@@ -445,7 +445,7 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
@@ -457,23 +457,23 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
|
||||
|
||||
MPI_Comm comm{team};
|
||||
int rank{}, pe_size{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
NET_CHECK(mpilib_ftable_.Comm_rank(comm, &rank));
|
||||
NET_CHECK(mpilib_ftable_.Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp{}, world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
|
||||
NET_CHECK(mpilib_ftable_.Comm_group(comm, &grp));
|
||||
NET_CHECK(mpilib_ftable_.Comm_group(ro_net_comm_world, &world_grp));
|
||||
|
||||
std::vector<int> ranks(pe_size);
|
||||
std::vector<int> world_ranks(pe_size);
|
||||
|
||||
for (int i = 0; i < pe_size; i++) ranks[i] = i;
|
||||
|
||||
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
NET_CHECK(mpilib_ftable_.Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
int type_size{};
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
NET_CHECK(mpilib_ftable_.Type_size(mpi_type, &type_size));
|
||||
|
||||
if (dst == src) {
|
||||
fprintf(stderr, "IN_PLACE option not support for fcollect in the RO rocSHMEM conduit\n");
|
||||
@@ -484,7 +484,7 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
|
||||
for (int i = 0; i < pe_size; ++i) {
|
||||
int target = (rank + i) % pe_size;
|
||||
int offset = rank * type_size * size;
|
||||
NET_CHECK(MPI_Rput(reinterpret_cast<char *>(src), size, mpi_type, world_ranks[target],
|
||||
NET_CHECK(mpilib_ftable_.Rput(reinterpret_cast<char *>(src), size, mpi_type, world_ranks[target],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) + offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &pe_req[i]));
|
||||
|
||||
@@ -492,7 +492,7 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
@@ -504,14 +504,14 @@ void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
auto *bp{backend_proxy->get()};
|
||||
MPI_Request request{};
|
||||
|
||||
NET_CHECK(MPI_Rput(
|
||||
NET_CHECK(mpilib_ftable_.Rput(
|
||||
src, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(dst),
|
||||
size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request));
|
||||
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet. Put it here as a hack for now even
|
||||
// though it should be in the progress loop.
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
NET_CHECK(mpilib_ftable_.Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
|
||||
@@ -525,7 +525,7 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id,
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Fetch_and_op(reinterpret_cast<void *>(val), src, mpi_type, pe,
|
||||
NET_CHECK(mpilib_ftable_.Fetch_and_op(reinterpret_cast<void *>(val), src, mpi_type, pe,
|
||||
bp->heap_window_info[win_id]->get_offset(dst),
|
||||
get_mpi_op(op),
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
@@ -533,7 +533,7 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id,
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet. Put it here as a hack for now even
|
||||
// though it should be in the progress loop.
|
||||
NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
|
||||
NET_CHECK(mpilib_ftable_.Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
queue->notify(status);
|
||||
|
||||
@@ -547,7 +547,7 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe,
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Compare_and_swap((const void *)val, (const void *)cond, src,
|
||||
NET_CHECK(mpilib_ftable_.Compare_and_swap((const void *)val, (const void *)cond, src,
|
||||
mpi_type, pe,
|
||||
bp->heap_window_info[win_id]->get_offset(dst),
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
@@ -555,7 +555,7 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe,
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet. Put it here as a hack for now even
|
||||
// though it should be in the progress loop.
|
||||
NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
|
||||
NET_CHECK(mpilib_ftable_.Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
queue->notify(status);
|
||||
|
||||
@@ -569,7 +569,7 @@ void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
MPI_Request request{};
|
||||
NET_CHECK(MPI_Rget(
|
||||
NET_CHECK(mpilib_ftable_.Rget(
|
||||
dst, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(src),
|
||||
size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request));
|
||||
|
||||
@@ -595,7 +595,7 @@ void MPITransport::progress() {
|
||||
// Slowing the progress engine down a bit avoid hammering the memory subsystem.
|
||||
// This leads to significant performance benefits
|
||||
usleep (progress_delay);
|
||||
NET_CHECK(MPI_Iprobe(MPI_ANY_SOURCE, tag, ro_net_comm_world, &flag, &status));
|
||||
NET_CHECK(mpilib_ftable_.Iprobe(MPI_ANY_SOURCE, tag, ro_net_comm_world, &flag, &status));
|
||||
} else {
|
||||
DPRINTF("Testing all outstanding requests (%zu)\n", requests.size());
|
||||
|
||||
@@ -605,7 +605,7 @@ void MPITransport::progress() {
|
||||
int outcount{};
|
||||
|
||||
auto uptr_req_arr {raw_requests()};
|
||||
NET_CHECK(MPI_Testsome(incount, uptr_req_arr.get(), &outcount,
|
||||
NET_CHECK(mpilib_ftable_.Testsome(incount, uptr_req_arr.get(), &outcount,
|
||||
testsome_indices.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
@@ -25,8 +25,6 @@
|
||||
#ifndef LIBRARY_SRC_REVERSE_OFFLOAD_QUEUE_PROXY_HPP_
|
||||
#define LIBRARY_SRC_REVERSE_OFFLOAD_QUEUE_PROXY_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "atomic_return.hpp"
|
||||
#include "device_proxy.hpp"
|
||||
#include "hdp_policy.hpp"
|
||||
|
||||
@@ -25,11 +25,10 @@
|
||||
#ifndef LIBRARY_SRC_REVERSE_OFFLOAD_RO_TEAM_PROXY_HPP_
|
||||
#define LIBRARY_SRC_REVERSE_OFFLOAD_RO_TEAM_PROXY_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "device_proxy.hpp"
|
||||
#include "ro_net_team.hpp"
|
||||
#include "team_info_proxy.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
|
||||
@@ -25,13 +25,12 @@
|
||||
#ifndef LIBRARY_SRC_REVERSE_OFFLOAD_TRANSPORT_HPP_
|
||||
#define LIBRARY_SRC_REVERSE_OFFLOAD_TRANSPORT_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "rocshmem/rocshmem.hpp"
|
||||
#include "backend_proxy.hpp"
|
||||
#include "ro_net_team.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ MPIInstance *mpi_instance = nullptr;
|
||||
TcpBootstrap *bootstr = nullptr;
|
||||
rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
/**
|
||||
/**
|
||||
* Begin Host Code
|
||||
**/
|
||||
|
||||
@@ -92,12 +92,18 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
rocm_init();
|
||||
|
||||
int ret;
|
||||
ret = MPIInstance::mpilib_dl_init();
|
||||
mpi_instance = new MPIInstance(comm);
|
||||
|
||||
#if defined(USE_GDA)
|
||||
CHECK_HIP(hipHostMalloc(&backend, sizeof(GDABackend)));
|
||||
backend = new (backend) GDABackend(comm);
|
||||
#elif defined(USE_RO)
|
||||
if (ret != ROCSHMEM_SUCCESS) {
|
||||
printf("Could not initialize MPI library. RO conduit requires MPI library to be loaded at runtime. Aborting\n");
|
||||
abort();
|
||||
}
|
||||
CHECK_HIP(hipHostMalloc(&backend, sizeof(ROBackend)));
|
||||
backend = new (backend) ROBackend(comm);
|
||||
#elif defined(USE_IPC)
|
||||
@@ -113,7 +119,15 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
[[maybe_unused]] __host__ static void inline library_init_subcomm(TcpBootstrap *bootstrap, int nranks, int rank) {
|
||||
int initialized;
|
||||
int world_size = -1;
|
||||
MPI_Initialized(&initialized);
|
||||
|
||||
int ret;
|
||||
ret = MPIInstance::mpilib_dl_init();
|
||||
if (ret == ROCSHMEM_SUCCESS) {
|
||||
printf("Could not initialize MPI library. This initialization method of "
|
||||
"rocSHMEM requires MPI library to be loaded at runtime. Aborting\n");
|
||||
abort();
|
||||
}
|
||||
mpilib_ftable_.Initialized(&initialized);
|
||||
|
||||
if (!initialized) {
|
||||
// This is an Open MPI specific solution to retrieve the number of
|
||||
@@ -131,7 +145,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
abort();
|
||||
}
|
||||
} else {
|
||||
MPI_Comm_size (MPI_COMM_WORLD, &world_size);
|
||||
mpilib_ftable_.Comm_size (MPI_COMM_WORLD, &world_size);
|
||||
}
|
||||
|
||||
if (world_size == nranks) {
|
||||
@@ -140,8 +154,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
MPI_Group world_group;
|
||||
int world_rank;
|
||||
|
||||
MPI_Comm_rank (MPI_COMM_WORLD, &world_rank);
|
||||
MPI_Comm_group (MPI_COMM_WORLD, &world_group);
|
||||
mpilib_ftable_.Comm_rank (MPI_COMM_WORLD, &world_rank);
|
||||
mpilib_ftable_.Comm_group (MPI_COMM_WORLD, &world_group);
|
||||
|
||||
int *inc_ranks = new int[nranks];
|
||||
inc_ranks[rank] = world_rank;
|
||||
@@ -150,14 +164,14 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
MPI_Group sub_group;
|
||||
MPI_Comm sub_comm;
|
||||
MPI_Group_incl (world_group, nranks, inc_ranks, &sub_group);
|
||||
MPI_Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm);
|
||||
mpilib_ftable_.Group_incl (world_group, nranks, inc_ranks, &sub_group);
|
||||
mpilib_ftable_.Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm);
|
||||
|
||||
library_init(sub_comm);
|
||||
|
||||
MPI_Group_free (&sub_group);
|
||||
MPI_Group_free (&world_group);
|
||||
MPI_Comm_free (&sub_comm);
|
||||
mpilib_ftable_.Group_free (&sub_group);
|
||||
mpilib_ftable_.Group_free (&world_group);
|
||||
mpilib_ftable_.Comm_free (&sub_comm);
|
||||
delete[] inc_ranks;
|
||||
}
|
||||
}
|
||||
@@ -192,7 +206,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_init_attr(unsigned int flags,
|
||||
rocshmem_init_attr_t *attr) {
|
||||
MPI_Comm comm = MPI_COMM_NULL;
|
||||
MPI_Comm comm;
|
||||
|
||||
if ((attr == nullptr) ||
|
||||
((flags != ROCSHMEM_INIT_WITH_UNIQUEID) &&
|
||||
@@ -228,12 +242,12 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_set_attr_uniqueid_args(int rank, int nranks,
|
||||
rocshmem_uniqueid_t *uid,
|
||||
rocshmem_init_attr_t *attr) {
|
||||
rocshmem_uniqueid_t *uid,
|
||||
rocshmem_init_attr_t *attr) {
|
||||
if (uid == nullptr || attr == nullptr) {
|
||||
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
|
||||
"Call 'rocshmem_get_uniqueid: invalid input argument'",
|
||||
__FILE__, __LINE__);
|
||||
__FILE__, __LINE__);
|
||||
return ROCSHMEM_ERROR;
|
||||
}
|
||||
|
||||
@@ -252,7 +266,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
if (uid == nullptr) {
|
||||
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
|
||||
"Call 'rocshmem_get_uniqueid: invalid input argument'",
|
||||
__FILE__, __LINE__);
|
||||
__FILE__, __LINE__);
|
||||
return ROCSHMEM_ERROR;
|
||||
}
|
||||
|
||||
@@ -262,17 +276,29 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
#if defined(HAVE_EXTERNAL_MPI)
|
||||
[[maybe_unused]] __host__ void rocshmem_init(MPI_Comm comm) {
|
||||
library_init(comm);
|
||||
}
|
||||
#endif
|
||||
|
||||
[[maybe_unused]] __host__ void rocshmem_init() {
|
||||
MPIInstance::mpilib_dl_init();
|
||||
library_init(MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
#if defined(HAVE_EXTERNAL_MPI)
|
||||
[[maybe_unused]] __host__ int rocshmem_init_thread(
|
||||
[[maybe_unused]] int required, int *provided, MPI_Comm comm) {
|
||||
if (comm == static_cast<MPI_Comm>(0) || comm == MPI_COMM_NULL) {
|
||||
comm = MPI_COMM_WORLD;
|
||||
}
|
||||
library_init(comm);
|
||||
rocshmem_query_thread(provided);
|
||||
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
#endif
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_my_pe() {
|
||||
if (backend != nullptr) {
|
||||
@@ -297,7 +323,6 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
void *ptr;
|
||||
backend->heap.malloc(&ptr, size);
|
||||
|
||||
rocshmem_barrier_all();
|
||||
|
||||
return ptr;
|
||||
@@ -455,7 +480,8 @@ __host__ int rocshmem_team_split_strided(
|
||||
TeamInfo(team_world, pe_start_in_world, stride_in_world, size);
|
||||
|
||||
MPI_Comm team_comm{MPI_COMM_NULL};
|
||||
if (parent_team_obj->mpi_comm != MPI_COMM_NULL) {
|
||||
if (parent_team_obj->mpi_comm != MPI_COMM_NULL &&
|
||||
parent_team_obj->mpi_comm != static_cast<MPI_Comm>(0)) {
|
||||
/* Create a new MPI communicator for this team */
|
||||
int color;
|
||||
if (my_pe_in_new_team < 0) {
|
||||
@@ -464,8 +490,8 @@ __host__ int rocshmem_team_split_strided(
|
||||
color = 1;
|
||||
}
|
||||
|
||||
MPI_Comm_split(parent_team_obj->mpi_comm, color, my_pe_in_world, &team_comm);
|
||||
}
|
||||
mpilib_ftable_.Comm_split(parent_team_obj->mpi_comm, color, my_pe_in_world, &team_comm);
|
||||
}
|
||||
/**
|
||||
* Allocate new team for GPU-inittiated communication with backend-specific
|
||||
* objects
|
||||
@@ -484,8 +510,8 @@ __host__ int rocshmem_team_split_strided(
|
||||
backend->team_tracker.track(*new_team);
|
||||
}
|
||||
|
||||
if (team_comm != MPI_COMM_NULL) {
|
||||
MPI_Comm_free (&team_comm);
|
||||
if (team_comm != MPI_COMM_NULL && team_comm != static_cast<MPI_Comm>(0)) {
|
||||
mpilib_ftable_.Comm_free (&team_comm);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#ifndef LIBRARY_SRC_STATS_HPP_
|
||||
#define LIBRARY_SRC_STATS_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
#include <time.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
@@ -179,10 +179,10 @@ class HostStats {
|
||||
AtomicStatType stats[I] = {};
|
||||
|
||||
public:
|
||||
__host__ uint64_t startTimer() const { return MPI_Wtime(); }
|
||||
__host__ uint64_t startTimer() const { return wtime(); }
|
||||
|
||||
__host__ void endTimer(uint64_t start, int index) {
|
||||
incStat(index, MPI_Wtime() - start);
|
||||
incStat(index, wtime() - start);
|
||||
}
|
||||
|
||||
__host__ void incStat(int index, int value = 1) { stats[index] += value; }
|
||||
@@ -197,6 +197,16 @@ class HostStats {
|
||||
}
|
||||
|
||||
__host__ StatType getStat(int index) const { return stats[index].load(); }
|
||||
private:
|
||||
double wtime(void) {
|
||||
double wt;
|
||||
struct timespec tp;
|
||||
|
||||
(void) clock_gettime(CLOCK_MONOTONIC, &tp);
|
||||
wt = (static_cast<double>(tp.tv_nsec))/1.0e+9;
|
||||
wt += static_cast<double>(tp.tv_sec);
|
||||
return wt;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "rocshmem/rocshmem.hpp"
|
||||
#include "backend_bc.hpp"
|
||||
#include "util.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
@@ -84,7 +85,7 @@ __host__ Team::Team(Backend* handle, TeamInfo* team_info_wrt_parent,
|
||||
num_pes(_num_pes),
|
||||
my_pe(_my_pe) {
|
||||
if (_mpi_comm != MPI_COMM_NULL) {
|
||||
MPI_Comm_dup (_mpi_comm, &mpi_comm);
|
||||
mpilib_ftable_.Comm_dup (_mpi_comm, &mpi_comm);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +118,7 @@ __host__ __device__ int Team::get_pe_in_my_team(int pe_in_world) {
|
||||
|
||||
__host__ Team::~Team() {
|
||||
if (mpi_comm != MPI_COMM_NULL)
|
||||
MPI_Comm_free (&mpi_comm);
|
||||
mpilib_ftable_.Comm_free (&mpi_comm);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -25,9 +25,8 @@
|
||||
#ifndef LIBRARY_SRC_TEAM_HPP_
|
||||
#define LIBRARY_SRC_TEAM_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "rocshmem/rocshmem.hpp"
|
||||
#include "rocshmem/rocshmem_mpi.hpp"
|
||||
#include "backend_type.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
@@ -115,6 +115,35 @@ namespace rocshmem {
|
||||
} while (0);
|
||||
#endif
|
||||
|
||||
/* Helper Macros for handling dynamic libraries */
|
||||
#define PPCAT_NX(prefix, func_name) prefix##func_name
|
||||
#define PPCAT(prefix, func_name) PPCAT_NX(prefix, func_name)
|
||||
|
||||
#define STRINGIFY_NX(name) #name
|
||||
#define STRINGIFY(name) STRINGIFY_NX(name)
|
||||
|
||||
#define DLSYM_HELPER(func_struct, prefix, handle, func_name) \
|
||||
do { \
|
||||
*(void **) (&func_struct.func_name) = dlsym(handle, STRINGIFY(PPCAT(prefix, func_name))); \
|
||||
if (!func_struct.func_name) { \
|
||||
DPRINTF("Failed to find function %s \n", STRINGIFY(PPCAT(prefix, func_name))); \
|
||||
dlclose(handle); \
|
||||
handle = nullptr; \
|
||||
return ROCSHMEM_ERROR; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DLSYM_VAR_HELPER(func_struct, handle, var_name) \
|
||||
do { \
|
||||
*(void **) (&func_struct.var_name) = dlsym(handle, STRINGIFY(var_name)); \
|
||||
if (!func_struct.var_name) { \
|
||||
DPRINTF("Failed to find function %s \n", STRINGIFY(var_name)); \
|
||||
dlclose(handle); \
|
||||
handle = nullptr; \
|
||||
return ROCSHMEM_ERROR; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
extern const int gpu_clock_freq_mhz;
|
||||
|
||||
/* Device-side internal functions */
|
||||
|
||||
@@ -72,8 +72,9 @@ target_sources(
|
||||
# ROCSHMEM
|
||||
###############################################################################
|
||||
if (BUILD_TESTS_ONLY)
|
||||
find_package(MPI REQUIRED)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
#TODO check that build_test_only still works with external-mpi
|
||||
#find_package(MPI REQUIRED)
|
||||
#find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
find_package(rocshmem REQUIRED PATHS /opt/rocm)
|
||||
|
||||
target_include_directories(
|
||||
|
||||
@@ -25,7 +25,6 @@
|
||||
#include "tester.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
@@ -79,9 +79,9 @@ endif()
|
||||
# ROCSHMEM DEPENDENCY
|
||||
###############################################################################
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
find_package(MPI REQUIRED)
|
||||
|
||||
if (BUILD_TESTS_ONLY)
|
||||
find_package(MPI REQUIRED)
|
||||
find_package(rocshmem REQUIRED PATHS /opt/rocm)
|
||||
|
||||
target_include_directories(
|
||||
@@ -95,6 +95,7 @@ endif()
|
||||
target_link_libraries(
|
||||
${PROJECT_NAME}
|
||||
PRIVATE
|
||||
MPI::MPI_CXX
|
||||
roc::rocshmem
|
||||
)
|
||||
|
||||
|
||||
@@ -32,13 +32,13 @@ TEST_P(DegenerateSimpleCoarse, ptr_check) {
|
||||
}
|
||||
|
||||
TEST_P(DegenerateSimpleCoarse, MPI_num_pes) {
|
||||
ASSERT_EQ(mpi_.num_pes(), 2);
|
||||
ASSERT_EQ(mpi_->num_pes(), 2);
|
||||
}
|
||||
|
||||
TEST_P(DegenerateSimpleCoarse, IPC_bases) {
|
||||
ASSERT_EQ(mpi_.num_pes(), 2);
|
||||
ASSERT_EQ(mpi_->num_pes(), 2);
|
||||
ASSERT_NE(ipc_impl_.ipc_bases, nullptr);
|
||||
for(int i{0}; i < mpi_.num_pes(); i++) {
|
||||
for(int i{0}; i < mpi_->num_pes(); i++) {
|
||||
ASSERT_NE(ipc_impl_.ipc_bases[i], nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +73,10 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
|
||||
public:
|
||||
IPCImplSimpleCoarse() {
|
||||
ipc_impl_.ipcHostInit(mpi_.my_pe(), mpi_.get_heap_bases() , MPI_COMM_WORLD);
|
||||
MPIInstance::mpilib_dl_init();
|
||||
mpi_ = new MPI_T (heap_mem_.get_ptr(), heap_mem_.get_size(), MPI_COMM_WORLD);
|
||||
|
||||
ipc_impl_.ipcHostInit(mpi_->my_pe(), mpi_->get_heap_bases(), MPI_COMM_WORLD);
|
||||
assert(ipc_impl_dptr_ == nullptr);
|
||||
hip_allocator_.allocate((void**)&ipc_impl_dptr_, sizeof(IpcImpl));
|
||||
CHECK_HIP(hipMemcpy(ipc_impl_dptr_, &ipc_impl_,
|
||||
@@ -85,6 +88,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
hip_allocator_.deallocate(ipc_impl_dptr_);
|
||||
}
|
||||
ipc_impl_.ipcHostStop();
|
||||
MPIInstance::mpilib_dl_close();
|
||||
}
|
||||
|
||||
void launch(FN_T f, const dim3 grid, const dim3 block, int* src, int* dest, size_t bytes) {
|
||||
@@ -132,7 +136,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
return;
|
||||
}
|
||||
size_t bytes = golden_.size() * sizeof(int);
|
||||
auto dev_src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_.my_pe()]);
|
||||
auto dev_src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_->my_pe()]);
|
||||
CHECK_HIP(hipMemcpy(dev_src, golden_.data(), bytes, hipMemcpyHostToDevice));
|
||||
CHECK_HIP(hipStreamSynchronize(nullptr));
|
||||
}
|
||||
@@ -140,14 +144,14 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
bool pe_initializes_src_buffer(TestType test) {
|
||||
bool is_write_test = test;
|
||||
bool is_read_test = !test;
|
||||
return (is_write_test && mpi_.my_pe() == 0) ||
|
||||
(is_read_test && mpi_.my_pe() == 1);
|
||||
return (is_write_test && mpi_->my_pe() == 0) ||
|
||||
(is_read_test && mpi_->my_pe() == 1);
|
||||
}
|
||||
|
||||
void execute(TestType test, FN_T fn, const dim3 grid, const dim3 block) {
|
||||
if (mpi_.my_pe()) {
|
||||
mpi_.barrier();
|
||||
mpi_.barrier();
|
||||
if (mpi_->my_pe()) {
|
||||
mpi_->barrier();
|
||||
mpi_->barrier();
|
||||
return;
|
||||
}
|
||||
int *src{nullptr};
|
||||
@@ -160,9 +164,9 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[0]);
|
||||
}
|
||||
size_t bytes = golden_.size() * sizeof(int);
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
launch(fn, grid, block, src, dest, bytes);
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
}
|
||||
|
||||
void validate_dest_buffer(TestType test) {
|
||||
@@ -170,7 +174,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
return;
|
||||
}
|
||||
|
||||
auto dev_dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_.my_pe()]);
|
||||
auto dev_dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_->my_pe()]);
|
||||
for (int i = 0; i < static_cast<int>(golden_.size()); i++) {
|
||||
ASSERT_EQ(golden_[i], dev_dest[i]);
|
||||
}
|
||||
@@ -184,7 +188,7 @@ class IPCImplSimpleCoarse : public ::testing::TestWithParam<std::tuple<int, int,
|
||||
std::vector<int> golden_;
|
||||
|
||||
HEAP_T heap_mem_ {};
|
||||
MPI_T mpi_ {heap_mem_.get_ptr(), heap_mem_.get_size()};
|
||||
MPI_T *mpi_{nullptr};
|
||||
|
||||
IpcImpl ipc_impl_ {};
|
||||
IpcImpl *ipc_impl_dptr_ {nullptr};
|
||||
|
||||
@@ -31,13 +31,13 @@ TEST_P(DegenerateSimpleFine, ptr_check) {
|
||||
}
|
||||
|
||||
TEST_P(DegenerateSimpleFine, MPI_num_pes) {
|
||||
ASSERT_EQ(mpi_.num_pes(), 2);
|
||||
ASSERT_EQ(mpi_->num_pes(), 2);
|
||||
}
|
||||
|
||||
TEST_P(DegenerateSimpleFine, IPC_bases) {
|
||||
ASSERT_EQ(mpi_.num_pes(), 2);
|
||||
ASSERT_EQ(mpi_->num_pes(), 2);
|
||||
ASSERT_NE(ipc_impl_.ipc_bases, nullptr);
|
||||
for(int i{0}; i < mpi_.num_pes(); i++) {
|
||||
for(int i{0}; i < mpi_->num_pes(); i++) {
|
||||
ASSERT_NE(ipc_impl_.ipc_bases[i], nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,7 +140,10 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
|
||||
public:
|
||||
IPCImplSimpleFine() {
|
||||
ipc_impl_.ipcHostInit(mpi_.my_pe(), mpi_.get_heap_bases() , MPI_COMM_WORLD);
|
||||
MPIInstance::mpilib_dl_init();
|
||||
mpi_ = new MPI_T (heap_mem_.get_ptr(), heap_mem_.get_size(), MPI_COMM_WORLD);
|
||||
|
||||
ipc_impl_.ipcHostInit(mpi_->my_pe(), mpi_->get_heap_bases(), MPI_COMM_WORLD);
|
||||
|
||||
assert(ipc_impl_dptr_ == nullptr);
|
||||
hip_allocator_.allocate((void**)&ipc_impl_dptr_, sizeof(IpcImpl));
|
||||
@@ -163,6 +166,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
}
|
||||
|
||||
ipc_impl_.ipcHostStop();
|
||||
MPIInstance::mpilib_dl_close();
|
||||
}
|
||||
|
||||
void launch(FN_T1 f, const dim3 grid, const dim3 block, int* src, int* dest, size_t bytes, TestType test) {
|
||||
@@ -214,7 +218,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
|
||||
void initialize_signal(TestType test) {
|
||||
bool is_write_test = test;
|
||||
if (is_write_test && mpi_.my_pe() == 0) {
|
||||
if (is_write_test && mpi_->my_pe() == 0) {
|
||||
int *dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[1]);
|
||||
*(dest + SIGNAL_OFFSET) = 0;
|
||||
}
|
||||
@@ -225,27 +229,27 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
return;
|
||||
}
|
||||
size_t bytes = golden_.size() * sizeof(int);
|
||||
auto dev_src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_.my_pe()]);
|
||||
auto dev_src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_->my_pe()]);
|
||||
CHECK_HIP(hipMemcpy(dev_src, golden_.data(), bytes, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
bool pe_initializes_src_buffer(TestType test) {
|
||||
bool is_write_test = test;
|
||||
bool is_read_test = !test;
|
||||
return (is_write_test && mpi_.my_pe() == 0) ||
|
||||
(is_read_test && mpi_.my_pe() == 1);
|
||||
return (is_write_test && mpi_->my_pe() == 0) ||
|
||||
(is_read_test && mpi_->my_pe() == 1);
|
||||
}
|
||||
|
||||
void execute(TestType test, FN_T1 fn, const dim3 grid, const dim3 block) {
|
||||
size_t bytes = golden_.size() * sizeof(int);
|
||||
if (mpi_.my_pe()) {
|
||||
mpi_.barrier();
|
||||
if (mpi_->my_pe()) {
|
||||
mpi_->barrier();
|
||||
if (test == WRITE) {
|
||||
int *dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[1]);
|
||||
FN_T2 val_fn = kernel_put_with_signal_simple_validator;
|
||||
launch(val_fn, grid, block, dest, bytes);
|
||||
}
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
return;
|
||||
}
|
||||
int *src{nullptr};
|
||||
@@ -257,9 +261,9 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[1]);
|
||||
dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[0]);
|
||||
}
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
launch(fn, grid, block, src, dest, bytes, test);
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
}
|
||||
|
||||
void check_device_validation_errors(TestType test) {
|
||||
@@ -274,7 +278,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
return;
|
||||
}
|
||||
|
||||
auto dev_dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_.my_pe()]);
|
||||
auto dev_dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_->my_pe()]);
|
||||
for (int i = 0; i < static_cast<int>(golden_.size()); i++) {
|
||||
ASSERT_EQ(golden_[i], dev_dest[i]);
|
||||
}
|
||||
@@ -291,7 +295,7 @@ class IPCImplSimpleFine : public ::testing::TestWithParam<std::tuple<int, int, i
|
||||
|
||||
HEAP_T heap_mem_ {};
|
||||
|
||||
MPI_T mpi_ {heap_mem_.get_ptr(), heap_mem_.get_size()};
|
||||
MPI_T *mpi_{nullptr};
|
||||
|
||||
std::vector<int> golden_;
|
||||
|
||||
|
||||
@@ -33,13 +33,13 @@ TEST_F(DegenerateTiledFine, ptr_check) {
|
||||
}
|
||||
|
||||
TEST_F(DegenerateTiledFine, MPI_num_pes) {
|
||||
ASSERT_EQ(mpi_.num_pes(), 2);
|
||||
ASSERT_EQ(mpi_->num_pes(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DegenerateTiledFine, IPC_bases) {
|
||||
ASSERT_EQ(mpi_.num_pes(), 2);
|
||||
ASSERT_EQ(mpi_->num_pes(), 2);
|
||||
ASSERT_NE(ipc_impl_.ipc_bases, nullptr);
|
||||
for(int i{0}; i < mpi_.num_pes(); i++) {
|
||||
for(int i{0}; i < mpi_->num_pes(); i++) {
|
||||
ASSERT_NE(ipc_impl_.ipc_bases[i], nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +158,10 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
|
||||
public:
|
||||
IPCImplTiledFine() {
|
||||
ipc_impl_.ipcHostInit(mpi_.my_pe(), mpi_.get_heap_bases() , MPI_COMM_WORLD);
|
||||
MPIInstance::mpilib_dl_init();
|
||||
mpi_ = new MPI_T (heap_mem_.get_ptr(), heap_mem_.get_size(), MPI_COMM_WORLD);
|
||||
|
||||
ipc_impl_.ipcHostInit(mpi_->my_pe(), mpi_->get_heap_bases(), MPI_COMM_WORLD);
|
||||
|
||||
assert(ipc_impl_dptr_ == nullptr);
|
||||
hip_allocator_.allocate((void**)&ipc_impl_dptr_, sizeof(IpcImpl));
|
||||
@@ -181,6 +184,7 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
}
|
||||
|
||||
ipc_impl_.ipcHostStop();
|
||||
MPIInstance::mpilib_dl_close();
|
||||
}
|
||||
|
||||
void launch(FN_T1 f, const dim3 grid, const dim3 block, int* src, int* dest, size_t bytes, TestType test) {
|
||||
@@ -232,7 +236,7 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
|
||||
void initialize_signal(TestType test, int signal_value = 0) {
|
||||
bool is_write_test = test;
|
||||
if (is_write_test && mpi_.my_pe() == 0) {
|
||||
if (is_write_test && mpi_->my_pe() == 0) {
|
||||
int *dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[1]);
|
||||
*(dest + SIGNAL_OFFSET) = signal_value;
|
||||
}
|
||||
@@ -243,28 +247,28 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
return;
|
||||
}
|
||||
size_t bytes = golden_.size() * sizeof(int);
|
||||
auto dev_src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_.my_pe()]);
|
||||
auto dev_src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_->my_pe()]);
|
||||
CHECK_HIP(hipMemcpy(dev_src, golden_.data(), bytes, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
bool pe_initializes_src_buffer(TestType test) {
|
||||
bool is_write_test = test;
|
||||
bool is_read_test = !test;
|
||||
return (is_write_test && mpi_.my_pe() == 0) ||
|
||||
(is_read_test && mpi_.my_pe() == 1);
|
||||
return (is_write_test && mpi_->my_pe() == 0) ||
|
||||
(is_read_test && mpi_->my_pe() == 1);
|
||||
}
|
||||
|
||||
void execute(TestType test, FN_T1 fn, const dim3 grid, const dim3 block) {
|
||||
size_t bytes = golden_.size() * sizeof(int);
|
||||
if (mpi_.my_pe()) {
|
||||
mpi_.barrier();
|
||||
if (mpi_->my_pe()) {
|
||||
mpi_->barrier();
|
||||
if (test == WRITE) {
|
||||
int *dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[1]);
|
||||
FN_T2 val_fn = kernel_put_with_signal_tiled_validator;
|
||||
launch(val_fn, grid, block, dest, bytes);
|
||||
ASSERT_EQ(*(dest + SIGNAL_OFFSET), 0);
|
||||
}
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
return;
|
||||
}
|
||||
int *src{nullptr};
|
||||
@@ -276,9 +280,9 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
src = reinterpret_cast<int*>(ipc_impl_.ipc_bases[1]);
|
||||
dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[0]);
|
||||
}
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
launch(fn, grid, block, src, dest, bytes, test);
|
||||
mpi_.barrier();
|
||||
mpi_->barrier();
|
||||
}
|
||||
|
||||
void check_device_validation_errors(TestType test) {
|
||||
@@ -293,7 +297,7 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
return;
|
||||
}
|
||||
|
||||
auto dev_dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_.my_pe()]);
|
||||
auto dev_dest = reinterpret_cast<int*>(ipc_impl_.ipc_bases[mpi_->my_pe()]);
|
||||
for (int i = 0; i < static_cast<int>(golden_.size()); i++) {
|
||||
ASSERT_EQ(golden_[i], dev_dest[i]);
|
||||
}
|
||||
@@ -310,7 +314,7 @@ class IPCImplTiledFine : public ::testing::TestWithParam<std::tuple<int, int, in
|
||||
|
||||
HEAP_T heap_mem_ {};
|
||||
|
||||
MPI_T mpi_ {heap_mem_.get_ptr(), heap_mem_.get_size()};
|
||||
MPI_T *mpi_ {nullptr};
|
||||
|
||||
std::vector<int> golden_;
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "../src/memory/heap_memory.hpp"
|
||||
#include "../src/memory/hip_allocator.hpp"
|
||||
#include "../src/memory/remote_heap_info.hpp"
|
||||
@@ -55,7 +57,8 @@ class RemoteHeapInfoTestFixture : public ::testing::Test
|
||||
* @brief Remote heap info with MPI Communicator
|
||||
*/
|
||||
MPI_T mpi_ {heap_mem_.get_ptr(),
|
||||
heap_mem_.get_size()};
|
||||
heap_mem_.get_size(),
|
||||
MPI_COMM_WORLD};
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -30,27 +30,27 @@ TEST_F(SymmetricHeapTestFixture, malloc_free) {
|
||||
void *ptr{nullptr};
|
||||
size_t request_bytes{48};
|
||||
|
||||
symmetric_heap_.malloc(&ptr, request_bytes);
|
||||
symmetric_heap_->malloc(&ptr, request_bytes);
|
||||
ASSERT_NE(ptr, nullptr);
|
||||
ASSERT_NO_FATAL_FAILURE(symmetric_heap_.free(ptr));
|
||||
ASSERT_NO_FATAL_FAILURE(symmetric_heap_->free(ptr));
|
||||
}
|
||||
|
||||
TEST_F(SymmetricHeapTestFixture, window_info) {
|
||||
auto win_info_ptr{symmetric_heap_.get_window_info()};
|
||||
auto win_info_ptr{symmetric_heap_->get_window_info()};
|
||||
|
||||
WindowInfoMPI* window_info_mpi = dynamic_cast<WindowInfoMPI*>(win_info_ptr);
|
||||
if (window_info_mpi) {
|
||||
void *window_base_addr{nullptr};
|
||||
int flag{0};
|
||||
MPI_Win_get_attr(window_info_mpi->get_win(), MPI_WIN_BASE, &window_base_addr,
|
||||
&flag);
|
||||
&flag);
|
||||
ASSERT_NE(0, flag);
|
||||
ASSERT_NE(nullptr, window_base_addr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SymmetricHeapTestFixture, heap_bases) {
|
||||
auto heap_bases{symmetric_heap_.get_heap_bases()};
|
||||
auto heap_bases{symmetric_heap_->get_heap_bases()};
|
||||
for (const auto &base : heap_bases) {
|
||||
ASSERT_NE(nullptr, base);
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@
|
||||
#ifndef ROCSHMEM_SYMMETRIC_HEAP_GTEST_HPP
|
||||
#define ROCSHMEM_SYMMETRIC_HEAP_GTEST_HPP
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "../src/memory/symmetric_heap.hpp"
|
||||
@@ -37,7 +39,16 @@ class SymmetricHeapTestFixture : public ::testing::Test
|
||||
/**
|
||||
* @brief Symmetric heap object
|
||||
*/
|
||||
SymmetricHeap symmetric_heap_ {MPI_COMM_WORLD};
|
||||
SymmetricHeap *symmetric_heap_;
|
||||
|
||||
void SetUp() override {
|
||||
MPIInstance::mpilib_dl_init();
|
||||
symmetric_heap_ = new SymmetricHeap(MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
MPIInstance::mpilib_dl_close();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
Viittaa uudesa ongelmassa
Block a user