From 257610bdc58e028ff3b7651949c09ba38d1b6993 Mon Sep 17 00:00:00 2001 From: Yiltan Hassan Temucin Date: Thu, 6 Feb 2025 08:16:04 -0600 Subject: [PATCH] Validate signal after put signal operations [ROCm/rocshmem commit: 8d74c7b73e39da89d39f09a02795bf5711f2354c] --- .../scripts/functional_tests/driver.sh | 3 + .../signaling_operations_tester.cpp | 98 +++++++++++++------ .../signaling_operations_tester.hpp | 2 + .../tests/functional_tests/tester.cpp | 18 ++-- 4 files changed, 84 insertions(+), 37 deletions(-) diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index 8594bff74e..2e2e835090 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -270,15 +270,18 @@ TestSigOps() { ExecTest "putsignal" 2 2 32 1048576 ExecTest "wgputsignal" 2 2 32 1048576 ExecTest "waveputsignal" 2 1 32 1048576 + ExecTest "waveputsignal" 2 1 64 1048576 ExecTest "putsignalnbi" 2 1 1 1048576 ExecTest "putsignalnbi" 2 2 32 1048576 ExecTest "wgputsignalnbi" 2 2 32 1048576 ExecTest "waveputsignalnbi" 2 1 32 1048576 + ExecTest "waveputsignalnbi" 2 1 64 1048576 ExecTest "signalfetch" 2 1 1 ExecTest "wgsignalfetch" 2 2 32 ExecTest "wavesignalfetch" 2 1 32 + ExecTest "wavesignalfetch" 2 1 64 } TestColl() { diff --git a/projects/rocshmem/tests/functional_tests/signaling_operations_tester.cpp b/projects/rocshmem/tests/functional_tests/signaling_operations_tester.cpp index b87bb72d6d..0a5ac18e5e 100644 --- a/projects/rocshmem/tests/functional_tests/signaling_operations_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/signaling_operations_tester.cpp @@ -31,14 +31,13 @@ using namespace rocshmem; *****************************************************************************/ __global__ void PutmemSignalTest(int loop, int skip, uint64_t *timer, char *s_buf, char *r_buf, int size, uint64_t *sig_addr, - TestType type, ShmemContextType ctx_type) { + TestType type, ShmemContextType ctx_type, int sig_op) { __shared__ rocshmem_ctx_t ctx; rocshmem_wg_init(); rocshmem_wg_ctx_create(ctx_type, &ctx); uint64_t start; - uint64_t signal = 0; - int sig_op = ROCSHMEM_SIGNAL_SET; + uint64_t signal = 1; for (int i = 0; i < loop + skip; i++) { if (i == skip) { @@ -128,6 +127,11 @@ SignalingOperationsTester::SignalingOperationsTester(TesterArguments args) : Tes CHECK_HIP(hipMallocManaged(&fetched_value, sizeof(uint64_t), hipMemAttachHost)); } +SignalingOperationsTester::SignalingOperationsTester(TesterArguments args, int signal_op) + : SignalingOperationsTester(args) { + sig_op = signal_op; +} + SignalingOperationsTester::~SignalingOperationsTester() { rocshmem_free(s_buf); rocshmem_free(r_buf); @@ -146,6 +150,7 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int uint64_t size) { size_t shared_bytes = 0; + if ((_type == SignalFetchTestType) || (_type == WAVESignalFetchTestType) || (_type == WGSignalFetchTestType)) { @@ -154,7 +159,7 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int } else { hipLaunchKernelGGL(PutmemSignalTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, timer, s_buf, r_buf, size, sig_addr, - _type, _shmem_context); + _type, _shmem_context, sig_op); } num_msgs = (loop + args.skip) * gridSize.x; @@ -162,35 +167,66 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int } void SignalingOperationsTester::verifyResults(uint64_t size) { - int check_data_id = (_type == PutSignalTestType || - _type == PutSignalNBITestType || - _type == WAVEPutSignalTestType || - _type == WAVEPutSignalNBITestType || - _type == WGPutSignalTestType || - _type == WGPutSignalNBITestType) - ? 1 : -1; // do not check if it doesn't match a test - - int check_fetched_value_id = (_type == SignalFetchTestType || - _type == WAVESignalFetchTestType || - _type == WGSignalFetchTestType) - ? 0 : -1; // do not check if it doesn't match a test - - if (args.myid == check_data_id) { - for (uint64_t i = 0; i < size; i++) { - if (r_buf[i] != '0') { - fprintf(stderr, "Data validation error at idx %lu\n", i); - fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0'); + if (_type == SignalFetchTestType || + _type == WAVESignalFetchTestType || + _type == WGSignalFetchTestType) { + if (0 == args.myid) { + uint64_t value = *fetched_value; + uint64_t expected_value = (args.myid + 123); + if (value != expected_value) { + fprintf(stderr, "Fetched Value %lu, Expected %lu\n", value, expected_value); exit(-1); } + return; + } + } else { + if (1 == args.myid) { + // Validate Data + for (uint64_t i = 0; i < size; i++) { + if (r_buf[i] != '0') { + fprintf(stderr, "Data validation error at idx %lu\n", i); + fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0'); + exit(-1); + } + } + // Validate Signal + if (ROCSHMEM_SIGNAL_SET == sig_op) { + uint64_t expected_value = 1; + uint64_t value = *sig_addr; + + if (value != expected_value) { + fprintf(stderr, "ROCSHMEM_SIGNAL_SET Value %lu, Expected %lu\n", value, expected_value); + exit(-1); + } + } else if (ROCSHMEM_SIGNAL_ADD == sig_op) { + uint64_t value = *sig_addr; + uint64_t expected_value = (args.myid + 123); // Initial Value + uint64_t num_waves = 1; + + switch (_type) { + case PutSignalTestType: + case PutSignalNBITestType: + expected_value += ((args.skip + args.loop) * args.wg_size * args.num_wgs); + break; + case WGPutSignalTestType: + case WGPutSignalNBITestType: + expected_value += ((args.skip + args.loop) * args.num_wgs); + break; + case WAVEPutSignalTestType: + case WAVEPutSignalNBITestType: + num_waves = max(1, (args.num_wgs / __AMDGCN_WAVEFRONT_SIZE__)); + expected_value += ((args.skip + args.loop) * num_waves); + break; + default: + fprintf(stderr, "Invalid Test\n"); + exit(-1); + } + + if (value != expected_value) { + fprintf(stderr, "ROCSHMEM_SIGNAL_ADD Value %lu, Expected %lu\n", value, expected_value); + exit(-1); + } + } } } - - if (args.myid == check_fetched_value_id) { - uint64_t value = *fetched_value; - uint64_t expected_value = (args.myid + 123); - if (value != expected_value) { - fprintf(stderr, "Fetched Value %lu, Expected %lu\n", value, expected_value); - exit(-1); - } - } } diff --git a/projects/rocshmem/tests/functional_tests/signaling_operations_tester.hpp b/projects/rocshmem/tests/functional_tests/signaling_operations_tester.hpp index acff9ea057..df47fa054a 100644 --- a/projects/rocshmem/tests/functional_tests/signaling_operations_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/signaling_operations_tester.hpp @@ -31,6 +31,7 @@ class SignalingOperationsTester : public Tester { public: explicit SignalingOperationsTester(TesterArguments args); + explicit SignalingOperationsTester(TesterArguments args, int signal_op); virtual ~SignalingOperationsTester(); protected: @@ -41,6 +42,7 @@ class SignalingOperationsTester : public Tester { virtual void verifyResults(uint64_t size) override; + int sig_op; char *s_buf = nullptr; char *r_buf = nullptr; uint64_t *sig_addr; diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 18efa2f983..c7f06be3e3 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -422,27 +422,33 @@ std::vector Tester::create(TesterArguments args) { return testers; case PutSignalTestType: if (rank == 0) std::cout << "Putmem Signal ###" << std::endl; - testers.push_back(new SignalingOperationsTester(args)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD)); return testers; case WGPutSignalTestType: if (rank == 0) std::cout << "WG Putmem Signal ###" << std::endl; - testers.push_back(new SignalingOperationsTester(args)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD)); return testers; case WAVEPutSignalTestType: if (rank == 0) std::cout << "Wave Putmem Signal ###" << std::endl; - testers.push_back(new SignalingOperationsTester(args)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD)); return testers; case PutSignalNBITestType: if (rank == 0) std::cout << "Non-Blocking Putmem Signal ###" << std::endl; - testers.push_back(new SignalingOperationsTester(args)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD)); return testers; case WGPutSignalNBITestType: if (rank == 0) std::cout << "Non-Blocking WG Putmem Signal ###" << std::endl; - testers.push_back(new SignalingOperationsTester(args)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD)); return testers; case WAVEPutSignalNBITestType: if (rank == 0) std::cout << "Non-Blocking Wave Putmem Signal ###" << std::endl; - testers.push_back(new SignalingOperationsTester(args)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET)); + testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD)); return testers; case SignalFetchTestType: if (rank == 0) std::cout << "Signal Fetch ###" << std::endl;