diff --git a/projects/rocshmem/tests/functional_tests/alltoall_tester.cpp b/projects/rocshmem/tests/functional_tests/alltoall_tester.cpp deleted file mode 100644 index e4ff9b7e14..0000000000 --- a/projects/rocshmem/tests/functional_tests/alltoall_tester.cpp +++ /dev/null @@ -1,165 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - *****************************************************************************/ - -using namespace rocshmem; - -/* Declare the template with a generic implementation */ -template -__device__ void wg_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, - const T *source, int nelem) { - return; -} - -/* Define templates to call rocSHMEM */ -#define ALLTOALL_DEF_GEN(T, TNAME) \ - template <> \ - __device__ void wg_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team, \ - T * dest, const T *source, int nelem) { \ - rocshmem_ctx_##TNAME##_wg_alltoall(ctx, team, dest, source, nelem); \ - } - -ALLTOALL_DEF_GEN(float, float) -ALLTOALL_DEF_GEN(double, double) -ALLTOALL_DEF_GEN(char, char) -// ALLTOALL_DEF_GEN(long double, longdouble) -ALLTOALL_DEF_GEN(signed char, schar) -ALLTOALL_DEF_GEN(short, short) -ALLTOALL_DEF_GEN(int, int) -ALLTOALL_DEF_GEN(long, long) -ALLTOALL_DEF_GEN(long long, longlong) -ALLTOALL_DEF_GEN(unsigned char, uchar) -ALLTOALL_DEF_GEN(unsigned short, ushort) -ALLTOALL_DEF_GEN(unsigned int, uint) -ALLTOALL_DEF_GEN(unsigned long, ulong) -ALLTOALL_DEF_GEN(unsigned long long, ulonglong) - -rocshmem_team_t team_alltoall_world_dup; - -/****************************************************************************** - * DEVICE TEST KERNEL - *****************************************************************************/ -template -__global__ void AlltoallTest(int loop, int skip, long long int *start_time, - long long int *end_time, T1 *source_buf, - T1 *dest_buf, int size, ShmemContextType ctx_type, - rocshmem_team_t team) { - __shared__ rocshmem_ctx_t ctx; - int wg_id = get_flat_grid_id(); - - rocshmem_wg_init(); - rocshmem_wg_ctx_create(ctx_type, &ctx); - - int n_pes = rocshmem_ctx_n_pes(ctx); - - __syncthreads(); - - for (int i = 0; i < loop + skip; i++) { - if (i == skip && hipThreadIdx_x == 0) { - start_time[wg_id] = wall_clock64(); - } - wg_alltoall(ctx, team, - dest_buf, // T* dest - source_buf, // const T* source - size); // int nelement - } - - __syncthreads(); - - if (hipThreadIdx_x == 0) { - end_time[wg_id] = wall_clock64(); - } - - rocshmem_wg_ctx_destroy(&ctx); - rocshmem_wg_finalize(); -} - -/****************************************************************************** - * HOST TESTER CLASS METHODS - *****************************************************************************/ -template -AlltoallTester::AlltoallTester( - TesterArguments args, std::function f1, - std::function(const T1 &, T1)> f2) - : Tester(args), init_buf{f1}, verify_buf{f2} { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes); - dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes); -} - -template -AlltoallTester::~AlltoallTester() { - rocshmem_free(source_buf); - rocshmem_free(dest_buf); -} - -template -void AlltoallTester::preLaunchKernel() { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - bw_factor = sizeof(T1) * n_pes; - - team_alltoall_world_dup = ROCSHMEM_TEAM_INVALID; - rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, - &team_alltoall_world_dup); -} - -template -void AlltoallTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, - uint64_t size) { - size_t shared_bytes = 0; - - hipLaunchKernelGGL(AlltoallTest, gridSize, blockSize, shared_bytes, - stream, loop, args.skip, start_time, end_time, source_buf, - dest_buf, size, _shmem_context, team_alltoall_world_dup); - - num_msgs = loop + args.skip; - num_timed_msgs = loop; -} - -template -void AlltoallTester::postLaunchKernel() { - rocshmem_team_destroy(team_alltoall_world_dup); -} - -template -void AlltoallTester::resetBuffers(uint64_t size) { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - for (int i = 0; i < n_pes; i++) { - for (uint64_t j = 0; j < size; j++) { - init_buf(source_buf[i * size + j], dest_buf[i * size + j], (T1)i); - } - } -} - -template -void AlltoallTester::verifyResults(uint64_t size) { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - for (int i = 0; i < n_pes; i++) { - for (uint64_t j = 0; j < size; j++) { - auto r = verify_buf(dest_buf[i * size + j], i); - if (r.first == false) { - fprintf(stderr, "Data validation error at idx %lu\n", j); - fprintf(stderr, "%s.\n", r.second.c_str()); - exit(-1); - } - } - } -} diff --git a/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp b/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp index 01bb2bd2a4..49f011a3d5 100644 --- a/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp @@ -44,7 +44,13 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time, __syncthreads(); - rocshmem_ctx_wg_barrier_all(ctx); + /** + * The function `rocshmem_ctx_wg_barrier_all` should be called from only + * one group within the grid to avoid unintended behavior. + */ + if (is_block_zero_in_grid()) { + rocshmem_ctx_wg_barrier_all(ctx); + } } __syncthreads(); @@ -70,7 +76,7 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, start_time, end_time); - num_msgs = (loop + args.skip) * gridSize.x; + num_msgs = loop + args.skip; num_timed_msgs = loop; } diff --git a/projects/rocshmem/tests/functional_tests/fcollect_tester.cpp b/projects/rocshmem/tests/functional_tests/fcollect_tester.cpp deleted file mode 100644 index 3f0efa3445..0000000000 --- a/projects/rocshmem/tests/functional_tests/fcollect_tester.cpp +++ /dev/null @@ -1,167 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - *****************************************************************************/ - -using namespace rocshmem; - -rocshmem_team_t team_fcollect_world_dup; - -/* Declare the template with a generic implementation */ -template -__device__ void wg_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, - const T *source, int nelems) { - return; -} - -/* Define templates to call rocSHMEM */ -#define FCOLLECT_DEF_GEN(T, TNAME) \ - template <> \ - __device__ void wg_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team, \ - T * dest, const T *source, int nelem) { \ - rocshmem_ctx_##TNAME##_wg_fcollect(ctx, team, dest, source, nelem); \ - } - -FCOLLECT_DEF_GEN(float, float) -FCOLLECT_DEF_GEN(double, double) -FCOLLECT_DEF_GEN(char, char) -// FCOLLECT_DEF_GEN(long double, longdouble) -FCOLLECT_DEF_GEN(signed char, schar) -FCOLLECT_DEF_GEN(short, short) -FCOLLECT_DEF_GEN(int, int) -FCOLLECT_DEF_GEN(long, long) -FCOLLECT_DEF_GEN(long long, longlong) -FCOLLECT_DEF_GEN(unsigned char, uchar) -FCOLLECT_DEF_GEN(unsigned short, ushort) -FCOLLECT_DEF_GEN(unsigned int, uint) -FCOLLECT_DEF_GEN(unsigned long, ulong) -FCOLLECT_DEF_GEN(unsigned long long, ulonglong) - -/****************************************************************************** - * DEVICE TEST KERNEL - *****************************************************************************/ -template -__global__ void FcollectTest(int loop, int skip, long long int *start_time, - long long int *end_time, T1 *source_buf, - T1 *dest_buf, int size, ShmemContextType ctx_type, - rocshmem_team_t team) { - __shared__ rocshmem_ctx_t ctx; - int wg_id = get_flat_grid_id(); - - rocshmem_wg_init(); - rocshmem_wg_ctx_create(ctx_type, &ctx); - - int n_pes = rocshmem_ctx_n_pes(ctx); - __syncthreads(); - - for (int i = 0; i < loop + skip; i++) { - if (i == skip && hipThreadIdx_x == 0) { - start_time[wg_id] = wall_clock64(); - } - wg_fcollect(ctx, team, - dest_buf, // T* dest - source_buf, // const T* source - size); // int nelement - } - - __syncthreads(); - - if (hipThreadIdx_x == 0) { - end_time[wg_id] = wall_clock64(); - } - - rocshmem_wg_ctx_destroy(&ctx); - rocshmem_wg_finalize(); -} - -/****************************************************************************** - * HOST TESTER CLASS METHODS - *****************************************************************************/ -template -FcollectTester::FcollectTester( - TesterArguments args, std::function f1, - std::function(const T1 &, T1)> f2) - : Tester(args), init_buf{f1}, verify_buf{f2} { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1)); - dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1) * n_pes); -} - -template -FcollectTester::~FcollectTester() { - rocshmem_free(source_buf); - rocshmem_free(dest_buf); -} - -template -void FcollectTester::preLaunchKernel() { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - bw_factor = sizeof(T1) * n_pes; - - team_fcollect_world_dup = ROCSHMEM_TEAM_INVALID; - rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, - &team_fcollect_world_dup); -} - -template -void FcollectTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, - uint64_t size) { - size_t shared_bytes = 0; - - hipLaunchKernelGGL(FcollectTest, gridSize, blockSize, shared_bytes, - stream, loop, args.skip, start_time, end_time, source_buf, - dest_buf, size, _shmem_context, team_fcollect_world_dup); - - num_msgs = loop + args.skip; - num_timed_msgs = loop; -} - -template -void FcollectTester::postLaunchKernel() { - rocshmem_team_destroy(team_fcollect_world_dup); -} - -template -void FcollectTester::resetBuffers(uint64_t size) { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - for (int i = 0; i < n_pes; i++) { - for (uint64_t j = 0; j < size; j++) { - // Note: This is redundant work, - // source is being reinitialized multiple times - init_buf(source_buf[j], dest_buf[i * size + j]); - } - } -} - -template -void FcollectTester::verifyResults(uint64_t size) { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - for (int i = 0; i < n_pes; i++) { - for (uint64_t j = 0; j < size; j++) { - auto r = verify_buf(dest_buf[i * size + j], i); - if (r.first == false) { - fprintf(stderr, "Data validation error at idx %lu\n", j); - fprintf(stderr, "%s.\n", r.second.c_str()); - // exit(-1); - return; - } - } - } -} diff --git a/projects/rocshmem/tests/functional_tests/sync_tester.cpp b/projects/rocshmem/tests/functional_tests/sync_tester.cpp index a74ae78d8e..1843cccc35 100644 --- a/projects/rocshmem/tests/functional_tests/sync_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/sync_tester.cpp @@ -22,17 +22,12 @@ #include "sync_tester.hpp" -#include - -using namespace rocshmem; -rocshmem_team_t team_sync_world_dup; - /****************************************************************************** * DEVICE TEST KERNEL *****************************************************************************/ __global__ void SyncTest(int loop, int skip, long long int *start_time, long long int *end_time, TestType type, - ShmemContextType ctx_type, rocshmem_team_t team) { + ShmemContextType ctx_type, rocshmem_team_t *teams) { __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); @@ -47,10 +42,16 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time, __syncthreads(); switch (type) { case SyncAllTestType: - rocshmem_ctx_wg_sync_all(ctx); + /** + * The function `rocshmem_ctx_wg_sync_all` should be called from only + * one group within the grid to avoid unintended behavior. + */ + if (is_block_zero_in_grid()) { + rocshmem_ctx_wg_sync_all(ctx); + } break; case SyncTestType: - rocshmem_ctx_wg_team_sync(ctx, team); + rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]); break; default: break; @@ -69,28 +70,60 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time, /****************************************************************************** * HOST TESTER CLASS METHODS *****************************************************************************/ -SyncTester::SyncTester(TesterArguments args) : Tester(args) {} +SyncTester::SyncTester(TesterArguments args) : Tester(args) { -SyncTester::~SyncTester() {} + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } + + CHECK_HIP(hipMalloc(&team_sync_world_dup, + sizeof(rocshmem_team_t) * num_teams)); +} + +SyncTester::~SyncTester() { + CHECK_HIP(hipFree(team_sync_world_dup)); +} void SyncTester::resetBuffers(uint64_t size) {} +void SyncTester::preLaunchKernel() { + int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + for (int team_i = 0; team_i < num_teams; team_i++) { + team_sync_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_sync_world_dup[team_i]); + if (team_sync_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + printf("Team %d is invalid!\n", team_i); + abort(); + } + } +} + void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - team_sync_world_dup = ROCSHMEM_TEAM_INVALID; - rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, - &team_sync_world_dup); - hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, start_time, end_time, _type, _shmem_context, team_sync_world_dup); - num_msgs = (loop + args.skip) * gridSize.x; + num_msgs = loop + args.skip; num_timed_msgs = loop; + + if(_type == SyncTestType) { + num_msgs *= gridSize.x; + num_timed_msgs *= gridSize.x; + } +} + +void SyncTester::postLaunchKernel() { + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_sync_world_dup[team_i]); + } } void SyncTester::verifyResults(uint64_t size) {} diff --git a/projects/rocshmem/tests/functional_tests/sync_tester.hpp b/projects/rocshmem/tests/functional_tests/sync_tester.hpp index 764fd5e27f..fdd7d6f6ca 100644 --- a/projects/rocshmem/tests/functional_tests/sync_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/sync_tester.hpp @@ -23,8 +23,12 @@ #ifndef _SYNC_TESTER_HPP_ #define _SYNC_TESTER_HPP_ +#include + #include "tester.hpp" +using namespace rocshmem; + /****************************************************************************** * HOST TESTER CLASS *****************************************************************************/ @@ -36,10 +40,22 @@ class SyncTester : public Tester { protected: virtual void resetBuffers(uint64_t size) override; + virtual void preLaunchKernel() override; + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) override; + virtual void postLaunchKernel() override; + virtual void verifyResults(uint64_t size) override; + + private: + /** + * This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1. + * The default value for the maximum number of teams is 40. + */ + int num_teams = 39; + rocshmem_team_t *team_sync_world_dup; }; #endif diff --git a/projects/rocshmem/tests/functional_tests/team_alltoall_tester.cpp b/projects/rocshmem/tests/functional_tests/team_alltoall_tester.cpp new file mode 100644 index 0000000000..fd3cc28e00 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/team_alltoall_tester.cpp @@ -0,0 +1,222 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +/* Declare the template with a generic implementation */ +template +__device__ void wg_team_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team, + T *dest, const T *source, int nelem) { + return; +} + +/* Define templates to call rocSHMEM */ +#define TEAM_ALLTOALL_DEF_GEN(T, TNAME) \ + template <> \ + __device__ void wg_team_alltoall(rocshmem_ctx_t ctx, rocshmem_team_t team,\ + T * dest, const T *source, int nelem) { \ + rocshmem_ctx_##TNAME##_wg_alltoall(ctx, team, dest, source, nelem); \ + } + +TEAM_ALLTOALL_DEF_GEN(float, float) +TEAM_ALLTOALL_DEF_GEN(double, double) +TEAM_ALLTOALL_DEF_GEN(char, char) +// TEAM_ALLTOALL_DEF_GEN(long double, longdouble) +TEAM_ALLTOALL_DEF_GEN(signed char, schar) +TEAM_ALLTOALL_DEF_GEN(short, short) +TEAM_ALLTOALL_DEF_GEN(int, int) +TEAM_ALLTOALL_DEF_GEN(long, long) +TEAM_ALLTOALL_DEF_GEN(long long, longlong) +TEAM_ALLTOALL_DEF_GEN(unsigned char, uchar) +TEAM_ALLTOALL_DEF_GEN(unsigned short, ushort) +TEAM_ALLTOALL_DEF_GEN(unsigned int, uint) +TEAM_ALLTOALL_DEF_GEN(unsigned long, ulong) +TEAM_ALLTOALL_DEF_GEN(unsigned long long, ulonglong) + +/****************************************************************************** + * DEVICE TEST KERNEL + *****************************************************************************/ +template +__global__ void TeamAlltoallTest(int loop, int skip, long long int *start_time, + long long int *end_time, T1 *source_buf, + T1 *dest_buf, int num_elems, + ShmemContextType ctx_type, + rocshmem_team_t *teams) { + __shared__ rocshmem_ctx_t ctx; + int wg_id = get_flat_grid_id(); + + rocshmem_wg_init(); + rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx); + + int n_pes = rocshmem_ctx_n_pes(ctx); + + source_buf += wg_id * n_pes * num_elems; + dest_buf += wg_id * n_pes * num_elems; + + __syncthreads(); + + for (int i = 0; i < loop + skip; i++) { + if (i == skip && hipThreadIdx_x == 0) { + start_time[wg_id] = wall_clock64(); + } + wg_team_alltoall(ctx, teams[wg_id], + dest_buf, // T* dest + source_buf, // const T* source + num_elems); // int nelement + } + + __syncthreads(); + + if (hipThreadIdx_x == 0) { + end_time[wg_id] = wall_clock64(); + } + + rocshmem_wg_ctx_destroy(&ctx); + rocshmem_wg_finalize(); +} + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +template +TeamAlltoallTester::TeamAlltoallTester(TesterArguments args) + : Tester(args){ + my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + // Number of elements per work group + int num_elems_wg = (args.max_msg_size / sizeof(T1)) * n_pes; + // Total number of elements in the GPU kernel + int total_elems = num_elems_wg * args.num_wgs; + int buff_size = total_elems * sizeof(T1); + + source_buf = (T1 *)rocshmem_malloc(buff_size); + dest_buf = (T1 *)rocshmem_malloc(buff_size); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cout << "Error allocating memory from symmetric heap" << std::endl; + std::cout << "source: " << source_buf + << ", dest: " << dest_buf + << std::endl; + rocshmem_global_exit(1); + } + + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } + + CHECK_HIP(hipMalloc(&team_alltoall_world_dup, + sizeof(rocshmem_team_t) * num_teams)); +} + +template +TeamAlltoallTester::~TeamAlltoallTester() { + rocshmem_free(source_buf); + rocshmem_free(dest_buf); + CHECK_HIP(hipFree(team_alltoall_world_dup)); +} + +template +void TeamAlltoallTester::preLaunchKernel() { + bw_factor = n_pes; + + for (int team_i = 0; team_i < num_teams; team_i++) { + team_alltoall_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_alltoall_world_dup[team_i]); + if (team_alltoall_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + std::cout << "Team " << team_i << " is invalid!" << std::endl; + abort(); + } + } +} + +template +void TeamAlltoallTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, uint64_t size) { + size_t shared_bytes = 0; + + int num_elems = size / sizeof(T1); + + hipLaunchKernelGGL(TeamAlltoallTest, gridSize, blockSize, shared_bytes, + stream, loop, args.skip, start_time, end_time, + source_buf, dest_buf, num_elems, _shmem_context, + team_alltoall_world_dup); + + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; +} + +template +void TeamAlltoallTester::postLaunchKernel() { + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_alltoall_world_dup[team_i]); + } +} + +template +void TeamAlltoallTester::resetBuffers(uint64_t size) { + + int num_elems = size / sizeof(T1); + int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes; + int idx = 0; + + for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) { + for(int pe = 0; pe < n_pes; pe++) { + for(int i = 0; i < num_elems; i++) { + idx = (wg_id * n_pes + pe) * num_elems + i; + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + source_buf[idx] = static_cast('a' + my_pe + pe + wg_id); + } + else if constexpr (std::is_floating_point::value) { + source_buf[idx] = static_cast(3.14 + my_pe + pe + wg_id); + } + else if constexpr (std::is_integral::value) { + source_buf[idx] = static_cast(my_pe + pe + wg_id); + } + } + } + } + + memset(dest_buf, -1, buff_size); +} + +template +void TeamAlltoallTester::verifyResults(uint64_t size) { + int num_elems = size / sizeof(T1); + int idx = 0; + + for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) { + for(int pe = 0; pe < n_pes; pe++) { + for(int i = 0; i < num_elems; i++) { + idx = (wg_id * n_pes + pe) * num_elems + i; + if (dest_buf[idx] != source_buf[idx]) { + std::cerr << "Data validation error at idx " << idx << std::endl; + std::cerr << "PE " << my_pe << " Got " << dest_buf[idx] + << ", Expected " << source_buf[idx] << std::endl; + exit(-1); + } + } + } + } +} diff --git a/projects/rocshmem/tests/functional_tests/alltoall_tester.hpp b/projects/rocshmem/tests/functional_tests/team_alltoall_tester.hpp similarity index 80% rename from projects/rocshmem/tests/functional_tests/alltoall_tester.hpp rename to projects/rocshmem/tests/functional_tests/team_alltoall_tester.hpp index 2e08b3b1ac..262b04efb0 100644 --- a/projects/rocshmem/tests/functional_tests/alltoall_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/team_alltoall_tester.hpp @@ -28,16 +28,16 @@ #include "tester.hpp" +using namespace rocshmem; + /************* ***************************************************************** * HOST TESTER CLASS *****************************************************************************/ template -class AlltoallTester : public Tester { +class TeamAlltoallTester : public Tester { public: - explicit AlltoallTester( - TesterArguments args, std::function f1, - std::function(const T1 &, T1)> f2); - virtual ~AlltoallTester(); + explicit TeamAlltoallTester(TesterArguments args); + virtual ~TeamAlltoallTester(); protected: virtual void resetBuffers(uint64_t size) override; @@ -51,14 +51,21 @@ class AlltoallTester : public Tester { virtual void verifyResults(uint64_t size) override; - T1 *source_buf; - T1 *dest_buf; + T1 *source_buf = nullptr; + T1 *dest_buf = nullptr; private: - std::function init_buf; - std::function(const T1 &, T1)> verify_buf; + int my_pe = 0; + int n_pes = 0; + + /** + * This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1. + * The default value for the maximum number of teams is 40. + */ + int num_teams = 39; + rocshmem_team_t *team_alltoall_world_dup; }; -#include "alltoall_tester.cpp" +#include "team_alltoall_tester.cpp" #endif diff --git a/projects/rocshmem/tests/functional_tests/team_broadcast_tester.cpp b/projects/rocshmem/tests/functional_tests/team_broadcast_tester.cpp index 68c981e797..43d5a7ac4c 100644 --- a/projects/rocshmem/tests/functional_tests/team_broadcast_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/team_broadcast_tester.cpp @@ -20,8 +20,6 @@ * IN THE SOFTWARE. *****************************************************************************/ -using namespace rocshmem; - /* Declare the template with a generic implementation */ template __device__ void wg_team_broadcast(rocshmem_ctx_t ctx, rocshmem_team_t team, @@ -65,28 +63,29 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time, long long int *end_time, T1 *source_buf, T1 *dest_buf, int size, ShmemContextType ctx_type, - rocshmem_team_t team) { + rocshmem_team_t *teams) { __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); rocshmem_wg_init(); - rocshmem_wg_ctx_create(ctx_type, &ctx); + rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx); int n_pes = rocshmem_ctx_n_pes(ctx); + source_buf += wg_id * size; + dest_buf += wg_id * size; __syncthreads(); - for (int i = 0; i < loop; i++) { + for (int i = 0; i < loop + skip; i++) { if (i == skip && hipThreadIdx_x == 0) { start_time[wg_id] = wall_clock64(); } - wg_team_broadcast(ctx, team, - dest_buf, // T* dest - source_buf, // const T* source - size, // int nelement - 0); // int PE_root - rocshmem_ctx_wg_barrier_all(ctx); + wg_team_broadcast(ctx, teams[wg_id], + dest_buf, // T* dest + source_buf, // const T* source + size, // int nelement + 0); // int PE_root } __syncthreads(); @@ -103,27 +102,51 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time, * HOST TESTER CLASS METHODS *****************************************************************************/ template -TeamBroadcastTester::TeamBroadcastTester( - TesterArguments args, std::function f1, - std::function(const T1 &)> f2) - : Tester(args), init_buf{f1}, verify_buf{f2} { - source_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1)); - dest_buf = (T1 *)rocshmem_malloc(args.max_msg_size * sizeof(T1)); +TeamBroadcastTester::TeamBroadcastTester(TesterArguments args) + : Tester(args){ + my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + // Total number of elements in src buffer + int total_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs ; + int buff_size = total_elems * sizeof(T1); + + source_buf = (T1 *)rocshmem_malloc(buff_size); + dest_buf = (T1 *)rocshmem_malloc(buff_size); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cout << "Error allocating memory from symmetric heap" << std::endl; + std::cout << "source: " << source_buf << ", dest: " << dest_buf << std::endl; + rocshmem_global_exit(1); + } + + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } + + CHECK_HIP(hipMalloc(&team_bcast_world_dup, + sizeof(rocshmem_team_t) * num_teams)); } template TeamBroadcastTester::~TeamBroadcastTester() { rocshmem_free(source_buf); rocshmem_free(dest_buf); + CHECK_HIP(hipFree(team_bcast_world_dup)); } template void TeamBroadcastTester::preLaunchKernel() { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); - - team_bcast_world_dup = ROCSHMEM_TEAM_INVALID; - rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, - &team_bcast_world_dup); + for (int team_i = 0; team_i < num_teams; team_i++) { + team_bcast_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_bcast_world_dup[team_i]); + if (team_bcast_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + printf("Team %d is invalid!\n", team_i); + abort(); + } + } } template @@ -131,34 +154,90 @@ void TeamBroadcastTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(TeamBroadcastTest, gridSize, blockSize, shared_bytes, - stream, loop, args.skip, start_time, end_time, source_buf, - dest_buf, size, _shmem_context, team_bcast_world_dup); + int num_elems = size / sizeof(T1); - num_msgs = loop + args.skip; - num_timed_msgs = loop; + hipLaunchKernelGGL(TeamBroadcastTest, gridSize, blockSize, + shared_bytes, stream, loop, args.skip, + start_time, end_time, source_buf, dest_buf, + num_elems, _shmem_context, team_bcast_world_dup); + + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; } template void TeamBroadcastTester::postLaunchKernel() { - rocshmem_team_destroy(team_bcast_world_dup); + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_bcast_world_dup[team_i]); + } } template void TeamBroadcastTester::resetBuffers(uint64_t size) { - for (uint64_t i = 0; i < args.max_msg_size; i++) { - init_buf(source_buf[i], dest_buf[i]); + + int num_elems = size / sizeof(T1); + int buff_size = num_elems * sizeof(T1) * args.num_wgs; + int idx = 0; + + for (int wg_id = 0; wg_id < args.num_wgs; wg_id++) { + for (int i = 0; i < num_elems; i++) { + idx = wg_id * num_elems + i; + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + source_buf[idx] = static_cast('a' + n_pes + wg_id); + dest_buf[idx] = static_cast('a' + wg_id); + } + else if constexpr (std::is_floating_point::value) { + source_buf[idx] = static_cast(3.14 + n_pes + wg_id); + dest_buf[idx] = static_cast(3.14 + wg_id); + } + else if constexpr (std::is_integral::value) { + source_buf[idx] = static_cast(n_pes + wg_id); + dest_buf[idx] = static_cast(wg_id); + } + } } } template void TeamBroadcastTester::verifyResults(uint64_t size) { - for (uint64_t i = 0; i < size; i++) { - auto r = verify_buf(dest_buf[i]); - if (r.first == false) { - fprintf(stderr, "Data validation error at idx %lu\n", i); - fprintf(stderr, "%s.\n", r.second.c_str()); - exit(-1); + + int num_elems = size / sizeof(T1); + int idx = 0; + T1 expected; + + /** + * The verification routine here requires that the + * PE_root value is 0 which denotes that the + * sending processing element is rank 0. + * + * The difference in expected values arises from + * the specification for broadcast where the + * PE_root processing element does not copy the + * contents from its own source to dest during + * the broadcast. + */ + for (int wg_id = 0; wg_id < args.num_wgs; wg_id++) { + for (int i = 0; i < num_elems; i++) { + idx = wg_id * num_elems + i; + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + expected = static_cast('a' + wg_id + (my_pe ? n_pes : 0)); + } + else if constexpr (std::is_floating_point::value) { + expected = static_cast(3.14 + wg_id + (my_pe ? n_pes : 0)); + } + else if constexpr (std::is_integral::value) { + expected = static_cast(wg_id + (my_pe ? n_pes : 0)); + } + if (dest_buf[idx] != expected) { + std::cerr << "Data validation error at idx " << idx << std::endl; + std::cerr << "PE " << my_pe << " Got " << dest_buf[idx] + << ", Expected " << expected << std::endl; + exit(-1); + } } } } diff --git a/projects/rocshmem/tests/functional_tests/team_broadcast_tester.hpp b/projects/rocshmem/tests/functional_tests/team_broadcast_tester.hpp index c2bd005b52..cebb09a61d 100644 --- a/projects/rocshmem/tests/functional_tests/team_broadcast_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/team_broadcast_tester.hpp @@ -28,15 +28,15 @@ #include "tester.hpp" +using namespace rocshmem; + /************* ***************************************************************** * HOST TESTER CLASS *****************************************************************************/ template class TeamBroadcastTester : public Tester { public: - explicit TeamBroadcastTester( - TesterArguments args, std::function f1, - std::function(const T1 &)> f2); + explicit TeamBroadcastTester(TesterArguments args); virtual ~TeamBroadcastTester(); protected: @@ -55,8 +55,14 @@ class TeamBroadcastTester : public Tester { T1 *dest_buf; private: - std::function init_buf; - std::function(const T1 &)> verify_buf; + int my_pe = 0; + int n_pes = 0; + /** + * This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1. + * The default value for the maximum number of teams is 40. + */ + int num_teams = 39; + rocshmem_team_t *team_bcast_world_dup; }; #include "team_broadcast_tester.cpp" diff --git a/projects/rocshmem/tests/functional_tests/team_fcollect_tester.cpp b/projects/rocshmem/tests/functional_tests/team_fcollect_tester.cpp new file mode 100644 index 0000000000..4f8b85e693 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/team_fcollect_tester.cpp @@ -0,0 +1,232 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +/* Declare the template with a generic implementation */ +template +__device__ void wg_team_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team, + T *dest, const T *source, int nelems) { + return; +} + +/* Define templates to call rocSHMEM */ +#define TEAM_FCOLLECT_DEF_GEN(T, TNAME) \ + template <> \ + __device__ void wg_team_fcollect(rocshmem_ctx_t ctx, rocshmem_team_t team,\ + T * dest, const T *source, int nelem) { \ + rocshmem_ctx_##TNAME##_wg_fcollect(ctx, team, dest, source, nelem); \ + } + +TEAM_FCOLLECT_DEF_GEN(float, float) +TEAM_FCOLLECT_DEF_GEN(double, double) +TEAM_FCOLLECT_DEF_GEN(char, char) +// TEAM_FCOLLECT_DEF_GEN(long double, longdouble) +TEAM_FCOLLECT_DEF_GEN(signed char, schar) +TEAM_FCOLLECT_DEF_GEN(short, short) +TEAM_FCOLLECT_DEF_GEN(int, int) +TEAM_FCOLLECT_DEF_GEN(long, long) +TEAM_FCOLLECT_DEF_GEN(long long, longlong) +TEAM_FCOLLECT_DEF_GEN(unsigned char, uchar) +TEAM_FCOLLECT_DEF_GEN(unsigned short, ushort) +TEAM_FCOLLECT_DEF_GEN(unsigned int, uint) +TEAM_FCOLLECT_DEF_GEN(unsigned long, ulong) +TEAM_FCOLLECT_DEF_GEN(unsigned long long, ulonglong) + +/****************************************************************************** + * DEVICE TEST KERNEL + *****************************************************************************/ +template +__global__ void TeamFcollectTest(int loop, int skip, long long int *start_time, + long long int *end_time, T1 *source_buf, + T1 *dest_buf, int num_elems, + ShmemContextType ctx_type, + rocshmem_team_t *teams) { + __shared__ rocshmem_ctx_t ctx; + int wg_id = get_flat_grid_id(); + + rocshmem_wg_init(); + rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx); + + int n_pes = rocshmem_ctx_n_pes(ctx); + source_buf += wg_id * num_elems; + dest_buf += wg_id * num_elems * n_pes; + + __syncthreads(); + + for (int i = 0; i < loop + skip; i++) { + if (i == skip && hipThreadIdx_x == 0) { + start_time[wg_id] = wall_clock64(); + } + wg_team_fcollect(ctx, teams[wg_id], + dest_buf, // T* dest + source_buf, // const T* source + num_elems); // int nelement + } + + __syncthreads(); + + if (hipThreadIdx_x == 0) { + end_time[wg_id] = wall_clock64(); + } + + rocshmem_wg_ctx_destroy(&ctx); + rocshmem_wg_finalize(); +} + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +template +TeamFcollectTester::TeamFcollectTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + // Total number of elements in src buffer + int total_elems = (args.max_msg_size / sizeof(T1)) * args.num_wgs ; + int buff_size = total_elems * sizeof(T1); + + source_buf = (T1 *)rocshmem_malloc(buff_size); + dest_buf = (T1 *)rocshmem_malloc(buff_size * n_pes); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cout << "Error allocating memory from symmetric heap" << std::endl; + std::cout << "source: " << source_buf + << ", dest: " << dest_buf + << std::endl; + rocshmem_global_exit(1); + } + + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + for (int i = 0; i < total_elems; ++i) { + source_buf[i] = static_cast('a' + my_pe); + } + } + else if constexpr (std::is_floating_point::value) { + for (int i = 0; i < total_elems; ++i) { + source_buf[i] = static_cast(3.14 + my_pe); + } + } + else if constexpr (std::is_integral::value) { + for (int i = 0; i < total_elems; i++) { + source_buf[i] = static_cast(my_pe); + } + } + + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } + + CHECK_HIP(hipMalloc(&team_fcollect_world_dup, + sizeof(rocshmem_team_t) * num_teams)); +} + +template +TeamFcollectTester::~TeamFcollectTester() { + rocshmem_free(source_buf); + rocshmem_free(dest_buf); + CHECK_HIP(hipFree(team_fcollect_world_dup)); +} + +template +void TeamFcollectTester::preLaunchKernel() { + bw_factor = n_pes; + + for (int team_i = 0; team_i < num_teams; team_i++) { + team_fcollect_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_fcollect_world_dup[team_i]); + if (team_fcollect_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + std::cout << "Team " << team_i << " is invalid!" << std::endl; + abort(); + } + } +} + +template +void TeamFcollectTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, uint64_t size) { + size_t shared_bytes = 0; + + int num_elems = size / sizeof(T1); + + int my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + hipLaunchKernelGGL(TeamFcollectTest, gridSize, blockSize, shared_bytes, + stream, loop, args.skip, start_time, end_time, + source_buf, dest_buf, num_elems, _shmem_context, + team_fcollect_world_dup); + + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; +} + +template +void TeamFcollectTester::postLaunchKernel() { + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_fcollect_world_dup[team_i]); + } +} + +template +void TeamFcollectTester::resetBuffers(uint64_t size) { + int num_elems = (size / sizeof(T1)); + int buff_size = num_elems * sizeof(T1) * args.num_wgs * n_pes; + + memset(dest_buf, -1, buff_size); +} + +template +void TeamFcollectTester::verifyResults(uint64_t size) { + + int num_elems = size / sizeof(T1); + int idx = 0; + T1 expected; + + for(int wg_id = 0; wg_id < args.num_wgs; wg_id++) { + for(int pe = 0; pe < n_pes; pe++) { + for(int i = 0; i < num_elems; i++) { + idx = (wg_id * n_pes + pe) * num_elems + i; + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + expected = static_cast('a' + pe); + } + else if constexpr (std::is_floating_point::value) { + expected = static_cast(3.14 + pe); + } + else if constexpr (std::is_integral::value) { + expected = pe; + } + if (dest_buf[idx] != expected) { + std::cerr << "Data validation error at idx " << idx << std::endl; + std::cerr << "PE " << my_pe << " Got " << dest_buf[idx] + << ", Expected " << expected << std::endl; + exit(-1); + } + } + } + } +} diff --git a/projects/rocshmem/tests/functional_tests/fcollect_tester.hpp b/projects/rocshmem/tests/functional_tests/team_fcollect_tester.hpp similarity index 82% rename from projects/rocshmem/tests/functional_tests/fcollect_tester.hpp rename to projects/rocshmem/tests/functional_tests/team_fcollect_tester.hpp index d7cb73a860..c695069c7a 100644 --- a/projects/rocshmem/tests/functional_tests/fcollect_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/team_fcollect_tester.hpp @@ -28,16 +28,16 @@ #include "tester.hpp" +using namespace rocshmem; + /************* ***************************************************************** * HOST TESTER CLASS *****************************************************************************/ template -class FcollectTester : public Tester { +class TeamFcollectTester : public Tester { public: - explicit FcollectTester( - TesterArguments args, std::function f1, - std::function(const T1 &, T1)> f2); - virtual ~FcollectTester(); + explicit TeamFcollectTester(TesterArguments args); + virtual ~TeamFcollectTester(); protected: virtual void resetBuffers(uint64_t size) override; @@ -55,10 +55,16 @@ class FcollectTester : public Tester { T1 *dest_buf; private: - std::function init_buf; - std::function(const T1 &, T1)> verify_buf; + int my_pe = 0; + int n_pes = 0; + /** + * This constant should equal ROCSHMEM_MAX_NUM_TEAMS - 1. + * The default value for the maximum number of teams is 40. + */ + int num_teams = 39; + rocshmem_team_t *team_fcollect_world_dup; }; -#include "fcollect_tester.cpp" +#include "team_fcollect_tester.cpp" #endif diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 4a7ae9327f..223d7d9d5b 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -30,14 +30,12 @@ #include #include -#include "alltoall_tester.hpp" #include "amo_bitwise_tester.hpp" #include "amo_extended_tester.hpp" #include "amo_standard_tester.hpp" #include "barrier_all_tester.hpp" #include "empty_tester.hpp" #include "extended_primitives.hpp" -#include "fcollect_tester.hpp" #include "ping_all_tester.hpp" #include "ping_pong_tester.hpp" #include "primitive_mr_tester.hpp" @@ -47,9 +45,11 @@ #include "signaling_operations_tester.hpp" #include "swarm_tester.hpp" #include "sync_tester.hpp" +#include "team_alltoall_tester.hpp" #include "team_broadcast_tester.hpp" #include "team_ctx_infra_tester.hpp" #include "team_ctx_primitive_tester.hpp" +#include "team_fcollect_tester.hpp" #include "team_reduction_tester.hpp" #include "wave_level_primitives.hpp" @@ -162,85 +162,37 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) { std::cout << "Team Broadcast Test ###" << std::endl; } - testers.push_back(new TeamBroadcastTester( - args, - [](long& f1, long& f2) { - f1 = 1; - f2 = 2; - }, - [rank](long v) { - long expected_val; - /** - * The verification routine here requires that the - * PE_root value is 0 which denotes that the - * sending processing element is rank 0. - * - * The difference in expected values arises from - * the specification for broadcast where the - * PE_root processing element does not copy the - * contents from its own source to dest during - * the broadcast. - */ - if (rank == 0) { - expected_val = 2; - } else { - expected_val = 1; - } - - return (v == expected_val) - ? std::make_pair(true, "") - : std::make_pair( - false, "Rank " + std::to_string(rank) + ", Got " + - std::to_string(v) + ", Expect " + - std::to_string(expected_val)); - })); + testers.push_back(new TeamBroadcastTester(args)); + testers.push_back(new TeamBroadcastTester(args)); + testers.push_back(new TeamBroadcastTester(args)); + testers.push_back(new TeamBroadcastTester(args)); + testers.push_back(new TeamBroadcastTester(args)); + testers.push_back(new TeamBroadcastTester(args)); + testers.push_back(new TeamBroadcastTester(args)); return testers; - case AllToAllTestType: + case TeamAllToAllTestType: if (rank == 0) { std::cout << "Alltoall Test ###" << std::endl; } - testers.push_back(new AlltoallTester( - args, - [rank](int64_t& f1, int64_t& f2, int64_t dest_pe) { - const long SRC_SHIFT = 16; - // Make value for each src, dst pair unique - // by shifting src by SRC_SHIFT bits - f1 = (rank << SRC_SHIFT) + dest_pe; - f2 = -1; - }, - [rank](int64_t v, int64_t src_pe) { - const long SRC_SHIFT = 16; - // See if we obtained unique value - long expected_val = (src_pe << SRC_SHIFT) + rank; - - return (v == expected_val) - ? std::make_pair(true, "") - : std::make_pair( - false, "Rank " + std::to_string(rank) + ", Got " + - std::to_string(v) + ", Expect " + - std::to_string(expected_val)); - })); + testers.push_back(new TeamAlltoallTester(args)); + testers.push_back(new TeamAlltoallTester(args)); + testers.push_back(new TeamAlltoallTester(args)); + testers.push_back(new TeamAlltoallTester(args)); + testers.push_back(new TeamAlltoallTester(args)); + testers.push_back(new TeamAlltoallTester(args)); + testers.push_back(new TeamAlltoallTester(args)); return testers; - case FCollectTestType: + case TeamFCollectTestType: if (rank == 0) { std::cout << "Fcollect Test ###" << std::endl; } - testers.push_back(new FcollectTester( - args, - [rank](int64_t& f1, int64_t& f2) { - f1 = rank; - f2 = -1; - }, - [rank](int64_t v, int64_t src_pe) { - int64_t expected_val = src_pe; - - return (v == expected_val) - ? std::make_pair(true, "") - : std::make_pair( - false, "Rank " + std::to_string(rank) + ", Got " + - std::to_string(v) + ", Expect " + - std::to_string(expected_val)); - })); + testers.push_back(new TeamFcollectTester(args)); + testers.push_back(new TeamFcollectTester(args)); + testers.push_back(new TeamFcollectTester(args)); + testers.push_back(new TeamFcollectTester(args)); + testers.push_back(new TeamFcollectTester(args)); + testers.push_back(new TeamFcollectTester(args)); + testers.push_back(new TeamFcollectTester(args)); return testers; case AMO_FAddTestType: if (rank == 0) std::cout << "AMO Fetch_Add ###" << std::endl; @@ -525,7 +477,7 @@ bool Tester::peLaunchesKernel() { */ is_launcher = is_launcher || (_type == TeamReductionTestType) || (_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) || - (_type == AllToAllTestType) || (_type == FCollectTestType) || + (_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) || (_type == PingPongTestType) || (_type == BarrierAllTestType) || (_type == SyncTestType) || (_type == SyncAllTestType) || (_type == RandomAccessTestType) || (_type == PingAllTestType); diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index cb5c1ab333..29308e731f 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -53,8 +53,8 @@ enum TestType { SyncAllTestType = 16, SyncTestType = 17, CollectTestType = 18, - FCollectTestType = 19, - AllToAllTestType = 20, + TeamFCollectTestType = 19, + TeamAllToAllTestType = 20, AllToAllsTestType = 21, ShmemPtrTestType = 22, PTestType = 23, diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index a9fe9d0482..086c68450c 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -85,6 +85,7 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case AMO_IncTestType: case AMO_FetchTestType: case BarrierAllTestType: + case SyncAllTestType: case SyncTestType: case ShmemPtrTestType: min_msg_size = 8; @@ -97,6 +98,11 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case RandomAccessTestType: min_msg_size = 4; break; + case TeamFCollectTestType: + case TeamAllToAllTestType: + case TeamBroadcastTestType: + min_msg_size = 8; + break; case TeamCtxInfraTestType: max_msg_size = min_msg_size; break; @@ -137,8 +143,8 @@ void TesterArguments::get_rocshmem_arguments() { TestType type = (TestType)algorithm; if ((type != BarrierAllTestType) && (type != SyncAllTestType) && - (type != SyncTestType) && (type != AllToAllTestType) && - (type != FCollectTestType) && (type != TeamReductionTestType) && + (type != SyncTestType) && (type != TeamAllToAllTestType) && + (type != TeamFCollectTestType) && (type != TeamReductionTestType) && (type != TeamBroadcastTestType) && (type != PingAllTestType)) { if (numprocs != 2) { if (myid == 0) {