Add host API for alltoallmem_on_stream collective operation (#333)

* Add host-side rocshmem_alltoallmem_on_stream function

Function signature:
  rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest,
                                 const void *source, size_t size,
                                 hipStream_t stream)

- The function launches rocshmem_alltoallmem_kernel which calls
device-side alltoall<char> workgroup collective through default context.
- Uses dynamic block size determination via occupancy API.
- Implemented for all backends.

* Fix incorrect sync buffer size allocation for alltoall in GDA and IPC backends

When allocating memory for alltoall_pSync_pool in setup_teams() and
teams_init() functions, the code incorrectly used ROCSHMEM_BCAST_SYNC_SIZE
instead of ROCSHMEM_ALLTOALL_SYNC_SIZE.

* Add functional test for team_alltoallmem_on_stream

This commit adds a new functional test to verify the correctness of
the host-side rocshmem_team_alltoallmem_on_stream API.

* Add documentation for rocshmem_alltoallmem_on_stream

This commit adds API documentation for the host-side
rocshmem_alltoallmem_on_stream function in the collective routines
section. The documentation includes:

[ROCm/rocshmem commit: 5577feb70d]
Этот коммит содержится в:
Anatolii Rozanov
2025-12-03 14:40:24 +01:00
коммит произвёл GitHub
родитель 0f32739b52
Коммит 4b04b540bf
24 изменённых файлов: 479 добавлений и 6 удалений
+23
Просмотреть файл
@@ -88,6 +88,29 @@ This function must be called as a work-group collective.
Valid ``TYPENAME`` and ``TYPE`` values are listed in :ref:`RMA_TYPES`.
ROCSHMEM_ALLTOALLMEM_ON_STREAM
-------------------------------
.. cpp:function:: __host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream)
:param team: The team participating in the collective.
:param dest: Destination address. Must be an address on the symmetric heap.
:param source: Source address. Must be an address on the symmetric heap.
:param size: Number of bytes to transfer per pair of PEs.
:param stream: HIP stream on which to enqueue the operation.
:returns: None.
**Description:**
This routine enqueues an alltoall collective operation on a HIP stream. The function
exchanges a fixed amount of contiguous data blocks between all pairs of PEs participating
in the collective routine. The operation is enqueued on the specified stream and will
execute asynchronously. The caller must synchronize the stream (e.g., using
``hipStreamSynchronize``) to ensure completion.
This function creates a separate context for each workgroup to avoid contention on the
default context, allowing parallel execution across multiple streams.
If ``stream`` is ``nullptr``, the operation will use ``hipStreamDefault``.
ROCSHMEM_BROADCAST
------------------
+16
Просмотреть файл
@@ -349,6 +349,22 @@ __host__ void rocshmem_barrier_all();
*/
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream);
/**
* @brief enqueues an alltoall collective operation on given stream.
*
* @param[in] team The team participating in the collective.
* @param[in] dest Destination address. Must be an address on the symmetric
* heap.
* @param[in] source Source address. Must be an address on the symmetric heap.
* @param[in] size Number of bytes to transfer per pair of PEs.
* @param[in] stream HIP stream on which to enqueue the operation.
*
* @return void
*/
__host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream);
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
+17
Просмотреть файл
@@ -607,6 +607,23 @@ __host__ int rocshmem_ctx_double_prod_reduce(
*/
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel();
/**
* @brief kernel for performing an alltoall collective operation.
* Caller enqueues the kernel on given stream
*
* @param[in] team The team participating in the collective.
* @param[in] dest Destination address. Must be an address on the symmetric
* heap.
* @param[in] source Source address. Must be an address on the symmetric heap.
* @param[in] size Number of bytes to transfer per pair of PEs.
*
* @return void
*/
__global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team,
void *dest,
const void *source,
size_t size);
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
+3
Просмотреть файл
@@ -109,6 +109,7 @@ declare -A TEST_NUMBERS=(
["teamctxsingleinfra"]="73"
["teamctxblockinfra"]="74"
["teamctxoddeveninfra"]="75"
["alltoallmem_on_stream"]="76"
)
ExecTest() {
@@ -428,6 +429,8 @@ TestColl() {
ExecTest "fcollect" 2 1 1 32768
ExecTest "teamreduction" 2 1 1 32768
ExecTest "alltoallmem_on_stream" 2 1 1 32768
}
TestOther() {
+4
Просмотреть файл
@@ -397,6 +397,10 @@ class Context {
__host__ void barrier_all_on_stream(hipStream_t stream);
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream);
__host__ void sync_all();
template <typename T>
+8
Просмотреть файл
@@ -122,4 +122,12 @@ __host__ void Context::barrier_all_on_stream(hipStream_t stream) {
HOST_DISPATCH(barrier_all_on_stream(stream));
}
__host__ void Context::alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream) {
ctxHostStats.incStat(NUM_HOST_ALLTOALL);
HOST_DISPATCH(alltoallmem_on_stream(team, dest, source, size, stream));
}
} // namespace rocshmem
+2 -2
Просмотреть файл
@@ -453,8 +453,8 @@ void GDABackend::setup_teams() {
* max_num_teams;
alltoall_pSync_pool = reinterpret_cast<long *>(wrk_sync_pool_top_);
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_BCAST_SYNC_SIZE
* max_num_teams;
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_ALLTOALL_SYNC_SIZE *
max_num_teams;
/* Accommodating for largest possible data type for pWrk */
pWrk_pool = reinterpret_cast<void *>(wrk_sync_pool_top_);
+8
Просмотреть файл
@@ -113,4 +113,12 @@ __host__ void GDAHostContext::barrier_all() {
host_interface->barrier_all(context_window_info);
}
__host__ void GDAHostContext::alltoallmem_on_stream(rocshmem_team_t team,
void *dest,
const void *source,
size_t size,
hipStream_t stream) {
host_interface->alltoallmem_on_stream(team, dest, source, size, stream);
}
} // namespace rocshmem
+4
Просмотреть файл
@@ -82,6 +82,10 @@ class GDAHostContext : public Context {
__host__ void barrier_all();
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream);
__host__ void sync_all();
template <typename T>
+30
Просмотреть файл
@@ -333,6 +333,36 @@ __host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) {
rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>();
}
__host__ void HostInterface::alltoallmem_on_stream(rocshmem_team_t team,
void *dest,
const void *source,
size_t size,
hipStream_t stream) {
// launch kernel to do alltoall with given stream, if none, use default stream
if (stream == nullptr) {
stream = hipStreamDefault;
}
// Use dynamic block size determination:
// - Query optimal block size using occupancy API
// - Limit block size to size (number of bytes) to avoid over-subscription
// - Always use 1 block (single workgroup collective)
int optimal_block_size = 0;
int grid_size = 0;
CHECK_HIP(hipOccupancyMaxPotentialBlockSize(&grid_size, &optimal_block_size,
rocshmem_alltoallmem_kernel, 0,
0));
// Limit block size to size (bytes) to avoid over-subscription
int num_threads_per_block = (optimal_block_size > static_cast<int>(size))
? static_cast<int>(size)
: optimal_block_size;
dim3 gridSize(1);
dim3 blockSize(num_threads_per_block);
rocshmem_alltoallmem_kernel<<<gridSize, blockSize, 0, stream>>>(team, dest,
source, size);
}
__host__ void HostInterface::barrier_for_sync() {
if (host_comm_world_ != MPI_COMM_NULL) {
+4
Просмотреть файл
@@ -196,6 +196,10 @@ class HostInterface {
__host__ void barrier_all_on_stream(hipStream_t stream);
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream);
__host__ void barrier_for_sync();
__host__ void sync_all(WindowInfo* window_info);
+2 -2
Просмотреть файл
@@ -482,8 +482,8 @@ void IPCBackend::teams_init() {
* max_num_teams;
alltoall_pSync_pool = reinterpret_cast<long *>(wrk_sync_pool_top_);
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_BCAST_SYNC_SIZE
* max_num_teams;
wrk_sync_pool_top_ += sizeof(long) * ROCSHMEM_ALLTOALL_SYNC_SIZE *
max_num_teams;
/* Accommodating for largest possible data type for pWrk */
pWrk_pool = reinterpret_cast<void *>(wrk_sync_pool_top_);
+8
Просмотреть файл
@@ -105,4 +105,12 @@ __host__ void IPCHostContext::barrier_all_on_stream(hipStream_t stream) {
host_interface->barrier_all_on_stream(stream);
}
__host__ void IPCHostContext::alltoallmem_on_stream(rocshmem_team_t team,
void *dest,
const void *source,
size_t size,
hipStream_t stream) {
host_interface->alltoallmem_on_stream(team, dest, source, size, stream);
}
} // namespace rocshmem
+4
Просмотреть файл
@@ -84,6 +84,10 @@ class IPCHostContext : public Context {
__host__ void barrier_all_on_stream(hipStream_t stream);
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream);
__host__ void sync_all();
template <typename T>
+10
Просмотреть файл
@@ -133,4 +133,14 @@ __host__ void ROHostContext::barrier_all() {
host_interface->barrier_for_sync();
}
__host__ void ROHostContext::alltoallmem_on_stream(rocshmem_team_t team,
void *dest,
const void *source,
size_t size,
hipStream_t stream) {
DPRINTF("Function: ro_net_host_alltoallmem_on_stream\n");
host_interface->alltoallmem_on_stream(team, dest, source, size, stream);
}
} // namespace rocshmem
+4
Просмотреть файл
@@ -131,6 +131,10 @@ class ROHostContext : public Context {
__host__ void barrier_all();
__host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream);
__host__ void sync_all();
template <typename T>
+9
Просмотреть файл
@@ -999,6 +999,15 @@ __host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) {
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all_on_stream(stream);
}
__host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest,
const void *source, size_t size,
hipStream_t stream) {
DPRINTF("Host function: rocshmem_alltoallmem_on_stream\n");
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)
->alltoallmem_on_stream(team, dest, source, size, stream);
}
__host__ void rocshmem_sync_all() {
DPRINTF("Host function: rocshmem_sync_all\n");
+26
Просмотреть файл
@@ -648,6 +648,32 @@ __global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(){
rocshmem_barrier_all();
}
__global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team,
void *dest,
const void *source,
size_t size) {
// Create a context for this workgroup to avoid contention on default context
// This allows parallel execution across multiple streams without serialization
__shared__ rocshmem_ctx_t ctx;
__shared__ int ctx_result;
ctx_result = rocshmem_wg_team_create_ctx(team, 0, &ctx);
// If context creation failed, fall back to default context
if (ctx_result != 0) {
ctx = ROCSHMEM_CTX_DEFAULT;
__syncthreads();
}
// Call device alltoall function with created context and provided team
// Using char type since size is in bytes (1 byte per element)
rocshmem_alltoall_wg<char>(ctx, team, (char *) dest,
(const char *) source, (int) size);
if (ctx_result == 0) {
rocshmem_wg_ctx_destroy(&ctx);
}
}
__device__ void rocshmem_barrier_all() {
GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n",
+1
Просмотреть файл
@@ -142,6 +142,7 @@ enum rocshmem_host_stats {
NUM_HOST_SHMEM_PTR,
NUM_HOST_SYNC_ALL,
NUM_HOST_BROADCAST,
NUM_HOST_ALLTOALL,
NUM_HOST_STATS
};
+215
Просмотреть файл
@@ -0,0 +1,215 @@
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#include "team_alltoallmem_on_stream_tester.hpp"
#include <rocshmem/rocshmem.hpp>
#include <hip/hip_runtime.h>
#include <cstring>
#include <cassert>
#include <vector>
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
TeamAlltoallmemOnStreamTester::TeamAlltoallmemOnStreamTester(TesterArguments args)
: Tester(args) {
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
} else {
// Default to number of work groups
num_teams = args.num_wgs;
}
int num_bytes_wg = args.max_msg_size * n_pes;
int total_bytes = num_bytes_wg * num_teams;
buf_size = total_bytes;
source_buf = static_cast<char *>(rocshmem_malloc(buf_size));
dest_buf = static_cast<char *>(rocshmem_malloc(buf_size));
if (source_buf == nullptr || dest_buf == nullptr) {
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
std::cerr << "source: " << source_buf << ", dest: " << dest_buf
<< std::endl;
rocshmem_global_exit(1);
}
team_world_dup.resize(num_teams);
streams.resize(num_teams);
start_events_timed.resize(num_teams);
stop_events_timed.resize(num_teams);
for (int i = 0; i < num_teams; i++) {
CHECK_HIP(hipStreamCreate(&streams[i]));
CHECK_HIP(hipEventCreate(&start_events_timed[i]));
CHECK_HIP(hipEventCreate(&stop_events_timed[i]));
}
}
TeamAlltoallmemOnStreamTester::~TeamAlltoallmemOnStreamTester() {
for (int i = 0; i < num_teams; i++) {
CHECK_HIP(hipEventDestroy(stop_events_timed[i]));
CHECK_HIP(hipEventDestroy(start_events_timed[i]));
CHECK_HIP(hipStreamDestroy(streams[i]));
}
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
}
void TeamAlltoallmemOnStreamTester::preLaunchKernel() {
bw_factor = n_pes;
for (int team_i = 0; team_i < num_teams; team_i++) {
team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_world_dup[team_i]);
if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
std::cerr << "Team " << team_i << " is invalid!" << std::endl;
abort();
}
}
}
void TeamAlltoallmemOnStreamTester::postLaunchKernel() {
// Synchronize all streams to ensure events are recorded
for (int i = 0; i < num_teams; i++) {
CHECK_HIP(hipStreamSynchronize(streams[i]));
}
// Get elapsed time for each work group from HIP events
for (int wg_id = 0; wg_id < num_teams && wg_id < num_timers; wg_id++) {
float elapsed_time_ms = 0.0f;
CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, start_events_timed[wg_id],
stop_events_timed[wg_id]));
// Convert milliseconds to GPU cycles
// wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate
long long int elapsed_cycles = static_cast<long long int>(
elapsed_time_ms * static_cast<float>(wall_clk_rate));
start_time[wg_id] = 0;
end_time[wg_id] = elapsed_cycles;
}
// Fill remaining timers with zero if num_timers > num_teams
for (int i = num_teams; i < num_timers; i++) {
start_time[i] = 0;
end_time[i] = 0;
}
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_world_dup[team_i]);
}
}
void TeamAlltoallmemOnStreamTester::resetBuffers(size_t size) {
// Initialize source buffer: each PE fills its portion with its PE number
// For alltoall, PE i sends block j to PE j
// Support multiple work groups (teams)
int idx = 0;
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
for (int pe = 0; pe < n_pes; pe++) {
// Each block in source buffer is filled with (my_pe * n_pes + pe)
// This makes it easy to verify correctness
int value = my_pe * n_pes + pe;
idx = (wg_id * n_pes + pe) * size;
std::memset(source_buf + idx, value, size);
}
}
// Clear destination buffer
std::memset(dest_buf, 0, buf_size);
}
void TeamAlltoallmemOnStreamTester::launchKernel(dim3 gridSize,
dim3 blockSize,
int loop,
size_t size) {
// Execute warmup iterations (skip)
for (int i = 0; i < args.skip; i++) {
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
char *wg_source = source_buf + wg_id * n_pes * size;
char *wg_dest = dest_buf + wg_id * n_pes * size;
rocshmem_alltoallmem_on_stream(team_world_dup[wg_id], wg_dest,
wg_source, size, streams[wg_id]);
}
}
for (int i = 0; i < loop; i++) {
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
// Record start event for this work group on first iteration
if (i == 0) {
CHECK_HIP(hipEventRecord(start_events_timed[wg_id], streams[wg_id]));
}
char *wg_source = source_buf + wg_id * n_pes * size;
char *wg_dest = dest_buf + wg_id * n_pes * size;
rocshmem_alltoallmem_on_stream(team_world_dup[wg_id], wg_dest,
wg_source, size, streams[wg_id]);
// Record stop event for this work group on last iteration
if (i == loop - 1) {
CHECK_HIP(hipEventRecord(stop_events_timed[wg_id], streams[wg_id]));
}
}
}
num_msgs = (loop + args.skip) * num_teams;
num_timed_msgs = loop * num_teams;
}
void TeamAlltoallmemOnStreamTester::verifyResults(size_t size) {
// Verify correctness: after alltoall, PE i should receive from PE j
// the block that PE j sent to PE i
// PE j sends block i (containing value j * n_pes + i) to PE i
// Support multiple work groups (teams)
int idx = 0;
for (int wg_id = 0; wg_id < num_teams; wg_id++) {
for (int j = 0; j < n_pes; j++) {
int expected_value = j * n_pes + my_pe;
idx = (wg_id * n_pes + j) * size;
for (size_t k = 0; k < size; k++) {
if (static_cast<unsigned char>(dest_buf[idx + k]) !=
static_cast<unsigned char>(expected_value)) {
std::cerr << "PE " << my_pe << ": Verification failed for WG "
<< wg_id << ", block from PE " << j << " at byte " << k
<< std::endl;
std::cerr << "Expected value: " << expected_value
<< ", Got: " << static_cast<int>(dest_buf[idx + k])
<< std::endl;
rocshmem_global_exit(1);
}
}
}
}
}
+70
Просмотреть файл
@@ -0,0 +1,70 @@
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef _TEAM_ALLTOALLMEM_ON_STREAM_TESTER_HPP_
#define _TEAM_ALLTOALLMEM_ON_STREAM_TESTER_HPP_
#include "tester.hpp"
#include <vector>
#include <hip/hip_runtime.h>
using namespace rocshmem;
/******************************************************************************
* HOST TESTER CLASS
*****************************************************************************/
class TeamAlltoallmemOnStreamTester : public Tester {
public:
explicit TeamAlltoallmemOnStreamTester(TesterArguments args);
virtual ~TeamAlltoallmemOnStreamTester();
protected:
virtual void resetBuffers(size_t size) override;
virtual void preLaunchKernel() override;
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
size_t size) override;
virtual void postLaunchKernel() override;
virtual void verifyResults(size_t size) override;
private:
char *source_buf;
char *dest_buf;
int my_pe;
int n_pes;
size_t buf_size;
int num_teams = 1;
std::vector<rocshmem_team_t> team_world_dup;
std::vector<hipStream_t> streams;
std::vector<hipEvent_t> start_events_timed;
std::vector<hipEvent_t> stop_events_timed;
};
#include "team_alltoallmem_on_stream_tester.cpp"
#endif
+8 -1
Просмотреть файл
@@ -47,6 +47,7 @@
#include "sync_all_tester.hpp"
#include "team_sync_tester.hpp"
#include "team_alltoall_tester.hpp"
#include "team_alltoallmem_on_stream_tester.hpp"
#include "team_barrier_tester.hpp"
#include "team_broadcast_tester.hpp"
#include "team_ctx_infra_tester.hpp"
@@ -227,6 +228,11 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
}
testers.push_back(new TeamAlltoallTester<float>(args));
return testers;
case TeamAlltoallmemOnStreamTestType:
if (rank == 0)
std::cout << "Alltoallmem_On_Stream ###" << std::endl;
testers.push_back(new TeamAlltoallmemOnStreamTester(args));
return testers;
case TeamFCollectTestType:
if (rank == 0) {
std::cout << "Fcollect Test ###" << std::endl;
@@ -585,7 +591,8 @@ bool Tester::peLaunchesKernel() {
(_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) ||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
(_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) ||
(_type == TeamWGBarrierTestType);
(_type == TeamWGBarrierTestType) ||
(_type == TeamAlltoallmemOnStreamTestType);
return is_launcher;
}
+1
Просмотреть файл
@@ -113,6 +113,7 @@ enum TestType {
TeamCtxInfraTestSingleType = 73,
TeamCtxInfraTestBlockType = 74,
TeamCtxInfraTestOddEvenType = 75,
TeamAlltoallmemOnStreamTestType = 76,
};
enum OpType { PutType = 0, GetType = 1 };
+2 -1
Просмотреть файл
@@ -182,7 +182,8 @@ void TesterArguments::get_arguments() {
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
(type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) &&
(type != TeamWGBarrierTestType) && (type != TeamCtxInfraTestBlockType) &&
(type != TeamCtxInfraTestOddEvenType)) {
(type != TeamCtxInfraTestOddEvenType) &&
(type != TeamAlltoallmemOnStreamTestType)) {
if (numprocs != 2) {
if (myid == 0) {
std::cerr << "This test requires exactly two processes, we have "