From bae16413119ea33fca509d64e761f44b5bf73f68 Mon Sep 17 00:00:00 2001 From: Yiltan Hassan Temucin Date: Mon, 3 Feb 2025 09:55:56 -0800 Subject: [PATCH] Fix sigops functional test - Ensure quiet is called on the correct context --- .../signaling_operations_tester.cpp | 65 ++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/tests/functional_tests/signaling_operations_tester.cpp b/tests/functional_tests/signaling_operations_tester.cpp index b15f383755..b87bb72d6d 100644 --- a/tests/functional_tests/signaling_operations_tester.cpp +++ b/tests/functional_tests/signaling_operations_tester.cpp @@ -29,10 +29,9 @@ using namespace rocshmem; /****************************************************************************** * DEVICE TEST KERNEL *****************************************************************************/ -__global__ void SignalingOperationsTest(int loop, int skip, uint64_t *timer, char *s_buf, - char *r_buf, int size, uint64_t *sig_addr, - uint64_t *fetched_value, - TestType type, ShmemContextType ctx_type) { +__global__ void PutmemSignalTest(int loop, int skip, uint64_t *timer, char *s_buf, + char *r_buf, int size, uint64_t *sig_addr, + TestType type, ShmemContextType ctx_type) { __shared__ rocshmem_ctx_t ctx; rocshmem_wg_init(); rocshmem_wg_ctx_create(ctx_type, &ctx); @@ -66,15 +65,6 @@ __global__ void SignalingOperationsTest(int loop, int skip, uint64_t *timer, cha case WAVEPutSignalNBITestType: rocshmem_ctx_putmem_signal_nbi_wave(ctx, r_buf, s_buf, size, sig_addr, signal, sig_op, 1); break; - case SignalFetchTestType: - *fetched_value = rocshmem_signal_fetch(sig_addr); - break; - case WGSignalFetchTestType: - *fetched_value = rocshmem_signal_fetch_wg(sig_addr); - break; - case WAVESignalFetchTestType: - *fetched_value = rocshmem_signal_fetch_wave(sig_addr); - break; default: break; } @@ -92,6 +82,42 @@ __global__ void SignalingOperationsTest(int loop, int skip, uint64_t *timer, cha rocshmem_wg_finalize(); } +__global__ void SignalFetchTest(int loop, int skip, uint64_t *timer, uint64_t *sig_addr, + uint64_t *fetched_value, TestType type) { + rocshmem_wg_init(); + + uint64_t start; + + for (int i = 0; i < loop + skip; i++) { + if (i == skip) { + __syncthreads(); + start = rocshmem_timer(); + } + + switch (type) { + case SignalFetchTestType: + *fetched_value = rocshmem_signal_fetch(sig_addr); + break; + case WGSignalFetchTestType: + *fetched_value = rocshmem_signal_fetch_wg(sig_addr); + break; + case WAVESignalFetchTestType: + *fetched_value = rocshmem_signal_fetch_wave(sig_addr); + break; + default: + break; + } + } + + __syncthreads(); + + if (hipThreadIdx_x == 0) { + timer[hipBlockIdx_x] = rocshmem_timer() - start; + } + + rocshmem_wg_finalize(); +} + /****************************************************************************** * HOST TESTER CLASS METHODS *****************************************************************************/ @@ -120,9 +146,16 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int uint64_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(SignalingOperationsTest, gridSize, blockSize, shared_bytes, stream, - loop, args.skip, timer, s_buf, r_buf, size, sig_addr, fetched_value, - _type, _shmem_context); + if ((_type == SignalFetchTestType) || + (_type == WAVESignalFetchTestType) || + (_type == WGSignalFetchTestType)) { + hipLaunchKernelGGL(SignalFetchTest, gridSize, blockSize, shared_bytes, stream, + loop, args.skip, timer, sig_addr, fetched_value, _type); + } else { + hipLaunchKernelGGL(PutmemSignalTest, gridSize, blockSize, shared_bytes, stream, + loop, args.skip, timer, s_buf, r_buf, size, sig_addr, + _type, _shmem_context); + } num_msgs = (loop + args.skip) * gridSize.x; num_timed_msgs = loop;