diff --git a/include/rocshmem/rocshmem.hpp b/include/rocshmem/rocshmem.hpp index 629c5a377c..d025c6b0d7 100644 --- a/include/rocshmem/rocshmem.hpp +++ b/include/rocshmem/rocshmem.hpp @@ -465,6 +465,16 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_wg_barrier_all( __device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all(); +/** + * @brief perform a collective barrier between all PEs in the team. + * The caller is blocked until the barrier is resolved. + * + * @param[in] team The team on which to perform barrier synchronization + * + * @return void + */ +__device__ void rocshmem_barrier(rocshmem_team_t); + /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. diff --git a/scripts/functional_tests/driver.sh b/scripts/functional_tests/driver.sh index 57c73e0c2a..bbad1f7720 100755 --- a/scripts/functional_tests/driver.sh +++ b/scripts/functional_tests/driver.sh @@ -80,6 +80,7 @@ declare -A TEST_NUMBERS=( ["signalfetch"]="55" ["wgsignalfetch"]="56" ["wavesignalfetch"]="57" + ["teambarrier"]="58" ) ExecTest() { @@ -306,6 +307,7 @@ TestColl() { # | Name | Ranks | Workgroups | Threads | Max Message Size # ############################################################################## ExecTest "barrierall" 2 1 1 + ExecTest "teambarrier" 2 1 1 ExecTest "sync" 2 1 1 diff --git a/src/context.hpp b/src/context.hpp index ac31521a91..8d04c8687f 100644 --- a/src/context.hpp +++ b/src/context.hpp @@ -137,6 +137,8 @@ class Context { __device__ void barrier_all(); + __device__ void barrier(rocshmem_team_t team); + __device__ void sync_all(); __device__ void sync(rocshmem_team_t team); diff --git a/src/context_device.cpp b/src/context_device.cpp index 850274ef9f..ea1ae4494f 100644 --- a/src/context_device.cpp +++ b/src/context_device.cpp @@ -148,6 +148,12 @@ __device__ void Context::barrier_all() { DISPATCH(barrier_all()); } +__device__ void Context::barrier(rocshmem_team_t team) { + ctxStats.incStat(NUM_BARRIER_ALL); + + DISPATCH(barrier(team)); +} + __device__ void Context::sync_all() { ctxStats.incStat(NUM_SYNC_ALL); diff --git a/src/ipc/context_ipc_device.hpp b/src/ipc/context_ipc_device.hpp index d2e0f49565..92abb76c2d 100644 --- a/src/ipc/context_ipc_device.hpp +++ b/src/ipc/context_ipc_device.hpp @@ -61,6 +61,8 @@ class IPCContext : public Context { __device__ void barrier_all(); + __device__ void barrier(rocshmem_team_t team); + __device__ void sync_all(); __device__ void sync(rocshmem_team_t team); diff --git a/src/ipc/context_ipc_device_coll.cpp b/src/ipc/context_ipc_device_coll.cpp index 337b9106ca..8ba18a8ae8 100644 --- a/src/ipc/context_ipc_device_coll.cpp +++ b/src/ipc/context_ipc_device_coll.cpp @@ -122,4 +122,20 @@ __device__ void IPCContext::barrier_all() { __syncthreads(); } +__device__ void IPCContext::barrier(rocshmem_team_t team) { + IPCTeam *team_obj = reinterpret_cast(team); + + int pe = team_obj->my_pe_in_world; + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_stride = team_obj->tinfo_wrt_world->stride; + int pe_size = team_obj->num_pes; + long *p_sync = team_obj->barrier_pSync; + + if (is_thread_zero_in_block()) { + quiet(); + } + internal_sync(pe, pe_start, pe_stride, pe_size, p_sync); + __syncthreads(); +} + } // namespace rocshmem diff --git a/src/reverse_offload/commands_types.hpp b/src/reverse_offload/commands_types.hpp index edfb14a00c..481e1275e1 100644 --- a/src/reverse_offload/commands_types.hpp +++ b/src/reverse_offload/commands_types.hpp @@ -39,7 +39,7 @@ enum ro_net_cmds { RO_NET_FINALIZE, RO_NET_TEAM_REDUCE, RO_NET_SYNC, - RO_NET_BARRIER_ALL, + RO_NET_BARRIER, RO_NET_TEAM_BROADCAST, RO_NET_ALLTOALL, RO_NET_FCOLLECT, diff --git a/src/reverse_offload/context_ro_device.cpp b/src/reverse_offload/context_ro_device.cpp index 7b147d4d34..cef85b5f54 100644 --- a/src/reverse_offload/context_ro_device.cpp +++ b/src/reverse_offload/context_ro_device.cpp @@ -165,16 +165,26 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) { __device__ void ROContext::barrier_all() { if (is_thread_zero_in_block()) { - build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0, + build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true, get_status_flag()); } __syncthreads(); } +__device__ void ROContext::barrier(rocshmem_team_t team) { + ROTeam *team_obj = reinterpret_cast(team); + if (is_thread_zero_in_block()) { + build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, + nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, + true, get_status_flag()); + } + __syncthreads(); +} + __device__ void ROContext::sync_all() { if (is_thread_zero_in_block()) { - build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0, + build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true, get_status_flag()); } diff --git a/src/reverse_offload/context_ro_device.hpp b/src/reverse_offload/context_ro_device.hpp index 9c2f91998e..9ebfa2a30f 100644 --- a/src/reverse_offload/context_ro_device.hpp +++ b/src/reverse_offload/context_ro_device.hpp @@ -65,6 +65,8 @@ class ROContext : public Context { __device__ void barrier_all(); + __device__ void barrier(rocshmem_team_t team); + __device__ void sync_all(); __device__ void sync(rocshmem_team_t team); diff --git a/src/reverse_offload/mpi_transport.cpp b/src/reverse_offload/mpi_transport.cpp index 162e5212b6..3b393f0a31 100644 --- a/src/reverse_offload/mpi_transport.cpp +++ b/src/reverse_offload/mpi_transport.cpp @@ -201,12 +201,16 @@ void MPITransport::submitRequestsToMPI() { next_element.dst, next_element.src, next_element.ol1.size, next_element.team_comm); break; - case RO_NET_BARRIER_ALL: - barrier(queue_idx, next_element.status, true, ro_net_comm_world); + case RO_NET_BARRIER: + barrier(queue_idx, next_element.status, true, + next_element.team_comm == NULL ? ro_net_comm_world : next_element.team_comm, + true); DPRINTF("Received Barrier_all\n"); break; case RO_NET_SYNC: - barrier(queue_idx, next_element.status, true, next_element.team_comm); + barrier(queue_idx, next_element.status, true, + next_element.team_comm == NULL ? ro_net_comm_world : next_element.team_comm, + false); DPRINTF("Received Sync\n"); break; case RO_NET_FENCE: @@ -269,12 +273,19 @@ void MPITransport::global_exit(int status) { } void MPITransport::barrier(int contextId, volatile char *status, bool blocking, - MPI_Comm team) { + MPI_Comm team, bool do_quiet) { MPI_Request request{}; NET_CHECK(MPI_Ibarrier(team, &request)); - requests.push_back({request, {status, contextId, blocking}}); - outstanding[contextId]++; + if (do_quiet) { + requests.push_back({request, {nullptr, contextId, false}}); + outstanding[contextId]++; + + quiet(contextId, status); + } else { + requests.push_back({request, {status, contextId, blocking}}); + outstanding[contextId]++; + } } MPI_Op MPITransport::get_mpi_op(ROCSHMEM_OP op) { @@ -388,7 +399,7 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id, } NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); - barrier(contextId, nullptr, false, comm); + barrier(contextId, nullptr, false, comm, false); quiet(contextId, status); } diff --git a/src/reverse_offload/mpi_transport.hpp b/src/reverse_offload/mpi_transport.hpp index 2129fa29d0..6db776932a 100644 --- a/src/reverse_offload/mpi_transport.hpp +++ b/src/reverse_offload/mpi_transport.hpp @@ -52,7 +52,7 @@ public: rocshmem_team_t *new_team) override; void barrier(int contextId, volatile char *status, bool blocking, - MPI_Comm team) override; + MPI_Comm team, bool quiet) override; void team_reduction(void *dst, void *src, int size, int win_id, int contextId, MPI_Comm team, ROCSHMEM_OP op, diff --git a/src/reverse_offload/transport.hpp b/src/reverse_offload/transport.hpp index e529b5d107..b2bc7fc78a 100644 --- a/src/reverse_offload/transport.hpp +++ b/src/reverse_offload/transport.hpp @@ -51,7 +51,7 @@ class Transport { rocshmem_team_t *new_team) = 0; virtual void barrier(int wg_id, volatile char *status, bool blocking, - MPI_Comm team) = 0; + MPI_Comm team, bool quiet) = 0; virtual void team_reduction(void *dst, void *src, int size, int win_id, int wg_id, MPI_Comm team, ROCSHMEM_OP op, diff --git a/src/rocshmem_gpu.cpp b/src/rocshmem_gpu.cpp index fc78f0931e..07f84c60d4 100644 --- a/src/rocshmem_gpu.cpp +++ b/src/rocshmem_gpu.cpp @@ -580,6 +580,12 @@ __device__ void rocshmem_wg_barrier_all() { rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT); } +__device__ void rocshmem_barrier(rocshmem_team_t team) { + GPU_DPRINTF("Function: rocshmem_barrier\n"); + + get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team); +} + __device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) { GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n"); diff --git a/tests/functional_tests/team_barrier_tester.cpp b/tests/functional_tests/team_barrier_tester.cpp new file mode 100644 index 0000000000..263f2a2c62 --- /dev/null +++ b/tests/functional_tests/team_barrier_tester.cpp @@ -0,0 +1,115 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + + +rocshmem_team_t team_barrier_world_dup; + +/****************************************************************************** + * Device TEST KERNEL + *****************************************************************************/ +__global__ void TeamBarrierTest(int loop, int skip, long long int *start_time, + long long int *end_time, + ShmemContextType ctx_type, + rocshmem_team_t *teams) { + __shared__ rocshmem_ctx_t ctx; + int wg_id = get_flat_grid_id(); + + rocshmem_wg_init(); + rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx); + + int n_pes = rocshmem_ctx_n_pes(ctx); + + __syncthreads(); + + for (int i = 0; i < loop + skip; i++) { + if (i == skip && hipThreadIdx_x == 0) { + start_time[wg_id] = wall_clock64(); + } + + rocshmem_barrier(teams[wg_id]); + } + + __syncthreads(); + + if (hipThreadIdx_x == 0) { + end_time[wg_id] = wall_clock64(); + } + + rocshmem_wg_ctx_destroy(&ctx); + rocshmem_wg_finalize(); +} + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +TeamBarrierTester::TeamBarrierTester(TesterArguments args) + : Tester(args){ + my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } + + CHECK_HIP(hipMalloc(&team_barrier_world_dup, + sizeof(rocshmem_team_t) * num_teams)); +} + +TeamBarrierTester::~TeamBarrierTester() { + CHECK_HIP(hipFree(team_barrier_world_dup)); +} + +void TeamBarrierTester::preLaunchKernel() { + for (int team_i = 0; team_i < num_teams; team_i++) { + team_barrier_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_barrier_world_dup[team_i]); + if (team_barrier_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + printf("Team %d is invalid!\n", team_i); + abort(); + } + } +} + +void TeamBarrierTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, uint64_t size) { + size_t shared_bytes = 0; + + hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize, + shared_bytes, stream, loop, args.skip, + start_time, end_time, _shmem_context, + team_barrier_world_dup); + + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; +} + +void TeamBarrierTester::postLaunchKernel() { + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_barrier_world_dup[team_i]); + } +} + +void TeamBarrierTester::resetBuffers(uint64_t size) {} + +void TeamBarrierTester::verifyResults(uint64_t size) {} diff --git a/tests/functional_tests/team_barrier_tester.hpp b/tests/functional_tests/team_barrier_tester.hpp new file mode 100644 index 0000000000..067895f974 --- /dev/null +++ b/tests/functional_tests/team_barrier_tester.hpp @@ -0,0 +1,66 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _TEAM_BARRIER_TESTER_HPP_ +#define _TEAM_BARRIER_TESTER_HPP_ + +#include +#include + +#include "tester.hpp" + +using namespace rocshmem; + +/************* ***************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class TeamBarrierTester : public Tester { + public: + explicit TeamBarrierTester(TesterArguments args); + virtual ~TeamBarrierTester(); + + protected: + virtual void resetBuffers(uint64_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + uint64_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(uint64_t size) override; + + private: + int my_pe = 0; + int n_pes = 0; + /** + * This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1. + * The default value for the maximum number of teams is 40. + */ + int num_teams = 39; + rocshmem_team_t *team_barrier_world_dup; +}; + +#include "team_barrier_tester.cpp" + +#endif diff --git a/tests/functional_tests/tester.cpp b/tests/functional_tests/tester.cpp index 65e0ea11b6..9bc16a1b73 100644 --- a/tests/functional_tests/tester.cpp +++ b/tests/functional_tests/tester.cpp @@ -45,6 +45,7 @@ #include "sync_tester.hpp" #include "team_alltoall_tester.hpp" #include "team_broadcast_tester.hpp" +#include "team_barrier_tester.hpp" #include "team_ctx_infra_tester.hpp" #include "team_ctx_primitive_tester.hpp" #include "team_fcollect_tester.hpp" @@ -295,6 +296,10 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "Barrier_All ###" << std::endl; testers.push_back(new BarrierAllTester(args)); return testers; + case TeamBarrierTestType: + if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl; + testers.push_back(new TeamBarrierTester(args)); + return testers; case SyncAllTestType: if (rank == 0) std::cout << "SyncAll ###" << std::endl; testers.push_back(new SyncTester(args)); @@ -485,7 +490,8 @@ bool Tester::peLaunchesKernel() { (_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) || (_type == PingPongTestType) || (_type == BarrierAllTestType) || (_type == SyncTestType) || (_type == SyncAllTestType) || - (_type == RandomAccessTestType) || (_type == PingAllTestType); + (_type == RandomAccessTestType) || (_type == PingAllTestType) || + (_type == TeamBarrierTestType); return is_launcher; } diff --git a/tests/functional_tests/tester.hpp b/tests/functional_tests/tester.hpp index 80a107668b..76871f4b33 100644 --- a/tests/functional_tests/tester.hpp +++ b/tests/functional_tests/tester.hpp @@ -92,6 +92,7 @@ enum TestType { SignalFetchTestType = 55, WGSignalFetchTestType = 56, WAVESignalFetchTestType = 57, + TeamBarrierTestType = 58, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/tests/functional_tests/tester_arguments.cpp b/tests/functional_tests/tester_arguments.cpp index c7258b9f61..c8d8dea516 100644 --- a/tests/functional_tests/tester_arguments.cpp +++ b/tests/functional_tests/tester_arguments.cpp @@ -85,6 +85,7 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case AMO_IncTestType: case AMO_FetchTestType: case BarrierAllTestType: + case TeamBarrierTestType: case SyncAllTestType: case SyncTestType: case ShmemPtrTestType: @@ -135,7 +136,8 @@ void TesterArguments::get_rocshmem_arguments() { if ((type != BarrierAllTestType) && (type != SyncAllTestType) && (type != SyncTestType) && (type != TeamAllToAllTestType) && (type != TeamFCollectTestType) && (type != TeamReductionTestType) && - (type != TeamBroadcastTestType) && (type != PingAllTestType)) { + (type != TeamBroadcastTestType) && (type != PingAllTestType) && + (type != TeamBarrierTestType)) { if (numprocs != 2) { if (myid == 0) { std::cerr << "This test requires exactly two processes, we have "