From 3d26792831f43d2a99ec73f7db9eaa7ae2e47d31 Mon Sep 17 00:00:00 2001 From: avinashkethineedi Date: Wed, 4 Sep 2024 12:55:10 -0700 Subject: [PATCH] Add tilled version of puts and gets at the workgroup level to the functional test suite [ROCm/rocshmem commit: d226922733dc081187824c7ca90a7e035ea6e773] --- .../scripts/functional_tests/driver.sh | 12 ++++ .../functional_tests/extended_primitives.cpp | 60 +++++++++++++++++-- .../functional_tests/extended_primitives.hpp | 1 + .../tests/functional_tests/tester.cpp | 20 +++++-- .../functional_tests/tester_arguments.cpp | 4 ++ .../functional_tests/tester_arguments.hpp | 1 + 6 files changed, 89 insertions(+), 9 deletions(-) diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index 556681ec66..0acfd26090 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -67,6 +67,18 @@ case $2 in echo "wg_putnbi_n2_w1_z1_1MB" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z1_1MB.log check wg_putnbi_n2_w1_z1_1MB + echo "wg_get_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 -ti 1 > $3/wg_get_tiled_n2_w1_z1_1MB.log + check wg_get_tiled_n2_w1_z1_1MB + echo "wg_getnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 -ti 1 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log + check wg_getnbi_tiled_n2_w1_z1_1MB + echo "wg_put_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 -ti 1 > $3/wg_put_tiled_n2_w1_z1_1MB.log + check wg_put_tiled_n2_w1_z1_1MB + echo "wg_putnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 -ti 1 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log + check wg_putnbi_tiled_n2_w1_z1_1MB echo "wave_get_n2_w1_z1_1MB" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z1_1MB.log check wave_get_n2_w1_z1_1MB diff --git a/projects/rocshmem/tests/functional_tests/extended_primitives.cpp b/projects/rocshmem/tests/functional_tests/extended_primitives.cpp index affb6b3e14..8eb569cd32 100644 --- a/projects/rocshmem/tests/functional_tests/extended_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/extended_primitives.cpp @@ -27,7 +27,7 @@ using namespace rocshmem; /****************************************************************************** - * DEVICE TEST KERNEL + * DEVICE TEST KERNELS *****************************************************************************/ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer, char *s_buf, char *r_buf, int size, @@ -70,6 +70,50 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer, roc_shmem_wg_finalize(); } +__global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer, + char *s_buf, char *r_buf, int size, + TestType type, + ShmemContextType ctx_type) { + __shared__ roc_shmem_ctx_t ctx; + roc_shmem_wg_init(); + roc_shmem_wg_ctx_create(ctx_type, &ctx); + + uint64_t start; + uint64_t idx = size * get_flat_grid_id(); + s_buf += idx; + r_buf += idx; + + for (int i = 0; i < loop + skip; i++) { + if (i == skip) start = roc_shmem_timer(); + + switch (type) { + case WGGetTestType: + roc_shmemx_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1); + break; + case WGGetNBITestType: + roc_shmemx_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1); + break; + case WGPutTestType: + roc_shmemx_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1); + break; + case WGPutNBITestType: + roc_shmemx_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1); + break; + default: + break; + } + } + + roc_shmem_ctx_quiet(ctx); + + if (hipThreadIdx_x == 0) { + timer[hipBlockIdx_x] = roc_shmem_timer() - start; + } + + roc_shmem_wg_ctx_destroy(&ctx); + roc_shmem_wg_finalize(); +} + /****************************************************************************** * HOST TESTER CLASS METHODS *****************************************************************************/ @@ -93,17 +137,23 @@ void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes, + if (args.tiled){ + hipLaunchKernelGGL(ExtendedPrimitiveTestTiled, gridSize, blockSize, shared_bytes, stream, loop, args.skip, timer, s_buf, r_buf, size, _type, _shmem_context); + } + else { + hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes, + stream, loop, args.skip, timer, s_buf, r_buf, size, _type, + _shmem_context); + } num_msgs = (loop + args.skip) * gridSize.x; - num_timed_msgs = loop; + num_timed_msgs = loop * gridSize.x; } void ExtendedPrimitiveTester::verifyResults(uint64_t size) { - int check_id = (_type == WGGetTestType || _type == WGGetNBITestType || - _type == WAVEGetTestType) + int check_id = (_type == WGGetTestType || _type == WGGetNBITestType) ? 0 : 1; diff --git a/projects/rocshmem/tests/functional_tests/extended_primitives.hpp b/projects/rocshmem/tests/functional_tests/extended_primitives.hpp index dbad6a48f3..33fdac2f8f 100644 --- a/projects/rocshmem/tests/functional_tests/extended_primitives.hpp +++ b/projects/rocshmem/tests/functional_tests/extended_primitives.hpp @@ -24,6 +24,7 @@ #define _EXTENDED_PRIMITIVES_HPP_ #include "tester.hpp" +#include "../src/util.hpp" /****************************************************************************** * HOST TESTER CLASS diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 58d40cb4f2..22ab55f4e1 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -469,19 +469,31 @@ std::vector Tester::create(TesterArguments args) { testers.push_back(new ShmemPtrTester(args)); return testers; case WGGetTestType: - if (rank == 0) std::cout << "Blocking WG level Gets***" << std::endl; + if (rank == 0) { + if (args.tiled) std::cout << "Tiled Blocking WG level Gets***" << std::endl; + else std::cout << "Blocking WG level Gets***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGGetNBITestType: - if (rank == 0) std::cout << "Non-Blocking WG level Gets***" << std::endl; + if (rank == 0) { + if (args.tiled) std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl; + else std::cout << "Non-Blocking WG level Gets***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGPutTestType: - if (rank == 0) std::cout << "Blocking WG level Puts***" << std::endl; + if (rank == 0) { + if (args.tiled) std::cout << "Tiled Blocking WG level Puts***" << std::endl; + else std::cout << "Blocking WG level Puts***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGPutNBITestType: - if (rank == 0) std::cout << "Non-Blocking WG level Puts***" << std::endl; + if (rank == 0) { + if(args.tiled) std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl; + else std::cout << "Non-Blocking WG level Puts***" << std::endl; + } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case PutNBIMRTestType: diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index 88e4b12cd6..014d5d1388 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -60,6 +60,9 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { } else if (arg == "-x") { i++; shmem_context = atoi(argv[i]); + } else if (arg == "-ti") { + i++; + tiled = (atoi(argv[i]) == 1); } else { show_usage(argv[0]); exit(-1); @@ -119,6 +122,7 @@ void TesterArguments::show_usage(std::string executable_name) { std::cout << "\t-o \n"; std::cout << "\t-ta \n"; std::cout << "\t-x \n"; + std::cout << "\t-ti \n"; } void TesterArguments::get_rocshmem_arguments() { diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp index 88ff6a5537..810d00d825 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp @@ -58,6 +58,7 @@ class TesterArguments { unsigned coal_coef = 64; unsigned op_type = 0; unsigned shmem_context = rocshmem::ROC_SHMEM_CTX_WG_PRIVATE; + bool tiled = false; /** * Arguments obtained from rocshmem