diff --git a/src/reverse_offload/commands_types.hpp b/src/reverse_offload/commands_types.hpp index 6fd7836a10..404bafc3ca 100644 --- a/src/reverse_offload/commands_types.hpp +++ b/src/reverse_offload/commands_types.hpp @@ -53,6 +53,7 @@ enum ro_net_types { RO_NET_DOUBLE, RO_NET_INT, RO_NET_LONG, + RO_NET_UNSIGNED_LONG, RO_NET_LONG_LONG, RO_NET_SHORT, RO_NET_LONG_DOUBLE diff --git a/src/reverse_offload/context_ro_device.cpp b/src/reverse_offload/context_ro_device.cpp index 21de4b9b2f..8b82a97578 100644 --- a/src/reverse_offload/context_ro_device.cpp +++ b/src/reverse_offload/context_ro_device.cpp @@ -21,6 +21,7 @@ *****************************************************************************/ #include "context_ro_device.hpp" +#include "context_ro_tmpl_device.hpp" #include #include @@ -348,6 +349,111 @@ __device__ void ROContext::getmem_nbi_wave(void *dest, const void *source, } } +__device__ void ROContext::putmem_signal(void *dest, const void *source, size_t nelems, + uint64_t *sig_addr, uint64_t signal, int sig_op, + int pe) { + putmem(dest, source, nelems, pe); + fence(); + + switch (sig_op) { + case ROCSHMEM_SIGNAL_SET: + amo_set(static_cast(sig_addr), signal, pe); + break; + case ROCSHMEM_SIGNAL_ADD: + amo_add(static_cast(sig_addr), signal, pe); + break; + default: + DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op); + break; + } +} + +__device__ void ROContext::putmem_signal_wg(void *dest, const void *source, size_t nelems, + uint64_t *sig_addr, uint64_t signal, int sig_op, + int pe) { + putmem_wg(dest, source, nelems, pe); + fence(); + + if (is_thread_zero_in_block()) { + switch (sig_op) { + case ROCSHMEM_SIGNAL_SET: + amo_set(static_cast(sig_addr), signal, pe); + break; + case ROCSHMEM_SIGNAL_ADD: + amo_add(static_cast(sig_addr), signal, pe); + break; + default: + DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op); + break; + } + } +} + +__device__ void ROContext::putmem_signal_wave(void *dest, const void *source, size_t nelems, + uint64_t *sig_addr, uint64_t signal, int sig_op, + int pe) { + putmem_wave(dest, source, nelems, pe); + fence(); + + if (is_thread_zero_in_wave()) { + switch (sig_op) { + case ROCSHMEM_SIGNAL_SET: + amo_set(static_cast(sig_addr), signal, pe); + break; + case ROCSHMEM_SIGNAL_ADD: + amo_add(static_cast(sig_addr), signal, pe); + break; + default: + DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op); + break; + } + } +} + +__device__ void ROContext::putmem_signal_nbi(void *dest, const void *source, size_t nelems, + uint64_t *sig_addr, uint64_t signal, int sig_op, + int pe) { + putmem_signal(dest, source, nelems, sig_addr, signal, sig_op, pe); +} + +__device__ void ROContext::putmem_signal_nbi_wg(void *dest, const void *source, size_t nelems, + uint64_t *sig_addr, uint64_t signal, int sig_op, + int pe) { + putmem_signal_wg(dest, source, nelems, sig_addr, signal, sig_op, pe); +} + +__device__ void ROContext::putmem_signal_nbi_wave(void *dest, const void *source, size_t nelems, + uint64_t *sig_addr, uint64_t signal, int sig_op, + int pe) { + putmem_signal_wave(dest, source, nelems, sig_addr, signal, sig_op, pe); +} + +__device__ uint64_t ROContext::signal_fetch(const uint64_t *sig_addr) { + uint64_t *dst = const_cast(sig_addr); + return amo_fetch_add(static_cast(dst), 0, my_pe); +} + +__device__ uint64_t ROContext::signal_fetch_wg(const uint64_t *sig_addr) { + __shared__ uint64_t value; + if (is_thread_zero_in_block()) { + uint64_t *dst = const_cast(sig_addr); + value = amo_fetch_add(static_cast(dst), 0, my_pe); + } + __threadfence_block(); + return value; +} + +__device__ uint64_t ROContext::signal_fetch_wave(const uint64_t *sig_addr) { + uint64_t value; + if (is_thread_zero_in_wave()) { + uint64_t *dst = const_cast(sig_addr); + value = amo_fetch_add(static_cast(dst), 0, my_pe); + } + __threadfence_block(); + value = __shfl(value, 0); + return value; +} + __device__ uint64_t number_active_lanes() { return __popcll(__ballot(1)); } diff --git a/src/reverse_offload/context_ro_device.hpp b/src/reverse_offload/context_ro_device.hpp index 3413ded9e9..ebdfb26d9e 100644 --- a/src/reverse_offload/context_ro_device.hpp +++ b/src/reverse_offload/context_ro_device.hpp @@ -229,6 +229,27 @@ class ROContext : public Context { template __device__ void get_nbi_wave(T *dest, const T *source, size_t nelems, int pe); +#define RO_CONTEXT_PUT_SIGNAL_DEC(SUFFIX) \ + template \ + __device__ void put_signal##SUFFIX(T *dest, const T *source, size_t nelems, \ + uint64_t *sig_addr, uint64_t signal, int sig_op, \ + int pe); \ + \ + __device__ void putmem_signal##SUFFIX(void *dest, const void *source, size_t nelems, \ + uint64_t *sig_addr, uint64_t signal, int sig_op, \ + int pe); + + RO_CONTEXT_PUT_SIGNAL_DEC() + RO_CONTEXT_PUT_SIGNAL_DEC(_wg) + RO_CONTEXT_PUT_SIGNAL_DEC(_wave) + RO_CONTEXT_PUT_SIGNAL_DEC(_nbi) + RO_CONTEXT_PUT_SIGNAL_DEC(_nbi_wg) + RO_CONTEXT_PUT_SIGNAL_DEC(_nbi_wave) + + __device__ uint64_t signal_fetch(const uint64_t *sig_addr); + __device__ uint64_t signal_fetch_wg(const uint64_t *sig_addr); + __device__ uint64_t signal_fetch_wave(const uint64_t *sig_addr); + private: __device__ uint64_t *get_unused_atomic(); diff --git a/src/reverse_offload/context_ro_tmpl_device.hpp b/src/reverse_offload/context_ro_tmpl_device.hpp index c8622bfe6f..7e292f8076 100644 --- a/src/reverse_offload/context_ro_tmpl_device.hpp +++ b/src/reverse_offload/context_ro_tmpl_device.hpp @@ -61,7 +61,7 @@ struct GetROType { template <> struct GetROType { - static constexpr ro_net_types Type{RO_NET_LONG}; + static constexpr ro_net_types Type{RO_NET_UNSIGNED_LONG}; }; template <> @@ -433,6 +433,25 @@ __device__ void ROContext::get_nbi_wave(T *dest, const T *source, size_t nelems, getmem_nbi_wave(dest, source, size, pe); } +#define RO_CONTEXT_PUT_SIGNAL_DEF(SUFFIX) \ + template \ + __device__ void ROContext::put_signal##SUFFIX(T *dest, const T *source, size_t nelems, \ + uint64_t *sig_addr, uint64_t signal, int sig_op, \ + int pe) { \ + putmem_signal##SUFFIX(dest, source, nelems * sizeof(T), sig_addr, signal, sig_op, pe); \ + } \ + \ + template \ + __device__ void ROContext::put_signal_nbi##SUFFIX(T *dest, const T *source, size_t nelems, \ + uint64_t *sig_addr, uint64_t signal, int sig_op, \ + int pe) { \ + putmem_signal##SUFFIX(dest, source, nelems * sizeof(T), sig_addr, signal, sig_op, pe); \ + } + +RO_CONTEXT_PUT_SIGNAL_DEF() +RO_CONTEXT_PUT_SIGNAL_DEF(_wg) +RO_CONTEXT_PUT_SIGNAL_DEF(_wave) + } // namespace rocshmem #endif // LIBRARY_SRC_REVERSE_OFFLOAD_RO_NET_GPU_TEMPLATES_HPP_ diff --git a/src/reverse_offload/mpi_transport.cpp b/src/reverse_offload/mpi_transport.cpp index ee057e3153..fc2284186d 100644 --- a/src/reverse_offload/mpi_transport.cpp +++ b/src/reverse_offload/mpi_transport.cpp @@ -376,6 +376,8 @@ static MPI_Datatype convertType(ro_net_types type) { return MPI_INT; case RO_NET_LONG: return MPI_LONG; + case RO_NET_UNSIGNED_LONG: + return MPI_UNSIGNED_LONG; case RO_NET_LONG_LONG: return MPI_LONG_LONG; case RO_NET_SHORT: