From 9532e084fc286c7a543db555abf1d2355d40af57 Mon Sep 17 00:00:00 2001 From: avinashkethineedi Date: Sat, 7 Sep 2024 16:06:36 -0700 Subject: [PATCH] Add tilled version of puts and gets at wavefront level to the functional test suite * Implemented tiled version of put*_wave and get*_wave functions * Maintain single kernel that supports both tiled and untiled versions * Disable IPC in the default RO build script [ROCm/rocshmem commit: b6d31ac7ef13187fe12736b683c5eb51411c4941] --- .../rocshmem/scripts/build_configs/ro_net | 3 +- .../scripts/functional_tests/driver.sh | 20 +++++-- .../functional_tests/extended_primitives.cpp | 60 ++----------------- .../tests/functional_tests/tester.cpp | 41 ++++++++++--- .../tests/functional_tests/tester.hpp | 3 + .../functional_tests/tester_arguments.cpp | 4 -- .../functional_tests/tester_arguments.hpp | 1 - .../wave_level_primitives.cpp | 27 +++++---- .../wave_level_primitives.hpp | 1 + 9 files changed, 78 insertions(+), 82 deletions(-) diff --git a/projects/rocshmem/scripts/build_configs/ro_net b/projects/rocshmem/scripts/build_configs/ro_net index 17809fa0a9..c1d09194f7 100755 --- a/projects/rocshmem/scripts/build_configs/ro_net +++ b/projects/rocshmem/scripts/build_configs/ro_net @@ -17,8 +17,9 @@ cmake \ -DDEBUG=OFF \ -DPROFILE=OFF \ -DUSE_GPU_IB=OFF \ + -DUSE_RO=ON \ -DUSE_DC=OFF \ - -DUSE_IPC=ON \ + -DUSE_IPC=OFF \ -DUSE_THREADS=ON \ -DUSE_WF_COAL=OFF \ -DUSE_COHERENT_HEAP=ON \ diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index 0acfd26090..5e6bc53ae3 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -68,16 +68,16 @@ case $2 in ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z1_1MB.log check wg_putnbi_n2_w1_z1_1MB echo "wg_get_tiled_n2_w1_z1_1MB" - ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 -ti 1 > $3/wg_get_tiled_n2_w1_z1_1MB.log + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 28 > $3/wg_get_tiled_n2_w1_z1_1MB.log check wg_get_tiled_n2_w1_z1_1MB echo "wg_getnbi_tiled_n2_w1_z1_1MB" - ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 -ti 1 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log check wg_getnbi_tiled_n2_w1_z1_1MB echo "wg_put_tiled_n2_w1_z1_1MB" - ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 -ti 1 > $3/wg_put_tiled_n2_w1_z1_1MB.log + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 30 > $3/wg_put_tiled_n2_w1_z1_1MB.log check wg_put_tiled_n2_w1_z1_1MB echo "wg_putnbi_tiled_n2_w1_z1_1MB" - ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 -ti 1 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log check wg_putnbi_tiled_n2_w1_z1_1MB echo "wave_get_n2_w1_z1_1MB" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z1_1MB.log @@ -91,6 +91,18 @@ case $2 in echo "wave_putnbi_n2_w1_z1_1MB" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 35 > $3/wave_putnbi_n2_w1_z1_1MB.log check wave_putnbi_n2_w1_z1_1MB + echo "wave_get_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 32 > $3/wave_get_tiled_n2_w1_z1_1MB.log + check wave_get_tiled_n2_w1_z1_1MB + echo "wave_getnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 33 > $3/wave_getnbi_tiled_n2_w1_z1_1MB.log + check wave_getnbi_tiled_n2_w1_z1_1MB + echo "wave_put_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 34 > $3/wave_put_tiled_n2_w1_z1_1MB.log + check wave_put_tiled_n2_w1_z1_1MB + echo "wave_putnbi_tiled_n2_w1_z1_1MB" + ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 35 > $3/wave_putnbi_tiled_n2_w1_z1_1MB.log + check wave_putnbi_tiled_n2_w1_z1_1MB echo "amofadd_n2_w1_z1" ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -a 6 > $3/amofadd_n2_w1_z1.log check amofadd_n2_w1_z1 diff --git a/projects/rocshmem/tests/functional_tests/extended_primitives.cpp b/projects/rocshmem/tests/functional_tests/extended_primitives.cpp index 8eb569cd32..cc73d9c502 100644 --- a/projects/rocshmem/tests/functional_tests/extended_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/extended_primitives.cpp @@ -37,47 +37,6 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer, roc_shmem_wg_init(); roc_shmem_wg_ctx_create(ctx_type, &ctx); - uint64_t start; - - for (int i = 0; i < loop + skip; i++) { - if (i == skip) start = roc_shmem_timer(); - - switch (type) { - case WGGetTestType: - roc_shmemx_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1); - break; - case WGGetNBITestType: - roc_shmemx_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1); - break; - case WGPutTestType: - roc_shmemx_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1); - break; - case WGPutNBITestType: - roc_shmemx_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1); - break; - default: - break; - } - } - - roc_shmem_ctx_quiet(ctx); - - if (hipThreadIdx_x == 0) { - timer[hipBlockIdx_x] = roc_shmem_timer() - start; - } - - roc_shmem_wg_ctx_destroy(&ctx); - roc_shmem_wg_finalize(); -} - -__global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer, - char *s_buf, char *r_buf, int size, - TestType type, - ShmemContextType ctx_type) { - __shared__ roc_shmem_ctx_t ctx; - roc_shmem_wg_init(); - roc_shmem_wg_ctx_create(ctx_type, &ctx); - uint64_t start; uint64_t idx = size * get_flat_grid_id(); s_buf += idx; @@ -119,8 +78,8 @@ __global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer, *****************************************************************************/ ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size); - r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size); + s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs); + r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs); } ExtendedPrimitiveTester::~ExtendedPrimitiveTester() { @@ -129,24 +88,17 @@ ExtendedPrimitiveTester::~ExtendedPrimitiveTester() { } void ExtendedPrimitiveTester::resetBuffers(uint64_t size) { - memset(s_buf, '0', args.max_msg_size * args.wg_size); - memset(r_buf, '1', args.max_msg_size * args.wg_size); + memset(s_buf, '0', size * args.num_wgs); + memset(r_buf, '1', size * args.num_wgs); } void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; - if (args.tiled){ - hipLaunchKernelGGL(ExtendedPrimitiveTestTiled, gridSize, blockSize, shared_bytes, + hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, timer, s_buf, r_buf, size, _type, _shmem_context); - } - else { - hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes, - stream, loop, args.skip, timer, s_buf, r_buf, size, _type, - _shmem_context); - } num_msgs = (loop + args.skip) * gridSize.x; num_timed_msgs = loop * gridSize.x; @@ -158,7 +110,7 @@ void ExtendedPrimitiveTester::verifyResults(uint64_t size) { : 1; if (args.myid == check_id) { - for (int i = 0; i < size; i++) { + for (int i = 0; i < size * args.num_wgs; i++) { if (r_buf[i] != '0') { fprintf(stderr, "Data validation error at idx %d\n", i); fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0'); diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 22ab55f4e1..cbd1ba89a5 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -57,6 +57,9 @@ Tester::Tester(TesterArguments args) : args(args) { _type = (TestType)args.algorithm; _shmem_context = args.shmem_context; + CHECK_HIP(hipGetDevice(&device_id)); + CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id)); + num_warps = args.wg_size / deviceProps.warpSize; CHECK_HIP(hipStreamCreate(&stream)); CHECK_HIP(hipEventCreate(&start_event)); CHECK_HIP(hipEventCreate(&stop_event)); @@ -470,28 +473,32 @@ std::vector Tester::create(TesterArguments args) { return testers; case WGGetTestType: if (rank == 0) { - if (args.tiled) std::cout << "Tiled Blocking WG level Gets***" << std::endl; + if (args.num_wgs > 1) + std::cout << "Tiled Blocking WG level Gets***" << std::endl; else std::cout << "Blocking WG level Gets***" << std::endl; } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGGetNBITestType: if (rank == 0) { - if (args.tiled) std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl; + if (args.num_wgs > 1) + std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl; else std::cout << "Non-Blocking WG level Gets***" << std::endl; } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGPutTestType: if (rank == 0) { - if (args.tiled) std::cout << "Tiled Blocking WG level Puts***" << std::endl; + if (args.num_wgs > 1) + std::cout << "Tiled Blocking WG level Puts***" << std::endl; else std::cout << "Blocking WG level Puts***" << std::endl; } testers.push_back(new ExtendedPrimitiveTester(args)); return testers; case WGPutNBITestType: if (rank == 0) { - if(args.tiled) std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl; + if (args.num_wgs > 1) + std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl; else std::cout << "Non-Blocking WG level Puts***" << std::endl; } testers.push_back(new ExtendedPrimitiveTester(args)); @@ -502,19 +509,35 @@ std::vector Tester::create(TesterArguments args) { testers.push_back(new PrimitiveMRTester(args)); return testers; case WAVEGetTestType: - if (rank == 0) std::cout << "WAVE Blocking Gets***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1 || args.wg_size / 64 > 1) + std::cout << "Tiled Blocking WAVE level Gets***" << std::endl; + else std::cout << "Blocking WAVE level Gets***" << std::endl; + } testers.push_back(new WaveLevelPrimitiveTester(args)); return testers; case WAVEGetNBITestType: - if (rank == 0) std::cout << "WAVE Non-Blocking Gets***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1 || args.wg_size / 64 > 1) + std::cout << "Tiled Non-Blocking WAVE level Gets***" << std::endl; + else std::cout << "Non-Blocking WAVE level Gets***" << std::endl; + } testers.push_back(new WaveLevelPrimitiveTester(args)); return testers; case WAVEPutTestType: - if (rank == 0) std::cout << "WAVE Blocking Puts***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1 || args.wg_size / 64 > 1) + std::cout << "Tiled Blocking WAVE level Puts***" << std::endl; + else std::cout << "Blocking WAVE level Puts***" << std::endl; + } testers.push_back(new WaveLevelPrimitiveTester(args)); return testers; case WAVEPutNBITestType: - if (rank == 0) std::cout << "WAVE Non-Blocking Puts***" << std::endl; + if (rank == 0) { + if (args.num_wgs > 1 || args.wg_size / 64 > 1) + std::cout << "Tiled Non-Blocking WAVE level Puts***" << std::endl; + else std::cout << "Non-Blocking WAVE level Puts***" << std::endl; + } testers.push_back(new WaveLevelPrimitiveTester(args)); return testers; default: @@ -684,6 +707,8 @@ uint64_t Tester::gpuCyclesToMicroseconds(uint64_t cycles) { uint64_t Tester::timerAvgInMicroseconds() { uint64_t sum = 0; + //TODO: Modify the calcuation for the Tiled version of puts and gets at + // wavefront level (bpotter/avinash) for (int i = 0; i < args.num_wgs; i++) { sum += gpuCyclesToMicroseconds(timer[i]); } diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index 14a91435e5..e0482868e8 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -117,7 +117,9 @@ class Tester { int num_msgs = 0; int num_timed_msgs = 0; + int num_warps = 0; int bw_factor = 1; + int device_id = 0; TesterArguments args; @@ -125,6 +127,7 @@ class Tester { ShmemContextType _shmem_context = 8; // SHMEM_CTX_WP_PRIVATE hipStream_t stream; + hipDeviceProp_t deviceProps; uint64_t *timer = nullptr; diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index 014d5d1388..88e4b12cd6 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -60,9 +60,6 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { } else if (arg == "-x") { i++; shmem_context = atoi(argv[i]); - } else if (arg == "-ti") { - i++; - tiled = (atoi(argv[i]) == 1); } else { show_usage(argv[0]); exit(-1); @@ -122,7 +119,6 @@ void TesterArguments::show_usage(std::string executable_name) { std::cout << "\t-o \n"; std::cout << "\t-ta \n"; std::cout << "\t-x \n"; - std::cout << "\t-ti \n"; } void TesterArguments::get_rocshmem_arguments() { diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp index 810d00d825..88ff6a5537 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp @@ -58,7 +58,6 @@ class TesterArguments { unsigned coal_coef = 64; unsigned op_type = 0; unsigned shmem_context = rocshmem::ROC_SHMEM_CTX_WG_PRIVATE; - bool tiled = false; /** * Arguments obtained from rocshmem diff --git a/projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp b/projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp index 6cb2ee8a80..bb7c299ef9 100644 --- a/projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/wave_level_primitives.cpp @@ -27,17 +27,22 @@ using namespace rocshmem; /****************************************************************************** - * DEVICE TEST KERNEL + * DEVICE TEST KERNELS *****************************************************************************/ __global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer, char *s_buf, char *r_buf, int size, - TestType type, - ShmemContextType ctx_type) { + TestType type, ShmemContextType ctx_type, + int wf_size) { __shared__ roc_shmem_ctx_t ctx; roc_shmem_wg_init(); roc_shmem_wg_ctx_create(ctx_type, &ctx); uint64_t start; + int wf_id = get_flat_block_id() / wf_size; + int offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size); + int idx = wf_id * size + offset; + s_buf += idx; + r_buf += idx; for (int i = 0; i < loop + skip; i++) { if (i == skip) start = roc_shmem_timer(); @@ -75,8 +80,10 @@ __global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer, *****************************************************************************/ WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args) : Tester(args) { - s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size); - r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size); + s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs + * num_warps); + r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs + * num_warps); } WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() { @@ -85,8 +92,8 @@ WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() { } void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) { - memset(s_buf, '0', args.max_msg_size * args.wg_size); - memset(r_buf, '1', args.max_msg_size * args.wg_size); + memset(s_buf, '0', size * args.num_wgs * num_warps); + memset(r_buf, '1', size * args.num_wgs * num_warps); } void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, @@ -95,10 +102,10 @@ void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize, hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, timer, s_buf, r_buf, size, _type, - _shmem_context); + _shmem_context, deviceProps.warpSize); num_msgs = (loop + args.skip) * gridSize.x; - num_timed_msgs = loop; + num_timed_msgs = loop * gridSize.x; } void WaveLevelPrimitiveTester::verifyResults(uint64_t size) { @@ -107,7 +114,7 @@ void WaveLevelPrimitiveTester::verifyResults(uint64_t size) { : 1; if (args.myid == check_id) { - for (int i = 0; i < size; i++) { + for (int i = 0; i < size * args.num_wgs * num_warps; i++) { if (r_buf[i] != '0') { fprintf(stderr, "Data validation error at idx %d\n", i); fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0'); diff --git a/projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp b/projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp index 8c59d81350..3b8a98541d 100644 --- a/projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp +++ b/projects/rocshmem/tests/functional_tests/wave_level_primitives.hpp @@ -24,6 +24,7 @@ #define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_ #include "tester.hpp" +#include "../src/util.hpp" /****************************************************************************** * HOST TESTER CLASS