RO/collectives: add linear algorithms using RPut/Rget (#58)
* RO/collectives: add linear algorithms using RPut/Rget
- make broadcast, alltoall and fcollect use a simple linear algorithm
using MPI_RPut/Rget, but without blocking in the execution
- remove the to_all interfaces, since they have been deprecated.
- remove the active-set interfaces, since they have been removed from
rocSHMEM
* avoid notification after barrier
Co-authored-by: Avinash Kethineedi <avinash.kethineedi@amd.com>
* disable allocation of ata_buffer
a temporary buffer of 128MB was allocated when creating a team. In
previous versions of the code, that buffer was used by some collective
operations. This is not the case for now. Therefore, do not allocate the
buffer for now. I am not removing the element itself from teh
structure, since we might need it in future versions again.
---------
Co-authored-by: Avinash Kethineedi <avinash.kethineedi@amd.com>
[ROCm/rocshmem commit: 908bd5bda3]
This commit is contained in:
@@ -37,11 +37,9 @@ enum ro_net_cmds {
|
||||
RO_NET_FENCE,
|
||||
RO_NET_QUIET,
|
||||
RO_NET_FINALIZE,
|
||||
RO_NET_TO_ALL,
|
||||
RO_NET_TEAM_REDUCE,
|
||||
RO_NET_SYNC,
|
||||
RO_NET_BARRIER_ALL,
|
||||
RO_NET_BROADCAST,
|
||||
RO_NET_TEAM_BROADCAST,
|
||||
RO_NET_ALLTOALL,
|
||||
RO_NET_FCOLLECT,
|
||||
|
||||
@@ -610,26 +610,11 @@ __device__ void build_queue_element(
|
||||
queue_element->ol2.pWrk = pWrk;
|
||||
queue_element->datatype = datatype;
|
||||
}
|
||||
if (type == RO_NET_TO_ALL) {
|
||||
queue_element->logPE_stride = logPE_stride;
|
||||
queue_element->PE_size = PE_size;
|
||||
queue_element->ol2.pWrk = pWrk;
|
||||
queue_element->pSync = pSync;
|
||||
queue_element->op = op;
|
||||
queue_element->datatype = datatype;
|
||||
}
|
||||
if (type == RO_NET_TEAM_REDUCE) {
|
||||
queue_element->op = op;
|
||||
queue_element->datatype = datatype;
|
||||
queue_element->team_comm = team_comm;
|
||||
}
|
||||
if (type == RO_NET_BROADCAST) {
|
||||
queue_element->logPE_stride = logPE_stride;
|
||||
queue_element->PE_size = PE_size;
|
||||
queue_element->pSync = pSync;
|
||||
queue_element->PE_root = PE_root;
|
||||
queue_element->datatype = datatype;
|
||||
}
|
||||
if (type == RO_NET_TEAM_BROADCAST) {
|
||||
queue_element->PE_root = PE_root;
|
||||
queue_element->datatype = datatype;
|
||||
|
||||
@@ -75,11 +75,6 @@ class ROContext : public Context {
|
||||
template <typename T>
|
||||
__device__ T g(const T *source, int pe);
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__device__ void to_all(T *dest, const T *source, int nreduce, int PE_start,
|
||||
int logPE_stride, int PE_size, T *pWrk,
|
||||
long *pSync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__device__ int reduce(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nreduce);
|
||||
@@ -145,42 +140,10 @@ class ROContext : public Context {
|
||||
__device__ void alltoall(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void alltoall_broadcast(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void alltoall_mpi(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void alltoall_gcen(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void alltoall_gcen2(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void fcollect(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void fcollect_broadcast(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void fcollect_mpi(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void fcollect_gcen(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems);
|
||||
|
||||
template <typename T>
|
||||
__device__ void fcollect_gcen2(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelems);
|
||||
|
||||
__device__ void putmem_wg(void *dest, const void *source, size_t nelems,
|
||||
int pe);
|
||||
|
||||
|
||||
@@ -138,15 +138,6 @@ class ROHostContext : public Context {
|
||||
__host__ void broadcast(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nelems, int pe_root);
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__host__ void to_all(T *dest, const T *source, int nreduce, int pe_start,
|
||||
int log_pe_stride, int pe_size, T *p_wrk,
|
||||
long *p_sync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__host__ void to_all(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__host__ void wait_until(T *ivars, int cmp, T val);
|
||||
|
||||
|
||||
@@ -127,23 +127,6 @@ __device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__device__ void ROContext::to_all(T *dest, const T *source, int nreduce,
|
||||
int PE_start, int logPE_stride, int PE_size,
|
||||
T *pWrk, long *pSync) {
|
||||
if (!is_thread_zero_in_block()) {
|
||||
__syncthreads();
|
||||
return;
|
||||
}
|
||||
|
||||
build_queue_element(RO_NET_TO_ALL, dest, const_cast<T *>(source), nreduce,
|
||||
PE_start, logPE_stride, PE_size, 0, pWrk, pSync,
|
||||
(MPI_Comm)NULL, ro_net_win_id, block_handle, true,
|
||||
get_status_flag(), Op, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ROContext::put(T *dest, const T *source, size_t nelems,
|
||||
int pe) {
|
||||
@@ -316,24 +299,6 @@ __device__ void ROContext::broadcast(rocshmem_team_t team, T *dest,
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ROContext::broadcast(T *dest, const T *source, int nelems,
|
||||
int pe_root, int pe_start,
|
||||
int log_pe_stride, int pe_size,
|
||||
long *p_sync) {
|
||||
if (!is_thread_zero_in_block()) {
|
||||
__syncthreads();
|
||||
return;
|
||||
}
|
||||
|
||||
build_queue_element(RO_NET_BROADCAST, dest, const_cast<T *>(source), nelems,
|
||||
pe_start, log_pe_stride, pe_size, pe_root, nullptr,
|
||||
p_sync, (MPI_Comm)NULL, ro_net_win_id, block_handle, true,
|
||||
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ROContext::alltoall(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelems) {
|
||||
|
||||
@@ -122,24 +122,6 @@ __host__ void ROHostContext::broadcast(rocshmem_team_t team, T *dest,
|
||||
host_interface->broadcast<T>(team, dest, source, nelems, pe_root);
|
||||
}
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__host__ void ROHostContext::to_all(T *dest, const T *source, int nreduce,
|
||||
int pe_start, int log_pe_stride,
|
||||
int pe_size, T *p_wrk, long *p_sync) {
|
||||
DPRINTF("Function: gpu_ib_host_to_all\n");
|
||||
|
||||
host_interface->to_all<T, Op>(dest, source, nreduce, pe_start, log_pe_stride,
|
||||
pe_size, p_wrk, p_sync);
|
||||
}
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__host__ void ROHostContext::to_all(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
DPRINTF("Function: Team-based ro_net_host_to_all\n");
|
||||
|
||||
host_interface->to_all<T, Op>(team, dest, source, nreduce);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ void ROHostContext::wait_until(T *ivars, int cmp, T val) {
|
||||
host_interface->wait_until<T>(ivars, cmp, val, context_window_info);
|
||||
|
||||
@@ -21,7 +21,6 @@
|
||||
*****************************************************************************/
|
||||
|
||||
#include "mpi_transport.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
@@ -170,21 +169,6 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm);
|
||||
break;
|
||||
case RO_NET_TO_ALL:
|
||||
reduction(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.PE, next_element.logPE_stride,
|
||||
next_element.PE_size, next_element.ol2.pWrk, next_element.pSync,
|
||||
static_cast<ROCSHMEM_OP>(next_element.op),
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.status, true);
|
||||
DPRINTF(
|
||||
"Received FLOAT_SUM_TO_ALL dst %p src %p size %lu "
|
||||
"PE_start %d, logPE_stride %d, PE_size %d, pWrk %p, pSync %p\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.logPE_stride, next_element.PE_size,
|
||||
next_element.ol2.pWrk, next_element.pSync);
|
||||
break;
|
||||
case RO_NET_TEAM_BROADCAST:
|
||||
team_broadcast(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.ro_net_win_id, queue_idx,
|
||||
@@ -197,20 +181,6 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm, next_element.PE_root);
|
||||
break;
|
||||
case RO_NET_BROADCAST:
|
||||
broadcast(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.ro_net_win_id, next_element.PE, queue_idx,
|
||||
next_element.PE, next_element.logPE_stride,
|
||||
next_element.PE_size, next_element.PE_root, next_element.pSync,
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.status, true);
|
||||
DPRINTF(
|
||||
"Received BROADCAST dst %p src %p size %lu PE_start %d, "
|
||||
"logPE_stride %d, PE_size %d, PE_root %d, pSync %p\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.logPE_stride, next_element.PE_size,
|
||||
next_element.PE_root, next_element.pSync);
|
||||
break;
|
||||
case RO_NET_ALLTOALL:
|
||||
alltoall(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.ro_net_win_id, queue_idx, next_element.team_comm,
|
||||
@@ -294,41 +264,6 @@ void MPITransport::createNewTeam(ROBackend *backend, Team *parent_team,
|
||||
*new_team = get_external_team(new_team_obj);
|
||||
}
|
||||
|
||||
MPI_Comm MPITransport::createComm(int start, int stride, int size) {
|
||||
CommKey key(start, stride, size);
|
||||
auto it{comm_map.find(key)};
|
||||
if (it != comm_map.end()) {
|
||||
DPRINTF("Using cached communicator\n");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int world_size{};
|
||||
NET_CHECK(MPI_Comm_size(ro_net_comm_world, &world_size));
|
||||
|
||||
MPI_Comm comm{};
|
||||
if (start == 0 && stride == 1 && size == world_size) {
|
||||
NET_CHECK(MPI_Comm_dup(ro_net_comm_world, &comm));
|
||||
} else {
|
||||
MPI_Group world_group{};
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_group));
|
||||
|
||||
std::vector<int> group_ranks(size);
|
||||
group_ranks[0] = start;
|
||||
for (int i{1}; i < size; i++) {
|
||||
group_ranks[i] = group_ranks[i - 1] + stride;
|
||||
}
|
||||
|
||||
MPI_Group new_group{};
|
||||
NET_CHECK(MPI_Group_incl(world_group, size, group_ranks.data(), &new_group));
|
||||
NET_CHECK(MPI_Comm_create_group(ro_net_comm_world, new_group, 0, &comm));
|
||||
}
|
||||
|
||||
comm_map.insert(std::pair<CommKey, MPI_Comm>(key, comm));
|
||||
DPRINTF("Creating new communicator\n");
|
||||
|
||||
return comm;
|
||||
}
|
||||
|
||||
void MPITransport::global_exit(int status) {
|
||||
MPI_Abort(ro_net_comm_world, status);
|
||||
}
|
||||
@@ -396,53 +331,6 @@ static MPI_Datatype convertType(ro_net_types type) {
|
||||
}
|
||||
}
|
||||
|
||||
void MPITransport::reduction(void *dst, void *src, int size, int pe,
|
||||
int win_id, int contextId, int start,
|
||||
int logPstride, int sizePE, void *pWrk,
|
||||
long *pSync, ROCSHMEM_OP op, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
MPI_Request request{};
|
||||
MPI_Op mpi_op{get_mpi_op(op)};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)};
|
||||
|
||||
if (dst == src) {
|
||||
NET_CHECK(MPI_Iallreduce(MPI_IN_PLACE, dst, size, mpi_type, mpi_op, comm,
|
||||
&request));
|
||||
} else {
|
||||
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
|
||||
}
|
||||
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::broadcast(void *dst, void *src, int size, int pe,
|
||||
int win_id, int contextId, int start,
|
||||
int logPstride, int sizePE, int root, long *pSync,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)};
|
||||
|
||||
int new_rank{};
|
||||
MPI_Comm_rank(comm, &new_rank);
|
||||
|
||||
void *data{nullptr};
|
||||
if (new_rank == root) {
|
||||
data = src;
|
||||
} else {
|
||||
data = dst;
|
||||
}
|
||||
|
||||
MPI_Request request{};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
|
||||
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char* status,
|
||||
@@ -469,570 +357,136 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, int root,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
MPI_Comm comm{team};
|
||||
int new_rank{};
|
||||
MPI_Comm_rank(comm, &new_rank);
|
||||
void *data{nullptr};
|
||||
if (new_rank == root) {
|
||||
data = src;
|
||||
} else {
|
||||
data = dst;
|
||||
}
|
||||
int rank{}, pe_size{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp{}, world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
|
||||
|
||||
std::vector<int> ranks(pe_size);
|
||||
std::vector<int> world_ranks(pe_size);
|
||||
|
||||
for (int i = 0; i < pe_size; i++) ranks[i] = i;
|
||||
|
||||
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
MPI_Request request{};
|
||||
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
|
||||
MPI_Request req;
|
||||
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
if (rank != root){
|
||||
NET_CHECK(MPI_Rget(reinterpret_cast<char *>(dst), size, mpi_type, world_ranks[root],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(src)),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &req));
|
||||
|
||||
outstanding[contextId]++;
|
||||
requests.push_back({req, {nullptr, contextId, false}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
barrier(contextId, nullptr, false, comm);
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
int pe_size{};
|
||||
NET_CHECK(MPI_Comm_size(team, &pe_size));
|
||||
|
||||
int type_size{};
|
||||
NET_CHECK(MPI_Type_size(convertType(type), &type_size));
|
||||
|
||||
int num_clust = sqrt(pe_size);
|
||||
int clust_size{(pe_size + num_clust - 1) / num_clust};
|
||||
|
||||
#ifdef A2A_HEURISTICS
|
||||
if ((pe_size >= 8 || type_size * size < 2048) &&
|
||||
num_clust * clust_size == pe_size) {
|
||||
return alltoall_gcen(dst, src, size, win_id, contextId, team, ata_buffptr,
|
||||
type, status, blocking);
|
||||
} else if (size <= 512) {
|
||||
#endif // A2A_HEURISTICS
|
||||
return alltoall_mpi(dst, src, size, contextId, team, ata_buffptr, type,
|
||||
status, blocking);
|
||||
#ifdef A2A_HEURISTICS
|
||||
} else {
|
||||
return alltoall_broadcast(dst, src, size, win_id, contextId, team,
|
||||
ata_buffptr, type, status, blocking);
|
||||
}
|
||||
#endif // A2A_HEURISTICS
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_broadcast(void *dst, void *src, int size,
|
||||
int win_id, int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
MPI_Comm comm{team};
|
||||
int new_rank{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
int pe_size{};
|
||||
int rank{}, pe_size{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp{};
|
||||
MPI_Group grp{}, world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
MPI_Group world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
|
||||
|
||||
int grp_size{};
|
||||
std::vector<int> ranks(pe_size);
|
||||
std::vector<int> world_ranks(pe_size);
|
||||
for (int i = 0; i < pe_size; i++) ranks[i] = i;
|
||||
|
||||
NET_CHECK(MPI_Group_size(grp, &grp_size));
|
||||
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
std::vector<int> ranks(grp_size);
|
||||
std::vector<int> world_ranks(grp_size);
|
||||
|
||||
for (int i{0}; i < grp_size; i++) ranks[i] = i;
|
||||
|
||||
NET_CHECK(
|
||||
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
int type_size{};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
int type_size{};
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
std::vector<MPI_Request> pe_req(pe_size);
|
||||
|
||||
for (int i{0}; i < pe_size; ++i) {
|
||||
int src_offset{i * type_size * size};
|
||||
int dst_offset{new_rank * type_size * size};
|
||||
if (dst == src) {
|
||||
fprintf(stderr, "IN_PLACE option not support for alltoall in the RO rocSHMEM conduit\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
std::vector<MPI_Request> pe_req(pe_size);
|
||||
for (int i = 0; i < pe_size; ++i) {
|
||||
int target = (rank + i) % pe_size;
|
||||
int src_offset = target * type_size * size;
|
||||
int dst_offset = rank * type_size * size;
|
||||
NET_CHECK(MPI_Rput(reinterpret_cast<char *>(src) + src_offset, size,
|
||||
mpi_type, world_ranks[i],
|
||||
bp->heap_window_info[win_id]->get_offset(
|
||||
reinterpret_cast<char *>(dst) + dst_offset),
|
||||
mpi_type, world_ranks[target],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) + dst_offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
|
||||
&pe_req[i]));
|
||||
requests.push_back({pe_req[i], {nullptr, contextId, false}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
barrier(contextId, status, blocking, comm);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
int new_rank{};
|
||||
NET_CHECK(MPI_Comm_rank(team, &new_rank));
|
||||
int pe_size{};
|
||||
NET_CHECK(MPI_Comm_size(team, &pe_size));
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Alltoall(src, size, mpi_type, dst, size, mpi_type, team));
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
int new_rank{};
|
||||
NET_CHECK(MPI_Comm_rank(team, &new_rank));
|
||||
int pe_size{};
|
||||
NET_CHECK(MPI_Comm_size(team, &pe_size));
|
||||
|
||||
MPI_Group grp{};
|
||||
NET_CHECK(MPI_Comm_group(team, &grp));
|
||||
MPI_Group world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
|
||||
|
||||
int grp_size{};
|
||||
NET_CHECK(MPI_Group_size(grp, &grp_size));
|
||||
|
||||
std::vector<int> ranks(grp_size);
|
||||
std::vector<int> world_ranks(grp_size);
|
||||
|
||||
for (int i{0}; i < grp_size; i++) ranks[i] = i;
|
||||
NET_CHECK(
|
||||
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
int type_size{};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
|
||||
int num_clust = sqrt(pe_size);
|
||||
int clust_size{(pe_size + num_clust - 1) / num_clust};
|
||||
assert(num_clust * clust_size == pe_size);
|
||||
int clust_id{new_rank / clust_size};
|
||||
|
||||
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
|
||||
fprintf(stderr, "Alltoall size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
|
||||
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
|
||||
abort();
|
||||
}
|
||||
|
||||
std::vector<MPI_Request> clust_req(pe_size);
|
||||
|
||||
// Step 1: Send data to PEs in cluster
|
||||
for (int i{0}; i < pe_size; ++i) {
|
||||
int src_offset{(new_rank % clust_size + (i / clust_size) * clust_size) *
|
||||
type_size * size};
|
||||
int dst_offset{i * type_size * size};
|
||||
NET_CHECK(MPI_Rget(
|
||||
reinterpret_cast<void *>(
|
||||
(reinterpret_cast<char *>(ata_buffptr) + dst_offset)),
|
||||
size, mpi_type, world_ranks[clust_id * clust_size + (i % clust_size)],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(src) +
|
||||
src_offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
|
||||
&clust_req[i]));
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Waitall(pe_size, clust_req.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
// Step 2: Send final data to PEs outside cluster
|
||||
for (int i{0}; i < num_clust; ++i) {
|
||||
int src_offset{i * type_size * size * clust_size};
|
||||
int dst_offset{clust_id * type_size * size * clust_size};
|
||||
NET_CHECK(MPI_Put(
|
||||
reinterpret_cast<void *>(
|
||||
(reinterpret_cast<char *>(ata_buffptr) + src_offset)),
|
||||
size * clust_size, mpi_type,
|
||||
world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_offset(dst) + dst_offset,
|
||||
size * clust_size, mpi_type, bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet.
|
||||
NET_CHECK(
|
||||
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
}
|
||||
|
||||
int stride{world_ranks[1] - world_ranks[0]};
|
||||
MPI_Comm comm_cluster{
|
||||
createComm(world_ranks[clust_id * clust_size], stride, clust_size)};
|
||||
MPI_Comm comm_ring{createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust)};
|
||||
|
||||
barrier(contextId, status, false, comm_cluster);
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// GPU-centric alltoall with in-place blocking synchronization
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
|
||||
MPI_Datatype mpi_type = convertType(type);
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp, world_grp;
|
||||
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
|
||||
int grp_size;
|
||||
NET_CHECK(MPI_Group_size(grp, &grp_size));
|
||||
|
||||
std::vector<int> ranks(grp_size);
|
||||
std::vector<int> world_ranks(grp_size);
|
||||
|
||||
for (int i = 0; i < grp_size; i++) ranks[i] = i;
|
||||
// Convert comm ranks to global ranks for rput
|
||||
NET_CHECK(
|
||||
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
int type_size;
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
|
||||
// Works when number of PEs divisible by root(PE_size)
|
||||
int num_clust = sqrt(pe_size);
|
||||
int clust_size = (pe_size + num_clust - 1) / num_clust;
|
||||
// TODO(bpotter) Allow any size of cluster
|
||||
assert(num_clust * clust_size == pe_size);
|
||||
int clust_id = new_rank / clust_size;
|
||||
|
||||
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
|
||||
fprintf(stderr, "Alltoall size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
|
||||
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
|
||||
abort();
|
||||
}
|
||||
|
||||
std::vector<MPI_Request> clust_req(pe_size);
|
||||
|
||||
// Step 1: Send data to PEs in cluster
|
||||
for (int i = 0; i < pe_size; ++i) {
|
||||
int src_offset = (new_rank % clust_size + (i / clust_size) * clust_size) *
|
||||
type_size * size;
|
||||
int dst_offset = i * type_size * size;
|
||||
NET_CHECK(MPI_Rget(reinterpret_cast<void *>(
|
||||
reinterpret_cast<char *>(ata_buffptr) + dst_offset),
|
||||
size, mpi_type,
|
||||
world_ranks[clust_id * clust_size + (i % clust_size)],
|
||||
bp->heap_window_info[win_id]->get_offset(
|
||||
reinterpret_cast<char *>(src) + src_offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
|
||||
&clust_req[i]));
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Waitall(pe_size, clust_req.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
// Now wait
|
||||
int stride = world_ranks[1] - world_ranks[0];
|
||||
MPI_Comm comm_cluster =
|
||||
createComm(world_ranks[clust_id * clust_size], stride, clust_size);
|
||||
MPI_Barrier(comm_cluster);
|
||||
|
||||
// Step 2: Send final data to PEs outside cluster
|
||||
for (int i = 0; i < num_clust; ++i) {
|
||||
int src_offset = i * type_size * size * clust_size;
|
||||
int dst_offset = clust_id * type_size * size * clust_size;
|
||||
NET_CHECK(MPI_Put(
|
||||
reinterpret_cast<void *>(reinterpret_cast<char *>(ata_buffptr) +
|
||||
src_offset),
|
||||
size * clust_size, mpi_type,
|
||||
world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_offset(dst) + dst_offset,
|
||||
size * clust_size, mpi_type, bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet.
|
||||
NET_CHECK(
|
||||
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
}
|
||||
|
||||
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust);
|
||||
// Now wait for completion
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
int pe_size, type_size;
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Datatype mpi_type = convertType(type);
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
|
||||
// Currently GPU-centric algo only supports multiples of square root
|
||||
// TODO(bpotter) Allow any size of cluster
|
||||
int num_clust = sqrt(pe_size);
|
||||
int clust_size = (pe_size + num_clust - 1) / num_clust;
|
||||
|
||||
// In most cases the MPI implementation is optimal
|
||||
// But it crashes for > 512 messages
|
||||
if (size <= 512) {
|
||||
fcollect_mpi(dst, src, size, contextId, team, ata_buffptr, type,
|
||||
status, blocking);
|
||||
} else if (num_clust * clust_size == pe_size) {
|
||||
fcollect_gcen(dst, src, size, win_id, contextId, team, ata_buffptr, type,
|
||||
status, blocking);
|
||||
} else {
|
||||
fcollect_broadcast(dst, src, size, win_id, contextId, team, ata_buffptr,
|
||||
type, status, blocking);
|
||||
}
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_broadcast(void *dst, void *src, int size,
|
||||
int win_id, int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// Broadcast implementation of fcollect
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
|
||||
MPI_Datatype mpi_type = convertType(type);
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
MPI_Comm comm{team};
|
||||
int rank{}, pe_size{};
|
||||
NET_CHECK(MPI_Comm_rank(comm, &rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp, world_grp;
|
||||
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
|
||||
MPI_Group grp{}, world_grp{};
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
|
||||
|
||||
int grp_size;
|
||||
NET_CHECK(MPI_Group_size(grp, &grp_size));
|
||||
std::vector<int> ranks(pe_size);
|
||||
std::vector<int> world_ranks(pe_size);
|
||||
|
||||
std::vector<int> ranks(grp_size);
|
||||
std::vector<int> world_ranks(grp_size);
|
||||
for (int i = 0; i < pe_size; i++) ranks[i] = i;
|
||||
|
||||
for (int i = 0; i < grp_size; i++) ranks[i] = i;
|
||||
// Convert comm ranks to global ranks for rput
|
||||
NET_CHECK(
|
||||
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
int type_size;
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
int type_size{};
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
|
||||
if (dst == src) {
|
||||
fprintf(stderr, "IN_PLACE option not support for fcollect in the RO rocSHMEM conduit\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
std::vector<MPI_Request> pe_req(pe_size);
|
||||
|
||||
// Put data to all PEs
|
||||
for (int i = 0; i < pe_size; ++i) {
|
||||
int dst_offset = new_rank * type_size * size;
|
||||
NET_CHECK(MPI_Rput(
|
||||
reinterpret_cast<char *>(src), size, mpi_type, world_ranks[i],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) +
|
||||
dst_offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &pe_req[i]));
|
||||
int target = (rank + i) % pe_size;
|
||||
int offset = rank * type_size * size;
|
||||
NET_CHECK(MPI_Rput(reinterpret_cast<char *>(src), size, mpi_type, world_ranks[target],
|
||||
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) + offset),
|
||||
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &pe_req[i]));
|
||||
|
||||
requests.push_back({pe_req[i], {nullptr, contextId, false}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
// Now wait for completion
|
||||
barrier(contextId, status, blocking, comm);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
// MPI's implementation of fcollect
|
||||
int new_rank, pe_size;
|
||||
|
||||
MPI_Datatype mpi_type = convertType(type);
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
NET_CHECK(MPI_Allgather(src, size, mpi_type, dst, size, mpi_type, comm));
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// GPU-centric implementation of fcollect
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
|
||||
MPI_Datatype mpi_type = convertType(type);
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp, world_grp;
|
||||
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
|
||||
int grp_size;
|
||||
NET_CHECK(MPI_Group_size(grp, &grp_size));
|
||||
|
||||
std::vector<int> ranks(grp_size);
|
||||
std::vector<int> world_ranks(grp_size);
|
||||
|
||||
for (int i = 0; i < grp_size; i++) ranks[i] = i;
|
||||
// Convert comm ranks to global ranks for rput
|
||||
NET_CHECK(
|
||||
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
int type_size;
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
|
||||
// Works when number of PEs divisible by root(PE_size)
|
||||
int num_clust = sqrt(pe_size);
|
||||
int clust_size = (pe_size + num_clust - 1) / num_clust;
|
||||
// TODO(bpotter) Allow any size of cluster
|
||||
assert(num_clust * clust_size == pe_size);
|
||||
int clust_id = new_rank / clust_size;
|
||||
|
||||
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
|
||||
fprintf(stderr, "Fcollect size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
|
||||
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
|
||||
abort();
|
||||
}
|
||||
|
||||
std::vector<MPI_Request> clust_req(pe_size);
|
||||
|
||||
// Step 1: Send data to PEs in cluster
|
||||
for (int i = 0; i < clust_size; ++i) {
|
||||
int dst_offset = i * type_size * size;
|
||||
NET_CHECK(MPI_Rget(
|
||||
reinterpret_cast<void *>(reinterpret_cast<char *>(ata_buffptr) +
|
||||
dst_offset),
|
||||
size, mpi_type, world_ranks[clust_id * clust_size + (i % clust_size)],
|
||||
bp->heap_window_info[win_id]->get_offset(src), size, mpi_type,
|
||||
bp->heap_window_info[win_id]->get_win(), &clust_req[i]));
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Waitall(clust_size, clust_req.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
// Step 2: Send final data to PEs outside cluster
|
||||
for (int i = 0; i < num_clust; ++i) {
|
||||
int src_offset = i * type_size * size * clust_size;
|
||||
int dst_offset = clust_id * type_size * size * clust_size;
|
||||
NET_CHECK(MPI_Put(ata_buffptr, size * clust_size, mpi_type,
|
||||
world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_offset(
|
||||
reinterpret_cast<char *>(dst) + dst_offset),
|
||||
size * clust_size, mpi_type,
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet.
|
||||
NET_CHECK(
|
||||
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
}
|
||||
|
||||
int stride = world_ranks[1] - world_ranks[0];
|
||||
MPI_Comm comm_cluster =
|
||||
createComm(world_ranks[clust_id * clust_size], stride, clust_size);
|
||||
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust);
|
||||
// Now wait for completion
|
||||
barrier(contextId, status, false, comm_cluster);
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// GPU-centric implementation with in-place, blocking synchronization
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
|
||||
MPI_Datatype mpi_type = convertType(type);
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
|
||||
MPI_Group grp, world_grp;
|
||||
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
|
||||
NET_CHECK(MPI_Comm_group(comm, &grp));
|
||||
|
||||
int grp_size;
|
||||
NET_CHECK(MPI_Group_size(grp, &grp_size));
|
||||
|
||||
std::vector<int> ranks(grp_size);
|
||||
std::vector<int> world_ranks(grp_size);
|
||||
|
||||
for (int i = 0; i < grp_size; i++) ranks[i] = i;
|
||||
// Convert comm ranks to global ranks for rput
|
||||
NET_CHECK(
|
||||
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
|
||||
|
||||
int type_size;
|
||||
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
|
||||
|
||||
// Works when number of PEs divisible by root(PE_size)
|
||||
int num_clust = sqrt(pe_size);
|
||||
int clust_size = (pe_size + num_clust - 1) / num_clust;
|
||||
// TODO(bpotter) Allow any size of cluster
|
||||
assert(num_clust * clust_size == pe_size);
|
||||
int clust_id = new_rank / clust_size;
|
||||
|
||||
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
|
||||
fprintf(stderr, "Fcollect size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
|
||||
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
|
||||
abort();
|
||||
}
|
||||
|
||||
std::vector<MPI_Request> clust_req(pe_size);
|
||||
|
||||
// Step 1: Send data to PEs in cluster
|
||||
for (int i = 0; i < clust_size; ++i) {
|
||||
int dst_offset = i * type_size * size;
|
||||
NET_CHECK(MPI_Rget(
|
||||
reinterpret_cast<void *>(reinterpret_cast<char *>(ata_buffptr) +
|
||||
dst_offset),
|
||||
size, mpi_type, world_ranks[clust_id * clust_size + (i % clust_size)],
|
||||
bp->heap_window_info[win_id]->get_offset(src), size, mpi_type,
|
||||
bp->heap_window_info[win_id]->get_win(), &clust_req[i]));
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Waitall(clust_size, clust_req.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
int stride = world_ranks[1] - world_ranks[0];
|
||||
MPI_Comm comm_cluster =
|
||||
createComm(world_ranks[clust_id * clust_size], stride, clust_size);
|
||||
MPI_Barrier(comm_cluster);
|
||||
|
||||
// Step 2: Send final data to PEs outside cluster
|
||||
for (int i = 0; i < num_clust; ++i) {
|
||||
int src_offset = i * type_size * size * clust_size;
|
||||
int dst_offset = clust_id * type_size * size * clust_size;
|
||||
NET_CHECK(MPI_Put(ata_buffptr, size * clust_size, mpi_type,
|
||||
world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_offset(
|
||||
reinterpret_cast<char *>(dst) + dst_offset),
|
||||
size * clust_size, mpi_type,
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
// Since MPI makes puts as complete as soon as the local buffer is free,
|
||||
// we need a flush to satisfy quiet.
|
||||
NET_CHECK(
|
||||
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
|
||||
bp->heap_window_info[win_id]->get_win()));
|
||||
}
|
||||
|
||||
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust);
|
||||
// Now wait for completion
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int contextId, volatile char *status, bool blocking,
|
||||
bool inline_data) {
|
||||
@@ -1136,7 +590,6 @@ void MPITransport::progress() {
|
||||
int outcount{};
|
||||
|
||||
auto uptr_req_arr {raw_requests()};
|
||||
|
||||
NET_CHECK(MPI_Testsome(incount, uptr_req_arr.get(), &outcount,
|
||||
testsome_indices.data(), MPI_STATUSES_IGNORE));
|
||||
|
||||
|
||||
@@ -54,21 +54,11 @@ public:
|
||||
void barrier(int contextId, volatile char *status, bool blocking,
|
||||
MPI_Comm team) override;
|
||||
|
||||
void reduction(void *dst, void *src, int size, int pe, int win_id,
|
||||
int contextId, int start, int logPstride, int sizePE,
|
||||
void *pWrk, long *pSync, ROCSHMEM_OP op, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) override;
|
||||
|
||||
void broadcast(void *dst, void *src, int size, int pe, int win_id,
|
||||
int contextId, int start, int logPstride, int sizePE,
|
||||
int PE_root, long *pSync, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, int PE_root,
|
||||
ro_net_types type, volatile char *status,
|
||||
@@ -78,44 +68,10 @@ public:
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void alltoall_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking);
|
||||
|
||||
void alltoall_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking);
|
||||
|
||||
void alltoall_gcen(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status, bool blocking);
|
||||
|
||||
void alltoall_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status, bool blocking);
|
||||
|
||||
void fcollect(void *dst, void *src, int size, int win_id, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void fcollect_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking);
|
||||
|
||||
void fcollect_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking);
|
||||
|
||||
void fcollect_gcen(void *dst, void *src, int size, int win_id, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking);
|
||||
|
||||
void fcollect_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status, bool blocking);
|
||||
|
||||
void putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int contextId, volatile char *status, bool blocking,
|
||||
bool inline_data = false) override;
|
||||
|
||||
@@ -34,7 +34,9 @@ ROTeam::ROTeam(Backend* backend, TeamInfo* team_info_wrt_parent,
|
||||
mpi_comm) {
|
||||
type = BackendType::RO_BACKEND;
|
||||
|
||||
ata_buffer = malloc(MAX_ATA_BUFF_SIZE);
|
||||
// Disable allocating ata_buffer for now. It is not
|
||||
// used at the moment, but might come back in future versions.
|
||||
ata_buffer = nullptr;
|
||||
}
|
||||
|
||||
ROTeam::~ROTeam() {
|
||||
|
||||
@@ -53,22 +53,11 @@ class Transport {
|
||||
virtual void barrier(int wg_id, volatile char *status, bool blocking,
|
||||
MPI_Comm team) = 0;
|
||||
|
||||
virtual void reduction(void *dst, void *src, int size, int pe, int win_id,
|
||||
int wg_id, int start, int logPstride, int sizePE,
|
||||
void *pWrk, long *pSync, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) = 0;
|
||||
|
||||
virtual void team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) = 0;
|
||||
|
||||
virtual void broadcast(void *dst, void *src, int size, int pe, int win_id,
|
||||
int wg_id, int start, int logPstride, int sizePE,
|
||||
int PE_root, long *pSync, ro_net_types type,
|
||||
volatile char *status, bool blocking) = 0;
|
||||
|
||||
virtual void team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int wg_id, MPI_Comm team, int PE_root,
|
||||
ro_net_types type, volatile char *status,
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user