Update backend to use provided MPI communicator during library initialization (#79)
* Update backend to use provided MPI communicator during library initialization, default to `MPI_COMM_WORLD` * Update `rocshmem_my_pe` and `rocshmem_n_pes` host APIs - Return values from backend if initialized; otherwise, fallback to MPI_Singleton.
Cette révision appartient à :
révisé par
GitHub
Parent
0fd628458c
révision
05755847f5
@@ -33,7 +33,15 @@
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
Backend::Backend() {
|
||||
#define NET_CHECK(cmd) \
|
||||
{ \
|
||||
if (cmd != MPI_SUCCESS) { \
|
||||
fprintf(stderr, "Unrecoverable error: MPI Failure\n"); \
|
||||
abort() ; \
|
||||
} \
|
||||
}
|
||||
|
||||
Backend::Backend(MPI_Comm comm) : heap{comm} {
|
||||
int num_cus{};
|
||||
if (hipDeviceGetAttribute(&num_cus, hipDeviceAttributeMultiprocessorCount,
|
||||
0)) {
|
||||
@@ -71,12 +79,31 @@ Backend::Backend() {
|
||||
CHECK_HIP(
|
||||
hipHostMalloc(reinterpret_cast<void**>(&done_init), sizeof(uint8_t)));
|
||||
|
||||
init_mpi_once(comm);
|
||||
/*
|
||||
* Notify other threads that Backend has been initialized.
|
||||
*/
|
||||
*done_init = 0;
|
||||
}
|
||||
|
||||
void Backend::init_mpi_once(MPI_Comm comm) {
|
||||
int init_done{};
|
||||
NET_CHECK(MPI_Initialized(&init_done));
|
||||
|
||||
int provided{};
|
||||
if (!init_done) {
|
||||
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
|
||||
if (provided != MPI_THREAD_MULTIPLE) {
|
||||
std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n";
|
||||
}
|
||||
}
|
||||
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
|
||||
|
||||
NET_CHECK(MPI_Comm_dup(comm, &backend_comm));
|
||||
NET_CHECK(MPI_Comm_size(backend_comm, &num_pes));
|
||||
NET_CHECK(MPI_Comm_rank(backend_comm, &my_pe));
|
||||
}
|
||||
|
||||
void Backend::track_ctx(Context* ctx) {
|
||||
/**
|
||||
* TODO: Don't track CTX_PRIVATE when we support it
|
||||
|
||||
@@ -67,7 +67,7 @@ class Backend {
|
||||
* @note Implementation may reduce the number of workgroups if the
|
||||
* number exceeds hardware limits.
|
||||
*/
|
||||
explicit Backend();
|
||||
explicit Backend(MPI_Comm comm);
|
||||
|
||||
/**
|
||||
* @brief Destructor.
|
||||
@@ -221,7 +221,7 @@ class Backend {
|
||||
* @todo document where this is used and try to coalesce this into another
|
||||
* class
|
||||
*/
|
||||
MPI_Comm thread_comm{};
|
||||
MPI_Comm backend_comm{};
|
||||
|
||||
/**
|
||||
* @brief Object contains the interface and internal data structures
|
||||
@@ -295,6 +295,13 @@ class Backend {
|
||||
* @brief List of ctxs created by the user.
|
||||
*/
|
||||
std::vector<Context*> list_of_ctxs{};
|
||||
|
||||
/**
|
||||
* @brief initialize MPI.
|
||||
*
|
||||
* Backend relies on MPI to exchange meta data across PEs.
|
||||
*/
|
||||
void init_mpi_once(MPI_Comm comm);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -54,7 +54,7 @@ int get_ls_non_zero_bit(char *bitmask, int mask_length) {
|
||||
}
|
||||
|
||||
IPCBackend::IPCBackend(MPI_Comm comm)
|
||||
: Backend() {
|
||||
: Backend(comm) {
|
||||
type = BackendType::IPC_BACKEND;
|
||||
|
||||
if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) {
|
||||
@@ -62,8 +62,6 @@ IPCBackend::IPCBackend(MPI_Comm comm)
|
||||
sstream >> maximum_num_contexts_;
|
||||
}
|
||||
|
||||
init_mpi_once(comm);
|
||||
|
||||
initIPC();
|
||||
|
||||
/**
|
||||
@@ -74,8 +72,8 @@ IPCBackend::IPCBackend(MPI_Comm comm)
|
||||
|
||||
/* Initialize the host interface */
|
||||
host_interface = std::make_shared<HostInterface>(hdp_proxy_.get(),
|
||||
thread_comm,
|
||||
&heap);
|
||||
backend_comm,
|
||||
&heap);
|
||||
|
||||
default_host_ctx = std::make_unique<IPCHostContext>(this, 0);
|
||||
|
||||
@@ -156,7 +154,7 @@ void IPCBackend::setup_team_world() {
|
||||
IPCTeam *team_world{nullptr};
|
||||
CHECK_HIP(hipMalloc(&team_world, sizeof(IPCTeam)));
|
||||
new (team_world) IPCTeam(this, team_info_wrt_parent, team_info_wrt_world,
|
||||
num_pes, my_pe, thread_comm, 0);
|
||||
num_pes, my_pe, backend_comm, 0);
|
||||
team_tracker.set_team_world(team_world);
|
||||
|
||||
/**
|
||||
@@ -165,24 +163,6 @@ void IPCBackend::setup_team_world() {
|
||||
ROCSHMEM_TEAM_WORLD = reinterpret_cast<rocshmem_team_t>(team_world);
|
||||
}
|
||||
|
||||
void IPCBackend::init_mpi_once(MPI_Comm comm) {
|
||||
int init_done{};
|
||||
NET_CHECK(MPI_Initialized(&init_done));
|
||||
|
||||
int provided{};
|
||||
if (!init_done) {
|
||||
NET_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &provided));
|
||||
if (provided != MPI_THREAD_MULTIPLE) {
|
||||
std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n";
|
||||
}
|
||||
}
|
||||
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
|
||||
|
||||
NET_CHECK(MPI_Comm_dup(comm, &thread_comm));
|
||||
NET_CHECK(MPI_Comm_size(thread_comm, &num_pes));
|
||||
NET_CHECK(MPI_Comm_rank(thread_comm, &my_pe));
|
||||
}
|
||||
|
||||
void IPCBackend::team_destroy(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = get_internal_ipc_team(team);
|
||||
|
||||
@@ -260,11 +240,11 @@ void IPCBackend::initIPC() {
|
||||
const auto &heap_bases{heap.get_heap_bases()};
|
||||
|
||||
ipcImpl.ipcHostInit(my_pe, heap_bases,
|
||||
thread_comm);
|
||||
backend_comm);
|
||||
}
|
||||
|
||||
void IPCBackend::global_exit(int status) {
|
||||
MPI_Abort(MPI_COMM_WORLD, status);
|
||||
MPI_Abort(backend_comm, status);
|
||||
}
|
||||
|
||||
void IPCBackend::teams_destroy() {
|
||||
@@ -331,7 +311,7 @@ void IPCBackend::init_wrk_sync_buffer() {
|
||||
* all-to-all exchange with each PE to share the IPC handles.
|
||||
*/
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(hipIpcMemHandle_t), MPI_CHAR,
|
||||
ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, thread_comm);
|
||||
ipc_handle, sizeof(hipIpcMemHandle_t), MPI_CHAR, backend_comm);
|
||||
|
||||
/*
|
||||
* Allocate device-side fine grained memory to hold IPC addresses of
|
||||
@@ -399,7 +379,7 @@ void IPCBackend::rocshmem_collective_init() {
|
||||
* Make sure that all processing elements have done this before
|
||||
* continuing.
|
||||
*/
|
||||
NET_CHECK(MPI_Barrier(thread_comm));
|
||||
NET_CHECK(MPI_Barrier(backend_comm));
|
||||
}
|
||||
|
||||
void IPCBackend::teams_init() {
|
||||
@@ -491,7 +471,7 @@ void IPCBackend::teams_init() {
|
||||
* Make sure that all processing elements have done this before
|
||||
* continuing.
|
||||
*/
|
||||
NET_CHECK(MPI_Barrier(thread_comm));
|
||||
NET_CHECK(MPI_Barrier(backend_comm));
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -65,16 +65,6 @@ class IPCBackend : public Backend {
|
||||
*/
|
||||
void ctx_destroy(Context *ctx) override;
|
||||
|
||||
/**
|
||||
* @brief initialize MPI.
|
||||
*
|
||||
* IPC relies on MPI just to exchange the IPC_handle information.
|
||||
*
|
||||
* todo: remove the dependency on MPI and make it generic to PMI-X or just
|
||||
* to OpenSHMEM to have support for both CPU and GPU
|
||||
*/
|
||||
void init_mpi_once(MPI_Comm comm);
|
||||
|
||||
/**
|
||||
* @brief Helper to initialize IPC interface.
|
||||
*/
|
||||
|
||||
@@ -50,7 +50,9 @@ class CommunicatorMPI {
|
||||
/**
|
||||
* @brief Primary constructor
|
||||
*/
|
||||
CommunicatorMPI(char* heap_base, size_t heap_size) {
|
||||
CommunicatorMPI(char* heap_base, size_t heap_size,
|
||||
MPI_Comm comm = MPI_COMM_WORLD)
|
||||
: comm_{comm} {
|
||||
int initialized;
|
||||
MPI_Initialized(&initialized);
|
||||
if (!initialized) {
|
||||
@@ -87,8 +89,8 @@ class CommunicatorMPI {
|
||||
* @brief Performs MPI_Allgather on recvbuf
|
||||
*/
|
||||
void allgather(void* recvbuf) {
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf, sizeof(void*),
|
||||
MPI_CHAR, comm_);
|
||||
MPI_Allgather(MPI_IN_PLACE, sizeof(void*), MPI_CHAR, recvbuf,
|
||||
sizeof(void*), MPI_CHAR, comm_);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -100,7 +102,7 @@ class CommunicatorMPI {
|
||||
/**
|
||||
* @brief Identifier for this processing element
|
||||
*/
|
||||
MPI_Comm comm_{MPI_COMM_WORLD};
|
||||
MPI_Comm comm_{};
|
||||
|
||||
/**
|
||||
* @brief Identifier for this processing element
|
||||
@@ -139,8 +141,9 @@ class RemoteHeapInfo {
|
||||
* @param[in] The identifier for this processing element
|
||||
* @param[in] The total number of processing elements
|
||||
*/
|
||||
RemoteHeapInfo(char* heap_ptr, size_t heap_size)
|
||||
: communicator_{heap_ptr, heap_size} {
|
||||
RemoteHeapInfo(char* heap_ptr, size_t heap_size,
|
||||
MPI_Comm comm = MPI_COMM_WORLD)
|
||||
: communicator_{heap_ptr, heap_size, comm} {
|
||||
heap_bases_.resize(communicator_.num_pes());
|
||||
for (auto& base : heap_bases_) {
|
||||
base = nullptr;
|
||||
|
||||
@@ -54,6 +54,10 @@ class SymmetricHeap {
|
||||
using RemoteHeapInfoType = RemoteHeapInfo<CommunicatorMPI>;
|
||||
|
||||
public:
|
||||
SymmetricHeap(MPI_Comm comm = MPI_COMM_WORLD)
|
||||
: remote_heap_info_{single_heap_.get_base_ptr(),
|
||||
single_heap_.get_size(),
|
||||
comm} {}
|
||||
/**
|
||||
* @brief Allocates heap memory and returns ptr to caller
|
||||
*
|
||||
@@ -120,8 +124,7 @@ class SymmetricHeap {
|
||||
/**
|
||||
* @brief Implementation of remote heaps
|
||||
*/
|
||||
RemoteHeapInfoType remote_heap_info_{single_heap_.get_base_ptr(),
|
||||
single_heap_.get_size()};
|
||||
RemoteHeapInfoType remote_heap_info_{};
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -165,7 +165,7 @@ class WindowInfo {
|
||||
/**
|
||||
* @brief MPI Communicator
|
||||
*/
|
||||
MPI_Comm comm_{MPI_COMM_WORLD};
|
||||
MPI_Comm comm_{};
|
||||
|
||||
/**
|
||||
* @brief Owning pointer to MPI_Win
|
||||
|
||||
@@ -45,7 +45,7 @@ namespace rocshmem {
|
||||
extern rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
ROBackend::ROBackend(MPI_Comm comm)
|
||||
: Backend() {
|
||||
: Backend(comm) {
|
||||
type = BackendType::RO_BACKEND;
|
||||
|
||||
if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) {
|
||||
@@ -78,7 +78,7 @@ ROBackend::ROBackend(MPI_Comm comm)
|
||||
|
||||
queue_ = Queue(maximum_num_contexts_, queue_size_);
|
||||
|
||||
transport_ = new MPITransport(comm, &queue_);
|
||||
transport_ = new MPITransport(backend_comm, &queue_);
|
||||
num_pes = transport_->getNumPes();
|
||||
my_pe = transport_->getMyPe();
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ MPITransport::MPITransport(MPI_Comm comm, Queue* q)
|
||||
std::cerr << "MPI_THREAD_MULTIPLE support disabled.\n";
|
||||
}
|
||||
}
|
||||
if (comm == MPI_COMM_NULL) comm = MPI_COMM_WORLD;
|
||||
assert(comm != MPI_COMM_NULL);
|
||||
|
||||
NET_CHECK(MPI_Comm_dup(comm, &ro_net_comm_world));
|
||||
NET_CHECK(MPI_Comm_size(ro_net_comm_world, &num_pes));
|
||||
|
||||
+15
-4
@@ -197,13 +197,24 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_my_pe() {
|
||||
MPIInitSingleton *s = s->GetInstance();
|
||||
return s->get_rank();
|
||||
if(backend == nullptr) {
|
||||
MPIInitSingleton *s = s->GetInstance();
|
||||
return s->get_rank();
|
||||
}
|
||||
else
|
||||
{
|
||||
return backend->getMyPE();
|
||||
}
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_n_pes() {
|
||||
MPIInitSingleton *s = s->GetInstance();
|
||||
return s->get_nprocs();
|
||||
if(backend == nullptr) {
|
||||
MPIInitSingleton *s = s->GetInstance();
|
||||
return s->get_nprocs();
|
||||
}
|
||||
else {
|
||||
return backend->getNumPEs();
|
||||
}
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ void *rocshmem_malloc(size_t size) {
|
||||
|
||||
Référencer dans un nouveau ticket
Bloquer un utilisateur