2
0

Simplify and remove stride based access of managed varaible test (#2677)

Este cometimento está contido em:
Jatin Chaudhary
2026-01-27 10:48:49 +00:00
cometido por GitHub
ascendente 1b55de002a
cometimento c4a9567492
+23 -22
Ver ficheiro
@@ -25,38 +25,37 @@
#include <hip_test_checkers.hh>
#define N 1048576
__managed__ float A[N]; // Accessible by ALL CPU and GPU functions !!!
__managed__ float B[N];
__managed__ int x = 0;
__managed__ float m_A[N]; // Accessible by ALL CPU and GPU functions !!!
__managed__ float m_B[N];
__managed__ int m_X = 0;
__global__ void add(const float* A, float* B) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < N; i += stride) B[i] = A[i] + B[i];
static __global__ void managed_add(size_t size) {
size_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < size) {
m_B[i] += m_A[i];
}
}
__global__ void GPU_func() { x++; }
static __global__ void managed_inc() { atomicAdd(&m_X, 1.0f); }
TEST_CASE("Unit_hipManagedKeyword_SingleGpu") {
for (int i = 0; i < N; i++) {
A[i] = 1.0f;
B[i] = 2.0f;
for (size_t i = 0; i < N; i++) {
m_A[i] = 1.0f;
m_B[i] = 2.0f;
}
int blockSize = 256;
int numBlocks = (N + blockSize - 1) / blockSize;
dim3 dimGrid(numBlocks, 1, 1);
dim3 dimBlock(blockSize, 1, 1);
hipLaunchKernelGGL(add, dimGrid, dimBlock, 0, 0, static_cast<const float*>(A),
static_cast<float*>(B));
int numBlocks = N / blockSize;
HIP_CHECK(hipGetLastError());
managed_add<<<numBlocks, blockSize>>>(N);
HIP_CHECK(hipDeviceSynchronize());
HIP_CHECK(hipGetLastError());
float maxError = 0.0f;
for (int i = 0; i < N; i++) maxError = fmax(maxError, fabs(B[i] - 3.0f));
REQUIRE(maxError == 0.0f);
for (size_t i = 0; i < N; i++) {
INFO("Reading output from managed variable: Index: " << i << " output: " << m_B[i]);
REQUIRE(3.0f == m_B[i]);
}
}
TEST_CASE("Unit_hipManagedKeyword_MultiGpu") {
@@ -74,8 +73,10 @@ TEST_CASE("Unit_hipManagedKeyword_MultiGpu") {
for (int i = 0; i < numDevices; i++) {
HIP_CHECK(hipSetDevice(i));
GPU_func<<<1, 1>>>();
managed_inc<<<1, 1>>>();
HIP_CHECK(hipDeviceSynchronize());
}
REQUIRE(x == numDevices);
INFO("Inc counter should match the device count: " << m_X << " Device count: " << numDevices);
REQUIRE(m_X == numDevices);
}