add rocshmem_barrier() (#61)
* add team-barrier implementation add a team-barrier API and implementation in the IPC and RO conduit. Clean up some of the logic in the RO Conduit to distinguish between sync, sync_all, barrier, and barrier_all. * add team_barrier_tests to functional tests
このコミットが含まれているのは:
@@ -465,6 +465,16 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_wg_barrier_all(
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all();
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the team.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
*
|
||||
* @param[in] team The team on which to perform barrier synchronization
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ void rocshmem_barrier(rocshmem_team_t);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
|
||||
@@ -80,6 +80,7 @@ declare -A TEST_NUMBERS=(
|
||||
["signalfetch"]="55"
|
||||
["wgsignalfetch"]="56"
|
||||
["wavesignalfetch"]="57"
|
||||
["teambarrier"]="58"
|
||||
)
|
||||
|
||||
ExecTest() {
|
||||
@@ -306,6 +307,7 @@ TestColl() {
|
||||
# | Name | Ranks | Workgroups | Threads | Max Message Size #
|
||||
##############################################################################
|
||||
ExecTest "barrierall" 2 1 1
|
||||
ExecTest "teambarrier" 2 1 1
|
||||
|
||||
ExecTest "sync" 2 1 1
|
||||
|
||||
|
||||
@@ -137,6 +137,8 @@ class Context {
|
||||
|
||||
__device__ void barrier_all();
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
|
||||
@@ -148,6 +148,12 @@ __device__ void Context::barrier_all() {
|
||||
DISPATCH(barrier_all());
|
||||
}
|
||||
|
||||
__device__ void Context::barrier(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_BARRIER_ALL);
|
||||
|
||||
DISPATCH(barrier(team));
|
||||
}
|
||||
|
||||
__device__ void Context::sync_all() {
|
||||
ctxStats.incStat(NUM_SYNC_ALL);
|
||||
|
||||
|
||||
@@ -61,6 +61,8 @@ class IPCContext : public Context {
|
||||
|
||||
__device__ void barrier_all();
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
|
||||
@@ -122,4 +122,20 @@ __device__ void IPCContext::barrier_all() {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void IPCContext::barrier(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
quiet();
|
||||
}
|
||||
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -39,7 +39,7 @@ enum ro_net_cmds {
|
||||
RO_NET_FINALIZE,
|
||||
RO_NET_TEAM_REDUCE,
|
||||
RO_NET_SYNC,
|
||||
RO_NET_BARRIER_ALL,
|
||||
RO_NET_BARRIER,
|
||||
RO_NET_TEAM_BROADCAST,
|
||||
RO_NET_ALLTOALL,
|
||||
RO_NET_FCOLLECT,
|
||||
|
||||
@@ -165,16 +165,26 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
|
||||
|
||||
__device__ void ROContext::barrier_all() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true, get_status_flag());
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
|
||||
true, get_status_flag());
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync_all() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true, get_status_flag());
|
||||
}
|
||||
|
||||
@@ -65,6 +65,8 @@ class ROContext : public Context {
|
||||
|
||||
__device__ void barrier_all();
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
|
||||
@@ -201,12 +201,16 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm);
|
||||
break;
|
||||
case RO_NET_BARRIER_ALL:
|
||||
barrier(queue_idx, next_element.status, true, ro_net_comm_world);
|
||||
case RO_NET_BARRIER:
|
||||
barrier(queue_idx, next_element.status, true,
|
||||
next_element.team_comm == NULL ? ro_net_comm_world : next_element.team_comm,
|
||||
true);
|
||||
DPRINTF("Received Barrier_all\n");
|
||||
break;
|
||||
case RO_NET_SYNC:
|
||||
barrier(queue_idx, next_element.status, true, next_element.team_comm);
|
||||
barrier(queue_idx, next_element.status, true,
|
||||
next_element.team_comm == NULL ? ro_net_comm_world : next_element.team_comm,
|
||||
false);
|
||||
DPRINTF("Received Sync\n");
|
||||
break;
|
||||
case RO_NET_FENCE:
|
||||
@@ -269,12 +273,19 @@ void MPITransport::global_exit(int status) {
|
||||
}
|
||||
|
||||
void MPITransport::barrier(int contextId, volatile char *status, bool blocking,
|
||||
MPI_Comm team) {
|
||||
MPI_Comm team, bool do_quiet) {
|
||||
MPI_Request request{};
|
||||
NET_CHECK(MPI_Ibarrier(team, &request));
|
||||
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
outstanding[contextId]++;
|
||||
if (do_quiet) {
|
||||
requests.push_back({request, {nullptr, contextId, false}});
|
||||
outstanding[contextId]++;
|
||||
|
||||
quiet(contextId, status);
|
||||
} else {
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op MPITransport::get_mpi_op(ROCSHMEM_OP op) {
|
||||
@@ -388,7 +399,7 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
}
|
||||
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
barrier(contextId, nullptr, false, comm);
|
||||
barrier(contextId, nullptr, false, comm, false);
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ public:
|
||||
rocshmem_team_t *new_team) override;
|
||||
|
||||
void barrier(int contextId, volatile char *status, bool blocking,
|
||||
MPI_Comm team) override;
|
||||
MPI_Comm team, bool quiet) override;
|
||||
|
||||
void team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, ROCSHMEM_OP op,
|
||||
|
||||
@@ -51,7 +51,7 @@ class Transport {
|
||||
rocshmem_team_t *new_team) = 0;
|
||||
|
||||
virtual void barrier(int wg_id, volatile char *status, bool blocking,
|
||||
MPI_Comm team) = 0;
|
||||
MPI_Comm team, bool quiet) = 0;
|
||||
|
||||
virtual void team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
|
||||
|
||||
@@ -580,6 +580,12 @@ __device__ void rocshmem_wg_barrier_all() {
|
||||
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_barrier(rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_barrier\n");
|
||||
|
||||
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
|
||||
rocshmem_team_t team_barrier_world_dup;
|
||||
|
||||
/******************************************************************************
|
||||
* Device TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void TeamBarrierTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time,
|
||||
ShmemContextType ctx_type,
|
||||
rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_barrier(teams[wg_id]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
TeamBarrierTester::TeamBarrierTester(TesterArguments args)
|
||||
: Tester(args){
|
||||
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
|
||||
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
char* value{nullptr};
|
||||
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
|
||||
num_teams = atoi(value);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipMalloc(&team_barrier_world_dup,
|
||||
sizeof(rocshmem_team_t) * num_teams));
|
||||
}
|
||||
|
||||
TeamBarrierTester::~TeamBarrierTester() {
|
||||
CHECK_HIP(hipFree(team_barrier_world_dup));
|
||||
}
|
||||
|
||||
void TeamBarrierTester::preLaunchKernel() {
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
team_barrier_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_barrier_world_dup[team_i]);
|
||||
if (team_barrier_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
|
||||
printf("Team %d is invalid!\n", team_i);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TeamBarrierTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize,
|
||||
shared_bytes, stream, loop, args.skip,
|
||||
start_time, end_time, _shmem_context,
|
||||
team_barrier_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void TeamBarrierTester::postLaunchKernel() {
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
rocshmem_team_destroy(team_barrier_world_dup[team_i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TeamBarrierTester::resetBuffers(uint64_t size) {}
|
||||
|
||||
void TeamBarrierTester::verifyResults(uint64_t size) {}
|
||||
@@ -0,0 +1,66 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef _TEAM_BARRIER_TESTER_HPP_
|
||||
#define _TEAM_BARRIER_TESTER_HPP_
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "tester.hpp"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/************* *****************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class TeamBarrierTester : public Tester {
|
||||
public:
|
||||
explicit TeamBarrierTester(TesterArguments args);
|
||||
virtual ~TeamBarrierTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
|
||||
virtual void preLaunchKernel() override;
|
||||
|
||||
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) override;
|
||||
|
||||
virtual void postLaunchKernel() override;
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
private:
|
||||
int my_pe = 0;
|
||||
int n_pes = 0;
|
||||
/**
|
||||
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
|
||||
* The default value for the maximum number of teams is 40.
|
||||
*/
|
||||
int num_teams = 39;
|
||||
rocshmem_team_t *team_barrier_world_dup;
|
||||
};
|
||||
|
||||
#include "team_barrier_tester.cpp"
|
||||
|
||||
#endif
|
||||
@@ -45,6 +45,7 @@
|
||||
#include "sync_tester.hpp"
|
||||
#include "team_alltoall_tester.hpp"
|
||||
#include "team_broadcast_tester.hpp"
|
||||
#include "team_barrier_tester.hpp"
|
||||
#include "team_ctx_infra_tester.hpp"
|
||||
#include "team_ctx_primitive_tester.hpp"
|
||||
#include "team_fcollect_tester.hpp"
|
||||
@@ -295,6 +296,10 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) std::cout << "Barrier_All ###" << std::endl;
|
||||
testers.push_back(new BarrierAllTester(args));
|
||||
return testers;
|
||||
case TeamBarrierTestType:
|
||||
if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl;
|
||||
testers.push_back(new TeamBarrierTester(args));
|
||||
return testers;
|
||||
case SyncAllTestType:
|
||||
if (rank == 0) std::cout << "SyncAll ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
@@ -485,7 +490,8 @@ bool Tester::peLaunchesKernel() {
|
||||
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
|
||||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
|
||||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
|
||||
(_type == RandomAccessTestType) || (_type == PingAllTestType);
|
||||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
|
||||
(_type == TeamBarrierTestType);
|
||||
|
||||
return is_launcher;
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@ enum TestType {
|
||||
SignalFetchTestType = 55,
|
||||
WGSignalFetchTestType = 56,
|
||||
WAVESignalFetchTestType = 57,
|
||||
TeamBarrierTestType = 58,
|
||||
};
|
||||
|
||||
enum OpType { PutType = 0, GetType = 1 };
|
||||
|
||||
@@ -85,6 +85,7 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case AMO_IncTestType:
|
||||
case AMO_FetchTestType:
|
||||
case BarrierAllTestType:
|
||||
case TeamBarrierTestType:
|
||||
case SyncAllTestType:
|
||||
case SyncTestType:
|
||||
case ShmemPtrTestType:
|
||||
@@ -135,7 +136,8 @@ void TesterArguments::get_rocshmem_arguments() {
|
||||
if ((type != BarrierAllTestType) && (type != SyncAllTestType) &&
|
||||
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
|
||||
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
|
||||
(type != TeamBroadcastTestType) && (type != PingAllTestType)) {
|
||||
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
|
||||
(type != TeamBarrierTestType)) {
|
||||
if (numprocs != 2) {
|
||||
if (myid == 0) {
|
||||
std::cerr << "This test requires exactly two processes, we have "
|
||||
|
||||
新しいイシューから参照
ユーザーをブロックする