@@ -132,16 +132,6 @@ void threadfence() {
|
||||
}
|
||||
}
|
||||
|
||||
template <rocshmem_memory_scope s>
|
||||
__device__
|
||||
void sync() {
|
||||
if constexpr (s == memory_scope_workgroup) {
|
||||
__syncthreads();
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace atomic
|
||||
} // namespace detail
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -31,13 +31,14 @@ namespace rocshmem {
|
||||
|
||||
template<detail::atomic::rocshmem_memory_scope scope>
|
||||
class Notifier {
|
||||
|
||||
public:
|
||||
__device__ uint64_t load() {
|
||||
return detail::atomic::load<uint64_t, scope>(&value_, orders);
|
||||
return detail::atomic::load<uint64_t, scope>(&value_, orders_);
|
||||
}
|
||||
|
||||
__device__ void store(uint64_t val) {
|
||||
detail::atomic::store<uint64_t, scope>(&value_, val, orders);
|
||||
detail::atomic::store<uint64_t, scope>(&value_, val, orders_);
|
||||
}
|
||||
|
||||
__device__ void fence() {
|
||||
@@ -45,19 +46,56 @@ class Notifier {
|
||||
}
|
||||
|
||||
__device__ void sync() {
|
||||
detail::atomic::sync<scope>();
|
||||
if constexpr (scope == detail::atomic::memory_scope_thread ||
|
||||
scope == detail::atomic::memory_scope_wavefront) {
|
||||
return;
|
||||
}
|
||||
if constexpr (scope == detail::atomic::memory_scope_workgroup) {
|
||||
__syncthreads();
|
||||
return;
|
||||
}
|
||||
if constexpr (scope == detail::atomic::memory_scope_system) {
|
||||
assert(false);
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t done = signal_ + 1;
|
||||
__syncthreads();
|
||||
|
||||
uint32_t retval = 0;
|
||||
bool executor {!threadIdx.x && !threadIdx.y && !threadIdx.z};
|
||||
if (executor) {
|
||||
retval = detail::atomic::fetch_add<uint32_t, uint32_t, scope>(&count_, 1, orders_);
|
||||
detail::atomic::threadfence<scope>();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (retval == ((gridDim.x * gridDim.y * gridDim.z) - 1)) {
|
||||
if (executor) {
|
||||
detail::atomic::store<uint32_t, scope>(&count_, 0, orders_);
|
||||
detail::atomic::threadfence<scope>();
|
||||
auto x = detail::atomic::fetch_add<uint32_t, uint32_t, scope>(&signal_, 1, orders_);
|
||||
detail::atomic::threadfence<scope>();
|
||||
}
|
||||
}
|
||||
while (detail::atomic::load<uint32_t, scope>(&signal_, orders_) != done) {
|
||||
;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
detail::atomic::rocshmem_memory_orders orders;
|
||||
detail::atomic::rocshmem_memory_orders orders_{};
|
||||
|
||||
uint64_t value_{};
|
||||
|
||||
uint32_t signal_ {};
|
||||
|
||||
uint32_t count_ {};
|
||||
};
|
||||
|
||||
template <typename ALLOCATOR, detail::atomic::rocshmem_memory_scope scope>
|
||||
class NotifierProxy {
|
||||
using ProxyT = DeviceProxy<ALLOCATOR, Notifier<scope>, 1>;
|
||||
using ProxyT = DeviceProxy<ALLOCATOR, Notifier<scope>>;
|
||||
|
||||
public:
|
||||
__host__ __device__ Notifier<scope>* get() { return proxy_.get(); }
|
||||
|
||||
@@ -55,3 +55,71 @@ TEST_F(NotifierBlockTestFixture, run_all_threads_once_512_1) {
|
||||
TEST_F(NotifierBlockTestFixture, run_all_threads_once_1024_1) {
|
||||
run_all_threads_once(1024, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_1) {
|
||||
run_all_threads_once(1, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_2_1) {
|
||||
run_all_threads_once(2, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_64_1) {
|
||||
run_all_threads_once(64, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_128_1) {
|
||||
run_all_threads_once(128, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_256_1) {
|
||||
run_all_threads_once(256, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_512_1) {
|
||||
run_all_threads_once(512, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_1) {
|
||||
run_all_threads_once(1024, 1);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_2) {
|
||||
run_all_threads_once(1, 2);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_2) {
|
||||
run_all_threads_once(1024, 2);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_4) {
|
||||
run_all_threads_once(1, 4);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_4) {
|
||||
run_all_threads_once(1024, 4);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_8) {
|
||||
run_all_threads_once(1, 8);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_8) {
|
||||
run_all_threads_once(1024, 8);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_32) {
|
||||
run_all_threads_once(1, 32);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_32) {
|
||||
run_all_threads_once(1024, 32);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1_64) {
|
||||
run_all_threads_once(1, 64);
|
||||
}
|
||||
|
||||
TEST_F(NotifierAgentTestFixture, run_all_threads_once_1024_64) {
|
||||
run_all_threads_once(1024, 64);
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ template <typename NotifierT>
|
||||
__global__
|
||||
void
|
||||
all_threads_once(uint8_t* raw_memory,
|
||||
Notifier<detail::atomic::memory_scope_workgroup> * notifier) {
|
||||
NotifierT * notifier) {
|
||||
if (!get_flat_id()) {
|
||||
notifier->store(NOTIFIER_OFFSET);
|
||||
notifier->fence();
|
||||
@@ -125,13 +125,11 @@ class NotifierBlockTestFixture : public NotifierBase {
|
||||
void
|
||||
run_all_threads_once(uint32_t x_block_dim,
|
||||
uint32_t x_grid_dim) {
|
||||
new (notifier_.get()) NotifierT();
|
||||
const dim3 block(x_block_dim, 1, 1);
|
||||
const dim3 grid(x_grid_dim, 1, 1);
|
||||
|
||||
all_threads_once<NotifierT><<<grid, block>>>(raw_memory_, notifier_.get());
|
||||
|
||||
CHECK_HIP(hipStreamSynchronize(nullptr));
|
||||
|
||||
verify(x_block_dim * x_grid_dim);
|
||||
}
|
||||
|
||||
@@ -141,6 +139,27 @@ class NotifierBlockTestFixture : public NotifierBase {
|
||||
NotifierProxyT notifier_ {};
|
||||
};
|
||||
|
||||
class NotifierAgentTestFixture : public NotifierBase {
|
||||
using NotifierT = Notifier<detail::atomic::memory_scope_agent>;
|
||||
using NotifierProxyT = NotifierProxy<HIPAllocator, detail::atomic::memory_scope_agent>;
|
||||
|
||||
public:
|
||||
void
|
||||
run_all_threads_once(uint32_t x_block_dim,
|
||||
uint32_t x_grid_dim) {
|
||||
new (notifier_.get()) NotifierT();
|
||||
const dim3 block(x_block_dim, 1, 1);
|
||||
const dim3 grid(x_grid_dim, 1, 1);
|
||||
all_threads_once<NotifierT><<<grid, block>>>(raw_memory_, notifier_.get());
|
||||
CHECK_HIP(hipStreamSynchronize(nullptr));
|
||||
verify(x_block_dim * x_grid_dim);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Used to broadcast base offset for writing.
|
||||
*/
|
||||
NotifierProxyT notifier_ {};
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
|
||||
Reference in New Issue
Block a user