diff --git a/src/ipc/context_ipc_device.cpp b/src/ipc/context_ipc_device.cpp index 3ed4831d70..6b1d56b051 100644 --- a/src/ipc/context_ipc_device.cpp +++ b/src/ipc/context_ipc_device.cpp @@ -164,13 +164,7 @@ __device__ void IPCContext::internal_putmem(void *dest, const void *source, size_t nelems, int pe) { uint64_t L_offset = reinterpret_cast(dest) - wrk_sync_pool_bases_[my_pe]; memcpy_lane(wrk_sync_pool_bases_[pe] + L_offset, const_cast(source), nelems); -#if defined(__gfx90a__) - __threadfence_system(); -#elif defined (__gfx1201__) || defined (__gfx1100__) - fence(pe); -#else ipcImpl_.ipcFence(); -#endif } __device__ void IPCContext::internal_getmem(void *dest, const void *source, @@ -186,15 +180,7 @@ __device__ void IPCContext::internal_putmem_wg(void *dest, const void *source, uint64_t L_offset = reinterpret_cast(dest) - wrk_sync_pool_bases_[my_pe]; memcpy_wg(wrk_sync_pool_bases_[pe] + L_offset, const_cast(source), nelems); __syncthreads(); -#if defined(__gfx90a__) - __threadfence_system(); -#elif defined (__gfx1201__) || defined (__gfx1100__) - if (is_thread_zero_in_block() ) { - fence(pe); - } -#else ipcImpl_.ipcFence(); -#endif } __device__ void IPCContext::internal_getmem_wg(void *dest, const void *source, @@ -210,15 +196,7 @@ __device__ void IPCContext::internal_putmem_wave(void *dest, const void *source, size_t nelems, int pe) { uint64_t L_offset = reinterpret_cast(dest) - wrk_sync_pool_bases_[my_pe]; memcpy_wave(wrk_sync_pool_bases_[pe] + L_offset, const_cast(source), nelems); -#if defined(__gfx90a__) - __threadfence_system(); -#elif defined (__gfx1201__) || defined (__gfx1100__) - if (is_thread_zero_in_wave() ) { - fence(pe); - } -#else ipcImpl_.ipcFence(); -#endif } __device__ void IPCContext::internal_getmem_wave(void *dest, diff --git a/src/ipc/context_ipc_tmpl_device.hpp b/src/ipc/context_ipc_tmpl_device.hpp index 36f1c17247..ce86d2aa62 100644 --- a/src/ipc/context_ipc_tmpl_device.hpp +++ b/src/ipc/context_ipc_tmpl_device.hpp @@ -174,7 +174,7 @@ __device__ void IPCContext::internal_direct_allreduce( int stride = team_obj->tinfo_wrt_world->stride; int PE_start = team_obj->tinfo_wrt_world->pe_start; int PE_size = team_obj->tinfo_wrt_world->size; - long *pSync = team_obj->barrier_pSync; + long *pSync = team_obj->reduce_pSync; T *pWrk = reinterpret_cast(team_obj->pWrk); int finish = PE_start + stride * PE_size; @@ -195,7 +195,7 @@ __device__ void IPCContext::internal_direct_allreduce( nelems * sizeof(T), i); if (is_thread_zero_in_block()) { - fence(); + fence(i); internal_putmem(&pSync[pe], &flag_val, sizeof(*pSync), i); } } @@ -222,7 +222,6 @@ __device__ void IPCContext::internal_direct_allreduce( for (int i = wg_id; i < num_pes; i += wg_size) { pSync[i] = ROCSHMEM_SYNC_VALUE; } - threadfence_system(); __syncthreads(); } @@ -290,7 +289,7 @@ __device__ void IPCContext::internal_ring_allreduce( int stride = team_obj->tinfo_wrt_world->stride; int PE_start = team_obj->tinfo_wrt_world->pe_start; int PE_size = team_obj->tinfo_wrt_world->size; - long *pSync = team_obj->barrier_pSync; + long *pSync = team_obj->reduce_pSync; T *pWrk = reinterpret_cast(team_obj->pWrk); int my_pe_in_team = team_obj->my_pe; @@ -320,8 +319,8 @@ __device__ void IPCContext::internal_ring_allreduce( chunk_size * sizeof(T), send_pe); if (is_thread_zero_in_block()) { - fence(); wait_val = seg + 100; + fence(send_pe); internal_putmem(&pSync[iter], &wait_val, sizeof(*pSync), send_pe); wait_until(&pSync[iter], ROCSHMEM_CMP_EQ, wait_val); } @@ -338,8 +337,8 @@ __device__ void IPCContext::internal_ring_allreduce( chunk_size * sizeof(T), send_pe); if (is_thread_zero_in_block()) { - fence(); wait_val = seg + 10; + fence(send_pe); internal_putmem(&pSync[iter], &wait_val, sizeof(*pSync), send_pe); wait_until(&pSync[iter], ROCSHMEM_CMP_EQ, wait_val); } @@ -350,7 +349,6 @@ __device__ void IPCContext::internal_ring_allreduce( for (int i = wg_id; i < 2 * num_pes - 2; i += wg_size) { pSync[i] = ROCSHMEM_SYNC_VALUE; } - threadfence_system(); __syncthreads(); } @@ -409,6 +407,7 @@ __device__ int IPCContext::reduce(rocshmem_team_t team, T *dest, return ROCSHMEM_ERROR; } } + barrier_wg(team); return ROCSHMEM_SUCCESS; } diff --git a/tests/functional_tests/team_reduction_tester.cpp b/tests/functional_tests/team_reduction_tester.cpp index ec19bee00a..b34643eea9 100644 --- a/tests/functional_tests/team_reduction_tester.cpp +++ b/tests/functional_tests/team_reduction_tester.cpp @@ -83,7 +83,7 @@ __global__ void TeamReductionTest(int loop, int skip, long long int *start_time, __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); - rocshmem_wg_ctx_create(ctx_type, &ctx); + rocshmem_wg_team_create_ctx(team, ctx_type, &ctx); int n_pes = rocshmem_ctx_n_pes(ctx);