diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index 92c4952f80..57c73e0c2a 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -26,61 +26,60 @@ declare -A TEST_NUMBERS=( ["getnbi"]="1" ["put"]="2" ["putnbi"]="3" - ["getswarm"]="4" - ["amo_fadd"]="5" - ["amo_finc"]="6" - ["amo_fetch"]="7" - ["amo_fcswap"]="8" - ["amo_add"]="9" - ["amo_inc"]="10" - ["amo_cswap"]="11" - ["init"]="12" - ["pingpong"]="13" - ["randomaccess"]="14" - ["barrierall"]="15" - ["syncall"]="16" - ["sync"]="17" - ["collect"]="18" - ["fcollect"]="19" - ["alltoall"]="20" - ["alltoalls"]="21" - ["shmemptr"]="22" - ["p"]="23" - ["g"]="24" - ["wgget"]="25" - ["wggetnbi"]="26" - ["wgput"]="27" - ["wgputnbi"]="28" - ["waveget"]="29" - ["wavegetnbi"]="30" - ["waveput"]="31" - ["waveputnbi"]="32" - ["teambroadcast"]="33" - ["teamreduction"]="34" - ["teamctxget"]="35" - ["teamctxgetnbi"]="36" - ["teamctxput"]="37" - ["teamctxputnbi"]="38" - ["teamctxinfra"]="39" - ["putnbimr"]="40" - ["amo_set"]="41" - ["amo_swap"]="42" - ["amo_fetchand"]="43" - ["amo_fetchor"]="44" - ["amo_fetchxor"]="45" - ["amo_and"]="46" - ["amo_or"]="47" - ["amo_xor"]="48" - ["pingall"]="49" - ["putsignal"]="50" - ["wgputsignal"]="51" - ["waveputsignal"]="52" - ["putsignalnbi"]="53" - ["wgputsignalnbi"]="54" - ["waveputsignalnbi"]="55" - ["signalfetch"]="56" - ["wgsignalfetch"]="57" - ["wavesignalfetch"]="58" + ["amo_fadd"]="4" + ["amo_finc"]="5" + ["amo_fetch"]="6" + ["amo_fcswap"]="7" + ["amo_add"]="8" + ["amo_inc"]="9" + ["amo_cswap"]="10" + ["init"]="11" + ["pingpong"]="12" + ["randomaccess"]="13" + ["barrierall"]="14" + ["syncall"]="15" + ["sync"]="16" + ["collect"]="17" + ["fcollect"]="18" + ["alltoall"]="19" + ["alltoalls"]="20" + ["shmemptr"]="21" + ["p"]="22" + ["g"]="23" + ["wgget"]="24" + ["wggetnbi"]="25" + ["wgput"]="26" + ["wgputnbi"]="27" + ["waveget"]="28" + ["wavegetnbi"]="29" + ["waveput"]="30" + ["waveputnbi"]="31" + ["teambroadcast"]="32" + ["teamreduction"]="33" + ["teamctxget"]="34" + ["teamctxgetnbi"]="35" + ["teamctxput"]="36" + ["teamctxputnbi"]="37" + ["teamctxinfra"]="38" + ["putnbimr"]="39" + ["amo_set"]="40" + ["amo_swap"]="41" + ["amo_fetchand"]="42" + ["amo_fetchor"]="43" + ["amo_fetchxor"]="44" + ["amo_and"]="45" + ["amo_or"]="46" + ["amo_xor"]="47" + ["pingall"]="48" + ["putsignal"]="49" + ["wgputsignal"]="50" + ["waveputsignal"]="51" + ["putsignalnbi"]="52" + ["wgputsignalnbi"]="53" + ["waveputsignalnbi"]="54" + ["signalfetch"]="55" + ["wgsignalfetch"]="56" + ["wavesignalfetch"]="57" ) ExecTest() { @@ -159,7 +158,8 @@ TestRMA() { ExecTest "waveput" 2 2 128 1048576 ExecTest "waveput" 2 16 128 8 - ExecTest "teamctxput" 2 1 1 1048576 + ExecTest "teamctxput" 2 4 128 1024 + ExecTest "teamctxput" 2 16 256 1024 ExecTest "get" 2 1 1 1048576 ExecTest "get" 2 1 1024 512 @@ -177,7 +177,8 @@ TestRMA() { ExecTest "waveget" 2 2 128 1048576 ExecTest "waveget" 2 16 128 8 - ExecTest "teamctxget" 2 1 1 1048576 + ExecTest "teamctxget" 2 4 128 1024 + ExecTest "teamctxget" 2 16 256 1024 ExecTest "g" 2 1 1 1048576 ExecTest "g" 2 1 1024 512 @@ -211,7 +212,8 @@ TestRMA() { ExecTest "waveputnbi" 2 2 128 1048576 ExecTest "waveputnbi" 2 16 128 8 - ExecTest "teamctxputnbi" 2 1 1 1048576 + ExecTest "teamctxputnbi" 2 4 128 1024 + ExecTest "teamctxputnbi" 2 16 256 1024 ExecTest "getnbi" 2 1 1 1048576 ExecTest "getnbi" 2 1 1024 512 @@ -229,7 +231,8 @@ TestRMA() { ExecTest "wavegetnbi" 2 2 128 1048576 ExecTest "wavegetnbi" 2 16 128 8 - ExecTest "teamctxgetnbi" 2 1 1 1048576 + ExecTest "teamctxgetnbi" 2 4 128 1024 + ExecTest "teamctxgetnbi" 2 16 256 1024 } TestAMO() { diff --git a/projects/rocshmem/tests/functional_tests/CMakeLists.txt b/projects/rocshmem/tests/functional_tests/CMakeLists.txt index fcf6a63ba3..6c991e989b 100644 --- a/projects/rocshmem/tests/functional_tests/CMakeLists.txt +++ b/projects/rocshmem/tests/functional_tests/CMakeLists.txt @@ -50,14 +50,13 @@ target_sources( amo_bitwise_tester.cpp amo_extended_tester.cpp amo_standard_tester.cpp - swarm_tester.cpp random_access_tester.cpp shmem_ptr_tester.cpp signaling_operations_tester.cpp signaling_operations_tester.hpp - extended_primitives.cpp + workgroup_primitives.cpp empty_tester.cpp - wave_level_primitives.cpp + wavefront_primitives.cpp ) ############################################################################### diff --git a/projects/rocshmem/tests/functional_tests/primitive_tester.cpp b/projects/rocshmem/tests/functional_tests/primitive_tester.cpp index 135733b2af..d98011bd39 100644 --- a/projects/rocshmem/tests/functional_tests/primitive_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/primitive_tester.cpp @@ -30,43 +30,65 @@ using namespace rocshmem; * DEVICE TEST KERNEL *****************************************************************************/ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time, - long long int *end_time, char *s_buf, - char *r_buf, int size, TestType type, - ShmemContextType ctx_type) { + long long int *end_time, char *source, + char *dest, int size, TestType type, + ShmemContextType ctx_type, int wf_size) { __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); + int t_id = get_flat_block_id(); + int wf_id = t_id / wf_size; rocshmem_wg_init(); rocshmem_wg_ctx_create(ctx_type, &ctx); + /** + * Shared array to capture the start time for each wavefront + * Max threads per block = 1024, wavefront size = 64 (in most GPUs) + * Maximum array size required = 1024/64 = 16 + */ + __shared__ long long int wf_start_time[16]; + + /** + * Calculate start index for each thread within the grid + */ + uint64_t offset = size * get_flat_id(); + source += offset; + dest += offset; + for (int i = 0; i < loop + skip; i++) { if (i == skip) { - __syncthreads(); - start_time[wg_id] = wall_clock64(); + __syncthreads(); + // Ensures all RMA calls from the skip loops are completed + if(is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); + } + __syncthreads(); + // Capture the start time of each wavefront to identify the earliest one + wf_start_time[wf_id] = wall_clock64(); } switch (type) { case GetTestType: - rocshmem_ctx_getmem(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_getmem(ctx, dest, source, size, 1); break; case GetNBITestType: - rocshmem_ctx_getmem_nbi(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_getmem_nbi(ctx, dest, source, size, 1); break; case PutTestType: - rocshmem_ctx_putmem(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_putmem(ctx, dest, source, size, 1); break; case PutNBITestType: - rocshmem_ctx_putmem_nbi(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_putmem_nbi(ctx, dest, source, size, 1); break; case PTestType: for (int s = 0; s < size; s++) { - char val = s_buf[s]; - rocshmem_ctx_char_p(ctx, &r_buf[s], val, 1); + char val = source[s]; + rocshmem_ctx_char_p(ctx, &dest[s], val, 1); } break; case GTestType: for (int s = 0; s < size; s++) { - char ret = rocshmem_ctx_char_g(ctx, &s_buf[s], 1); - r_buf[s] = ret; + char ret = rocshmem_ctx_char_g(ctx, &source[s], 1); + dest[s] = ret; } break; default: @@ -74,12 +96,28 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time, } } - rocshmem_ctx_quiet(ctx); + __syncthreads(); + if(is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); + } + /** + * End time of the last wavefront is recorded by overwriting + * the value previously set by earlier wavefronts. + */ + end_time[wg_id] = wall_clock64(); + + // Find the earliest start time + int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1; + for (int i = num_wfs / 2; i > 0; i >>= 1 ) { + if(t_id < i) { + wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]); + } + } __syncthreads(); - if (hipThreadIdx_x == 0) { - end_time[wg_id] = wall_clock64(); + if (t_id == 0) { + start_time[wg_id] = wf_start_time[0]; } rocshmem_wg_ctx_destroy(&ctx); @@ -90,18 +128,35 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time, * HOST TESTER CLASS METHODS *****************************************************************************/ PrimitiveTester::PrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size); - r_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size); + size_t buff_size = args.max_msg_size * args.wg_size * args.num_wgs; + source = (char *)rocshmem_malloc(buff_size); + dest = (char *)rocshmem_malloc(buff_size); + + if (source == nullptr || dest == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source << ", dest: " << dest << std::endl; + if (source) { + rocshmem_free(source); + } + if (dest) { + rocshmem_free(dest); + } + rocshmem_global_exit(1); + } + + for(size_t i = 0; i < buff_size; i++) { + source[i] = static_cast('a' + i % 26); + } } PrimitiveTester::~PrimitiveTester() { - rocshmem_free(s_buf); - rocshmem_free(r_buf); + rocshmem_free(source); + rocshmem_free(dest); } void PrimitiveTester::resetBuffers(uint64_t size) { - memset(s_buf, '0', args.max_msg_size * args.wg_size); - memset(r_buf, '1', args.max_msg_size * args.wg_size); + size_t buff_size = size * args.wg_size * args.num_wgs; + memset(dest, '1', buff_size); } void PrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, @@ -109,11 +164,11 @@ void PrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, size_t shared_bytes = 0; hipLaunchKernelGGL(PrimitiveTest, gridSize, blockSize, shared_bytes, stream, - loop, args.skip, start_time, end_time, s_buf, r_buf, - size, _type, _shmem_context); + loop, args.skip, start_time, end_time, source, dest, + size, _type, _shmem_context, wf_size); - num_msgs = (loop + args.skip) * gridSize.x; - num_timed_msgs = loop; + num_msgs = (loop + args.skip) * gridSize.x * blockSize.x; + num_timed_msgs = loop * gridSize.x * blockSize.x; } void PrimitiveTester::verifyResults(uint64_t size) { @@ -123,10 +178,12 @@ void PrimitiveTester::verifyResults(uint64_t size) { : 1; if (args.myid == check_id) { - for (uint64_t i = 0; i < size; i++) { - if (r_buf[i] != '0') { - fprintf(stderr, "Data validation error at idx %lu\n", i); - fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0'); + size_t buff_size = size * args.wg_size * args.num_wgs; + for (uint64_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; exit(-1); } } diff --git a/projects/rocshmem/tests/functional_tests/primitive_tester.hpp b/projects/rocshmem/tests/functional_tests/primitive_tester.hpp index 14f98f0330..ea0df7440f 100644 --- a/projects/rocshmem/tests/functional_tests/primitive_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/primitive_tester.hpp @@ -41,8 +41,8 @@ class PrimitiveTester : public Tester { virtual void verifyResults(uint64_t size) override; - char *s_buf = nullptr; - char *r_buf = nullptr; + char *source = nullptr; + char *dest = nullptr; }; #endif diff --git a/projects/rocshmem/tests/functional_tests/swarm_tester.cpp b/projects/rocshmem/tests/functional_tests/swarm_tester.cpp deleted file mode 100644 index 831febbfa0..0000000000 --- a/projects/rocshmem/tests/functional_tests/swarm_tester.cpp +++ /dev/null @@ -1,95 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - *****************************************************************************/ - -#include "swarm_tester.hpp" - -#include - -using namespace rocshmem; - -/****************************************************************************** - * DEVICE TEST KERNEL - *****************************************************************************/ -__global__ void GetSwarmTest(int loop, int skip, long long int *start_time, - long long int *end_time, char *s_buf, - char *r_buf, int size, ShmemContextType ctx_type) { - __shared__ rocshmem_ctx_t ctx; - int wg_id = get_flat_grid_id(); - - int provided; - rocshmem_wg_init_thread(ROCSHMEM_THREAD_MULTIPLE, &provided); - assert(provided == ROCSHMEM_THREAD_MULTIPLE); - - rocshmem_wg_ctx_create(ctx_type, &ctx); - - __syncthreads(); - - int index = hipThreadIdx_x * size; - - for (int i = 0; i < loop + skip; i++) { - if (i == skip) { - start_time[wg_id] = wall_clock64(); - } - rocshmem_ctx_getmem(ctx, &r_buf[index], &s_buf[index], size, 1); - - __syncthreads(); - } - - // atomicAdd((unsigned long long *)&timer[hipBlockIdx_x], - // rocshmem_timer() - start); - - end_time[wg_id] = wall_clock64(); - - rocshmem_wg_ctx_destroy(&ctx); - rocshmem_wg_finalize(); -} - -/****************************************************************************** - * HOST TESTER CLASS METHODS - *****************************************************************************/ -GetSwarmTester::GetSwarmTester(TesterArguments args) : PrimitiveTester(args) {} - -GetSwarmTester::~GetSwarmTester() {} - -void GetSwarmTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, - uint64_t size) { - size_t shared_bytes = 0; - - hipLaunchKernelGGL(GetSwarmTest, gridSize, blockSize, shared_bytes, stream, - loop, args.skip, start_time, end_time, s_buf, r_buf, size, - _shmem_context); - - num_msgs = (loop + args.skip) * gridSize.x * blockSize.x; - num_timed_msgs = loop * gridSize.x * blockSize.x; -} - -void GetSwarmTester::verifyResults(uint64_t size) { - if (args.myid == 0) { - for (uint64_t i = 0; i < size * args.wg_size; i++) { - if (r_buf[i] != '0') { - fprintf(stderr, "Data validation error at idx %lu\n", i); - fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0'); - exit(-1); - } - } - } -} diff --git a/projects/rocshmem/tests/functional_tests/swarm_tester.hpp b/projects/rocshmem/tests/functional_tests/swarm_tester.hpp deleted file mode 100644 index e9c1bf75f7..0000000000 --- a/projects/rocshmem/tests/functional_tests/swarm_tester.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - *****************************************************************************/ - -#ifndef _SWARM_TESTER_HPP_ -#define _SWARM_TESTER_HPP_ - -#include "primitive_tester.hpp" - -/****************************************************************************** - * DEVICE TEST KERNEL - *****************************************************************************/ -__global__ void GetSwarmTest(int loop, int skip, long long int *start_time, - long long int *end_time, char *s_buf, - char *r_buf, int size); - -/****************************************************************************** - * HOST TESTER CLASS - *****************************************************************************/ -class GetSwarmTester : public PrimitiveTester { - public: - explicit GetSwarmTester(TesterArguments args); - virtual ~GetSwarmTester(); - - protected: - virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, - uint64_t size) override; - - virtual void verifyResults(uint64_t size) override; -}; - -#endif diff --git a/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.cpp b/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.cpp index 45c6549e9c..5633fb22ba 100644 --- a/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.cpp @@ -32,43 +32,83 @@ rocshmem_team_t team_primitive_world_dup; * DEVICE TEST KERNEL *****************************************************************************/ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_time, - long long int *end_time, char *s_buf, - char *r_buf, int size, TestType type, - ShmemContextType ctx_type, + long long int *end_time, char *source, + char *dest, int size, TestType type, + ShmemContextType ctx_type, int wf_size, rocshmem_team_t team) { __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); + int t_id = get_flat_block_id(); + int wf_id = t_id / wf_size; rocshmem_wg_init(); rocshmem_wg_team_create_ctx(team, ctx_type, &ctx); - if (hipThreadIdx_x == 0) { + /** + * Shared array to capture the start time for each wavefront + * Max threads per block = 1024, wavefront size = 64 (in most GPUs) + * Maximum array size required = 1024/64 = 16 + */ + __shared__ long long int wf_start_time[16]; - for (int i = 0; i < loop + skip; i++) { - if (i == skip) { - start_time[wg_id] = wall_clock64(); - } - switch (type) { - case TeamCtxGetTestType: - rocshmem_ctx_getmem(ctx, r_buf, s_buf, size, 1); - break; - case TeamCtxGetNBITestType: - rocshmem_ctx_getmem_nbi(ctx, r_buf, s_buf, size, 1); - break; - case TeamCtxPutTestType: - rocshmem_ctx_putmem(ctx, r_buf, s_buf, size, 1); - break; - case TeamCtxPutNBITestType: - rocshmem_ctx_putmem_nbi(ctx, r_buf, s_buf, size, 1); - break; - default: - break; + /** + * Calculate start index for each thread within the grid + */ + uint64_t offset = size * get_flat_id(); + source += offset; + dest += offset; + + for (int i = 0; i < loop + skip; i++) { + if (i == skip) { + __syncthreads(); + // Ensures all RMA calls from the skip loops are completed + if(is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); } + __syncthreads(); + // Capture the start time of each wavefront to identify the earliest one + wf_start_time[wf_id] = wall_clock64(); } + switch (type) { + case TeamCtxGetTestType: + rocshmem_ctx_getmem(ctx, dest, source, size, 1); + break; + case TeamCtxGetNBITestType: + rocshmem_ctx_getmem_nbi(ctx, dest, source, size, 1); + break; + case TeamCtxPutTestType: + rocshmem_ctx_putmem(ctx, dest, source, size, 1); + break; + case TeamCtxPutNBITestType: + rocshmem_ctx_putmem_nbi(ctx, dest, source, size, 1); + break; + default: + break; + } + } + __syncthreads(); + if(is_thread_zero_in_block()) { rocshmem_ctx_quiet(ctx); + } - end_time[wg_id] = wall_clock64(); + /** + * End time of the last wavefront is recorded by overwriting + * the value previously set by earlier wavefronts. + */ + end_time[wg_id] = wall_clock64(); + + // Find the earliest start time + int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1; + for (int i = num_wfs / 2; i > 0; i >>= 1 ) { + if(t_id < i) { + wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]); + } + } + __syncthreads(); + + if (t_id == 0) { + start_time[wg_id] = wf_start_time[0]; } rocshmem_wg_ctx_destroy(&ctx); @@ -80,18 +120,35 @@ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_ti *****************************************************************************/ TeamCtxPrimitiveTester::TeamCtxPrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size); - r_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size); + size_t buff_size = args.max_msg_size * args.wg_size * args.num_wgs; + source = (char *)rocshmem_malloc(buff_size); + dest = (char *)rocshmem_malloc(buff_size); + + if (source == nullptr || dest == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source << ", dest: " << dest << std::endl; + if (source) { + rocshmem_free(source); + } + if (dest) { + rocshmem_free(dest); + } + rocshmem_global_exit(1); + } + + for(size_t i = 0; i < buff_size; i++) { + source[i] = static_cast('a' + i % 26); + } } TeamCtxPrimitiveTester::~TeamCtxPrimitiveTester() { - rocshmem_free(s_buf); - rocshmem_free(r_buf); + rocshmem_free(source); + rocshmem_free(dest); } void TeamCtxPrimitiveTester::resetBuffers(uint64_t size) { - memset(s_buf, '0', args.max_msg_size * args.wg_size); - memset(r_buf, '1', args.max_msg_size * args.wg_size); + size_t buff_size = size * args.wg_size * args.num_wgs; + memset(dest, '1', buff_size); } void TeamCtxPrimitiveTester::preLaunchKernel() { @@ -107,12 +164,12 @@ void TeamCtxPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, size_t shared_bytes = 0; hipLaunchKernelGGL(TeamCtxPrimitiveTest, gridSize, blockSize, shared_bytes, - stream, loop, args.skip, start_time, end_time, s_buf, - r_buf, size, _type, _shmem_context, + stream, loop, args.skip, start_time, end_time, source, + dest, size, _type, _shmem_context, wf_size, team_primitive_world_dup); - num_msgs = (loop + args.skip) * gridSize.x; - num_timed_msgs = loop * gridSize.x; + num_msgs = (loop + args.skip) * gridSize.x * blockSize.x; + num_timed_msgs = loop * gridSize.x * blockSize.x; } void TeamCtxPrimitiveTester::postLaunchKernel() { @@ -124,10 +181,12 @@ void TeamCtxPrimitiveTester::verifyResults(uint64_t size) { (_type == TeamCtxGetTestType || _type == TeamCtxGetNBITestType) ? 0 : 1; if (args.myid == check_id) { - for (uint64_t i = 0; i < size; i++) { - if (r_buf[i] != '0') { - fprintf(stderr, "Data validation error at idx %lu\n", i); - fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0'); + size_t buff_size = size * args.wg_size * args.num_wgs; + for (uint64_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; exit(-1); } } diff --git a/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.hpp b/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.hpp index fa4c6ac369..59d917f0da 100644 --- a/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/team_ctx_primitive_tester.hpp @@ -45,8 +45,8 @@ class TeamCtxPrimitiveTester : public Tester { virtual void verifyResults(uint64_t size) override; - char *s_buf = nullptr; - char *r_buf = nullptr; + char *source = nullptr; + char *dest = nullptr; }; #endif diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 223d7d9d5b..65e0ea11b6 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -35,7 +35,6 @@ #include "amo_standard_tester.hpp" #include "barrier_all_tester.hpp" #include "empty_tester.hpp" -#include "extended_primitives.hpp" #include "ping_all_tester.hpp" #include "ping_pong_tester.hpp" #include "primitive_mr_tester.hpp" @@ -43,7 +42,6 @@ #include "random_access_tester.hpp" #include "shmem_ptr_tester.hpp" #include "signaling_operations_tester.hpp" -#include "swarm_tester.hpp" #include "sync_tester.hpp" #include "team_alltoall_tester.hpp" #include "team_broadcast_tester.hpp" @@ -51,7 +49,8 @@ #include "team_ctx_primitive_tester.hpp" #include "team_fcollect_tester.hpp" #include "team_reduction_tester.hpp" -#include "wave_level_primitives.hpp" +#include "wavefront_primitives.hpp" +#include "workgroup_primitives.hpp" Tester::Tester(TesterArguments args) : args(args) { _type = (TestType)args.algorithm; @@ -66,6 +65,16 @@ Tester::Tester(TesterArguments args) : args(args) { CHECK_HIP(hipDeviceGetAttribute(&wall_clk_rate, hipDeviceAttributeWallClockRate, device_id)); num_timers = args.num_wgs; + switch (_type) { + case WAVEGetTestType: + case WAVEGetNBITestType: + case WAVEPutTestType: + case WAVEPutNBITestType: + num_timers = args.num_wgs * num_warps; + break; + default: + break; + } CHECK_HIP(hipMalloc((void**)&timer, sizeof(long long int) * num_timers)); CHECK_HIP(hipMalloc((void**)&start_time, sizeof(long long int) * num_timers)); CHECK_HIP(hipMalloc((void**)&end_time, sizeof(long long int) * num_timers)); @@ -137,10 +146,6 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "G Test ###" << std::endl; testers.push_back(new PrimitiveTester(args)); return testers; - case GetSwarmTestType: - if (rank == 0) std::cout << "Get Swarm ###" << std::endl; - testers.push_back(new GetSwarmTester(args)); - return testers; case TeamReductionTestType: if (rank == 0) std::cout << "All-to-All Team-based Reduction ###" << std::endl; @@ -309,22 +314,22 @@ std::vector Tester::create(TesterArguments args) { case WGGetTestType: if (rank == 0) std::cout << "Blocking WG level Gets ###" << std::endl; - testers.push_back(new ExtendedPrimitiveTester(args)); + testers.push_back(new WorkGroupPrimitiveTester(args)); return testers; case WGGetNBITestType: if (rank == 0) std::cout << "Non-Blocking WG level Gets ###" << std::endl; - testers.push_back(new ExtendedPrimitiveTester(args)); + testers.push_back(new WorkGroupPrimitiveTester(args)); return testers; case WGPutTestType: if (rank == 0) std::cout << "Blocking WG level Puts ###" << std::endl; - testers.push_back(new ExtendedPrimitiveTester(args)); + testers.push_back(new WorkGroupPrimitiveTester(args)); return testers; case WGPutNBITestType: if (rank == 0) std::cout << "Non-Blocking WG level Puts ###" << std::endl; - testers.push_back(new ExtendedPrimitiveTester(args)); + testers.push_back(new WorkGroupPrimitiveTester(args)); return testers; case PutNBIMRTestType: if (rank == 0) @@ -334,22 +339,22 @@ std::vector Tester::create(TesterArguments args) { case WAVEGetTestType: if (rank == 0) std::cout << "Blocking WAVE level Gets ###" << std::endl; - testers.push_back(new WaveLevelPrimitiveTester(args)); + testers.push_back(new WaveFrontPrimitiveTester(args)); return testers; case WAVEGetNBITestType: if (rank == 0) std::cout << "Non-Blocking WAVE level Gets ###" << std::endl; - testers.push_back(new WaveLevelPrimitiveTester(args)); + testers.push_back(new WaveFrontPrimitiveTester(args)); return testers; case WAVEPutTestType: if (rank == 0) std::cout << "Blocking WAVE level Puts ###" << std::endl; - testers.push_back(new WaveLevelPrimitiveTester(args)); + testers.push_back(new WaveFrontPrimitiveTester(args)); return testers; case WAVEPutNBITestType: if (rank == 0) std::cout << "Non-Blocking WAVE level Puts ###" << std::endl; - testers.push_back(new WaveLevelPrimitiveTester(args)); + testers.push_back(new WaveFrontPrimitiveTester(args)); return testers; case PutSignalTestType: if (rank == 0) std::cout << "Putmem Signal ###" << std::endl; @@ -495,18 +500,21 @@ void Tester::print(uint64_t size) { */ uint64_t total_size = size * num_timed_msgs; double timer_avg = timerAvgInMicroseconds(); - double latency_avg = timer_avg / num_timed_msgs; - double avg_msg_rate = num_timed_msgs / (timer_avg / 1e6); + + double time_us = gpuCyclesToMicroseconds(max_end_time - min_start_time); + double time_s = time_us / 1e6; + + double latency_avg = time_us / num_timed_msgs; + + double avg_msg_rate = num_timed_msgs / time_s; + + double bandwidth_avg_gbs = + static_cast(total_size * bw_factor) / time_s / pow(2, 30); float total_kern_time_ms; CHECK_HIP(hipEventElapsedTime(&total_kern_time_ms, start_event, stop_event)); float total_kern_time_s = total_kern_time_ms / 1000; - double time_us = gpuCyclesToMicroseconds(max_end_time - min_start_time); - double time_s = time_us / 1e6; - double bandwidth_avg_gbs = - static_cast(total_size * bw_factor) / time_s / pow(2, 30); - int field_width = 20; int float_precision = 2; diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index 29308e731f..80a107668b 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -38,61 +38,60 @@ enum TestType { GetNBITestType = 1, PutTestType = 2, PutNBITestType = 3, - GetSwarmTestType = 4, - AMO_FAddTestType = 5, - AMO_FIncTestType = 6, - AMO_FetchTestType = 7, - AMO_FCswapTestType = 8, - AMO_AddTestType = 9, - AMO_IncTestType = 10, - AMO_CswapTestType = 11, - InitTestType = 12, - PingPongTestType = 13, - RandomAccessTestType = 14, - BarrierAllTestType = 15, - SyncAllTestType = 16, - SyncTestType = 17, - CollectTestType = 18, - TeamFCollectTestType = 19, - TeamAllToAllTestType = 20, - AllToAllsTestType = 21, - ShmemPtrTestType = 22, - PTestType = 23, - GTestType = 24, - WGGetTestType = 25, - WGGetNBITestType = 26, - WGPutTestType = 27, - WGPutNBITestType = 28, - WAVEGetTestType = 29, - WAVEGetNBITestType = 30, - WAVEPutTestType = 31, - WAVEPutNBITestType = 32, - TeamBroadcastTestType = 33, - TeamReductionTestType = 34, - TeamCtxGetTestType = 35, - TeamCtxGetNBITestType = 36, - TeamCtxPutTestType = 37, - TeamCtxPutNBITestType = 38, - TeamCtxInfraTestType = 39, - PutNBIMRTestType = 40, - AMO_SetTestType = 41, - AMO_SwapTestType = 42, - AMO_FetchAndTestType = 43, - AMO_FetchOrTestType = 44, - AMO_FetchXorTestType = 45, - AMO_AndTestType = 46, - AMO_OrTestType = 47, - AMO_XorTestType = 48, - PingAllTestType = 49, - PutSignalTestType = 50, - WGPutSignalTestType = 51, - WAVEPutSignalTestType = 52, - PutSignalNBITestType = 53, - WGPutSignalNBITestType = 54, - WAVEPutSignalNBITestType = 55, - SignalFetchTestType = 56, - WGSignalFetchTestType = 57, - WAVESignalFetchTestType = 58, + AMO_FAddTestType = 4, + AMO_FIncTestType = 5, + AMO_FetchTestType = 6, + AMO_FCswapTestType = 7, + AMO_AddTestType = 8, + AMO_IncTestType = 9, + AMO_CswapTestType = 10, + InitTestType = 11, + PingPongTestType = 12, + RandomAccessTestType = 13, + BarrierAllTestType = 14, + SyncAllTestType = 15, + SyncTestType = 16, + CollectTestType = 17, + TeamFCollectTestType = 18, + TeamAllToAllTestType = 19, + AllToAllsTestType = 20, + ShmemPtrTestType = 21, + PTestType = 22, + GTestType = 23, + WGGetTestType = 24, + WGGetNBITestType = 25, + WGPutTestType = 26, + WGPutNBITestType = 27, + WAVEGetTestType = 28, + WAVEGetNBITestType = 29, + WAVEPutTestType = 30, + WAVEPutNBITestType = 31, + TeamBroadcastTestType = 32, + TeamReductionTestType = 33, + TeamCtxGetTestType = 34, + TeamCtxGetNBITestType = 35, + TeamCtxPutTestType = 36, + TeamCtxPutNBITestType = 37, + TeamCtxInfraTestType = 38, + PutNBIMRTestType = 39, + AMO_SetTestType = 40, + AMO_SwapTestType = 41, + AMO_FetchAndTestType = 42, + AMO_FetchOrTestType = 43, + AMO_FetchXorTestType = 44, + AMO_AndTestType = 45, + AMO_OrTestType = 46, + AMO_XorTestType = 47, + PingAllTestType = 48, + PutSignalTestType = 49, + WGPutSignalTestType = 50, + WAVEPutSignalTestType = 51, + PutSignalNBITestType = 52, + WGPutSignalNBITestType = 53, + WAVEPutSignalNBITestType = 54, + SignalFetchTestType = 55, + WGSignalFetchTestType = 56, + WAVESignalFetchTestType = 57, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index 086c68450c..c7258b9f61 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -109,16 +109,6 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case PutNBIMRTestType: min_msg_size = max_msg_size; break; - case WAVEGetTestType: - case WAVEGetNBITestType: - case WAVEPutTestType: - case WAVEPutNBITestType: - case WGGetTestType: - case WGGetNBITestType: - case WGPutTestType: - case WGPutNBITestType: - min_msg_size = 4; - break; default: break; } diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp index 6c1cda15a3..4a6d82e728 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp @@ -27,6 +27,7 @@ #include #include #include +#include class TesterArguments { public: diff --git a/projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp b/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp similarity index 55% rename from projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp rename to projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp index d935834d24..a0a95f0b10 100644 --- a/projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp @@ -20,7 +20,7 @@ * IN THE SOFTWARE. *****************************************************************************/ -#include "wave_level_primitives.hpp" +#include "wavefront_primitives.hpp" #include @@ -31,54 +31,56 @@ using namespace rocshmem; /****************************************************************************** * DEVICE TEST KERNEL *****************************************************************************/ -__global__ void WaveLevelPrimitiveTest(int loop, int skip, +__global__ void WaveFrontPrimitiveTest(int loop, int skip, long long int *start_time, - long long int *end_time, char *s_buf, - char *r_buf, int size, TestType type, - ShmemContextType ctx_type, int wf_size) { + long long int *end_time, char *source, + char *dest, int size, TestType type, + ShmemContextType ctx_type, + int wf_size) { __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); rocshmem_wg_init(); rocshmem_wg_ctx_create(ctx_type, &ctx); - /** - * Calculate start index for each wavefront for tiled version - * If the number of wavefronts is greater than 1, this kernel performs a - * tiled functional test - */ + // Calculate start index for each wavefront int wf_id = get_flat_block_id() / wf_size; - int wg_offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size); - int idx = wf_id * size + wg_offset; - s_buf += idx; - r_buf += idx; + int wg_offset = wg_id * ((get_flat_block_size() - 1 ) / wf_size + 1); + int idx = wf_id + wg_offset; + int offset = size * idx; + source += offset; + dest += offset; for (int i = 0; i < loop + skip; i++) { if (i == skip) { - start_time[wg_id] = wall_clock64(); + // Ensures all RMA calls from the skip loops are completed + if(is_thread_zero_in_wave()) { + rocshmem_ctx_quiet(ctx); + } + __syncthreads(); + start_time[idx] = wall_clock64(); } switch (type) { case WAVEGetTestType: - rocshmem_ctx_getmem_wave(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_getmem_wave(ctx, dest, source, size, 1); break; case WAVEGetNBITestType: - rocshmem_ctx_getmem_nbi_wave(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_getmem_nbi_wave(ctx, dest, source, size, 1); break; case WAVEPutTestType: - rocshmem_ctx_putmem_wave(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_putmem_wave(ctx, dest, source, size, 1); break; case WAVEPutNBITestType: - rocshmem_ctx_putmem_nbi_wave(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_putmem_nbi_wave(ctx, dest, source, size, 1); break; default: break; } } - rocshmem_ctx_quiet(ctx); - - if (hipThreadIdx_x == 0) { - end_time[hipBlockIdx_x] = wall_clock64(); + if (is_thread_zero_in_wave()) { + rocshmem_ctx_quiet(ctx); + end_time[idx] = wall_clock64(); } rocshmem_wg_ctx_destroy(&ctx); @@ -88,48 +90,64 @@ __global__ void WaveLevelPrimitiveTest(int loop, int skip, /****************************************************************************** * HOST TESTER CLASS METHODS *****************************************************************************/ -WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args) +WaveFrontPrimitiveTester::WaveFrontPrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = static_cast( - rocshmem_malloc(args.max_msg_size * args.num_wgs * num_warps)); - r_buf = static_cast( - rocshmem_malloc(args.max_msg_size * args.num_wgs * num_warps)); + size_t buff_size = args.max_msg_size * args.num_wgs * num_warps; + source = (char *)rocshmem_malloc(buff_size); + dest = (char *)rocshmem_malloc(buff_size); + + if (source == nullptr || dest == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source << ", dest: " << dest << std::endl; + if (source) { + rocshmem_free(source); + } + if (dest) { + rocshmem_free(dest); + } + rocshmem_global_exit(1); + } + + for(size_t i = 0; i < buff_size; i++) { + source[i] = static_cast('a' + i % 26); + } } -WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() { - rocshmem_free(s_buf); - rocshmem_free(r_buf); +WaveFrontPrimitiveTester::~WaveFrontPrimitiveTester() { + rocshmem_free(source); + rocshmem_free(dest); } -void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) { - num_elems = (size * args.num_wgs * num_warps) / sizeof(int); - std::iota(s_buf, s_buf + num_elems, 0); - memset(r_buf, 0, size * args.num_wgs * num_warps); +void WaveFrontPrimitiveTester::resetBuffers(uint64_t size) { + size_t buff_size = size * args.num_wgs * num_warps; + memset(dest, '1', buff_size); } -void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, +void WaveFrontPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes, + hipLaunchKernelGGL(WaveFrontPrimitiveTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, start_time, end_time, - (char*)s_buf, (char*)r_buf, size, _type, _shmem_context, + source, dest, size, _type, _shmem_context, wf_size); num_msgs = (loop + args.skip) * gridSize.x * num_warps; num_timed_msgs = loop * gridSize.x * num_warps; } -void WaveLevelPrimitiveTester::verifyResults(uint64_t size) { +void WaveFrontPrimitiveTester::verifyResults(uint64_t size) { int check_id = (_type == WAVEGetTestType || _type == WAVEGetNBITestType) ? 0 : 1; if (args.myid == check_id) { - for (int i = 0; i < num_elems; i++) { - if (r_buf[i] != i) { - fprintf(stderr, "Data validation error at idx %d\n", i); - fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i); + size_t buff_size = size * args.num_wgs * num_warps; + for (size_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; exit(-1); } } diff --git a/projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp b/projects/rocshmem/tests/functional_tests/wavefront_primitives.hpp similarity index 88% rename from projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp rename to projects/rocshmem/tests/functional_tests/wavefront_primitives.hpp index af65ac0a30..e1c2923087 100644 --- a/projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp +++ b/projects/rocshmem/tests/functional_tests/wavefront_primitives.hpp @@ -24,15 +24,14 @@ #define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_ #include "tester.hpp" -#include "../src/util.hpp" /****************************************************************************** * HOST TESTER CLASS *****************************************************************************/ -class WaveLevelPrimitiveTester : public Tester { +class WaveFrontPrimitiveTester : public Tester { public: - explicit WaveLevelPrimitiveTester(TesterArguments args); - virtual ~WaveLevelPrimitiveTester(); + explicit WaveFrontPrimitiveTester(TesterArguments args); + virtual ~WaveFrontPrimitiveTester(); protected: virtual void resetBuffers(uint64_t size) override; @@ -42,9 +41,8 @@ class WaveLevelPrimitiveTester : public Tester { virtual void verifyResults(uint64_t size) override; - int *s_buf = nullptr; - int *r_buf = nullptr; - int num_elems = 0; + char *source = nullptr; + char *dest = nullptr; }; #endif diff --git a/projects/rocshmem/tests/functional_tests/extended_primitives.cpp b/projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp similarity index 57% rename from projects/rocshmem/tests/functional_tests/extended_primitives.cpp rename to projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp index 26a3ec50a1..8d455e1001 100644 --- a/projects/rocshmem/tests/functional_tests/extended_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp @@ -20,7 +20,7 @@ * IN THE SOFTWARE. *****************************************************************************/ -#include "extended_primitives.hpp" +#include "workgroup_primitives.hpp" #include @@ -31,51 +31,51 @@ using namespace rocshmem; /****************************************************************************** * DEVICE TEST KERNEL *****************************************************************************/ -__global__ void ExtendedPrimitiveTest(int loop, int skip, +__global__ void WorkGroupPrimitiveTest(int loop, int skip, long long int *start_time, - long long int *end_time, char *s_buf, - char *r_buf, int size, TestType type, + long long int *end_time, char *source, + char *dest, int size, TestType type, ShmemContextType ctx_type) { __shared__ rocshmem_ctx_t ctx; int wg_id = get_flat_grid_id(); rocshmem_wg_init(); rocshmem_wg_ctx_create(ctx_type, &ctx); - /** - * Calculate start index for each work group for tiled version - * If the number of work groups is greater than 1, this kernel performs a - * tiled functional test - */ - uint64_t idx = size * get_flat_grid_id(); - s_buf += idx; - r_buf += idx; + // Calculate start index for each work group + uint64_t offset = size * wg_id; + source += offset; + dest += offset; for (int i = 0; i < loop + skip; i++) { if (i == skip) { + // Ensures all RMA calls from the skip loops are completed + if (is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); + } + __syncthreads(); start_time[wg_id] = wall_clock64(); } switch (type) { case WGGetTestType: - rocshmem_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_getmem_wg(ctx, dest, source, size, 1); break; case WGGetNBITestType: - rocshmem_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_getmem_nbi_wg(ctx, dest, source, size, 1); break; case WGPutTestType: - rocshmem_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_putmem_wg(ctx, dest, source, size, 1); break; case WGPutNBITestType: - rocshmem_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1); + rocshmem_ctx_putmem_nbi_wg(ctx, dest, source, size, 1); break; default: break; } } - rocshmem_ctx_quiet(ctx); - - if (hipThreadIdx_x == 0) { + if (is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); end_time[wg_id] = wall_clock64(); } @@ -86,45 +86,63 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, /****************************************************************************** * HOST TESTER CLASS METHODS *****************************************************************************/ -ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args) +WorkGroupPrimitiveTester::WorkGroupPrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = static_cast(rocshmem_malloc(args.max_msg_size * args.num_wgs)); - r_buf = static_cast(rocshmem_malloc(args.max_msg_size * args.num_wgs)); + size_t buff_size = args.max_msg_size * args.num_wgs; + source = (char *)rocshmem_malloc(buff_size); + dest = (char *)rocshmem_malloc(buff_size); + + if (source == nullptr || dest == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source << ", dest: " << dest << std::endl; + if (source) { + rocshmem_free(source); + } + if (dest) { + rocshmem_free(dest); + } + rocshmem_global_exit(1); + } + + for(size_t i = 0; i < buff_size; i++) { + source[i] = static_cast('a' + i % 26); + } } -ExtendedPrimitiveTester::~ExtendedPrimitiveTester() { - rocshmem_free(s_buf); - rocshmem_free(r_buf); +WorkGroupPrimitiveTester::~WorkGroupPrimitiveTester() { + rocshmem_free(source); + rocshmem_free(dest); } -void ExtendedPrimitiveTester::resetBuffers(uint64_t size) { - num_elems = (size * args.num_wgs) / sizeof(int); - std::iota(s_buf, s_buf + num_elems, 0); - memset(r_buf, 0, size * args.num_wgs); +void WorkGroupPrimitiveTester::resetBuffers(uint64_t size) { + size_t buff_size = size * args.num_wgs; + memset(dest, '1', buff_size); } -void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, +void WorkGroupPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes, + hipLaunchKernelGGL(WorkGroupPrimitiveTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, start_time, end_time, - (char*)s_buf, (char*)r_buf, size, _type, _shmem_context); + source, dest, size, _type, _shmem_context); num_msgs = (loop + args.skip) * gridSize.x; num_timed_msgs = loop * gridSize.x; } -void ExtendedPrimitiveTester::verifyResults(uint64_t size) { +void WorkGroupPrimitiveTester::verifyResults(uint64_t size) { int check_id = (_type == WGGetTestType || _type == WGGetNBITestType) ? 0 : 1; if (args.myid == check_id) { - for (int i = 0; i < num_elems; i++) { - if (r_buf[i] != i) { - fprintf(stderr, "Data validation error at idx %d\n", i); - fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i); + size_t buff_size = size * args.num_wgs; + for (size_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; exit(-1); } } diff --git a/projects/rocshmem/tests/functional_tests/extended_primitives.hpp b/projects/rocshmem/tests/functional_tests/workgroup_primitives.hpp similarity index 88% rename from projects/rocshmem/tests/functional_tests/extended_primitives.hpp rename to projects/rocshmem/tests/functional_tests/workgroup_primitives.hpp index 76225f07b0..b3a68d66fe 100644 --- a/projects/rocshmem/tests/functional_tests/extended_primitives.hpp +++ b/projects/rocshmem/tests/functional_tests/workgroup_primitives.hpp @@ -24,15 +24,14 @@ #define _EXTENDED_PRIMITIVES_HPP_ #include "tester.hpp" -#include "../src/util.hpp" /****************************************************************************** * HOST TESTER CLASS *****************************************************************************/ -class ExtendedPrimitiveTester : public Tester { +class WorkGroupPrimitiveTester : public Tester { public: - explicit ExtendedPrimitiveTester(TesterArguments args); - virtual ~ExtendedPrimitiveTester(); + explicit WorkGroupPrimitiveTester(TesterArguments args); + virtual ~WorkGroupPrimitiveTester(); protected: virtual void resetBuffers(uint64_t size) override; @@ -42,9 +41,8 @@ class ExtendedPrimitiveTester : public Tester { virtual void verifyResults(uint64_t size) override; - int *s_buf = nullptr; - int *r_buf = nullptr; - int num_elems = 0; + char *source = nullptr; + char *dest = nullptr; }; #endif