add rocshmem_barrier() (#61)

* add team-barrier implementation

add a team-barrier API and implementation in the IPC and RO conduit.
Clean up some of the logic in the RO Conduit to distinguish between
sync, sync_all, barrier, and barrier_all.

* add team_barrier_tests to functional tests
このコミットが含まれているのは:
Edgar Gabriel
2025-03-24 11:23:03 -05:00
committed by GitHub
コミット bcbc42e78f
18個のファイルの変更271行の追加14行の削除
+10
ファイルの表示
@@ -465,6 +465,16 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_wg_barrier_all(
__device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the team.
* The caller is blocked until the barrier is resolved.
*
* @param[in] team The team on which to perform barrier synchronization
*
* @return void
*/
__device__ void rocshmem_barrier(rocshmem_team_t);
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
+2
ファイルの表示
@@ -80,6 +80,7 @@ declare -A TEST_NUMBERS=(
["signalfetch"]="55"
["wgsignalfetch"]="56"
["wavesignalfetch"]="57"
["teambarrier"]="58"
)
ExecTest() {
@@ -306,6 +307,7 @@ TestColl() {
# | Name | Ranks | Workgroups | Threads | Max Message Size #
##############################################################################
ExecTest "barrierall" 2 1 1
ExecTest "teambarrier" 2 1 1
ExecTest "sync" 2 1 1
+2
ファイルの表示
@@ -137,6 +137,8 @@ class Context {
__device__ void barrier_all();
__device__ void barrier(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync(rocshmem_team_t team);
+6
ファイルの表示
@@ -148,6 +148,12 @@ __device__ void Context::barrier_all() {
DISPATCH(barrier_all());
}
__device__ void Context::barrier(rocshmem_team_t team) {
ctxStats.incStat(NUM_BARRIER_ALL);
DISPATCH(barrier(team));
}
__device__ void Context::sync_all() {
ctxStats.incStat(NUM_SYNC_ALL);
+2
ファイルの表示
@@ -61,6 +61,8 @@ class IPCContext : public Context {
__device__ void barrier_all();
__device__ void barrier(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync(rocshmem_team_t team);
+16
ファイルの表示
@@ -122,4 +122,20 @@ __device__ void IPCContext::barrier_all() {
__syncthreads();
}
__device__ void IPCContext::barrier(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
int pe = team_obj->my_pe_in_world;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_stride = team_obj->tinfo_wrt_world->stride;
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
if (is_thread_zero_in_block()) {
quiet();
}
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
__syncthreads();
}
} // namespace rocshmem
+1 -1
ファイルの表示
@@ -39,7 +39,7 @@ enum ro_net_cmds {
RO_NET_FINALIZE,
RO_NET_TEAM_REDUCE,
RO_NET_SYNC,
RO_NET_BARRIER_ALL,
RO_NET_BARRIER,
RO_NET_TEAM_BROADCAST,
RO_NET_ALLTOALL,
RO_NET_FCOLLECT,
+12 -2
ファイルの表示
@@ -165,16 +165,26 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
__device__ void ROContext::barrier_all() {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true, get_status_flag());
}
__syncthreads();
}
__device__ void ROContext::barrier(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
true, get_status_flag());
}
__syncthreads();
}
__device__ void ROContext::sync_all() {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true, get_status_flag());
}
+2
ファイルの表示
@@ -65,6 +65,8 @@ class ROContext : public Context {
__device__ void barrier_all();
__device__ void barrier(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync(rocshmem_team_t team);
+18 -7
ファイルの表示
@@ -201,12 +201,16 @@ void MPITransport::submitRequestsToMPI() {
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm);
break;
case RO_NET_BARRIER_ALL:
barrier(queue_idx, next_element.status, true, ro_net_comm_world);
case RO_NET_BARRIER:
barrier(queue_idx, next_element.status, true,
next_element.team_comm == NULL ? ro_net_comm_world : next_element.team_comm,
true);
DPRINTF("Received Barrier_all\n");
break;
case RO_NET_SYNC:
barrier(queue_idx, next_element.status, true, next_element.team_comm);
barrier(queue_idx, next_element.status, true,
next_element.team_comm == NULL ? ro_net_comm_world : next_element.team_comm,
false);
DPRINTF("Received Sync\n");
break;
case RO_NET_FENCE:
@@ -269,12 +273,19 @@ void MPITransport::global_exit(int status) {
}
void MPITransport::barrier(int contextId, volatile char *status, bool blocking,
MPI_Comm team) {
MPI_Comm team, bool do_quiet) {
MPI_Request request{};
NET_CHECK(MPI_Ibarrier(team, &request));
requests.push_back({request, {status, contextId, blocking}});
outstanding[contextId]++;
if (do_quiet) {
requests.push_back({request, {nullptr, contextId, false}});
outstanding[contextId]++;
quiet(contextId, status);
} else {
requests.push_back({request, {status, contextId, blocking}});
outstanding[contextId]++;
}
}
MPI_Op MPITransport::get_mpi_op(ROCSHMEM_OP op) {
@@ -388,7 +399,7 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
}
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
barrier(contextId, nullptr, false, comm);
barrier(contextId, nullptr, false, comm, false);
quiet(contextId, status);
}
+1 -1
ファイルの表示
@@ -52,7 +52,7 @@ public:
rocshmem_team_t *new_team) override;
void barrier(int contextId, volatile char *status, bool blocking,
MPI_Comm team) override;
MPI_Comm team, bool quiet) override;
void team_reduction(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, ROCSHMEM_OP op,
+1 -1
ファイルの表示
@@ -51,7 +51,7 @@ class Transport {
rocshmem_team_t *new_team) = 0;
virtual void barrier(int wg_id, volatile char *status, bool blocking,
MPI_Comm team) = 0;
MPI_Comm team, bool quiet) = 0;
virtual void team_reduction(void *dst, void *src, int size, int win_id,
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
+6
ファイルの表示
@@ -580,6 +580,12 @@ __device__ void rocshmem_wg_barrier_all() {
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_barrier(rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_barrier\n");
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
}
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
+115
ファイルの表示
@@ -0,0 +1,115 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
rocshmem_team_t team_barrier_world_dup;
/******************************************************************************
* Device TEST KERNEL
*****************************************************************************/
__global__ void TeamBarrierTest(int loop, int skip, long long int *start_time,
long long int *end_time,
ShmemContextType ctx_type,
rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
__syncthreads();
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
rocshmem_barrier(teams[wg_id]);
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
TeamBarrierTester::TeamBarrierTester(TesterArguments args)
: Tester(args){
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
}
CHECK_HIP(hipMalloc(&team_barrier_world_dup,
sizeof(rocshmem_team_t) * num_teams));
}
TeamBarrierTester::~TeamBarrierTester() {
CHECK_HIP(hipFree(team_barrier_world_dup));
}
void TeamBarrierTester::preLaunchKernel() {
for (int team_i = 0; team_i < num_teams; team_i++) {
team_barrier_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_barrier_world_dup[team_i]);
if (team_barrier_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
printf("Team %d is invalid!\n", team_i);
abort();
}
}
}
void TeamBarrierTester::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize,
shared_bytes, stream, loop, args.skip,
start_time, end_time, _shmem_context,
team_barrier_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
void TeamBarrierTester::postLaunchKernel() {
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_barrier_world_dup[team_i]);
}
}
void TeamBarrierTester::resetBuffers(uint64_t size) {}
void TeamBarrierTester::verifyResults(uint64_t size) {}
+66
ファイルの表示
@@ -0,0 +1,66 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef _TEAM_BARRIER_TESTER_HPP_
#define _TEAM_BARRIER_TESTER_HPP_
#include <functional>
#include <utility>
#include "tester.hpp"
using namespace rocshmem;
/************* *****************************************************************
* HOST TESTER CLASS
*****************************************************************************/
class TeamBarrierTester : public Tester {
public:
explicit TeamBarrierTester(TesterArguments args);
virtual ~TeamBarrierTester();
protected:
virtual void resetBuffers(uint64_t size) override;
virtual void preLaunchKernel() override;
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) override;
virtual void postLaunchKernel() override;
virtual void verifyResults(uint64_t size) override;
private:
int my_pe = 0;
int n_pes = 0;
/**
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
* The default value for the maximum number of teams is 40.
*/
int num_teams = 39;
rocshmem_team_t *team_barrier_world_dup;
};
#include "team_barrier_tester.cpp"
#endif
+7 -1
ファイルの表示
@@ -45,6 +45,7 @@
#include "sync_tester.hpp"
#include "team_alltoall_tester.hpp"
#include "team_broadcast_tester.hpp"
#include "team_barrier_tester.hpp"
#include "team_ctx_infra_tester.hpp"
#include "team_ctx_primitive_tester.hpp"
#include "team_fcollect_tester.hpp"
@@ -295,6 +296,10 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) std::cout << "Barrier_All ###" << std::endl;
testers.push_back(new BarrierAllTester(args));
return testers;
case TeamBarrierTestType:
if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl;
testers.push_back(new TeamBarrierTester(args));
return testers;
case SyncAllTestType:
if (rank == 0) std::cout << "SyncAll ###" << std::endl;
testers.push_back(new SyncTester(args));
@@ -485,7 +490,8 @@ bool Tester::peLaunchesKernel() {
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
(_type == RandomAccessTestType) || (_type == PingAllTestType);
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
(_type == TeamBarrierTestType);
return is_launcher;
}
+1
ファイルの表示
@@ -92,6 +92,7 @@ enum TestType {
SignalFetchTestType = 55,
WGSignalFetchTestType = 56,
WAVESignalFetchTestType = 57,
TeamBarrierTestType = 58,
};
enum OpType { PutType = 0, GetType = 1 };
+3 -1
ファイルの表示
@@ -85,6 +85,7 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case AMO_IncTestType:
case AMO_FetchTestType:
case BarrierAllTestType:
case TeamBarrierTestType:
case SyncAllTestType:
case SyncTestType:
case ShmemPtrTestType:
@@ -135,7 +136,8 @@ void TesterArguments::get_rocshmem_arguments() {
if ((type != BarrierAllTestType) && (type != SyncAllTestType) &&
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
(type != TeamBroadcastTestType) && (type != PingAllTestType)) {
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
(type != TeamBarrierTestType)) {
if (numprocs != 2) {
if (myid == 0) {
std::cerr << "This test requires exactly two processes, we have "