Update RMA functional tests (#50)

* Update primitive tests for multi-workgroup support

* Update workgroup primitive tests for multi-workgroup support

* Update workfront primitive tests for multi-workgroup support

* Update team based primitive tests for multi-workgroup support

* Update RMA functional tests to capture timing after quiet call
   - Modified RMA functional tests to record the time after a `quiet` call in thread, wavefront, and workgroup RMA calls.

* Improve error handling and memory management
   - Replaced `cout` with `cerr` for improved error reporting.
   - Ensured all allocated memory is freed when `rocshmem_malloc` fails.

* Update start time in primitive tests and latency calculations
   - Modified primitive tests to capture the earliest start time.
   - Updated latency calculations in functional tests.

* Remove `GetSwarmTester`

* Update start time in team primitive tests

* Invoke quiet call from a single thread within a block on a rocshmem context

[ROCm/rocshmem commit: aa3121a967]
Bu işleme şunda yer alıyor:
Avinash Kethineedi
2025-03-18 14:39:57 -05:00
işlemeyi yapan: GitHub
ebeveyn 9b187a2e44
işleme e16bb62767
16 değiştirilmiş dosya ile 463 ekleme ve 460 silme
+62 -59
Dosyayı Görüntüle
@@ -26,61 +26,60 @@ declare -A TEST_NUMBERS=(
["getnbi"]="1"
["put"]="2"
["putnbi"]="3"
["getswarm"]="4"
["amo_fadd"]="5"
["amo_finc"]="6"
["amo_fetch"]="7"
["amo_fcswap"]="8"
["amo_add"]="9"
["amo_inc"]="10"
["amo_cswap"]="11"
["init"]="12"
["pingpong"]="13"
["randomaccess"]="14"
["barrierall"]="15"
["syncall"]="16"
["sync"]="17"
["collect"]="18"
["fcollect"]="19"
["alltoall"]="20"
["alltoalls"]="21"
["shmemptr"]="22"
["p"]="23"
["g"]="24"
["wgget"]="25"
["wggetnbi"]="26"
["wgput"]="27"
["wgputnbi"]="28"
["waveget"]="29"
["wavegetnbi"]="30"
["waveput"]="31"
["waveputnbi"]="32"
["teambroadcast"]="33"
["teamreduction"]="34"
["teamctxget"]="35"
["teamctxgetnbi"]="36"
["teamctxput"]="37"
["teamctxputnbi"]="38"
["teamctxinfra"]="39"
["putnbimr"]="40"
["amo_set"]="41"
["amo_swap"]="42"
["amo_fetchand"]="43"
["amo_fetchor"]="44"
["amo_fetchxor"]="45"
["amo_and"]="46"
["amo_or"]="47"
["amo_xor"]="48"
["pingall"]="49"
["putsignal"]="50"
["wgputsignal"]="51"
["waveputsignal"]="52"
["putsignalnbi"]="53"
["wgputsignalnbi"]="54"
["waveputsignalnbi"]="55"
["signalfetch"]="56"
["wgsignalfetch"]="57"
["wavesignalfetch"]="58"
["amo_fadd"]="4"
["amo_finc"]="5"
["amo_fetch"]="6"
["amo_fcswap"]="7"
["amo_add"]="8"
["amo_inc"]="9"
["amo_cswap"]="10"
["init"]="11"
["pingpong"]="12"
["randomaccess"]="13"
["barrierall"]="14"
["syncall"]="15"
["sync"]="16"
["collect"]="17"
["fcollect"]="18"
["alltoall"]="19"
["alltoalls"]="20"
["shmemptr"]="21"
["p"]="22"
["g"]="23"
["wgget"]="24"
["wggetnbi"]="25"
["wgput"]="26"
["wgputnbi"]="27"
["waveget"]="28"
["wavegetnbi"]="29"
["waveput"]="30"
["waveputnbi"]="31"
["teambroadcast"]="32"
["teamreduction"]="33"
["teamctxget"]="34"
["teamctxgetnbi"]="35"
["teamctxput"]="36"
["teamctxputnbi"]="37"
["teamctxinfra"]="38"
["putnbimr"]="39"
["amo_set"]="40"
["amo_swap"]="41"
["amo_fetchand"]="42"
["amo_fetchor"]="43"
["amo_fetchxor"]="44"
["amo_and"]="45"
["amo_or"]="46"
["amo_xor"]="47"
["pingall"]="48"
["putsignal"]="49"
["wgputsignal"]="50"
["waveputsignal"]="51"
["putsignalnbi"]="52"
["wgputsignalnbi"]="53"
["waveputsignalnbi"]="54"
["signalfetch"]="55"
["wgsignalfetch"]="56"
["wavesignalfetch"]="57"
)
ExecTest() {
@@ -159,7 +158,8 @@ TestRMA() {
ExecTest "waveput" 2 2 128 1048576
ExecTest "waveput" 2 16 128 8
ExecTest "teamctxput" 2 1 1 1048576
ExecTest "teamctxput" 2 4 128 1024
ExecTest "teamctxput" 2 16 256 1024
ExecTest "get" 2 1 1 1048576
ExecTest "get" 2 1 1024 512
@@ -177,7 +177,8 @@ TestRMA() {
ExecTest "waveget" 2 2 128 1048576
ExecTest "waveget" 2 16 128 8
ExecTest "teamctxget" 2 1 1 1048576
ExecTest "teamctxget" 2 4 128 1024
ExecTest "teamctxget" 2 16 256 1024
ExecTest "g" 2 1 1 1048576
ExecTest "g" 2 1 1024 512
@@ -211,7 +212,8 @@ TestRMA() {
ExecTest "waveputnbi" 2 2 128 1048576
ExecTest "waveputnbi" 2 16 128 8
ExecTest "teamctxputnbi" 2 1 1 1048576
ExecTest "teamctxputnbi" 2 4 128 1024
ExecTest "teamctxputnbi" 2 16 256 1024
ExecTest "getnbi" 2 1 1 1048576
ExecTest "getnbi" 2 1 1024 512
@@ -229,7 +231,8 @@ TestRMA() {
ExecTest "wavegetnbi" 2 2 128 1048576
ExecTest "wavegetnbi" 2 16 128 8
ExecTest "teamctxgetnbi" 2 1 1 1048576
ExecTest "teamctxgetnbi" 2 4 128 1024
ExecTest "teamctxgetnbi" 2 16 256 1024
}
TestAMO() {
+2 -3
Dosyayı Görüntüle
@@ -50,14 +50,13 @@ target_sources(
amo_bitwise_tester.cpp
amo_extended_tester.cpp
amo_standard_tester.cpp
swarm_tester.cpp
random_access_tester.cpp
shmem_ptr_tester.cpp
signaling_operations_tester.cpp
signaling_operations_tester.hpp
extended_primitives.cpp
workgroup_primitives.cpp
empty_tester.cpp
wave_level_primitives.cpp
wavefront_primitives.cpp
)
###############################################################################
+87 -30
Dosyayı Görüntüle
@@ -30,43 +30,65 @@ using namespace rocshmem;
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
long long int *end_time, char *s_buf,
char *r_buf, int size, TestType type,
ShmemContextType ctx_type) {
long long int *end_time, char *source,
char *dest, int size, TestType type,
ShmemContextType ctx_type, int wf_size) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
/**
* Shared array to capture the start time for each wavefront
* Max threads per block = 1024, wavefront size = 64 (in most GPUs)
* Maximum array size required = 1024/64 = 16
*/
__shared__ long long int wf_start_time[16];
/**
* Calculate start index for each thread within the grid
*/
uint64_t offset = size * get_flat_id();
source += offset;
dest += offset;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
__syncthreads();
start_time[wg_id] = wall_clock64();
__syncthreads();
// Ensures all RMA calls from the skip loops are completed
if(is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
__syncthreads();
// Capture the start time of each wavefront to identify the earliest one
wf_start_time[wf_id] = wall_clock64();
}
switch (type) {
case GetTestType:
rocshmem_ctx_getmem(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_getmem(ctx, dest, source, size, 1);
break;
case GetNBITestType:
rocshmem_ctx_getmem_nbi(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_getmem_nbi(ctx, dest, source, size, 1);
break;
case PutTestType:
rocshmem_ctx_putmem(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_putmem(ctx, dest, source, size, 1);
break;
case PutNBITestType:
rocshmem_ctx_putmem_nbi(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_putmem_nbi(ctx, dest, source, size, 1);
break;
case PTestType:
for (int s = 0; s < size; s++) {
char val = s_buf[s];
rocshmem_ctx_char_p(ctx, &r_buf[s], val, 1);
char val = source[s];
rocshmem_ctx_char_p(ctx, &dest[s], val, 1);
}
break;
case GTestType:
for (int s = 0; s < size; s++) {
char ret = rocshmem_ctx_char_g(ctx, &s_buf[s], 1);
r_buf[s] = ret;
char ret = rocshmem_ctx_char_g(ctx, &source[s], 1);
dest[s] = ret;
}
break;
default:
@@ -74,12 +96,28 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
}
}
rocshmem_ctx_quiet(ctx);
__syncthreads();
if(is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
/**
* End time of the last wavefront is recorded by overwriting
* the value previously set by earlier wavefronts.
*/
end_time[wg_id] = wall_clock64();
// Find the earliest start time
int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1;
for (int i = num_wfs / 2; i > 0; i >>= 1 ) {
if(t_id < i) {
wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]);
}
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
if (t_id == 0) {
start_time[wg_id] = wf_start_time[0];
}
rocshmem_wg_ctx_destroy(&ctx);
@@ -90,18 +128,35 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
* HOST TESTER CLASS METHODS
*****************************************************************************/
PrimitiveTester::PrimitiveTester(TesterArguments args) : Tester(args) {
s_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
r_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
size_t buff_size = args.max_msg_size * args.wg_size * args.num_wgs;
source = (char *)rocshmem_malloc(buff_size);
dest = (char *)rocshmem_malloc(buff_size);
if (source == nullptr || dest == nullptr) {
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
if (source) {
rocshmem_free(source);
}
if (dest) {
rocshmem_free(dest);
}
rocshmem_global_exit(1);
}
for(size_t i = 0; i < buff_size; i++) {
source[i] = static_cast<char>('a' + i % 26);
}
}
PrimitiveTester::~PrimitiveTester() {
rocshmem_free(s_buf);
rocshmem_free(r_buf);
rocshmem_free(source);
rocshmem_free(dest);
}
void PrimitiveTester::resetBuffers(uint64_t size) {
memset(s_buf, '0', args.max_msg_size * args.wg_size);
memset(r_buf, '1', args.max_msg_size * args.wg_size);
size_t buff_size = size * args.wg_size * args.num_wgs;
memset(dest, '1', buff_size);
}
void PrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
@@ -109,11 +164,11 @@ void PrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
size_t shared_bytes = 0;
hipLaunchKernelGGL(PrimitiveTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time, s_buf, r_buf,
size, _type, _shmem_context);
loop, args.skip, start_time, end_time, source, dest,
size, _type, _shmem_context, wf_size);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop;
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
num_timed_msgs = loop * gridSize.x * blockSize.x;
}
void PrimitiveTester::verifyResults(uint64_t size) {
@@ -123,10 +178,12 @@ void PrimitiveTester::verifyResults(uint64_t size) {
: 1;
if (args.myid == check_id) {
for (uint64_t i = 0; i < size; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %lu\n", i);
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
size_t buff_size = size * args.wg_size * args.num_wgs;
for (uint64_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
+2 -2
Dosyayı Görüntüle
@@ -41,8 +41,8 @@ class PrimitiveTester : public Tester {
virtual void verifyResults(uint64_t size) override;
char *s_buf = nullptr;
char *r_buf = nullptr;
char *source = nullptr;
char *dest = nullptr;
};
#endif
-95
Dosyayı Görüntüle
@@ -1,95 +0,0 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#include "swarm_tester.hpp"
#include <rocshmem/rocshmem.hpp>
using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void GetSwarmTest(int loop, int skip, long long int *start_time,
long long int *end_time, char *s_buf,
char *r_buf, int size, ShmemContextType ctx_type) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
int provided;
rocshmem_wg_init_thread(ROCSHMEM_THREAD_MULTIPLE, &provided);
assert(provided == ROCSHMEM_THREAD_MULTIPLE);
rocshmem_wg_ctx_create(ctx_type, &ctx);
__syncthreads();
int index = hipThreadIdx_x * size;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
start_time[wg_id] = wall_clock64();
}
rocshmem_ctx_getmem(ctx, &r_buf[index], &s_buf[index], size, 1);
__syncthreads();
}
// atomicAdd((unsigned long long *)&timer[hipBlockIdx_x],
// rocshmem_timer() - start);
end_time[wg_id] = wall_clock64();
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
GetSwarmTester::GetSwarmTester(TesterArguments args) : PrimitiveTester(args) {}
GetSwarmTester::~GetSwarmTester() {}
void GetSwarmTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(GetSwarmTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time, s_buf, r_buf, size,
_shmem_context);
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
num_timed_msgs = loop * gridSize.x * blockSize.x;
}
void GetSwarmTester::verifyResults(uint64_t size) {
if (args.myid == 0) {
for (uint64_t i = 0; i < size * args.wg_size; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %lu\n", i);
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
exit(-1);
}
}
}
}
-50
Dosyayı Görüntüle
@@ -1,50 +0,0 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef _SWARM_TESTER_HPP_
#define _SWARM_TESTER_HPP_
#include "primitive_tester.hpp"
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void GetSwarmTest(int loop, int skip, long long int *start_time,
long long int *end_time, char *s_buf,
char *r_buf, int size);
/******************************************************************************
* HOST TESTER CLASS
*****************************************************************************/
class GetSwarmTester : public PrimitiveTester {
public:
explicit GetSwarmTester(TesterArguments args);
virtual ~GetSwarmTester();
protected:
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) override;
virtual void verifyResults(uint64_t size) override;
};
#endif
+97 -38
Dosyayı Görüntüle
@@ -32,43 +32,83 @@ rocshmem_team_t team_primitive_world_dup;
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_time,
long long int *end_time, char *s_buf,
char *r_buf, int size, TestType type,
ShmemContextType ctx_type,
long long int *end_time, char *source,
char *dest, int size, TestType type,
ShmemContextType ctx_type, int wf_size,
rocshmem_team_t team) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(team, ctx_type, &ctx);
if (hipThreadIdx_x == 0) {
/**
* Shared array to capture the start time for each wavefront
* Max threads per block = 1024, wavefront size = 64 (in most GPUs)
* Maximum array size required = 1024/64 = 16
*/
__shared__ long long int wf_start_time[16];
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
start_time[wg_id] = wall_clock64();
}
switch (type) {
case TeamCtxGetTestType:
rocshmem_ctx_getmem(ctx, r_buf, s_buf, size, 1);
break;
case TeamCtxGetNBITestType:
rocshmem_ctx_getmem_nbi(ctx, r_buf, s_buf, size, 1);
break;
case TeamCtxPutTestType:
rocshmem_ctx_putmem(ctx, r_buf, s_buf, size, 1);
break;
case TeamCtxPutNBITestType:
rocshmem_ctx_putmem_nbi(ctx, r_buf, s_buf, size, 1);
break;
default:
break;
/**
* Calculate start index for each thread within the grid
*/
uint64_t offset = size * get_flat_id();
source += offset;
dest += offset;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
__syncthreads();
// Ensures all RMA calls from the skip loops are completed
if(is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
__syncthreads();
// Capture the start time of each wavefront to identify the earliest one
wf_start_time[wf_id] = wall_clock64();
}
switch (type) {
case TeamCtxGetTestType:
rocshmem_ctx_getmem(ctx, dest, source, size, 1);
break;
case TeamCtxGetNBITestType:
rocshmem_ctx_getmem_nbi(ctx, dest, source, size, 1);
break;
case TeamCtxPutTestType:
rocshmem_ctx_putmem(ctx, dest, source, size, 1);
break;
case TeamCtxPutNBITestType:
rocshmem_ctx_putmem_nbi(ctx, dest, source, size, 1);
break;
default:
break;
}
}
__syncthreads();
if(is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
end_time[wg_id] = wall_clock64();
/**
* End time of the last wavefront is recorded by overwriting
* the value previously set by earlier wavefronts.
*/
end_time[wg_id] = wall_clock64();
// Find the earliest start time
int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1;
for (int i = num_wfs / 2; i > 0; i >>= 1 ) {
if(t_id < i) {
wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]);
}
}
__syncthreads();
if (t_id == 0) {
start_time[wg_id] = wf_start_time[0];
}
rocshmem_wg_ctx_destroy(&ctx);
@@ -80,18 +120,35 @@ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_ti
*****************************************************************************/
TeamCtxPrimitiveTester::TeamCtxPrimitiveTester(TesterArguments args)
: Tester(args) {
s_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
r_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
size_t buff_size = args.max_msg_size * args.wg_size * args.num_wgs;
source = (char *)rocshmem_malloc(buff_size);
dest = (char *)rocshmem_malloc(buff_size);
if (source == nullptr || dest == nullptr) {
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
if (source) {
rocshmem_free(source);
}
if (dest) {
rocshmem_free(dest);
}
rocshmem_global_exit(1);
}
for(size_t i = 0; i < buff_size; i++) {
source[i] = static_cast<char>('a' + i % 26);
}
}
TeamCtxPrimitiveTester::~TeamCtxPrimitiveTester() {
rocshmem_free(s_buf);
rocshmem_free(r_buf);
rocshmem_free(source);
rocshmem_free(dest);
}
void TeamCtxPrimitiveTester::resetBuffers(uint64_t size) {
memset(s_buf, '0', args.max_msg_size * args.wg_size);
memset(r_buf, '1', args.max_msg_size * args.wg_size);
size_t buff_size = size * args.wg_size * args.num_wgs;
memset(dest, '1', buff_size);
}
void TeamCtxPrimitiveTester::preLaunchKernel() {
@@ -107,12 +164,12 @@ void TeamCtxPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
size_t shared_bytes = 0;
hipLaunchKernelGGL(TeamCtxPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time, s_buf,
r_buf, size, _type, _shmem_context,
stream, loop, args.skip, start_time, end_time, source,
dest, size, _type, _shmem_context, wf_size,
team_primitive_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
num_timed_msgs = loop * gridSize.x * blockSize.x;
}
void TeamCtxPrimitiveTester::postLaunchKernel() {
@@ -124,10 +181,12 @@ void TeamCtxPrimitiveTester::verifyResults(uint64_t size) {
(_type == TeamCtxGetTestType || _type == TeamCtxGetNBITestType) ? 0 : 1;
if (args.myid == check_id) {
for (uint64_t i = 0; i < size; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %lu\n", i);
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
size_t buff_size = size * args.wg_size * args.num_wgs;
for (uint64_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
+2 -2
Dosyayı Görüntüle
@@ -45,8 +45,8 @@ class TeamCtxPrimitiveTester : public Tester {
virtual void verifyResults(uint64_t size) override;
char *s_buf = nullptr;
char *r_buf = nullptr;
char *source = nullptr;
char *dest = nullptr;
};
#endif
+30 -22
Dosyayı Görüntüle
@@ -35,7 +35,6 @@
#include "amo_standard_tester.hpp"
#include "barrier_all_tester.hpp"
#include "empty_tester.hpp"
#include "extended_primitives.hpp"
#include "ping_all_tester.hpp"
#include "ping_pong_tester.hpp"
#include "primitive_mr_tester.hpp"
@@ -43,7 +42,6 @@
#include "random_access_tester.hpp"
#include "shmem_ptr_tester.hpp"
#include "signaling_operations_tester.hpp"
#include "swarm_tester.hpp"
#include "sync_tester.hpp"
#include "team_alltoall_tester.hpp"
#include "team_broadcast_tester.hpp"
@@ -51,7 +49,8 @@
#include "team_ctx_primitive_tester.hpp"
#include "team_fcollect_tester.hpp"
#include "team_reduction_tester.hpp"
#include "wave_level_primitives.hpp"
#include "wavefront_primitives.hpp"
#include "workgroup_primitives.hpp"
Tester::Tester(TesterArguments args) : args(args) {
_type = (TestType)args.algorithm;
@@ -66,6 +65,16 @@ Tester::Tester(TesterArguments args) : args(args) {
CHECK_HIP(hipDeviceGetAttribute(&wall_clk_rate,
hipDeviceAttributeWallClockRate, device_id));
num_timers = args.num_wgs;
switch (_type) {
case WAVEGetTestType:
case WAVEGetNBITestType:
case WAVEPutTestType:
case WAVEPutNBITestType:
num_timers = args.num_wgs * num_warps;
break;
default:
break;
}
CHECK_HIP(hipMalloc((void**)&timer, sizeof(long long int) * num_timers));
CHECK_HIP(hipMalloc((void**)&start_time, sizeof(long long int) * num_timers));
CHECK_HIP(hipMalloc((void**)&end_time, sizeof(long long int) * num_timers));
@@ -137,10 +146,6 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) std::cout << "G Test ###" << std::endl;
testers.push_back(new PrimitiveTester(args));
return testers;
case GetSwarmTestType:
if (rank == 0) std::cout << "Get Swarm ###" << std::endl;
testers.push_back(new GetSwarmTester(args));
return testers;
case TeamReductionTestType:
if (rank == 0)
std::cout << "All-to-All Team-based Reduction ###" << std::endl;
@@ -309,22 +314,22 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
case WGGetTestType:
if (rank == 0)
std::cout << "Blocking WG level Gets ###" << std::endl;
testers.push_back(new ExtendedPrimitiveTester(args));
testers.push_back(new WorkGroupPrimitiveTester(args));
return testers;
case WGGetNBITestType:
if (rank == 0)
std::cout << "Non-Blocking WG level Gets ###" << std::endl;
testers.push_back(new ExtendedPrimitiveTester(args));
testers.push_back(new WorkGroupPrimitiveTester(args));
return testers;
case WGPutTestType:
if (rank == 0)
std::cout << "Blocking WG level Puts ###" << std::endl;
testers.push_back(new ExtendedPrimitiveTester(args));
testers.push_back(new WorkGroupPrimitiveTester(args));
return testers;
case WGPutNBITestType:
if (rank == 0)
std::cout << "Non-Blocking WG level Puts ###" << std::endl;
testers.push_back(new ExtendedPrimitiveTester(args));
testers.push_back(new WorkGroupPrimitiveTester(args));
return testers;
case PutNBIMRTestType:
if (rank == 0)
@@ -334,22 +339,22 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
case WAVEGetTestType:
if (rank == 0)
std::cout << "Blocking WAVE level Gets ###" << std::endl;
testers.push_back(new WaveLevelPrimitiveTester(args));
testers.push_back(new WaveFrontPrimitiveTester(args));
return testers;
case WAVEGetNBITestType:
if (rank == 0)
std::cout << "Non-Blocking WAVE level Gets ###" << std::endl;
testers.push_back(new WaveLevelPrimitiveTester(args));
testers.push_back(new WaveFrontPrimitiveTester(args));
return testers;
case WAVEPutTestType:
if (rank == 0)
std::cout << "Blocking WAVE level Puts ###" << std::endl;
testers.push_back(new WaveLevelPrimitiveTester(args));
testers.push_back(new WaveFrontPrimitiveTester(args));
return testers;
case WAVEPutNBITestType:
if (rank == 0)
std::cout << "Non-Blocking WAVE level Puts ###" << std::endl;
testers.push_back(new WaveLevelPrimitiveTester(args));
testers.push_back(new WaveFrontPrimitiveTester(args));
return testers;
case PutSignalTestType:
if (rank == 0) std::cout << "Putmem Signal ###" << std::endl;
@@ -495,18 +500,21 @@ void Tester::print(uint64_t size) {
*/
uint64_t total_size = size * num_timed_msgs;
double timer_avg = timerAvgInMicroseconds();
double latency_avg = timer_avg / num_timed_msgs;
double avg_msg_rate = num_timed_msgs / (timer_avg / 1e6);
double time_us = gpuCyclesToMicroseconds(max_end_time - min_start_time);
double time_s = time_us / 1e6;
double latency_avg = time_us / num_timed_msgs;
double avg_msg_rate = num_timed_msgs / time_s;
double bandwidth_avg_gbs =
static_cast<double>(total_size * bw_factor) / time_s / pow(2, 30);
float total_kern_time_ms;
CHECK_HIP(hipEventElapsedTime(&total_kern_time_ms, start_event, stop_event));
float total_kern_time_s = total_kern_time_ms / 1000;
double time_us = gpuCyclesToMicroseconds(max_end_time - min_start_time);
double time_s = time_us / 1e6;
double bandwidth_avg_gbs =
static_cast<double>(total_size * bw_factor) / time_s / pow(2, 30);
int field_width = 20;
int float_precision = 2;
+54 -55
Dosyayı Görüntüle
@@ -38,61 +38,60 @@ enum TestType {
GetNBITestType = 1,
PutTestType = 2,
PutNBITestType = 3,
GetSwarmTestType = 4,
AMO_FAddTestType = 5,
AMO_FIncTestType = 6,
AMO_FetchTestType = 7,
AMO_FCswapTestType = 8,
AMO_AddTestType = 9,
AMO_IncTestType = 10,
AMO_CswapTestType = 11,
InitTestType = 12,
PingPongTestType = 13,
RandomAccessTestType = 14,
BarrierAllTestType = 15,
SyncAllTestType = 16,
SyncTestType = 17,
CollectTestType = 18,
TeamFCollectTestType = 19,
TeamAllToAllTestType = 20,
AllToAllsTestType = 21,
ShmemPtrTestType = 22,
PTestType = 23,
GTestType = 24,
WGGetTestType = 25,
WGGetNBITestType = 26,
WGPutTestType = 27,
WGPutNBITestType = 28,
WAVEGetTestType = 29,
WAVEGetNBITestType = 30,
WAVEPutTestType = 31,
WAVEPutNBITestType = 32,
TeamBroadcastTestType = 33,
TeamReductionTestType = 34,
TeamCtxGetTestType = 35,
TeamCtxGetNBITestType = 36,
TeamCtxPutTestType = 37,
TeamCtxPutNBITestType = 38,
TeamCtxInfraTestType = 39,
PutNBIMRTestType = 40,
AMO_SetTestType = 41,
AMO_SwapTestType = 42,
AMO_FetchAndTestType = 43,
AMO_FetchOrTestType = 44,
AMO_FetchXorTestType = 45,
AMO_AndTestType = 46,
AMO_OrTestType = 47,
AMO_XorTestType = 48,
PingAllTestType = 49,
PutSignalTestType = 50,
WGPutSignalTestType = 51,
WAVEPutSignalTestType = 52,
PutSignalNBITestType = 53,
WGPutSignalNBITestType = 54,
WAVEPutSignalNBITestType = 55,
SignalFetchTestType = 56,
WGSignalFetchTestType = 57,
WAVESignalFetchTestType = 58,
AMO_FAddTestType = 4,
AMO_FIncTestType = 5,
AMO_FetchTestType = 6,
AMO_FCswapTestType = 7,
AMO_AddTestType = 8,
AMO_IncTestType = 9,
AMO_CswapTestType = 10,
InitTestType = 11,
PingPongTestType = 12,
RandomAccessTestType = 13,
BarrierAllTestType = 14,
SyncAllTestType = 15,
SyncTestType = 16,
CollectTestType = 17,
TeamFCollectTestType = 18,
TeamAllToAllTestType = 19,
AllToAllsTestType = 20,
ShmemPtrTestType = 21,
PTestType = 22,
GTestType = 23,
WGGetTestType = 24,
WGGetNBITestType = 25,
WGPutTestType = 26,
WGPutNBITestType = 27,
WAVEGetTestType = 28,
WAVEGetNBITestType = 29,
WAVEPutTestType = 30,
WAVEPutNBITestType = 31,
TeamBroadcastTestType = 32,
TeamReductionTestType = 33,
TeamCtxGetTestType = 34,
TeamCtxGetNBITestType = 35,
TeamCtxPutTestType = 36,
TeamCtxPutNBITestType = 37,
TeamCtxInfraTestType = 38,
PutNBIMRTestType = 39,
AMO_SetTestType = 40,
AMO_SwapTestType = 41,
AMO_FetchAndTestType = 42,
AMO_FetchOrTestType = 43,
AMO_FetchXorTestType = 44,
AMO_AndTestType = 45,
AMO_OrTestType = 46,
AMO_XorTestType = 47,
PingAllTestType = 48,
PutSignalTestType = 49,
WGPutSignalTestType = 50,
WAVEPutSignalTestType = 51,
PutSignalNBITestType = 52,
WGPutSignalNBITestType = 53,
WAVEPutSignalNBITestType = 54,
SignalFetchTestType = 55,
WGSignalFetchTestType = 56,
WAVESignalFetchTestType = 57,
};
enum OpType { PutType = 0, GetType = 1 };
-10
Dosyayı Görüntüle
@@ -109,16 +109,6 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case PutNBIMRTestType:
min_msg_size = max_msg_size;
break;
case WAVEGetTestType:
case WAVEGetNBITestType:
case WAVEPutTestType:
case WAVEPutNBITestType:
case WGGetTestType:
case WGGetNBITestType:
case WGPutTestType:
case WGPutNBITestType:
min_msg_size = 4;
break;
default:
break;
}
+1
Dosyayı Görüntüle
@@ -27,6 +27,7 @@
#include <cstdint>
#include <rocshmem/rocshmem.hpp>
#include <string>
#include <iostream>
class TesterArguments {
public:
@@ -20,7 +20,7 @@
* IN THE SOFTWARE.
*****************************************************************************/
#include "wave_level_primitives.hpp"
#include "wavefront_primitives.hpp"
#include <rocshmem/rocshmem.hpp>
@@ -31,54 +31,56 @@ using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void WaveLevelPrimitiveTest(int loop, int skip,
__global__ void WaveFrontPrimitiveTest(int loop, int skip,
long long int *start_time,
long long int *end_time, char *s_buf,
char *r_buf, int size, TestType type,
ShmemContextType ctx_type, int wf_size) {
long long int *end_time, char *source,
char *dest, int size, TestType type,
ShmemContextType ctx_type,
int wf_size) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
/**
* Calculate start index for each wavefront for tiled version
* If the number of wavefronts is greater than 1, this kernel performs a
* tiled functional test
*/
// Calculate start index for each wavefront
int wf_id = get_flat_block_id() / wf_size;
int wg_offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size);
int idx = wf_id * size + wg_offset;
s_buf += idx;
r_buf += idx;
int wg_offset = wg_id * ((get_flat_block_size() - 1 ) / wf_size + 1);
int idx = wf_id + wg_offset;
int offset = size * idx;
source += offset;
dest += offset;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
start_time[wg_id] = wall_clock64();
// Ensures all RMA calls from the skip loops are completed
if(is_thread_zero_in_wave()) {
rocshmem_ctx_quiet(ctx);
}
__syncthreads();
start_time[idx] = wall_clock64();
}
switch (type) {
case WAVEGetTestType:
rocshmem_ctx_getmem_wave(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_getmem_wave(ctx, dest, source, size, 1);
break;
case WAVEGetNBITestType:
rocshmem_ctx_getmem_nbi_wave(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_getmem_nbi_wave(ctx, dest, source, size, 1);
break;
case WAVEPutTestType:
rocshmem_ctx_putmem_wave(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_putmem_wave(ctx, dest, source, size, 1);
break;
case WAVEPutNBITestType:
rocshmem_ctx_putmem_nbi_wave(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_putmem_nbi_wave(ctx, dest, source, size, 1);
break;
default:
break;
}
}
rocshmem_ctx_quiet(ctx);
if (hipThreadIdx_x == 0) {
end_time[hipBlockIdx_x] = wall_clock64();
if (is_thread_zero_in_wave()) {
rocshmem_ctx_quiet(ctx);
end_time[idx] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
@@ -88,48 +90,64 @@ __global__ void WaveLevelPrimitiveTest(int loop, int skip,
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args)
WaveFrontPrimitiveTester::WaveFrontPrimitiveTester(TesterArguments args)
: Tester(args) {
s_buf = static_cast<int*>(
rocshmem_malloc(args.max_msg_size * args.num_wgs * num_warps));
r_buf = static_cast<int*>(
rocshmem_malloc(args.max_msg_size * args.num_wgs * num_warps));
size_t buff_size = args.max_msg_size * args.num_wgs * num_warps;
source = (char *)rocshmem_malloc(buff_size);
dest = (char *)rocshmem_malloc(buff_size);
if (source == nullptr || dest == nullptr) {
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
if (source) {
rocshmem_free(source);
}
if (dest) {
rocshmem_free(dest);
}
rocshmem_global_exit(1);
}
for(size_t i = 0; i < buff_size; i++) {
source[i] = static_cast<char>('a' + i % 26);
}
}
WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
rocshmem_free(s_buf);
rocshmem_free(r_buf);
WaveFrontPrimitiveTester::~WaveFrontPrimitiveTester() {
rocshmem_free(source);
rocshmem_free(dest);
}
void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) {
num_elems = (size * args.num_wgs * num_warps) / sizeof(int);
std::iota(s_buf, s_buf + num_elems, 0);
memset(r_buf, 0, size * args.num_wgs * num_warps);
void WaveFrontPrimitiveTester::resetBuffers(uint64_t size) {
size_t buff_size = size * args.num_wgs * num_warps;
memset(dest, '1', buff_size);
}
void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
void WaveFrontPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes,
hipLaunchKernelGGL(WaveFrontPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time,
(char*)s_buf, (char*)r_buf, size, _type, _shmem_context,
source, dest, size, _type, _shmem_context,
wf_size);
num_msgs = (loop + args.skip) * gridSize.x * num_warps;
num_timed_msgs = loop * gridSize.x * num_warps;
}
void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
void WaveFrontPrimitiveTester::verifyResults(uint64_t size) {
int check_id = (_type == WAVEGetTestType || _type == WAVEGetNBITestType)
? 0
: 1;
if (args.myid == check_id) {
for (int i = 0; i < num_elems; i++) {
if (r_buf[i] != i) {
fprintf(stderr, "Data validation error at idx %d\n", i);
fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i);
size_t buff_size = size * args.num_wgs * num_warps;
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
@@ -24,15 +24,14 @@
#define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_
#include "tester.hpp"
#include "../src/util.hpp"
/******************************************************************************
* HOST TESTER CLASS
*****************************************************************************/
class WaveLevelPrimitiveTester : public Tester {
class WaveFrontPrimitiveTester : public Tester {
public:
explicit WaveLevelPrimitiveTester(TesterArguments args);
virtual ~WaveLevelPrimitiveTester();
explicit WaveFrontPrimitiveTester(TesterArguments args);
virtual ~WaveFrontPrimitiveTester();
protected:
virtual void resetBuffers(uint64_t size) override;
@@ -42,9 +41,8 @@ class WaveLevelPrimitiveTester : public Tester {
virtual void verifyResults(uint64_t size) override;
int *s_buf = nullptr;
int *r_buf = nullptr;
int num_elems = 0;
char *source = nullptr;
char *dest = nullptr;
};
#endif
@@ -20,7 +20,7 @@
* IN THE SOFTWARE.
*****************************************************************************/
#include "extended_primitives.hpp"
#include "workgroup_primitives.hpp"
#include <rocshmem/rocshmem.hpp>
@@ -31,51 +31,51 @@ using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void ExtendedPrimitiveTest(int loop, int skip,
__global__ void WorkGroupPrimitiveTest(int loop, int skip,
long long int *start_time,
long long int *end_time, char *s_buf,
char *r_buf, int size, TestType type,
long long int *end_time, char *source,
char *dest, int size, TestType type,
ShmemContextType ctx_type) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
/**
* Calculate start index for each work group for tiled version
* If the number of work groups is greater than 1, this kernel performs a
* tiled functional test
*/
uint64_t idx = size * get_flat_grid_id();
s_buf += idx;
r_buf += idx;
// Calculate start index for each work group
uint64_t offset = size * wg_id;
source += offset;
dest += offset;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
// Ensures all RMA calls from the skip loops are completed
if (is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
__syncthreads();
start_time[wg_id] = wall_clock64();
}
switch (type) {
case WGGetTestType:
rocshmem_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_getmem_wg(ctx, dest, source, size, 1);
break;
case WGGetNBITestType:
rocshmem_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_getmem_nbi_wg(ctx, dest, source, size, 1);
break;
case WGPutTestType:
rocshmem_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_putmem_wg(ctx, dest, source, size, 1);
break;
case WGPutNBITestType:
rocshmem_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
rocshmem_ctx_putmem_nbi_wg(ctx, dest, source, size, 1);
break;
default:
break;
}
}
rocshmem_ctx_quiet(ctx);
if (hipThreadIdx_x == 0) {
if (is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
end_time[wg_id] = wall_clock64();
}
@@ -86,45 +86,63 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip,
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args)
WorkGroupPrimitiveTester::WorkGroupPrimitiveTester(TesterArguments args)
: Tester(args) {
s_buf = static_cast<int*>(rocshmem_malloc(args.max_msg_size * args.num_wgs));
r_buf = static_cast<int*>(rocshmem_malloc(args.max_msg_size * args.num_wgs));
size_t buff_size = args.max_msg_size * args.num_wgs;
source = (char *)rocshmem_malloc(buff_size);
dest = (char *)rocshmem_malloc(buff_size);
if (source == nullptr || dest == nullptr) {
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
if (source) {
rocshmem_free(source);
}
if (dest) {
rocshmem_free(dest);
}
rocshmem_global_exit(1);
}
for(size_t i = 0; i < buff_size; i++) {
source[i] = static_cast<char>('a' + i % 26);
}
}
ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
rocshmem_free(s_buf);
rocshmem_free(r_buf);
WorkGroupPrimitiveTester::~WorkGroupPrimitiveTester() {
rocshmem_free(source);
rocshmem_free(dest);
}
void ExtendedPrimitiveTester::resetBuffers(uint64_t size) {
num_elems = (size * args.num_wgs) / sizeof(int);
std::iota(s_buf, s_buf + num_elems, 0);
memset(r_buf, 0, size * args.num_wgs);
void WorkGroupPrimitiveTester::resetBuffers(uint64_t size) {
size_t buff_size = size * args.num_wgs;
memset(dest, '1', buff_size);
}
void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
void WorkGroupPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
hipLaunchKernelGGL(WorkGroupPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time,
(char*)s_buf, (char*)r_buf, size, _type, _shmem_context);
source, dest, size, _type, _shmem_context);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
void WorkGroupPrimitiveTester::verifyResults(uint64_t size) {
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType)
? 0
: 1;
if (args.myid == check_id) {
for (int i = 0; i < num_elems; i++) {
if (r_buf[i] != i) {
fprintf(stderr, "Data validation error at idx %d\n", i);
fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i);
size_t buff_size = size * args.num_wgs;
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
@@ -24,15 +24,14 @@
#define _EXTENDED_PRIMITIVES_HPP_
#include "tester.hpp"
#include "../src/util.hpp"
/******************************************************************************
* HOST TESTER CLASS
*****************************************************************************/
class ExtendedPrimitiveTester : public Tester {
class WorkGroupPrimitiveTester : public Tester {
public:
explicit ExtendedPrimitiveTester(TesterArguments args);
virtual ~ExtendedPrimitiveTester();
explicit WorkGroupPrimitiveTester(TesterArguments args);
virtual ~WorkGroupPrimitiveTester();
protected:
virtual void resetBuffers(uint64_t size) override;
@@ -42,9 +41,8 @@ class ExtendedPrimitiveTester : public Tester {
virtual void verifyResults(uint64_t size) override;
int *s_buf = nullptr;
int *r_buf = nullptr;
int num_elems = 0;
char *source = nullptr;
char *dest = nullptr;
};
#endif