diff --git a/projects/rocshmem/src/ipc/context_ipc_device.hpp b/projects/rocshmem/src/ipc/context_ipc_device.hpp index 9ff8a688a6..03b72efe89 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.hpp @@ -196,11 +196,6 @@ class IPCContext : public Context { char* g_ret; //internal functions used by collective operations - template - __device__ void internal_to_all(T *dest, const T *source, int nreduce, int PE_start, - int stride, int PE_size, T *pWrk, - long *pSync); // NOLINT(runtime/int) - template __device__ void internal_broadcast(T *dest, const T *source, int nelems, int pe_root, int pe_start, int stride, int pe_size, @@ -234,12 +229,10 @@ class IPCContext : public Context { template __device__ void internal_direct_allreduce(T *dst, const T *src, - int nelems, int PE_start, int logPE_stride, - int PE_size, T *pWrk, long *pSync); + int nelems, IPCTeam *team_obj); template __device__ void internal_ring_allreduce(T *dst, const T *src, - int nelems, int PE_start, int logPE_stride, - int PE_size, T *pWrk, long *pSync, + int nelems, IPCTeam *team_obj, int n_seg, int seg_size, int chunk_size); //internal functions used by collectives routines to write/read to diff --git a/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp b/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp index e0e06ecf7a..7a03233d89 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp @@ -101,13 +101,9 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride, __device__ void IPCContext::sync(roc_shmem_team_t team) { IPCTeam *team_obj = reinterpret_cast(team); - /** - * Ensure that the stride is a multiple of 2. - */ - int log_pe_stride = static_cast(team_obj->tinfo_wrt_world->log_stride); int pe = team_obj->my_pe_in_world; int pe_start = team_obj->tinfo_wrt_world->pe_start; - int pe_stride = (1 << log_pe_stride); + int pe_stride = team_obj->tinfo_wrt_world->stride; int pe_size = team_obj->num_pes; internal_sync(pe, pe_start, pe_stride, pe_size, barrier_sync); diff --git a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp index 0e1504a6e7..4da49b6f7b 100644 --- a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp @@ -164,9 +164,13 @@ __device__ void compute_reduce(T *src, T *dst, int size, int wg_id, template __device__ void IPCContext::internal_direct_allreduce( - T *dst, const T *src, int nelems, int PE_start, int stride, - int PE_size, T *pWrk, - long *pSync) { // NOLINT(runtime/int) + T *dst, const T *src, int nelems, IPCTeam *team_obj) { // NOLINT(runtime/int) + + int stride = team_obj->tinfo_wrt_world->stride; + int PE_start = team_obj->tinfo_wrt_world->pe_start; + int PE_size = team_obj->tinfo_wrt_world->size; + long *pSync = team_obj->barrier_pSync; + T *pWrk = reinterpret_cast(team_obj->pWrk); int finish = PE_start + stride * PE_size; int pe = my_pe; @@ -276,12 +280,20 @@ __device__ void IPCContext::internal_direct_allreduce( */ template __device__ void IPCContext::internal_ring_allreduce( - T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start, - [[maybe_unused]] int stride, [[maybe_unused]] int PE_size, T *pWrk, - long *pSync, // NOLINT(runtime/int) + T *dst, const T *src, int nelems, IPCTeam *team_obj, // NOLINT(runtime/int) int n_seg, int seg_size, int chunk_size) { + + int stride = team_obj->tinfo_wrt_world->stride; + int PE_start = team_obj->tinfo_wrt_world->pe_start; + int PE_size = team_obj->tinfo_wrt_world->size; + long *pSync = team_obj->barrier_pSync; + T *pWrk = reinterpret_cast(team_obj->pWrk); + int my_pe_in_team = team_obj->my_pe; + int off_seg, off_send, off_recv; - int send_pe = (my_pe + 1) % num_pes; + int send_pe = (my_pe_in_team + 1) % PE_size; + // send_pe is relative to team, convert it relative to team world + send_pe = team_obj->get_pe_in_world(send_pe); long wait_val; // NOLINT(runtime/int) int wg_size = get_flat_block_size(); @@ -295,9 +307,9 @@ __device__ void IPCContext::internal_ring_allreduce( for (size_t seg = 0; seg < n_seg; seg++) { off_seg = seg * seg_size; // Loop 2 in the algorithm above - for (int iter = 0; iter < num_pes - 1; iter++) { - off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size); - off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size); + for (int iter = 0; iter < PE_size - 1; iter++) { + off_send = (((my_pe_in_team + 1 - iter + 2 * PE_size) % PE_size) * chunk_size); + off_recv = (((my_pe_in_team - iter + 2 * PE_size) % PE_size) * chunk_size); internal_putmem_wg(reinterpret_cast(&pWrk[off_send]), reinterpret_cast(&dst[off_send + off_seg]), @@ -319,8 +331,8 @@ __device__ void IPCContext::internal_ring_allreduce( } // Loop 2 in the example above - for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) { - off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size); + for (size_t iter = PE_size - 1; iter < 2 * PE_size - 2; iter++) { + off_send = (((my_pe_in_team + 1 - iter + 2 * PE_size) % PE_size) * chunk_size); putmem_nbi_wg(reinterpret_cast(&dst[off_send + off_seg]), reinterpret_cast(&dst[off_send + off_seg]), chunk_size * sizeof(T), send_pe); @@ -350,76 +362,58 @@ __device__ int IPCContext::reduce(roc_shmem_team_t team, T *dest, const T *source, int nreduce) { IPCTeam *team_obj = reinterpret_cast(team); - /** - * Ensure that the stride is a multiple of 2 for GPU_IB. - */ - int stride = team_obj->tinfo_wrt_world->stride; - int pe_start = team_obj->tinfo_wrt_world->pe_start; - int pe_size = team_obj->tinfo_wrt_world->size; - long *p_sync = team_obj->barrier_pSync; - T *pWrk = reinterpret_cast(team_obj->pWrk); + int PE_size = team_obj->tinfo_wrt_world->size; - internal_to_all(dest, source, nreduce, pe_start, stride, pe_size, pWrk, - p_sync); - return ROC_SHMEM_SUCCESS; -} - -template -__device__ void IPCContext::internal_to_all(T *dest, const T *source, int nreduce, - int PE_start, int stride, - int PE_size, T *pWrk, - long *pSync) { // NOLINT(runtime/int) - size_t direct_pWrk = num_pes * nreduce; - size_t direct_pSync = num_pes; - size_t ring_pSync = 2 * num_pes; + size_t direct_pWrk = PE_size * nreduce; + size_t direct_pSync = PE_size; + size_t ring_pSync = 2 * PE_size; size_t provided_pWrk = max(nreduce / 2 + 1, ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE); size_t provided_pSync = ROC_SHMEM_REDUCE_SYNC_SIZE; if (provided_pWrk >= direct_pWrk && provided_pSync >= direct_pSync) { - internal_direct_allreduce(dest, source, nreduce, PE_start, stride, - PE_size, pWrk, pSync); + internal_direct_allreduce(dest, source, nreduce, team_obj); } else { if (ring_pSync <= ROC_SHMEM_REDUCE_SYNC_SIZE) { size_t ring_pWrk = ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE; // integer division truncating value - int chunk_size = ring_pWrk / num_pes; - int seg_size = chunk_size * num_pes; + int chunk_size = ring_pWrk / PE_size; + int seg_size = chunk_size * PE_size; // integer division truncating value int n_seg = nreduce / seg_size; // integer division rounding up - int n_seg_up = (nreduce + (seg_size -1)) / seg_size; + int n_seg_up = (nreduce - 1) / seg_size + 1; // recalculate chunk_size - chunk_size = seg_size / num_pes; + chunk_size = seg_size / PE_size; if (n_seg == 0) { n_seg = 1; } - internal_ring_allreduce(dest, source, nreduce, PE_start, - stride, PE_size, pWrk, pSync, n_seg, + internal_ring_allreduce(dest, source, nreduce, team_obj, n_seg, seg_size, chunk_size); if (n_seg_up > n_seg) { T *p_dst = (dest + (n_seg * seg_size)); const T *p_src = (source + (n_seg * seg_size)); int p_count = nreduce - (n_seg * seg_size); - int p_chunk = p_count / num_pes; + int p_chunk = p_count / PE_size; - internal_ring_allreduce(p_dst, p_src, p_count, PE_start, stride, - PE_size, pWrk, pSync, 1, (p_chunk * num_pes), p_chunk); + internal_ring_allreduce(p_dst, p_src, p_count, team_obj, 1, + (p_chunk * PE_size), p_chunk); - if ((p_chunk * num_pes) < p_count) { + if ((p_chunk * PE_size) < p_count) { // Final elements need to use direct_allreduce - p_count -= (p_chunk * num_pes); - p_dst += (p_chunk * num_pes); - const T *p_src2 = p_src + (p_chunk * num_pes); + p_count -= (p_chunk * PE_size); + p_dst += (p_chunk * PE_size); + const T *p_src2 = p_src + (p_chunk * PE_size); - internal_direct_allreduce(p_dst, p_src2, p_count, PE_start, stride, - PE_size, pWrk, pSync); + internal_direct_allreduce(p_dst, p_src2, p_count, team_obj); } } } else { GPU_DPRINTF("Unsupported reduction size for IPC conduit.\n"); + return ROC_SHMEM_ERROR; } } + return ROC_SHMEM_SUCCESS; } template @@ -449,9 +443,6 @@ __device__ void IPCContext::broadcast(roc_shmem_team_t team, T *dst, const T *src, int nelems, int pe_root) { IPCTeam *team_obj = reinterpret_cast(team); - /** - * Ensure that the stride is a multiple of 2 . - */ int stride = team_obj->tinfo_wrt_world->stride; int pe_start = team_obj->tinfo_wrt_world->pe_start; int pe_size = team_obj->tinfo_wrt_world->size; @@ -490,13 +481,9 @@ __device__ void IPCContext::alltoall_linear(roc_shmem_team_t team, T *dst, const T *src, int nelems) { IPCTeam *team_obj = reinterpret_cast(team); - /** - * Ensure that the stride is a multiple of 2 - */ - int log_pe_stride = static_cast(team_obj->tinfo_wrt_world->log_stride); int pe_start = team_obj->tinfo_wrt_world->pe_start; int pe_size = team_obj->num_pes; - int stride = 1 << log_pe_stride; + int stride = team_obj->tinfo_wrt_world->stride; long *pSync = team_obj->alltoall_pSync; int my_pe_in_team = team_obj->my_pe; @@ -523,13 +510,9 @@ __device__ void IPCContext::fcollect_linear(roc_shmem_team_t team, T *dst, const T *src, int nelems) { IPCTeam *team_obj = reinterpret_cast(team); - /** - * Ensure that the stride is a multiple of 2. - */ - int log_pe_stride = static_cast(team_obj->tinfo_wrt_world->log_stride); int pe_start = team_obj->tinfo_wrt_world->pe_start; int pe_size = team_obj->num_pes; - int stride = 1 << log_pe_stride; + int stride = team_obj->tinfo_wrt_world->stride; long *pSync = team_obj->alltoall_pSync; int my_pe_in_team = team_obj->my_pe;