Merge pull request #40 from avinashkethineedi/functional_tests/puts_gets
Functional tests {wave, wg} puts and gets
Этот коммит содержится в:
@@ -55,6 +55,54 @@ case $2 in
|
||||
echo "putnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -s 1048576 -a 3 > $3/putnbi_n2_w1_z1_1MB.log
|
||||
check putnbi_n2_w1_z1_1MB
|
||||
echo "wg_get_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 > $3/wg_get_n2_w1_z64_1MB.log
|
||||
check wg_get_n2_w1_z1_1MB
|
||||
echo "wg_getnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_n2_w1_z64_1MB.log
|
||||
check wg_getnbi_n2_w1_z1_1MB
|
||||
echo "wg_put_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 > $3/wg_put_n2_w1_z64_1MB.log
|
||||
check wg_put_n2_w1_z1_1MB
|
||||
echo "wg_putnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z64_1MB.log
|
||||
check wg_putnbi_n2_w1_z1_1MB
|
||||
echo "wg_get_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 28 > $3/wg_get_tiled_n2_w2_z64_1MB.log
|
||||
check wg_get_tiled_n2_w1_z1_1MB
|
||||
echo "wg_getnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_tiled_n2_w2_z64_1MB.log
|
||||
check wg_getnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wg_put_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 30 > $3/wg_put_tiled_n2_w2_z64_1MB.log
|
||||
check wg_put_tiled_n2_w1_z1_1MB
|
||||
echo "wg_putnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_tiled_n2_w2_z64_1MB.log
|
||||
check wg_putnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wave_get_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z64_1MB.log
|
||||
check wave_get_n2_w1_z1_1MB
|
||||
echo "wave_getnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 33 > $3/wave_getnbi_n2_w1_z64_1MB.log
|
||||
check wave_getnbi_n2_w1_z1_1MB
|
||||
echo "wave_put_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 34 > $3/wave_put_n2_w1_z64_1MB.log
|
||||
check wave_put_n2_w1_z1_1MB
|
||||
echo "wave_putnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 35 > $3/wave_putnbi_n2_w1_z64_1MB.log
|
||||
check wave_putnbi_n2_w1_z1_1MB
|
||||
echo "wave_get_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 32 > $3/wave_get_tiled_n2_w2_z128_1MB.log
|
||||
check wave_get_tiled_n2_w1_z1_1MB
|
||||
echo "wave_getnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 33 > $3/wave_getnbi_tiled_n2_w2_z128_1MB.log
|
||||
check wave_getnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wave_put_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 34 > $3/wave_put_tiled_n2_w2_z128_1MB.log
|
||||
check wave_put_tiled_n2_w1_z1_1MB
|
||||
echo "wave_putnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 35 > $3/wave_putnbi_tiled_n2_w2_z128_1MB.log
|
||||
check wave_putnbi_tiled_n2_w1_z1_1MB
|
||||
echo "amofadd_n2_w1_z1"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -a 6 > $3/amofadd_n2_w1_z1.log
|
||||
check amofadd_n2_w1_z1
|
||||
@@ -97,6 +145,54 @@ case $2 in
|
||||
echo "putnbi_n2_w16_z128_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 3 > $3/putnbi_n2_w16_z128_8B.log
|
||||
check putnbi_n2_w16_z128_8B
|
||||
echo "wg_get_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 28 > $3/wg_get_n2_w1_z64_8B.log
|
||||
check wg_get_n2_w1_z64_8B
|
||||
echo "wg_getnbi_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 29 > $3/wg_getnbi_n2_w1_z64_8B.log
|
||||
check wg_getnbi_n2_w1_z64_8B
|
||||
echo "wg_put_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 30 > $3/wg_put_n2_w1_z64_8B.log
|
||||
check wg_put_n2_w1_z64_8B
|
||||
echo "wg_putnbi_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 31 > $3/wg_putnbi_n2_w1_z64_8B.log
|
||||
check wg_putnbi_n2_w1_z64_8B
|
||||
echo "wg_get_tiled_n2_w16_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 28 > $3/wg_get_tiled_n2_w16_z64_8B.log
|
||||
check wg_get_tiled_n2_w16_z64_8B
|
||||
echo "wg_getnbi_tiled_n2_w16_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 29 > $3/wg_getnbi_tiled_n2_w16_z64_8B.log
|
||||
check wg_getnbi_tiled_n2_w16_z64_8B
|
||||
echo "wg_put_tiled_n2_w16_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 30 > $3/wg_put_tiled_n2_w16_z64_8B.log
|
||||
check wg_put_tiled_n2_w16_z64_8B
|
||||
echo "wg_putnbi_tiled_n2_w16_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 31 > $3/wg_putnbi_tiled_n2_w16_z64_8B.log
|
||||
check wg_putnbi_tiled_n2_w16_z64_8B
|
||||
echo "wave_get_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 32 > $3/wave_get_n2_w1_z64_8B.log
|
||||
check wave_get_n2_w1_z64_8B
|
||||
echo "wave_getnbi_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 33 > $3/wave_getnbi_n2_w1_z64_8B.log
|
||||
check wave_getnbi_n2_w1_z64_8B
|
||||
echo "wave_put_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 34 > $3/wave_put_n2_w1_z64_8B.log
|
||||
check wave_put_n2_w1_z64_8B
|
||||
echo "wave_putnbi_n2_w1_z64_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 35 > $3/wave_putnbi_n2_w1_z64_8B.log
|
||||
check wave_putnbi_n2_w1_z64_8B
|
||||
echo "wave_get_tiled_n2_w16_z128_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 32 > $3/wave_get_tiled_n2_w16_z128_8B.log
|
||||
check wave_get_tiled_n2_w16_z128_8B
|
||||
echo "wave_getnbi_tiled_n2_w16_z128_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 33 > $3/wave_getnbi_tiled_n2_w16_z128_8B.log
|
||||
check wave_getnbi_tiled_n2_w16_z128_8B
|
||||
echo "wave_put_tiled_n2_w16_z128_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 34 > $3/wave_put_tiled_n2_w16_z128_8B.log
|
||||
check wave_put_tiled_n2_w16_z128_8B
|
||||
echo "wave_putnbi_tiled_n2_w16_z128_8B"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 35 > $3/wave_putnbi_tiled_n2_w16_z128_8B.log
|
||||
check wave_putnbi_tiled_n2_w16_z128_8B
|
||||
echo "amofadd_n2_w8_z1"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=8 mpirun -np 2 $1 -w 8 -z 1 -a 6 > $3/amofadd_n2_w8_z1.log
|
||||
check amofadd_n2_w8_z1
|
||||
|
||||
@@ -55,6 +55,7 @@ target_sources(
|
||||
shmem_ptr_tester.cpp
|
||||
extended_primitives.cpp
|
||||
empty_tester.cpp
|
||||
wave_level_primitives.cpp
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
|
||||
#include <roc_shmem/roc_shmem.hpp>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
@@ -37,7 +39,15 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
roc_shmem_wg_init();
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
/**
|
||||
* Calculate start index for each work group for tiled version
|
||||
* If the number of work groups is greater than 1, this kernel performs a
|
||||
* tiled functional test
|
||||
*/
|
||||
uint64_t start;
|
||||
uint64_t idx = size * get_flat_grid_id();
|
||||
s_buf += idx;
|
||||
r_buf += idx;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) start = roc_shmem_timer();
|
||||
@@ -75,8 +85,8 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
*****************************************************************************/
|
||||
ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
|
||||
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
|
||||
s_buf = static_cast<int*>(roc_shmem_malloc(args.max_msg_size * args.num_wgs));
|
||||
r_buf = static_cast<int*>(roc_shmem_malloc(args.max_msg_size * args.num_wgs));
|
||||
}
|
||||
|
||||
ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
|
||||
@@ -85,8 +95,9 @@ ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
memset(s_buf, '0', args.max_msg_size * args.wg_size);
|
||||
memset(r_buf, '1', args.max_msg_size * args.wg_size);
|
||||
num_elems = (size * args.num_wgs) / sizeof(int);
|
||||
std::iota(s_buf, s_buf + num_elems, 0);
|
||||
memset(r_buf, 0, size * args.num_wgs);
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
@@ -94,24 +105,23 @@ void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
|
||||
_shmem_context);
|
||||
stream, loop, args.skip, timer, (char*)s_buf,
|
||||
(char*)r_buf, size, _type, _shmem_context);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
|
||||
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType ||
|
||||
_type == WAVEGetTestType)
|
||||
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType)
|
||||
? 0
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
if (r_buf[i] != i) {
|
||||
fprintf(stderr, "Data validation error at idx %d\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0');
|
||||
fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#define _EXTENDED_PRIMITIVES_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include "../src/util.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
@@ -41,8 +42,9 @@ class ExtendedPrimitiveTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
char *s_buf = nullptr;
|
||||
char *r_buf = nullptr;
|
||||
int *s_buf = nullptr;
|
||||
int *r_buf = nullptr;
|
||||
int num_elems = 0;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -52,10 +52,14 @@
|
||||
#include "team_ctx_infra_tester.hpp"
|
||||
#include "team_ctx_primitive_tester.hpp"
|
||||
#include "team_reduction_tester.hpp"
|
||||
#include "wave_level_primitives.hpp"
|
||||
|
||||
Tester::Tester(TesterArguments args) : args(args) {
|
||||
_type = (TestType)args.algorithm;
|
||||
_shmem_context = args.shmem_context;
|
||||
CHECK_HIP(hipGetDevice(&device_id));
|
||||
CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id));
|
||||
num_warps = (args.wg_size - 1) / deviceProps.warpSize + 1;
|
||||
CHECK_HIP(hipStreamCreate(&stream));
|
||||
CHECK_HIP(hipEventCreate(&start_event));
|
||||
CHECK_HIP(hipEventCreate(&stop_event));
|
||||
@@ -72,6 +76,11 @@ Tester::~Tester() {
|
||||
std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
int rank = args.myid;
|
||||
std::vector<Tester*> testers;
|
||||
hipDeviceProp_t deviceProps;
|
||||
int device_id, numWarps;
|
||||
CHECK_HIP(hipGetDevice(&device_id));
|
||||
CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id));
|
||||
numWarps = (args.wg_size - 1) / deviceProps.warpSize + 1;
|
||||
|
||||
if (rank == 0) std::cout << "*** Creating Test: ";
|
||||
|
||||
@@ -468,19 +477,35 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
testers.push_back(new ShmemPtrTester(args));
|
||||
return testers;
|
||||
case WGGetTestType:
|
||||
if (rank == 0) std::cout << "Blocking WG level Gets***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Blocking WG level Gets***" << std::endl;
|
||||
else std::cout << "Blocking WG level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGGetNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking WG level Gets***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
|
||||
else std::cout << "Non-Blocking WG level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutTestType:
|
||||
if (rank == 0) std::cout << "Blocking WG level Puts***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Blocking WG level Puts***" << std::endl;
|
||||
else std::cout << "Blocking WG level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking WG level Puts***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
|
||||
else std::cout << "Non-Blocking WG level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case PutNBIMRTestType:
|
||||
@@ -488,6 +513,38 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
std::cout << "Non-Blocking Put message rate***" << std::endl;
|
||||
testers.push_back(new PrimitiveMRTester(args));
|
||||
return testers;
|
||||
case WAVEGetTestType:
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || numWarps > 1)
|
||||
std::cout << "Tiled Blocking WAVE level Gets***" << std::endl;
|
||||
else std::cout << "Blocking WAVE level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEGetNBITestType:
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || numWarps > 1)
|
||||
std::cout << "Tiled Non-Blocking WAVE level Gets***" << std::endl;
|
||||
else std::cout << "Non-Blocking WAVE level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEPutTestType:
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || numWarps > 1)
|
||||
std::cout << "Tiled Blocking WAVE level Puts***" << std::endl;
|
||||
else std::cout << "Blocking WAVE level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEPutNBITestType:
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || numWarps > 1)
|
||||
std::cout << "Tiled Non-Blocking WAVE level Puts***" << std::endl;
|
||||
else std::cout << "Non-Blocking WAVE level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
default:
|
||||
if (rank == 0) std::cout << "Unknown***" << std::endl;
|
||||
testers.push_back(new PrimitiveTester(args));
|
||||
@@ -560,10 +617,32 @@ void Tester::execute() {
|
||||
// data validation
|
||||
verifyResults(size);
|
||||
|
||||
/**
|
||||
* Adjust size for *_wg and *_wave functions
|
||||
*/
|
||||
uint64_t size_ = size;
|
||||
TestType type = (TestType)args.algorithm;
|
||||
switch (type) {
|
||||
case WAVEGetTestType:
|
||||
case WAVEGetNBITestType:
|
||||
case WAVEPutTestType:
|
||||
case WAVEPutNBITestType:
|
||||
size_ *= (args.num_wgs * num_warps);
|
||||
break;
|
||||
case WGGetTestType:
|
||||
case WGGetNBITestType:
|
||||
case WGPutTestType:
|
||||
case WGPutNBITestType:
|
||||
size_ *= args.num_wgs;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
if (_type != TeamCtxInfraTestType) {
|
||||
print(size);
|
||||
print(size_);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -655,6 +734,10 @@ uint64_t Tester::gpuCyclesToMicroseconds(uint64_t cycles) {
|
||||
uint64_t Tester::timerAvgInMicroseconds() {
|
||||
uint64_t sum = 0;
|
||||
|
||||
/**
|
||||
* TODO: (bpotter/avinash) Modify the calcuation for the Tiled version of
|
||||
* puts and gets at wavefront level
|
||||
*/
|
||||
for (int i = 0; i < args.num_wgs; i++) {
|
||||
sum += gpuCyclesToMicroseconds(timer[i]);
|
||||
}
|
||||
|
||||
@@ -117,7 +117,9 @@ class Tester {
|
||||
|
||||
int num_msgs = 0;
|
||||
int num_timed_msgs = 0;
|
||||
int num_warps = 0;
|
||||
int bw_factor = 1;
|
||||
int device_id = 0;
|
||||
|
||||
TesterArguments args;
|
||||
|
||||
@@ -125,6 +127,7 @@ class Tester {
|
||||
ShmemContextType _shmem_context = 8; // SHMEM_CTX_WP_PRIVATE
|
||||
|
||||
hipStream_t stream;
|
||||
hipDeviceProp_t deviceProps;
|
||||
|
||||
uint64_t *timer = nullptr;
|
||||
|
||||
|
||||
@@ -103,6 +103,16 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case PutNBIMRTestType:
|
||||
min_msg_size = max_msg_size;
|
||||
break;
|
||||
case WAVEGetTestType:
|
||||
case WAVEGetNBITestType:
|
||||
case WAVEPutTestType:
|
||||
case WAVEPutNBITestType:
|
||||
case WGGetTestType:
|
||||
case WGGetNBITestType:
|
||||
case WGPutTestType:
|
||||
case WGPutNBITestType:
|
||||
min_msg_size = 4;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "wave_level_primitives.hpp"
|
||||
|
||||
#include <roc_shmem/roc_shmem.hpp>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
char *s_buf, char *r_buf, int size,
|
||||
TestType type, ShmemContextType ctx_type,
|
||||
int wf_size) {
|
||||
__shared__ roc_shmem_ctx_t ctx;
|
||||
roc_shmem_wg_init();
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
/**
|
||||
* Calculate start index for each wavefront for tiled version
|
||||
* If the number of wavefronts is greater than 1, this kernel performs a
|
||||
* tiled functional test
|
||||
*/
|
||||
uint64_t start;
|
||||
int wf_id = get_flat_block_id() / wf_size;
|
||||
int wg_offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size);
|
||||
int idx = wf_id * size + wg_offset;
|
||||
s_buf += idx;
|
||||
r_buf += idx;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) start = roc_shmem_timer();
|
||||
|
||||
switch (type) {
|
||||
case WAVEGetTestType:
|
||||
roc_shmemx_ctx_getmem_wave(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WAVEGetNBITestType:
|
||||
roc_shmemx_ctx_getmem_nbi_wave(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WAVEPutTestType:
|
||||
roc_shmemx_ctx_putmem_wave(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WAVEPutNBITestType:
|
||||
roc_shmemx_ctx_putmem_nbi_wave(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
roc_shmem_ctx_quiet(ctx);
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
timer[hipBlockIdx_x] = roc_shmem_timer() - start;
|
||||
}
|
||||
|
||||
roc_shmem_wg_ctx_destroy(&ctx);
|
||||
roc_shmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = static_cast<int*>(
|
||||
roc_shmem_malloc(args.max_msg_size * args.num_wgs * num_warps));
|
||||
r_buf = static_cast<int*>(
|
||||
roc_shmem_malloc(args.max_msg_size * args.num_wgs * num_warps));
|
||||
}
|
||||
|
||||
WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
|
||||
roc_shmem_free(s_buf);
|
||||
roc_shmem_free(r_buf);
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
num_elems = (size * args.num_wgs * num_warps) / sizeof(int);
|
||||
std::iota(s_buf, s_buf + num_elems, 0);
|
||||
memset(r_buf, 0, size * args.num_wgs * num_warps);
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, (char*)s_buf,
|
||||
(char*)r_buf, size, _type, _shmem_context,
|
||||
deviceProps.warpSize);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x * num_warps;
|
||||
num_timed_msgs = loop * gridSize.x * num_warps;
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
|
||||
int check_id = (_type == WAVEGetTestType || _type == WAVEGetNBITestType)
|
||||
? 0
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
if (r_buf[i] != i) {
|
||||
fprintf(stderr, "Data validation error at idx %d\n", i);
|
||||
fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef _WAVE_LEVEL_PRIMITIVE_TEST_HPP_
|
||||
#define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include "../src/util.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class WaveLevelPrimitiveTester : public Tester {
|
||||
public:
|
||||
explicit WaveLevelPrimitiveTester(TesterArguments args);
|
||||
virtual ~WaveLevelPrimitiveTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
|
||||
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) override;
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
int *s_buf = nullptr;
|
||||
int *r_buf = nullptr;
|
||||
int num_elems = 0;
|
||||
};
|
||||
|
||||
#endif
|
||||
Ссылка в новой задаче
Block a user