Add host API for alltoallmem_on_stream collective operation (#333)
* Add host-side rocshmem_alltoallmem_on_stream function
Function signature:
rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream)
- The function launches rocshmem_alltoallmem_kernel which calls
device-side alltoall<char> workgroup collective through default context.
- Uses dynamic block size determination via occupancy API.
- Implemented for all backends.
* Fix incorrect sync buffer size allocation for alltoall in GDA and IPC backends
When allocating memory for alltoall_pSync_pool in setup_teams() and
teams_init() functions, the code incorrectly used ROCSHMEM_BCAST_SYNC_SIZE
instead of ROCSHMEM_ALLTOALL_SYNC_SIZE.
* Add functional test for team_alltoallmem_on_stream
This commit adds a new functional test to verify the correctness of
the host-side rocshmem_team_alltoallmem_on_stream API.
* Add documentation for rocshmem_alltoallmem_on_stream
This commit adds API documentation for the host-side
rocshmem_alltoallmem_on_stream function in the collective routines
section. The documentation includes:
This commit is contained in:
committad av
GitHub
förälder
8b350a51fe
incheckning
5577feb70d
@@ -88,6 +88,29 @@ This function must be called as a work-group collective.
|
||||
|
||||
Valid ``TYPENAME`` and ``TYPE`` values are listed in :ref:`RMA_TYPES`.
|
||||
|
||||
ROCSHMEM_ALLTOALLMEM_ON_STREAM
|
||||
-------------------------------
|
||||
|
||||
.. cpp:function:: __host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream)
|
||||
|
||||
:param team: The team participating in the collective.
|
||||
:param dest: Destination address. Must be an address on the symmetric heap.
|
||||
:param source: Source address. Must be an address on the symmetric heap.
|
||||
:param size: Number of bytes to transfer per pair of PEs.
|
||||
:param stream: HIP stream on which to enqueue the operation.
|
||||
:returns: None.
|
||||
|
||||
**Description:**
|
||||
This routine enqueues an alltoall collective operation on a HIP stream. The function
|
||||
exchanges a fixed amount of contiguous data blocks between all pairs of PEs participating
|
||||
in the collective routine. The operation is enqueued on the specified stream and will
|
||||
execute asynchronously. The caller must synchronize the stream (e.g., using
|
||||
``hipStreamSynchronize``) to ensure completion.
|
||||
|
||||
This function creates a separate context for each workgroup to avoid contention on the
|
||||
default context, allowing parallel execution across multiple streams.
|
||||
If ``stream`` is ``nullptr``, the operation will use ``hipStreamDefault``.
|
||||
|
||||
ROCSHMEM_BROADCAST
|
||||
------------------
|
||||
|
||||
|
||||
@@ -349,6 +349,22 @@ __host__ void rocshmem_barrier_all();
|
||||
*/
|
||||
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
/**
|
||||
* @brief enqueues an alltoall collective operation on given stream.
|
||||
*
|
||||
* @param[in] team The team participating in the collective.
|
||||
* @param[in] dest Destination address. Must be an address on the symmetric
|
||||
* heap.
|
||||
* @param[in] source Source address. Must be an address on the symmetric heap.
|
||||
* @param[in] size Number of bytes to transfer per pair of PEs.
|
||||
* @param[in] stream HIP stream on which to enqueue the operation.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
|
||||
@@ -607,6 +607,23 @@ __host__ int rocshmem_ctx_double_prod_reduce(
|
||||
*/
|
||||
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel();
|
||||
|
||||
/**
|
||||
* @brief kernel for performing an alltoall collective operation.
|
||||
* Caller enqueues the kernel on given stream
|
||||
*
|
||||
* @param[in] team The team participating in the collective.
|
||||
* @param[in] dest Destination address. Must be an address on the symmetric
|
||||
* heap.
|
||||
* @param[in] source Source address. Must be an address on the symmetric heap.
|
||||
* @param[in] size Number of bytes to transfer per pair of PEs.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team,
|
||||
void *dest,
|
||||
const void *source,
|
||||
size_t size);
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
|
||||
@@ -109,6 +109,7 @@ declare -A TEST_NUMBERS=(
|
||||
["teamctxsingleinfra"]="73"
|
||||
["teamctxblockinfra"]="74"
|
||||
["teamctxoddeveninfra"]="75"
|
||||
["alltoallmem_on_stream"]="76"
|
||||
)
|
||||
|
||||
ExecTest() {
|
||||
@@ -428,6 +429,8 @@ TestColl() {
|
||||
ExecTest "fcollect" 2 1 1 32768
|
||||
|
||||
ExecTest "teamreduction" 2 1 1 32768
|
||||
|
||||
ExecTest "alltoallmem_on_stream" 2 1 1 32768
|
||||
}
|
||||
|
||||
TestOther() {
|
||||
|
||||
@@ -397,6 +397,10 @@ class Context {
|
||||
|
||||
__host__ void barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -122,4 +122,12 @@ __host__ void Context::barrier_all_on_stream(hipStream_t stream) {
|
||||
HOST_DISPATCH(barrier_all_on_stream(stream));
|
||||
}
|
||||
|
||||
__host__ void Context::alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream) {
|
||||
ctxHostStats.incStat(NUM_HOST_ALLTOALL);
|
||||
|
||||
HOST_DISPATCH(alltoallmem_on_stream(team, dest, source, size, stream));
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -453,8 +453,8 @@ void GDABackend::setup_teams() {
|
||||
* max_num_teams;
|
||||
|
||||
alltoall_pSync_pool = reinterpret_cast<long *>(wrk_sync_pool_top_);
|
||||
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_BCAST_SYNC_SIZE
|
||||
* max_num_teams;
|
||||
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_ALLTOALL_SYNC_SIZE *
|
||||
max_num_teams;
|
||||
|
||||
/* Accommodating for largest possible data type for pWrk */
|
||||
pWrk_pool = reinterpret_cast<void *>(wrk_sync_pool_top_);
|
||||
|
||||
@@ -113,4 +113,12 @@ __host__ void GDAHostContext::barrier_all() {
|
||||
host_interface->barrier_all(context_window_info);
|
||||
}
|
||||
|
||||
__host__ void GDAHostContext::alltoallmem_on_stream(rocshmem_team_t team,
|
||||
void *dest,
|
||||
const void *source,
|
||||
size_t size,
|
||||
hipStream_t stream) {
|
||||
host_interface->alltoallmem_on_stream(team, dest, source, size, stream);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -82,6 +82,10 @@ class GDAHostContext : public Context {
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -333,6 +333,36 @@ __host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) {
|
||||
rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>();
|
||||
}
|
||||
|
||||
__host__ void HostInterface::alltoallmem_on_stream(rocshmem_team_t team,
|
||||
void *dest,
|
||||
const void *source,
|
||||
size_t size,
|
||||
hipStream_t stream) {
|
||||
// launch kernel to do alltoall with given stream, if none, use default stream
|
||||
if (stream == nullptr) {
|
||||
stream = hipStreamDefault;
|
||||
}
|
||||
|
||||
// Use dynamic block size determination:
|
||||
// - Query optimal block size using occupancy API
|
||||
// - Limit block size to size (number of bytes) to avoid over-subscription
|
||||
// - Always use 1 block (single workgroup collective)
|
||||
int optimal_block_size = 0;
|
||||
int grid_size = 0;
|
||||
CHECK_HIP(hipOccupancyMaxPotentialBlockSize(&grid_size, &optimal_block_size,
|
||||
rocshmem_alltoallmem_kernel, 0,
|
||||
0));
|
||||
|
||||
// Limit block size to size (bytes) to avoid over-subscription
|
||||
int num_threads_per_block = (optimal_block_size > static_cast<int>(size))
|
||||
? static_cast<int>(size)
|
||||
: optimal_block_size;
|
||||
|
||||
dim3 gridSize(1);
|
||||
dim3 blockSize(num_threads_per_block);
|
||||
rocshmem_alltoallmem_kernel<<<gridSize, blockSize, 0, stream>>>(team, dest,
|
||||
source, size);
|
||||
}
|
||||
|
||||
__host__ void HostInterface::barrier_for_sync() {
|
||||
if (host_comm_world_ != MPI_COMM_NULL) {
|
||||
|
||||
@@ -196,6 +196,10 @@ class HostInterface {
|
||||
|
||||
__host__ void barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
__host__ void barrier_for_sync();
|
||||
|
||||
__host__ void sync_all(WindowInfo* window_info);
|
||||
|
||||
@@ -482,8 +482,8 @@ void IPCBackend::teams_init() {
|
||||
* max_num_teams;
|
||||
|
||||
alltoall_pSync_pool = reinterpret_cast<long *>(wrk_sync_pool_top_);
|
||||
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_BCAST_SYNC_SIZE
|
||||
* max_num_teams;
|
||||
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_ALLTOALL_SYNC_SIZE *
|
||||
max_num_teams;
|
||||
|
||||
/* Accommodating for largest possible data type for pWrk */
|
||||
pWrk_pool = reinterpret_cast<void *>(wrk_sync_pool_top_);
|
||||
|
||||
@@ -105,4 +105,12 @@ __host__ void IPCHostContext::barrier_all_on_stream(hipStream_t stream) {
|
||||
host_interface->barrier_all_on_stream(stream);
|
||||
}
|
||||
|
||||
__host__ void IPCHostContext::alltoallmem_on_stream(rocshmem_team_t team,
|
||||
void *dest,
|
||||
const void *source,
|
||||
size_t size,
|
||||
hipStream_t stream) {
|
||||
host_interface->alltoallmem_on_stream(team, dest, source, size, stream);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -84,6 +84,10 @@ class IPCHostContext : public Context {
|
||||
|
||||
__host__ void barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -133,4 +133,14 @@ __host__ void ROHostContext::barrier_all() {
|
||||
host_interface->barrier_for_sync();
|
||||
}
|
||||
|
||||
__host__ void ROHostContext::alltoallmem_on_stream(rocshmem_team_t team,
|
||||
void *dest,
|
||||
const void *source,
|
||||
size_t size,
|
||||
hipStream_t stream) {
|
||||
DPRINTF("Function: ro_net_host_alltoallmem_on_stream\n");
|
||||
|
||||
host_interface->alltoallmem_on_stream(team, dest, source, size, stream);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -131,6 +131,10 @@ class ROHostContext : public Context {
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream);
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -999,6 +999,15 @@ __host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) {
|
||||
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all_on_stream(stream);
|
||||
}
|
||||
|
||||
__host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest,
|
||||
const void *source, size_t size,
|
||||
hipStream_t stream) {
|
||||
DPRINTF("Host function: rocshmem_alltoallmem_on_stream\n");
|
||||
|
||||
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)
|
||||
->alltoallmem_on_stream(team, dest, source, size, stream);
|
||||
}
|
||||
|
||||
__host__ void rocshmem_sync_all() {
|
||||
DPRINTF("Host function: rocshmem_sync_all\n");
|
||||
|
||||
|
||||
@@ -648,6 +648,32 @@ __global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(){
|
||||
rocshmem_barrier_all();
|
||||
}
|
||||
|
||||
__global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team,
|
||||
void *dest,
|
||||
const void *source,
|
||||
size_t size) {
|
||||
// Create a context for this workgroup to avoid contention on default context
|
||||
// This allows parallel execution across multiple streams without serialization
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
__shared__ int ctx_result;
|
||||
|
||||
ctx_result = rocshmem_wg_team_create_ctx(team, 0, &ctx);
|
||||
|
||||
// If context creation failed, fall back to default context
|
||||
if (ctx_result != 0) {
|
||||
ctx = ROCSHMEM_CTX_DEFAULT;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Call device alltoall function with created context and provided team
|
||||
// Using char type since size is in bytes (1 byte per element)
|
||||
rocshmem_alltoall_wg<char>(ctx, team, (char *) dest,
|
||||
(const char *) source, (int) size);
|
||||
|
||||
if (ctx_result == 0) {
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void rocshmem_barrier_all() {
|
||||
GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n",
|
||||
|
||||
@@ -142,6 +142,7 @@ enum rocshmem_host_stats {
|
||||
NUM_HOST_SHMEM_PTR,
|
||||
NUM_HOST_SYNC_ALL,
|
||||
NUM_HOST_BROADCAST,
|
||||
NUM_HOST_ALLTOALL,
|
||||
NUM_HOST_STATS
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "team_alltoallmem_on_stream_tester.hpp"
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
TeamAlltoallmemOnStreamTester::TeamAlltoallmemOnStreamTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
|
||||
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
char* value{nullptr};
|
||||
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
|
||||
num_teams = atoi(value);
|
||||
} else {
|
||||
// Default to number of work groups
|
||||
num_teams = args.num_wgs;
|
||||
}
|
||||
|
||||
int num_bytes_wg = args.max_msg_size * n_pes;
|
||||
int total_bytes = num_bytes_wg * num_teams;
|
||||
buf_size = total_bytes;
|
||||
|
||||
source_buf = static_cast<char *>(rocshmem_malloc(buf_size));
|
||||
dest_buf = static_cast<char *>(rocshmem_malloc(buf_size));
|
||||
|
||||
if (source_buf == nullptr || dest_buf == nullptr) {
|
||||
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cerr << "source: " << source_buf << ", dest: " << dest_buf
|
||||
<< std::endl;
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
team_world_dup.resize(num_teams);
|
||||
|
||||
streams.resize(num_teams);
|
||||
start_events_timed.resize(num_teams);
|
||||
stop_events_timed.resize(num_teams);
|
||||
for (int i = 0; i < num_teams; i++) {
|
||||
CHECK_HIP(hipStreamCreate(&streams[i]));
|
||||
CHECK_HIP(hipEventCreate(&start_events_timed[i]));
|
||||
CHECK_HIP(hipEventCreate(&stop_events_timed[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TeamAlltoallmemOnStreamTester::~TeamAlltoallmemOnStreamTester() {
|
||||
for (int i = 0; i < num_teams; i++) {
|
||||
CHECK_HIP(hipEventDestroy(stop_events_timed[i]));
|
||||
CHECK_HIP(hipEventDestroy(start_events_timed[i]));
|
||||
CHECK_HIP(hipStreamDestroy(streams[i]));
|
||||
}
|
||||
rocshmem_free(source_buf);
|
||||
rocshmem_free(dest_buf);
|
||||
}
|
||||
|
||||
void TeamAlltoallmemOnStreamTester::preLaunchKernel() {
|
||||
bw_factor = n_pes;
|
||||
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_world_dup[team_i]);
|
||||
if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
|
||||
std::cerr << "Team " << team_i << " is invalid!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TeamAlltoallmemOnStreamTester::postLaunchKernel() {
|
||||
// Synchronize all streams to ensure events are recorded
|
||||
for (int i = 0; i < num_teams; i++) {
|
||||
CHECK_HIP(hipStreamSynchronize(streams[i]));
|
||||
}
|
||||
|
||||
// Get elapsed time for each work group from HIP events
|
||||
for (int wg_id = 0; wg_id < num_teams && wg_id < num_timers; wg_id++) {
|
||||
float elapsed_time_ms = 0.0f;
|
||||
CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, start_events_timed[wg_id],
|
||||
stop_events_timed[wg_id]));
|
||||
|
||||
// Convert milliseconds to GPU cycles
|
||||
// wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate
|
||||
long long int elapsed_cycles = static_cast<long long int>(
|
||||
elapsed_time_ms * static_cast<float>(wall_clk_rate));
|
||||
|
||||
start_time[wg_id] = 0;
|
||||
end_time[wg_id] = elapsed_cycles;
|
||||
}
|
||||
|
||||
// Fill remaining timers with zero if num_timers > num_teams
|
||||
for (int i = num_teams; i < num_timers; i++) {
|
||||
start_time[i] = 0;
|
||||
end_time[i] = 0;
|
||||
}
|
||||
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
rocshmem_team_destroy(team_world_dup[team_i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TeamAlltoallmemOnStreamTester::resetBuffers(size_t size) {
|
||||
// Initialize source buffer: each PE fills its portion with its PE number
|
||||
// For alltoall, PE i sends block j to PE j
|
||||
// Support multiple work groups (teams)
|
||||
int idx = 0;
|
||||
|
||||
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
|
||||
for (int pe = 0; pe < n_pes; pe++) {
|
||||
// Each block in source buffer is filled with (my_pe * n_pes + pe)
|
||||
// This makes it easy to verify correctness
|
||||
int value = my_pe * n_pes + pe;
|
||||
idx = (wg_id * n_pes + pe) * size;
|
||||
std::memset(source_buf + idx, value, size);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear destination buffer
|
||||
std::memset(dest_buf, 0, buf_size);
|
||||
}
|
||||
|
||||
void TeamAlltoallmemOnStreamTester::launchKernel(dim3 gridSize,
|
||||
dim3 blockSize,
|
||||
int loop,
|
||||
size_t size) {
|
||||
// Execute warmup iterations (skip)
|
||||
for (int i = 0; i < args.skip; i++) {
|
||||
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
|
||||
char *wg_source = source_buf + wg_id * n_pes * size;
|
||||
char *wg_dest = dest_buf + wg_id * n_pes * size;
|
||||
rocshmem_alltoallmem_on_stream(team_world_dup[wg_id], wg_dest,
|
||||
wg_source, size, streams[wg_id]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < loop; i++) {
|
||||
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
|
||||
// Record start event for this work group on first iteration
|
||||
if (i == 0) {
|
||||
CHECK_HIP(hipEventRecord(start_events_timed[wg_id], streams[wg_id]));
|
||||
}
|
||||
|
||||
char *wg_source = source_buf + wg_id * n_pes * size;
|
||||
char *wg_dest = dest_buf + wg_id * n_pes * size;
|
||||
rocshmem_alltoallmem_on_stream(team_world_dup[wg_id], wg_dest,
|
||||
wg_source, size, streams[wg_id]);
|
||||
|
||||
// Record stop event for this work group on last iteration
|
||||
if (i == loop - 1) {
|
||||
CHECK_HIP(hipEventRecord(stop_events_timed[wg_id], streams[wg_id]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
num_msgs = (loop + args.skip) * num_teams;
|
||||
num_timed_msgs = loop * num_teams;
|
||||
}
|
||||
|
||||
void TeamAlltoallmemOnStreamTester::verifyResults(size_t size) {
|
||||
// Verify correctness: after alltoall, PE i should receive from PE j
|
||||
// the block that PE j sent to PE i
|
||||
// PE j sends block i (containing value j * n_pes + i) to PE i
|
||||
// Support multiple work groups (teams)
|
||||
int idx = 0;
|
||||
|
||||
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
|
||||
for (int j = 0; j < n_pes; j++) {
|
||||
int expected_value = j * n_pes + my_pe;
|
||||
idx = (wg_id * n_pes + j) * size;
|
||||
|
||||
for (size_t k = 0; k < size; k++) {
|
||||
if (static_cast<unsigned char>(dest_buf[idx + k]) !=
|
||||
static_cast<unsigned char>(expected_value)) {
|
||||
std::cerr << "PE " << my_pe << ": Verification failed for WG "
|
||||
<< wg_id << ", block from PE " << j << " at byte " << k
|
||||
<< std::endl;
|
||||
std::cerr << "Expected value: " << expected_value
|
||||
<< ", Got: " << static_cast<int>(dest_buf[idx + k])
|
||||
<< std::endl;
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef _TEAM_ALLTOALLMEM_ON_STREAM_TESTER_HPP_
|
||||
#define _TEAM_ALLTOALLMEM_ON_STREAM_TESTER_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class TeamAlltoallmemOnStreamTester : public Tester {
|
||||
public:
|
||||
explicit TeamAlltoallmemOnStreamTester(TesterArguments args);
|
||||
virtual ~TeamAlltoallmemOnStreamTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(size_t size) override;
|
||||
|
||||
virtual void preLaunchKernel() override;
|
||||
|
||||
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
size_t size) override;
|
||||
|
||||
virtual void postLaunchKernel() override;
|
||||
|
||||
virtual void verifyResults(size_t size) override;
|
||||
|
||||
private:
|
||||
char *source_buf;
|
||||
char *dest_buf;
|
||||
int my_pe;
|
||||
int n_pes;
|
||||
size_t buf_size;
|
||||
int num_teams = 1;
|
||||
std::vector<rocshmem_team_t> team_world_dup;
|
||||
std::vector<hipStream_t> streams;
|
||||
std::vector<hipEvent_t> start_events_timed;
|
||||
std::vector<hipEvent_t> stop_events_timed;
|
||||
};
|
||||
|
||||
#include "team_alltoallmem_on_stream_tester.cpp"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
#include "sync_all_tester.hpp"
|
||||
#include "team_sync_tester.hpp"
|
||||
#include "team_alltoall_tester.hpp"
|
||||
#include "team_alltoallmem_on_stream_tester.hpp"
|
||||
#include "team_barrier_tester.hpp"
|
||||
#include "team_broadcast_tester.hpp"
|
||||
#include "team_ctx_infra_tester.hpp"
|
||||
@@ -227,6 +228,11 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
}
|
||||
testers.push_back(new TeamAlltoallTester<float>(args));
|
||||
return testers;
|
||||
case TeamAlltoallmemOnStreamTestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Alltoallmem_On_Stream ###" << std::endl;
|
||||
testers.push_back(new TeamAlltoallmemOnStreamTester(args));
|
||||
return testers;
|
||||
case TeamFCollectTestType:
|
||||
if (rank == 0) {
|
||||
std::cout << "Fcollect Test ###" << std::endl;
|
||||
@@ -585,7 +591,8 @@ bool Tester::peLaunchesKernel() {
|
||||
(_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) ||
|
||||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
|
||||
(_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) ||
|
||||
(_type == TeamWGBarrierTestType);
|
||||
(_type == TeamWGBarrierTestType) ||
|
||||
(_type == TeamAlltoallmemOnStreamTestType);
|
||||
|
||||
return is_launcher;
|
||||
}
|
||||
|
||||
@@ -113,6 +113,7 @@ enum TestType {
|
||||
TeamCtxInfraTestSingleType = 73,
|
||||
TeamCtxInfraTestBlockType = 74,
|
||||
TeamCtxInfraTestOddEvenType = 75,
|
||||
TeamAlltoallmemOnStreamTestType = 76,
|
||||
};
|
||||
|
||||
enum OpType { PutType = 0, GetType = 1 };
|
||||
|
||||
@@ -182,7 +182,8 @@ void TesterArguments::get_arguments() {
|
||||
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
|
||||
(type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) &&
|
||||
(type != TeamWGBarrierTestType) && (type != TeamCtxInfraTestBlockType) &&
|
||||
(type != TeamCtxInfraTestOddEvenType)) {
|
||||
(type != TeamCtxInfraTestOddEvenType) &&
|
||||
(type != TeamAlltoallmemOnStreamTestType)) {
|
||||
if (numprocs != 2) {
|
||||
if (myid == 0) {
|
||||
std::cerr << "This test requires exactly two processes, we have "
|
||||
|
||||
Referens i nytt ärende
Block a user