From ef4a14e9478f15ff89c8f75546fc76763822fc06 Mon Sep 17 00:00:00 2001 From: Brandon Potter Date: Mon, 12 Aug 2024 11:29:31 -0700 Subject: [PATCH] Fix problems with Notifier [ROCm/rocshmem commit: 0c53a075f2faaf5717ac12d32473b59b68160c81] --- projects/rocshmem/src/atomic.hpp | 87 ++++++++++--------- projects/rocshmem/src/memory/notifier.hpp | 23 +++-- projects/rocshmem/src/memory/slab_heap.hpp | 2 +- .../tests/unit_tests/notifier_gtest.hpp | 4 +- 4 files changed, 61 insertions(+), 55 deletions(-) diff --git a/projects/rocshmem/src/atomic.hpp b/projects/rocshmem/src/atomic.hpp index 85a140f75a..7b1ce9300a 100644 --- a/projects/rocshmem/src/atomic.hpp +++ b/projects/rocshmem/src/atomic.hpp @@ -44,82 +44,91 @@ typedef enum rocshmem_memory_order { memory_order_seq_cst = __ATOMIC_SEQ_CST } rocshmem_memory_order; -template +struct rocshmem_memory_orders { + rocshmem_memory_order load {memory_order_acquire}; + rocshmem_memory_order store {memory_order_release}; + rocshmem_memory_order fence {memory_order_acq_rel}; + rocshmem_memory_order atomic {memory_order_acq_rel}; + rocshmem_memory_order weak_cas_success {memory_order_acq_rel}; + rocshmem_memory_order weak_cas_failure {memory_order_acq_rel}; + rocshmem_memory_order strong_cas_success {memory_order_acq_rel}; + rocshmem_memory_order strong_cas_failure {memory_order_acq_rel}; +}; + +template __host__ __device__ -T load(const T* address, rocshmem_memory_order order) { - return __hip_atomic_load(address, order, Scope); +T load(const T* address, rocshmem_memory_orders o) { + return __hip_atomic_load(address, o.load, s); } -template +template __host__ __device__ -void store(const T value, const T* address, rocshmem_memory_order order) { - return __hip_atomic_store(value, address, order, Scope); +void store(T* address, const T value, rocshmem_memory_orders o) { + return __hip_atomic_store(address, value, o.store, s); } -template +template __host__ __device__ -bool compare_exchange_weak(T& expected, T desired, rocshmem_memory_order success, - rocshmem_memory_order failure) { - return __hip_atomic_compare_exchange_weak(expected, desired, success, failure, Scope); +bool compare_exchange_weak(T& expected, T desired, rocshmem_memory_orders o) { + return __hip_atomic_compare_exchange_weak(expected, desired, o.weak_cas_success, o.weak_cas_failure, s); } -template +template __host__ __device__ -bool compare_exchange_strong(T& expected, T desired, rocshmem_memory_order success, - rocshmem_memory_order failure) { - return __hip_atomic_compare_exchange_strong(expected, desired, success, failure, Scope); +bool compare_exchange_strong(T& expected, T desired, rocshmem_memory_orders o) { + return __hip_atomic_compare_exchange_strong(expected, desired, o.strong_cas_success, o.strong_cas_failure, s); } -template +template __host__ __device__ -T fetch_add(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_add(obj, arg, order, Scope); +T fetch_add(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_add(obj, arg, o.atomic, s); } -template +template __host__ __device__ -T fetch_sub(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_sub(obj, arg, order, Scope); +T fetch_sub(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_sub(obj, arg, o.atomic, s); } -template +template __host__ __device__ -T fetch_and(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_and(obj, arg, order, Scope); +T fetch_and(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_and(obj, arg, o.atomic, s); } -template +template __host__ __device__ -T fetch_or(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_or(obj, arg, order, Scope); +T fetch_or(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_or(obj, arg, o, s); } -template +template __host__ __device__ -T fetch_xor(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_xor(obj, arg, order, Scope); +T fetch_xor(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_xor(obj, arg, o.atomic, s); } -template +template __host__ __device__ -T fetch_max(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_max(obj, arg, order, Scope); +T fetch_max(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_max(obj, arg, o.atomic, s); } -template +template __host__ __device__ -T fetch_min(T* obj, U arg, rocshmem_memory_order order) { - return __hip_atomic_fetch_min(obj, arg, order, Scope); +T fetch_min(T* obj, U arg, rocshmem_memory_orders o) { + return __hip_atomic_fetch_min(obj, arg, o.atomic, s); } -template +template __device__ void thread_fence([[maybe_unused]] rocshmem_memory_order order) { - if constexpr (Scope == memory_scope_system) { + if constexpr (s == memory_scope_system) { __threadfence_system(); - } else if constexpr (Scope == memory_scope_agent) { + } else if constexpr (s == memory_scope_agent) { __threadfence(); - } else if constexpr (Scope == memory_scope_workgroup) { + } else if constexpr (s == memory_scope_workgroup) { __threadfence_block(); } } diff --git a/projects/rocshmem/src/memory/notifier.hpp b/projects/rocshmem/src/memory/notifier.hpp index d398110e9a..e6364a26cf 100644 --- a/projects/rocshmem/src/memory/notifier.hpp +++ b/projects/rocshmem/src/memory/notifier.hpp @@ -29,20 +29,15 @@ namespace rocshmem { -template +template class Notifier { -}; - -template -class Notifier { public: - __device__ uint64_t read() { return value_; } + __device__ uint64_t read() { + return detail::atomic::load(&value_, orders); + } __device__ void write(uint64_t val) { - if (is_thread_zero_in_block()) { - value_ = val; - } - publish(); + detail::atomic::store(&value_, val, orders); } __device__ void done() { __syncthreads(); } @@ -55,15 +50,17 @@ class Notifier { __syncthreads(); } + detail::atomic::rocshmem_memory_orders orders; + uint64_t value_{}; }; -template +template class NotifierProxy { - using ProxyT = DeviceProxy, 1>; + using ProxyT = DeviceProxy, 1>; public: - __host__ __device__ Notifier* get() { return proxy_.get(); } + __host__ __device__ Notifier* get() { return proxy_.get(); } private: ProxyT proxy_{}; diff --git a/projects/rocshmem/src/memory/slab_heap.hpp b/projects/rocshmem/src/memory/slab_heap.hpp index a3655b27c1..171332bee3 100644 --- a/projects/rocshmem/src/memory/slab_heap.hpp +++ b/projects/rocshmem/src/memory/slab_heap.hpp @@ -48,7 +48,7 @@ class SlabHeap { /** * @brief Helper type for notifier */ - using NOTIFIER_PROXY_T = NotifierProxy; + using NOTIFIER_PROXY_T = NotifierProxy; /** * @brief Helper type for notifier diff --git a/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp b/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp index ecb2e7a619..4a72212f5d 100644 --- a/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp +++ b/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp @@ -52,7 +52,7 @@ write_to_memory(uint8_t* raw_memory) { __global__ void all_threads_once(uint8_t* raw_memory, - Notifier* notifier) { + Notifier * notifier) { notifier->write(NOTIFIER_OFFSET); uint64_t offset_u64 {notifier->read()}; notifier->done(); @@ -65,7 +65,7 @@ all_threads_once(uint8_t* raw_memory, } class NotifierTestFixture : public ::testing::Test { - using NotifierProxyT = NotifierProxy; + using NotifierProxyT = NotifierProxy; public: NotifierTestFixture() {