2
0

Merge pull request #38 from Yiltan/ro/implement-sigops

Implements Signalling Operations for RO
Este cometimento está contido em:
Yiltan
2025-02-10 15:10:07 -05:00
cometido por GitHub
ascendente 94144f4460 f1c25f7e19
cometimento 495cd6970b
5 ficheiros modificados com 150 adições e 1 eliminações
+1
Ver ficheiro
@@ -53,6 +53,7 @@ enum ro_net_types {
RO_NET_DOUBLE,
RO_NET_INT,
RO_NET_LONG,
RO_NET_UNSIGNED_LONG,
RO_NET_LONG_LONG,
RO_NET_SHORT,
RO_NET_LONG_DOUBLE
+106
Ver ficheiro
@@ -21,6 +21,7 @@
*****************************************************************************/
#include "context_ro_device.hpp"
#include "context_ro_tmpl_device.hpp"
#include <hip/hip_runtime.h>
#include <hip/amd_detail/amd_device_functions.h>
@@ -348,6 +349,111 @@ __device__ void ROContext::getmem_nbi_wave(void *dest, const void *source,
}
}
__device__ void ROContext::putmem_signal(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem(dest, source, nelems, pe);
fence();
switch (sig_op) {
case ROCSHMEM_SIGNAL_SET:
amo_set<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
case ROCSHMEM_SIGNAL_ADD:
amo_add<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
default:
DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op);
break;
}
}
__device__ void ROContext::putmem_signal_wg(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_wg(dest, source, nelems, pe);
fence();
if (is_thread_zero_in_block()) {
switch (sig_op) {
case ROCSHMEM_SIGNAL_SET:
amo_set<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
case ROCSHMEM_SIGNAL_ADD:
amo_add<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
default:
DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op);
break;
}
}
}
__device__ void ROContext::putmem_signal_wave(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_wave(dest, source, nelems, pe);
fence();
if (is_thread_zero_in_wave()) {
switch (sig_op) {
case ROCSHMEM_SIGNAL_SET:
amo_set<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
case ROCSHMEM_SIGNAL_ADD:
amo_add<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
default:
DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op);
break;
}
}
}
__device__ void ROContext::putmem_signal_nbi(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_signal(dest, source, nelems, sig_addr, signal, sig_op, pe);
}
__device__ void ROContext::putmem_signal_nbi_wg(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_signal_wg(dest, source, nelems, sig_addr, signal, sig_op, pe);
}
__device__ void ROContext::putmem_signal_nbi_wave(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_signal_wave(dest, source, nelems, sig_addr, signal, sig_op, pe);
}
__device__ uint64_t ROContext::signal_fetch(const uint64_t *sig_addr) {
uint64_t *dst = const_cast<uint64_t*>(sig_addr);
return amo_fetch_add<uint64_t>(static_cast<void*>(dst), 0, my_pe);
}
__device__ uint64_t ROContext::signal_fetch_wg(const uint64_t *sig_addr) {
__shared__ uint64_t value;
if (is_thread_zero_in_block()) {
uint64_t *dst = const_cast<uint64_t*>(sig_addr);
value = amo_fetch_add<uint64_t>(static_cast<void*>(dst), 0, my_pe);
}
__threadfence_block();
return value;
}
__device__ uint64_t ROContext::signal_fetch_wave(const uint64_t *sig_addr) {
uint64_t value;
if (is_thread_zero_in_wave()) {
uint64_t *dst = const_cast<uint64_t*>(sig_addr);
value = amo_fetch_add<uint64_t>(static_cast<void*>(dst), 0, my_pe);
}
__threadfence_block();
value = __shfl(value, 0);
return value;
}
__device__ uint64_t number_active_lanes() {
return __popcll(__ballot(1));
}
+21
Ver ficheiro
@@ -229,6 +229,27 @@ class ROContext : public Context {
template <typename T>
__device__ void get_nbi_wave(T *dest, const T *source, size_t nelems, int pe);
#define RO_CONTEXT_PUT_SIGNAL_DEC(SUFFIX) \
template <typename T> \
__device__ void put_signal##SUFFIX(T *dest, const T *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe); \
\
__device__ void putmem_signal##SUFFIX(void *dest, const void *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe);
RO_CONTEXT_PUT_SIGNAL_DEC()
RO_CONTEXT_PUT_SIGNAL_DEC(_wg)
RO_CONTEXT_PUT_SIGNAL_DEC(_wave)
RO_CONTEXT_PUT_SIGNAL_DEC(_nbi)
RO_CONTEXT_PUT_SIGNAL_DEC(_nbi_wg)
RO_CONTEXT_PUT_SIGNAL_DEC(_nbi_wave)
__device__ uint64_t signal_fetch(const uint64_t *sig_addr);
__device__ uint64_t signal_fetch_wg(const uint64_t *sig_addr);
__device__ uint64_t signal_fetch_wave(const uint64_t *sig_addr);
private:
__device__ uint64_t *get_unused_atomic();
+20 -1
Ver ficheiro
@@ -61,7 +61,7 @@ struct GetROType<unsigned int> {
template <>
struct GetROType<unsigned long> {
static constexpr ro_net_types Type{RO_NET_LONG};
static constexpr ro_net_types Type{RO_NET_UNSIGNED_LONG};
};
template <>
@@ -433,6 +433,25 @@ __device__ void ROContext::get_nbi_wave(T *dest, const T *source, size_t nelems,
getmem_nbi_wave(dest, source, size, pe);
}
#define RO_CONTEXT_PUT_SIGNAL_DEF(SUFFIX) \
template <typename T> \
__device__ void ROContext::put_signal##SUFFIX(T *dest, const T *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe) { \
putmem_signal##SUFFIX(dest, source, nelems * sizeof(T), sig_addr, signal, sig_op, pe); \
} \
\
template <typename T> \
__device__ void ROContext::put_signal_nbi##SUFFIX(T *dest, const T *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe) { \
putmem_signal##SUFFIX(dest, source, nelems * sizeof(T), sig_addr, signal, sig_op, pe); \
}
RO_CONTEXT_PUT_SIGNAL_DEF()
RO_CONTEXT_PUT_SIGNAL_DEF(_wg)
RO_CONTEXT_PUT_SIGNAL_DEF(_wave)
} // namespace rocshmem
#endif // LIBRARY_SRC_REVERSE_OFFLOAD_RO_NET_GPU_TEMPLATES_HPP_
+2
Ver ficheiro
@@ -376,6 +376,8 @@ static MPI_Datatype convertType(ro_net_types type) {
return MPI_INT;
case RO_NET_LONG:
return MPI_LONG;
case RO_NET_UNSIGNED_LONG:
return MPI_UNSIGNED_LONG;
case RO_NET_LONG_LONG:
return MPI_LONG_LONG;
case RO_NET_SHORT: