Add multi work-group support for collective functional tests (#45)

- Added multi-work group support for the All-to-all, Fcollect, Broadcast, Barrier and Sync collective functional tests
- Renamed All-to-all and Fcollect tests to TeamAlltoAll and TeamFcollect

[ROCm/rocshmem commit: 57d60aa727]
Этот коммит содержится в:
Avinash Kethineedi
2025-02-19 10:31:53 -06:00
коммит произвёл GitHub
родитель e1ed36e58f
Коммит 65b4ff4c41
14 изменённых файлов: 719 добавлений и 486 удалений
-165
Просмотреть файл
@@ -1,165 +0,0 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
using namespace rocshmem;
/* Declare the template with a generic implementation */
template <typename T>
__device__ void wg_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest,
const T *source, int nelem) {
return;
}
/* Define templates to call rocSHMEM */
#define ALLTOALL_DEF_GEN(T, TNAME) \
template <> \
__device__ void wg_alltoall<T>(rocshmem_ctx_t ctx, rocshmem_team_t team, \
T * dest, const T *source, int nelem) { \
rocshmem_ctx_##TNAME##_wg_alltoall(ctx, team, dest, source, nelem); \
}
ALLTOALL_DEF_GEN(float, float)
ALLTOALL_DEF_GEN(double, double)
ALLTOALL_DEF_GEN(char, char)
// ALLTOALL_DEF_GEN(long double, longdouble)
ALLTOALL_DEF_GEN(signed char, schar)
ALLTOALL_DEF_GEN(short, short)
ALLTOALL_DEF_GEN(int, int)
ALLTOALL_DEF_GEN(long, long)
ALLTOALL_DEF_GEN(long long, longlong)
ALLTOALL_DEF_GEN(unsigned char, uchar)
ALLTOALL_DEF_GEN(unsigned short, ushort)
ALLTOALL_DEF_GEN(unsigned int, uint)
ALLTOALL_DEF_GEN(unsigned long, ulong)
ALLTOALL_DEF_GEN(unsigned long long, ulonglong)
rocshmem_team_t team_alltoall_world_dup;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
template <typename T1>
__global__ void AlltoallTest(int loop, int skip, long long int *start_time,
long long int *end_time, T1 *source_buf,
T1 *dest_buf, int size, ShmemContextType ctx_type,
rocshmem_team_t team) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
__syncthreads();
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
wg_alltoall<T1>(ctx, team,
dest_buf, // T* dest
source_buf, // const T* source
size); // int nelement
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
template <typename T1>
AlltoallTester<T1>::AlltoallTester(
TesterArguments args, std::function<void(T1 &, T1 &, T1)> f1,
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2)
: Tester(args), init_buf{f1}, verify_buf{f2} {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
}
template <typename T1>
AlltoallTester<T1>::~AlltoallTester() {
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
}
template <typename T1>
void AlltoallTester<T1>::preLaunchKernel() {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
bw_factor = sizeof(T1) * n_pes;
team_alltoall_world_dup = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_alltoall_world_dup);
}
template <typename T1>
void AlltoallTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(AlltoallTest<T1>, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time, source_buf,
dest_buf, size, _shmem_context, team_alltoall_world_dup);
num_msgs = loop + args.skip;
num_timed_msgs = loop;
}
template <typename T1>
void AlltoallTester<T1>::postLaunchKernel() {
rocshmem_team_destroy(team_alltoall_world_dup);
}
template <typename T1>
void AlltoallTester<T1>::resetBuffers(uint64_t size) {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int i = 0; i < n_pes; i++) {
for (uint64_t j = 0; j < size; j++) {
init_buf(source_buf[i * size + j], dest_buf[i * size + j], (T1)i);
}
}
}
template <typename T1>
void AlltoallTester<T1>::verifyResults(uint64_t size) {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int i = 0; i < n_pes; i++) {
for (uint64_t j = 0; j < size; j++) {
auto r = verify_buf(dest_buf[i * size + j], i);
if (r.first == false) {
fprintf(stderr, "Data validation error at idx %lu\n", j);
fprintf(stderr, "%s.\n", r.second.c_str());
exit(-1);
}
}
}
}
+8 -2
Просмотреть файл
@@ -44,7 +44,13 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
__syncthreads();
rocshmem_ctx_wg_barrier_all(ctx);
/**
* The function `rocshmem_ctx_wg_barrier_all` should be called from only
* one group within the grid to avoid unintended behavior.
*/
if (is_block_zero_in_grid()) {
rocshmem_ctx_wg_barrier_all(ctx);
}
}
__syncthreads();
@@ -70,7 +76,7 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time);
num_msgs = (loop + args.skip) * gridSize.x;
num_msgs = loop + args.skip;
num_timed_msgs = loop;
}
-167
Просмотреть файл
@@ -1,167 +0,0 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
using namespace rocshmem;
rocshmem_team_t team_fcollect_world_dup;
/* Declare the template with a generic implementation */
template <typename T>
__device__ void wg_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest,
const T *source, int nelems) {
return;
}
/* Define templates to call rocSHMEM */
#define FCOLLECT_DEF_GEN(T, TNAME) \
template <> \
__device__ void wg_fcollect<T>(rocshmem_ctx_t ctx, rocshmem_team_t team, \
T * dest, const T *source, int nelem) { \
rocshmem_ctx_##TNAME##_wg_fcollect(ctx, team, dest, source, nelem); \
}
FCOLLECT_DEF_GEN(float, float)
FCOLLECT_DEF_GEN(double, double)
FCOLLECT_DEF_GEN(char, char)
// FCOLLECT_DEF_GEN(long double, longdouble)
FCOLLECT_DEF_GEN(signed char, schar)
FCOLLECT_DEF_GEN(short, short)
FCOLLECT_DEF_GEN(int, int)
FCOLLECT_DEF_GEN(long, long)
FCOLLECT_DEF_GEN(long long, longlong)
FCOLLECT_DEF_GEN(unsigned char, uchar)
FCOLLECT_DEF_GEN(unsigned short, ushort)
FCOLLECT_DEF_GEN(unsigned int, uint)
FCOLLECT_DEF_GEN(unsigned long, ulong)
FCOLLECT_DEF_GEN(unsigned long long, ulonglong)
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
template <typename T1>
__global__ void FcollectTest(int loop, int skip, long long int *start_time,
long long int *end_time, T1 *source_buf,
T1 *dest_buf, int size, ShmemContextType ctx_type,
rocshmem_team_t team) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
__syncthreads();
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
wg_fcollect<T1>(ctx, team,
dest_buf, // T* dest
source_buf, // const T* source
size); // int nelement
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
template <typename T1>
FcollectTester<T1>::FcollectTester(
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2)
: Tester(args), init_buf{f1}, verify_buf{f2} {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1));
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes);
}
template <typename T1>
FcollectTester<T1>::~FcollectTester() {
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
}
template <typename T1>
void FcollectTester<T1>::preLaunchKernel() {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
bw_factor = sizeof(T1) * n_pes;
team_fcollect_world_dup = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_fcollect_world_dup);
}
template <typename T1>
void FcollectTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(FcollectTest<T1>, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time, source_buf,
dest_buf, size, _shmem_context, team_fcollect_world_dup);
num_msgs = loop + args.skip;
num_timed_msgs = loop;
}
template <typename T1>
void FcollectTester<T1>::postLaunchKernel() {
rocshmem_team_destroy(team_fcollect_world_dup);
}
template <typename T1>
void FcollectTester<T1>::resetBuffers(uint64_t size) {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int i = 0; i < n_pes; i++) {
for (uint64_t j = 0; j < size; j++) {
// Note: This is redundant work,
// source is being reinitialized multiple times
init_buf(source_buf[j], dest_buf[i * size + j]);
}
}
}
template <typename T1>
void FcollectTester<T1>::verifyResults(uint64_t size) {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int i = 0; i < n_pes; i++) {
for (uint64_t j = 0; j < size; j++) {
auto r = verify_buf(dest_buf[i * size + j], i);
if (r.first == false) {
fprintf(stderr, "Data validation error at idx %lu\n", j);
fprintf(stderr, "%s.\n", r.second.c_str());
// exit(-1);
return;
}
}
}
}
+48 -15
Просмотреть файл
@@ -22,17 +22,12 @@
#include "sync_tester.hpp"
#include <rocshmem/rocshmem.hpp>
using namespace rocshmem;
rocshmem_team_t team_sync_world_dup;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void SyncTest(int loop, int skip, long long int *start_time,
long long int *end_time, TestType type,
ShmemContextType ctx_type, rocshmem_team_t team) {
ShmemContextType ctx_type, rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
@@ -47,10 +42,16 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
__syncthreads();
switch (type) {
case SyncAllTestType:
rocshmem_ctx_wg_sync_all(ctx);
/**
* The function `rocshmem_ctx_wg_sync_all` should be called from only
* one group within the grid to avoid unintended behavior.
*/
if (is_block_zero_in_grid()) {
rocshmem_ctx_wg_sync_all(ctx);
}
break;
case SyncTestType:
rocshmem_ctx_wg_team_sync(ctx, team);
rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]);
break;
default:
break;
@@ -69,28 +70,60 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
SyncTester::SyncTester(TesterArguments args) : Tester(args) {}
SyncTester::SyncTester(TesterArguments args) : Tester(args) {
SyncTester::~SyncTester() {}
char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
}
CHECK_HIP(hipMalloc(&team_sync_world_dup,
sizeof(rocshmem_team_t) * num_teams));
}
SyncTester::~SyncTester() {
CHECK_HIP(hipFree(team_sync_world_dup));
}
void SyncTester::resetBuffers(uint64_t size) {}
void SyncTester::preLaunchKernel() {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
for (int team_i = 0; team_i < num_teams; team_i++) {
team_sync_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_sync_world_dup[team_i]);
if (team_sync_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
printf("Team %d is invalid!\n", team_i);
abort();
}
}
}
void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) {
size_t shared_bytes = 0;
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
team_sync_world_dup = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_sync_world_dup);
hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop,
args.skip, start_time, end_time, _type, _shmem_context,
team_sync_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_msgs = loop + args.skip;
num_timed_msgs = loop;
if(_type == SyncTestType) {
num_msgs *= gridSize.x;
num_timed_msgs *= gridSize.x;
}
}
void SyncTester::postLaunchKernel() {
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_sync_world_dup[team_i]);
}
}
void SyncTester::verifyResults(uint64_t size) {}
+16
Просмотреть файл
@@ -23,8 +23,12 @@
#ifndef _SYNC_TESTER_HPP_
#define _SYNC_TESTER_HPP_
#include <rocshmem/rocshmem.hpp>
#include "tester.hpp"
using namespace rocshmem;
/******************************************************************************
* HOST TESTER CLASS
*****************************************************************************/
@@ -36,10 +40,22 @@ class SyncTester : public Tester {
protected:
virtual void resetBuffers(uint64_t size) override;
virtual void preLaunchKernel() override;
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) override;
virtual void postLaunchKernel() override;
virtual void verifyResults(uint64_t size) override;
private:
/**
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
* The default value for the maximum number of teams is 40.
*/
int num_teams = 39;
rocshmem_team_t *team_sync_world_dup;
};
#endif
+222
Просмотреть файл
@@ -0,0 +1,222 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
/* Declare the template with a generic implementation */
template <typename T>
__device__ void wg_team_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team,
T *dest, const T *source, int nelem) {
return;
}
/* Define templates to call rocSHMEM */
#define TEAM_ALLTOALL_DEF_GEN(T, TNAME) \
template <> \
__device__ void wg_team_alltoall<T>(rocshmem_ctx_t ctx, rocshmem_team_t team,\
T * dest, const T *source, int nelem) { \
rocshmem_ctx_##TNAME##_wg_alltoall(ctx, team, dest, source, nelem); \
}
TEAM_ALLTOALL_DEF_GEN(float, float)
TEAM_ALLTOALL_DEF_GEN(double, double)
TEAM_ALLTOALL_DEF_GEN(char, char)
// TEAM_ALLTOALL_DEF_GEN(long double, longdouble)
TEAM_ALLTOALL_DEF_GEN(signed char, schar)
TEAM_ALLTOALL_DEF_GEN(short, short)
TEAM_ALLTOALL_DEF_GEN(int, int)
TEAM_ALLTOALL_DEF_GEN(long, long)
TEAM_ALLTOALL_DEF_GEN(long long, longlong)
TEAM_ALLTOALL_DEF_GEN(unsigned char, uchar)
TEAM_ALLTOALL_DEF_GEN(unsigned short, ushort)
TEAM_ALLTOALL_DEF_GEN(unsigned int, uint)
TEAM_ALLTOALL_DEF_GEN(unsigned long, ulong)
TEAM_ALLTOALL_DEF_GEN(unsigned long long, ulonglong)
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
template <typename T1>
__global__ void TeamAlltoallTest(int loop, int skip, long long int *start_time,
long long int *end_time, T1 *source_buf,
T1 *dest_buf, int num_elems,
ShmemContextType ctx_type,
rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
source_buf += wg_id * n_pes * num_elems;
dest_buf += wg_id * n_pes * num_elems;
__syncthreads();
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
wg_team_alltoall<T1>(ctx, teams[wg_id],
dest_buf, // T* dest
source_buf, // const T* source
num_elems); // int nelement
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
template <typename T1>
TeamAlltoallTester<T1>::TeamAlltoallTester(TesterArguments args)
: Tester(args){
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
// Number of elements per work group
int num_elems_wg = (args.max_msg_size / sizeof(T1)) * n_pes;
// Total number of elements in the GPU kernel
int total_elems = num_elems_wg * args.num_wgs;
int buff_size = total_elems * sizeof(T1);
source_buf = (T1 *)rocshmem_malloc(buff_size);
dest_buf = (T1 *)rocshmem_malloc(buff_size);
if (source_buf == nullptr || dest_buf == nullptr) {
std::cout << "Error allocating memory from symmetric heap" << std::endl;
std::cout << "source: " << source_buf
<< ", dest: " << dest_buf
<< std::endl;
rocshmem_global_exit(1);
}
char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
}
CHECK_HIP(hipMalloc(&team_alltoall_world_dup,
sizeof(rocshmem_team_t) * num_teams));
}
template <typename T1>
TeamAlltoallTester<T1>::~TeamAlltoallTester() {
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
CHECK_HIP(hipFree(team_alltoall_world_dup));
}
template <typename T1>
void TeamAlltoallTester<T1>::preLaunchKernel() {
bw_factor = n_pes;
for (int team_i = 0; team_i < num_teams; team_i++) {
team_alltoall_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_alltoall_world_dup[team_i]);
if (team_alltoall_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
std::cout << "Team " << team_i << " is invalid!" << std::endl;
abort();
}
}
}
template <typename T1>
void TeamAlltoallTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
int num_elems = size / sizeof(T1);
hipLaunchKernelGGL(TeamAlltoallTest<T1>, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time,
source_buf, dest_buf, num_elems, _shmem_context,
team_alltoall_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
template <typename T1>
void TeamAlltoallTester<T1>::postLaunchKernel() {
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_alltoall_world_dup[team_i]);
}
}
template <typename T1>
void TeamAlltoallTester<T1>::resetBuffers(uint64_t size) {
int num_elems = size / sizeof(T1);
int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes;
int idx = 0;
for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for(int pe = 0; pe < n_pes; pe++) {
for(int i = 0; i < num_elems; i++) {
idx = (wg_id * n_pes + pe) * num_elems + i;
if constexpr (std::is_same<T1, char>::value ||
std::is_same<T1, signed char>::value ||
std::is_same<T1, unsigned char>::value) {
source_buf[idx] = static_cast<T1>('a' + my_pe + pe + wg_id);
}
else if constexpr (std::is_floating_point<T1>::value) {
source_buf[idx] = static_cast<T1>(3.14 + my_pe + pe + wg_id);
}
else if constexpr (std::is_integral<T1>::value) {
source_buf[idx] = static_cast<T1>(my_pe + pe + wg_id);
}
}
}
}
memset(dest_buf, -1, buff_size);
}
template <typename T1>
void TeamAlltoallTester<T1>::verifyResults(uint64_t size) {
int num_elems = size / sizeof(T1);
int idx = 0;
for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for(int pe = 0; pe < n_pes; pe++) {
for(int i = 0; i < num_elems; i++) {
idx = (wg_id * n_pes + pe) * num_elems + i;
if (dest_buf[idx] != source_buf[idx]) {
std::cerr << "Data validation error at idx " << idx << std::endl;
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
<< ", Expected " << source_buf[idx] << std::endl;
exit(-1);
}
}
}
}
}
@@ -28,16 +28,16 @@
#include "tester.hpp"
using namespace rocshmem;
/************* *****************************************************************
* HOST TESTER CLASS
*****************************************************************************/
template <typename T1>
class AlltoallTester : public Tester {
class TeamAlltoallTester : public Tester {
public:
explicit AlltoallTester(
TesterArguments args, std::function<void(T1 &, T1 &, T1)> f1,
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2);
virtual ~AlltoallTester();
explicit TeamAlltoallTester(TesterArguments args);
virtual ~TeamAlltoallTester();
protected:
virtual void resetBuffers(uint64_t size) override;
@@ -51,14 +51,21 @@ class AlltoallTester : public Tester {
virtual void verifyResults(uint64_t size) override;
T1 *source_buf;
T1 *dest_buf;
T1 *source_buf = nullptr;
T1 *dest_buf = nullptr;
private:
std::function<void(T1 &, T1 &, T1)> init_buf;
std::function<std::pair<bool, std::string>(const T1 &, T1)> verify_buf;
int my_pe = 0;
int n_pes = 0;
/**
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
* The default value for the maximum number of teams is 40.
*/
int num_teams = 39;
rocshmem_team_t *team_alltoall_world_dup;
};
#include "alltoall_tester.cpp"
#include "team_alltoall_tester.cpp"
#endif
+115 -36
Просмотреть файл
@@ -20,8 +20,6 @@
* IN THE SOFTWARE.
*****************************************************************************/
using namespace rocshmem;
/* Declare the template with a generic implementation */
template <typename T>
__device__ void wg_team_broadcast(rocshmem_ctx_t ctx, rocshmem_team_t team,
@@ -65,28 +63,29 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time,
long long int *end_time, T1 *source_buf,
T1 *dest_buf, int size,
ShmemContextType ctx_type,
rocshmem_team_t team) {
rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
source_buf += wg_id * size;
dest_buf += wg_id * size;
__syncthreads();
for (int i = 0; i < loop; i++) {
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
wg_team_broadcast<T1>(ctx, team,
dest_buf, // T* dest
source_buf, // const T* source
size, // int nelement
0); // int PE_root
rocshmem_ctx_wg_barrier_all(ctx);
wg_team_broadcast<T1>(ctx, teams[wg_id],
dest_buf, // T* dest
source_buf, // const T* source
size, // int nelement
0); // int PE_root
}
__syncthreads();
@@ -103,27 +102,51 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time,
* HOST TESTER CLASS METHODS
*****************************************************************************/
template <typename T1>
TeamBroadcastTester<T1>::TeamBroadcastTester(
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
std::function<std::pair<bool, std::string>(const T1 &)> f2)
: Tester(args), init_buf{f1}, verify_buf{f2} {
source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1));
dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1));
TeamBroadcastTester<T1>::TeamBroadcastTester(TesterArguments args)
: Tester(args){
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
// Total number of elements in src buffer
int total_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs ;
int buff_size = total_elems * sizeof(T1);
source_buf = (T1 *)rocshmem_malloc(buff_size);
dest_buf = (T1 *)rocshmem_malloc(buff_size);
if (source_buf == nullptr || dest_buf == nullptr) {
std::cout << "Error allocating memory from symmetric heap" << std::endl;
std::cout << "source: " << source_buf << ", dest: " << dest_buf << std::endl;
rocshmem_global_exit(1);
}
char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
}
CHECK_HIP(hipMalloc(&team_bcast_world_dup,
sizeof(rocshmem_team_t) * num_teams));
}
template <typename T1>
TeamBroadcastTester<T1>::~TeamBroadcastTester() {
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
CHECK_HIP(hipFree(team_bcast_world_dup));
}
template <typename T1>
void TeamBroadcastTester<T1>::preLaunchKernel() {
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
team_bcast_world_dup = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_bcast_world_dup);
for (int team_i = 0; team_i < num_teams; team_i++) {
team_bcast_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_bcast_world_dup[team_i]);
if (team_bcast_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
printf("Team %d is invalid!\n", team_i);
abort();
}
}
}
template <typename T1>
@@ -131,34 +154,90 @@ void TeamBroadcastTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(TeamBroadcastTest<T1>, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time, source_buf,
dest_buf, size, _shmem_context, team_bcast_world_dup);
int num_elems = size / sizeof(T1);
num_msgs = loop + args.skip;
num_timed_msgs = loop;
hipLaunchKernelGGL(TeamBroadcastTest<T1>, gridSize, blockSize,
shared_bytes, stream, loop, args.skip,
start_time, end_time, source_buf, dest_buf,
num_elems, _shmem_context, team_bcast_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
template <typename T1>
void TeamBroadcastTester<T1>::postLaunchKernel() {
rocshmem_team_destroy(team_bcast_world_dup);
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_bcast_world_dup[team_i]);
}
}
template <typename T1>
void TeamBroadcastTester<T1>::resetBuffers(uint64_t size) {
for (uint64_t i = 0; i < args.max_msg_size; i++) {
init_buf(source_buf[i], dest_buf[i]);
int num_elems = size / sizeof(T1);
int buff_size = num_elems * sizeof(T1) * args.num_wgs;
int idx = 0;
for (int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for (int i = 0; i < num_elems; i++) {
idx = wg_id * num_elems + i;
if constexpr (std::is_same<T1, char>::value ||
std::is_same<T1, signed char>::value ||
std::is_same<T1, unsigned char>::value) {
source_buf[idx] = static_cast<T1>('a' + n_pes + wg_id);
dest_buf[idx] = static_cast<T1>('a' + wg_id);
}
else if constexpr (std::is_floating_point<T1>::value) {
source_buf[idx] = static_cast<T1>(3.14 + n_pes + wg_id);
dest_buf[idx] = static_cast<T1>(3.14 + wg_id);
}
else if constexpr (std::is_integral<T1>::value) {
source_buf[idx] = static_cast<T1>(n_pes + wg_id);
dest_buf[idx] = static_cast<T1>(wg_id);
}
}
}
}
template <typename T1>
void TeamBroadcastTester<T1>::verifyResults(uint64_t size) {
for (uint64_t i = 0; i < size; i++) {
auto r = verify_buf(dest_buf[i]);
if (r.first == false) {
fprintf(stderr, "Data validation error at idx %lu\n", i);
fprintf(stderr, "%s.\n", r.second.c_str());
exit(-1);
int num_elems = size / sizeof(T1);
int idx = 0;
T1 expected;
/**
* The verification routine here requires that the
* PE_root value is 0 which denotes that the
* sending processing element is rank 0.
*
* The difference in expected values arises from
* the specification for broadcast where the
* PE_root processing element does not copy the
* contents from its own source to dest during
* the broadcast.
*/
for (int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for (int i = 0; i < num_elems; i++) {
idx = wg_id * num_elems + i;
if constexpr (std::is_same<T1, char>::value ||
std::is_same<T1, signed char>::value ||
std::is_same<T1, unsigned char>::value) {
expected = static_cast<T1>('a' + wg_id + (my_pe ? n_pes : 0));
}
else if constexpr (std::is_floating_point<T1>::value) {
expected = static_cast<T1>(3.14 + wg_id + (my_pe ? n_pes : 0));
}
else if constexpr (std::is_integral<T1>::value) {
expected = static_cast<T1>(wg_id + (my_pe ? n_pes : 0));
}
if (dest_buf[idx] != expected) {
std::cerr << "Data validation error at idx " << idx << std::endl;
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
<< ", Expected " << expected << std::endl;
exit(-1);
}
}
}
}
+11 -5
Просмотреть файл
@@ -28,15 +28,15 @@
#include "tester.hpp"
using namespace rocshmem;
/************* *****************************************************************
* HOST TESTER CLASS
*****************************************************************************/
template <typename T1>
class TeamBroadcastTester : public Tester {
public:
explicit TeamBroadcastTester(
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
std::function<std::pair<bool, std::string>(const T1 &)> f2);
explicit TeamBroadcastTester(TesterArguments args);
virtual ~TeamBroadcastTester();
protected:
@@ -55,8 +55,14 @@ class TeamBroadcastTester : public Tester {
T1 *dest_buf;
private:
std::function<void(T1 &, T1 &)> init_buf;
std::function<std::pair<bool, std::string>(const T1 &)> verify_buf;
int my_pe = 0;
int n_pes = 0;
/**
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
* The default value for the maximum number of teams is 40.
*/
int num_teams = 39;
rocshmem_team_t *team_bcast_world_dup;
};
#include "team_broadcast_tester.cpp"
+232
Просмотреть файл
@@ -0,0 +1,232 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
/* Declare the template with a generic implementation */
template <typename T>
__device__ void wg_team_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team,
T *dest, const T *source, int nelems) {
return;
}
/* Define templates to call rocSHMEM */
#define TEAM_FCOLLECT_DEF_GEN(T, TNAME) \
template <> \
__device__ void wg_team_fcollect<T>(rocshmem_ctx_t ctx, rocshmem_team_t team,\
T * dest, const T *source, int nelem) { \
rocshmem_ctx_##TNAME##_wg_fcollect(ctx, team, dest, source, nelem); \
}
TEAM_FCOLLECT_DEF_GEN(float, float)
TEAM_FCOLLECT_DEF_GEN(double, double)
TEAM_FCOLLECT_DEF_GEN(char, char)
// TEAM_FCOLLECT_DEF_GEN(long double, longdouble)
TEAM_FCOLLECT_DEF_GEN(signed char, schar)
TEAM_FCOLLECT_DEF_GEN(short, short)
TEAM_FCOLLECT_DEF_GEN(int, int)
TEAM_FCOLLECT_DEF_GEN(long, long)
TEAM_FCOLLECT_DEF_GEN(long long, longlong)
TEAM_FCOLLECT_DEF_GEN(unsigned char, uchar)
TEAM_FCOLLECT_DEF_GEN(unsigned short, ushort)
TEAM_FCOLLECT_DEF_GEN(unsigned int, uint)
TEAM_FCOLLECT_DEF_GEN(unsigned long, ulong)
TEAM_FCOLLECT_DEF_GEN(unsigned long long, ulonglong)
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
template <typename T1>
__global__ void TeamFcollectTest(int loop, int skip, long long int *start_time,
long long int *end_time, T1 *source_buf,
T1 *dest_buf, int num_elems,
ShmemContextType ctx_type,
rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
source_buf += wg_id * num_elems;
dest_buf += wg_id * num_elems * n_pes;
__syncthreads();
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
wg_team_fcollect<T1>(ctx, teams[wg_id],
dest_buf, // T* dest
source_buf, // const T* source
num_elems); // int nelement
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
template <typename T1>
TeamFcollectTester<T1>::TeamFcollectTester(TesterArguments args)
: Tester(args) {
my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
// Total number of elements in src buffer
int total_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs ;
int buff_size = total_elems * sizeof(T1);
source_buf = (T1 *)rocshmem_malloc(buff_size);
dest_buf = (T1 *)rocshmem_malloc(buff_size * n_pes);
if (source_buf == nullptr || dest_buf == nullptr) {
std::cout << "Error allocating memory from symmetric heap" << std::endl;
std::cout << "source: " << source_buf
<< ", dest: " << dest_buf
<< std::endl;
rocshmem_global_exit(1);
}
if constexpr (std::is_same<T1, char>::value ||
std::is_same<T1, signed char>::value ||
std::is_same<T1, unsigned char>::value) {
for (int i = 0; i < total_elems; ++i) {
source_buf[i] = static_cast<T1>('a' + my_pe);
}
}
else if constexpr (std::is_floating_point<T1>::value) {
for (int i = 0; i < total_elems; ++i) {
source_buf[i] = static_cast<T1>(3.14 + my_pe);
}
}
else if constexpr (std::is_integral<T1>::value) {
for (int i = 0; i < total_elems; i++) {
source_buf[i] = static_cast<T1>(my_pe);
}
}
char* value{nullptr};
if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) {
num_teams = atoi(value);
}
CHECK_HIP(hipMalloc(&team_fcollect_world_dup,
sizeof(rocshmem_team_t) * num_teams));
}
template <typename T1>
TeamFcollectTester<T1>::~TeamFcollectTester() {
rocshmem_free(source_buf);
rocshmem_free(dest_buf);
CHECK_HIP(hipFree(team_fcollect_world_dup));
}
template <typename T1>
void TeamFcollectTester<T1>::preLaunchKernel() {
bw_factor = n_pes;
for (int team_i = 0; team_i < num_teams; team_i++) {
team_fcollect_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0,
&team_fcollect_world_dup[team_i]);
if (team_fcollect_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
std::cout << "Team " << team_i << " is invalid!" << std::endl;
abort();
}
}
}
template <typename T1>
void TeamFcollectTester<T1>::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
int num_elems = size / sizeof(T1);
int my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD);
int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD);
hipLaunchKernelGGL(TeamFcollectTest<T1>, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time,
source_buf, dest_buf, num_elems, _shmem_context,
team_fcollect_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
template <typename T1>
void TeamFcollectTester<T1>::postLaunchKernel() {
for (int team_i = 0; team_i < num_teams; team_i++) {
rocshmem_team_destroy(team_fcollect_world_dup[team_i]);
}
}
template <typename T1>
void TeamFcollectTester<T1>::resetBuffers(uint64_t size) {
int num_elems = (size / sizeof(T1));
int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes;
memset(dest_buf, -1, buff_size);
}
template <typename T1>
void TeamFcollectTester<T1>::verifyResults(uint64_t size) {
int num_elems = size / sizeof(T1);
int idx = 0;
T1 expected;
for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) {
for(int pe = 0; pe < n_pes; pe++) {
for(int i = 0; i < num_elems; i++) {
idx = (wg_id * n_pes + pe) * num_elems + i;
if constexpr (std::is_same<T1, char>::value ||
std::is_same<T1, signed char>::value ||
std::is_same<T1, unsigned char>::value) {
expected = static_cast<T1>('a' + pe);
}
else if constexpr (std::is_floating_point<T1>::value) {
expected = static_cast<T1>(3.14 + pe);
}
else if constexpr (std::is_integral<T1>::value) {
expected = pe;
}
if (dest_buf[idx] != expected) {
std::cerr << "Data validation error at idx " << idx << std::endl;
std::cerr << "PE " << my_pe << " Got " << dest_buf[idx]
<< ", Expected " << expected << std::endl;
exit(-1);
}
}
}
}
}
@@ -28,16 +28,16 @@
#include "tester.hpp"
using namespace rocshmem;
/************* *****************************************************************
* HOST TESTER CLASS
*****************************************************************************/
template <typename T1>
class FcollectTester : public Tester {
class TeamFcollectTester : public Tester {
public:
explicit FcollectTester(
TesterArguments args, std::function<void(T1 &, T1 &)> f1,
std::function<std::pair<bool, std::string>(const T1 &, T1)> f2);
virtual ~FcollectTester();
explicit TeamFcollectTester(TesterArguments args);
virtual ~TeamFcollectTester();
protected:
virtual void resetBuffers(uint64_t size) override;
@@ -55,10 +55,16 @@ class FcollectTester : public Tester {
T1 *dest_buf;
private:
std::function<void(T1 &, T1 &)> init_buf;
std::function<std::pair<bool, std::string>(const T1 &, T1)> verify_buf;
int my_pe = 0;
int n_pes = 0;
/**
* This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1.
* The default value for the maximum number of teams is 40.
*/
int num_teams = 39;
rocshmem_team_t *team_fcollect_world_dup;
};
#include "fcollect_tester.cpp"
#include "team_fcollect_tester.cpp"
#endif
+26 -74
Просмотреть файл
@@ -30,14 +30,12 @@
#include <rocshmem/rocshmem.hpp>
#include <vector>
#include "alltoall_tester.hpp"
#include "amo_bitwise_tester.hpp"
#include "amo_extended_tester.hpp"
#include "amo_standard_tester.hpp"
#include "barrier_all_tester.hpp"
#include "empty_tester.hpp"
#include "extended_primitives.hpp"
#include "fcollect_tester.hpp"
#include "ping_all_tester.hpp"
#include "ping_pong_tester.hpp"
#include "primitive_mr_tester.hpp"
@@ -47,9 +45,11 @@
#include "signaling_operations_tester.hpp"
#include "swarm_tester.hpp"
#include "sync_tester.hpp"
#include "team_alltoall_tester.hpp"
#include "team_broadcast_tester.hpp"
#include "team_ctx_infra_tester.hpp"
#include "team_ctx_primitive_tester.hpp"
#include "team_fcollect_tester.hpp"
#include "team_reduction_tester.hpp"
#include "wave_level_primitives.hpp"
@@ -162,85 +162,37 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) {
std::cout << "Team Broadcast Test ###" << std::endl;
}
testers.push_back(new TeamBroadcastTester<long>(
args,
[](long& f1, long& f2) {
f1 = 1;
f2 = 2;
},
[rank](long v) {
long expected_val;
/**
* The verification routine here requires that the
* PE_root value is 0 which denotes that the
* sending processing element is rank 0.
*
* The difference in expected values arises from
* the specification for broadcast where the
* PE_root processing element does not copy the
* contents from its own source to dest during
* the broadcast.
*/
if (rank == 0) {
expected_val = 2;
} else {
expected_val = 1;
}
return (v == expected_val)
? std::make_pair(true, "")
: std::make_pair(
false, "Rank " + std::to_string(rank) + ", Got " +
std::to_string(v) + ", Expect " +
std::to_string(expected_val));
}));
testers.push_back(new TeamBroadcastTester<int64_t>(args));
testers.push_back(new TeamBroadcastTester<int>(args));
testers.push_back(new TeamBroadcastTester<long long>(args));
testers.push_back(new TeamBroadcastTester<float>(args));
testers.push_back(new TeamBroadcastTester<double>(args));
testers.push_back(new TeamBroadcastTester<char>(args));
testers.push_back(new TeamBroadcastTester<unsigned char>(args));
return testers;
case AllToAllTestType:
case TeamAllToAllTestType:
if (rank == 0) {
std::cout << "Alltoall Test ###" << std::endl;
}
testers.push_back(new AlltoallTester<int64_t>(
args,
[rank](int64_t& f1, int64_t& f2, int64_t dest_pe) {
const long SRC_SHIFT = 16;
// Make value for each src, dst pair unique
// by shifting src by SRC_SHIFT bits
f1 = (rank << SRC_SHIFT) + dest_pe;
f2 = -1;
},
[rank](int64_t v, int64_t src_pe) {
const long SRC_SHIFT = 16;
// See if we obtained unique value
long expected_val = (src_pe << SRC_SHIFT) + rank;
return (v == expected_val)
? std::make_pair(true, "")
: std::make_pair(
false, "Rank " + std::to_string(rank) + ", Got " +
std::to_string(v) + ", Expect " +
std::to_string(expected_val));
}));
testers.push_back(new TeamAlltoallTester<int64_t>(args));
testers.push_back(new TeamAlltoallTester<int>(args));
testers.push_back(new TeamAlltoallTester<long long>(args));
testers.push_back(new TeamAlltoallTester<float>(args));
testers.push_back(new TeamAlltoallTester<double>(args));
testers.push_back(new TeamAlltoallTester<char>(args));
testers.push_back(new TeamAlltoallTester<unsigned char>(args));
return testers;
case FCollectTestType:
case TeamFCollectTestType:
if (rank == 0) {
std::cout << "Fcollect Test ###" << std::endl;
}
testers.push_back(new FcollectTester<int64_t>(
args,
[rank](int64_t& f1, int64_t& f2) {
f1 = rank;
f2 = -1;
},
[rank](int64_t v, int64_t src_pe) {
int64_t expected_val = src_pe;
return (v == expected_val)
? std::make_pair(true, "")
: std::make_pair(
false, "Rank " + std::to_string(rank) + ", Got " +
std::to_string(v) + ", Expect " +
std::to_string(expected_val));
}));
testers.push_back(new TeamFcollectTester<int64_t>(args));
testers.push_back(new TeamFcollectTester<int>(args));
testers.push_back(new TeamFcollectTester<long long>(args));
testers.push_back(new TeamFcollectTester<float>(args));
testers.push_back(new TeamFcollectTester<double>(args));
testers.push_back(new TeamFcollectTester<char>(args));
testers.push_back(new TeamFcollectTester<unsigned char>(args));
return testers;
case AMO_FAddTestType:
if (rank == 0) std::cout << "AMO Fetch_Add ###" << std::endl;
@@ -525,7 +477,7 @@ bool Tester::peLaunchesKernel() {
*/
is_launcher = is_launcher || (_type == TeamReductionTestType) ||
(_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) ||
(_type == AllToAllTestType) || (_type == FCollectTestType) ||
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
(_type == RandomAccessTestType) || (_type == PingAllTestType);
+2 -2
Просмотреть файл
@@ -53,8 +53,8 @@ enum TestType {
SyncAllTestType = 16,
SyncTestType = 17,
CollectTestType = 18,
FCollectTestType = 19,
AllToAllTestType = 20,
TeamFCollectTestType = 19,
TeamAllToAllTestType = 20,
AllToAllsTestType = 21,
ShmemPtrTestType = 22,
PTestType = 23,
+8 -2
Просмотреть файл
@@ -85,6 +85,7 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case AMO_IncTestType:
case AMO_FetchTestType:
case BarrierAllTestType:
case SyncAllTestType:
case SyncTestType:
case ShmemPtrTestType:
min_msg_size = 8;
@@ -97,6 +98,11 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case RandomAccessTestType:
min_msg_size = 4;
break;
case TeamFCollectTestType:
case TeamAllToAllTestType:
case TeamBroadcastTestType:
min_msg_size = 8;
break;
case TeamCtxInfraTestType:
max_msg_size = min_msg_size;
break;
@@ -137,8 +143,8 @@ void TesterArguments::get_rocshmem_arguments() {
TestType type = (TestType)algorithm;
if ((type != BarrierAllTestType) && (type != SyncAllTestType) &&
(type != SyncTestType) && (type != AllToAllTestType) &&
(type != FCollectTestType) && (type != TeamReductionTestType) &&
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
(type != TeamBroadcastTestType) && (type != PingAllTestType)) {
if (numprocs != 2) {
if (myid == 0) {