Add sync method to notifier class

[ROCm/rocshmem commit: 359d6be797]
This commit is contained in:
Brandon Potter
2024-08-16 10:45:33 -07:00
orang tua 9b0e4dc05d
melakukan aed0da61d0
4 mengubah file dengan 135 tambahan dan 20 penghapusan
-10
Melihat File
@@ -132,16 +132,6 @@ void threadfence() {
}
}
template <rocshmem_memory_scope s>
__device__
void sync() {
if constexpr (s == memory_scope_workgroup) {
__syncthreads();
} else {
assert(false);
}
}
} // namespace atomic
} // namespace detail
} // namespace rocshmem
+44 -6
Melihat File
@@ -31,13 +31,14 @@ namespace rocshmem {
template<detail::atomic::rocshmem_memory_scope scope>
class Notifier {
public:
__device__ uint64_t load() {
return detail::atomic::load<uint64_t, scope>(&value_, orders);
return detail::atomic::load<uint64_t, scope>(&value_, orders_);
}
__device__ void store(uint64_t val) {
detail::atomic::store<uint64_t, scope>(&value_, val, orders);
detail::atomic::store<uint64_t, scope>(&value_, val, orders_);
}
__device__ void fence() {
@@ -45,19 +46,56 @@ class Notifier {
}
__device__ void sync() {
detail::atomic::sync<scope>();
if constexpr (scope == detail::atomic::memory_scope_thread ||
scope == detail::atomic::memory_scope_wavefront) {
return;
}
if constexpr (scope == detail::atomic::memory_scope_workgroup) {
__syncthreads();
return;
}
if constexpr (scope == detail::atomic::memory_scope_system) {
assert(false);
return;
}
uint32_t done = signal_ + 1;
__syncthreads();
uint32_t retval = 0;
bool executor {!threadIdx.x && !threadIdx.y && !threadIdx.z};
if (executor) {
retval = detail::atomic::fetch_add<uint32_t, uint32_t, scope>(&count_, 1, orders_);
detail::atomic::threadfence<scope>();
}
__syncthreads();
if (retval == ((gridDim.x * gridDim.y * gridDim.z) - 1)) {
if (executor) {
detail::atomic::store<uint32_t, scope>(&count_, 0, orders_);
detail::atomic::threadfence<scope>();
auto x = detail::atomic::fetch_add<uint32_t, uint32_t, scope>(&signal_, 1, orders_);
detail::atomic::threadfence<scope>();
}
}
while (detail::atomic::load<uint32_t, scope>(&signal_, orders_) != done) {
;
}
}
private:
detail::atomic::rocshmem_memory_orders orders;
detail::atomic::rocshmem_memory_orders orders_{};
uint64_t value_{};
uint32_t signal_ {};
uint32_t count_ {};
};
template <typename ALLOCATOR, detail::atomic::rocshmem_memory_scope scope>
class NotifierProxy {
using ProxyT = DeviceProxy<ALLOCATOR, Notifier<scope>, 1>;
using ProxyT = DeviceProxy<ALLOCATOR, Notifier<scope>>;
public:
__host__ __device__ Notifier<scope>* get() { return proxy_.get(); }
@@ -55,3 +55,71 @@ TEST_F(NotifierBlockTestFixture, run_all_threads_once_512_1) {
TEST_F(NotifierBlockTestFixture, run_all_threads_once_1024_1) {
run_all_threads_once(1024, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_1) {
run_all_threads_once(1, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_2_1) {
run_all_threads_once(2, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_64_1) {
run_all_threads_once(64, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_128_1) {
run_all_threads_once(128, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_256_1) {
run_all_threads_once(256, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_512_1) {
run_all_threads_once(512, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_1) {
run_all_threads_once(1024, 1);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_2) {
run_all_threads_once(1, 2);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_2) {
run_all_threads_once(1024, 2);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_4) {
run_all_threads_once(1, 4);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_4) {
run_all_threads_once(1024, 4);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_8) {
run_all_threads_once(1, 8);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_8) {
run_all_threads_once(1024, 8);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_32) {
run_all_threads_once(1, 32);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_32) {
run_all_threads_once(1024, 32);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_64) {
run_all_threads_once(1, 64);
}
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_64) {
run_all_threads_once(1024, 64);
}
@@ -53,7 +53,7 @@ template <typename NotifierT>
__global__
void
all_threads_once(uint8_t* raw_memory,
Notifier<detail::atomic::memory_scope_workgroup> * notifier) {
NotifierT * notifier) {
if (!get_flat_id()) {
notifier->store(NOTIFIER_OFFSET);
notifier->fence();
@@ -125,13 +125,11 @@ class NotifierBlockTestFixture : public NotifierBase {
void
run_all_threads_once(uint32_t x_block_dim,
uint32_t x_grid_dim) {
new (notifier_.get()) NotifierT();
const dim3 block(x_block_dim, 1, 1);
const dim3 grid(x_grid_dim, 1, 1);
all_threads_once<NotifierT><<<grid, block>>>(raw_memory_, notifier_.get());
CHECK_HIP(hipStreamSynchronize(nullptr));
verify(x_block_dim * x_grid_dim);
}
@@ -141,6 +139,27 @@ class NotifierBlockTestFixture : public NotifierBase {
NotifierProxyT notifier_ {};
};
class NotifierAgentTestFixture : public NotifierBase {
using NotifierT = Notifier<detail::atomic::memory_scope_agent>;
using NotifierProxyT = NotifierProxy<HIPAllocator, detail::atomic::memory_scope_agent>;
public:
void
run_all_threads_once(uint32_t x_block_dim,
uint32_t x_grid_dim) {
new (notifier_.get()) NotifierT();
const dim3 block(x_block_dim, 1, 1);
const dim3 grid(x_grid_dim, 1, 1);
all_threads_once<NotifierT><<<grid, block>>>(raw_memory_, notifier_.get());
CHECK_HIP(hipStreamSynchronize(nullptr));
verify(x_block_dim * x_grid_dim);
}
/**
* @brief Used to broadcast base offset for writing.
*/
NotifierProxyT notifier_ {};
};
} // namespace rocshmem