Merge pull request #39 from Yiltan/ro/fix-teamreduce
Fix Team reduction intra-node
Šī revīzija ir iekļauta:
@@ -81,7 +81,7 @@ __device__ void Context::to_all(T *dest, const T *source, int nreduce,
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__device__ int Context::reduce(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nreduce) {
|
||||
int nreduce) {
|
||||
if (nreduce == 0) {
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ enum ro_net_cmds {
|
||||
RO_NET_QUIET,
|
||||
RO_NET_FINALIZE,
|
||||
RO_NET_TO_ALL,
|
||||
RO_NET_TEAM_TO_ALL,
|
||||
RO_NET_TEAM_REDUCE,
|
||||
RO_NET_SYNC,
|
||||
RO_NET_BARRIER_ALL,
|
||||
RO_NET_BROADCAST,
|
||||
|
||||
@@ -507,7 +507,7 @@ __device__ void build_queue_element(
|
||||
queue_element->op = op;
|
||||
queue_element->datatype = datatype;
|
||||
}
|
||||
if (type == RO_NET_TEAM_TO_ALL) {
|
||||
if (type == RO_NET_TEAM_REDUCE) {
|
||||
queue_element->op = op;
|
||||
queue_element->datatype = datatype;
|
||||
queue_element->team_comm = team_comm;
|
||||
|
||||
@@ -81,8 +81,8 @@ class ROContext : public Context {
|
||||
long *pSync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__device__ void to_all(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nreduce);
|
||||
__device__ int reduce(rocshmem_team_t team, T *dest, const T *source,
|
||||
int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__device__ void put(T *dest, const T *source, size_t nelems, int pe);
|
||||
|
||||
@@ -109,20 +109,21 @@ struct GetROType<long double> {
|
||||
*****************************************************************************/
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
__device__ void ROContext::to_all(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
__device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
if (!is_thread_zero_in_block()) {
|
||||
__syncthreads();
|
||||
return;
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
ROTeam *team_obj{reinterpret_cast<ROTeam *>(team)};
|
||||
|
||||
build_queue_element(RO_NET_TEAM_TO_ALL, dest, const_cast<T *>(source),
|
||||
build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast<T *>(source),
|
||||
nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm,
|
||||
ro_net_win_id, block_handle, true, Op, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T, ROCSHMEM_OP Op>
|
||||
|
||||
@@ -159,14 +159,14 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.PE,
|
||||
reinterpret_cast<int64_t>(next_element.ol2.pWrk));
|
||||
break;
|
||||
case RO_NET_TEAM_TO_ALL:
|
||||
case RO_NET_TEAM_REDUCE:
|
||||
team_reduction(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.ro_net_win_id, queue_idx,
|
||||
next_element.team_comm,
|
||||
static_cast<ROCSHMEM_OP>(next_element.op),
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
DPRINTF("Received FLOAT_SUM_TEAM_TO_ALL dst %p src %p size %lu team %d\n",
|
||||
DPRINTF("Received FLOAT_SUM_TEAM_REDUCE dst %p src %p size %lu team %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm);
|
||||
break;
|
||||
|
||||
Atsaukties uz šo jaunā problēmā
Block a user