Add multi work-group support for collective functional tests (#45)
- Added multi-work group support for the All-to-all, Fcollect, Broadcast, Barrier and Sync collective functional tests
- Renamed All-to-all and Fcollect tests to TeamAlltoAll and TeamFcollect
[ROCm/rocshmem commit: 57d60aa727]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
e1ed36e58f
Коммит
65b4ff4c41
@@ -1,165 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/* Declare the template with a generic implementation */
|
||||
template <typename T>
|
||||
__device__ void wg_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelem) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Define templates to call rocSHMEM */
|
||||
#define ALLTOALL_DEF_GEN(T, TNAME) \
|
||||
template <> \
|
||||
__device__ void wg_alltoall<T>(rocshmem_ctx_t ctx, rocshmem_team_t team, \
|
||||
T * dest, const T *source, int nelem) { \
|
||||
rocshmem_ctx_##TNAME##_wg_alltoall(ctx, team, dest, source, nelem); \
|
||||
}
|
||||
|
||||
ALLTOALL_DEF_GEN(float, float)
|
||||
ALLTOALL_DEF_GEN(double, double)
|
||||
ALLTOALL_DEF_GEN(char, char)
|
||||
// ALLTOALL_DEF_GEN(long double, longdouble)
|
||||
ALLTOALL_DEF_GEN(signed char, schar)
|
||||
ALLTOALL_DEF_GEN(short, short)
|
||||
ALLTOALL_DEF_GEN(int, int)
|
||||
ALLTOALL_DEF_GEN(long, long)
|
||||
ALLTOALL_DEF_GEN(long long, longlong)
|
||||
ALLTOALL_DEF_GEN(unsigned char, uchar)
|
||||
ALLTOALL_DEF_GEN(unsigned short, ushort)
|
||||
ALLTOALL_DEF_GEN(unsigned int, uint)
|
||||
ALLTOALL_DEF_GEN(unsigned long, ulong)
|
||||
ALLTOALL_DEF_GEN(unsigned long long, ulonglong)
|
||||
|
||||
rocshmem_team_t team_alltoall_world_dup;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
__global__ void AlltoallTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, T1 *source_buf,
|
||||
T1 *dest_buf, int size, ShmemContextType ctx_type,
|
||||
rocshmem_team_t team) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
wg_alltoall<T1>(ctx, team,
|
||||
dest_buf, // T* dest
|
||||
source_buf, // const T* source
|
||||
size); // int nelement
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
AlltoallTester<T1>::AlltoallTester(
|
||||
TesterArguments args, std::function<void(T1 &, T1 &, T1)> f1,
|
||||
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2)
|
||||
: Tester(args), init_buf{f1}, verify_buf{f2} {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
|
||||
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
AlltoallTester<T1>::~AlltoallTester() {
|
||||
rocshmem_free(source_buf);
|
||||
rocshmem_free(dest_buf);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void AlltoallTester<T1>::preLaunchKernel() {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
bw_factor = sizeof(T1) * n_pes;
|
||||
|
||||
team_alltoall_world_dup = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_alltoall_world_dup);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void AlltoallTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(AlltoallTest<T1>, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time, source_buf,
|
||||
dest_buf, size, _shmem_context, team_alltoall_world_dup);
|
||||
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void AlltoallTester<T1>::postLaunchKernel() {
|
||||
rocshmem_team_destroy(team_alltoall_world_dup);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void AlltoallTester<T1>::resetBuffers(uint64_t size) {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
for (int i = 0; i < n_pes; i++) {
|
||||
for (uint64_t j = 0; j < size; j++) {
|
||||
init_buf(source_buf[i * size + j], dest_buf[i * size + j], (T1)i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void AlltoallTester<T1>::verifyResults(uint64_t size) {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
for (int i = 0; i < n_pes; i++) {
|
||||
for (uint64_t j = 0; j < size; j++) {
|
||||
auto r = verify_buf(dest_buf[i * size + j], i);
|
||||
if (r.first == false) {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", j);
|
||||
fprintf(stderr, "%s.\n", r.second.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,13 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
|
||||
|
||||
__syncthreads();
|
||||
|
||||
rocshmem_ctx_wg_barrier_all(ctx);
|
||||
/**
|
||||
* The function `rocshmem_ctx_wg_barrier_all` should be called from only
|
||||
* one group within the grid to avoid unintended behavior.
|
||||
*/
|
||||
if (is_block_zero_in_grid()) {
|
||||
rocshmem_ctx_wg_barrier_all(ctx);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -70,7 +76,7 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream,
|
||||
loop, args.skip, start_time, end_time);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
rocshmem_team_t team_fcollect_world_dup;
|
||||
|
||||
/* Declare the template with a generic implementation */
|
||||
template <typename T>
|
||||
__device__ void wg_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelems) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Define templates to call rocSHMEM */
|
||||
#define FCOLLECT_DEF_GEN(T, TNAME) \
|
||||
template <> \
|
||||
__device__ void wg_fcollect<T>(rocshmem_ctx_t ctx, rocshmem_team_t team, \
|
||||
T * dest, const T *source, int nelem) { \
|
||||
rocshmem_ctx_##TNAME##_wg_fcollect(ctx, team, dest, source, nelem); \
|
||||
}
|
||||
|
||||
FCOLLECT_DEF_GEN(float, float)
|
||||
FCOLLECT_DEF_GEN(double, double)
|
||||
FCOLLECT_DEF_GEN(char, char)
|
||||
// FCOLLECT_DEF_GEN(long double, longdouble)
|
||||
FCOLLECT_DEF_GEN(signed char, schar)
|
||||
FCOLLECT_DEF_GEN(short, short)
|
||||
FCOLLECT_DEF_GEN(int, int)
|
||||
FCOLLECT_DEF_GEN(long, long)
|
||||
FCOLLECT_DEF_GEN(long long, longlong)
|
||||
FCOLLECT_DEF_GEN(unsigned char, uchar)
|
||||
FCOLLECT_DEF_GEN(unsigned short, ushort)
|
||||
FCOLLECT_DEF_GEN(unsigned int, uint)
|
||||
FCOLLECT_DEF_GEN(unsigned long, ulong)
|
||||
FCOLLECT_DEF_GEN(unsigned long long, ulonglong)
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
__global__ void FcollectTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, T1 *source_buf,
|
||||
T1 *dest_buf, int size, ShmemContextType ctx_type,
|
||||
rocshmem_team_t team) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
wg_fcollect<T1>(ctx, team,
|
||||
dest_buf, // T* dest
|
||||
source_buf, // const T* source
|
||||
size); // int nelement
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
FcollectTester<T1>::FcollectTester(
|
||||
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
|
||||
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2)
|
||||
: Tester(args), init_buf{f1}, verify_buf{f2} {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1));
|
||||
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
FcollectTester<T1>::~FcollectTester() {
|
||||
rocshmem_free(source_buf);
|
||||
rocshmem_free(dest_buf);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void FcollectTester<T1>::preLaunchKernel() {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
bw_factor = sizeof(T1) * n_pes;
|
||||
|
||||
team_fcollect_world_dup = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_fcollect_world_dup);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void FcollectTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(FcollectTest<T1>, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time, source_buf,
|
||||
dest_buf, size, _shmem_context, team_fcollect_world_dup);
|
||||
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void FcollectTester<T1>::postLaunchKernel() {
|
||||
rocshmem_team_destroy(team_fcollect_world_dup);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void FcollectTester<T1>::resetBuffers(uint64_t size) {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
for (int i = 0; i < n_pes; i++) {
|
||||
for (uint64_t j = 0; j < size; j++) {
|
||||
// Note: This is redundant work,
|
||||
// source is being reinitialized multiple times
|
||||
init_buf(source_buf[j], dest_buf[i * size + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void FcollectTester<T1>::verifyResults(uint64_t size) {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
for (int i = 0; i < n_pes; i++) {
|
||||
for (uint64_t j = 0; j < size; j++) {
|
||||
auto r = verify_buf(dest_buf[i * size + j], i);
|
||||
if (r.first == false) {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", j);
|
||||
fprintf(stderr, "%s.\n", r.second.c_str());
|
||||
// exit(-1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,17 +22,12 @@
|
||||
|
||||
#include "sync_tester.hpp"
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
using namespace rocshmem;
|
||||
rocshmem_team_t team_sync_world_dup;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void SyncTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, TestType type,
|
||||
ShmemContextType ctx_type, rocshmem_team_t team) {
|
||||
ShmemContextType ctx_type, rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
@@ -47,10 +42,16 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
|
||||
__syncthreads();
|
||||
switch (type) {
|
||||
case SyncAllTestType:
|
||||
rocshmem_ctx_wg_sync_all(ctx);
|
||||
/**
|
||||
* The function `rocshmem_ctx_wg_sync_all` should be called from only
|
||||
* one group within the grid to avoid unintended behavior.
|
||||
*/
|
||||
if (is_block_zero_in_grid()) {
|
||||
rocshmem_ctx_wg_sync_all(ctx);
|
||||
}
|
||||
break;
|
||||
case SyncTestType:
|
||||
rocshmem_ctx_wg_team_sync(ctx, team);
|
||||
rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -69,28 +70,60 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
SyncTester::SyncTester(TesterArguments args) : Tester(args) {}
|
||||
SyncTester::SyncTester(TesterArguments args) : Tester(args) {
|
||||
|
||||
SyncTester::~SyncTester() {}
|
||||
char* value{nullptr};
|
||||
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
|
||||
num_teams = atoi(value);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipMalloc(&team_sync_world_dup,
|
||||
sizeof(rocshmem_team_t) * num_teams));
|
||||
}
|
||||
|
||||
SyncTester::~SyncTester() {
|
||||
CHECK_HIP(hipFree(team_sync_world_dup));
|
||||
}
|
||||
|
||||
void SyncTester::resetBuffers(uint64_t size) {}
|
||||
|
||||
void SyncTester::preLaunchKernel() {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
team_sync_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_sync_world_dup[team_i]);
|
||||
if (team_sync_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
|
||||
printf("Team %d is invalid!\n", team_i);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
team_sync_world_dup = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_sync_world_dup);
|
||||
|
||||
hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop,
|
||||
args.skip, start_time, end_time, _type, _shmem_context,
|
||||
team_sync_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
|
||||
if(_type == SyncTestType) {
|
||||
num_msgs *= gridSize.x;
|
||||
num_timed_msgs *= gridSize.x;
|
||||
}
|
||||
}
|
||||
|
||||
void SyncTester::postLaunchKernel() {
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
rocshmem_team_destroy(team_sync_world_dup[team_i]);
|
||||
}
|
||||
}
|
||||
|
||||
void SyncTester::verifyResults(uint64_t size) {}
|
||||
|
||||
@@ -23,8 +23,12 @@
|
||||
#ifndef _SYNC_TESTER_HPP_
|
||||
#define _SYNC_TESTER_HPP_
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
#include "tester.hpp"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
@@ -36,10 +40,22 @@ class SyncTester : public Tester {
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
|
||||
virtual void preLaunchKernel() override;
|
||||
|
||||
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) override;
|
||||
|
||||
virtual void postLaunchKernel() override;
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
|
||||
* The default value for the maximum number of teams is 40.
|
||||
*/
|
||||
int num_teams = 39;
|
||||
rocshmem_team_t *team_sync_world_dup;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* Declare the template with a generic implementation */
|
||||
template <typename T>
|
||||
__device__ void wg_team_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team,
|
||||
T *dest, const T *source, int nelem) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Define templates to call rocSHMEM */
|
||||
#define TEAM_ALLTOALL_DEF_GEN(T, TNAME) \
|
||||
template <> \
|
||||
__device__ void wg_team_alltoall<T>(rocshmem_ctx_t ctx, rocshmem_team_t team,\
|
||||
T * dest, const T *source, int nelem) { \
|
||||
rocshmem_ctx_##TNAME##_wg_alltoall(ctx, team, dest, source, nelem); \
|
||||
}
|
||||
|
||||
TEAM_ALLTOALL_DEF_GEN(float, float)
|
||||
TEAM_ALLTOALL_DEF_GEN(double, double)
|
||||
TEAM_ALLTOALL_DEF_GEN(char, char)
|
||||
// TEAM_ALLTOALL_DEF_GEN(long double, longdouble)
|
||||
TEAM_ALLTOALL_DEF_GEN(signed char, schar)
|
||||
TEAM_ALLTOALL_DEF_GEN(short, short)
|
||||
TEAM_ALLTOALL_DEF_GEN(int, int)
|
||||
TEAM_ALLTOALL_DEF_GEN(long, long)
|
||||
TEAM_ALLTOALL_DEF_GEN(long long, longlong)
|
||||
TEAM_ALLTOALL_DEF_GEN(unsigned char, uchar)
|
||||
TEAM_ALLTOALL_DEF_GEN(unsigned short, ushort)
|
||||
TEAM_ALLTOALL_DEF_GEN(unsigned int, uint)
|
||||
TEAM_ALLTOALL_DEF_GEN(unsigned long, ulong)
|
||||
TEAM_ALLTOALL_DEF_GEN(unsigned long long, ulonglong)
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
__global__ void TeamAlltoallTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, T1 *source_buf,
|
||||
T1 *dest_buf, int num_elems,
|
||||
ShmemContextType ctx_type,
|
||||
rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
|
||||
source_buf += wg_id * n_pes * num_elems;
|
||||
dest_buf += wg_id * n_pes * num_elems;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
wg_team_alltoall<T1>(ctx, teams[wg_id],
|
||||
dest_buf, // T* dest
|
||||
source_buf, // const T* source
|
||||
num_elems); // int nelement
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
TeamAlltoallTester<T1>::TeamAlltoallTester(TesterArguments args)
|
||||
: Tester(args){
|
||||
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
|
||||
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
// Number of elements per work group
|
||||
int num_elems_wg = (args.max_msg_size / sizeof(T1)) * n_pes;
|
||||
// Total number of elements in the GPU kernel
|
||||
int total_elems = num_elems_wg * args.num_wgs;
|
||||
int buff_size = total_elems * sizeof(T1);
|
||||
|
||||
source_buf = (T1 *)rocshmem_malloc(buff_size);
|
||||
dest_buf = (T1 *)rocshmem_malloc(buff_size);
|
||||
|
||||
if (source_buf == nullptr || dest_buf == nullptr) {
|
||||
std::cout << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cout << "source: " << source_buf
|
||||
<< ", dest: " << dest_buf
|
||||
<< std::endl;
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
char* value{nullptr};
|
||||
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
|
||||
num_teams = atoi(value);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipMalloc(&team_alltoall_world_dup,
|
||||
sizeof(rocshmem_team_t) * num_teams));
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
TeamAlltoallTester<T1>::~TeamAlltoallTester() {
|
||||
rocshmem_free(source_buf);
|
||||
rocshmem_free(dest_buf);
|
||||
CHECK_HIP(hipFree(team_alltoall_world_dup));
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamAlltoallTester<T1>::preLaunchKernel() {
|
||||
bw_factor = n_pes;
|
||||
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
team_alltoall_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_alltoall_world_dup[team_i]);
|
||||
if (team_alltoall_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
|
||||
std::cout << "Team " << team_i << " is invalid!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamAlltoallTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
int num_elems = size / sizeof(T1);
|
||||
|
||||
hipLaunchKernelGGL(TeamAlltoallTest<T1>, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time,
|
||||
source_buf, dest_buf, num_elems, _shmem_context,
|
||||
team_alltoall_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamAlltoallTester<T1>::postLaunchKernel() {
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
rocshmem_team_destroy(team_alltoall_world_dup[team_i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamAlltoallTester<T1>::resetBuffers(uint64_t size) {
|
||||
|
||||
int num_elems = size / sizeof(T1);
|
||||
int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes;
|
||||
int idx = 0;
|
||||
|
||||
for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
|
||||
for(int pe = 0; pe < n_pes; pe++) {
|
||||
for(int i = 0; i < num_elems; i++) {
|
||||
idx = (wg_id * n_pes + pe) * num_elems + i;
|
||||
if constexpr (std::is_same<T1, char>::value ||
|
||||
std::is_same<T1, signed char>::value ||
|
||||
std::is_same<T1, unsigned char>::value) {
|
||||
source_buf[idx] = static_cast<T1>('a' + my_pe + pe + wg_id);
|
||||
}
|
||||
else if constexpr (std::is_floating_point<T1>::value) {
|
||||
source_buf[idx] = static_cast<T1>(3.14 + my_pe + pe + wg_id);
|
||||
}
|
||||
else if constexpr (std::is_integral<T1>::value) {
|
||||
source_buf[idx] = static_cast<T1>(my_pe + pe + wg_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memset(dest_buf, -1, buff_size);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamAlltoallTester<T1>::verifyResults(uint64_t size) {
|
||||
int num_elems = size / sizeof(T1);
|
||||
int idx = 0;
|
||||
|
||||
for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
|
||||
for(int pe = 0; pe < n_pes; pe++) {
|
||||
for(int i = 0; i < num_elems; i++) {
|
||||
idx = (wg_id * n_pes + pe) * num_elems + i;
|
||||
if (dest_buf[idx] != source_buf[idx]) {
|
||||
std::cerr << "Data validation error at idx " << idx << std::endl;
|
||||
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
|
||||
<< ", Expected " << source_buf[idx] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+17
-10
@@ -28,16 +28,16 @@
|
||||
|
||||
#include "tester.hpp"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/************* *****************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
class AlltoallTester : public Tester {
|
||||
class TeamAlltoallTester : public Tester {
|
||||
public:
|
||||
explicit AlltoallTester(
|
||||
TesterArguments args, std::function<void(T1 &, T1 &, T1)> f1,
|
||||
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2);
|
||||
virtual ~AlltoallTester();
|
||||
explicit TeamAlltoallTester(TesterArguments args);
|
||||
virtual ~TeamAlltoallTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
@@ -51,14 +51,21 @@ class AlltoallTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
T1 *source_buf;
|
||||
T1 *dest_buf;
|
||||
T1 *source_buf = nullptr;
|
||||
T1 *dest_buf = nullptr;
|
||||
|
||||
private:
|
||||
std::function<void(T1 &, T1 &, T1)> init_buf;
|
||||
std::function<std::pair<bool, std::string>(const T1 &, T1)> verify_buf;
|
||||
int my_pe = 0;
|
||||
int n_pes = 0;
|
||||
|
||||
/**
|
||||
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
|
||||
* The default value for the maximum number of teams is 40.
|
||||
*/
|
||||
int num_teams = 39;
|
||||
rocshmem_team_t *team_alltoall_world_dup;
|
||||
};
|
||||
|
||||
#include "alltoall_tester.cpp"
|
||||
#include "team_alltoall_tester.cpp"
|
||||
|
||||
#endif
|
||||
@@ -20,8 +20,6 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/* Declare the template with a generic implementation */
|
||||
template <typename T>
|
||||
__device__ void wg_team_broadcast(rocshmem_ctx_t ctx, rocshmem_team_t team,
|
||||
@@ -65,28 +63,29 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, T1 *source_buf,
|
||||
T1 *dest_buf, int size,
|
||||
ShmemContextType ctx_type,
|
||||
rocshmem_team_t team) {
|
||||
rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
source_buf += wg_id * size;
|
||||
dest_buf += wg_id * size;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop; i++) {
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
wg_team_broadcast<T1>(ctx, team,
|
||||
dest_buf, // T* dest
|
||||
source_buf, // const T* source
|
||||
size, // int nelement
|
||||
0); // int PE_root
|
||||
rocshmem_ctx_wg_barrier_all(ctx);
|
||||
wg_team_broadcast<T1>(ctx, teams[wg_id],
|
||||
dest_buf, // T* dest
|
||||
source_buf, // const T* source
|
||||
size, // int nelement
|
||||
0); // int PE_root
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -103,27 +102,51 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time,
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
TeamBroadcastTester<T1>::TeamBroadcastTester(
|
||||
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
|
||||
std::function<std::pair<bool, std::string>(const T1 &)> f2)
|
||||
: Tester(args), init_buf{f1}, verify_buf{f2} {
|
||||
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1));
|
||||
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1));
|
||||
TeamBroadcastTester<T1>::TeamBroadcastTester(TesterArguments args)
|
||||
: Tester(args){
|
||||
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
|
||||
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
// Total number of elements in src buffer
|
||||
int total_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs ;
|
||||
int buff_size = total_elems * sizeof(T1);
|
||||
|
||||
source_buf = (T1 *)rocshmem_malloc(buff_size);
|
||||
dest_buf = (T1 *)rocshmem_malloc(buff_size);
|
||||
|
||||
if (source_buf == nullptr || dest_buf == nullptr) {
|
||||
std::cout << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cout << "source: " << source_buf << ", dest: " << dest_buf << std::endl;
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
char* value{nullptr};
|
||||
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
|
||||
num_teams = atoi(value);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipMalloc(&team_bcast_world_dup,
|
||||
sizeof(rocshmem_team_t) * num_teams));
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
TeamBroadcastTester<T1>::~TeamBroadcastTester() {
|
||||
rocshmem_free(source_buf);
|
||||
rocshmem_free(dest_buf);
|
||||
CHECK_HIP(hipFree(team_bcast_world_dup));
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamBroadcastTester<T1>::preLaunchKernel() {
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
team_bcast_world_dup = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_bcast_world_dup);
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
team_bcast_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_bcast_world_dup[team_i]);
|
||||
if (team_bcast_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
|
||||
printf("Team %d is invalid!\n", team_i);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
@@ -131,34 +154,90 @@ void TeamBroadcastTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(TeamBroadcastTest<T1>, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time, source_buf,
|
||||
dest_buf, size, _shmem_context, team_bcast_world_dup);
|
||||
int num_elems = size / sizeof(T1);
|
||||
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
hipLaunchKernelGGL(TeamBroadcastTest<T1>, gridSize, blockSize,
|
||||
shared_bytes, stream, loop, args.skip,
|
||||
start_time, end_time, source_buf, dest_buf,
|
||||
num_elems, _shmem_context, team_bcast_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamBroadcastTester<T1>::postLaunchKernel() {
|
||||
rocshmem_team_destroy(team_bcast_world_dup);
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
rocshmem_team_destroy(team_bcast_world_dup[team_i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamBroadcastTester<T1>::resetBuffers(uint64_t size) {
|
||||
for (uint64_t i = 0; i < args.max_msg_size; i++) {
|
||||
init_buf(source_buf[i], dest_buf[i]);
|
||||
|
||||
int num_elems = size / sizeof(T1);
|
||||
int buff_size = num_elems * sizeof(T1) * args.num_wgs;
|
||||
int idx = 0;
|
||||
|
||||
for (int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
idx = wg_id * num_elems + i;
|
||||
if constexpr (std::is_same<T1, char>::value ||
|
||||
std::is_same<T1, signed char>::value ||
|
||||
std::is_same<T1, unsigned char>::value) {
|
||||
source_buf[idx] = static_cast<T1>('a' + n_pes + wg_id);
|
||||
dest_buf[idx] = static_cast<T1>('a' + wg_id);
|
||||
}
|
||||
else if constexpr (std::is_floating_point<T1>::value) {
|
||||
source_buf[idx] = static_cast<T1>(3.14 + n_pes + wg_id);
|
||||
dest_buf[idx] = static_cast<T1>(3.14 + wg_id);
|
||||
}
|
||||
else if constexpr (std::is_integral<T1>::value) {
|
||||
source_buf[idx] = static_cast<T1>(n_pes + wg_id);
|
||||
dest_buf[idx] = static_cast<T1>(wg_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamBroadcastTester<T1>::verifyResults(uint64_t size) {
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
auto r = verify_buf(dest_buf[i]);
|
||||
if (r.first == false) {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", i);
|
||||
fprintf(stderr, "%s.\n", r.second.c_str());
|
||||
exit(-1);
|
||||
|
||||
int num_elems = size / sizeof(T1);
|
||||
int idx = 0;
|
||||
T1 expected;
|
||||
|
||||
/**
|
||||
* The verification routine here requires that the
|
||||
* PE_root value is 0 which denotes that the
|
||||
* sending processing element is rank 0.
|
||||
*
|
||||
* The difference in expected values arises from
|
||||
* the specification for broadcast where the
|
||||
* PE_root processing element does not copy the
|
||||
* contents from its own source to dest during
|
||||
* the broadcast.
|
||||
*/
|
||||
for (int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
idx = wg_id * num_elems + i;
|
||||
if constexpr (std::is_same<T1, char>::value ||
|
||||
std::is_same<T1, signed char>::value ||
|
||||
std::is_same<T1, unsigned char>::value) {
|
||||
expected = static_cast<T1>('a' + wg_id + (my_pe ? n_pes : 0));
|
||||
}
|
||||
else if constexpr (std::is_floating_point<T1>::value) {
|
||||
expected = static_cast<T1>(3.14 + wg_id + (my_pe ? n_pes : 0));
|
||||
}
|
||||
else if constexpr (std::is_integral<T1>::value) {
|
||||
expected = static_cast<T1>(wg_id + (my_pe ? n_pes : 0));
|
||||
}
|
||||
if (dest_buf[idx] != expected) {
|
||||
std::cerr << "Data validation error at idx " << idx << std::endl;
|
||||
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
|
||||
<< ", Expected " << expected << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,15 +28,15 @@
|
||||
|
||||
#include "tester.hpp"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/************* *****************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
class TeamBroadcastTester : public Tester {
|
||||
public:
|
||||
explicit TeamBroadcastTester(
|
||||
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
|
||||
std::function<std::pair<bool, std::string>(const T1 &)> f2);
|
||||
explicit TeamBroadcastTester(TesterArguments args);
|
||||
virtual ~TeamBroadcastTester();
|
||||
|
||||
protected:
|
||||
@@ -55,8 +55,14 @@ class TeamBroadcastTester : public Tester {
|
||||
T1 *dest_buf;
|
||||
|
||||
private:
|
||||
std::function<void(T1 &, T1 &)> init_buf;
|
||||
std::function<std::pair<bool, std::string>(const T1 &)> verify_buf;
|
||||
int my_pe = 0;
|
||||
int n_pes = 0;
|
||||
/**
|
||||
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
|
||||
* The default value for the maximum number of teams is 40.
|
||||
*/
|
||||
int num_teams = 39;
|
||||
rocshmem_team_t *team_bcast_world_dup;
|
||||
};
|
||||
|
||||
#include "team_broadcast_tester.cpp"
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* Declare the template with a generic implementation */
|
||||
template <typename T>
|
||||
__device__ void wg_team_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team,
|
||||
T *dest, const T *source, int nelems) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Define templates to call rocSHMEM */
|
||||
#define TEAM_FCOLLECT_DEF_GEN(T, TNAME) \
|
||||
template <> \
|
||||
__device__ void wg_team_fcollect<T>(rocshmem_ctx_t ctx, rocshmem_team_t team,\
|
||||
T * dest, const T *source, int nelem) { \
|
||||
rocshmem_ctx_##TNAME##_wg_fcollect(ctx, team, dest, source, nelem); \
|
||||
}
|
||||
|
||||
TEAM_FCOLLECT_DEF_GEN(float, float)
|
||||
TEAM_FCOLLECT_DEF_GEN(double, double)
|
||||
TEAM_FCOLLECT_DEF_GEN(char, char)
|
||||
// TEAM_FCOLLECT_DEF_GEN(long double, longdouble)
|
||||
TEAM_FCOLLECT_DEF_GEN(signed char, schar)
|
||||
TEAM_FCOLLECT_DEF_GEN(short, short)
|
||||
TEAM_FCOLLECT_DEF_GEN(int, int)
|
||||
TEAM_FCOLLECT_DEF_GEN(long, long)
|
||||
TEAM_FCOLLECT_DEF_GEN(long long, longlong)
|
||||
TEAM_FCOLLECT_DEF_GEN(unsigned char, uchar)
|
||||
TEAM_FCOLLECT_DEF_GEN(unsigned short, ushort)
|
||||
TEAM_FCOLLECT_DEF_GEN(unsigned int, uint)
|
||||
TEAM_FCOLLECT_DEF_GEN(unsigned long, ulong)
|
||||
TEAM_FCOLLECT_DEF_GEN(unsigned long long, ulonglong)
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
__global__ void TeamFcollectTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, T1 *source_buf,
|
||||
T1 *dest_buf, int num_elems,
|
||||
ShmemContextType ctx_type,
|
||||
rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
source_buf += wg_id * num_elems;
|
||||
dest_buf += wg_id * num_elems * n_pes;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
wg_team_fcollect<T1>(ctx, teams[wg_id],
|
||||
dest_buf, // T* dest
|
||||
source_buf, // const T* source
|
||||
num_elems); // int nelement
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
TeamFcollectTester<T1>::TeamFcollectTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
|
||||
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
// Total number of elements in src buffer
|
||||
int total_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs ;
|
||||
int buff_size = total_elems * sizeof(T1);
|
||||
|
||||
source_buf = (T1 *)rocshmem_malloc(buff_size);
|
||||
dest_buf = (T1 *)rocshmem_malloc(buff_size * n_pes);
|
||||
|
||||
if (source_buf == nullptr || dest_buf == nullptr) {
|
||||
std::cout << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cout << "source: " << source_buf
|
||||
<< ", dest: " << dest_buf
|
||||
<< std::endl;
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<T1, char>::value ||
|
||||
std::is_same<T1, signed char>::value ||
|
||||
std::is_same<T1, unsigned char>::value) {
|
||||
for (int i = 0; i < total_elems; ++i) {
|
||||
source_buf[i] = static_cast<T1>('a' + my_pe);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_floating_point<T1>::value) {
|
||||
for (int i = 0; i < total_elems; ++i) {
|
||||
source_buf[i] = static_cast<T1>(3.14 + my_pe);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_integral<T1>::value) {
|
||||
for (int i = 0; i < total_elems; i++) {
|
||||
source_buf[i] = static_cast<T1>(my_pe);
|
||||
}
|
||||
}
|
||||
|
||||
char* value{nullptr};
|
||||
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
|
||||
num_teams = atoi(value);
|
||||
}
|
||||
|
||||
CHECK_HIP(hipMalloc(&team_fcollect_world_dup,
|
||||
sizeof(rocshmem_team_t) * num_teams));
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
TeamFcollectTester<T1>::~TeamFcollectTester() {
|
||||
rocshmem_free(source_buf);
|
||||
rocshmem_free(dest_buf);
|
||||
CHECK_HIP(hipFree(team_fcollect_world_dup));
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamFcollectTester<T1>::preLaunchKernel() {
|
||||
bw_factor = n_pes;
|
||||
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
team_fcollect_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
|
||||
&team_fcollect_world_dup[team_i]);
|
||||
if (team_fcollect_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
|
||||
std::cout << "Team " << team_i << " is invalid!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamFcollectTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
int num_elems = size / sizeof(T1);
|
||||
|
||||
int my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
|
||||
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
|
||||
|
||||
hipLaunchKernelGGL(TeamFcollectTest<T1>, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time,
|
||||
source_buf, dest_buf, num_elems, _shmem_context,
|
||||
team_fcollect_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamFcollectTester<T1>::postLaunchKernel() {
|
||||
for (int team_i = 0; team_i < num_teams; team_i++) {
|
||||
rocshmem_team_destroy(team_fcollect_world_dup[team_i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamFcollectTester<T1>::resetBuffers(uint64_t size) {
|
||||
int num_elems = (size / sizeof(T1));
|
||||
int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes;
|
||||
|
||||
memset(dest_buf, -1, buff_size);
|
||||
}
|
||||
|
||||
template <typename T1>
|
||||
void TeamFcollectTester<T1>::verifyResults(uint64_t size) {
|
||||
|
||||
int num_elems = size / sizeof(T1);
|
||||
int idx = 0;
|
||||
T1 expected;
|
||||
|
||||
for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
|
||||
for(int pe = 0; pe < n_pes; pe++) {
|
||||
for(int i = 0; i < num_elems; i++) {
|
||||
idx = (wg_id * n_pes + pe) * num_elems + i;
|
||||
if constexpr (std::is_same<T1, char>::value ||
|
||||
std::is_same<T1, signed char>::value ||
|
||||
std::is_same<T1, unsigned char>::value) {
|
||||
expected = static_cast<T1>('a' + pe);
|
||||
}
|
||||
else if constexpr (std::is_floating_point<T1>::value) {
|
||||
expected = static_cast<T1>(3.14 + pe);
|
||||
}
|
||||
else if constexpr (std::is_integral<T1>::value) {
|
||||
expected = pe;
|
||||
}
|
||||
if (dest_buf[idx] != expected) {
|
||||
std::cerr << "Data validation error at idx " << idx << std::endl;
|
||||
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
|
||||
<< ", Expected " << expected << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+14
-8
@@ -28,16 +28,16 @@
|
||||
|
||||
#include "tester.hpp"
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/************* *****************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
template <typename T1>
|
||||
class FcollectTester : public Tester {
|
||||
class TeamFcollectTester : public Tester {
|
||||
public:
|
||||
explicit FcollectTester(
|
||||
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
|
||||
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2);
|
||||
virtual ~FcollectTester();
|
||||
explicit TeamFcollectTester(TesterArguments args);
|
||||
virtual ~TeamFcollectTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
@@ -55,10 +55,16 @@ class FcollectTester : public Tester {
|
||||
T1 *dest_buf;
|
||||
|
||||
private:
|
||||
std::function<void(T1 &, T1 &)> init_buf;
|
||||
std::function<std::pair<bool, std::string>(const T1 &, T1)> verify_buf;
|
||||
int my_pe = 0;
|
||||
int n_pes = 0;
|
||||
/**
|
||||
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
|
||||
* The default value for the maximum number of teams is 40.
|
||||
*/
|
||||
int num_teams = 39;
|
||||
rocshmem_team_t *team_fcollect_world_dup;
|
||||
};
|
||||
|
||||
#include "fcollect_tester.cpp"
|
||||
#include "team_fcollect_tester.cpp"
|
||||
|
||||
#endif
|
||||
@@ -30,14 +30,12 @@
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "alltoall_tester.hpp"
|
||||
#include "amo_bitwise_tester.hpp"
|
||||
#include "amo_extended_tester.hpp"
|
||||
#include "amo_standard_tester.hpp"
|
||||
#include "barrier_all_tester.hpp"
|
||||
#include "empty_tester.hpp"
|
||||
#include "extended_primitives.hpp"
|
||||
#include "fcollect_tester.hpp"
|
||||
#include "ping_all_tester.hpp"
|
||||
#include "ping_pong_tester.hpp"
|
||||
#include "primitive_mr_tester.hpp"
|
||||
@@ -47,9 +45,11 @@
|
||||
#include "signaling_operations_tester.hpp"
|
||||
#include "swarm_tester.hpp"
|
||||
#include "sync_tester.hpp"
|
||||
#include "team_alltoall_tester.hpp"
|
||||
#include "team_broadcast_tester.hpp"
|
||||
#include "team_ctx_infra_tester.hpp"
|
||||
#include "team_ctx_primitive_tester.hpp"
|
||||
#include "team_fcollect_tester.hpp"
|
||||
#include "team_reduction_tester.hpp"
|
||||
#include "wave_level_primitives.hpp"
|
||||
|
||||
@@ -162,85 +162,37 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) {
|
||||
std::cout << "Team Broadcast Test ###" << std::endl;
|
||||
}
|
||||
testers.push_back(new TeamBroadcastTester<long>(
|
||||
args,
|
||||
[](long& f1, long& f2) {
|
||||
f1 = 1;
|
||||
f2 = 2;
|
||||
},
|
||||
[rank](long v) {
|
||||
long expected_val;
|
||||
/**
|
||||
* The verification routine here requires that the
|
||||
* PE_root value is 0 which denotes that the
|
||||
* sending processing element is rank 0.
|
||||
*
|
||||
* The difference in expected values arises from
|
||||
* the specification for broadcast where the
|
||||
* PE_root processing element does not copy the
|
||||
* contents from its own source to dest during
|
||||
* the broadcast.
|
||||
*/
|
||||
if (rank == 0) {
|
||||
expected_val = 2;
|
||||
} else {
|
||||
expected_val = 1;
|
||||
}
|
||||
|
||||
return (v == expected_val)
|
||||
? std::make_pair(true, "")
|
||||
: std::make_pair(
|
||||
false, "Rank " + std::to_string(rank) + ", Got " +
|
||||
std::to_string(v) + ", Expect " +
|
||||
std::to_string(expected_val));
|
||||
}));
|
||||
testers.push_back(new TeamBroadcastTester<int64_t>(args));
|
||||
testers.push_back(new TeamBroadcastTester<int>(args));
|
||||
testers.push_back(new TeamBroadcastTester<long long>(args));
|
||||
testers.push_back(new TeamBroadcastTester<float>(args));
|
||||
testers.push_back(new TeamBroadcastTester<double>(args));
|
||||
testers.push_back(new TeamBroadcastTester<char>(args));
|
||||
testers.push_back(new TeamBroadcastTester<unsigned char>(args));
|
||||
return testers;
|
||||
case AllToAllTestType:
|
||||
case TeamAllToAllTestType:
|
||||
if (rank == 0) {
|
||||
std::cout << "Alltoall Test ###" << std::endl;
|
||||
}
|
||||
testers.push_back(new AlltoallTester<int64_t>(
|
||||
args,
|
||||
[rank](int64_t& f1, int64_t& f2, int64_t dest_pe) {
|
||||
const long SRC_SHIFT = 16;
|
||||
// Make value for each src, dst pair unique
|
||||
// by shifting src by SRC_SHIFT bits
|
||||
f1 = (rank << SRC_SHIFT) + dest_pe;
|
||||
f2 = -1;
|
||||
},
|
||||
[rank](int64_t v, int64_t src_pe) {
|
||||
const long SRC_SHIFT = 16;
|
||||
// See if we obtained unique value
|
||||
long expected_val = (src_pe << SRC_SHIFT) + rank;
|
||||
|
||||
return (v == expected_val)
|
||||
? std::make_pair(true, "")
|
||||
: std::make_pair(
|
||||
false, "Rank " + std::to_string(rank) + ", Got " +
|
||||
std::to_string(v) + ", Expect " +
|
||||
std::to_string(expected_val));
|
||||
}));
|
||||
testers.push_back(new TeamAlltoallTester<int64_t>(args));
|
||||
testers.push_back(new TeamAlltoallTester<int>(args));
|
||||
testers.push_back(new TeamAlltoallTester<long long>(args));
|
||||
testers.push_back(new TeamAlltoallTester<float>(args));
|
||||
testers.push_back(new TeamAlltoallTester<double>(args));
|
||||
testers.push_back(new TeamAlltoallTester<char>(args));
|
||||
testers.push_back(new TeamAlltoallTester<unsigned char>(args));
|
||||
return testers;
|
||||
case FCollectTestType:
|
||||
case TeamFCollectTestType:
|
||||
if (rank == 0) {
|
||||
std::cout << "Fcollect Test ###" << std::endl;
|
||||
}
|
||||
testers.push_back(new FcollectTester<int64_t>(
|
||||
args,
|
||||
[rank](int64_t& f1, int64_t& f2) {
|
||||
f1 = rank;
|
||||
f2 = -1;
|
||||
},
|
||||
[rank](int64_t v, int64_t src_pe) {
|
||||
int64_t expected_val = src_pe;
|
||||
|
||||
return (v == expected_val)
|
||||
? std::make_pair(true, "")
|
||||
: std::make_pair(
|
||||
false, "Rank " + std::to_string(rank) + ", Got " +
|
||||
std::to_string(v) + ", Expect " +
|
||||
std::to_string(expected_val));
|
||||
}));
|
||||
testers.push_back(new TeamFcollectTester<int64_t>(args));
|
||||
testers.push_back(new TeamFcollectTester<int>(args));
|
||||
testers.push_back(new TeamFcollectTester<long long>(args));
|
||||
testers.push_back(new TeamFcollectTester<float>(args));
|
||||
testers.push_back(new TeamFcollectTester<double>(args));
|
||||
testers.push_back(new TeamFcollectTester<char>(args));
|
||||
testers.push_back(new TeamFcollectTester<unsigned char>(args));
|
||||
return testers;
|
||||
case AMO_FAddTestType:
|
||||
if (rank == 0) std::cout << "AMO Fetch_Add ###" << std::endl;
|
||||
@@ -525,7 +477,7 @@ bool Tester::peLaunchesKernel() {
|
||||
*/
|
||||
is_launcher = is_launcher || (_type == TeamReductionTestType) ||
|
||||
(_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) ||
|
||||
(_type == AllToAllTestType) || (_type == FCollectTestType) ||
|
||||
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
|
||||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
|
||||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
|
||||
(_type == RandomAccessTestType) || (_type == PingAllTestType);
|
||||
|
||||
@@ -53,8 +53,8 @@ enum TestType {
|
||||
SyncAllTestType = 16,
|
||||
SyncTestType = 17,
|
||||
CollectTestType = 18,
|
||||
FCollectTestType = 19,
|
||||
AllToAllTestType = 20,
|
||||
TeamFCollectTestType = 19,
|
||||
TeamAllToAllTestType = 20,
|
||||
AllToAllsTestType = 21,
|
||||
ShmemPtrTestType = 22,
|
||||
PTestType = 23,
|
||||
|
||||
@@ -85,6 +85,7 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case AMO_IncTestType:
|
||||
case AMO_FetchTestType:
|
||||
case BarrierAllTestType:
|
||||
case SyncAllTestType:
|
||||
case SyncTestType:
|
||||
case ShmemPtrTestType:
|
||||
min_msg_size = 8;
|
||||
@@ -97,6 +98,11 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case RandomAccessTestType:
|
||||
min_msg_size = 4;
|
||||
break;
|
||||
case TeamFCollectTestType:
|
||||
case TeamAllToAllTestType:
|
||||
case TeamBroadcastTestType:
|
||||
min_msg_size = 8;
|
||||
break;
|
||||
case TeamCtxInfraTestType:
|
||||
max_msg_size = min_msg_size;
|
||||
break;
|
||||
@@ -137,8 +143,8 @@ void TesterArguments::get_rocshmem_arguments() {
|
||||
|
||||
TestType type = (TestType)algorithm;
|
||||
if ((type != BarrierAllTestType) && (type != SyncAllTestType) &&
|
||||
(type != SyncTestType) && (type != AllToAllTestType) &&
|
||||
(type != FCollectTestType) && (type != TeamReductionTestType) &&
|
||||
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
|
||||
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
|
||||
(type != TeamBroadcastTestType) && (type != PingAllTestType)) {
|
||||
if (numprocs != 2) {
|
||||
if (myid == 0) {
|
||||
|
||||
Ссылка в новой задаче
Block a user