Update backend to use provided MPI communicator during library initialization (#79)

* Update backend to use provided MPI communicator during library initialization, default to `MPI_COMM_WORLD`

* Update `rocshmem_my_pe` and `rocshmem_n_pes` host APIs
   - Return values from backend if initialized; otherwise, fallback to MPI_Singleton.
Cette révision appartient à :
Avinash Kethineedi
2025-04-14 09:18:57 -05:00
révisé par GitHub
Parent 0fd628458c
révision 05755847f5
10 fichiers modifiés avec 79 ajouts et 58 suppressions
+28 -1
Voir le fichier
@@ -33,7 +33,15 @@
namespace rocshmem {
Backend::Backend() {
#define NET_CHECK(cmd) \
{ \
if (cmd != MPI_SUCCESS) { \
fprintf(stderr, "Unrecoverable error: MPI Failure\n"); \
abort() ; \
} \
}
Backend::Backend(MPI_Comm comm) : heap{comm} {
int num_cus{};
if (hipDeviceGetAttribute(&num_cus, hipDeviceAttributeMultiprocessorCount,
0)) {
@@ -71,12 +79,31 @@ Backend::Backend() {
CHECK_HIP(
hipHostMalloc(reinterpret_cast<void**>(&done_init), sizeof(uint8_t)));
init_mpi_once(comm);
/*
* Notify other threads that Backend has been initialized.
*/
*done_init = 0;
}
void Backend::init_mpi_once(MPI_Comm comm) {
int init_done{};
NET_CHECK(MPI_Initialized(&init_done));
int provided{};
if (!init_done) {
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
if (provided != MPI_THREAD_MULTIPLE) {
std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n";
}
}
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
NET_CHECK(MPI_Comm_dup(comm, &backend_comm));
NET_CHECK(MPI_Comm_size(backend_comm, &num_pes));
NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe));
}
void Backend::track_ctx(Context* ctx) {
/**
* TODO: Don't track CTX_PRIVATE when we support it
+9 -2
Voir le fichier
@@ -67,7 +67,7 @@ class Backend {
* @note Implementation may reduce the number of workgroups if the
* number exceeds hardware limits.
*/
explicit Backend();
explicit Backend(MPI_Comm comm);
/**
* @brief Destructor.
@@ -221,7 +221,7 @@ class Backend {
* @todo document where this is used and try to coalesce this into another
* class
*/
MPI_Comm thread_comm{};
MPI_Comm backend_comm{};
/**
* @brief Object contains the interface and internal data structures
@@ -295,6 +295,13 @@ class Backend {
* @brief List of ctxs created by the user.
*/
std::vector<Context*> list_of_ctxs{};
/**
* @brief initialize MPI.
*
* Backend relies on MPI to exchange meta data across PEs.
*/
void init_mpi_once(MPI_Comm comm);
};
/**
+9 -29
Voir le fichier
@@ -54,7 +54,7 @@ int get_ls_non_zero_bit(char *bitmask, int mask_length) {
}
IPCBackend::IPCBackend(MPI_Comm comm)
: Backend() {
: Backend(comm) {
type = BackendType::IPC_BACKEND;
if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) {
@@ -62,8 +62,6 @@ IPCBackend::IPCBackend(MPI_Comm comm)
sstream >> maximum_num_contexts_;
}
init_mpi_once(comm);
initIPC();
/**
@@ -74,8 +72,8 @@ IPCBackend::IPCBackend(MPI_Comm comm)
/* Initialize the host interface */
host_interface = std::make_shared<HostInterface>(hdp_proxy_.get(),
thread_comm,
&heap);
backend_comm,
&heap);
default_host_ctx = std::make_unique<IPCHostContext>(this, 0);
@@ -156,7 +154,7 @@ void IPCBackend::setup_team_world() {
IPCTeam *team_world{nullptr};
CHECK_HIP(hipMalloc(&team_world, sizeof(IPCTeam)));
new (team_world) IPCTeam(this, team_info_wrt_parent, team_info_wrt_world,
num_pes, my_pe, thread_comm, 0);
num_pes, my_pe, backend_comm, 0);
team_tracker.set_team_world(team_world);
/**
@@ -165,24 +163,6 @@ void IPCBackend::setup_team_world() {
ROCSHMEM_TEAM_WORLD = reinterpret_cast<rocshmem_team_t>(team_world);
}
void IPCBackend::init_mpi_once(MPI_Comm comm) {
int init_done{};
NET_CHECK(MPI_Initialized(&init_done));
int provided{};
if (!init_done) {
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
if (provided != MPI_THREAD_MULTIPLE) {
std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n";
}
}
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
NET_CHECK(MPI_Comm_dup(comm, &thread_comm));
NET_CHECK(MPI_Comm_size(thread_comm, &num_pes));
NET_CHECK(MPI_Comm_rank(thread_comm, &my_pe));
}
void IPCBackend::team_destroy(rocshmem_team_t team) {
IPCTeam *team_obj = get_internal_ipc_team(team);
@@ -260,11 +240,11 @@ void IPCBackend::initIPC() {
const auto &heap_bases{heap.get_heap_bases()};
ipcImpl.ipcHostInit(my_pe, heap_bases,
thread_comm);
backend_comm);
}
void IPCBackend::global_exit(int status) {
MPI_Abort(MPI_COMM_WORLD, status);
MPI_Abort(backend_comm, status);
}
void IPCBackend::teams_destroy() {
@@ -331,7 +311,7 @@ void IPCBackend::init_wrk_sync_buffer() {
* all-to-all exchange with each PE to share the IPC handles.
*/
MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR,
ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, thread_comm);
ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm);
/*
* Allocate device-side fine grained memory to hold IPC addresses of
@@ -399,7 +379,7 @@ void IPCBackend::rocshmem_collective_init() {
* Make sure that all processing elements have done this before
* continuing.
*/
NET_CHECK(MPI_Barrier(thread_comm));
NET_CHECK(MPI_Barrier(backend_comm));
}
void IPCBackend::teams_init() {
@@ -491,7 +471,7 @@ void IPCBackend::teams_init() {
* Make sure that all processing elements have done this before
* continuing.
*/
NET_CHECK(MPI_Barrier(thread_comm));
NET_CHECK(MPI_Barrier(backend_comm));
}
} // namespace rocshmem
-10
Voir le fichier
@@ -65,16 +65,6 @@ class IPCBackend : public Backend {
*/
void ctx_destroy(Context *ctx) override;
/**
* @brief initialize MPI.
*
* IPC relies on MPI just to exchange the IPC_handle information.
*
* todo: remove the dependency on MPI and make it generic to PMI-X or just
* to OpenSHMEM to have support for both CPU and GPU
*/
void init_mpi_once(MPI_Comm comm);
/**
* @brief Helper to initialize IPC interface.
*/
+9 -6
Voir le fichier
@@ -50,7 +50,9 @@ class CommunicatorMPI {
/**
* @brief Primary constructor
*/
CommunicatorMPI(char* heap_base, size_t heap_size) {
CommunicatorMPI(char* heap_base, size_t heap_size,
MPI_Comm comm = MPI_COMM_WORLD)
: comm_{comm} {
int initialized;
MPI_Initialized(&initialized);
if (!initialized) {
@@ -87,8 +89,8 @@ class CommunicatorMPI {
* @brief Performs MPI_Allgather on recvbuf
*/
void allgather(void* recvbuf) {
MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf, sizeof(void*),
MPI_CHAR, comm_);
MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf,
sizeof(void*), MPI_CHAR, comm_);
}
/**
@@ -100,7 +102,7 @@ class CommunicatorMPI {
/**
* @brief Identifier for this processing element
*/
MPI_Comm comm_{MPI_COMM_WORLD};
MPI_Comm comm_{};
/**
* @brief Identifier for this processing element
@@ -139,8 +141,9 @@ class RemoteHeapInfo {
* @param[in] The identifier for this processing element
* @param[in] The total number of processing elements
*/
RemoteHeapInfo(char* heap_ptr, size_t heap_size)
: communicator_{heap_ptr, heap_size} {
RemoteHeapInfo(char* heap_ptr, size_t heap_size,
MPI_Comm comm = MPI_COMM_WORLD)
: communicator_{heap_ptr, heap_size, comm} {
heap_bases_.resize(communicator_.num_pes());
for (auto& base : heap_bases_) {
base = nullptr;
+5 -2
Voir le fichier
@@ -54,6 +54,10 @@ class SymmetricHeap {
using RemoteHeapInfoType = RemoteHeapInfo<CommunicatorMPI>;
public:
SymmetricHeap(MPI_Comm comm = MPI_COMM_WORLD)
: remote_heap_info_{single_heap_.get_base_ptr(),
single_heap_.get_size(),
comm} {}
/**
* @brief Allocates heap memory and returns ptr to caller
*
@@ -120,8 +124,7 @@ class SymmetricHeap {
/**
* @brief Implementation of remote heaps
*/
RemoteHeapInfoType remote_heap_info_{single_heap_.get_base_ptr(),
single_heap_.get_size()};
RemoteHeapInfoType remote_heap_info_{};
};
} // namespace rocshmem
+1 -1
Voir le fichier
@@ -165,7 +165,7 @@ class WindowInfo {
/**
* @brief MPI Communicator
*/
MPI_Comm comm_{MPI_COMM_WORLD};
MPI_Comm comm_{};
/**
* @brief Owning pointer to MPI_Win
+2 -2
Voir le fichier
@@ -45,7 +45,7 @@ namespace rocshmem {
extern rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
ROBackend::ROBackend(MPI_Comm comm)
: Backend() {
: Backend(comm) {
type = BackendType::RO_BACKEND;
if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) {
@@ -78,7 +78,7 @@ ROBackend::ROBackend(MPI_Comm comm)
queue_ = Queue(maximum_num_contexts_, queue_size_);
transport_ = new MPITransport(comm, &queue_);
transport_ = new MPITransport(backend_comm, &queue_);
num_pes = transport_->getNumPes();
my_pe = transport_->getMyPe();
+1 -1
Voir le fichier
@@ -54,7 +54,7 @@ MPITransport::MPITransport(MPI_Comm comm, Queue* q)
std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n";
}
}
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
assert(comm != MPI_COMM_NULL);
NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world));
NET_CHECK(MPI_Comm_size(ro_net_comm_world, &num_pes));
+15 -4
Voir le fichier
@@ -197,13 +197,24 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
}
[[maybe_unused]] __host__ int rocshmem_my_pe() {
MPIInitSingleton *s = s->GetInstance();
return s->get_rank();
if(backend == nullptr) {
MPIInitSingleton *s = s->GetInstance();
return s->get_rank();
}
else
{
return backend->getMyPE();
}
}
[[maybe_unused]] __host__ int rocshmem_n_pes() {
MPIInitSingleton *s = s->GetInstance();
return s->get_nprocs();
if(backend == nullptr) {
MPIInitSingleton *s = s->GetInstance();
return s->get_nprocs();
}
else {
return backend->getNumPEs();
}
}
[[maybe_unused]] __host__ void *rocshmem_malloc(size_t size) {