Merge pull request #39 from Yiltan/ro/fix-teamreduce

Fix Team reduction intra-node
Šī revīzija ir iekļauta:
Yiltan
2025-02-10 14:56:27 -05:00
revīziju iesūtīja GitHub
revīzija 944444cf12
6 mainīti faili ar 12 papildinājumiem un 11 dzēšanām
+1 -1
Parādīt failu
@@ -81,7 +81,7 @@ __device__ void Context::to_all(T *dest, const T *source, int nreduce,
template <typename T, ROCSHMEM_OP Op>
__device__ int Context::reduce(rocshmem_team_t team, T *dest, const T *source,
int nreduce) {
int nreduce) {
if (nreduce == 0) {
return ROCSHMEM_SUCCESS;
}
+1 -1
Parādīt failu
@@ -38,7 +38,7 @@ enum ro_net_cmds {
RO_NET_QUIET,
RO_NET_FINALIZE,
RO_NET_TO_ALL,
RO_NET_TEAM_TO_ALL,
RO_NET_TEAM_REDUCE,
RO_NET_SYNC,
RO_NET_BARRIER_ALL,
RO_NET_BROADCAST,
@@ -507,7 +507,7 @@ __device__ void build_queue_element(
queue_element->op = op;
queue_element->datatype = datatype;
}
if (type == RO_NET_TEAM_TO_ALL) {
if (type == RO_NET_TEAM_REDUCE) {
queue_element->op = op;
queue_element->datatype = datatype;
queue_element->team_comm = team_comm;
@@ -81,8 +81,8 @@ class ROContext : public Context {
long *pSync); // NOLINT(runtime/int)
template <typename T, ROCSHMEM_OP Op>
__device__ void to_all(rocshmem_team_t team, T *dest, const T *source,
int nreduce);
__device__ int reduce(rocshmem_team_t team, T *dest, const T *source,
int nreduce);
template <typename T>
__device__ void put(T *dest, const T *source, size_t nelems, int pe);
@@ -109,20 +109,21 @@ struct GetROType<long double> {
*****************************************************************************/
template <typename T, ROCSHMEM_OP Op>
__device__ void ROContext::to_all(rocshmem_team_t team, T *dest,
const T *source, int nreduce) {
__device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
const T *source, int nreduce) {
if (!is_thread_zero_in_block()) {
__syncthreads();
return;
return ROCSHMEM_SUCCESS;
}
ROTeam *team_obj{reinterpret_cast<ROTeam *>(team)};
build_queue_element(RO_NET_TEAM_TO_ALL, dest, const_cast<T *>(source),
build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast<T *>(source),
nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm,
ro_net_win_id, block_handle, true, Op, GetROType<T>::Type);
__syncthreads();
return ROCSHMEM_SUCCESS;
}
template <typename T, ROCSHMEM_OP Op>
+2 -2
Parādīt failu
@@ -159,14 +159,14 @@ void MPITransport::submitRequestsToMPI() {
next_element.PE,
reinterpret_cast<int64_t>(next_element.ol2.pWrk));
break;
case RO_NET_TEAM_TO_ALL:
case RO_NET_TEAM_REDUCE:
team_reduction(next_element.dst, next_element.src, next_element.ol1.size,
next_element.ro_net_win_id, queue_idx,
next_element.team_comm,
static_cast<ROCSHMEM_OP>(next_element.op),
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
DPRINTF("Received FLOAT_SUM_TEAM_TO_ALL dst %p src %p size %lu team %d\n",
DPRINTF("Received FLOAT_SUM_TEAM_REDUCE dst %p src %p size %lu team %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm);
break;