Check RMA functional test data in GPU kernel (#91)

[ROCm/rocshmem commit: c81722c339]
This commit is contained in:
Yiltan
2025-04-28 16:06:05 -04:00
committad av GitHub
förälder 19e7b4798e
incheckning 8f135af156
6 ändrade filer med 100 tillägg och 18 borttagningar
@@ -181,13 +181,23 @@ void PrimitiveTester::verifyResults(uint64_t size) {
if (args.myid == check_id) {
size_t buff_size = size * args.wg_size * args.num_wgs;
for (uint64_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
size_t verify_wg_size = std::min((size_t) 1024, buff_size);
size_t verify_num_wgs = buff_size / verify_wg_size;
hipLaunchKernelGGL(verify_results_kernel_char, verify_num_wgs, verify_wg_size, 0, stream,
source, dest, buff_size, verification_error);
CHECK_HIP(hipStreamSynchronize(stream));
if (*verification_error) {
for (uint64_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
*verification_error = false;
}
}
}
@@ -83,6 +83,8 @@ Tester::Tester(TesterArguments args) : args(args) {
CHECK_HIP(hipMalloc((void**)&timer, sizeof(long long int) * num_timers));
CHECK_HIP(hipMalloc((void**)&start_time, sizeof(long long int) * num_timers));
CHECK_HIP(hipMalloc((void**)&end_time, sizeof(long long int) * num_timers));
CHECK_HIP(hipHostMalloc((void**)&verification_error, sizeof(bool)));
*verification_error = false;
}
Tester::~Tester() {
@@ -92,6 +94,7 @@ Tester::~Tester() {
CHECK_HIP(hipEventDestroy(stop_event));
CHECK_HIP(hipEventDestroy(start_event));
CHECK_HIP(hipStreamDestroy(stream));
CHECK_HIP(hipFree(verification_error));
}
std::vector<Tester*> Tester::create(TesterArguments args) {
@@ -31,6 +31,7 @@
#include "tester_arguments.hpp"
#include "../src/util.hpp"
#include "verify_results_kernels.hpp"
/******************************************************************************
* TESTER CLASS TYPES
@@ -162,6 +163,8 @@ class Tester {
long long int max_end_time = 0;
uint32_t num_timers = 0;
bool *verification_error;
private:
bool _print_header = 1;
void print(uint64_t size);
@@ -0,0 +1,46 @@
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef _VERIFY_RESULTS_KERNELS_HPP_
#define _VERIFY_RESULTS_KERNELS_HPP_
namespace rocshmem {
static __global__ void verify_results_kernel_char(char *source, char *dest, size_t buf_size,
bool *verification_error) {
int idx = get_flat_id();
if (idx >= buf_size) {
return;
}
if (dest[idx] != source[idx]) {
*verification_error = true;
}
}
}
#endif /* _VERIFY_RESULTS_KERNELS_HPP_ */
@@ -145,13 +145,23 @@ void WaveFrontPrimitiveTester::verifyResults(uint64_t size) {
if (args.myid == check_id) {
size_t buff_size = size * args.num_wgs * num_warps;
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
size_t verify_wg_size = std::min((size_t) 1024, buff_size);
size_t verify_num_wgs = buff_size / verify_wg_size;
hipLaunchKernelGGL(verify_results_kernel_char, verify_num_wgs, verify_wg_size, 0, stream,
source, dest, buff_size, verification_error);
CHECK_HIP(hipStreamSynchronize(stream));
if (*verification_error) {
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
*verification_error = false;
}
}
}
@@ -140,13 +140,23 @@ void WorkGroupPrimitiveTester::verifyResults(uint64_t size) {
if (args.myid == check_id) {
size_t buff_size = size * args.num_wgs;
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
size_t verify_wg_size = std::min((size_t) 1024, buff_size);
size_t verify_num_wgs = buff_size / verify_wg_size;
hipLaunchKernelGGL(verify_results_kernel_char, verify_num_wgs, verify_wg_size, 0, stream,
source, dest, buff_size, verification_error);
CHECK_HIP(hipStreamSynchronize(stream));
if (*verification_error) {
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != source[i]) {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected "
<< source[i] << std::endl;
exit(-1);
}
}
*verification_error = false;
}
}
}