Multi-Node rocshmem_finalize() bug (#138)

[ROCm/rocshmem commit: 3f01d89207]
This commit is contained in:
Yiltan
2025-06-04 10:02:03 -04:00
committato da GitHub
parent 032d5e5c6b
commit bceeadeb63
18 ha cambiato i file con 175 aggiunte e 237 eliminazioni
@@ -54,19 +54,9 @@
*/
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
#include "util.h"
using namespace rocshmem;
@@ -123,16 +113,13 @@ int main (int argc, char **argv)
nelem = atoi(argv[1]);
}
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int ndevices, my_device = 0;
CHECK_HIP(hipGetDeviceCount(&ndevices));
my_device = my_pe % ndevices;
CHECK_HIP(hipSetDevice(my_device));
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
rocshmem_init();
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int *source = (int *)rocshmem_malloc(nelem * sizeof(int));
int *dest = (int *)rocshmem_malloc(nelem * sizeof(int));
if (NULL == source || NULL == dest) {
@@ -54,19 +54,9 @@
*/
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
#include "util.h"
using namespace rocshmem;
@@ -128,16 +118,13 @@ int main (int argc, char **argv)
nelem = atoi(argv[1]);
}
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int ndevices, my_device = 0;
CHECK_HIP(hipGetDeviceCount(&ndevices));
my_device = my_pe % ndevices;
CHECK_HIP(hipSetDevice(my_device));
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
rocshmem_init();
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int *source = (int *)rocshmem_malloc(nelem * npes * sizeof(int));
int *dest = (int *)rocshmem_malloc(nelem * npes * sizeof(int));
if (NULL == source || NULL == dest) {
@@ -54,19 +54,9 @@
*/
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
#include "util.h"
using namespace rocshmem;
@@ -121,16 +111,13 @@ int main(int argc, char **argv)
nelem = atoi(argv[1]);
}
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int ndevices, my_device = 0;
CHECK_HIP(hipGetDeviceCount(&ndevices));
my_device = my_pe % ndevices;
CHECK_HIP(hipSetDevice(my_device));
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
rocshmem_init();
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int *source = (int *)rocshmem_malloc(nelem * sizeof(int));
int *dest = (int *)rocshmem_malloc(nelem * sizeof(int));
if (NULL == source || NULL == dest) {
@@ -54,19 +54,9 @@
*/
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
#include "util.h"
using namespace rocshmem;
@@ -76,8 +66,8 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem)
int threadId = blockIdx.x * blockDim.x + threadIdx.x;
if (threadId == 0) {
int rank = rocshmem_my_pe();
int peer = rank ? 0 : 1;
int my_pe = rocshmem_my_pe();
int peer = my_pe ? 0 : 1;
rocshmem_getmem(dst, src, nelem * sizeof(int), peer);
rocshmem_quiet();
}
@@ -90,19 +80,19 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem)
int main (int argc, char **argv)
{
int rank = rocshmem_my_pe();
int ndevices, my_device = 0;
CHECK_HIP(hipGetDeviceCount(&ndevices));
my_device = rank % ndevices;
CHECK_HIP(hipSetDevice(my_device));
int nelem = MAX_ELEM;
if (argc > 1) {
nelem = atoi(argv[1]);
}
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
rocshmem_init();
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int *src = (int *)rocshmem_malloc(nelem * sizeof(int));
int *dst = (int *)rocshmem_malloc(nelem * sizeof(int));
if (NULL == src || NULL == dst) {
@@ -128,7 +118,7 @@ int main (int argc, char **argv)
if (dst[i] != 0) {
pass = false;
#if VERBOSE
printf("[%d] Error in element %d expected 0 got %d\n", rank, i, dst[i]);
printf("[%d] Error in element %d expected 0 got %d\n", my_pe, i, dst[i]);
#endif
}
}
@@ -54,19 +54,9 @@
*/
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
#include "util.h"
using namespace rocshmem;
@@ -54,19 +54,9 @@
*/
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
#include "util.h"
using namespace rocshmem;
@@ -95,20 +85,20 @@ __global__ void simple_put_signal_test(uint64_t *data, uint64_t *message, size_t
int main (int argc, char **argv)
{
int rank = rocshmem_my_pe();
int ndevices, my_device = 0;
CHECK_HIP(hipGetDeviceCount(&ndevices));
my_device = rank % ndevices;
CHECK_HIP(hipSetDevice(my_device));
int nelem = MAX_ELEM;
if (argc > 1) {
nelem = atoi(argv[1]);
}
CHECK_HIP(hipSetDevice(get_launcher_local_rank()));
rocshmem_init();
int my_pe = rocshmem_my_pe();
int npes = rocshmem_n_pes();
int dst_pe = (rank + 1) % npes;
int dst_pe = (my_pe + 1) % npes;
uint64_t *message = (uint64_t*)rocshmem_malloc(nelem * sizeof(uint64_t));
uint64_t *data = (uint64_t*)rocshmem_malloc(nelem * sizeof(uint64_t));
uint64_t *sig_addr = (uint64_t*)rocshmem_malloc(sizeof(uint64_t));
@@ -123,14 +113,14 @@ int main (int argc, char **argv)
}
for (int i=0; i<nelem; i++) {
message[i] = rank;
message[i] = my_pe;
}
CHECK_HIP(hipMemset(data, 0, (nelem * sizeof(uint64_t))));
CHECK_HIP(hipDeviceSynchronize());
int threadsPerBlock=256;
simple_put_signal_test<<<dim3(1), dim3(threadsPerBlock), 0, 0>>>(data, message, nelem, sig_addr, rank, dst_pe);
simple_put_signal_test<<<dim3(1), dim3(threadsPerBlock), 0, 0>>>(data, message, nelem, sig_addr, my_pe, dst_pe);
rocshmem_barrier_all();
CHECK_HIP(hipDeviceSynchronize());
@@ -139,11 +129,11 @@ int main (int argc, char **argv)
if (data[i] != 0) {
pass = false;
#if VERBOSE
printf("[%d] Error in element %d expected 0 got %d\n", rank, i, dst[i]);
printf("[%d] Error in element %d expected 0 got %d\n", my_pe, i, dst[i]);
#endif
}
}
printf("[%d] Test %s \t %s\n", rank, argv[0], pass ? "[PASS]" : "[FAIL]");
printf("[%d] Test %s \t %s\n", my_pe, argv[0], pass ? "[PASS]" : "[FAIL]");
rocshmem_free(data);
rocshmem_free(message);
+52
Vedi File
@@ -0,0 +1,52 @@
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef __ROCSHMEM_EXAMPLES_UTIL_H__
#define __ROCSHMEM_EXAMPLES_UTIL_H__
#include <iostream>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#define CHECK_HIP(condition) { \
hipError_t error = condition; \
if(error != hipSuccess){ \
fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \
MPI_Abort(MPI_COMM_WORLD, error); \
} \
}
static int get_launcher_local_rank() {
char *local_rank_str = nullptr;
local_rank_str = getenv("OMPI_COMM_WORLD_LOCAL_RANK");
if (nullptr != local_rank_str) {
return atoi(local_rank_str);
}
return -1;
}
#endif /* __ROCSHMEM_EXAMPLES_UTIL_H__ */
+1 -1
Vedi File
@@ -32,7 +32,7 @@ target_sources(
backend_bc.cpp
context_host.cpp
context_device.cpp
mpi_init_singleton.cpp
mpi_instance.cpp
rocshmem_gpu.cpp
rocshmem.cpp
team.cpp
-11
Vedi File
@@ -86,18 +86,7 @@ Backend::Backend(MPI_Comm comm) : heap{comm} {
}
void Backend::init_mpi_once(MPI_Comm comm) {
int init_done{};
NET_CHECK(MPI_Initialized(&init_done));
int provided{};
if (!init_done) {
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
if (provided != MPI_THREAD_MULTIPLE) {
fprintf(stderr, "MPI_THREAD_MULTIPLE support disabled.\n");
}
}
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
NET_CHECK(MPI_Comm_dup(comm, &backend_comm));
NET_CHECK(MPI_Comm_size(backend_comm, &num_pes));
NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe));
@@ -55,15 +55,8 @@ class CommunicatorMPI {
CommunicatorMPI(char* heap_base, size_t heap_size,
MPI_Comm comm = MPI_COMM_WORLD)
: comm_{comm} {
int initialized;
MPI_Initialized(&initialized);
if (!initialized) {
int provided;
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
}
MPI_Comm_rank(comm_, &my_pe_);
MPI_Comm_size(comm_, &num_pes_);
heap_window_info_ = WindowInfo(comm_, heap_base, heap_size);
}
@@ -22,13 +22,11 @@
* IN THE SOFTWARE.
*****************************************************************************/
#include "mpi_init_singleton.hpp"
#include "mpi_instance.hpp"
namespace rocshmem {
MPIInitSingleton* MPIInitSingleton::instance{nullptr};
MPIInitSingleton::MPIInitSingleton() {
MPIInstance::MPIInstance(MPI_Comm comm) {
MPI_Initialized(&pre_init_done);
if (!pre_init_done) {
@@ -36,11 +34,15 @@ MPIInitSingleton::MPIInitSingleton() {
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
}
MPI_Comm_size(MPI_COMM_WORLD, &nprocs_);
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);
if (comm == MPI_COMM_NULL) {
comm = MPI_COMM_WORLD;
}
MPI_Comm_size(comm, &nprocs_);
MPI_Comm_rank(comm, &my_rank_);
}
MPIInitSingleton::~MPIInitSingleton() {
MPIInstance::~MPIInstance() {
int finalized{0};
MPI_Finalized(&finalized);
if (!finalized && !pre_init_done) {
@@ -48,16 +50,8 @@ MPIInitSingleton::~MPIInitSingleton() {
}
}
MPIInitSingleton* MPIInitSingleton::GetInstance() {
if (!instance) {
instance = new MPIInitSingleton();
return instance;
}
return instance;
}
int MPIInstance::get_rank() { return my_rank_; }
int MPIInitSingleton::get_rank() { return my_rank_; }
int MPIInitSingleton::get_nprocs() { return nprocs_; }
int MPIInstance::get_nprocs() { return nprocs_; }
} // namespace rocshmem
@@ -22,77 +22,64 @@
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_
#define LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_
#ifndef LIBRARY_SRC_MPI_INSTANCE_HPP_
#define LIBRARY_SRC_MPI_INSTANCE_HPP_
#include <mpi.h>
#include <memory>
/**
* @file mpi_init_singleton.hpp
* @file mpi_instance.hpp
*
* @brief Contains MPI library initialization code
*/
namespace rocshmem {
class MPIInitSingleton {
private:
/**
* @brief Primary constructor
*/
MPIInitSingleton();
class MPIInstance {
public:
/**
* @brief Primary constructor
*/
MPIInstance(MPI_Comm comm);
public:
/**
* @brief Destructor
*/
~MPIInitSingleton();
/**
* @brief Destructor
*/
~MPIInstance();
/**
* @brief Invoke singleton construction or return handle
*
* @return Initialized handle to singleton
*/
static MPIInitSingleton* GetInstance();
/**
* @brief Accessor for my COMM_WORLD rank identifier
*
* @return My COMM_WORLD rank identifier
*/
int get_rank();
/**
* @brief Accessor for my COMM_WORLD rank identifier
*
* @return My COMM_WORLD rank identifier
*/
int get_rank();
/**
* @brief Accessor for number or processes in COMM_WORLD
*
* @return Number of processes in COMM_WORLD
*/
int get_nprocs();
/**
* @brief Accessor for number or processes in COMM_WORLD
*
* @return Number of processes in COMM_WORLD
*/
int get_nprocs();
private:
/**
* @brief My MPI rank identifier
*/
int my_rank_{-1};
private:
/**
* @brief My MPI rank identifier
*/
int my_rank_{-1};
/**
* @brief Number of MPI processes
*/
int nprocs_{-1};
/**
* @brief Number of MPI processes
*/
int nprocs_{-1};
/**
* @brief Was MPI initialized before rocshmem_init call
*/
int pre_init_done{0};
/**
* @brief Refers to global variable
*/
static MPIInitSingleton* instance;
/**
* @brief Was MPI initialized before rocshmem_init call
*/
int pre_init_done{0};
};
} // namespace rocshmem
#endif // LIBRARY_SRC_MPI_INIT_SINGLETON_HPP_
#endif // LIBRARY_SRC_MPI_INSTANCE_HPP_
@@ -46,16 +46,7 @@ namespace rocshmem {
MPITransport::MPITransport(MPI_Comm comm, Queue* q)
: queue{q}, Transport{} {
int init_done{};
NET_CHECK(MPI_Initialized(&init_done));
int provided{};
if (!init_done) {
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
if (provided != MPI_THREAD_MULTIPLE) {
fprintf(stderr, "MPI_THREAD_MULTIPLE support disabled.\n");
}
}
assert(comm != MPI_COMM_NULL);
NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world));
+16 -16
Vedi File
@@ -47,7 +47,7 @@
#include "ipc/backend_ipc.hpp"
#include "ipc/context_ipc_tmpl_host.hpp"
#endif
#include "mpi_init_singleton.hpp"
#include "mpi_instance.hpp"
#include "team.hpp"
#include "templates_host.hpp"
#include "util.hpp"
@@ -67,6 +67,7 @@ namespace rocshmem {
}
Backend *backend = nullptr;
MPIInstance *mpi_instance = nullptr;
rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
@@ -86,6 +87,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
rocm_init();
mpi_instance = new MPIInstance(comm);
#ifdef USE_RO
CHECK_HIP(hipHostMalloc(&backend, sizeof(ROBackend)));
backend = new (backend) ROBackend(comm);
@@ -103,7 +106,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
rocshmem_init_attr_t *attr) {
MPI_Comm comm = MPI_COMM_NULL;
if ((attr == nullptr) ||
if ((attr == nullptr) ||
((flags != ROCSHMEM_INIT_WITH_UNIQUEID) &&
(flags != ROCSHMEM_INIT_WITH_MPI_COMM)) ) {
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
@@ -224,24 +227,21 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
}
[[maybe_unused]] __host__ int rocshmem_my_pe() {
if(backend == nullptr) {
MPIInitSingleton *s = s->GetInstance();
return s->get_rank();
}
else
{
return backend->getMyPE();
if (mpi_instance != nullptr) {
return mpi_instance->get_rank();
}
fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n");
return -1;
}
[[maybe_unused]] __host__ int rocshmem_n_pes() {
if(backend == nullptr) {
MPIInitSingleton *s = s->GetInstance();
return s->get_nprocs();
}
else {
return backend->getNumPEs();
if (mpi_instance != nullptr) {
return mpi_instance->get_nprocs();
}
fprintf(stderr, "[WARNING] rocshmem_init() has not been called\n");
return -1;
}
[[maybe_unused]] __host__ void *rocshmem_malloc(size_t size) {
@@ -294,7 +294,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
backend->~Backend();
CHECK_HIP(hipHostFree(backend));
delete MPIInitSingleton::GetInstance();
delete mpi_instance;
}
__host__ void rocshmem_query_thread(int *provided) {
@@ -39,11 +39,8 @@ int main(int argc, char *argv[]) {
/***
* Select a GPU
*/
int rank = rocshmem_my_pe();
int ndevices, my_device = 0;
CHECK_HIP(hipGetDeviceCount(&ndevices));
my_device = rank % ndevices;
CHECK_HIP(hipSetDevice(my_device));
char* ompi_local_rank = getenv("OMPI_COMM_WORLD_LOCAL_RANK");
CHECK_HIP(hipSetDevice(atoi(ompi_local_rank)));
/**
* Must initialize rocshmem to access arguments needed by the tester.
@@ -86,7 +86,7 @@ target_sources(
pow2_bins_gtest.cpp
dlmalloc_gtest.cpp
remote_heap_info_gtest.cpp
mpi_init_singleton_gtest.cpp
mpi_instance_gtest.cpp
abql_block_mutex_gtest.cpp
notifier_gtest.cpp
free_list_gtest.cpp
@@ -22,16 +22,16 @@
* IN THE SOFTWARE.
*****************************************************************************/
#include "mpi_init_singleton_gtest.hpp"
#include "mpi_instance_gtest.hpp"
using namespace rocshmem;
TEST_F(MPIInitSingletonTestFixture, library_initialize_destroy) {}
TEST_F(MPIInstanceTestFixture, library_initialize_destroy) {}
TEST_F(MPIInitSingletonTestFixture, rank) {
TEST_F(MPIInstanceTestFixture, rank) {
ASSERT_NO_FATAL_FAILURE(s_ptr_->get_rank());
}
TEST_F(MPIInitSingletonTestFixture, nprocs) {
TEST_F(MPIInstanceTestFixture, nprocs) {
ASSERT_EQ(s_ptr_->get_nprocs(), 4);
}
@@ -22,29 +22,33 @@
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP
#define ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP
#ifndef ROCSHMEM_MPI_INSTANCE_GTEST_HPP
#define ROCSHMEM_MPI_INSTANCE_GTEST_HPP
#include "gtest/gtest.h"
#include "../src/mpi_init_singleton.hpp"
#include "../src/mpi_instance.hpp"
namespace rocshmem {
class MPIInitSingletonTestFixture : public ::testing::Test
class MPIInstanceTestFixture : public ::testing::Test
{
public:
MPIInitSingletonTestFixture() {
s_ptr_ = s_ptr_->GetInstance();
MPIInstanceTestFixture() {
s_ptr_ = new MPIInstance(MPI_COMM_WORLD);
}
~MPIInstanceTestFixture() {
delete s_ptr_;
}
protected:
/**
* @brief A singleton object used to initialize MPI
* @brief A MPI instance object used to initialize MPI
*/
MPIInitSingleton* s_ptr_ {nullptr};
MPIInstance* s_ptr_ {nullptr};
};
} // namespace rocshmem
#endif // ROCSHMEM_MPI_INIT_SINGLETON_GTEST_HPP
#endif // ROCSHMEM_MPI_INSTANCE_GTEST_HPP