Update RMA functional tests (#50)
* Update primitive tests for multi-workgroup support
* Update workgroup primitive tests for multi-workgroup support
* Update workfront primitive tests for multi-workgroup support
* Update team based primitive tests for multi-workgroup support
* Update RMA functional tests to capture timing after quiet call
- Modified RMA functional tests to record the time after a `quiet` call in thread, wavefront, and workgroup RMA calls.
* Improve error handling and memory management
- Replaced `cout` with `cerr` for improved error reporting.
- Ensured all allocated memory is freed when `rocshmem_malloc` fails.
* Update start time in primitive tests and latency calculations
- Modified primitive tests to capture the earliest start time.
- Updated latency calculations in functional tests.
* Remove `GetSwarmTester`
* Update start time in team primitive tests
* Invoke quiet call from a single thread within a block on a rocshmem context
[ROCm/rocshmem commit: aa3121a967]
Bu işleme şunda yer alıyor:
işlemeyi yapan:
GitHub
ebeveyn
9b187a2e44
işleme
e16bb62767
@@ -26,61 +26,60 @@ declare -A TEST_NUMBERS=(
|
||||
["getnbi"]="1"
|
||||
["put"]="2"
|
||||
["putnbi"]="3"
|
||||
["getswarm"]="4"
|
||||
["amo_fadd"]="5"
|
||||
["amo_finc"]="6"
|
||||
["amo_fetch"]="7"
|
||||
["amo_fcswap"]="8"
|
||||
["amo_add"]="9"
|
||||
["amo_inc"]="10"
|
||||
["amo_cswap"]="11"
|
||||
["init"]="12"
|
||||
["pingpong"]="13"
|
||||
["randomaccess"]="14"
|
||||
["barrierall"]="15"
|
||||
["syncall"]="16"
|
||||
["sync"]="17"
|
||||
["collect"]="18"
|
||||
["fcollect"]="19"
|
||||
["alltoall"]="20"
|
||||
["alltoalls"]="21"
|
||||
["shmemptr"]="22"
|
||||
["p"]="23"
|
||||
["g"]="24"
|
||||
["wgget"]="25"
|
||||
["wggetnbi"]="26"
|
||||
["wgput"]="27"
|
||||
["wgputnbi"]="28"
|
||||
["waveget"]="29"
|
||||
["wavegetnbi"]="30"
|
||||
["waveput"]="31"
|
||||
["waveputnbi"]="32"
|
||||
["teambroadcast"]="33"
|
||||
["teamreduction"]="34"
|
||||
["teamctxget"]="35"
|
||||
["teamctxgetnbi"]="36"
|
||||
["teamctxput"]="37"
|
||||
["teamctxputnbi"]="38"
|
||||
["teamctxinfra"]="39"
|
||||
["putnbimr"]="40"
|
||||
["amo_set"]="41"
|
||||
["amo_swap"]="42"
|
||||
["amo_fetchand"]="43"
|
||||
["amo_fetchor"]="44"
|
||||
["amo_fetchxor"]="45"
|
||||
["amo_and"]="46"
|
||||
["amo_or"]="47"
|
||||
["amo_xor"]="48"
|
||||
["pingall"]="49"
|
||||
["putsignal"]="50"
|
||||
["wgputsignal"]="51"
|
||||
["waveputsignal"]="52"
|
||||
["putsignalnbi"]="53"
|
||||
["wgputsignalnbi"]="54"
|
||||
["waveputsignalnbi"]="55"
|
||||
["signalfetch"]="56"
|
||||
["wgsignalfetch"]="57"
|
||||
["wavesignalfetch"]="58"
|
||||
["amo_fadd"]="4"
|
||||
["amo_finc"]="5"
|
||||
["amo_fetch"]="6"
|
||||
["amo_fcswap"]="7"
|
||||
["amo_add"]="8"
|
||||
["amo_inc"]="9"
|
||||
["amo_cswap"]="10"
|
||||
["init"]="11"
|
||||
["pingpong"]="12"
|
||||
["randomaccess"]="13"
|
||||
["barrierall"]="14"
|
||||
["syncall"]="15"
|
||||
["sync"]="16"
|
||||
["collect"]="17"
|
||||
["fcollect"]="18"
|
||||
["alltoall"]="19"
|
||||
["alltoalls"]="20"
|
||||
["shmemptr"]="21"
|
||||
["p"]="22"
|
||||
["g"]="23"
|
||||
["wgget"]="24"
|
||||
["wggetnbi"]="25"
|
||||
["wgput"]="26"
|
||||
["wgputnbi"]="27"
|
||||
["waveget"]="28"
|
||||
["wavegetnbi"]="29"
|
||||
["waveput"]="30"
|
||||
["waveputnbi"]="31"
|
||||
["teambroadcast"]="32"
|
||||
["teamreduction"]="33"
|
||||
["teamctxget"]="34"
|
||||
["teamctxgetnbi"]="35"
|
||||
["teamctxput"]="36"
|
||||
["teamctxputnbi"]="37"
|
||||
["teamctxinfra"]="38"
|
||||
["putnbimr"]="39"
|
||||
["amo_set"]="40"
|
||||
["amo_swap"]="41"
|
||||
["amo_fetchand"]="42"
|
||||
["amo_fetchor"]="43"
|
||||
["amo_fetchxor"]="44"
|
||||
["amo_and"]="45"
|
||||
["amo_or"]="46"
|
||||
["amo_xor"]="47"
|
||||
["pingall"]="48"
|
||||
["putsignal"]="49"
|
||||
["wgputsignal"]="50"
|
||||
["waveputsignal"]="51"
|
||||
["putsignalnbi"]="52"
|
||||
["wgputsignalnbi"]="53"
|
||||
["waveputsignalnbi"]="54"
|
||||
["signalfetch"]="55"
|
||||
["wgsignalfetch"]="56"
|
||||
["wavesignalfetch"]="57"
|
||||
)
|
||||
|
||||
ExecTest() {
|
||||
@@ -159,7 +158,8 @@ TestRMA() {
|
||||
ExecTest "waveput" 2 2 128 1048576
|
||||
ExecTest "waveput" 2 16 128 8
|
||||
|
||||
ExecTest "teamctxput" 2 1 1 1048576
|
||||
ExecTest "teamctxput" 2 4 128 1024
|
||||
ExecTest "teamctxput" 2 16 256 1024
|
||||
|
||||
ExecTest "get" 2 1 1 1048576
|
||||
ExecTest "get" 2 1 1024 512
|
||||
@@ -177,7 +177,8 @@ TestRMA() {
|
||||
ExecTest "waveget" 2 2 128 1048576
|
||||
ExecTest "waveget" 2 16 128 8
|
||||
|
||||
ExecTest "teamctxget" 2 1 1 1048576
|
||||
ExecTest "teamctxget" 2 4 128 1024
|
||||
ExecTest "teamctxget" 2 16 256 1024
|
||||
|
||||
ExecTest "g" 2 1 1 1048576
|
||||
ExecTest "g" 2 1 1024 512
|
||||
@@ -211,7 +212,8 @@ TestRMA() {
|
||||
ExecTest "waveputnbi" 2 2 128 1048576
|
||||
ExecTest "waveputnbi" 2 16 128 8
|
||||
|
||||
ExecTest "teamctxputnbi" 2 1 1 1048576
|
||||
ExecTest "teamctxputnbi" 2 4 128 1024
|
||||
ExecTest "teamctxputnbi" 2 16 256 1024
|
||||
|
||||
ExecTest "getnbi" 2 1 1 1048576
|
||||
ExecTest "getnbi" 2 1 1024 512
|
||||
@@ -229,7 +231,8 @@ TestRMA() {
|
||||
ExecTest "wavegetnbi" 2 2 128 1048576
|
||||
ExecTest "wavegetnbi" 2 16 128 8
|
||||
|
||||
ExecTest "teamctxgetnbi" 2 1 1 1048576
|
||||
ExecTest "teamctxgetnbi" 2 4 128 1024
|
||||
ExecTest "teamctxgetnbi" 2 16 256 1024
|
||||
}
|
||||
|
||||
TestAMO() {
|
||||
|
||||
@@ -50,14 +50,13 @@ target_sources(
|
||||
amo_bitwise_tester.cpp
|
||||
amo_extended_tester.cpp
|
||||
amo_standard_tester.cpp
|
||||
swarm_tester.cpp
|
||||
random_access_tester.cpp
|
||||
shmem_ptr_tester.cpp
|
||||
signaling_operations_tester.cpp
|
||||
signaling_operations_tester.hpp
|
||||
extended_primitives.cpp
|
||||
workgroup_primitives.cpp
|
||||
empty_tester.cpp
|
||||
wave_level_primitives.cpp
|
||||
wavefront_primitives.cpp
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
|
||||
@@ -30,43 +30,65 @@ using namespace rocshmem;
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, char *s_buf,
|
||||
char *r_buf, int size, TestType type,
|
||||
ShmemContextType ctx_type) {
|
||||
long long int *end_time, char *source,
|
||||
char *dest, int size, TestType type,
|
||||
ShmemContextType ctx_type, int wf_size) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
int t_id = get_flat_block_id();
|
||||
int wf_id = t_id / wf_size;
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
/**
|
||||
* Shared array to capture the start time for each wavefront
|
||||
* Max threads per block = 1024, wavefront size = 64 (in most GPUs)
|
||||
* Maximum array size required = 1024/64 = 16
|
||||
*/
|
||||
__shared__ long long int wf_start_time[16];
|
||||
|
||||
/**
|
||||
* Calculate start index for each thread within the grid
|
||||
*/
|
||||
uint64_t offset = size * get_flat_id();
|
||||
source += offset;
|
||||
dest += offset;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
__syncthreads();
|
||||
start_time[wg_id] = wall_clock64();
|
||||
__syncthreads();
|
||||
// Ensures all RMA calls from the skip loops are completed
|
||||
if(is_thread_zero_in_block()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
}
|
||||
__syncthreads();
|
||||
// Capture the start time of each wavefront to identify the earliest one
|
||||
wf_start_time[wf_id] = wall_clock64();
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case GetTestType:
|
||||
rocshmem_ctx_getmem(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_getmem(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case GetNBITestType:
|
||||
rocshmem_ctx_getmem_nbi(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_getmem_nbi(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case PutTestType:
|
||||
rocshmem_ctx_putmem(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_putmem(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case PutNBITestType:
|
||||
rocshmem_ctx_putmem_nbi(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_putmem_nbi(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case PTestType:
|
||||
for (int s = 0; s < size; s++) {
|
||||
char val = s_buf[s];
|
||||
rocshmem_ctx_char_p(ctx, &r_buf[s], val, 1);
|
||||
char val = source[s];
|
||||
rocshmem_ctx_char_p(ctx, &dest[s], val, 1);
|
||||
}
|
||||
break;
|
||||
case GTestType:
|
||||
for (int s = 0; s < size; s++) {
|
||||
char ret = rocshmem_ctx_char_g(ctx, &s_buf[s], 1);
|
||||
r_buf[s] = ret;
|
||||
char ret = rocshmem_ctx_char_g(ctx, &source[s], 1);
|
||||
dest[s] = ret;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@@ -74,12 +96,28 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
|
||||
}
|
||||
}
|
||||
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
__syncthreads();
|
||||
if(is_thread_zero_in_block()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
}
|
||||
|
||||
/**
|
||||
* End time of the last wavefront is recorded by overwriting
|
||||
* the value previously set by earlier wavefronts.
|
||||
*/
|
||||
end_time[wg_id] = wall_clock64();
|
||||
|
||||
// Find the earliest start time
|
||||
int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1;
|
||||
for (int i = num_wfs / 2; i > 0; i >>= 1 ) {
|
||||
if(t_id < i) {
|
||||
wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
if (t_id == 0) {
|
||||
start_time[wg_id] = wf_start_time[0];
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
@@ -90,18 +128,35 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
PrimitiveTester::PrimitiveTester(TesterArguments args) : Tester(args) {
|
||||
s_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
|
||||
r_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
|
||||
size_t buff_size = args.max_msg_size * args.wg_size * args.num_wgs;
|
||||
source = (char *)rocshmem_malloc(buff_size);
|
||||
dest = (char *)rocshmem_malloc(buff_size);
|
||||
|
||||
if (source == nullptr || dest == nullptr) {
|
||||
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
|
||||
if (source) {
|
||||
rocshmem_free(source);
|
||||
}
|
||||
if (dest) {
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < buff_size; i++) {
|
||||
source[i] = static_cast<char>('a' + i % 26);
|
||||
}
|
||||
}
|
||||
|
||||
PrimitiveTester::~PrimitiveTester() {
|
||||
rocshmem_free(s_buf);
|
||||
rocshmem_free(r_buf);
|
||||
rocshmem_free(source);
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
|
||||
void PrimitiveTester::resetBuffers(uint64_t size) {
|
||||
memset(s_buf, '0', args.max_msg_size * args.wg_size);
|
||||
memset(r_buf, '1', args.max_msg_size * args.wg_size);
|
||||
size_t buff_size = size * args.wg_size * args.num_wgs;
|
||||
memset(dest, '1', buff_size);
|
||||
}
|
||||
|
||||
void PrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
@@ -109,11 +164,11 @@ void PrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(PrimitiveTest, gridSize, blockSize, shared_bytes, stream,
|
||||
loop, args.skip, start_time, end_time, s_buf, r_buf,
|
||||
size, _type, _shmem_context);
|
||||
loop, args.skip, start_time, end_time, source, dest,
|
||||
size, _type, _shmem_context, wf_size);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop;
|
||||
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
|
||||
num_timed_msgs = loop * gridSize.x * blockSize.x;
|
||||
}
|
||||
|
||||
void PrimitiveTester::verifyResults(uint64_t size) {
|
||||
@@ -123,10 +178,12 @@ void PrimitiveTester::verifyResults(uint64_t size) {
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
|
||||
size_t buff_size = size * args.wg_size * args.num_wgs;
|
||||
for (uint64_t i = 0; i < buff_size; i++) {
|
||||
if (dest[i] != source[i]) {
|
||||
std::cerr << "Data validation error at idx " << i << std::endl;
|
||||
std::cerr << " Got " << dest[i] << ", Expected "
|
||||
<< source[i] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,8 +41,8 @@ class PrimitiveTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
char *s_buf = nullptr;
|
||||
char *r_buf = nullptr;
|
||||
char *source = nullptr;
|
||||
char *dest = nullptr;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "swarm_tester.hpp"
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void GetSwarmTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, char *s_buf,
|
||||
char *r_buf, int size, ShmemContextType ctx_type) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
int provided;
|
||||
rocshmem_wg_init_thread(ROCSHMEM_THREAD_MULTIPLE, &provided);
|
||||
assert(provided == ROCSHMEM_THREAD_MULTIPLE);
|
||||
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int index = hipThreadIdx_x * size;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
rocshmem_ctx_getmem(ctx, &r_buf[index], &s_buf[index], size, 1);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// atomicAdd((unsigned long long *)&timer[hipBlockIdx_x],
|
||||
// rocshmem_timer() - start);
|
||||
|
||||
end_time[wg_id] = wall_clock64();
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
GetSwarmTester::GetSwarmTester(TesterArguments args) : PrimitiveTester(args) {}
|
||||
|
||||
GetSwarmTester::~GetSwarmTester() {}
|
||||
|
||||
void GetSwarmTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(GetSwarmTest, gridSize, blockSize, shared_bytes, stream,
|
||||
loop, args.skip, start_time, end_time, s_buf, r_buf, size,
|
||||
_shmem_context);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
|
||||
num_timed_msgs = loop * gridSize.x * blockSize.x;
|
||||
}
|
||||
|
||||
void GetSwarmTester::verifyResults(uint64_t size) {
|
||||
if (args.myid == 0) {
|
||||
for (uint64_t i = 0; i < size * args.wg_size; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef _SWARM_TESTER_HPP_
|
||||
#define _SWARM_TESTER_HPP_
|
||||
|
||||
#include "primitive_tester.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void GetSwarmTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, char *s_buf,
|
||||
char *r_buf, int size);
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class GetSwarmTester : public PrimitiveTester {
|
||||
public:
|
||||
explicit GetSwarmTester(TesterArguments args);
|
||||
virtual ~GetSwarmTester();
|
||||
|
||||
protected:
|
||||
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) override;
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -32,43 +32,83 @@ rocshmem_team_t team_primitive_world_dup;
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, char *s_buf,
|
||||
char *r_buf, int size, TestType type,
|
||||
ShmemContextType ctx_type,
|
||||
long long int *end_time, char *source,
|
||||
char *dest, int size, TestType type,
|
||||
ShmemContextType ctx_type, int wf_size,
|
||||
rocshmem_team_t team) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
int t_id = get_flat_block_id();
|
||||
int wf_id = t_id / wf_size;
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_team_create_ctx(team, ctx_type, &ctx);
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
/**
|
||||
* Shared array to capture the start time for each wavefront
|
||||
* Max threads per block = 1024, wavefront size = 64 (in most GPUs)
|
||||
* Maximum array size required = 1024/64 = 16
|
||||
*/
|
||||
__shared__ long long int wf_start_time[16];
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
switch (type) {
|
||||
case TeamCtxGetTestType:
|
||||
rocshmem_ctx_getmem(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case TeamCtxGetNBITestType:
|
||||
rocshmem_ctx_getmem_nbi(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case TeamCtxPutTestType:
|
||||
rocshmem_ctx_putmem(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case TeamCtxPutNBITestType:
|
||||
rocshmem_ctx_putmem_nbi(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
/**
|
||||
* Calculate start index for each thread within the grid
|
||||
*/
|
||||
uint64_t offset = size * get_flat_id();
|
||||
source += offset;
|
||||
dest += offset;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
__syncthreads();
|
||||
// Ensures all RMA calls from the skip loops are completed
|
||||
if(is_thread_zero_in_block()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
}
|
||||
__syncthreads();
|
||||
// Capture the start time of each wavefront to identify the earliest one
|
||||
wf_start_time[wf_id] = wall_clock64();
|
||||
}
|
||||
switch (type) {
|
||||
case TeamCtxGetTestType:
|
||||
rocshmem_ctx_getmem(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case TeamCtxGetNBITestType:
|
||||
rocshmem_ctx_getmem_nbi(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case TeamCtxPutTestType:
|
||||
rocshmem_ctx_putmem(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case TeamCtxPutNBITestType:
|
||||
rocshmem_ctx_putmem_nbi(ctx, dest, source, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if(is_thread_zero_in_block()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
}
|
||||
|
||||
end_time[wg_id] = wall_clock64();
|
||||
/**
|
||||
* End time of the last wavefront is recorded by overwriting
|
||||
* the value previously set by earlier wavefronts.
|
||||
*/
|
||||
end_time[wg_id] = wall_clock64();
|
||||
|
||||
// Find the earliest start time
|
||||
int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1;
|
||||
for (int i = num_wfs / 2; i > 0; i >>= 1 ) {
|
||||
if(t_id < i) {
|
||||
wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (t_id == 0) {
|
||||
start_time[wg_id] = wf_start_time[0];
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
@@ -80,18 +120,35 @@ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_ti
|
||||
*****************************************************************************/
|
||||
TeamCtxPrimitiveTester::TeamCtxPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
|
||||
r_buf = (char *)rocshmem_malloc(args.max_msg_size * args.wg_size);
|
||||
size_t buff_size = args.max_msg_size * args.wg_size * args.num_wgs;
|
||||
source = (char *)rocshmem_malloc(buff_size);
|
||||
dest = (char *)rocshmem_malloc(buff_size);
|
||||
|
||||
if (source == nullptr || dest == nullptr) {
|
||||
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
|
||||
if (source) {
|
||||
rocshmem_free(source);
|
||||
}
|
||||
if (dest) {
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < buff_size; i++) {
|
||||
source[i] = static_cast<char>('a' + i % 26);
|
||||
}
|
||||
}
|
||||
|
||||
TeamCtxPrimitiveTester::~TeamCtxPrimitiveTester() {
|
||||
rocshmem_free(s_buf);
|
||||
rocshmem_free(r_buf);
|
||||
rocshmem_free(source);
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
|
||||
void TeamCtxPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
memset(s_buf, '0', args.max_msg_size * args.wg_size);
|
||||
memset(r_buf, '1', args.max_msg_size * args.wg_size);
|
||||
size_t buff_size = size * args.wg_size * args.num_wgs;
|
||||
memset(dest, '1', buff_size);
|
||||
}
|
||||
|
||||
void TeamCtxPrimitiveTester::preLaunchKernel() {
|
||||
@@ -107,12 +164,12 @@ void TeamCtxPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(TeamCtxPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time, s_buf,
|
||||
r_buf, size, _type, _shmem_context,
|
||||
stream, loop, args.skip, start_time, end_time, source,
|
||||
dest, size, _type, _shmem_context, wf_size,
|
||||
team_primitive_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
|
||||
num_timed_msgs = loop * gridSize.x * blockSize.x;
|
||||
}
|
||||
|
||||
void TeamCtxPrimitiveTester::postLaunchKernel() {
|
||||
@@ -124,10 +181,12 @@ void TeamCtxPrimitiveTester::verifyResults(uint64_t size) {
|
||||
(_type == TeamCtxGetTestType || _type == TeamCtxGetNBITestType) ? 0 : 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
|
||||
size_t buff_size = size * args.wg_size * args.num_wgs;
|
||||
for (uint64_t i = 0; i < buff_size; i++) {
|
||||
if (dest[i] != source[i]) {
|
||||
std::cerr << "Data validation error at idx " << i << std::endl;
|
||||
std::cerr << " Got " << dest[i] << ", Expected "
|
||||
<< source[i] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,8 +45,8 @@ class TeamCtxPrimitiveTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
char *s_buf = nullptr;
|
||||
char *r_buf = nullptr;
|
||||
char *source = nullptr;
|
||||
char *dest = nullptr;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -35,7 +35,6 @@
|
||||
#include "amo_standard_tester.hpp"
|
||||
#include "barrier_all_tester.hpp"
|
||||
#include "empty_tester.hpp"
|
||||
#include "extended_primitives.hpp"
|
||||
#include "ping_all_tester.hpp"
|
||||
#include "ping_pong_tester.hpp"
|
||||
#include "primitive_mr_tester.hpp"
|
||||
@@ -43,7 +42,6 @@
|
||||
#include "random_access_tester.hpp"
|
||||
#include "shmem_ptr_tester.hpp"
|
||||
#include "signaling_operations_tester.hpp"
|
||||
#include "swarm_tester.hpp"
|
||||
#include "sync_tester.hpp"
|
||||
#include "team_alltoall_tester.hpp"
|
||||
#include "team_broadcast_tester.hpp"
|
||||
@@ -51,7 +49,8 @@
|
||||
#include "team_ctx_primitive_tester.hpp"
|
||||
#include "team_fcollect_tester.hpp"
|
||||
#include "team_reduction_tester.hpp"
|
||||
#include "wave_level_primitives.hpp"
|
||||
#include "wavefront_primitives.hpp"
|
||||
#include "workgroup_primitives.hpp"
|
||||
|
||||
Tester::Tester(TesterArguments args) : args(args) {
|
||||
_type = (TestType)args.algorithm;
|
||||
@@ -66,6 +65,16 @@ Tester::Tester(TesterArguments args) : args(args) {
|
||||
CHECK_HIP(hipDeviceGetAttribute(&wall_clk_rate,
|
||||
hipDeviceAttributeWallClockRate, device_id));
|
||||
num_timers = args.num_wgs;
|
||||
switch (_type) {
|
||||
case WAVEGetTestType:
|
||||
case WAVEGetNBITestType:
|
||||
case WAVEPutTestType:
|
||||
case WAVEPutNBITestType:
|
||||
num_timers = args.num_wgs * num_warps;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
CHECK_HIP(hipMalloc((void**)&timer, sizeof(long long int) * num_timers));
|
||||
CHECK_HIP(hipMalloc((void**)&start_time, sizeof(long long int) * num_timers));
|
||||
CHECK_HIP(hipMalloc((void**)&end_time, sizeof(long long int) * num_timers));
|
||||
@@ -137,10 +146,6 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) std::cout << "G Test ###" << std::endl;
|
||||
testers.push_back(new PrimitiveTester(args));
|
||||
return testers;
|
||||
case GetSwarmTestType:
|
||||
if (rank == 0) std::cout << "Get Swarm ###" << std::endl;
|
||||
testers.push_back(new GetSwarmTester(args));
|
||||
return testers;
|
||||
case TeamReductionTestType:
|
||||
if (rank == 0)
|
||||
std::cout << "All-to-All Team-based Reduction ###" << std::endl;
|
||||
@@ -309,22 +314,22 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
case WGGetTestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Blocking WG level Gets ###" << std::endl;
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
testers.push_back(new WorkGroupPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGGetNBITestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Non-Blocking WG level Gets ###" << std::endl;
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
testers.push_back(new WorkGroupPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutTestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Blocking WG level Puts ###" << std::endl;
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
testers.push_back(new WorkGroupPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutNBITestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Non-Blocking WG level Puts ###" << std::endl;
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
testers.push_back(new WorkGroupPrimitiveTester(args));
|
||||
return testers;
|
||||
case PutNBIMRTestType:
|
||||
if (rank == 0)
|
||||
@@ -334,22 +339,22 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
case WAVEGetTestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Blocking WAVE level Gets ###" << std::endl;
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
testers.push_back(new WaveFrontPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEGetNBITestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Non-Blocking WAVE level Gets ###" << std::endl;
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
testers.push_back(new WaveFrontPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEPutTestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Blocking WAVE level Puts ###" << std::endl;
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
testers.push_back(new WaveFrontPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEPutNBITestType:
|
||||
if (rank == 0)
|
||||
std::cout << "Non-Blocking WAVE level Puts ###" << std::endl;
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
testers.push_back(new WaveFrontPrimitiveTester(args));
|
||||
return testers;
|
||||
case PutSignalTestType:
|
||||
if (rank == 0) std::cout << "Putmem Signal ###" << std::endl;
|
||||
@@ -495,18 +500,21 @@ void Tester::print(uint64_t size) {
|
||||
*/
|
||||
uint64_t total_size = size * num_timed_msgs;
|
||||
double timer_avg = timerAvgInMicroseconds();
|
||||
double latency_avg = timer_avg / num_timed_msgs;
|
||||
double avg_msg_rate = num_timed_msgs / (timer_avg / 1e6);
|
||||
|
||||
double time_us = gpuCyclesToMicroseconds(max_end_time - min_start_time);
|
||||
double time_s = time_us / 1e6;
|
||||
|
||||
double latency_avg = time_us / num_timed_msgs;
|
||||
|
||||
double avg_msg_rate = num_timed_msgs / time_s;
|
||||
|
||||
double bandwidth_avg_gbs =
|
||||
static_cast<double>(total_size * bw_factor) / time_s / pow(2, 30);
|
||||
|
||||
float total_kern_time_ms;
|
||||
CHECK_HIP(hipEventElapsedTime(&total_kern_time_ms, start_event, stop_event));
|
||||
float total_kern_time_s = total_kern_time_ms / 1000;
|
||||
|
||||
double time_us = gpuCyclesToMicroseconds(max_end_time - min_start_time);
|
||||
double time_s = time_us / 1e6;
|
||||
double bandwidth_avg_gbs =
|
||||
static_cast<double>(total_size * bw_factor) / time_s / pow(2, 30);
|
||||
|
||||
int field_width = 20;
|
||||
int float_precision = 2;
|
||||
|
||||
|
||||
@@ -38,61 +38,60 @@ enum TestType {
|
||||
GetNBITestType = 1,
|
||||
PutTestType = 2,
|
||||
PutNBITestType = 3,
|
||||
GetSwarmTestType = 4,
|
||||
AMO_FAddTestType = 5,
|
||||
AMO_FIncTestType = 6,
|
||||
AMO_FetchTestType = 7,
|
||||
AMO_FCswapTestType = 8,
|
||||
AMO_AddTestType = 9,
|
||||
AMO_IncTestType = 10,
|
||||
AMO_CswapTestType = 11,
|
||||
InitTestType = 12,
|
||||
PingPongTestType = 13,
|
||||
RandomAccessTestType = 14,
|
||||
BarrierAllTestType = 15,
|
||||
SyncAllTestType = 16,
|
||||
SyncTestType = 17,
|
||||
CollectTestType = 18,
|
||||
TeamFCollectTestType = 19,
|
||||
TeamAllToAllTestType = 20,
|
||||
AllToAllsTestType = 21,
|
||||
ShmemPtrTestType = 22,
|
||||
PTestType = 23,
|
||||
GTestType = 24,
|
||||
WGGetTestType = 25,
|
||||
WGGetNBITestType = 26,
|
||||
WGPutTestType = 27,
|
||||
WGPutNBITestType = 28,
|
||||
WAVEGetTestType = 29,
|
||||
WAVEGetNBITestType = 30,
|
||||
WAVEPutTestType = 31,
|
||||
WAVEPutNBITestType = 32,
|
||||
TeamBroadcastTestType = 33,
|
||||
TeamReductionTestType = 34,
|
||||
TeamCtxGetTestType = 35,
|
||||
TeamCtxGetNBITestType = 36,
|
||||
TeamCtxPutTestType = 37,
|
||||
TeamCtxPutNBITestType = 38,
|
||||
TeamCtxInfraTestType = 39,
|
||||
PutNBIMRTestType = 40,
|
||||
AMO_SetTestType = 41,
|
||||
AMO_SwapTestType = 42,
|
||||
AMO_FetchAndTestType = 43,
|
||||
AMO_FetchOrTestType = 44,
|
||||
AMO_FetchXorTestType = 45,
|
||||
AMO_AndTestType = 46,
|
||||
AMO_OrTestType = 47,
|
||||
AMO_XorTestType = 48,
|
||||
PingAllTestType = 49,
|
||||
PutSignalTestType = 50,
|
||||
WGPutSignalTestType = 51,
|
||||
WAVEPutSignalTestType = 52,
|
||||
PutSignalNBITestType = 53,
|
||||
WGPutSignalNBITestType = 54,
|
||||
WAVEPutSignalNBITestType = 55,
|
||||
SignalFetchTestType = 56,
|
||||
WGSignalFetchTestType = 57,
|
||||
WAVESignalFetchTestType = 58,
|
||||
AMO_FAddTestType = 4,
|
||||
AMO_FIncTestType = 5,
|
||||
AMO_FetchTestType = 6,
|
||||
AMO_FCswapTestType = 7,
|
||||
AMO_AddTestType = 8,
|
||||
AMO_IncTestType = 9,
|
||||
AMO_CswapTestType = 10,
|
||||
InitTestType = 11,
|
||||
PingPongTestType = 12,
|
||||
RandomAccessTestType = 13,
|
||||
BarrierAllTestType = 14,
|
||||
SyncAllTestType = 15,
|
||||
SyncTestType = 16,
|
||||
CollectTestType = 17,
|
||||
TeamFCollectTestType = 18,
|
||||
TeamAllToAllTestType = 19,
|
||||
AllToAllsTestType = 20,
|
||||
ShmemPtrTestType = 21,
|
||||
PTestType = 22,
|
||||
GTestType = 23,
|
||||
WGGetTestType = 24,
|
||||
WGGetNBITestType = 25,
|
||||
WGPutTestType = 26,
|
||||
WGPutNBITestType = 27,
|
||||
WAVEGetTestType = 28,
|
||||
WAVEGetNBITestType = 29,
|
||||
WAVEPutTestType = 30,
|
||||
WAVEPutNBITestType = 31,
|
||||
TeamBroadcastTestType = 32,
|
||||
TeamReductionTestType = 33,
|
||||
TeamCtxGetTestType = 34,
|
||||
TeamCtxGetNBITestType = 35,
|
||||
TeamCtxPutTestType = 36,
|
||||
TeamCtxPutNBITestType = 37,
|
||||
TeamCtxInfraTestType = 38,
|
||||
PutNBIMRTestType = 39,
|
||||
AMO_SetTestType = 40,
|
||||
AMO_SwapTestType = 41,
|
||||
AMO_FetchAndTestType = 42,
|
||||
AMO_FetchOrTestType = 43,
|
||||
AMO_FetchXorTestType = 44,
|
||||
AMO_AndTestType = 45,
|
||||
AMO_OrTestType = 46,
|
||||
AMO_XorTestType = 47,
|
||||
PingAllTestType = 48,
|
||||
PutSignalTestType = 49,
|
||||
WGPutSignalTestType = 50,
|
||||
WAVEPutSignalTestType = 51,
|
||||
PutSignalNBITestType = 52,
|
||||
WGPutSignalNBITestType = 53,
|
||||
WAVEPutSignalNBITestType = 54,
|
||||
SignalFetchTestType = 55,
|
||||
WGSignalFetchTestType = 56,
|
||||
WAVESignalFetchTestType = 57,
|
||||
};
|
||||
|
||||
enum OpType { PutType = 0, GetType = 1 };
|
||||
|
||||
@@ -109,16 +109,6 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case PutNBIMRTestType:
|
||||
min_msg_size = max_msg_size;
|
||||
break;
|
||||
case WAVEGetTestType:
|
||||
case WAVEGetNBITestType:
|
||||
case WAVEPutTestType:
|
||||
case WAVEPutNBITestType:
|
||||
case WGGetTestType:
|
||||
case WGGetNBITestType:
|
||||
case WGPutTestType:
|
||||
case WGPutNBITestType:
|
||||
min_msg_size = 4;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include <cstdint>
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
class TesterArguments {
|
||||
public:
|
||||
|
||||
+61
-43
@@ -20,7 +20,7 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "wave_level_primitives.hpp"
|
||||
#include "wavefront_primitives.hpp"
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
@@ -31,54 +31,56 @@ using namespace rocshmem;
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void WaveLevelPrimitiveTest(int loop, int skip,
|
||||
__global__ void WaveFrontPrimitiveTest(int loop, int skip,
|
||||
long long int *start_time,
|
||||
long long int *end_time, char *s_buf,
|
||||
char *r_buf, int size, TestType type,
|
||||
ShmemContextType ctx_type, int wf_size) {
|
||||
long long int *end_time, char *source,
|
||||
char *dest, int size, TestType type,
|
||||
ShmemContextType ctx_type,
|
||||
int wf_size) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
/**
|
||||
* Calculate start index for each wavefront for tiled version
|
||||
* If the number of wavefronts is greater than 1, this kernel performs a
|
||||
* tiled functional test
|
||||
*/
|
||||
// Calculate start index for each wavefront
|
||||
int wf_id = get_flat_block_id() / wf_size;
|
||||
int wg_offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size);
|
||||
int idx = wf_id * size + wg_offset;
|
||||
s_buf += idx;
|
||||
r_buf += idx;
|
||||
int wg_offset = wg_id * ((get_flat_block_size() - 1 ) / wf_size + 1);
|
||||
int idx = wf_id + wg_offset;
|
||||
int offset = size * idx;
|
||||
source += offset;
|
||||
dest += offset;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
// Ensures all RMA calls from the skip loops are completed
|
||||
if(is_thread_zero_in_wave()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
}
|
||||
__syncthreads();
|
||||
start_time[idx] = wall_clock64();
|
||||
}
|
||||
switch (type) {
|
||||
case WAVEGetTestType:
|
||||
rocshmem_ctx_getmem_wave(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_getmem_wave(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case WAVEGetNBITestType:
|
||||
rocshmem_ctx_getmem_nbi_wave(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_getmem_nbi_wave(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case WAVEPutTestType:
|
||||
rocshmem_ctx_putmem_wave(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_putmem_wave(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case WAVEPutNBITestType:
|
||||
rocshmem_ctx_putmem_nbi_wave(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_putmem_nbi_wave(ctx, dest, source, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[hipBlockIdx_x] = wall_clock64();
|
||||
if (is_thread_zero_in_wave()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
end_time[idx] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
@@ -88,48 +90,64 @@ __global__ void WaveLevelPrimitiveTest(int loop, int skip,
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args)
|
||||
WaveFrontPrimitiveTester::WaveFrontPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = static_cast<int*>(
|
||||
rocshmem_malloc(args.max_msg_size * args.num_wgs * num_warps));
|
||||
r_buf = static_cast<int*>(
|
||||
rocshmem_malloc(args.max_msg_size * args.num_wgs * num_warps));
|
||||
size_t buff_size = args.max_msg_size * args.num_wgs * num_warps;
|
||||
source = (char *)rocshmem_malloc(buff_size);
|
||||
dest = (char *)rocshmem_malloc(buff_size);
|
||||
|
||||
if (source == nullptr || dest == nullptr) {
|
||||
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
|
||||
if (source) {
|
||||
rocshmem_free(source);
|
||||
}
|
||||
if (dest) {
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < buff_size; i++) {
|
||||
source[i] = static_cast<char>('a' + i % 26);
|
||||
}
|
||||
}
|
||||
|
||||
WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
|
||||
rocshmem_free(s_buf);
|
||||
rocshmem_free(r_buf);
|
||||
WaveFrontPrimitiveTester::~WaveFrontPrimitiveTester() {
|
||||
rocshmem_free(source);
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
num_elems = (size * args.num_wgs * num_warps) / sizeof(int);
|
||||
std::iota(s_buf, s_buf + num_elems, 0);
|
||||
memset(r_buf, 0, size * args.num_wgs * num_warps);
|
||||
void WaveFrontPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
size_t buff_size = size * args.num_wgs * num_warps;
|
||||
memset(dest, '1', buff_size);
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
void WaveFrontPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
hipLaunchKernelGGL(WaveFrontPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time,
|
||||
(char*)s_buf, (char*)r_buf, size, _type, _shmem_context,
|
||||
source, dest, size, _type, _shmem_context,
|
||||
wf_size);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x * num_warps;
|
||||
num_timed_msgs = loop * gridSize.x * num_warps;
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
|
||||
void WaveFrontPrimitiveTester::verifyResults(uint64_t size) {
|
||||
int check_id = (_type == WAVEGetTestType || _type == WAVEGetNBITestType)
|
||||
? 0
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
if (r_buf[i] != i) {
|
||||
fprintf(stderr, "Data validation error at idx %d\n", i);
|
||||
fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i);
|
||||
size_t buff_size = size * args.num_wgs * num_warps;
|
||||
for (size_t i = 0; i < buff_size; i++) {
|
||||
if (dest[i] != source[i]) {
|
||||
std::cerr << "Data validation error at idx " << i << std::endl;
|
||||
std::cerr << " Got " << dest[i] << ", Expected "
|
||||
<< source[i] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
+5
-7
@@ -24,15 +24,14 @@
|
||||
#define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include "../src/util.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class WaveLevelPrimitiveTester : public Tester {
|
||||
class WaveFrontPrimitiveTester : public Tester {
|
||||
public:
|
||||
explicit WaveLevelPrimitiveTester(TesterArguments args);
|
||||
virtual ~WaveLevelPrimitiveTester();
|
||||
explicit WaveFrontPrimitiveTester(TesterArguments args);
|
||||
virtual ~WaveFrontPrimitiveTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
@@ -42,9 +41,8 @@ class WaveLevelPrimitiveTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
int *s_buf = nullptr;
|
||||
int *r_buf = nullptr;
|
||||
int num_elems = 0;
|
||||
char *source = nullptr;
|
||||
char *dest = nullptr;
|
||||
};
|
||||
|
||||
#endif
|
||||
+55
-37
@@ -20,7 +20,7 @@
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "extended_primitives.hpp"
|
||||
#include "workgroup_primitives.hpp"
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
@@ -31,51 +31,51 @@ using namespace rocshmem;
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void ExtendedPrimitiveTest(int loop, int skip,
|
||||
__global__ void WorkGroupPrimitiveTest(int loop, int skip,
|
||||
long long int *start_time,
|
||||
long long int *end_time, char *s_buf,
|
||||
char *r_buf, int size, TestType type,
|
||||
long long int *end_time, char *source,
|
||||
char *dest, int size, TestType type,
|
||||
ShmemContextType ctx_type) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int wg_id = get_flat_grid_id();
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
/**
|
||||
* Calculate start index for each work group for tiled version
|
||||
* If the number of work groups is greater than 1, this kernel performs a
|
||||
* tiled functional test
|
||||
*/
|
||||
uint64_t idx = size * get_flat_grid_id();
|
||||
s_buf += idx;
|
||||
r_buf += idx;
|
||||
// Calculate start index for each work group
|
||||
uint64_t offset = size * wg_id;
|
||||
source += offset;
|
||||
dest += offset;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
// Ensures all RMA calls from the skip loops are completed
|
||||
if (is_thread_zero_in_block()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
}
|
||||
__syncthreads();
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case WGGetTestType:
|
||||
rocshmem_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_getmem_wg(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case WGGetNBITestType:
|
||||
rocshmem_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_getmem_nbi_wg(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case WGPutTestType:
|
||||
rocshmem_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_putmem_wg(ctx, dest, source, size, 1);
|
||||
break;
|
||||
case WGPutNBITestType:
|
||||
rocshmem_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
|
||||
rocshmem_ctx_putmem_nbi_wg(ctx, dest, source, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
if (is_thread_zero_in_block()) {
|
||||
rocshmem_ctx_quiet(ctx);
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
@@ -86,45 +86,63 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip,
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args)
|
||||
WorkGroupPrimitiveTester::WorkGroupPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = static_cast<int*>(rocshmem_malloc(args.max_msg_size * args.num_wgs));
|
||||
r_buf = static_cast<int*>(rocshmem_malloc(args.max_msg_size * args.num_wgs));
|
||||
size_t buff_size = args.max_msg_size * args.num_wgs;
|
||||
source = (char *)rocshmem_malloc(buff_size);
|
||||
dest = (char *)rocshmem_malloc(buff_size);
|
||||
|
||||
if (source == nullptr || dest == nullptr) {
|
||||
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
|
||||
std::cerr << "source: " << source << ", dest: " << dest << std::endl;
|
||||
if (source) {
|
||||
rocshmem_free(source);
|
||||
}
|
||||
if (dest) {
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
rocshmem_global_exit(1);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < buff_size; i++) {
|
||||
source[i] = static_cast<char>('a' + i % 26);
|
||||
}
|
||||
}
|
||||
|
||||
ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
|
||||
rocshmem_free(s_buf);
|
||||
rocshmem_free(r_buf);
|
||||
WorkGroupPrimitiveTester::~WorkGroupPrimitiveTester() {
|
||||
rocshmem_free(source);
|
||||
rocshmem_free(dest);
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
num_elems = (size * args.num_wgs) / sizeof(int);
|
||||
std::iota(s_buf, s_buf + num_elems, 0);
|
||||
memset(r_buf, 0, size * args.num_wgs);
|
||||
void WorkGroupPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
size_t buff_size = size * args.num_wgs;
|
||||
memset(dest, '1', buff_size);
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
void WorkGroupPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
hipLaunchKernelGGL(WorkGroupPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time,
|
||||
(char*)s_buf, (char*)r_buf, size, _type, _shmem_context);
|
||||
source, dest, size, _type, _shmem_context);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
|
||||
void WorkGroupPrimitiveTester::verifyResults(uint64_t size) {
|
||||
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType)
|
||||
? 0
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
if (r_buf[i] != i) {
|
||||
fprintf(stderr, "Data validation error at idx %d\n", i);
|
||||
fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i);
|
||||
size_t buff_size = size * args.num_wgs;
|
||||
for (size_t i = 0; i < buff_size; i++) {
|
||||
if (dest[i] != source[i]) {
|
||||
std::cerr << "Data validation error at idx " << i << std::endl;
|
||||
std::cerr << " Got " << dest[i] << ", Expected "
|
||||
<< source[i] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
+5
-7
@@ -24,15 +24,14 @@
|
||||
#define _EXTENDED_PRIMITIVES_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include "../src/util.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class ExtendedPrimitiveTester : public Tester {
|
||||
class WorkGroupPrimitiveTester : public Tester {
|
||||
public:
|
||||
explicit ExtendedPrimitiveTester(TesterArguments args);
|
||||
virtual ~ExtendedPrimitiveTester();
|
||||
explicit WorkGroupPrimitiveTester(TesterArguments args);
|
||||
virtual ~WorkGroupPrimitiveTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
@@ -42,9 +41,8 @@ class ExtendedPrimitiveTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
int *s_buf = nullptr;
|
||||
int *r_buf = nullptr;
|
||||
int num_elems = 0;
|
||||
char *source = nullptr;
|
||||
char *dest = nullptr;
|
||||
};
|
||||
|
||||
#endif
|
||||
Yeni konuda referans
Bir kullanıcı engelle