[RO] implemented signaling operations

[ROCm/rocshmem commit: f1c25f7e19]
Этот коммит содержится в:
Yiltan Hassan Temucin
2025-02-03 11:51:45 -08:00
родитель 90b8f191d6
Коммит c4f2ccd48f
3 изменённых файлов: 146 добавлений и 0 удалений
+106
Просмотреть файл
@@ -21,6 +21,7 @@
*****************************************************************************/
#include "context_ro_device.hpp"
#include "context_ro_tmpl_device.hpp"
#include <hip/hip_runtime.h>
#include <hip/amd_detail/amd_device_functions.h>
@@ -348,6 +349,111 @@ __device__ void ROContext::getmem_nbi_wave(void *dest, const void *source,
}
}
__device__ void ROContext::putmem_signal(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem(dest, source, nelems, pe);
fence();
switch (sig_op) {
case ROCSHMEM_SIGNAL_SET:
amo_set<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
case ROCSHMEM_SIGNAL_ADD:
amo_add<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
default:
DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op);
break;
}
}
__device__ void ROContext::putmem_signal_wg(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_wg(dest, source, nelems, pe);
fence();
if (is_thread_zero_in_block()) {
switch (sig_op) {
case ROCSHMEM_SIGNAL_SET:
amo_set<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
case ROCSHMEM_SIGNAL_ADD:
amo_add<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
default:
DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op);
break;
}
}
}
__device__ void ROContext::putmem_signal_wave(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_wave(dest, source, nelems, pe);
fence();
if (is_thread_zero_in_wave()) {
switch (sig_op) {
case ROCSHMEM_SIGNAL_SET:
amo_set<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
case ROCSHMEM_SIGNAL_ADD:
amo_add<uint64_t>(static_cast<void*>(sig_addr), signal, pe);
break;
default:
DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op);
break;
}
}
}
__device__ void ROContext::putmem_signal_nbi(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_signal(dest, source, nelems, sig_addr, signal, sig_op, pe);
}
__device__ void ROContext::putmem_signal_nbi_wg(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_signal_wg(dest, source, nelems, sig_addr, signal, sig_op, pe);
}
__device__ void ROContext::putmem_signal_nbi_wave(void *dest, const void *source, size_t nelems,
uint64_t *sig_addr, uint64_t signal, int sig_op,
int pe) {
putmem_signal_wave(dest, source, nelems, sig_addr, signal, sig_op, pe);
}
__device__ uint64_t ROContext::signal_fetch(const uint64_t *sig_addr) {
uint64_t *dst = const_cast<uint64_t*>(sig_addr);
return amo_fetch_add<uint64_t>(static_cast<void*>(dst), 0, my_pe);
}
__device__ uint64_t ROContext::signal_fetch_wg(const uint64_t *sig_addr) {
__shared__ uint64_t value;
if (is_thread_zero_in_block()) {
uint64_t *dst = const_cast<uint64_t*>(sig_addr);
value = amo_fetch_add<uint64_t>(static_cast<void*>(dst), 0, my_pe);
}
__threadfence_block();
return value;
}
__device__ uint64_t ROContext::signal_fetch_wave(const uint64_t *sig_addr) {
uint64_t value;
if (is_thread_zero_in_wave()) {
uint64_t *dst = const_cast<uint64_t*>(sig_addr);
value = amo_fetch_add<uint64_t>(static_cast<void*>(dst), 0, my_pe);
}
__threadfence_block();
value = __shfl(value, 0);
return value;
}
__device__ uint64_t number_active_lanes() {
return __popcll(__ballot(1));
}
+21
Просмотреть файл
@@ -229,6 +229,27 @@ class ROContext : public Context {
template <typename T>
__device__ void get_nbi_wave(T *dest, const T *source, size_t nelems, int pe);
#define RO_CONTEXT_PUT_SIGNAL_DEC(SUFFIX) \
template <typename T> \
__device__ void put_signal##SUFFIX(T *dest, const T *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe); \
\
__device__ void putmem_signal##SUFFIX(void *dest, const void *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe);
RO_CONTEXT_PUT_SIGNAL_DEC()
RO_CONTEXT_PUT_SIGNAL_DEC(_wg)
RO_CONTEXT_PUT_SIGNAL_DEC(_wave)
RO_CONTEXT_PUT_SIGNAL_DEC(_nbi)
RO_CONTEXT_PUT_SIGNAL_DEC(_nbi_wg)
RO_CONTEXT_PUT_SIGNAL_DEC(_nbi_wave)
__device__ uint64_t signal_fetch(const uint64_t *sig_addr);
__device__ uint64_t signal_fetch_wg(const uint64_t *sig_addr);
__device__ uint64_t signal_fetch_wave(const uint64_t *sig_addr);
private:
__device__ uint64_t *get_unused_atomic();
+19
Просмотреть файл
@@ -432,6 +432,25 @@ __device__ void ROContext::get_nbi_wave(T *dest, const T *source, size_t nelems,
getmem_nbi_wave(dest, source, size, pe);
}
#define RO_CONTEXT_PUT_SIGNAL_DEF(SUFFIX) \
template <typename T> \
__device__ void ROContext::put_signal##SUFFIX(T *dest, const T *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe) { \
putmem_signal##SUFFIX(dest, source, nelems * sizeof(T), sig_addr, signal, sig_op, pe); \
} \
\
template <typename T> \
__device__ void ROContext::put_signal_nbi##SUFFIX(T *dest, const T *source, size_t nelems, \
uint64_t *sig_addr, uint64_t signal, int sig_op, \
int pe) { \
putmem_signal##SUFFIX(dest, source, nelems * sizeof(T), sig_addr, signal, sig_op, pe); \
}
RO_CONTEXT_PUT_SIGNAL_DEF()
RO_CONTEXT_PUT_SIGNAL_DEF(_wg)
RO_CONTEXT_PUT_SIGNAL_DEF(_wave)
} // namespace rocshmem
#endif // LIBRARY_SRC_REVERSE_OFFLOAD_RO_NET_GPU_TEMPLATES_HPP_