diff --git a/projects/rocshmem/examples/rocshmem_allreduce_test.cc b/projects/rocshmem/examples/rocshmem_allreduce_test.cc index cdf8d21c01..50f4483784 100644 --- a/projects/rocshmem/examples/rocshmem_allreduce_test.cc +++ b/projects/rocshmem/examples/rocshmem_allreduce_test.cc @@ -54,19 +54,9 @@ */ -#include - -#include -#include #include -#define CHECK_HIP(condition) { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ - } \ - } +#include "util.h" using namespace rocshmem; @@ -123,16 +113,13 @@ int main (int argc, char **argv) nelem = atoi(argv[1]); } - int my_pe = rocshmem_my_pe(); - int npes = rocshmem_n_pes(); - - int ndevices, my_device = 0; - CHECK_HIP(hipGetDeviceCount(&ndevices)); - my_device = my_pe % ndevices; - CHECK_HIP(hipSetDevice(my_device)); + CHECK_HIP(hipSetDevice(get_launcher_local_rank())); rocshmem_init(); + int my_pe = rocshmem_my_pe(); + int npes = rocshmem_n_pes(); + int *source = (int *)rocshmem_malloc(nelem * sizeof(int)); int *dest = (int *)rocshmem_malloc(nelem * sizeof(int)); if (NULL == source || NULL == dest) { diff --git a/projects/rocshmem/examples/rocshmem_alltoall_test.cc b/projects/rocshmem/examples/rocshmem_alltoall_test.cc index 3d5f0fb153..90d7cc9437 100644 --- a/projects/rocshmem/examples/rocshmem_alltoall_test.cc +++ b/projects/rocshmem/examples/rocshmem_alltoall_test.cc @@ -54,19 +54,9 @@ */ -#include - -#include -#include #include -#define CHECK_HIP(condition) { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ - } \ - } +#include "util.h" using namespace rocshmem; @@ -128,16 +118,13 @@ int main (int argc, char **argv) nelem = atoi(argv[1]); } - int my_pe = rocshmem_my_pe(); - int npes = rocshmem_n_pes(); - - int ndevices, my_device = 0; - CHECK_HIP(hipGetDeviceCount(&ndevices)); - my_device = my_pe % ndevices; - CHECK_HIP(hipSetDevice(my_device)); + CHECK_HIP(hipSetDevice(get_launcher_local_rank())); rocshmem_init(); + int my_pe = rocshmem_my_pe(); + int npes = rocshmem_n_pes(); + int *source = (int *)rocshmem_malloc(nelem * npes * sizeof(int)); int *dest = (int *)rocshmem_malloc(nelem * npes * sizeof(int)); if (NULL == source || NULL == dest) { diff --git a/projects/rocshmem/examples/rocshmem_broadcast_test.cc b/projects/rocshmem/examples/rocshmem_broadcast_test.cc index 382dd1237c..7fa895bed5 100644 --- a/projects/rocshmem/examples/rocshmem_broadcast_test.cc +++ b/projects/rocshmem/examples/rocshmem_broadcast_test.cc @@ -54,19 +54,9 @@ */ -#include - -#include -#include #include -#define CHECK_HIP(condition) { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ - } \ - } +#include "util.h" using namespace rocshmem; @@ -121,16 +111,13 @@ int main(int argc, char **argv) nelem = atoi(argv[1]); } - int my_pe = rocshmem_my_pe(); - int npes = rocshmem_n_pes(); - - int ndevices, my_device = 0; - CHECK_HIP(hipGetDeviceCount(&ndevices)); - my_device = my_pe % ndevices; - CHECK_HIP(hipSetDevice(my_device)); + CHECK_HIP(hipSetDevice(get_launcher_local_rank())); rocshmem_init(); + int my_pe = rocshmem_my_pe(); + int npes = rocshmem_n_pes(); + int *source = (int *)rocshmem_malloc(nelem * sizeof(int)); int *dest = (int *)rocshmem_malloc(nelem * sizeof(int)); if (NULL == source || NULL == dest) { diff --git a/projects/rocshmem/examples/rocshmem_getmem_test.cc b/projects/rocshmem/examples/rocshmem_getmem_test.cc index 13bc05d3f2..3c77dd43ab 100644 --- a/projects/rocshmem/examples/rocshmem_getmem_test.cc +++ b/projects/rocshmem/examples/rocshmem_getmem_test.cc @@ -54,19 +54,9 @@ */ -#include - -#include -#include #include -#define CHECK_HIP(condition) { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ - } \ - } +#include "util.h" using namespace rocshmem; @@ -76,8 +66,8 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem) int threadId = blockIdx.x * blockDim.x + threadIdx.x; if (threadId == 0) { - int rank = rocshmem_my_pe(); - int peer = rank ? 0 : 1; + int my_pe = rocshmem_my_pe(); + int peer = my_pe ? 0 : 1; rocshmem_getmem(dst, src, nelem * sizeof(int), peer); rocshmem_quiet(); } @@ -90,19 +80,19 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem) int main (int argc, char **argv) { - int rank = rocshmem_my_pe(); - int ndevices, my_device = 0; - CHECK_HIP(hipGetDeviceCount(&ndevices)); - my_device = rank % ndevices; - CHECK_HIP(hipSetDevice(my_device)); int nelem = MAX_ELEM; if (argc > 1) { nelem = atoi(argv[1]); } + CHECK_HIP(hipSetDevice(get_launcher_local_rank())); + rocshmem_init(); + + int my_pe = rocshmem_my_pe(); int npes = rocshmem_n_pes(); + int *src = (int *)rocshmem_malloc(nelem * sizeof(int)); int *dst = (int *)rocshmem_malloc(nelem * sizeof(int)); if (NULL == src || NULL == dst) { @@ -128,7 +118,7 @@ int main (int argc, char **argv) if (dst[i] != 0) { pass = false; #if VERBOSE - printf("[%d] Error in element %d expected 0 got %d\n", rank, i, dst[i]); + printf("[%d] Error in element %d expected 0 got %d\n", my_pe, i, dst[i]); #endif } } diff --git a/projects/rocshmem/examples/rocshmem_init_attr_test.cc b/projects/rocshmem/examples/rocshmem_init_attr_test.cc index ef1353077c..d4c061dff5 100644 --- a/projects/rocshmem/examples/rocshmem_init_attr_test.cc +++ b/projects/rocshmem/examples/rocshmem_init_attr_test.cc @@ -54,19 +54,9 @@ */ -#include - -#include -#include #include -#define CHECK_HIP(condition) { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ - } \ - } +#include "util.h" using namespace rocshmem; diff --git a/projects/rocshmem/examples/rocshmem_put_signal_test.cc b/projects/rocshmem/examples/rocshmem_put_signal_test.cc index 08a38355ba..27086f4c07 100644 --- a/projects/rocshmem/examples/rocshmem_put_signal_test.cc +++ b/projects/rocshmem/examples/rocshmem_put_signal_test.cc @@ -54,19 +54,9 @@ */ -#include - -#include -#include #include -#define CHECK_HIP(condition) { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ - MPI_Abort(MPI_COMM_WORLD, error); \ - } \ - } +#include "util.h" using namespace rocshmem; @@ -95,20 +85,20 @@ __global__ void simple_put_signal_test(uint64_t *data, uint64_t *message, size_t int main (int argc, char **argv) { - int rank = rocshmem_my_pe(); - int ndevices, my_device = 0; - CHECK_HIP(hipGetDeviceCount(&ndevices)); - my_device = rank % ndevices; - CHECK_HIP(hipSetDevice(my_device)); int nelem = MAX_ELEM; if (argc > 1) { nelem = atoi(argv[1]); } + CHECK_HIP(hipSetDevice(get_launcher_local_rank())); + rocshmem_init(); + + int my_pe = rocshmem_my_pe(); int npes = rocshmem_n_pes(); - int dst_pe = (rank + 1) % npes; + + int dst_pe = (my_pe + 1) % npes; uint64_t *message = (uint64_t*)rocshmem_malloc(nelem * sizeof(uint64_t)); uint64_t *data = (uint64_t*)rocshmem_malloc(nelem * sizeof(uint64_t)); uint64_t *sig_addr = (uint64_t*)rocshmem_malloc(sizeof(uint64_t)); @@ -123,14 +113,14 @@ int main (int argc, char **argv) } for (int i=0; i>>(data, message, nelem, sig_addr, rank, dst_pe); + simple_put_signal_test<<>>(data, message, nelem, sig_addr, my_pe, dst_pe); rocshmem_barrier_all(); CHECK_HIP(hipDeviceSynchronize()); @@ -139,11 +129,11 @@ int main (int argc, char **argv) if (data[i] != 0) { pass = false; #if VERBOSE - printf("[%d] Error in element %d expected 0 got %d\n", rank, i, dst[i]); + printf("[%d] Error in element %d expected 0 got %d\n", my_pe, i, dst[i]); #endif } } - printf("[%d] Test %s \t %s\n", rank, argv[0], pass ? "[PASS]" : "[FAIL]"); + printf("[%d] Test %s \t %s\n", my_pe, argv[0], pass ? "[PASS]" : "[FAIL]"); rocshmem_free(data); rocshmem_free(message); diff --git a/projects/rocshmem/examples/util.h b/projects/rocshmem/examples/util.h new file mode 100644 index 0000000000..3118f52f89 --- /dev/null +++ b/projects/rocshmem/examples/util.h @@ -0,0 +1,52 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef __ROCSHMEM_EXAMPLES_UTIL_H__ +#define __ROCSHMEM_EXAMPLES_UTIL_H__ + +#include + +#include +#include + +#define CHECK_HIP(condition) { \ + hipError_t error = condition; \ + if(error != hipSuccess){ \ + fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ + MPI_Abort(MPI_COMM_WORLD, error); \ + } \ + } + +static int get_launcher_local_rank() { + char *local_rank_str = nullptr; + + local_rank_str = getenv("OMPI_COMM_WORLD_LOCAL_RANK"); + if (nullptr != local_rank_str) { + return atoi(local_rank_str); + } + + return -1; +} + +#endif /* __ROCSHMEM_EXAMPLES_UTIL_H__ */ diff --git a/projects/rocshmem/src/CMakeLists.txt b/projects/rocshmem/src/CMakeLists.txt index 41cec7ded1..498d93cffd 100644 --- a/projects/rocshmem/src/CMakeLists.txt +++ b/projects/rocshmem/src/CMakeLists.txt @@ -32,7 +32,7 @@ target_sources( backend_bc.cpp context_host.cpp context_device.cpp - mpi_init_singleton.cpp + mpi_instance.cpp rocshmem_gpu.cpp rocshmem.cpp team.cpp diff --git a/projects/rocshmem/src/backend_bc.cpp b/projects/rocshmem/src/backend_bc.cpp index d46ee2c1c4..cb9a484c47 100644 --- a/projects/rocshmem/src/backend_bc.cpp +++ b/projects/rocshmem/src/backend_bc.cpp @@ -86,18 +86,7 @@ Backend::Backend(MPI_Comm comm) : heap{comm} { } void Backend::init_mpi_once(MPI_Comm comm) { - int init_done{}; - NET_CHECK(MPI_Initialized(&init_done)); - - int provided{}; - if (!init_done) { - NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided)); - if (provided != MPI_THREAD_MULTIPLE) { - fprintf(stderr, "MPI_THREAD_MULTIPLE support disabled.\n"); - } - } if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD; - NET_CHECK(MPI_Comm_dup(comm, &backend_comm)); NET_CHECK(MPI_Comm_size(backend_comm, &num_pes)); NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe)); diff --git a/projects/rocshmem/src/memory/remote_heap_info.hpp b/projects/rocshmem/src/memory/remote_heap_info.hpp index 71f0483c99..e64a48c023 100644 --- a/projects/rocshmem/src/memory/remote_heap_info.hpp +++ b/projects/rocshmem/src/memory/remote_heap_info.hpp @@ -55,15 +55,8 @@ class CommunicatorMPI { CommunicatorMPI(char* heap_base, size_t heap_size, MPI_Comm comm = MPI_COMM_WORLD) : comm_{comm} { - int initialized; - MPI_Initialized(&initialized); - if (!initialized) { - int provided; - MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); - } MPI_Comm_rank(comm_, &my_pe_); MPI_Comm_size(comm_, &num_pes_); - heap_window_info_ = WindowInfo(comm_, heap_base, heap_size); } diff --git a/projects/rocshmem/src/mpi_init_singleton.cpp b/projects/rocshmem/src/mpi_instance.cpp similarity index 76% rename from projects/rocshmem/src/mpi_init_singleton.cpp rename to projects/rocshmem/src/mpi_instance.cpp index efb4039869..37ff3b4228 100644 --- a/projects/rocshmem/src/mpi_init_singleton.cpp +++ b/projects/rocshmem/src/mpi_instance.cpp @@ -22,13 +22,11 @@ * IN THE SOFTWARE. *****************************************************************************/ -#include "mpi_init_singleton.hpp" +#include "mpi_instance.hpp" namespace rocshmem { -MPIInitSingleton* MPIInitSingleton::instance{nullptr}; - -MPIInitSingleton::MPIInitSingleton() { +MPIInstance::MPIInstance(MPI_Comm comm) { MPI_Initialized(&pre_init_done); if (!pre_init_done) { @@ -36,11 +34,15 @@ MPIInitSingleton::MPIInitSingleton() { MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided); } - MPI_Comm_size(MPI_COMM_WORLD, &nprocs_); - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_); + if (comm == MPI_COMM_NULL) { + comm = MPI_COMM_WORLD; + } + + MPI_Comm_size(comm, &nprocs_); + MPI_Comm_rank(comm, &my_rank_); } -MPIInitSingleton::~MPIInitSingleton() { +MPIInstance::~MPIInstance() { int finalized{0}; MPI_Finalized(&finalized); if (!finalized && !pre_init_done) { @@ -48,16 +50,8 @@ MPIInitSingleton::~MPIInitSingleton() { } } -MPIInitSingleton* MPIInitSingleton::GetInstance() { - if (!instance) { - instance = new MPIInitSingleton(); - return instance; - } - return instance; -} +int MPIInstance::get_rank() { return my_rank_; } -int MPIInitSingleton::get_rank() { return my_rank_; } - -int MPIInitSingleton::get_nprocs() { return nprocs_; } +int MPIInstance::get_nprocs() { return nprocs_; } } // namespace rocshmem diff --git a/projects/rocshmem/src/mpi_init_singleton.hpp b/projects/rocshmem/src/mpi_instance.hpp similarity index 57% rename from projects/rocshmem/src/mpi_init_singleton.hpp rename to projects/rocshmem/src/mpi_instance.hpp index bcf5bb552c..a9b83fd88d 100644 --- a/projects/rocshmem/src/mpi_init_singleton.hpp +++ b/projects/rocshmem/src/mpi_instance.hpp @@ -22,77 +22,64 @@ * IN THE SOFTWARE. *****************************************************************************/ -#ifndef LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_ -#define LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_ +#ifndef LIBRARY_SRC_MPI_INSTANCE_HPP_ +#define LIBRARY_SRC_MPI_INSTANCE_HPP_ #include #include /** - * @file mpi_init_singleton.hpp + * @file mpi_instance.hpp * * @brief Contains MPI library initialization code */ namespace rocshmem { -class MPIInitSingleton { - private: - /** - * @brief Primary constructor - */ - MPIInitSingleton(); +class MPIInstance { + public: + /** + * @brief Primary constructor + */ + MPIInstance(MPI_Comm comm); - public: - /** - * @brief Destructor - */ - ~MPIInitSingleton(); + /** + * @brief Destructor + */ + ~MPIInstance(); - /** - * @brief Invoke singleton construction or return handle - * - * @return Initialized handle to singleton - */ - static MPIInitSingleton* GetInstance(); + /** + * @brief Accessor for my COMM_WORLD rank identifier + * + * @return My COMM_WORLD rank identifier + */ + int get_rank(); - /** - * @brief Accessor for my COMM_WORLD rank identifier - * - * @return My COMM_WORLD rank identifier - */ - int get_rank(); + /** + * @brief Accessor for number or processes in COMM_WORLD + * + * @return Number of processes in COMM_WORLD + */ + int get_nprocs(); - /** - * @brief Accessor for number or processes in COMM_WORLD - * - * @return Number of processes in COMM_WORLD - */ - int get_nprocs(); + private: + /** + * @brief My MPI rank identifier + */ + int my_rank_{-1}; - private: - /** - * @brief My MPI rank identifier - */ - int my_rank_{-1}; + /** + * @brief Number of MPI processes + */ + int nprocs_{-1}; - /** - * @brief Number of MPI processes - */ - int nprocs_{-1}; - - /** - * @brief Was MPI initialized before rocshmem_init call - */ - int pre_init_done{0}; - - /** - * @brief Refers to global variable - */ - static MPIInitSingleton* instance; + /** + * @brief Was MPI initialized before rocshmem_init call + */ + int pre_init_done{0}; }; } // namespace rocshmem -#endif // LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_ +#endif // LIBRARY_SRC_MPI_INSTANCE_HPP_ diff --git a/projects/rocshmem/src/reverse_offload/mpi_transport.cpp b/projects/rocshmem/src/reverse_offload/mpi_transport.cpp index 761721dd19..53812b565d 100644 --- a/projects/rocshmem/src/reverse_offload/mpi_transport.cpp +++ b/projects/rocshmem/src/reverse_offload/mpi_transport.cpp @@ -46,16 +46,7 @@ namespace rocshmem { MPITransport::MPITransport(MPI_Comm comm, Queue* q) : queue{q}, Transport{} { - int init_done{}; - NET_CHECK(MPI_Initialized(&init_done)); - int provided{}; - if (!init_done) { - NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided)); - if (provided != MPI_THREAD_MULTIPLE) { - fprintf(stderr, "MPI_THREAD_MULTIPLE support disabled.\n"); - } - } assert(comm != MPI_COMM_NULL); NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world)); diff --git a/projects/rocshmem/src/rocshmem.cpp b/projects/rocshmem/src/rocshmem.cpp index 8fb5c08ea5..94e42aa410 100644 --- a/projects/rocshmem/src/rocshmem.cpp +++ b/projects/rocshmem/src/rocshmem.cpp @@ -47,7 +47,7 @@ #include "ipc/backend_ipc.hpp" #include "ipc/context_ipc_tmpl_host.hpp" #endif -#include "mpi_init_singleton.hpp" +#include "mpi_instance.hpp" #include "team.hpp" #include "templates_host.hpp" #include "util.hpp" @@ -67,6 +67,7 @@ namespace rocshmem { } Backend *backend = nullptr; +MPIInstance *mpi_instance = nullptr; rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; @@ -86,6 +87,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; rocm_init(); + mpi_instance = new MPIInstance(comm); + #ifdef USE_RO CHECK_HIP(hipHostMalloc(&backend, sizeof(ROBackend))); backend = new (backend) ROBackend(comm); @@ -103,7 +106,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; rocshmem_init_attr_t *attr) { MPI_Comm comm = MPI_COMM_NULL; - if ((attr == nullptr) || + if ((attr == nullptr) || ((flags != ROCSHMEM_INIT_WITH_UNIQUEID) && (flags != ROCSHMEM_INIT_WITH_MPI_COMM)) ) { fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", @@ -224,24 +227,21 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } [[maybe_unused]] __host__ int rocshmem_my_pe() { - if(backend == nullptr) { - MPIInitSingleton *s = s->GetInstance(); - return s->get_rank(); - } - else - { - return backend->getMyPE(); + if (mpi_instance != nullptr) { + return mpi_instance->get_rank(); } + + fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n"); + return -1; } [[maybe_unused]] __host__ int rocshmem_n_pes() { - if(backend == nullptr) { - MPIInitSingleton *s = s->GetInstance(); - return s->get_nprocs(); - } - else { - return backend->getNumPEs(); + if (mpi_instance != nullptr) { + return mpi_instance->get_nprocs(); } + + fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n"); + return -1; } [[maybe_unused]] __host__ void *rocshmem_malloc(size_t size) { @@ -294,7 +294,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; backend->~Backend(); CHECK_HIP(hipHostFree(backend)); - delete MPIInitSingleton::GetInstance(); + delete mpi_instance; } __host__ void rocshmem_query_thread(int *provided) { diff --git a/projects/rocshmem/tests/functional_tests/test_driver.cpp b/projects/rocshmem/tests/functional_tests/test_driver.cpp index 5ab79f59f7..1214c2b6be 100644 --- a/projects/rocshmem/tests/functional_tests/test_driver.cpp +++ b/projects/rocshmem/tests/functional_tests/test_driver.cpp @@ -39,11 +39,8 @@ int main(int argc, char *argv[]) { /*** * Select a GPU */ - int rank = rocshmem_my_pe(); - int ndevices, my_device = 0; - CHECK_HIP(hipGetDeviceCount(&ndevices)); - my_device = rank % ndevices; - CHECK_HIP(hipSetDevice(my_device)); + char* ompi_local_rank = getenv("OMPI_COMM_WORLD_LOCAL_RANK"); + CHECK_HIP(hipSetDevice(atoi(ompi_local_rank))); /** * Must initialize rocshmem to access arguments needed by the tester. diff --git a/projects/rocshmem/tests/unit_tests/CMakeLists.txt b/projects/rocshmem/tests/unit_tests/CMakeLists.txt index 255d39d728..401b2eda62 100644 --- a/projects/rocshmem/tests/unit_tests/CMakeLists.txt +++ b/projects/rocshmem/tests/unit_tests/CMakeLists.txt @@ -86,7 +86,7 @@ target_sources( pow2_bins_gtest.cpp dlmalloc_gtest.cpp remote_heap_info_gtest.cpp - mpi_init_singleton_gtest.cpp + mpi_instance_gtest.cpp abql_block_mutex_gtest.cpp notifier_gtest.cpp free_list_gtest.cpp diff --git a/projects/rocshmem/tests/unit_tests/mpi_init_singleton_gtest.cpp b/projects/rocshmem/tests/unit_tests/mpi_instance_gtest.cpp similarity index 88% rename from projects/rocshmem/tests/unit_tests/mpi_init_singleton_gtest.cpp rename to projects/rocshmem/tests/unit_tests/mpi_instance_gtest.cpp index 06fd25231d..8e6055c8a2 100644 --- a/projects/rocshmem/tests/unit_tests/mpi_init_singleton_gtest.cpp +++ b/projects/rocshmem/tests/unit_tests/mpi_instance_gtest.cpp @@ -22,16 +22,16 @@ * IN THE SOFTWARE. *****************************************************************************/ -#include "mpi_init_singleton_gtest.hpp" +#include "mpi_instance_gtest.hpp" using namespace rocshmem; -TEST_F(MPIInitSingletonTestFixture, library_initialize_destroy) {} +TEST_F(MPIInstanceTestFixture, library_initialize_destroy) {} -TEST_F(MPIInitSingletonTestFixture, rank) { +TEST_F(MPIInstanceTestFixture, rank) { ASSERT_NO_FATAL_FAILURE(s_ptr_->get_rank()); } -TEST_F(MPIInitSingletonTestFixture, nprocs) { +TEST_F(MPIInstanceTestFixture, nprocs) { ASSERT_EQ(s_ptr_->get_nprocs(), 4); } diff --git a/projects/rocshmem/tests/unit_tests/mpi_init_singleton_gtest.hpp b/projects/rocshmem/tests/unit_tests/mpi_instance_gtest.hpp similarity index 76% rename from projects/rocshmem/tests/unit_tests/mpi_init_singleton_gtest.hpp rename to projects/rocshmem/tests/unit_tests/mpi_instance_gtest.hpp index 4f07dafbcc..f40e574938 100644 --- a/projects/rocshmem/tests/unit_tests/mpi_init_singleton_gtest.hpp +++ b/projects/rocshmem/tests/unit_tests/mpi_instance_gtest.hpp @@ -22,29 +22,33 @@ * IN THE SOFTWARE. *****************************************************************************/ -#ifndef ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP -#define ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP +#ifndef ROCSHMEM_MPI_INSTANCE_GTEST_HPP +#define ROCSHMEM_MPI_INSTANCE_GTEST_HPP #include "gtest/gtest.h" -#include "../src/mpi_init_singleton.hpp" +#include "../src/mpi_instance.hpp" namespace rocshmem { -class MPIInitSingletonTestFixture : public ::testing::Test +class MPIInstanceTestFixture : public ::testing::Test { public: - MPIInitSingletonTestFixture() { - s_ptr_ = s_ptr_->GetInstance(); + MPIInstanceTestFixture() { + s_ptr_ = new MPIInstance(MPI_COMM_WORLD); + } + + ~MPIInstanceTestFixture() { + delete s_ptr_; } protected: /** - * @brief A singleton object used to initialize MPI + * @brief A MPI instance object used to initialize MPI */ - MPIInitSingleton* s_ptr_ {nullptr}; + MPIInstance* s_ptr_ {nullptr}; }; } // namespace rocshmem -#endif // ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP +#endif // ROCSHMEM_MPI_INSTANCE_GTEST_HPP