From 022b2c27e752e903c6aa779df51d9ce71f92eb32 Mon Sep 17 00:00:00 2001 From: Yiltan Hassan Temucin Date: Thu, 6 Feb 2025 16:23:45 -0600 Subject: [PATCH] Fix Team reduction intra-node --- src/context_tmpl_device.hpp | 2 +- src/reverse_offload/commands_types.hpp | 2 +- src/reverse_offload/context_ro_device.cpp | 2 +- src/reverse_offload/context_ro_device.hpp | 4 ++-- src/reverse_offload/context_ro_tmpl_device.hpp | 9 +++++---- src/reverse_offload/mpi_transport.cpp | 4 ++-- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/context_tmpl_device.hpp b/src/context_tmpl_device.hpp index 4a7862cb64..a9c1dda624 100644 --- a/src/context_tmpl_device.hpp +++ b/src/context_tmpl_device.hpp @@ -81,7 +81,7 @@ __device__ void Context::to_all(T *dest, const T *source, int nreduce, template __device__ int Context::reduce(rocshmem_team_t team, T *dest, const T *source, - int nreduce) { + int nreduce) { if (nreduce == 0) { return ROCSHMEM_SUCCESS; } diff --git a/src/reverse_offload/commands_types.hpp b/src/reverse_offload/commands_types.hpp index e28ff572cd..6fd7836a10 100644 --- a/src/reverse_offload/commands_types.hpp +++ b/src/reverse_offload/commands_types.hpp @@ -38,7 +38,7 @@ enum ro_net_cmds { RO_NET_QUIET, RO_NET_FINALIZE, RO_NET_TO_ALL, - RO_NET_TEAM_TO_ALL, + RO_NET_TEAM_REDUCE, RO_NET_SYNC, RO_NET_BARRIER_ALL, RO_NET_BROADCAST, diff --git a/src/reverse_offload/context_ro_device.cpp b/src/reverse_offload/context_ro_device.cpp index 4cc5951de0..21de4b9b2f 100644 --- a/src/reverse_offload/context_ro_device.cpp +++ b/src/reverse_offload/context_ro_device.cpp @@ -507,7 +507,7 @@ __device__ void build_queue_element( queue_element->op = op; queue_element->datatype = datatype; } - if (type == RO_NET_TEAM_TO_ALL) { + if (type == RO_NET_TEAM_REDUCE) { queue_element->op = op; queue_element->datatype = datatype; queue_element->team_comm = team_comm; diff --git a/src/reverse_offload/context_ro_device.hpp b/src/reverse_offload/context_ro_device.hpp index ceccf5e776..3413ded9e9 100644 --- a/src/reverse_offload/context_ro_device.hpp +++ b/src/reverse_offload/context_ro_device.hpp @@ -81,8 +81,8 @@ class ROContext : public Context { long *pSync); // NOLINT(runtime/int) template - __device__ void to_all(rocshmem_team_t team, T *dest, const T *source, - int nreduce); + __device__ int reduce(rocshmem_team_t team, T *dest, const T *source, + int nreduce); template __device__ void put(T *dest, const T *source, size_t nelems, int pe); diff --git a/src/reverse_offload/context_ro_tmpl_device.hpp b/src/reverse_offload/context_ro_tmpl_device.hpp index d76a9eb222..c8622bfe6f 100644 --- a/src/reverse_offload/context_ro_tmpl_device.hpp +++ b/src/reverse_offload/context_ro_tmpl_device.hpp @@ -109,20 +109,21 @@ struct GetROType { *****************************************************************************/ template -__device__ void ROContext::to_all(rocshmem_team_t team, T *dest, - const T *source, int nreduce) { +__device__ int ROContext::reduce(rocshmem_team_t team, T *dest, + const T *source, int nreduce) { if (!is_thread_zero_in_block()) { __syncthreads(); - return; + return ROCSHMEM_SUCCESS; } ROTeam *team_obj{reinterpret_cast(team)}; - build_queue_element(RO_NET_TEAM_TO_ALL, dest, const_cast(source), + build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast(source), nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, true, Op, GetROType::Type); __syncthreads(); + return ROCSHMEM_SUCCESS; } template diff --git a/src/reverse_offload/mpi_transport.cpp b/src/reverse_offload/mpi_transport.cpp index 9aa660d30c..ee057e3153 100644 --- a/src/reverse_offload/mpi_transport.cpp +++ b/src/reverse_offload/mpi_transport.cpp @@ -159,14 +159,14 @@ void MPITransport::submitRequestsToMPI() { next_element.PE, reinterpret_cast(next_element.ol2.pWrk)); break; - case RO_NET_TEAM_TO_ALL: + case RO_NET_TEAM_REDUCE: team_reduction(next_element.dst, next_element.src, next_element.ol1.size, next_element.ro_net_win_id, queue_idx, next_element.team_comm, static_cast(next_element.op), static_cast(next_element.datatype), next_element.threadId, true); - DPRINTF("Received FLOAT_SUM_TEAM_TO_ALL dst %p src %p size %lu team %d\n", + DPRINTF("Received FLOAT_SUM_TEAM_REDUCE dst %p src %p size %lu team %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.team_comm); break;