Multi-Node rocshmem_finalize() bug (#138)
This commit is contained in:
@@ -54,19 +54,9 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
#include "util.h"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
@@ -123,16 +113,13 @@ int main (int argc, char **argv)
|
||||
nelem = atoi(argv[1]);
|
||||
}
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int ndevices, my_device = 0;
|
||||
CHECK_HIP(hipGetDeviceCount(&ndevices));
|
||||
my_device = my_pe % ndevices;
|
||||
CHECK_HIP(hipSetDevice(my_device));
|
||||
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
|
||||
|
||||
rocshmem_init();
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int *source = (int *)rocshmem_malloc(nelem * sizeof(int));
|
||||
int *dest = (int *)rocshmem_malloc(nelem * sizeof(int));
|
||||
if (NULL == source || NULL == dest) {
|
||||
|
||||
@@ -54,19 +54,9 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
#include "util.h"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
@@ -128,16 +118,13 @@ int main (int argc, char **argv)
|
||||
nelem = atoi(argv[1]);
|
||||
}
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int ndevices, my_device = 0;
|
||||
CHECK_HIP(hipGetDeviceCount(&ndevices));
|
||||
my_device = my_pe % ndevices;
|
||||
CHECK_HIP(hipSetDevice(my_device));
|
||||
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
|
||||
|
||||
rocshmem_init();
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int *source = (int *)rocshmem_malloc(nelem * npes * sizeof(int));
|
||||
int *dest = (int *)rocshmem_malloc(nelem * npes * sizeof(int));
|
||||
if (NULL == source || NULL == dest) {
|
||||
|
||||
@@ -54,19 +54,9 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
#include "util.h"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
@@ -121,16 +111,13 @@ int main(int argc, char **argv)
|
||||
nelem = atoi(argv[1]);
|
||||
}
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int ndevices, my_device = 0;
|
||||
CHECK_HIP(hipGetDeviceCount(&ndevices));
|
||||
my_device = my_pe % ndevices;
|
||||
CHECK_HIP(hipSetDevice(my_device));
|
||||
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
|
||||
|
||||
rocshmem_init();
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int *source = (int *)rocshmem_malloc(nelem * sizeof(int));
|
||||
int *dest = (int *)rocshmem_malloc(nelem * sizeof(int));
|
||||
if (NULL == source || NULL == dest) {
|
||||
|
||||
@@ -54,19 +54,9 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
#include "util.h"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
@@ -76,8 +66,8 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem)
|
||||
|
||||
int threadId = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (threadId == 0) {
|
||||
int rank = rocshmem_my_pe();
|
||||
int peer = rank ? 0 : 1;
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int peer = my_pe ? 0 : 1;
|
||||
rocshmem_getmem(dst, src, nelem * sizeof(int), peer);
|
||||
rocshmem_quiet();
|
||||
}
|
||||
@@ -90,19 +80,19 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem)
|
||||
|
||||
int main (int argc, char **argv)
|
||||
{
|
||||
int rank = rocshmem_my_pe();
|
||||
int ndevices, my_device = 0;
|
||||
CHECK_HIP(hipGetDeviceCount(&ndevices));
|
||||
my_device = rank % ndevices;
|
||||
CHECK_HIP(hipSetDevice(my_device));
|
||||
int nelem = MAX_ELEM;
|
||||
|
||||
if (argc > 1) {
|
||||
nelem = atoi(argv[1]);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
|
||||
|
||||
rocshmem_init();
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
|
||||
int *src = (int *)rocshmem_malloc(nelem * sizeof(int));
|
||||
int *dst = (int *)rocshmem_malloc(nelem * sizeof(int));
|
||||
if (NULL == src || NULL == dst) {
|
||||
@@ -128,7 +118,7 @@ int main (int argc, char **argv)
|
||||
if (dst[i] != 0) {
|
||||
pass = false;
|
||||
#if VERBOSE
|
||||
printf("[%d] Error in element %d expected 0 got %d\n", rank, i, dst[i]);
|
||||
printf("[%d] Error in element %d expected 0 got %d\n", my_pe, i, dst[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,19 +54,9 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
#include "util.h"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
|
||||
@@ -54,19 +54,9 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
#include "util.h"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
@@ -95,20 +85,20 @@ __global__ void simple_put_signal_test(uint64_t *data, uint64_t *message, size_t
|
||||
|
||||
int main (int argc, char **argv)
|
||||
{
|
||||
int rank = rocshmem_my_pe();
|
||||
int ndevices, my_device = 0;
|
||||
CHECK_HIP(hipGetDeviceCount(&ndevices));
|
||||
my_device = rank % ndevices;
|
||||
CHECK_HIP(hipSetDevice(my_device));
|
||||
int nelem = MAX_ELEM;
|
||||
|
||||
if (argc > 1) {
|
||||
nelem = atoi(argv[1]);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
|
||||
|
||||
rocshmem_init();
|
||||
|
||||
int my_pe = rocshmem_my_pe();
|
||||
int npes = rocshmem_n_pes();
|
||||
int dst_pe = (rank + 1) % npes;
|
||||
|
||||
int dst_pe = (my_pe + 1) % npes;
|
||||
uint64_t *message = (uint64_t*)rocshmem_malloc(nelem * sizeof(uint64_t));
|
||||
uint64_t *data = (uint64_t*)rocshmem_malloc(nelem * sizeof(uint64_t));
|
||||
uint64_t *sig_addr = (uint64_t*)rocshmem_malloc(sizeof(uint64_t));
|
||||
@@ -123,14 +113,14 @@ int main (int argc, char **argv)
|
||||
}
|
||||
|
||||
for (int i=0; i<nelem; i++) {
|
||||
message[i] = rank;
|
||||
message[i] = my_pe;
|
||||
}
|
||||
|
||||
CHECK_HIP(hipMemset(data, 0, (nelem * sizeof(uint64_t))));
|
||||
CHECK_HIP(hipDeviceSynchronize());
|
||||
|
||||
int threadsPerBlock=256;
|
||||
simple_put_signal_test<<<dim3(1), dim3(threadsPerBlock), 0, 0>>>(data, message, nelem, sig_addr, rank, dst_pe);
|
||||
simple_put_signal_test<<<dim3(1), dim3(threadsPerBlock), 0, 0>>>(data, message, nelem, sig_addr, my_pe, dst_pe);
|
||||
rocshmem_barrier_all();
|
||||
CHECK_HIP(hipDeviceSynchronize());
|
||||
|
||||
@@ -139,11 +129,11 @@ int main (int argc, char **argv)
|
||||
if (data[i] != 0) {
|
||||
pass = false;
|
||||
#if VERBOSE
|
||||
printf("[%d] Error in element %d expected 0 got %d\n", rank, i, dst[i]);
|
||||
printf("[%d] Error in element %d expected 0 got %d\n", my_pe, i, dst[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
printf("[%d] Test %s \t %s\n", rank, argv[0], pass ? "[PASS]" : "[FAIL]");
|
||||
printf("[%d] Test %s \t %s\n", my_pe, argv[0], pass ? "[PASS]" : "[FAIL]");
|
||||
|
||||
rocshmem_free(data);
|
||||
rocshmem_free(message);
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef __ROCSHMEM_EXAMPLES_UTIL_H__
|
||||
#define __ROCSHMEM_EXAMPLES_UTIL_H__
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#define CHECK_HIP(condition) { \
|
||||
hipError_t error = condition; \
|
||||
if(error != hipSuccess){ \
|
||||
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
|
||||
MPI_Abort(MPI_COMM_WORLD, error); \
|
||||
} \
|
||||
}
|
||||
|
||||
static int get_launcher_local_rank() {
|
||||
char *local_rank_str = nullptr;
|
||||
|
||||
local_rank_str = getenv("OMPI_COMM_WORLD_LOCAL_RANK");
|
||||
if (nullptr != local_rank_str) {
|
||||
return atoi(local_rank_str);
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
#endif /* __ROCSHMEM_EXAMPLES_UTIL_H__ */
|
||||
@@ -32,7 +32,7 @@ target_sources(
|
||||
backend_bc.cpp
|
||||
context_host.cpp
|
||||
context_device.cpp
|
||||
mpi_init_singleton.cpp
|
||||
mpi_instance.cpp
|
||||
rocshmem_gpu.cpp
|
||||
rocshmem.cpp
|
||||
team.cpp
|
||||
|
||||
@@ -86,18 +86,7 @@ Backend::Backend(MPI_Comm comm) : heap{comm} {
|
||||
}
|
||||
|
||||
void Backend::init_mpi_once(MPI_Comm comm) {
|
||||
int init_done{};
|
||||
NET_CHECK(MPI_Initialized(&init_done));
|
||||
|
||||
int provided{};
|
||||
if (!init_done) {
|
||||
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
|
||||
if (provided != MPI_THREAD_MULTIPLE) {
|
||||
fprintf(stderr, "MPI_THREAD_MULTIPLE support disabled.\n");
|
||||
}
|
||||
}
|
||||
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
|
||||
|
||||
NET_CHECK(MPI_Comm_dup(comm, &backend_comm));
|
||||
NET_CHECK(MPI_Comm_size(backend_comm, &num_pes));
|
||||
NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe));
|
||||
|
||||
@@ -55,15 +55,8 @@ class CommunicatorMPI {
|
||||
CommunicatorMPI(char* heap_base, size_t heap_size,
|
||||
MPI_Comm comm = MPI_COMM_WORLD)
|
||||
: comm_{comm} {
|
||||
int initialized;
|
||||
MPI_Initialized(&initialized);
|
||||
if (!initialized) {
|
||||
int provided;
|
||||
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
|
||||
}
|
||||
MPI_Comm_rank(comm_, &my_pe_);
|
||||
MPI_Comm_size(comm_, &num_pes_);
|
||||
|
||||
heap_window_info_ = WindowInfo(comm_, heap_base, heap_size);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,13 +22,11 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "mpi_init_singleton.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
MPIInitSingleton* MPIInitSingleton::instance{nullptr};
|
||||
|
||||
MPIInitSingleton::MPIInitSingleton() {
|
||||
MPIInstance::MPIInstance(MPI_Comm comm) {
|
||||
MPI_Initialized(&pre_init_done);
|
||||
|
||||
if (!pre_init_done) {
|
||||
@@ -36,11 +34,15 @@ MPIInitSingleton::MPIInitSingleton() {
|
||||
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
|
||||
}
|
||||
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nprocs_);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);
|
||||
if (comm == MPI_COMM_NULL) {
|
||||
comm = MPI_COMM_WORLD;
|
||||
}
|
||||
|
||||
MPI_Comm_size(comm, &nprocs_);
|
||||
MPI_Comm_rank(comm, &my_rank_);
|
||||
}
|
||||
|
||||
MPIInitSingleton::~MPIInitSingleton() {
|
||||
MPIInstance::~MPIInstance() {
|
||||
int finalized{0};
|
||||
MPI_Finalized(&finalized);
|
||||
if (!finalized && !pre_init_done) {
|
||||
@@ -48,16 +50,8 @@ MPIInitSingleton::~MPIInitSingleton() {
|
||||
}
|
||||
}
|
||||
|
||||
MPIInitSingleton* MPIInitSingleton::GetInstance() {
|
||||
if (!instance) {
|
||||
instance = new MPIInitSingleton();
|
||||
return instance;
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
int MPIInstance::get_rank() { return my_rank_; }
|
||||
|
||||
int MPIInitSingleton::get_rank() { return my_rank_; }
|
||||
|
||||
int MPIInitSingleton::get_nprocs() { return nprocs_; }
|
||||
int MPIInstance::get_nprocs() { return nprocs_; }
|
||||
|
||||
} // namespace rocshmem
|
||||
@@ -22,77 +22,64 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_
|
||||
#define LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_
|
||||
#ifndef LIBRARY_SRC_MPI_INSTANCE_HPP_
|
||||
#define LIBRARY_SRC_MPI_INSTANCE_HPP_
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
/**
|
||||
* @file mpi_init_singleton.hpp
|
||||
* @file mpi_instance.hpp
|
||||
*
|
||||
* @brief Contains MPI library initialization code
|
||||
*/
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
class MPIInitSingleton {
|
||||
private:
|
||||
/**
|
||||
* @brief Primary constructor
|
||||
*/
|
||||
MPIInitSingleton();
|
||||
class MPIInstance {
|
||||
public:
|
||||
/**
|
||||
* @brief Primary constructor
|
||||
*/
|
||||
MPIInstance(MPI_Comm comm);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
~MPIInitSingleton();
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
~MPIInstance();
|
||||
|
||||
/**
|
||||
* @brief Invoke singleton construction or return handle
|
||||
*
|
||||
* @return Initialized handle to singleton
|
||||
*/
|
||||
static MPIInitSingleton* GetInstance();
|
||||
/**
|
||||
* @brief Accessor for my COMM_WORLD rank identifier
|
||||
*
|
||||
* @return My COMM_WORLD rank identifier
|
||||
*/
|
||||
int get_rank();
|
||||
|
||||
/**
|
||||
* @brief Accessor for my COMM_WORLD rank identifier
|
||||
*
|
||||
* @return My COMM_WORLD rank identifier
|
||||
*/
|
||||
int get_rank();
|
||||
/**
|
||||
* @brief Accessor for number or processes in COMM_WORLD
|
||||
*
|
||||
* @return Number of processes in COMM_WORLD
|
||||
*/
|
||||
int get_nprocs();
|
||||
|
||||
/**
|
||||
* @brief Accessor for number or processes in COMM_WORLD
|
||||
*
|
||||
* @return Number of processes in COMM_WORLD
|
||||
*/
|
||||
int get_nprocs();
|
||||
private:
|
||||
/**
|
||||
* @brief My MPI rank identifier
|
||||
*/
|
||||
int my_rank_{-1};
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief My MPI rank identifier
|
||||
*/
|
||||
int my_rank_{-1};
|
||||
/**
|
||||
* @brief Number of MPI processes
|
||||
*/
|
||||
int nprocs_{-1};
|
||||
|
||||
/**
|
||||
* @brief Number of MPI processes
|
||||
*/
|
||||
int nprocs_{-1};
|
||||
|
||||
/**
|
||||
* @brief Was MPI initialized before rocshmem_init call
|
||||
*/
|
||||
int pre_init_done{0};
|
||||
|
||||
/**
|
||||
* @brief Refers to global variable
|
||||
*/
|
||||
static MPIInitSingleton* instance;
|
||||
/**
|
||||
* @brief Was MPI initialized before rocshmem_init call
|
||||
*/
|
||||
int pre_init_done{0};
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
#endif // LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_
|
||||
#endif // LIBRARY_SRC_MPI_INSTANCE_HPP_
|
||||
@@ -46,16 +46,7 @@ namespace rocshmem {
|
||||
|
||||
MPITransport::MPITransport(MPI_Comm comm, Queue* q)
|
||||
: queue{q}, Transport{} {
|
||||
int init_done{};
|
||||
NET_CHECK(MPI_Initialized(&init_done));
|
||||
|
||||
int provided{};
|
||||
if (!init_done) {
|
||||
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
|
||||
if (provided != MPI_THREAD_MULTIPLE) {
|
||||
fprintf(stderr, "MPI_THREAD_MULTIPLE support disabled.\n");
|
||||
}
|
||||
}
|
||||
assert(comm != MPI_COMM_NULL);
|
||||
|
||||
NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world));
|
||||
|
||||
+16
-16
@@ -47,7 +47,7 @@
|
||||
#include "ipc/backend_ipc.hpp"
|
||||
#include "ipc/context_ipc_tmpl_host.hpp"
|
||||
#endif
|
||||
#include "mpi_init_singleton.hpp"
|
||||
#include "mpi_instance.hpp"
|
||||
#include "team.hpp"
|
||||
#include "templates_host.hpp"
|
||||
#include "util.hpp"
|
||||
@@ -67,6 +67,7 @@ namespace rocshmem {
|
||||
}
|
||||
|
||||
Backend *backend = nullptr;
|
||||
MPIInstance *mpi_instance = nullptr;
|
||||
|
||||
rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
@@ -86,6 +87,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
rocm_init();
|
||||
|
||||
mpi_instance = new MPIInstance(comm);
|
||||
|
||||
#ifdef USE_RO
|
||||
CHECK_HIP(hipHostMalloc(&backend, sizeof(ROBackend)));
|
||||
backend = new (backend) ROBackend(comm);
|
||||
@@ -103,7 +106,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
rocshmem_init_attr_t *attr) {
|
||||
MPI_Comm comm = MPI_COMM_NULL;
|
||||
|
||||
if ((attr == nullptr) ||
|
||||
if ((attr == nullptr) ||
|
||||
((flags != ROCSHMEM_INIT_WITH_UNIQUEID) &&
|
||||
(flags != ROCSHMEM_INIT_WITH_MPI_COMM)) ) {
|
||||
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
|
||||
@@ -224,24 +227,21 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_my_pe() {
|
||||
if(backend == nullptr) {
|
||||
MPIInitSingleton *s = s->GetInstance();
|
||||
return s->get_rank();
|
||||
}
|
||||
else
|
||||
{
|
||||
return backend->getMyPE();
|
||||
if (mpi_instance != nullptr) {
|
||||
return mpi_instance->get_rank();
|
||||
}
|
||||
|
||||
fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_n_pes() {
|
||||
if(backend == nullptr) {
|
||||
MPIInitSingleton *s = s->GetInstance();
|
||||
return s->get_nprocs();
|
||||
}
|
||||
else {
|
||||
return backend->getNumPEs();
|
||||
if (mpi_instance != nullptr) {
|
||||
return mpi_instance->get_nprocs();
|
||||
}
|
||||
|
||||
fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ void *rocshmem_malloc(size_t size) {
|
||||
@@ -294,7 +294,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
backend->~Backend();
|
||||
CHECK_HIP(hipHostFree(backend));
|
||||
|
||||
delete MPIInitSingleton::GetInstance();
|
||||
delete mpi_instance;
|
||||
}
|
||||
|
||||
__host__ void rocshmem_query_thread(int *provided) {
|
||||
|
||||
@@ -39,11 +39,8 @@ int main(int argc, char *argv[]) {
|
||||
/***
|
||||
* Select a GPU
|
||||
*/
|
||||
int rank = rocshmem_my_pe();
|
||||
int ndevices, my_device = 0;
|
||||
CHECK_HIP(hipGetDeviceCount(&ndevices));
|
||||
my_device = rank % ndevices;
|
||||
CHECK_HIP(hipSetDevice(my_device));
|
||||
char* ompi_local_rank = getenv("OMPI_COMM_WORLD_LOCAL_RANK");
|
||||
CHECK_HIP(hipSetDevice(atoi(ompi_local_rank)));
|
||||
|
||||
/**
|
||||
* Must initialize rocshmem to access arguments needed by the tester.
|
||||
|
||||
@@ -86,7 +86,7 @@ target_sources(
|
||||
pow2_bins_gtest.cpp
|
||||
dlmalloc_gtest.cpp
|
||||
remote_heap_info_gtest.cpp
|
||||
mpi_init_singleton_gtest.cpp
|
||||
mpi_instance_gtest.cpp
|
||||
abql_block_mutex_gtest.cpp
|
||||
notifier_gtest.cpp
|
||||
free_list_gtest.cpp
|
||||
|
||||
+4
-4
@@ -22,16 +22,16 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "mpi_init_singleton_gtest.hpp"
|
||||
#include "mpi_instance_gtest.hpp"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
TEST_F(MPIInitSingletonTestFixture, library_initialize_destroy) {}
|
||||
TEST_F(MPIInstanceTestFixture, library_initialize_destroy) {}
|
||||
|
||||
TEST_F(MPIInitSingletonTestFixture, rank) {
|
||||
TEST_F(MPIInstanceTestFixture, rank) {
|
||||
ASSERT_NO_FATAL_FAILURE(s_ptr_->get_rank());
|
||||
}
|
||||
|
||||
TEST_F(MPIInitSingletonTestFixture, nprocs) {
|
||||
TEST_F(MPIInstanceTestFixture, nprocs) {
|
||||
ASSERT_EQ(s_ptr_->get_nprocs(), 4);
|
||||
}
|
||||
+13
-9
@@ -22,29 +22,33 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP
|
||||
#define ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP
|
||||
#ifndef ROCSHMEM_MPI_INSTANCE_GTEST_HPP
|
||||
#define ROCSHMEM_MPI_INSTANCE_GTEST_HPP
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "../src/mpi_init_singleton.hpp"
|
||||
#include "../src/mpi_instance.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
class MPIInitSingletonTestFixture : public ::testing::Test
|
||||
class MPIInstanceTestFixture : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
MPIInitSingletonTestFixture() {
|
||||
s_ptr_ = s_ptr_->GetInstance();
|
||||
MPIInstanceTestFixture() {
|
||||
s_ptr_ = new MPIInstance(MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
~MPIInstanceTestFixture() {
|
||||
delete s_ptr_;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief A singleton object used to initialize MPI
|
||||
* @brief A MPI instance object used to initialize MPI
|
||||
*/
|
||||
MPIInitSingleton* s_ptr_ {nullptr};
|
||||
MPIInstance* s_ptr_ {nullptr};
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
#endif // ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP
|
||||
#endif // ROCSHMEM_MPI_INSTANCE_GTEST_HPP
|
||||
مرجع در شماره جدید
Block a user