diff --git a/scripts/functional_tests/driver.sh b/scripts/functional_tests/driver.sh index f65a1aeadd..e0e0d45fa6 100755 --- a/scripts/functional_tests/driver.sh +++ b/scripts/functional_tests/driver.sh @@ -55,6 +55,54 @@ case $2 in echo "putnbi_n2_w1_z1_1MB" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -s 1048576 -a 3 > $3/putnbi_n2_w1_z1_1MB.log check putnbi_n2_w1_z1_1MB + echo "wg_get_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 > $3/wg_get_n2_w1_z64_1MB.log + check wg_get_n2_w1_z1_1MB + echo "wg_getnbi_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_n2_w1_z64_1MB.log + check wg_getnbi_n2_w1_z1_1MB + echo "wg_put_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 > $3/wg_put_n2_w1_z64_1MB.log + check wg_put_n2_w1_z1_1MB + echo "wg_putnbi_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z64_1MB.log + check wg_putnbi_n2_w1_z1_1MB + echo "wg_get_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 28 > $3/wg_get_tiled_n2_w2_z64_1MB.log + check wg_get_tiled_n2_w1_z1_1MB + echo "wg_getnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_tiled_n2_w2_z64_1MB.log + check wg_getnbi_tiled_n2_w1_z1_1MB + echo "wg_put_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 30 > $3/wg_put_tiled_n2_w2_z64_1MB.log + check wg_put_tiled_n2_w1_z1_1MB + echo "wg_putnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_tiled_n2_w2_z64_1MB.log + check wg_putnbi_tiled_n2_w1_z1_1MB + echo "wave_get_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z64_1MB.log + check wave_get_n2_w1_z1_1MB + echo "wave_getnbi_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 33 > $3/wave_getnbi_n2_w1_z64_1MB.log + check wave_getnbi_n2_w1_z1_1MB + echo "wave_put_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 34 > $3/wave_put_n2_w1_z64_1MB.log + check wave_put_n2_w1_z1_1MB + echo "wave_putnbi_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 35 > $3/wave_putnbi_n2_w1_z64_1MB.log + check wave_putnbi_n2_w1_z1_1MB + echo "wave_get_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 32 > $3/wave_get_tiled_n2_w2_z128_1MB.log + check wave_get_tiled_n2_w1_z1_1MB + echo "wave_getnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 33 > $3/wave_getnbi_tiled_n2_w2_z128_1MB.log + check wave_getnbi_tiled_n2_w1_z1_1MB + echo "wave_put_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 34 > $3/wave_put_tiled_n2_w2_z128_1MB.log + check wave_put_tiled_n2_w1_z1_1MB + echo "wave_putnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 35 > $3/wave_putnbi_tiled_n2_w2_z128_1MB.log + check wave_putnbi_tiled_n2_w1_z1_1MB echo "amofadd_n2_w1_z1" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -a 6 > $3/amofadd_n2_w1_z1.log check amofadd_n2_w1_z1 @@ -97,6 +145,54 @@ case $2 in echo "putnbi_n2_w16_z128_8B" ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 3 > $3/putnbi_n2_w16_z128_8B.log check putnbi_n2_w16_z128_8B + echo "wg_get_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 28 > $3/wg_get_n2_w1_z64_8B.log + check wg_get_n2_w1_z64_8B + echo "wg_getnbi_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 29 > $3/wg_getnbi_n2_w1_z64_8B.log + check wg_getnbi_n2_w1_z64_8B + echo "wg_put_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 30 > $3/wg_put_n2_w1_z64_8B.log + check wg_put_n2_w1_z64_8B + echo "wg_putnbi_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 31 > $3/wg_putnbi_n2_w1_z64_8B.log + check wg_putnbi_n2_w1_z64_8B + echo "wg_get_tiled_n2_w16_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 28 > $3/wg_get_tiled_n2_w16_z64_8B.log + check wg_get_tiled_n2_w16_z64_8B + echo "wg_getnbi_tiled_n2_w16_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 29 > $3/wg_getnbi_tiled_n2_w16_z64_8B.log + check wg_getnbi_tiled_n2_w16_z64_8B + echo "wg_put_tiled_n2_w16_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 30 > $3/wg_put_tiled_n2_w16_z64_8B.log + check wg_put_tiled_n2_w16_z64_8B + echo "wg_putnbi_tiled_n2_w16_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 64 -s 8 -a 31 > $3/wg_putnbi_tiled_n2_w16_z64_8B.log + check wg_putnbi_tiled_n2_w16_z64_8B + echo "wave_get_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 32 > $3/wave_get_n2_w1_z64_8B.log + check wave_get_n2_w1_z64_8B + echo "wave_getnbi_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 33 > $3/wave_getnbi_n2_w1_z64_8B.log + check wave_getnbi_n2_w1_z64_8B + echo "wave_put_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 34 > $3/wave_put_n2_w1_z64_8B.log + check wave_put_n2_w1_z64_8B + echo "wave_putnbi_n2_w1_z64_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 8 -a 35 > $3/wave_putnbi_n2_w1_z64_8B.log + check wave_putnbi_n2_w1_z64_8B + echo "wave_get_tiled_n2_w16_z128_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 32 > $3/wave_get_tiled_n2_w16_z128_8B.log + check wave_get_tiled_n2_w16_z128_8B + echo "wave_getnbi_tiled_n2_w16_z128_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 33 > $3/wave_getnbi_tiled_n2_w16_z128_8B.log + check wave_getnbi_tiled_n2_w16_z128_8B + echo "wave_put_tiled_n2_w16_z128_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 34 > $3/wave_put_tiled_n2_w16_z128_8B.log + check wave_put_tiled_n2_w16_z128_8B + echo "wave_putnbi_tiled_n2_w16_z128_8B" + ROC_SHMEM_MAX_NUM_CONTEXTS=16 mpirun -np 2 $1 -w 16 -z 128 -s 8 -a 35 > $3/wave_putnbi_tiled_n2_w16_z128_8B.log + check wave_putnbi_tiled_n2_w16_z128_8B echo "amofadd_n2_w8_z1" ROC_SHMEM_MAX_NUM_CONTEXTS=8 mpirun -np 2 $1 -w 8 -z 1 -a 6 > $3/amofadd_n2_w8_z1.log check amofadd_n2_w8_z1 diff --git a/tests/functional_tests/CMakeLists.txt b/tests/functional_tests/CMakeLists.txt index a6176e402e..542462261d 100644 --- a/tests/functional_tests/CMakeLists.txt +++ b/tests/functional_tests/CMakeLists.txt @@ -55,6 +55,7 @@ target_sources( shmem_ptr_tester.cpp extended_primitives.cpp empty_tester.cpp + wave_level_primitives.cpp ) ############################################################################### diff --git a/tests/functional_tests/extended_primitives.cpp b/tests/functional_tests/extended_primitives.cpp index affb6b3e14..d6c6042faa 100644 --- a/tests/functional_tests/extended_primitives.cpp +++ b/tests/functional_tests/extended_primitives.cpp @@ -24,6 +24,8 @@ #include +#include + using namespace rocshmem; /****************************************************************************** @@ -37,7 +39,15 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer, roc_shmem_wg_init(); roc_shmem_wg_ctx_create(ctx_type, &ctx); + /** + * Calculate start index for each work group for tiled version + * If the number of work groups is greater than 1, this kernel performs a + * tiled functional test + */ uint64_t start; + uint64_t idx = size * get_flat_grid_id(); + s_buf += idx; + r_buf += idx; for (int i = 0; i < loop + skip; i++) { if (i == skip) start = roc_shmem_timer(); @@ -75,8 +85,8 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer, *****************************************************************************/ ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size); - r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size); + s_buf = static_cast(roc_shmem_malloc(args.max_msg_size * args.num_wgs)); + r_buf = static_cast(roc_shmem_malloc(args.max_msg_size * args.num_wgs)); } ExtendedPrimitiveTester::~ExtendedPrimitiveTester() { @@ -85,8 +95,9 @@ ExtendedPrimitiveTester::~ExtendedPrimitiveTester() { } void ExtendedPrimitiveTester::resetBuffers(uint64_t size) { - memset(s_buf, '0', args.max_msg_size * args.wg_size); - memset(r_buf, '1', args.max_msg_size * args.wg_size); + num_elems = (size * args.num_wgs) / sizeof(int); + std::iota(s_buf, s_buf + num_elems, 0); + memset(r_buf, 0, size * args.num_wgs); } void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, @@ -94,24 +105,23 @@ void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, size_t shared_bytes = 0; hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes, - stream, loop, args.skip, timer, s_buf, r_buf, size, _type, - _shmem_context); + stream, loop, args.skip, timer, (char*)s_buf, + (char*)r_buf, size, _type, _shmem_context); num_msgs = (loop + args.skip) * gridSize.x; - num_timed_msgs = loop; + num_timed_msgs = loop * gridSize.x; } void ExtendedPrimitiveTester::verifyResults(uint64_t size) { - int check_id = (_type == WGGetTestType || _type == WGGetNBITestType || - _type == WAVEGetTestType) + int check_id = (_type == WGGetTestType || _type == WGGetNBITestType) ? 0 : 1; if (args.myid == check_id) { - for (int i = 0; i < size; i++) { - if (r_buf[i] != '0') { + for (int i = 0; i < num_elems; i++) { + if (r_buf[i] != i) { fprintf(stderr, "Data validation error at idx %d\n", i); - fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0'); + fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i); exit(-1); } } diff --git a/tests/functional_tests/extended_primitives.hpp b/tests/functional_tests/extended_primitives.hpp index dbad6a48f3..76225f07b0 100644 --- a/tests/functional_tests/extended_primitives.hpp +++ b/tests/functional_tests/extended_primitives.hpp @@ -24,6 +24,7 @@ #define _EXTENDED_PRIMITIVES_HPP_ #include "tester.hpp" +#include "../src/util.hpp" /****************************************************************************** * HOST TESTER CLASS @@ -41,8 +42,9 @@ class ExtendedPrimitiveTester : public Tester { virtual void verifyResults(uint64_t size) override; - char *s_buf = nullptr; - char *r_buf = nullptr; + int *s_buf = nullptr; + int *r_buf = nullptr; + int num_elems = 0; }; #endif diff --git a/tests/functional_tests/tester.cpp b/tests/functional_tests/tester.cpp index f35928eb53..e802c0b289 100644 --- a/tests/functional_tests/tester.cpp +++ b/tests/functional_tests/tester.cpp @@ -52,10 +52,14 @@ #include "team_ctx_infra_tester.hpp" #include "team_ctx_primitive_tester.hpp" #include "team_reduction_tester.hpp" +#include "wave_level_primitives.hpp" Tester::Tester(TesterArguments args) : args(args) { _type = (TestType)args.algorithm; _shmem_context = args.shmem_context; + CHECK_HIP(hipGetDevice(&device_id)); + CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id)); + num_warps = (args.wg_size - 1) / deviceProps.warpSize + 1; CHECK_HIP(hipStreamCreate(&stream)); CHECK_HIP(hipEventCreate(&start_event)); CHECK_HIP(hipEventCreate(&stop_event)); @@ -72,6 +76,11 @@ Tester::~Tester() { std::vector Tester::create(TesterArguments args) { int rank = args.myid; std::vector testers; + hipDeviceProp_t deviceProps; + int device_id, numWarps; + CHECK_HIP(hipGetDevice(&device_id)); + CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id)); + numWarps = (args.wg_size - 1) / deviceProps.warpSize + 1; if (rank == 0) std::cout << "*** Creating Test: "; @@ -468,19 +477,35 @@ std::vector Tester::create(TesterArguments args) { testers.push_back(new ShmemPtrTester(args)); return testers; case WGGetTestType: - if (rank == 0) std::cout << "Blocking WG level Gets***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1) + std::cout << "Tiled Blocking WG level Gets***" << std::endl; + else std::cout << "Blocking WG level Gets***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGGetNBITestType: - if (rank == 0) std::cout << "Non-Blocking WG level Gets***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1) + std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl; + else std::cout << "Non-Blocking WG level Gets***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGPutTestType: - if (rank == 0) std::cout << "Blocking WG level Puts***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1) + std::cout << "Tiled Blocking WG level Puts***" << std::endl; + else std::cout << "Blocking WG level Puts***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGPutNBITestType: - if (rank == 0) std::cout << "Non-Blocking WG level Puts***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1) + std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl; + else std::cout << "Non-Blocking WG level Puts***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case PutNBIMRTestType: @@ -488,6 +513,38 @@ std::vector Tester::create(TesterArguments args) { std::cout << "Non-Blocking Put message rate***" << std::endl; testers.push_back(new PrimitiveMRTester(args)); return testers; + case WAVEGetTestType: + if (rank == 0) { + if (args.num_wgs > 1 || numWarps > 1) + std::cout << "Tiled Blocking WAVE level Gets***" << std::endl; + else std::cout << "Blocking WAVE level Gets***" << std::endl; + } + testers.push_back(new WaveLevelPrimitiveTester(args)); + return testers; + case WAVEGetNBITestType: + if (rank == 0) { + if (args.num_wgs > 1 || numWarps > 1) + std::cout << "Tiled Non-Blocking WAVE level Gets***" << std::endl; + else std::cout << "Non-Blocking WAVE level Gets***" << std::endl; + } + testers.push_back(new WaveLevelPrimitiveTester(args)); + return testers; + case WAVEPutTestType: + if (rank == 0) { + if (args.num_wgs > 1 || numWarps > 1) + std::cout << "Tiled Blocking WAVE level Puts***" << std::endl; + else std::cout << "Blocking WAVE level Puts***" << std::endl; + } + testers.push_back(new WaveLevelPrimitiveTester(args)); + return testers; + case WAVEPutNBITestType: + if (rank == 0) { + if (args.num_wgs > 1 || numWarps > 1) + std::cout << "Tiled Non-Blocking WAVE level Puts***" << std::endl; + else std::cout << "Non-Blocking WAVE level Puts***" << std::endl; + } + testers.push_back(new WaveLevelPrimitiveTester(args)); + return testers; default: if (rank == 0) std::cout << "Unknown***" << std::endl; testers.push_back(new PrimitiveTester(args)); @@ -560,10 +617,32 @@ void Tester::execute() { // data validation verifyResults(size); + /** + * Adjust size for *_wg and *_wave functions + */ + uint64_t size_ = size; + TestType type = (TestType)args.algorithm; + switch (type) { + case WAVEGetTestType: + case WAVEGetNBITestType: + case WAVEPutTestType: + case WAVEPutNBITestType: + size_ *= (args.num_wgs * num_warps); + break; + case WGGetTestType: + case WGGetNBITestType: + case WGPutTestType: + case WGPutNBITestType: + size_ *= args.num_wgs; + break; + default: + break; + } + barrier(); if (_type != TeamCtxInfraTestType) { - print(size); + print(size_); } } } @@ -655,6 +734,10 @@ uint64_t Tester::gpuCyclesToMicroseconds(uint64_t cycles) { uint64_t Tester::timerAvgInMicroseconds() { uint64_t sum = 0; + /** + * TODO: (bpotter/avinash) Modify the calcuation for the Tiled version of + * puts and gets at wavefront level + */ for (int i = 0; i < args.num_wgs; i++) { sum += gpuCyclesToMicroseconds(timer[i]); } diff --git a/tests/functional_tests/tester.hpp b/tests/functional_tests/tester.hpp index 14a91435e5..e0482868e8 100644 --- a/tests/functional_tests/tester.hpp +++ b/tests/functional_tests/tester.hpp @@ -117,7 +117,9 @@ class Tester { int num_msgs = 0; int num_timed_msgs = 0; + int num_warps = 0; int bw_factor = 1; + int device_id = 0; TesterArguments args; @@ -125,6 +127,7 @@ class Tester { ShmemContextType _shmem_context = 8; // SHMEM_CTX_WP_PRIVATE hipStream_t stream; + hipDeviceProp_t deviceProps; uint64_t *timer = nullptr; diff --git a/tests/functional_tests/tester_arguments.cpp b/tests/functional_tests/tester_arguments.cpp index 88e4b12cd6..1167fafaeb 100644 --- a/tests/functional_tests/tester_arguments.cpp +++ b/tests/functional_tests/tester_arguments.cpp @@ -103,6 +103,16 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case PutNBIMRTestType: min_msg_size = max_msg_size; break; + case WAVEGetTestType: + case WAVEGetNBITestType: + case WAVEPutTestType: + case WAVEPutNBITestType: + case WGGetTestType: + case WGGetNBITestType: + case WGPutTestType: + case WGPutNBITestType: + min_msg_size = 4; + break; default: break; } diff --git a/tests/functional_tests/wave_level_primitives.cpp b/tests/functional_tests/wave_level_primitives.cpp new file mode 100644 index 0000000000..56b263ebf4 --- /dev/null +++ b/tests/functional_tests/wave_level_primitives.cpp @@ -0,0 +1,134 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "wave_level_primitives.hpp" + +#include + +#include + +using namespace rocshmem; + +/****************************************************************************** + * DEVICE TEST KERNEL + *****************************************************************************/ +__global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer, + char *s_buf, char *r_buf, int size, + TestType type, ShmemContextType ctx_type, + int wf_size) { + __shared__ roc_shmem_ctx_t ctx; + roc_shmem_wg_init(); + roc_shmem_wg_ctx_create(ctx_type, &ctx); + + /** + * Calculate start index for each wavefront for tiled version + * If the number of wavefronts is greater than 1, this kernel performs a + * tiled functional test + */ + uint64_t start; + int wf_id = get_flat_block_id() / wf_size; + int wg_offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size); + int idx = wf_id * size + wg_offset; + s_buf += idx; + r_buf += idx; + + for (int i = 0; i < loop + skip; i++) { + if (i == skip) start = roc_shmem_timer(); + + switch (type) { + case WAVEGetTestType: + roc_shmemx_ctx_getmem_wave(ctx, r_buf, s_buf, size, 1); + break; + case WAVEGetNBITestType: + roc_shmemx_ctx_getmem_nbi_wave(ctx, r_buf, s_buf, size, 1); + break; + case WAVEPutTestType: + roc_shmemx_ctx_putmem_wave(ctx, r_buf, s_buf, size, 1); + break; + case WAVEPutNBITestType: + roc_shmemx_ctx_putmem_nbi_wave(ctx, r_buf, s_buf, size, 1); + break; + default: + break; + } + } + + roc_shmem_ctx_quiet(ctx); + + if (hipThreadIdx_x == 0) { + timer[hipBlockIdx_x] = roc_shmem_timer() - start; + } + + roc_shmem_wg_ctx_destroy(&ctx); + roc_shmem_wg_finalize(); +} + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args) + : Tester(args) { + s_buf = static_cast( + roc_shmem_malloc(args.max_msg_size * args.num_wgs * num_warps)); + r_buf = static_cast( + roc_shmem_malloc(args.max_msg_size * args.num_wgs * num_warps)); +} + +WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() { + roc_shmem_free(s_buf); + roc_shmem_free(r_buf); +} + +void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) { + num_elems = (size * args.num_wgs * num_warps) / sizeof(int); + std::iota(s_buf, s_buf + num_elems, 0); + memset(r_buf, 0, size * args.num_wgs * num_warps); +} + +void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, uint64_t size) { + size_t shared_bytes = 0; + + hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes, + stream, loop, args.skip, timer, (char*)s_buf, + (char*)r_buf, size, _type, _shmem_context, + deviceProps.warpSize); + + num_msgs = (loop + args.skip) * gridSize.x * num_warps; + num_timed_msgs = loop * gridSize.x * num_warps; +} + +void WaveLevelPrimitiveTester::verifyResults(uint64_t size) { + int check_id = (_type == WAVEGetTestType || _type == WAVEGetNBITestType) + ? 0 + : 1; + + if (args.myid == check_id) { + for (int i = 0; i < num_elems; i++) { + if (r_buf[i] != i) { + fprintf(stderr, "Data validation error at idx %d\n", i); + fprintf(stderr, "Got %d, Expected %d \n", r_buf[i], i); + exit(-1); + } + } + } +} diff --git a/tests/functional_tests/wave_level_primitives.hpp b/tests/functional_tests/wave_level_primitives.hpp new file mode 100644 index 0000000000..af65ac0a30 --- /dev/null +++ b/tests/functional_tests/wave_level_primitives.hpp @@ -0,0 +1,50 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _WAVE_LEVEL_PRIMITIVE_TEST_HPP_ +#define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_ + +#include "tester.hpp" +#include "../src/util.hpp" + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class WaveLevelPrimitiveTester : public Tester { + public: + explicit WaveLevelPrimitiveTester(TesterArguments args); + virtual ~WaveLevelPrimitiveTester(); + + protected: + virtual void resetBuffers(uint64_t size) override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + uint64_t size) override; + + virtual void verifyResults(uint64_t size) override; + + int *s_buf = nullptr; + int *r_buf = nullptr; + int num_elems = 0; +}; + +#endif