From 8f135af15648cf3633a18dbafce2f6450c8de316 Mon Sep 17 00:00:00 2001 From: Yiltan Date: Mon, 28 Apr 2025 16:06:05 -0400 Subject: [PATCH] Check RMA functional test data in GPU kernel (#91) [ROCm/rocshmem commit: c81722c33960a4be584fd04d7c0bb1ce7b541aa3] --- .../functional_tests/primitive_tester.cpp | 22 ++++++--- .../tests/functional_tests/tester.cpp | 3 ++ .../tests/functional_tests/tester.hpp | 3 ++ .../verify_results_kernels.hpp | 46 +++++++++++++++++++ .../functional_tests/wavefront_primitives.cpp | 22 ++++++--- .../functional_tests/workgroup_primitives.cpp | 22 ++++++--- 6 files changed, 100 insertions(+), 18 deletions(-) create mode 100644 projects/rocshmem/tests/functional_tests/verify_results_kernels.hpp diff --git a/projects/rocshmem/tests/functional_tests/primitive_tester.cpp b/projects/rocshmem/tests/functional_tests/primitive_tester.cpp index f5e3fcf550..adecd00dfd 100644 --- a/projects/rocshmem/tests/functional_tests/primitive_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/primitive_tester.cpp @@ -181,13 +181,23 @@ void PrimitiveTester::verifyResults(uint64_t size) { if (args.myid == check_id) { size_t buff_size = size * args.wg_size * args.num_wgs; - for (uint64_t i = 0; i < buff_size; i++) { - if (dest[i] != source[i]) { - std::cerr << "Data validation error at idx " << i << std::endl; - std::cerr << " Got " << dest[i] << ", Expected " - << source[i] << std::endl; - exit(-1); + size_t verify_wg_size = std::min((size_t) 1024, buff_size); + size_t verify_num_wgs = buff_size / verify_wg_size; + + hipLaunchKernelGGL(verify_results_kernel_char, verify_num_wgs, verify_wg_size, 0, stream, + source, dest, buff_size, verification_error); + CHECK_HIP(hipStreamSynchronize(stream)); + + if (*verification_error) { + for (uint64_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; + exit(-1); + } } + *verification_error = false; } } } diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index b698714168..cccb164c99 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -83,6 +83,8 @@ Tester::Tester(TesterArguments args) : args(args) { CHECK_HIP(hipMalloc((void**)&timer, sizeof(long long int) * num_timers)); CHECK_HIP(hipMalloc((void**)&start_time, sizeof(long long int) * num_timers)); CHECK_HIP(hipMalloc((void**)&end_time, sizeof(long long int) * num_timers)); + CHECK_HIP(hipHostMalloc((void**)&verification_error, sizeof(bool))); + *verification_error = false; } Tester::~Tester() { @@ -92,6 +94,7 @@ Tester::~Tester() { CHECK_HIP(hipEventDestroy(stop_event)); CHECK_HIP(hipEventDestroy(start_event)); CHECK_HIP(hipStreamDestroy(stream)); + CHECK_HIP(hipFree(verification_error)); } std::vector Tester::create(TesterArguments args) { diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index ebb0913a07..120616f298 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -31,6 +31,7 @@ #include "tester_arguments.hpp" #include "../src/util.hpp" +#include "verify_results_kernels.hpp" /****************************************************************************** * TESTER CLASS TYPES @@ -162,6 +163,8 @@ class Tester { long long int max_end_time = 0; uint32_t num_timers = 0; + bool *verification_error; + private: bool _print_header = 1; void print(uint64_t size); diff --git a/projects/rocshmem/tests/functional_tests/verify_results_kernels.hpp b/projects/rocshmem/tests/functional_tests/verify_results_kernels.hpp new file mode 100644 index 0000000000..ecc527a790 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/verify_results_kernels.hpp @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _VERIFY_RESULTS_KERNELS_HPP_ +#define _VERIFY_RESULTS_KERNELS_HPP_ + +namespace rocshmem { + +static __global__ void verify_results_kernel_char(char *source, char *dest, size_t buf_size, + bool *verification_error) { + + int idx = get_flat_id(); + + if (idx >= buf_size) { + return; + } + + if (dest[idx] != source[idx]) { + *verification_error = true; + } +} + +} + +#endif /* _VERIFY_RESULTS_KERNELS_HPP_ */ diff --git a/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp b/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp index e2744b263c..4a208ffc40 100644 --- a/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp @@ -145,13 +145,23 @@ void WaveFrontPrimitiveTester::verifyResults(uint64_t size) { if (args.myid == check_id) { size_t buff_size = size * args.num_wgs * num_warps; - for (size_t i = 0; i < buff_size; i++) { - if (dest[i] != source[i]) { - std::cerr << "Data validation error at idx " << i << std::endl; - std::cerr << " Got " << dest[i] << ", Expected " - << source[i] << std::endl; - exit(-1); + size_t verify_wg_size = std::min((size_t) 1024, buff_size); + size_t verify_num_wgs = buff_size / verify_wg_size; + + hipLaunchKernelGGL(verify_results_kernel_char, verify_num_wgs, verify_wg_size, 0, stream, + source, dest, buff_size, verification_error); + CHECK_HIP(hipStreamSynchronize(stream)); + + if (*verification_error) { + for (size_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; + exit(-1); + } } + *verification_error = false; } } } diff --git a/projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp b/projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp index a19954f45c..27fa8c961f 100644 --- a/projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/workgroup_primitives.cpp @@ -140,13 +140,23 @@ void WorkGroupPrimitiveTester::verifyResults(uint64_t size) { if (args.myid == check_id) { size_t buff_size = size * args.num_wgs; - for (size_t i = 0; i < buff_size; i++) { - if (dest[i] != source[i]) { - std::cerr << "Data validation error at idx " << i << std::endl; - std::cerr << " Got " << dest[i] << ", Expected " - << source[i] << std::endl; - exit(-1); + size_t verify_wg_size = std::min((size_t) 1024, buff_size); + size_t verify_num_wgs = buff_size / verify_wg_size; + + hipLaunchKernelGGL(verify_results_kernel_char, verify_num_wgs, verify_wg_size, 0, stream, + source, dest, buff_size, verification_error); + CHECK_HIP(hipStreamSynchronize(stream)); + + if (*verification_error) { + for (size_t i = 0; i < buff_size; i++) { + if (dest[i] != source[i]) { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected " + << source[i] << std::endl; + exit(-1); + } } + *verification_error = false; } } }