From 4b04b540bf2745df04162af0dd9ed4643fa18954 Mon Sep 17 00:00:00 2001 From: Anatolii Rozanov Date: Wed, 3 Dec 2025 14:40:24 +0100 Subject: [PATCH] Add host API for alltoallmem_on_stream collective operation (#333) * Add host-side rocshmem_alltoallmem_on_stream function Function signature: rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream) - The function launches rocshmem_alltoallmem_kernel which calls device-side alltoall workgroup collective through default context. - Uses dynamic block size determination via occupancy API. - Implemented for all backends. * Fix incorrect sync buffer size allocation for alltoall in GDA and IPC backends When allocating memory for alltoall_pSync_pool in setup_teams() and teams_init() functions, the code incorrectly used ROCSHMEM_BCAST_SYNC_SIZE instead of ROCSHMEM_ALLTOALL_SYNC_SIZE. * Add functional test for team_alltoallmem_on_stream This commit adds a new functional test to verify the correctness of the host-side rocshmem_team_alltoallmem_on_stream API. * Add documentation for rocshmem_alltoallmem_on_stream This commit adds API documentation for the host-side rocshmem_alltoallmem_on_stream function in the collective routines section. The documentation includes: [ROCm/rocshmem commit: 5577feb70d90085f7f53b374f244b6b051f048e2] --- projects/rocshmem/docs/api/coll.rst | 23 ++ .../rocshmem/include/rocshmem/rocshmem.hpp | 16 ++ .../include/rocshmem/rocshmem_COLL.hpp | 17 ++ .../scripts/functional_tests/driver.sh | 3 + projects/rocshmem/src/context.hpp | 4 + projects/rocshmem/src/context_host.cpp | 8 + projects/rocshmem/src/gda/backend_gda.cpp | 4 +- .../rocshmem/src/gda/context_gda_host.cpp | 8 + .../rocshmem/src/gda/context_gda_host.hpp | 4 + projects/rocshmem/src/host/host.cpp | 30 +++ projects/rocshmem/src/host/host.hpp | 4 + projects/rocshmem/src/ipc/backend_ipc.cpp | 4 +- .../rocshmem/src/ipc/context_ipc_host.cpp | 8 + .../rocshmem/src/ipc/context_ipc_host.hpp | 4 + .../src/reverse_offload/context_ro_host.cpp | 10 + .../src/reverse_offload/context_ro_host.hpp | 4 + projects/rocshmem/src/rocshmem.cpp | 9 + projects/rocshmem/src/rocshmem_gpu.cpp | 26 +++ projects/rocshmem/src/stats.hpp | 1 + .../team_alltoallmem_on_stream_tester.cpp | 215 ++++++++++++++++++ .../team_alltoallmem_on_stream_tester.hpp | 70 ++++++ .../tests/functional_tests/tester.cpp | 9 +- .../tests/functional_tests/tester.hpp | 1 + .../functional_tests/tester_arguments.cpp | 3 +- 24 files changed, 479 insertions(+), 6 deletions(-) create mode 100644 projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp create mode 100644 projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.hpp diff --git a/projects/rocshmem/docs/api/coll.rst b/projects/rocshmem/docs/api/coll.rst index 5f3874ded5..809f6e53d2 100644 --- a/projects/rocshmem/docs/api/coll.rst +++ b/projects/rocshmem/docs/api/coll.rst @@ -88,6 +88,29 @@ This function must be called as a work-group collective. Valid ``TYPENAME`` and ``TYPE`` values are listed in :ref:`RMA_TYPES`. +ROCSHMEM_ALLTOALLMEM_ON_STREAM +------------------------------- + +.. cpp:function:: __host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream) + + :param team: The team participating in the collective. + :param dest: Destination address. Must be an address on the symmetric heap. + :param source: Source address. Must be an address on the symmetric heap. + :param size: Number of bytes to transfer per pair of PEs. + :param stream: HIP stream on which to enqueue the operation. + :returns: None. + +**Description:** +This routine enqueues an alltoall collective operation on a HIP stream. The function +exchanges a fixed amount of contiguous data blocks between all pairs of PEs participating +in the collective routine. The operation is enqueued on the specified stream and will +execute asynchronously. The caller must synchronize the stream (e.g., using +``hipStreamSynchronize``) to ensure completion. + +This function creates a separate context for each workgroup to avoid contention on the +default context, allowing parallel execution across multiple streams. +If ``stream`` is ``nullptr``, the operation will use ``hipStreamDefault``. + ROCSHMEM_BROADCAST ------------------ diff --git a/projects/rocshmem/include/rocshmem/rocshmem.hpp b/projects/rocshmem/include/rocshmem/rocshmem.hpp index f3f39b4ad9..cbe0ca2d6a 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem.hpp @@ -349,6 +349,22 @@ __host__ void rocshmem_barrier_all(); */ __host__ void rocshmem_barrier_all_on_stream(hipStream_t stream); +/** + * @brief enqueues an alltoall collective operation on given stream. + * + * @param[in] team The team participating in the collective. + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] size Number of bytes to transfer per pair of PEs. + * @param[in] stream HIP stream on which to enqueue the operation. + * + * @return void + */ +__host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream); + /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. diff --git a/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp b/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp index 9d7f2e5437..af105c9bfc 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp @@ -607,6 +607,23 @@ __host__ int rocshmem_ctx_double_prod_reduce( */ __global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(); +/** + * @brief kernel for performing an alltoall collective operation. + * Caller enqueues the kernel on given stream + * + * @param[in] team The team participating in the collective. + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] size Number of bytes to transfer per pair of PEs. + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team, + void *dest, + const void *source, + size_t size); + /** * @brief perform a collective barrier between all PEs in the system. * The caller is blocked until the barrier is resolved. diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index 91e3bbe726..f5010ee624 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -109,6 +109,7 @@ declare -A TEST_NUMBERS=( ["teamctxsingleinfra"]="73" ["teamctxblockinfra"]="74" ["teamctxoddeveninfra"]="75" + ["alltoallmem_on_stream"]="76" ) ExecTest() { @@ -428,6 +429,8 @@ TestColl() { ExecTest "fcollect" 2 1 1 32768 ExecTest "teamreduction" 2 1 1 32768 + + ExecTest "alltoallmem_on_stream" 2 1 1 32768 } TestOther() { diff --git a/projects/rocshmem/src/context.hpp b/projects/rocshmem/src/context.hpp index 5269265619..619b8a343c 100644 --- a/projects/rocshmem/src/context.hpp +++ b/projects/rocshmem/src/context.hpp @@ -397,6 +397,10 @@ class Context { __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/context_host.cpp b/projects/rocshmem/src/context_host.cpp index 0a78f14c1f..ead3485b45 100644 --- a/projects/rocshmem/src/context_host.cpp +++ b/projects/rocshmem/src/context_host.cpp @@ -122,4 +122,12 @@ __host__ void Context::barrier_all_on_stream(hipStream_t stream) { HOST_DISPATCH(barrier_all_on_stream(stream)); } +__host__ void Context::alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream) { + ctxHostStats.incStat(NUM_HOST_ALLTOALL); + + HOST_DISPATCH(alltoallmem_on_stream(team, dest, source, size, stream)); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/backend_gda.cpp b/projects/rocshmem/src/gda/backend_gda.cpp index 82bf5cfb28..b247cd6026 100644 --- a/projects/rocshmem/src/gda/backend_gda.cpp +++ b/projects/rocshmem/src/gda/backend_gda.cpp @@ -453,8 +453,8 @@ void GDABackend::setup_teams() { * max_num_teams; alltoall_pSync_pool = reinterpret_cast(wrk_sync_pool_top_); - wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_BCAST_SYNC_SIZE - * max_num_teams; + wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_ALLTOALL_SYNC_SIZE * + max_num_teams; /* Accommodating for largest possible data type for pWrk */ pWrk_pool = reinterpret_cast(wrk_sync_pool_top_); diff --git a/projects/rocshmem/src/gda/context_gda_host.cpp b/projects/rocshmem/src/gda/context_gda_host.cpp index c6ffbacb14..2345d0f9ad 100644 --- a/projects/rocshmem/src/gda/context_gda_host.cpp +++ b/projects/rocshmem/src/gda/context_gda_host.cpp @@ -113,4 +113,12 @@ __host__ void GDAHostContext::barrier_all() { host_interface->barrier_all(context_window_info); } +__host__ void GDAHostContext::alltoallmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t size, + hipStream_t stream) { + host_interface->alltoallmem_on_stream(team, dest, source, size, stream); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/context_gda_host.hpp b/projects/rocshmem/src/gda/context_gda_host.hpp index 7f7f86b4d6..3479b999f2 100644 --- a/projects/rocshmem/src/gda/context_gda_host.hpp +++ b/projects/rocshmem/src/gda/context_gda_host.hpp @@ -82,6 +82,10 @@ class GDAHostContext : public Context { __host__ void barrier_all(); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/host/host.cpp b/projects/rocshmem/src/host/host.cpp index fe68cbe499..7b6a782dcd 100644 --- a/projects/rocshmem/src/host/host.cpp +++ b/projects/rocshmem/src/host/host.cpp @@ -333,6 +333,36 @@ __host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) { rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>(); } +__host__ void HostInterface::alltoallmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t size, + hipStream_t stream) { + // launch kernel to do alltoall with given stream, if none, use default stream + if (stream == nullptr) { + stream = hipStreamDefault; + } + + // Use dynamic block size determination: + // - Query optimal block size using occupancy API + // - Limit block size to size (number of bytes) to avoid over-subscription + // - Always use 1 block (single workgroup collective) + int optimal_block_size = 0; + int grid_size = 0; + CHECK_HIP(hipOccupancyMaxPotentialBlockSize(&grid_size, &optimal_block_size, + rocshmem_alltoallmem_kernel, 0, + 0)); + + // Limit block size to size (bytes) to avoid over-subscription + int num_threads_per_block = (optimal_block_size > static_cast(size)) + ? static_cast(size) + : optimal_block_size; + + dim3 gridSize(1); + dim3 blockSize(num_threads_per_block); + rocshmem_alltoallmem_kernel<<>>(team, dest, + source, size); +} __host__ void HostInterface::barrier_for_sync() { if (host_comm_world_ != MPI_COMM_NULL) { diff --git a/projects/rocshmem/src/host/host.hpp b/projects/rocshmem/src/host/host.hpp index 324e513529..286458344d 100644 --- a/projects/rocshmem/src/host/host.hpp +++ b/projects/rocshmem/src/host/host.hpp @@ -196,6 +196,10 @@ class HostInterface { __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream); + __host__ void barrier_for_sync(); __host__ void sync_all(WindowInfo* window_info); diff --git a/projects/rocshmem/src/ipc/backend_ipc.cpp b/projects/rocshmem/src/ipc/backend_ipc.cpp index d61fb7e96b..2e073e9c2f 100644 --- a/projects/rocshmem/src/ipc/backend_ipc.cpp +++ b/projects/rocshmem/src/ipc/backend_ipc.cpp @@ -482,8 +482,8 @@ void IPCBackend::teams_init() { * max_num_teams; alltoall_pSync_pool = reinterpret_cast(wrk_sync_pool_top_); - wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_BCAST_SYNC_SIZE - * max_num_teams; + wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_ALLTOALL_SYNC_SIZE * + max_num_teams; /* Accommodating for largest possible data type for pWrk */ pWrk_pool = reinterpret_cast(wrk_sync_pool_top_); diff --git a/projects/rocshmem/src/ipc/context_ipc_host.cpp b/projects/rocshmem/src/ipc/context_ipc_host.cpp index dbddb50613..e30fa2c379 100644 --- a/projects/rocshmem/src/ipc/context_ipc_host.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_host.cpp @@ -105,4 +105,12 @@ __host__ void IPCHostContext::barrier_all_on_stream(hipStream_t stream) { host_interface->barrier_all_on_stream(stream); } +__host__ void IPCHostContext::alltoallmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t size, + hipStream_t stream) { + host_interface->alltoallmem_on_stream(team, dest, source, size, stream); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/ipc/context_ipc_host.hpp b/projects/rocshmem/src/ipc/context_ipc_host.hpp index 10f1e6460d..f5317a87c3 100644 --- a/projects/rocshmem/src/ipc/context_ipc_host.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_host.hpp @@ -84,6 +84,10 @@ class IPCHostContext : public Context { __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/reverse_offload/context_ro_host.cpp b/projects/rocshmem/src/reverse_offload/context_ro_host.cpp index 6f2a034e7a..f72399341b 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_host.cpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_host.cpp @@ -133,4 +133,14 @@ __host__ void ROHostContext::barrier_all() { host_interface->barrier_for_sync(); } +__host__ void ROHostContext::alltoallmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t size, + hipStream_t stream) { + DPRINTF("Function: ro_net_host_alltoallmem_on_stream\n"); + + host_interface->alltoallmem_on_stream(team, dest, source, size, stream); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/reverse_offload/context_ro_host.hpp b/projects/rocshmem/src/reverse_offload/context_ro_host.hpp index 4e0719a84b..049ce0be6d 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_host.hpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_host.hpp @@ -131,6 +131,10 @@ class ROHostContext : public Context { __host__ void barrier_all(); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/rocshmem.cpp b/projects/rocshmem/src/rocshmem.cpp index 8a526f9acd..4167df121c 100644 --- a/projects/rocshmem/src/rocshmem.cpp +++ b/projects/rocshmem/src/rocshmem.cpp @@ -999,6 +999,15 @@ __host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) { get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all_on_stream(stream); } +__host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t size, + hipStream_t stream) { + DPRINTF("Host function: rocshmem_alltoallmem_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT) + ->alltoallmem_on_stream(team, dest, source, size, stream); +} + __host__ void rocshmem_sync_all() { DPRINTF("Host function: rocshmem_sync_all\n"); diff --git a/projects/rocshmem/src/rocshmem_gpu.cpp b/projects/rocshmem/src/rocshmem_gpu.cpp index 858bf499de..9ef0add371 100644 --- a/projects/rocshmem/src/rocshmem_gpu.cpp +++ b/projects/rocshmem/src/rocshmem_gpu.cpp @@ -648,6 +648,32 @@ __global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(){ rocshmem_barrier_all(); } +__global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team, + void *dest, + const void *source, + size_t size) { + // Create a context for this workgroup to avoid contention on default context + // This allows parallel execution across multiple streams without serialization + __shared__ rocshmem_ctx_t ctx; + __shared__ int ctx_result; + + ctx_result = rocshmem_wg_team_create_ctx(team, 0, &ctx); + + // If context creation failed, fall back to default context + if (ctx_result != 0) { + ctx = ROCSHMEM_CTX_DEFAULT; + __syncthreads(); + } + + // Call device alltoall function with created context and provided team + // Using char type since size is in bytes (1 byte per element) + rocshmem_alltoall_wg(ctx, team, (char *) dest, + (const char *) source, (int) size); + + if (ctx_result == 0) { + rocshmem_wg_ctx_destroy(&ctx); + } +} __device__ void rocshmem_barrier_all() { GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n", diff --git a/projects/rocshmem/src/stats.hpp b/projects/rocshmem/src/stats.hpp index f01d57e00c..2f6cc260a0 100644 --- a/projects/rocshmem/src/stats.hpp +++ b/projects/rocshmem/src/stats.hpp @@ -142,6 +142,7 @@ enum rocshmem_host_stats { NUM_HOST_SHMEM_PTR, NUM_HOST_SYNC_ALL, NUM_HOST_BROADCAST, + NUM_HOST_ALLTOALL, NUM_HOST_STATS }; diff --git a/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp new file mode 100644 index 0000000000..9d8900e603 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "team_alltoallmem_on_stream_tester.hpp" + +#include +#include +#include +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +TeamAlltoallmemOnStreamTester::TeamAlltoallmemOnStreamTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } else { + // Default to number of work groups + num_teams = args.num_wgs; + } + + int num_bytes_wg = args.max_msg_size * n_pes; + int total_bytes = num_bytes_wg * num_teams; + buf_size = total_bytes; + + source_buf = static_cast(rocshmem_malloc(buf_size)); + dest_buf = static_cast(rocshmem_malloc(buf_size)); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source_buf << ", dest: " << dest_buf + << std::endl; + rocshmem_global_exit(1); + } + + team_world_dup.resize(num_teams); + + streams.resize(num_teams); + start_events_timed.resize(num_teams); + stop_events_timed.resize(num_teams); + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipStreamCreate(&streams[i])); + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +TeamAlltoallmemOnStreamTester::~TeamAlltoallmemOnStreamTester() { + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + CHECK_HIP(hipStreamDestroy(streams[i])); + } + rocshmem_free(source_buf); + rocshmem_free(dest_buf); +} + +void TeamAlltoallmemOnStreamTester::preLaunchKernel() { + bw_factor = n_pes; + + for (int team_i = 0; team_i < num_teams; team_i++) { + team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_world_dup[team_i]); + if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + std::cerr << "Team " << team_i << " is invalid!" << std::endl; + abort(); + } + } +} + +void TeamAlltoallmemOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each work group from HIP events + for (int wg_id = 0; wg_id < num_teams && wg_id < num_timers; wg_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, start_events_timed[wg_id], + stop_events_timed[wg_id])); + + // Convert milliseconds to GPU cycles + // wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate + long long int elapsed_cycles = static_cast( + elapsed_time_ms * static_cast(wall_clk_rate)); + + start_time[wg_id] = 0; + end_time[wg_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_teams + for (int i = num_teams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } + + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_world_dup[team_i]); + } +} + +void TeamAlltoallmemOnStreamTester::resetBuffers(size_t size) { + // Initialize source buffer: each PE fills its portion with its PE number + // For alltoall, PE i sends block j to PE j + // Support multiple work groups (teams) + int idx = 0; + + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + for (int pe = 0; pe < n_pes; pe++) { + // Each block in source buffer is filled with (my_pe * n_pes + pe) + // This makes it easy to verify correctness + int value = my_pe * n_pes + pe; + idx = (wg_id * n_pes + pe) * size; + std::memset(source_buf + idx, value, size); + } + } + + // Clear destination buffer + std::memset(dest_buf, 0, buf_size); +} + +void TeamAlltoallmemOnStreamTester::launchKernel(dim3 gridSize, + dim3 blockSize, + int loop, + size_t size) { + // Execute warmup iterations (skip) + for (int i = 0; i < args.skip; i++) { + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + char *wg_source = source_buf + wg_id * n_pes * size; + char *wg_dest = dest_buf + wg_id * n_pes * size; + rocshmem_alltoallmem_on_stream(team_world_dup[wg_id], wg_dest, + wg_source, size, streams[wg_id]); + } + } + + for (int i = 0; i < loop; i++) { + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + // Record start event for this work group on first iteration + if (i == 0) { + CHECK_HIP(hipEventRecord(start_events_timed[wg_id], streams[wg_id])); + } + + char *wg_source = source_buf + wg_id * n_pes * size; + char *wg_dest = dest_buf + wg_id * n_pes * size; + rocshmem_alltoallmem_on_stream(team_world_dup[wg_id], wg_dest, + wg_source, size, streams[wg_id]); + + // Record stop event for this work group on last iteration + if (i == loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[wg_id], streams[wg_id])); + } + } + } + + num_msgs = (loop + args.skip) * num_teams; + num_timed_msgs = loop * num_teams; +} + +void TeamAlltoallmemOnStreamTester::verifyResults(size_t size) { + // Verify correctness: after alltoall, PE i should receive from PE j + // the block that PE j sent to PE i + // PE j sends block i (containing value j * n_pes + i) to PE i + // Support multiple work groups (teams) + int idx = 0; + + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + for (int j = 0; j < n_pes; j++) { + int expected_value = j * n_pes + my_pe; + idx = (wg_id * n_pes + j) * size; + + for (size_t k = 0; k < size; k++) { + if (static_cast(dest_buf[idx + k]) != + static_cast(expected_value)) { + std::cerr << "PE " << my_pe << ": Verification failed for WG " + << wg_id << ", block from PE " << j << " at byte " << k + << std::endl; + std::cerr << "Expected value: " << expected_value + << ", Got: " << static_cast(dest_buf[idx + k]) + << std::endl; + rocshmem_global_exit(1); + } + } + } + } +} + diff --git a/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.hpp new file mode 100644 index 0000000000..d96593ecf1 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.hpp @@ -0,0 +1,70 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _TEAM_ALLTOALLMEM_ON_STREAM_TESTER_HPP_ +#define _TEAM_ALLTOALLMEM_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class TeamAlltoallmemOnStreamTester : public Tester { + public: + explicit TeamAlltoallmemOnStreamTester(TesterArguments args); + virtual ~TeamAlltoallmemOnStreamTester(); + + protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + char *source_buf; + char *dest_buf; + int my_pe; + int n_pes; + size_t buf_size; + int num_teams = 1; + std::vector team_world_dup; + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "team_alltoallmem_on_stream_tester.cpp" + +#endif + diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index cd550db80a..955afb48dc 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -47,6 +47,7 @@ #include "sync_all_tester.hpp" #include "team_sync_tester.hpp" #include "team_alltoall_tester.hpp" +#include "team_alltoallmem_on_stream_tester.hpp" #include "team_barrier_tester.hpp" #include "team_broadcast_tester.hpp" #include "team_ctx_infra_tester.hpp" @@ -227,6 +228,11 @@ std::vector Tester::create(TesterArguments args) { } testers.push_back(new TeamAlltoallTester(args)); return testers; + case TeamAlltoallmemOnStreamTestType: + if (rank == 0) + std::cout << "Alltoallmem_On_Stream ###" << std::endl; + testers.push_back(new TeamAlltoallmemOnStreamTester(args)); + return testers; case TeamFCollectTestType: if (rank == 0) { std::cout << "Fcollect Test ###" << std::endl; @@ -585,7 +591,8 @@ bool Tester::peLaunchesKernel() { (_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) || (_type == RandomAccessTestType) || (_type == PingAllTestType) || (_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) || - (_type == TeamWGBarrierTestType); + (_type == TeamWGBarrierTestType) || + (_type == TeamAlltoallmemOnStreamTestType); return is_launcher; } diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index bd3a7d6a7c..4d67df060a 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -113,6 +113,7 @@ enum TestType { TeamCtxInfraTestSingleType = 73, TeamCtxInfraTestBlockType = 74, TeamCtxInfraTestOddEvenType = 75, + TeamAlltoallmemOnStreamTestType = 76, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index 9da8326017..afbe541df8 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -182,7 +182,8 @@ void TesterArguments::get_arguments() { (type != TeamBroadcastTestType) && (type != PingAllTestType) && (type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) && (type != TeamWGBarrierTestType) && (type != TeamCtxInfraTestBlockType) && - (type != TeamCtxInfraTestOddEvenType)) { + (type != TeamCtxInfraTestOddEvenType) && + (type != TeamAlltoallmemOnStreamTestType)) { if (numprocs != 2) { if (myid == 0) { std::cerr << "This test requires exactly two processes, we have "