RO/collectives: add linear algorithms using RPut/Rget (#58)

* RO/collectives: add linear algorithms using RPut/Rget

- make broadcast, alltoall and fcollect use a simple linear algorithm
  using MPI_RPut/Rget, but without blocking in the execution
- remove the to_all interfaces, since they have been deprecated.
- remove the active-set interfaces, since they have been removed from
  rocSHMEM

* avoid notification after barrier

Co-authored-by: Avinash Kethineedi <avinash.kethineedi@amd.com>

* disable allocation of ata_buffer

a temporary buffer of 128MB was allocated when creating a team. In
previous versions of the code, that buffer was used by some collective
operations. This is not the case for now. Therefore, do not allocate the
buffer for now. I am not removing the element itself from teh
structure, since we might need it in future versions again.

---------

Co-authored-by: Avinash Kethineedi <avinash.kethineedi@amd.com>

[ROCm/rocshmem commit: 908bd5bda3]
This commit is contained in:
Edgar Gabriel
2025-03-21 12:49:39 -05:00
committato da GitHub
parent 1380f43156
commit 033253fbdf
10 ha cambiato i file con 80 aggiunte e 796 eliminazioni
@@ -37,11 +37,9 @@ enum ro_net_cmds {
RO_NET_FENCE,
RO_NET_QUIET,
RO_NET_FINALIZE,
RO_NET_TO_ALL,
RO_NET_TEAM_REDUCE,
RO_NET_SYNC,
RO_NET_BARRIER_ALL,
RO_NET_BROADCAST,
RO_NET_TEAM_BROADCAST,
RO_NET_ALLTOALL,
RO_NET_FCOLLECT,
@@ -610,26 +610,11 @@ __device__ void build_queue_element(
queue_element->ol2.pWrk = pWrk;
queue_element->datatype = datatype;
}
if (type == RO_NET_TO_ALL) {
queue_element->logPE_stride = logPE_stride;
queue_element->PE_size = PE_size;
queue_element->ol2.pWrk = pWrk;
queue_element->pSync = pSync;
queue_element->op = op;
queue_element->datatype = datatype;
}
if (type == RO_NET_TEAM_REDUCE) {
queue_element->op = op;
queue_element->datatype = datatype;
queue_element->team_comm = team_comm;
}
if (type == RO_NET_BROADCAST) {
queue_element->logPE_stride = logPE_stride;
queue_element->PE_size = PE_size;
queue_element->pSync = pSync;
queue_element->PE_root = PE_root;
queue_element->datatype = datatype;
}
if (type == RO_NET_TEAM_BROADCAST) {
queue_element->PE_root = PE_root;
queue_element->datatype = datatype;
@@ -75,11 +75,6 @@ class ROContext : public Context {
template <typename T>
__device__ T g(const T *source, int pe);
template <typename T, ROCSHMEM_OP Op>
__device__ void to_all(T *dest, const T *source, int nreduce, int PE_start,
int logPE_stride, int PE_size, T *pWrk,
long *pSync); // NOLINT(runtime/int)
template <typename T, ROCSHMEM_OP Op>
__device__ int reduce(rocshmem_team_t team, T *dest, const T *source,
int nreduce);
@@ -145,42 +140,10 @@ class ROContext : public Context {
__device__ void alltoall(rocshmem_team_t team, T *dest, const T *source,
int nelems);
template <typename T>
__device__ void alltoall_broadcast(rocshmem_team_t team, T *dest,
const T *source, int nelems);
template <typename T>
__device__ void alltoall_mpi(rocshmem_team_t team, T *dest, const T *source,
int nelems);
template <typename T>
__device__ void alltoall_gcen(rocshmem_team_t team, T *dest, const T *source,
int nelems);
template <typename T>
__device__ void alltoall_gcen2(rocshmem_team_t team, T *dest,
const T *source, int nelems);
template <typename T>
__device__ void fcollect(rocshmem_team_t team, T *dest, const T *source,
int nelems);
template <typename T>
__device__ void fcollect_broadcast(rocshmem_team_t team, T *dest,
const T *source, int nelems);
template <typename T>
__device__ void fcollect_mpi(rocshmem_team_t team, T *dest, const T *source,
int nelems);
template <typename T>
__device__ void fcollect_gcen(rocshmem_team_t team, T *dest, const T *source,
int nelems);
template <typename T>
__device__ void fcollect_gcen2(rocshmem_team_t team, T *dest,
const T *source, int nelems);
__device__ void putmem_wg(void *dest, const void *source, size_t nelems,
int pe);
@@ -138,15 +138,6 @@ class ROHostContext : public Context {
__host__ void broadcast(rocshmem_team_t team, T *dest, const T *source,
int nelems, int pe_root);
template <typename T, ROCSHMEM_OP Op>
__host__ void to_all(T *dest, const T *source, int nreduce, int pe_start,
int log_pe_stride, int pe_size, T *p_wrk,
long *p_sync); // NOLINT(runtime/int)
template <typename T, ROCSHMEM_OP Op>
__host__ void to_all(rocshmem_team_t team, T *dest, const T *source,
int nreduce);
template <typename T>
__host__ void wait_until(T *ivars, int cmp, T val);
@@ -127,23 +127,6 @@ __device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
return ROCSHMEM_SUCCESS;
}
template <typename T, ROCSHMEM_OP Op>
__device__ void ROContext::to_all(T *dest, const T *source, int nreduce,
int PE_start, int logPE_stride, int PE_size,
T *pWrk, long *pSync) {
if (!is_thread_zero_in_block()) {
__syncthreads();
return;
}
build_queue_element(RO_NET_TO_ALL, dest, const_cast<T *>(source), nreduce,
PE_start, logPE_stride, PE_size, 0, pWrk, pSync,
(MPI_Comm)NULL, ro_net_win_id, block_handle, true,
get_status_flag(), Op, GetROType<T>::Type);
__syncthreads();
}
template <typename T>
__device__ void ROContext::put(T *dest, const T *source, size_t nelems,
int pe) {
@@ -316,24 +299,6 @@ __device__ void ROContext::broadcast(rocshmem_team_t team, T *dest,
__syncthreads();
}
template <typename T>
__device__ void ROContext::broadcast(T *dest, const T *source, int nelems,
int pe_root, int pe_start,
int log_pe_stride, int pe_size,
long *p_sync) {
if (!is_thread_zero_in_block()) {
__syncthreads();
return;
}
build_queue_element(RO_NET_BROADCAST, dest, const_cast<T *>(source), nelems,
pe_start, log_pe_stride, pe_size, pe_root, nullptr,
p_sync, (MPI_Comm)NULL, ro_net_win_id, block_handle, true,
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
__syncthreads();
}
template <typename T>
__device__ void ROContext::alltoall(rocshmem_team_t team, T *dest,
const T *source, int nelems) {
@@ -122,24 +122,6 @@ __host__ void ROHostContext::broadcast(rocshmem_team_t team, T *dest,
host_interface->broadcast<T>(team, dest, source, nelems, pe_root);
}
template <typename T, ROCSHMEM_OP Op>
__host__ void ROHostContext::to_all(T *dest, const T *source, int nreduce,
int pe_start, int log_pe_stride,
int pe_size, T *p_wrk, long *p_sync) {
DPRINTF("Function: gpu_ib_host_to_all\n");
host_interface->to_all<T, Op>(dest, source, nreduce, pe_start, log_pe_stride,
pe_size, p_wrk, p_sync);
}
template <typename T, ROCSHMEM_OP Op>
__host__ void ROHostContext::to_all(rocshmem_team_t team, T *dest,
const T *source, int nreduce) {
DPRINTF("Function: Team-based ro_net_host_to_all\n");
host_interface->to_all<T, Op>(team, dest, source, nreduce);
}
template <typename T>
__host__ void ROHostContext::wait_until(T *ivars, int cmp, T val) {
host_interface->wait_until<T>(ivars, cmp, val, context_window_info);
@@ -21,7 +21,6 @@
*****************************************************************************/
#include "mpi_transport.hpp"
#include <algorithm>
#include <functional>
#include <utility>
@@ -170,21 +169,6 @@ void MPITransport::submitRequestsToMPI() {
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm);
break;
case RO_NET_TO_ALL:
reduction(next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.PE, next_element.logPE_stride,
next_element.PE_size, next_element.ol2.pWrk, next_element.pSync,
static_cast<ROCSHMEM_OP>(next_element.op),
static_cast<ro_net_types>(next_element.datatype),
next_element.status, true);
DPRINTF(
"Received FLOAT_SUM_TO_ALL dst %p src %p size %lu "
"PE_start %d, logPE_stride %d, PE_size %d, pWrk %p, pSync %p\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.logPE_stride, next_element.PE_size,
next_element.ol2.pWrk, next_element.pSync);
break;
case RO_NET_TEAM_BROADCAST:
team_broadcast(next_element.dst, next_element.src, next_element.ol1.size,
next_element.ro_net_win_id, queue_idx,
@@ -197,20 +181,6 @@ void MPITransport::submitRequestsToMPI() {
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm, next_element.PE_root);
break;
case RO_NET_BROADCAST:
broadcast(next_element.dst, next_element.src, next_element.ol1.size,
next_element.ro_net_win_id, next_element.PE, queue_idx,
next_element.PE, next_element.logPE_stride,
next_element.PE_size, next_element.PE_root, next_element.pSync,
static_cast<ro_net_types>(next_element.datatype),
next_element.status, true);
DPRINTF(
"Received BROADCAST dst %p src %p size %lu PE_start %d, "
"logPE_stride %d, PE_size %d, PE_root %d, pSync %p\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.logPE_stride, next_element.PE_size,
next_element.PE_root, next_element.pSync);
break;
case RO_NET_ALLTOALL:
alltoall(next_element.dst, next_element.src, next_element.ol1.size,
next_element.ro_net_win_id, queue_idx, next_element.team_comm,
@@ -294,41 +264,6 @@ void MPITransport::createNewTeam(ROBackend *backend, Team *parent_team,
*new_team = get_external_team(new_team_obj);
}
MPI_Comm MPITransport::createComm(int start, int stride, int size) {
CommKey key(start, stride, size);
auto it{comm_map.find(key)};
if (it != comm_map.end()) {
DPRINTF("Using cached communicator\n");
return it->second;
}
int world_size{};
NET_CHECK(MPI_Comm_size(ro_net_comm_world, &world_size));
MPI_Comm comm{};
if (start == 0 && stride == 1 && size == world_size) {
NET_CHECK(MPI_Comm_dup(ro_net_comm_world, &comm));
} else {
MPI_Group world_group{};
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_group));
std::vector<int> group_ranks(size);
group_ranks[0] = start;
for (int i{1}; i < size; i++) {
group_ranks[i] = group_ranks[i - 1] + stride;
}
MPI_Group new_group{};
NET_CHECK(MPI_Group_incl(world_group, size, group_ranks.data(), &new_group));
NET_CHECK(MPI_Comm_create_group(ro_net_comm_world, new_group, 0, &comm));
}
comm_map.insert(std::pair<CommKey, MPI_Comm>(key, comm));
DPRINTF("Creating new communicator\n");
return comm;
}
void MPITransport::global_exit(int status) {
MPI_Abort(ro_net_comm_world, status);
}
@@ -396,53 +331,6 @@ static MPI_Datatype convertType(ro_net_types type) {
}
}
void MPITransport::reduction(void *dst, void *src, int size, int pe,
int win_id, int contextId, int start,
int logPstride, int sizePE, void *pWrk,
long *pSync, ROCSHMEM_OP op, ro_net_types type,
volatile char *status, bool blocking) {
MPI_Request request{};
MPI_Op mpi_op{get_mpi_op(op)};
MPI_Datatype mpi_type{convertType(type)};
MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)};
if (dst == src) {
NET_CHECK(MPI_Iallreduce(MPI_IN_PLACE, dst, size, mpi_type, mpi_op, comm,
&request));
} else {
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
}
requests.push_back({request, {status, contextId, blocking}});
outstanding[contextId]++;
}
void MPITransport::broadcast(void *dst, void *src, int size, int pe,
int win_id, int contextId, int start,
int logPstride, int sizePE, int root, long *pSync,
ro_net_types type, volatile char *status,
bool blocking) {
MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)};
int new_rank{};
MPI_Comm_rank(comm, &new_rank);
void *data{nullptr};
if (new_rank == root) {
data = src;
} else {
data = dst;
}
MPI_Request request{};
MPI_Datatype mpi_type{convertType(type)};
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
requests.push_back({request, {status, contextId, blocking}});
outstanding[contextId]++;
}
void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, volatile char* status,
@@ -469,570 +357,136 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, int root,
ro_net_types type, volatile char *status,
bool blocking) {
auto *bp{backend_proxy->get()};
MPI_Comm comm{team};
int new_rank{};
MPI_Comm_rank(comm, &new_rank);
void *data{nullptr};
if (new_rank == root) {
data = src;
} else {
data = dst;
}
int rank{}, pe_size{};
NET_CHECK(MPI_Comm_rank(comm, &rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Group grp{}, world_grp{};
NET_CHECK(MPI_Comm_group(comm, &grp));
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
std::vector<int> ranks(pe_size);
std::vector<int> world_ranks(pe_size);
for (int i = 0; i < pe_size; i++) ranks[i] = i;
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
MPI_Datatype mpi_type{convertType(type)};
MPI_Request request{};
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
MPI_Request req;
requests.push_back({request, {status, contextId, blocking}});
if (rank != root){
NET_CHECK(MPI_Rget(reinterpret_cast<char *>(dst), size, mpi_type, world_ranks[root],
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(src)),
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &req));
outstanding[contextId]++;
requests.push_back({req, {nullptr, contextId, false}});
outstanding[contextId]++;
}
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
barrier(contextId, nullptr, false, comm);
quiet(contextId, status);
}
void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
int pe_size{};
NET_CHECK(MPI_Comm_size(team, &pe_size));
int type_size{};
NET_CHECK(MPI_Type_size(convertType(type), &type_size));
int num_clust = sqrt(pe_size);
int clust_size{(pe_size + num_clust - 1) / num_clust};
#ifdef A2A_HEURISTICS
if ((pe_size >= 8 || type_size * size < 2048) &&
num_clust * clust_size == pe_size) {
return alltoall_gcen(dst, src, size, win_id, contextId, team, ata_buffptr,
type, status, blocking);
} else if (size <= 512) {
#endif // A2A_HEURISTICS
return alltoall_mpi(dst, src, size, contextId, team, ata_buffptr, type,
status, blocking);
#ifdef A2A_HEURISTICS
} else {
return alltoall_broadcast(dst, src, size, win_id, contextId, team,
ata_buffptr, type, status, blocking);
}
#endif // A2A_HEURISTICS
}
void MPITransport::alltoall_broadcast(void *dst, void *src, int size,
int win_id, int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
auto *bp{backend_proxy->get()};
MPI_Comm comm{team};
int new_rank{};
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
int pe_size{};
int rank{}, pe_size{};
NET_CHECK(MPI_Comm_rank(comm, &rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Group grp{};
MPI_Group grp{}, world_grp{};
NET_CHECK(MPI_Comm_group(comm, &grp));
MPI_Group world_grp{};
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
int grp_size{};
std::vector<int> ranks(pe_size);
std::vector<int> world_ranks(pe_size);
for (int i = 0; i < pe_size; i++) ranks[i] = i;
NET_CHECK(MPI_Group_size(grp, &grp_size));
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
std::vector<int> ranks(grp_size);
std::vector<int> world_ranks(grp_size);
for (int i{0}; i < grp_size; i++) ranks[i] = i;
NET_CHECK(
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
int type_size{};
MPI_Datatype mpi_type{convertType(type)};
int type_size{};
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
std::vector<MPI_Request> pe_req(pe_size);
for (int i{0}; i < pe_size; ++i) {
int src_offset{i * type_size * size};
int dst_offset{new_rank * type_size * size};
if (dst == src) {
fprintf(stderr, "IN_PLACE option not support for alltoall in the RO rocSHMEM conduit\n");
abort();
}
std::vector<MPI_Request> pe_req(pe_size);
for (int i = 0; i < pe_size; ++i) {
int target = (rank + i) % pe_size;
int src_offset = target * type_size * size;
int dst_offset = rank * type_size * size;
NET_CHECK(MPI_Rput(reinterpret_cast<char *>(src) + src_offset, size,
mpi_type, world_ranks[i],
bp->heap_window_info[win_id]->get_offset(
reinterpret_cast<char *>(dst) + dst_offset),
mpi_type, world_ranks[target],
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) + dst_offset),
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
&pe_req[i]));
requests.push_back({pe_req[i], {nullptr, contextId, false}});
outstanding[contextId]++;
}
NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE));
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
barrier(contextId, status, blocking, comm);
}
void MPITransport::alltoall_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
int new_rank{};
NET_CHECK(MPI_Comm_rank(team, &new_rank));
int pe_size{};
NET_CHECK(MPI_Comm_size(team, &pe_size));
MPI_Datatype mpi_type{convertType(type)};
NET_CHECK(MPI_Alltoall(src, size, mpi_type, dst, size, mpi_type, team));
quiet(contextId, status);
}
void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
auto *bp{backend_proxy->get()};
int new_rank{};
NET_CHECK(MPI_Comm_rank(team, &new_rank));
int pe_size{};
NET_CHECK(MPI_Comm_size(team, &pe_size));
MPI_Group grp{};
NET_CHECK(MPI_Comm_group(team, &grp));
MPI_Group world_grp{};
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
int grp_size{};
NET_CHECK(MPI_Group_size(grp, &grp_size));
std::vector<int> ranks(grp_size);
std::vector<int> world_ranks(grp_size);
for (int i{0}; i < grp_size; i++) ranks[i] = i;
NET_CHECK(
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
int type_size{};
MPI_Datatype mpi_type{convertType(type)};
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
int num_clust = sqrt(pe_size);
int clust_size{(pe_size + num_clust - 1) / num_clust};
assert(num_clust * clust_size == pe_size);
int clust_id{new_rank / clust_size};
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
fprintf(stderr, "Alltoall size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
abort();
}
std::vector<MPI_Request> clust_req(pe_size);
// Step 1: Send data to PEs in cluster
for (int i{0}; i < pe_size; ++i) {
int src_offset{(new_rank % clust_size + (i / clust_size) * clust_size) *
type_size * size};
int dst_offset{i * type_size * size};
NET_CHECK(MPI_Rget(
reinterpret_cast<void *>(
(reinterpret_cast<char *>(ata_buffptr) + dst_offset)),
size, mpi_type, world_ranks[clust_id * clust_size + (i % clust_size)],
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(src) +
src_offset),
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
&clust_req[i]));
}
NET_CHECK(MPI_Waitall(pe_size, clust_req.data(), MPI_STATUSES_IGNORE));
// Step 2: Send final data to PEs outside cluster
for (int i{0}; i < num_clust; ++i) {
int src_offset{i * type_size * size * clust_size};
int dst_offset{clust_id * type_size * size * clust_size};
NET_CHECK(MPI_Put(
reinterpret_cast<void *>(
(reinterpret_cast<char *>(ata_buffptr) + src_offset)),
size * clust_size, mpi_type,
world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_offset(dst) + dst_offset,
size * clust_size, mpi_type, bp->heap_window_info[win_id]->get_win()));
// Since MPI makes puts as complete as soon as the local buffer is free,
// we need a flush to satisfy quiet.
NET_CHECK(
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_win()));
}
int stride{world_ranks[1] - world_ranks[0]};
MPI_Comm comm_cluster{
createComm(world_ranks[clust_id * clust_size], stride, clust_size)};
MPI_Comm comm_ring{createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust)};
barrier(contextId, status, false, comm_cluster);
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// GPU-centric alltoall with in-place blocking synchronization
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
MPI_Datatype mpi_type = convertType(type);
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Group grp, world_grp;
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
NET_CHECK(MPI_Comm_group(comm, &grp));
int grp_size;
NET_CHECK(MPI_Group_size(grp, &grp_size));
std::vector<int> ranks(grp_size);
std::vector<int> world_ranks(grp_size);
for (int i = 0; i < grp_size; i++) ranks[i] = i;
// Convert comm ranks to global ranks for rput
NET_CHECK(
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
int type_size;
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
// Works when number of PEs divisible by root(PE_size)
int num_clust = sqrt(pe_size);
int clust_size = (pe_size + num_clust - 1) / num_clust;
// TODO(bpotter) Allow any size of cluster
assert(num_clust * clust_size == pe_size);
int clust_id = new_rank / clust_size;
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
fprintf(stderr, "Alltoall size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
abort();
}
std::vector<MPI_Request> clust_req(pe_size);
// Step 1: Send data to PEs in cluster
for (int i = 0; i < pe_size; ++i) {
int src_offset = (new_rank % clust_size + (i / clust_size) * clust_size) *
type_size * size;
int dst_offset = i * type_size * size;
NET_CHECK(MPI_Rget(reinterpret_cast<void *>(
reinterpret_cast<char *>(ata_buffptr) + dst_offset),
size, mpi_type,
world_ranks[clust_id * clust_size + (i % clust_size)],
bp->heap_window_info[win_id]->get_offset(
reinterpret_cast<char *>(src) + src_offset),
size, mpi_type, bp->heap_window_info[win_id]->get_win(),
&clust_req[i]));
}
NET_CHECK(MPI_Waitall(pe_size, clust_req.data(), MPI_STATUSES_IGNORE));
// Now wait
int stride = world_ranks[1] - world_ranks[0];
MPI_Comm comm_cluster =
createComm(world_ranks[clust_id * clust_size], stride, clust_size);
MPI_Barrier(comm_cluster);
// Step 2: Send final data to PEs outside cluster
for (int i = 0; i < num_clust; ++i) {
int src_offset = i * type_size * size * clust_size;
int dst_offset = clust_id * type_size * size * clust_size;
NET_CHECK(MPI_Put(
reinterpret_cast<void *>(reinterpret_cast<char *>(ata_buffptr) +
src_offset),
size * clust_size, mpi_type,
world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_offset(dst) + dst_offset,
size * clust_size, mpi_type, bp->heap_window_info[win_id]->get_win()));
// Since MPI makes puts as complete as soon as the local buffer is free,
// we need a flush to satisfy quiet.
NET_CHECK(
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_win()));
}
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust);
// Now wait for completion
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
int pe_size, type_size;
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Datatype mpi_type = convertType(type);
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
// Currently GPU-centric algo only supports multiples of square root
// TODO(bpotter) Allow any size of cluster
int num_clust = sqrt(pe_size);
int clust_size = (pe_size + num_clust - 1) / num_clust;
// In most cases the MPI implementation is optimal
// But it crashes for > 512 messages
if (size <= 512) {
fcollect_mpi(dst, src, size, contextId, team, ata_buffptr, type,
status, blocking);
} else if (num_clust * clust_size == pe_size) {
fcollect_gcen(dst, src, size, win_id, contextId, team, ata_buffptr, type,
status, blocking);
} else {
fcollect_broadcast(dst, src, size, win_id, contextId, team, ata_buffptr,
type, status, blocking);
}
}
void MPITransport::fcollect_broadcast(void *dst, void *src, int size,
int win_id, int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// Broadcast implementation of fcollect
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
MPI_Datatype mpi_type = convertType(type);
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
MPI_Comm comm{team};
int rank{}, pe_size{};
NET_CHECK(MPI_Comm_rank(comm, &rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Group grp, world_grp;
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
MPI_Group grp{}, world_grp{};
NET_CHECK(MPI_Comm_group(comm, &grp));
NET_CHECK(MPI_Comm_group(ro_net_comm_world, &world_grp));
int grp_size;
NET_CHECK(MPI_Group_size(grp, &grp_size));
std::vector<int> ranks(pe_size);
std::vector<int> world_ranks(pe_size);
std::vector<int> ranks(grp_size);
std::vector<int> world_ranks(grp_size);
for (int i = 0; i < pe_size; i++) ranks[i] = i;
for (int i = 0; i < grp_size; i++) ranks[i] = i;
// Convert comm ranks to global ranks for rput
NET_CHECK(
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
NET_CHECK(MPI_Group_translate_ranks(grp, pe_size, ranks.data(), world_grp, world_ranks.data()));
int type_size;
MPI_Datatype mpi_type{convertType(type)};
int type_size{};
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
if (dst == src) {
fprintf(stderr, "IN_PLACE option not support for fcollect in the RO rocSHMEM conduit\n");
abort();
}
std::vector<MPI_Request> pe_req(pe_size);
// Put data to all PEs
for (int i = 0; i < pe_size; ++i) {
int dst_offset = new_rank * type_size * size;
NET_CHECK(MPI_Rput(
reinterpret_cast<char *>(src), size, mpi_type, world_ranks[i],
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) +
dst_offset),
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &pe_req[i]));
int target = (rank + i) % pe_size;
int offset = rank * type_size * size;
NET_CHECK(MPI_Rput(reinterpret_cast<char *>(src), size, mpi_type, world_ranks[target],
bp->heap_window_info[win_id]->get_offset(reinterpret_cast<char *>(dst) + offset),
size, mpi_type, bp->heap_window_info[win_id]->get_win(), &pe_req[i]));
requests.push_back({pe_req[i], {nullptr, contextId, false}});
outstanding[contextId]++;
}
NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE));
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
// Now wait for completion
barrier(contextId, status, blocking, comm);
}
void MPITransport::fcollect_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
// MPI's implementation of fcollect
int new_rank, pe_size;
MPI_Datatype mpi_type = convertType(type);
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
NET_CHECK(MPI_Allgather(src, size, mpi_type, dst, size, mpi_type, comm));
quiet(contextId, status);
}
void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// GPU-centric implementation of fcollect
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
MPI_Datatype mpi_type = convertType(type);
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Group grp, world_grp;
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
NET_CHECK(MPI_Comm_group(comm, &grp));
int grp_size;
NET_CHECK(MPI_Group_size(grp, &grp_size));
std::vector<int> ranks(grp_size);
std::vector<int> world_ranks(grp_size);
for (int i = 0; i < grp_size; i++) ranks[i] = i;
// Convert comm ranks to global ranks for rput
NET_CHECK(
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
int type_size;
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
// Works when number of PEs divisible by root(PE_size)
int num_clust = sqrt(pe_size);
int clust_size = (pe_size + num_clust - 1) / num_clust;
// TODO(bpotter) Allow any size of cluster
assert(num_clust * clust_size == pe_size);
int clust_id = new_rank / clust_size;
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
fprintf(stderr, "Fcollect size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
abort();
}
std::vector<MPI_Request> clust_req(pe_size);
// Step 1: Send data to PEs in cluster
for (int i = 0; i < clust_size; ++i) {
int dst_offset = i * type_size * size;
NET_CHECK(MPI_Rget(
reinterpret_cast<void *>(reinterpret_cast<char *>(ata_buffptr) +
dst_offset),
size, mpi_type, world_ranks[clust_id * clust_size + (i % clust_size)],
bp->heap_window_info[win_id]->get_offset(src), size, mpi_type,
bp->heap_window_info[win_id]->get_win(), &clust_req[i]));
}
NET_CHECK(MPI_Waitall(clust_size, clust_req.data(), MPI_STATUSES_IGNORE));
// Step 2: Send final data to PEs outside cluster
for (int i = 0; i < num_clust; ++i) {
int src_offset = i * type_size * size * clust_size;
int dst_offset = clust_id * type_size * size * clust_size;
NET_CHECK(MPI_Put(ata_buffptr, size * clust_size, mpi_type,
world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_offset(
reinterpret_cast<char *>(dst) + dst_offset),
size * clust_size, mpi_type,
bp->heap_window_info[win_id]->get_win()));
// Since MPI makes puts as complete as soon as the local buffer is free,
// we need a flush to satisfy quiet.
NET_CHECK(
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_win()));
}
int stride = world_ranks[1] - world_ranks[0];
MPI_Comm comm_cluster =
createComm(world_ranks[clust_id * clust_size], stride, clust_size);
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust);
// Now wait for completion
barrier(contextId, status, false, comm_cluster);
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// GPU-centric implementation with in-place, blocking synchronization
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
MPI_Datatype mpi_type = convertType(type);
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
MPI_Group grp, world_grp;
NET_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_grp));
NET_CHECK(MPI_Comm_group(comm, &grp));
int grp_size;
NET_CHECK(MPI_Group_size(grp, &grp_size));
std::vector<int> ranks(grp_size);
std::vector<int> world_ranks(grp_size);
for (int i = 0; i < grp_size; i++) ranks[i] = i;
// Convert comm ranks to global ranks for rput
NET_CHECK(
MPI_Group_translate_ranks(grp, grp_size, ranks.data(), world_grp, world_ranks.data()));
int type_size;
NET_CHECK(MPI_Type_size(mpi_type, &type_size));
// Works when number of PEs divisible by root(PE_size)
int num_clust = sqrt(pe_size);
int clust_size = (pe_size + num_clust - 1) / num_clust;
// TODO(bpotter) Allow any size of cluster
assert(num_clust * clust_size == pe_size);
int clust_id = new_rank / clust_size;
if (MAX_ATA_BUFF_SIZE < type_size * size * pe_size) {
fprintf(stderr, "Fcollect size %d exceeds max MAX_ATA_BUFF_SIZE %d\n",
type_size * size * pe_size, MAX_ATA_BUFF_SIZE);
abort();
}
std::vector<MPI_Request> clust_req(pe_size);
// Step 1: Send data to PEs in cluster
for (int i = 0; i < clust_size; ++i) {
int dst_offset = i * type_size * size;
NET_CHECK(MPI_Rget(
reinterpret_cast<void *>(reinterpret_cast<char *>(ata_buffptr) +
dst_offset),
size, mpi_type, world_ranks[clust_id * clust_size + (i % clust_size)],
bp->heap_window_info[win_id]->get_offset(src), size, mpi_type,
bp->heap_window_info[win_id]->get_win(), &clust_req[i]));
}
NET_CHECK(MPI_Waitall(clust_size, clust_req.data(), MPI_STATUSES_IGNORE));
int stride = world_ranks[1] - world_ranks[0];
MPI_Comm comm_cluster =
createComm(world_ranks[clust_id * clust_size], stride, clust_size);
MPI_Barrier(comm_cluster);
// Step 2: Send final data to PEs outside cluster
for (int i = 0; i < num_clust; ++i) {
int src_offset = i * type_size * size * clust_size;
int dst_offset = clust_id * type_size * size * clust_size;
NET_CHECK(MPI_Put(ata_buffptr, size * clust_size, mpi_type,
world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_offset(
reinterpret_cast<char *>(dst) + dst_offset),
size * clust_size, mpi_type,
bp->heap_window_info[win_id]->get_win()));
// Since MPI makes puts as complete as soon as the local buffer is free,
// we need a flush to satisfy quiet.
NET_CHECK(
MPI_Win_flush(world_ranks[(new_rank % clust_size) + i * clust_size],
bp->heap_window_info[win_id]->get_win()));
}
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust);
// Now wait for completion
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
int contextId, volatile char *status, bool blocking,
bool inline_data) {
@@ -1136,7 +590,6 @@ void MPITransport::progress() {
int outcount{};
auto uptr_req_arr {raw_requests()};
NET_CHECK(MPI_Testsome(incount, uptr_req_arr.get(), &outcount,
testsome_indices.data(), MPI_STATUSES_IGNORE));
@@ -54,21 +54,11 @@ public:
void barrier(int contextId, volatile char *status, bool blocking,
MPI_Comm team) override;
void reduction(void *dst, void *src, int size, int pe, int win_id,
int contextId, int start, int logPstride, int sizePE,
void *pWrk, long *pSync, ROCSHMEM_OP op, ro_net_types type,
volatile char *status, bool blocking) override;
void team_reduction(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, volatile char *status,
bool blocking) override;
void broadcast(void *dst, void *src, int size, int pe, int win_id,
int contextId, int start, int logPstride, int sizePE,
int PE_root, long *pSync, ro_net_types type,
volatile char *status, bool blocking) override;
void team_broadcast(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, int PE_root,
ro_net_types type, volatile char *status,
@@ -78,44 +68,10 @@ public:
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) override;
void alltoall_broadcast(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking);
void alltoall_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking);
void alltoall_gcen(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status, bool blocking);
void alltoall_gcen2(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status, bool blocking);
void fcollect(void *dst, void *src, int size, int win_id, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) override;
void fcollect_broadcast(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking);
void fcollect_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking);
void fcollect_gcen(void *dst, void *src, int size, int win_id, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking);
void fcollect_gcen2(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status, bool blocking);
void putMem(void *dst, void *src, int size, int pe, int win_id,
int contextId, volatile char *status, bool blocking,
bool inline_data = false) override;
@@ -34,7 +34,9 @@ ROTeam::ROTeam(Backend* backend, TeamInfo* team_info_wrt_parent,
mpi_comm) {
type = BackendType::RO_BACKEND;
ata_buffer = malloc(MAX_ATA_BUFF_SIZE);
// Disable allocating ata_buffer for now. It is not
// used at the moment, but might come back in future versions.
ata_buffer = nullptr;
}
ROTeam::~ROTeam() {
@@ -53,22 +53,11 @@ class Transport {
virtual void barrier(int wg_id, volatile char *status, bool blocking,
MPI_Comm team) = 0;
virtual void reduction(void *dst, void *src, int size, int pe, int win_id,
int wg_id, int start, int logPstride, int sizePE,
void *pWrk, long *pSync, ROCSHMEM_OP op,
ro_net_types type, volatile char *status,
bool blocking) = 0;
virtual void team_reduction(void *dst, void *src, int size, int win_id,
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, volatile char *status,
bool blocking) = 0;
virtual void broadcast(void *dst, void *src, int size, int pe, int win_id,
int wg_id, int start, int logPstride, int sizePE,
int PE_root, long *pSync, ro_net_types type,
volatile char *status, bool blocking) = 0;
virtual void team_broadcast(void *dst, void *src, int size, int win_id,
int wg_id, MPI_Comm team, int PE_root,
ro_net_types type, volatile char *status,