diff --git a/projects/rocshmem/src/atomic.hpp b/projects/rocshmem/src/atomic.hpp index 330bd82f86..eadacf7284 100644 --- a/projects/rocshmem/src/atomic.hpp +++ b/projects/rocshmem/src/atomic.hpp @@ -132,16 +132,6 @@ void threadfence() { } } -template -__device__ -void sync() { - if constexpr (s == memory_scope_workgroup) { - __syncthreads(); - } else { - assert(false); - } -} - } // namespace atomic } // namespace detail } // namespace rocshmem diff --git a/projects/rocshmem/src/memory/notifier.hpp b/projects/rocshmem/src/memory/notifier.hpp index 12f53a4c8b..f01f7f8301 100644 --- a/projects/rocshmem/src/memory/notifier.hpp +++ b/projects/rocshmem/src/memory/notifier.hpp @@ -31,13 +31,14 @@ namespace rocshmem { template class Notifier { + public: __device__ uint64_t load() { - return detail::atomic::load(&value_, orders); + return detail::atomic::load(&value_, orders_); } __device__ void store(uint64_t val) { - detail::atomic::store(&value_, val, orders); + detail::atomic::store(&value_, val, orders_); } __device__ void fence() { @@ -45,19 +46,56 @@ class Notifier { } __device__ void sync() { - detail::atomic::sync(); + if constexpr (scope == detail::atomic::memory_scope_thread || + scope == detail::atomic::memory_scope_wavefront) { + return; + } + if constexpr (scope == detail::atomic::memory_scope_workgroup) { + __syncthreads(); + return; + } + if constexpr (scope == detail::atomic::memory_scope_system) { + assert(false); + return; + } + + uint32_t done = signal_ + 1; + __syncthreads(); + + uint32_t retval = 0; + bool executor {!threadIdx.x && !threadIdx.y && !threadIdx.z}; + if (executor) { + retval = detail::atomic::fetch_add(&count_, 1, orders_); + detail::atomic::threadfence(); + } + __syncthreads(); + + if (retval == ((gridDim.x * gridDim.y * gridDim.z) - 1)) { + if (executor) { + detail::atomic::store(&count_, 0, orders_); + detail::atomic::threadfence(); + auto x = detail::atomic::fetch_add(&signal_, 1, orders_); + detail::atomic::threadfence(); + } + } + while (detail::atomic::load(&signal_, orders_) != done) { + ; + } } private: - - detail::atomic::rocshmem_memory_orders orders; + detail::atomic::rocshmem_memory_orders orders_{}; uint64_t value_{}; + + uint32_t signal_ {}; + + uint32_t count_ {}; }; template class NotifierProxy { - using ProxyT = DeviceProxy, 1>; + using ProxyT = DeviceProxy>; public: __host__ __device__ Notifier* get() { return proxy_.get(); } diff --git a/projects/rocshmem/tests/unit_tests/notifier_gtest.cpp b/projects/rocshmem/tests/unit_tests/notifier_gtest.cpp index 073c37a495..e6275b10a1 100644 --- a/projects/rocshmem/tests/unit_tests/notifier_gtest.cpp +++ b/projects/rocshmem/tests/unit_tests/notifier_gtest.cpp @@ -55,3 +55,71 @@ TEST_F(NotifierBlockTestFixture, run_all_threads_once_512_1) { TEST_F(NotifierBlockTestFixture, run_all_threads_once_1024_1) { run_all_threads_once(1024, 1); } + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_1) { + run_all_threads_once(1, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_2_1) { + run_all_threads_once(2, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_64_1) { + run_all_threads_once(64, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_128_1) { + run_all_threads_once(128, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_256_1) { + run_all_threads_once(256, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_512_1) { + run_all_threads_once(512, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_1) { + run_all_threads_once(1024, 1); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_2) { + run_all_threads_once(1, 2); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_2) { + run_all_threads_once(1024, 2); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_4) { + run_all_threads_once(1, 4); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_4) { + run_all_threads_once(1024, 4); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_8) { + run_all_threads_once(1, 8); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_8) { + run_all_threads_once(1024, 8); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_32) { + run_all_threads_once(1, 32); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_32) { + run_all_threads_once(1024, 32); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_64) { + run_all_threads_once(1, 64); +} + +TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_64) { + run_all_threads_once(1024, 64); +} diff --git a/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp b/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp index b82f47b1d9..e130159b2e 100644 --- a/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp +++ b/projects/rocshmem/tests/unit_tests/notifier_gtest.hpp @@ -53,7 +53,7 @@ template __global__ void all_threads_once(uint8_t* raw_memory, - Notifier * notifier) { + NotifierT * notifier) { if (!get_flat_id()) { notifier->store(NOTIFIER_OFFSET); notifier->fence(); @@ -125,13 +125,11 @@ class NotifierBlockTestFixture : public NotifierBase { void run_all_threads_once(uint32_t x_block_dim, uint32_t x_grid_dim) { + new (notifier_.get()) NotifierT(); const dim3 block(x_block_dim, 1, 1); const dim3 grid(x_grid_dim, 1, 1); - all_threads_once<<>>(raw_memory_, notifier_.get()); - CHECK_HIP(hipStreamSynchronize(nullptr)); - verify(x_block_dim * x_grid_dim); } @@ -141,6 +139,27 @@ class NotifierBlockTestFixture : public NotifierBase { NotifierProxyT notifier_ {}; }; +class NotifierAgentTestFixture : public NotifierBase { + using NotifierT = Notifier; + using NotifierProxyT = NotifierProxy; + + public: + void + run_all_threads_once(uint32_t x_block_dim, + uint32_t x_grid_dim) { + new (notifier_.get()) NotifierT(); + const dim3 block(x_block_dim, 1, 1); + const dim3 grid(x_grid_dim, 1, 1); + all_threads_once<<>>(raw_memory_, notifier_.get()); + CHECK_HIP(hipStreamSynchronize(nullptr)); + verify(x_block_dim * x_grid_dim); + } + + /** + * @brief Used to broadcast base offset for writing. + */ + NotifierProxyT notifier_ {}; +}; } // namespace rocshmem