Refactor RO backend data structures (#49)

- Remove hdp and ipc pointers from BlockHandle, align RO stats with RO contexts

- Add run commands for `rocshmem_g` and `rocshmem_p` API tests in driver.sh

- Allocate rocshmem API return buffers based on number of device contexts.

- Associate status flag address with blocking calls and remove threadId dependency
   - Associated the status flag address with each blocking call request to notify the GPU thread.
   - Removed dependency on threadId for determining the appropriate status flag index.

- Move status flag buffer allocation to backend.

- Initialize allocated memeory to zero
Bu işleme şunda yer alıyor:
Avinash Kethineedi
2025-03-14 10:49:44 -05:00
işlemeyi yapan: GitHub
ebeveyn 96424a59a8
işleme df4ad2c04d
16 değiştirilmiş dosya ile 389 ekleme ve 357 silme
+14
Dosyayı Görüntüle
@@ -179,6 +179,20 @@ TestRMA() {
ExecTest "teamctxget" 2 1 1 1048576
ExecTest "g" 2 1 1 1048576
ExecTest "g" 2 1 1024 512
ExecTest "g" 2 8 1 1048576
ExecTest "g" 2 16 128 8
ExecTest "g" 2 32 256 512
ExecTest "g" 2 64 1024 8
ExecTest "p" 2 1 1 1048576
ExecTest "p" 2 1 1024 512
ExecTest "p" 2 8 1 1048576
ExecTest "p" 2 16 128 8
ExecTest "p" 2 32 256 512
ExecTest "p" 2 64 1024 8
################################ Non-Blocking ################################
ExecTest "putnbi" 2 1 1 1048576
+1 -1
Dosyayı Görüntüle
@@ -51,7 +51,7 @@ class DeviceProxy {
/*
* Default memory provided by the allocation to recognizable bytes.
*/
memset(static_cast<void*>(temp), 0xBC, size_bytes);
memset(static_cast<void*>(temp), 0, size_bytes);
/*
* Pass the memory into a unique ptr for tracking.
-2
Dosyayı Görüntüle
@@ -36,10 +36,8 @@ struct BackendRegister {
std::atomic<bool> worker_thread_exit{false};
bool *needs_quiet{nullptr};
bool *needs_blocking{nullptr};
char *g_ret{nullptr};
HdpPolicy *hdp_policy{nullptr};
WindowInfo **heap_window_info{nullptr};
atomic_ret_t *atomic_ret{nullptr};
SymmetricHeap *heap_ptr{nullptr};
};
+18 -11
Dosyayı Görüntüle
@@ -64,7 +64,15 @@ ROBackend::ROBackend(MPI_Comm comm)
max_wg_size_ = device_props.maxThreadsPerBlock;
queue_ = Queue(maximum_num_contexts_, max_wg_size_, queue_size_);
size_t num_buff_elems = maximum_num_contexts_ * max_wg_size_;
g_ret_buffer_ = RetBufferProxyT(num_buff_elems);
atomic_ret_buffer_ = RetBufferProxyT(num_buff_elems);
status_ = StatusProxyT(num_buff_elems);
queue_ = Queue(maximum_num_contexts_, queue_size_);
transport_ = new MPITransport(comm, &queue_);
num_pes = transport_->getNumPes();
@@ -87,10 +95,6 @@ ROBackend::ROBackend(MPI_Comm comm)
initIPC();
init_g_ret(&heap, transport_->get_world_comm(), maximum_num_contexts_, &bp->g_ret);
allocate_atomic_region(&bp->atomic_ret, maximum_num_contexts_);
transport_->initTransport(maximum_num_contexts_, &backend_proxy);
host_interface = transport_->host_interface;
@@ -107,14 +111,17 @@ ROBackend::ROBackend(MPI_Comm comm)
reinterpret_cast<rocshmem_team_t>(team_world_proxy_->get());
default_block_handle_proxy_ = DefaultBlockHandleProxyT(
bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get());
g_ret_buffer_.get(),
atomic_ret_buffer_.get(), &queue_,
status_.get());
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
default_context_proxy_ = DefaultContextProxyT(this, tinfo);
block_handle_proxy_ = BlockHandleProxyT(bp->g_ret, bp->atomic_ret, &queue_,
&ipcImpl, hdp_proxy_.get(),
maximum_num_contexts_);
block_handle_proxy_ = BlockHandleProxyT(g_ret_buffer_.get(),
atomic_ret_buffer_.get(), &queue_,
max_wg_size_, status_.get(), maximum_num_contexts_);
setup_ctxs();
worker_thread = std::thread(&ROBackend::ro_net_poll, this);
@@ -187,7 +194,7 @@ void ROBackend::ctx_destroy(Context *ctx) {
void ROBackend::reset_backend_stats() {
auto *bp{backend_proxy.get()};
for (size_t i{0}; i < MAX_NUM_BLOCKS; i++) {
for (size_t i{0}; i < maximum_num_contexts_; i++) {
bp->profiler[i].resetStats();
}
}
@@ -208,7 +215,7 @@ void ROBackend::dump_backend_stats() {
auto *bp{backend_proxy.get()};
for (size_t i{0}; i < MAX_NUM_BLOCKS; i++) {
for (size_t i{0}; i < maximum_num_contexts_; i++) {
// Average latency as perceived from a thread
const ROStats &prof{bp->profiler[i]};
us_wait_slot += prof.getStat(WAITING_ON_SLOT) / gpu_frequency_mhz;
+22 -1
Dosyayı Görüntüle
@@ -55,7 +55,9 @@ class ROHostContext;
* the host (which is an inversion of the normal behavior).
*/
class ROBackend : public Backend {
const unsigned MAX_NUM_BLOCKS{65536};
using RetBufferProxyT = DeviceProxy<HIPAllocator, uint64_t>;
using StatusProxyT =
DeviceProxy<HIPDefaultFinegrainedAllocator, char>;
public:
/**
@@ -270,6 +272,25 @@ class ROBackend : public Backend {
* @brief Number of MPI windows used for device contexts in RO Backend
*/
size_t num_windows_{32};
/**
* @brief Return buffer for rocshmem_g API
*/
RetBufferProxyT g_ret_buffer_;
/**
* @brief Return buffer for rocshmem atomic return APIs
*/
RetBufferProxyT atomic_ret_buffer_;
/**
* This buffer is used by the GPU to wait on a blocking operation. The initial
* value is 0. When a GPU enqueues a blocking operation, it waits for this
* value to resolve to 1, which is set by the CPU when the blocking
* operation completes. The GPU then resets status back to zero. There is
* a separate status variable for each work-item in a RO Context
*/
StatusProxyT status_;
};
} // namespace rocshmem
+13 -23
Dosyayı Görüntüle
@@ -38,10 +38,8 @@ struct BlockHandle {
volatile uint64_t write_index{};
volatile uint64_t *host_read_index{};
volatile char *status{nullptr};
char *g_ret{nullptr};
atomic_ret_t atomic_ret{};
IpcImpl ipc{};
HdpPolicy *hdp{};
void *g_ret{nullptr};
void *atomic_ret{nullptr};
volatile uint64_t lock{};
};
@@ -52,9 +50,8 @@ class DefaultBlockHandleProxy {
public:
DefaultBlockHandleProxy() = default;
DefaultBlockHandleProxy(char *g_ret, atomic_ret_t *atomic_ret, Queue *queue,
IpcImpl *ipc_policy, HdpPolicy *hdp_policy,
size_t num_elems = 1)
DefaultBlockHandleProxy(void *g_ret, void *atomic_ret, Queue *queue,
volatile char *status, size_t num_elems = 1)
: proxy_{num_elems} {
// TODO(bpotter): create a default queue for this queue descriptor
@@ -66,13 +63,9 @@ class DefaultBlockHandleProxy {
block_handle->read_index = queue_descriptor->read_index;
block_handle->write_index = queue_descriptor->write_index;
block_handle->host_read_index = &queue_descriptor->read_index;
block_handle->status = queue_descriptor->status;
block_handle->status = status;
block_handle->g_ret = g_ret;
block_handle->atomic_ret.atomic_base_ptr = atomic_ret->atomic_base_ptr;
block_handle->atomic_ret.atomic_counter = 0;
block_handle->ipc.ipc_bases = ipc_policy->ipc_bases;
block_handle->ipc.shm_size = ipc_policy->shm_size;
block_handle->hdp = hdp_policy;
block_handle->atomic_ret = atomic_ret;
block_handle->lock = 0;
}
@@ -99,27 +92,24 @@ class BlockHandleProxy {
public:
BlockHandleProxy() = default;
BlockHandleProxy(char *g_ret, atomic_ret_t *atomic_ret, Queue *queue,
IpcImpl *ipc_policy, HdpPolicy *hdp_policy,
size_t max_blocks)
BlockHandleProxy(void *g_ret, void *atomic_ret, Queue *queue, size_t offset,
volatile char *status, size_t max_blocks)
: proxy_{max_blocks} {
for (size_t i{0}; i < max_blocks; i++) {
auto queue_descriptor{queue->descriptor(i)};
auto block_handle{&proxy_.get()[i]};
size_t block_offset{i * offset};
block_handle->profiler.resetStats();
block_handle->queue = queue->elements(i);
block_handle->queue_size = queue->size();
block_handle->read_index = queue_descriptor->read_index;
block_handle->write_index = queue_descriptor->write_index;
block_handle->host_read_index = &queue_descriptor->read_index;
block_handle->status = queue_descriptor->status;
block_handle->g_ret = g_ret;
block_handle->atomic_ret.atomic_base_ptr = atomic_ret->atomic_base_ptr;
block_handle->atomic_ret.atomic_counter = 0;
block_handle->ipc.ipc_bases = ipc_policy->ipc_bases;
block_handle->ipc.shm_size = ipc_policy->shm_size;
block_handle->hdp = hdp_policy;
block_handle->status = status + block_offset;
block_handle->g_ret = reinterpret_cast<uint64_t*>(g_ret) + block_offset;
block_handle->atomic_ret = reinterpret_cast<uint64_t*>(atomic_ret) +
block_offset;
block_handle->lock = 0;
}
}
+40 -24
Dosyayı Görüntüle
@@ -71,7 +71,7 @@ __device__ void ROContext::putmem(void *dest, const void *source, size_t nelems,
}
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true);
ro_net_win_id, block_handle, true, get_status_flag());
}
}
@@ -90,7 +90,7 @@ __device__ void ROContext::getmem(void *dest, const void *source, size_t nelems,
}
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true);
ro_net_win_id, block_handle, true, get_status_flag());
}
}
@@ -134,18 +134,21 @@ __device__ void ROContext::getmem_nbi(void *dest, const void *source,
__device__ void ROContext::fence() {
build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true);
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
true, get_status_flag());
}
__device__ void ROContext::fence(int pe) {
// TODO(khamidou): need to check if per pe has any special handling
build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true);
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
true, get_status_flag());
}
__device__ void ROContext::quiet() {
build_queue_element(RO_NET_QUIET, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true);
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
true, get_status_flag());
}
__device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
@@ -163,7 +166,7 @@ __device__ void ROContext::barrier_all() {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true);
block_handle, true, get_status_flag());
}
__syncthreads();
}
@@ -172,7 +175,7 @@ __device__ void ROContext::sync_all() {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true);
block_handle, true, get_status_flag());
}
__syncthreads();
}
@@ -182,7 +185,7 @@ __device__ void ROContext::sync(rocshmem_team_t team) {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
true);
true, get_status_flag());
}
__syncthreads();
}
@@ -195,7 +198,7 @@ __device__ void ROContext::ctx_destroy() {
build_queue_element(RO_NET_FINALIZE, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true);
block_handle, true, get_status_flag());
int buffer_id = ro_net_win_id;
backend->queue_.descriptor(buffer_id)->write_index = block_handle->write_index;
@@ -219,7 +222,7 @@ __device__ void ROContext::putmem_wg(void *dest, const void *source,
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true);
ro_net_win_id, block_handle, true, get_status_flag());
}
}
__syncthreads();
@@ -237,7 +240,7 @@ __device__ void ROContext::getmem_wg(void *dest, const void *source,
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true);
ro_net_win_id, block_handle, true, get_status_flag());
}
}
__syncthreads();
@@ -291,7 +294,7 @@ __device__ void ROContext::putmem_wave(void *dest, const void *source,
if (is_thread_zero_in_wave()) {
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true);
ro_net_win_id, block_handle, true, get_status_flag());
}
}
}
@@ -309,7 +312,7 @@ __device__ void ROContext::getmem_wave(void *dest, const void *source,
if (is_thread_zero_in_wave()) {
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true);
ro_net_win_id, block_handle, true, get_status_flag());
}
}
}
@@ -578,7 +581,8 @@ __device__ void build_queue_element(
ro_net_cmds type, void *dst, void *src, size_t size, int pe,
int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync,
MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle,
bool blocking, ROCSHMEM_OP op, ro_net_types datatype) {
bool blocking, volatile char *status, ROCSHMEM_OP op,
ro_net_types datatype) {
auto write_slot{next_write_slot(handle)};
auto queue_element = &handle->queue[write_slot];
@@ -587,6 +591,9 @@ __device__ void build_queue_element(
queue_element->ol1.size = size;
queue_element->dst = dst;
queue_element->ro_net_win_id = ro_net_win_id;
if(blocking) {
queue_element->status = status;
}
if (type == RO_NET_P) {
memcpy(&queue_element->src, src, size);
@@ -594,9 +601,6 @@ __device__ void build_queue_element(
queue_element->src = src;
}
auto threadId {get_flat_id()};
queue_element->threadId = threadId;
if (type == RO_NET_AMO_FOP) {
queue_element->op = op;
queue_element->datatype = datatype;
@@ -655,19 +659,31 @@ __device__ void build_queue_element(
if (blocking) {
int network_status{0};
do {
refresh_volatile_sbyte(&network_status, &handle->status[threadId]);
refresh_volatile_sbyte(&network_status, queue_element->status);
} while (network_status == 0);
handle->status[threadId] = 0;
*(queue_element->status) = 0;
__threadfence();
}
}
__device__ uint64_t *ROContext::get_unused_atomic() {
auto index{atomicAdd(&block_handle->atomic_ret.atomic_counter, 1)};
index = index % max_nb_atomic;
auto atomic_base_ptr{block_handle->atomic_ret.atomic_base_ptr};
return &atomic_base_ptr[index];
__device__ uint64_t *ROContext::get_atomic_ret_buf() {
uint64_t *atomic_base_ptr{
reinterpret_cast<uint64_t*>(block_handle->atomic_ret)};
int thread_id{get_flat_block_id()};
return &atomic_base_ptr[thread_id];
}
__device__ uint64_t *ROContext::get_g_ret_buf() {
uint64_t *g_ret{reinterpret_cast<uint64_t*>(block_handle->g_ret)};
int thread_id{get_flat_block_id()};
return &g_ret[thread_id];
}
__device__ volatile char *ROContext::get_status_flag() {
volatile char* status{block_handle->status};
int thread_id{get_flat_block_id()};
return &status[thread_id];
}
} // namespace rocshmem
+6 -2
Dosyayı Görüntüle
@@ -34,7 +34,7 @@ __device__ void build_queue_element(
ro_net_cmds type, void *dst, void *src, size_t size, int pe,
int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync,
MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle,
bool blocking, ROCSHMEM_OP op = ROCSHMEM_SUM,
bool blocking, volatile char *status = nullptr, ROCSHMEM_OP op = ROCSHMEM_SUM,
ro_net_types datatype = RO_NET_INT);
class ROContext : public Context {
@@ -251,7 +251,11 @@ class ROContext : public Context {
__device__ uint64_t signal_fetch_wave(const uint64_t *sig_addr);
private:
__device__ uint64_t *get_unused_atomic();
__device__ uint64_t *get_atomic_ret_buf();
__device__ uint64_t *get_g_ret_buf();
__device__ volatile char *get_status_flag();
BlockHandle *block_handle{nullptr};
+29 -32
Dosyayı Görüntüle
@@ -120,7 +120,8 @@ __device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast<T *>(source),
nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm,
ro_net_win_id, block_handle, true, Op, GetROType<T>::Type);
ro_net_win_id, block_handle, true, get_status_flag(),
Op, GetROType<T>::Type);
__syncthreads();
return ROCSHMEM_SUCCESS;
@@ -137,8 +138,8 @@ __device__ void ROContext::to_all(T *dest, const T *source, int nreduce,
build_queue_element(RO_NET_TO_ALL, dest, const_cast<T *>(source), nreduce,
PE_start, logPE_stride, PE_size, 0, pWrk, pSync,
(MPI_Comm)NULL, ro_net_win_id, block_handle, true, Op,
GetROType<T>::Type);
(MPI_Comm)NULL, ro_net_win_id, block_handle, true,
get_status_flag(), Op, GetROType<T>::Type);
__syncthreads();
}
@@ -166,7 +167,7 @@ __device__ void ROContext::p(T *dest, T value, int pe) {
} else {
build_queue_element(RO_NET_P, dest, &value, sizeof(T), pe, 0, 0, 0, nullptr,
nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true);
block_handle, true, get_status_flag());
}
}
@@ -179,12 +180,7 @@ __device__ T ROContext::g(const T *source, int pe) {
ipcImpl_.ipcCopy(&dest, ipcImpl_.ipc_bases[pe] + L_offset, sizeof(T));
return dest;
} else {
int thread_id{get_flat_block_id()};
int block_size{get_flat_block_size()};
int offset{get_flat_grid_id() * block_size + thread_id};
char *base_dest{block_handle->g_ret};
char *dest{&base_dest[offset * sizeof(int64_t)]};
auto dest{get_g_ret_buf()};
get<T>(reinterpret_cast<T *>(dest), source, 1, pe);
return *(reinterpret_cast<T *>(dest));
}
@@ -206,12 +202,13 @@ __device__ void ROContext::get_nbi(T *dest, const T *source, size_t nelems,
template <typename T>
__device__ T ROContext::amo_fetch_cas(void *dst, T value, T cond, int pe) {
auto source{get_unused_atomic()};
auto source{get_atomic_ret_buf()};
build_queue_element(RO_NET_AMO_FCAS, dst, reinterpret_cast<T *>(source),
value, pe, 0, 0, 0,
reinterpret_cast<void *>(static_cast<long long>(cond)),
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true,
ROCSHMEM_SUM, GetROType<T>::Type);
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
true, get_status_flag(), ROCSHMEM_SUM,
GetROType<T>::Type);
__threadfence();
return *source;
}
@@ -223,11 +220,11 @@ __device__ void ROContext::amo_cas(void *dst, T value, T cond, int pe) {
template <typename T>
__device__ T ROContext::amo_fetch_add(void *dst, T value, int pe) {
auto source{get_unused_atomic()};
auto source{get_atomic_ret_buf()};
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<T *>(source), value,
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true, ROCSHMEM_SUM,
GetROType<T>::Type);
ro_net_win_id, block_handle, true, get_status_flag(),
ROCSHMEM_SUM, GetROType<T>::Type);
__threadfence();
return *source;
}
@@ -239,11 +236,11 @@ __device__ void ROContext::amo_add(void *dst, T value, int pe) {
template <typename T>
__device__ T ROContext::amo_swap(void *dst, T value, int pe) {
auto source{get_unused_atomic()};
auto source{get_atomic_ret_buf()};
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true, ROCSHMEM_REPLACE,
GetROType<T>::Type);
ro_net_win_id, block_handle, true, get_status_flag(),
ROCSHMEM_REPLACE, GetROType<T>::Type);
__threadfence();
return *source;
}
@@ -255,11 +252,11 @@ __device__ void ROContext::amo_set(void *dst, T value, int pe) {
template <typename T>
__device__ T ROContext::amo_fetch_and(void *dst, T value, int pe) {
auto source{get_unused_atomic()};
auto source{get_atomic_ret_buf()};
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true, ROCSHMEM_AND,
GetROType<T>::Type);
ro_net_win_id, block_handle, true, get_status_flag(),
ROCSHMEM_AND, GetROType<T>::Type);
__threadfence();
return *source;
}
@@ -271,11 +268,11 @@ __device__ void ROContext::amo_and(void *dst, T value, int pe) {
template <typename T>
__device__ T ROContext::amo_fetch_or(void *dst, T value, int pe) {
auto source{get_unused_atomic()};
auto source{get_atomic_ret_buf()};
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true, ROCSHMEM_OR,
GetROType<T>::Type);
ro_net_win_id, block_handle, true, get_status_flag(),
ROCSHMEM_OR, GetROType<T>::Type);
__threadfence();
return *source;
}
@@ -287,11 +284,11 @@ __device__ void ROContext::amo_or(void *dst, T value, int pe) {
template <typename T>
__device__ T ROContext::amo_fetch_xor(void *dst, T value, int pe) {
auto source{get_unused_atomic()};
auto source{get_atomic_ret_buf()};
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
ro_net_win_id, block_handle, true, ROCSHMEM_XOR,
GetROType<T>::Type);
ro_net_win_id, block_handle, true, get_status_flag(),
ROCSHMEM_XOR, GetROType<T>::Type);
__threadfence();
return *source;
}
@@ -314,7 +311,7 @@ __device__ void ROContext::broadcast(rocshmem_team_t team, T *dest,
build_queue_element(RO_NET_TEAM_BROADCAST, dest, const_cast<T *>(source),
nelems, 0, 0, 0, pe_root, nullptr, nullptr,
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
ROCSHMEM_SUM, GetROType<T>::Type);
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
__syncthreads();
}
@@ -332,7 +329,7 @@ __device__ void ROContext::broadcast(T *dest, const T *source, int nelems,
build_queue_element(RO_NET_BROADCAST, dest, const_cast<T *>(source), nelems,
pe_start, log_pe_stride, pe_size, pe_root, nullptr,
p_sync, (MPI_Comm)NULL, ro_net_win_id, block_handle, true,
ROCSHMEM_SUM, GetROType<T>::Type);
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
__syncthreads();
}
@@ -350,7 +347,7 @@ __device__ void ROContext::alltoall(rocshmem_team_t team, T *dest,
build_queue_element(RO_NET_ALLTOALL, dest, const_cast<T *>(source), nelems, 0,
0, 0, 0, team_obj->ata_buffer, nullptr,
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
ROCSHMEM_SUM, GetROType<T>::Type);
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
__syncthreads();
}
@@ -368,7 +365,7 @@ __device__ void ROContext::fcollect(rocshmem_team_t team, T *dest,
build_queue_element(RO_NET_FCOLLECT, dest, const_cast<T *>(source), nelems, 0,
0, 0, 0, team_obj->ata_buffer, nullptr,
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
ROCSHMEM_SUM, GetROType<T>::Type);
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
__syncthreads();
}
+132 -128
Dosyayı Görüntüle
@@ -94,7 +94,7 @@ void MPITransport::submitRequestsToMPI() {
case RO_NET_PUT:
putMem(next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, true);
next_element.status, true);
DPRINTF("Received PUT dst %p src %p size %lu pe %d win_id %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id);
@@ -109,7 +109,7 @@ void MPITransport::submitRequestsToMPI() {
putMem(next_element.dst, source_buffer, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, true, true);
next_element.status, true, true);
DPRINTF("Received P dst %p value %p pe %d\n", next_element.dst,
next_element.src, next_element.PE);
break;
@@ -117,14 +117,14 @@ void MPITransport::submitRequestsToMPI() {
case RO_NET_GET:
getMem(next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, true);
next_element.status, true);
DPRINTF("Received GET dst %p src %p size %lu pe %d\n", next_element.dst,
next_element.src, next_element.ol1.size, next_element.PE);
break;
case RO_NET_PUT_NBI:
putMem(next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, false);
next_element.status, false);
DPRINTF("Received PUT NBI dst %p src %p size %lu pe %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE);
@@ -132,7 +132,7 @@ void MPITransport::submitRequestsToMPI() {
case RO_NET_GET_NBI:
getMem(next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, false);
next_element.status, false);
DPRINTF("Received GET NBI dst %p src %p size %lu pe %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.PE);
@@ -141,7 +141,7 @@ void MPITransport::submitRequestsToMPI() {
amoFOP(next_element.dst, next_element.src,
const_cast<unsigned long long *>(&next_element.ol1.atomic_value),
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, true,
next_element.status, true,
static_cast<ROCSHMEM_OP>(next_element.op),
static_cast<ro_net_types>(next_element.datatype));
DPRINTF("Received AMO dst %p src %p Val %llu pe %d\n", next_element.dst,
@@ -151,7 +151,7 @@ void MPITransport::submitRequestsToMPI() {
amoFCAS(next_element.dst, next_element.src,
const_cast<unsigned long long *>(&next_element.ol1.atomic_value),
next_element.PE, next_element.ro_net_win_id, queue_idx,
next_element.threadId, true,
next_element.status, true,
const_cast<void **>(&next_element.ol2.pWrk),
static_cast<ro_net_types>(next_element.datatype));
DPRINTF("Received F_CSWAP dst %p src %p Val %llu pe %d cond %ld\n",
@@ -165,7 +165,7 @@ void MPITransport::submitRequestsToMPI() {
next_element.team_comm,
static_cast<ROCSHMEM_OP>(next_element.op),
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
next_element.status, true);
DPRINTF("Received FLOAT_SUM_TEAM_REDUCE dst %p src %p size %lu team %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm);
@@ -177,7 +177,7 @@ void MPITransport::submitRequestsToMPI() {
next_element.PE_size, next_element.ol2.pWrk, next_element.pSync,
static_cast<ROCSHMEM_OP>(next_element.op),
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
next_element.status, true);
DPRINTF(
"Received FLOAT_SUM_TO_ALL dst %p src %p size %lu "
"PE_start %d, logPE_stride %d, PE_size %d, pWrk %p, pSync %p\n",
@@ -190,7 +190,7 @@ void MPITransport::submitRequestsToMPI() {
next_element.ro_net_win_id, queue_idx,
next_element.team_comm, next_element.PE_root,
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
next_element.status, true);
DPRINTF(
"Received TEAM_BROADCAST dst %p src %p size %lu "
"team %d, PE_root %d \n",
@@ -203,7 +203,7 @@ void MPITransport::submitRequestsToMPI() {
next_element.PE, next_element.logPE_stride,
next_element.PE_size, next_element.PE_root, next_element.pSync,
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
next_element.status, true);
DPRINTF(
"Received BROADCAST dst %p src %p size %lu PE_start %d, "
"logPE_stride %d, PE_size %d, PE_root %d, pSync %p\n",
@@ -216,7 +216,7 @@ void MPITransport::submitRequestsToMPI() {
next_element.ro_net_win_id, queue_idx, next_element.team_comm,
next_element.ol2.pWrk,
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
next_element.status, true);
DPRINTF("Received ALLTOALL dst %p src %p size %lu team %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm);
@@ -226,26 +226,26 @@ void MPITransport::submitRequestsToMPI() {
next_element.ro_net_win_id, queue_idx, next_element.team_comm,
next_element.ol2.pWrk,
static_cast<ro_net_types>(next_element.datatype),
next_element.threadId, true);
next_element.status, true);
DPRINTF("Received FCOLLECT dst %p src %p size %lu team %d\n",
next_element.dst, next_element.src, next_element.ol1.size,
next_element.team_comm);
break;
case RO_NET_BARRIER_ALL:
barrier(queue_idx, next_element.threadId, true, ro_net_comm_world);
barrier(queue_idx, next_element.status, true, ro_net_comm_world);
DPRINTF("Received Barrier_all\n");
break;
case RO_NET_SYNC:
barrier(queue_idx, next_element.threadId, true, next_element.team_comm);
barrier(queue_idx, next_element.status, true, next_element.team_comm);
DPRINTF("Received Sync\n");
break;
case RO_NET_FENCE:
case RO_NET_QUIET:
quiet(queue_idx, next_element.threadId);
quiet(queue_idx, next_element.status);
DPRINTF("Received FENCE/QUIET\n");
break;
case RO_NET_FINALIZE:
quiet(queue_idx, next_element.threadId);
quiet(queue_idx, next_element.status);
DPRINTF("Received Finalize\n");
break;
default:
@@ -256,7 +256,7 @@ void MPITransport::submitRequestsToMPI() {
}
void MPITransport::initTransport(int num_queues, BackendProxyT *proxy) {
waiting_quiet.resize(num_queues, std::vector<int>());
waiting_quiet.resize(num_queues, std::vector<volatile char *>());
outstanding.resize(num_queues, 0);
transport_up = false;
@@ -333,13 +333,13 @@ void MPITransport::global_exit(int status) {
MPI_Abort(ro_net_comm_world, status);
}
void MPITransport::barrier(int blockId, int threadId, bool blocking,
MPI_Comm team) {
void MPITransport::barrier(int contextId, volatile char *status, bool blocking,
MPI_Comm team) {
MPI_Request request{};
NET_CHECK(MPI_Ibarrier(team, &request));
requests.push_back({request, {threadId, blockId, blocking}});
outstanding[blockId]++;
requests.push_back({request, {status, contextId, blocking}});
outstanding[contextId]++;
}
MPI_Op MPITransport::get_mpi_op(ROCSHMEM_OP op) {
@@ -397,10 +397,10 @@ static MPI_Datatype convertType(ro_net_types type) {
}
void MPITransport::reduction(void *dst, void *src, int size, int pe,
int win_id, int blockId, int start, int logPstride,
int sizePE, void *pWrk, long *pSync,
ROCSHMEM_OP op, ro_net_types type, int threadId,
bool blocking) {
int win_id, int contextId, int start,
int logPstride, int sizePE, void *pWrk,
long *pSync, ROCSHMEM_OP op, ro_net_types type,
volatile char *status, bool blocking) {
MPI_Request request{};
MPI_Op mpi_op{get_mpi_op(op)};
MPI_Datatype mpi_type{convertType(type)};
@@ -413,14 +413,15 @@ void MPITransport::reduction(void *dst, void *src, int size, int pe,
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
}
requests.push_back({request, {threadId, blockId, blocking}});
outstanding[blockId]++;
requests.push_back({request, {status, contextId, blocking}});
outstanding[contextId]++;
}
void MPITransport::broadcast(void *dst, void *src, int size, int pe,
int win_id, int blockId, int start, int logPstride,
int sizePE, int root, long *pSync,
ro_net_types type, int threadId, bool blocking) {
int win_id, int contextId, int start,
int logPstride, int sizePE, int root, long *pSync,
ro_net_types type, volatile char *status,
bool blocking) {
MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)};
int new_rank{};
@@ -437,15 +438,15 @@ void MPITransport::broadcast(void *dst, void *src, int size, int pe,
MPI_Datatype mpi_type{convertType(type)};
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
requests.push_back({request, {threadId, blockId, blocking}});
requests.push_back({request, {status, contextId, blocking}});
outstanding[blockId]++;
outstanding[contextId]++;
}
void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, int threadId,
bool blocking) {
int contextId, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, volatile char* status,
bool blocking) {
MPI_Request request{};
MPI_Op mpi_op{get_mpi_op(op)};
@@ -459,15 +460,15 @@ void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
}
requests.push_back({request, {threadId, blockId, blocking}});
requests.push_back({request, {status, contextId, blocking}});
outstanding[blockId]++;
outstanding[contextId]++;
}
void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, int root,
ro_net_types type, int threadId,
bool blocking) {
int contextId, MPI_Comm team, int root,
ro_net_types type, volatile char *status,
bool blocking) {
MPI_Comm comm{team};
int new_rank{};
MPI_Comm_rank(comm, &new_rank);
@@ -482,14 +483,15 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
MPI_Request request{};
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
requests.push_back({request, {threadId, blockId, blocking}});
requests.push_back({request, {status, contextId, blocking}});
outstanding[blockId]++;
outstanding[contextId]++;
}
void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId, bool blocking) {
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
int pe_size{};
NET_CHECK(MPI_Comm_size(team, &pe_size));
@@ -502,24 +504,24 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
#ifdef A2A_HEURISTICS
if ((pe_size >= 8 || type_size * size < 2048) &&
num_clust * clust_size == pe_size) {
return alltoall_gcen(dst, src, size, win_id, blockId, team, ata_buffptr, type,
threadId, blocking);
return alltoall_gcen(dst, src, size, win_id, contextId, team, ata_buffptr,
type, status, blocking);
} else if (size <= 512) {
#endif // A2A_HEURISTICS
return alltoall_mpi(dst, src, size, blockId, team, ata_buffptr, type,
threadId, blocking);
return alltoall_mpi(dst, src, size, contextId, team, ata_buffptr, type,
status, blocking);
#ifdef A2A_HEURISTICS
} else {
return alltoall_broadcast(dst, src, size, win_id, blockId, team, ata_buffptr,
type, threadId, blocking);
return alltoall_broadcast(dst, src, size, win_id, contextId, team,
ata_buffptr, type, status, blocking);
}
#endif // A2A_HEURISTICS
}
void MPITransport::alltoall_broadcast(void *dst, void *src, int size,
int win_id, int blockId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
int threadId, bool blocking) {
int win_id, int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
auto *bp{backend_proxy->get()};
MPI_Comm comm{team};
@@ -563,26 +565,26 @@ void MPITransport::alltoall_broadcast(void *dst, void *src, int size,
NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE));
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
barrier(blockId, threadId, blocking, comm);
barrier(contextId, status, blocking, comm);
}
void MPITransport::alltoall_mpi(void *dst, void *src, int size, int blockId,
MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId,
bool blocking) {
void MPITransport::alltoall_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
int new_rank{};
NET_CHECK(MPI_Comm_rank(team, &new_rank));
int pe_size{};
NET_CHECK(MPI_Comm_size(team, &pe_size));
MPI_Datatype mpi_type{convertType(type)};
NET_CHECK(MPI_Alltoall(src, size, mpi_type, dst, size, mpi_type, team));
quiet(blockId, threadId);
quiet(contextId, status);
}
void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId,
bool blocking) {
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
auto *bp{backend_proxy->get()};
int new_rank{};
@@ -664,14 +666,14 @@ void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id,
MPI_Comm comm_ring{createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust)};
barrier(blockId, threadId, false, comm_cluster);
barrier(blockId, threadId, blocking, comm_ring);
barrier(contextId, status, false, comm_cluster);
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId,
bool blocking) {
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// GPU-centric alltoall with in-place blocking synchronization
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
@@ -759,12 +761,13 @@ void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id,
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust);
// Now wait for completion
barrier(blockId, threadId, blocking, comm_ring);
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId, bool blocking) {
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
int pe_size, type_size;
MPI_Comm comm = team;
NET_CHECK(MPI_Comm_size(comm, &pe_size));
@@ -780,21 +783,21 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
// In most cases the MPI implementation is optimal
// But it crashes for > 512 messages
if (size <= 512) {
fcollect_mpi(dst, src, size, blockId, team, ata_buffptr, type,
threadId, blocking);
fcollect_mpi(dst, src, size, contextId, team, ata_buffptr, type,
status, blocking);
} else if (num_clust * clust_size == pe_size) {
fcollect_gcen(dst, src, size, win_id, blockId, team, ata_buffptr, type,
threadId, blocking);
fcollect_gcen(dst, src, size, win_id, contextId, team, ata_buffptr, type,
status, blocking);
} else {
fcollect_broadcast(dst, src, size, win_id, blockId, team, ata_buffptr,
type, threadId, blocking);
fcollect_broadcast(dst, src, size, win_id, contextId, team, ata_buffptr,
type, status, blocking);
}
}
void MPITransport::fcollect_broadcast(void *dst, void *src, int size,
int win_id, int blockId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
int threadId, bool blocking) {
int win_id, int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// Broadcast implementation of fcollect
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
@@ -837,13 +840,13 @@ void MPITransport::fcollect_broadcast(void *dst, void *src, int size,
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
// Now wait for completion
barrier(blockId, threadId, blocking, comm);
barrier(contextId, status, blocking, comm);
}
void MPITransport::fcollect_mpi(void *dst, void *src, int size, int blockId,
MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId,
bool blocking) {
void MPITransport::fcollect_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking) {
// MPI's implementation of fcollect
int new_rank, pe_size;
@@ -852,13 +855,13 @@ void MPITransport::fcollect_mpi(void *dst, void *src, int size, int blockId,
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
NET_CHECK(MPI_Comm_size(comm, &pe_size));
NET_CHECK(MPI_Allgather(src, size, mpi_type, dst, size, mpi_type, comm));
quiet(blockId, threadId);
quiet(contextId, status);
}
void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId,
bool blocking) {
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// GPU-centric implementation of fcollect
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
@@ -938,14 +941,14 @@ void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id,
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust);
// Now wait for completion
barrier(blockId, threadId, false, comm_cluster);
barrier(blockId, threadId, blocking, comm_ring);
barrier(contextId, status, false, comm_cluster);
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId,
bool blocking) {
int contextId, MPI_Comm team,
void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) {
// GPU-centric implementation with in-place, blocking synchronization
auto *bp{backend_proxy->get()};
int new_rank, pe_size;
@@ -1027,12 +1030,12 @@ void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id,
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
stride * clust_size, num_clust);
// Now wait for completion
barrier(blockId, threadId, blocking, comm_ring);
barrier(contextId, status, blocking, comm_ring);
}
void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
int blockId, int threadId, bool blocking,
bool inline_data) {
int contextId, volatile char *status, bool blocking,
bool inline_data) {
queue->flush_hdp();
auto *bp{backend_proxy->get()};
@@ -1047,14 +1050,14 @@ void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
// though it should be in the progress loop.
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
requests.push_back({request, {threadId, blockId, blocking}});
requests.push_back({request, {status, contextId, blocking}});
outstanding[blockId]++;
outstanding[contextId]++;
}
void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id,
int blockId, int threadId, bool blocking,
ROCSHMEM_OP op, ro_net_types type) {
int contextId, volatile char *status, bool blocking,
ROCSHMEM_OP op, ro_net_types type) {
queue->flush_hdp();
auto *bp{backend_proxy->get()};
@@ -1069,14 +1072,14 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id,
// though it should be in the progress loop.
NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
queue->notify(blockId, threadId);
queue->notify(status);
queue->sfence_flush_hdp();
}
void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe,
int win_id, int blockId, int threadId, bool blocking,
void *cond, ro_net_types type) {
int win_id, int contextId, volatile char *status,
bool blocking, void *cond, ro_net_types type) {
queue->flush_hdp();
auto *bp{backend_proxy->get()};
@@ -1091,14 +1094,15 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe,
// though it should be in the progress loop.
NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
queue->notify(blockId, threadId);
queue->notify(status);
queue->sfence_flush_hdp();
}
void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id,
int blockId, int threadId, bool blocking) {
outstanding[blockId]++;
int contextId, volatile char *status,
bool blocking) {
outstanding[contextId]++;
auto *bp{backend_proxy->get()};
MPI_Request request{};
@@ -1106,7 +1110,7 @@ void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id,
dst, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(src),
size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request));
requests.push_back({request, {threadId, blockId, blocking}});
requests.push_back({request, {status, contextId, blocking}});
}
std::unique_ptr<MPI_Request[]> MPITransport::raw_requests() {
@@ -1139,20 +1143,20 @@ void MPITransport::progress() {
auto *bp{backend_proxy->get()};
for (int i{0}; i < outcount; i++) {
int index{testsome_indices[i]};
int blockId{requests[index].properties.blockId};
int threadId{requests[index].properties.threadId};
int contextId{requests[index].properties.contextId};
volatile char *status{requests[index].properties.status};
if (blockId != -1) {
outstanding[blockId]--;
if (contextId != -1) {
outstanding[contextId]--;
DPRINTF(
"Finished op for blockId %d at threadId %d "
"Finished op for contextId %d at status addr %p "
"(%d requests outstanding)\n",
blockId, threadId, outstanding[blockId]);
contextId, status, outstanding[contextId]);
}
if (requests[index].properties.blocking) {
if (blockId != -1) {
queue->notify(blockId, threadId);
if (contextId != -1) {
queue->notify(status);
}
queue->sfence_flush_hdp();
}
@@ -1163,14 +1167,14 @@ void MPITransport::progress() {
// If the GPU has requested a quiet, notify it of completion when
// all outstanding requests are complete.
if (!outstanding[blockId] && !waiting_quiet[blockId].empty()) {
for (const auto threadId : waiting_quiet[blockId]) {
DPRINTF("Finished Quiet for blockId %d at threadId %d\n", blockId,
threadId);
queue->notify(blockId, threadId);
if (!outstanding[contextId] && !waiting_quiet[contextId].empty()) {
for (const auto status : waiting_quiet[contextId]) {
DPRINTF("Finished Quiet for contextId %d at status addr %p\n", contextId,
status);
queue->notify(status);
}
waiting_quiet[blockId].clear();
waiting_quiet[contextId].clear();
queue->sfence_flush_hdp();
}
@@ -1185,15 +1189,15 @@ void MPITransport::progress() {
}
}
void MPITransport::quiet(int blockId, int threadId) {
void MPITransport::quiet(int contextId, volatile char *status) {
auto *bp{backend_proxy->get()};
if (!outstanding[blockId]) {
DPRINTF("Finished Quiet immediately for blockId %d at threadId %d\n", blockId,
threadId);
queue->notify(blockId, threadId);
if (!outstanding[contextId]) {
DPRINTF("Finished Quiet immediately for contextId %d at status addr %p\n",
contextId, status);
queue->notify(status);
} else {
waiting_quiet[blockId].emplace_back(threadId);
waiting_quiet[contextId].emplace_back(status);
}
}
+74 -69
Dosyayı Görüntüle
@@ -36,7 +36,7 @@ namespace rocshmem {
class HostInterface;
class MPITransport : public Transport {
public:
public:
explicit MPITransport(MPI_Comm com, Queue* queue);
virtual ~MPITransport();
@@ -46,87 +46,92 @@ class MPITransport : public Transport {
void finalizeTransport() override;
void createNewTeam(ROBackend *backend, Team *parent_team,
TeamInfo *team_info_wrt_parent,
TeamInfo *team_info_wrt_world, int num_pes,
int my_pe_in_new_team, MPI_Comm team_comm,
rocshmem_team_t *new_team) override;
TeamInfo *team_info_wrt_parent,
TeamInfo *team_info_wrt_world, int num_pes,
int my_pe_in_new_team, MPI_Comm team_comm,
rocshmem_team_t *new_team) override;
void barrier(int blockId, int threadId, bool blocking,
MPI_Comm team) override;
void barrier(int contextId, volatile char *status, bool blocking,
MPI_Comm team) override;
void reduction(void *dst, void *src, int size, int pe, int win_id,
int blockId, int start, int logPstride, int sizePE, void *pWrk,
long *pSync, ROCSHMEM_OP op, ro_net_types type,
int threadId, bool blocking) override;
int contextId, int start, int logPstride, int sizePE,
void *pWrk, long *pSync, ROCSHMEM_OP op, ro_net_types type,
volatile char *status, bool blocking) override;
void team_reduction(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, ROCSHMEM_OP op, ro_net_types type,
int threadId, bool blocking) override;
void team_reduction(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, volatile char *status,
bool blocking) override;
void broadcast(void *dst, void *src, int size, int pe, int win_id,
int blockId, int start, int logPstride, int sizePE,
int PE_root, long *pSync, ro_net_types type, int threadId,
bool blocking) override;
int contextId, int start, int logPstride, int sizePE,
int PE_root, long *pSync, ro_net_types type,
volatile char *status, bool blocking) override;
void team_broadcast(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, int PE_root, ro_net_types type,
int threadId, bool blocking) override;
void team_broadcast(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, int PE_root,
ro_net_types type, volatile char *status,
bool blocking) override;
void alltoall(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking) override;
void alltoall(void *dst, void *src, int size, int win_id, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) override;
void alltoall_broadcast(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId, bool blocking);
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking);
void alltoall_mpi(void *dst, void *src, int size, int blockId, MPI_Comm team,
void *ata_buffptr, ro_net_types type, int threadId,
bool blocking);
void alltoall_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking);
void alltoall_gcen(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking);
void alltoall_gcen(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status, bool blocking);
void alltoall_gcen2(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking);
void alltoall_gcen2(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status, bool blocking);
void fcollect(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking) override;
void fcollect(void *dst, void *src, int size, int win_id, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) override;
void fcollect_broadcast(void *dst, void *src, int size, int win_id,
int blockId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, int threadId, bool blocking);
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status,
bool blocking);
void fcollect_mpi(void *dst, void *src, int size, int blockId, MPI_Comm team,
void *ata_buffptr, ro_net_types type, int threadId,
bool blocking);
void fcollect_mpi(void *dst, void *src, int size, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking);
void fcollect_gcen(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking);
void fcollect_gcen(void *dst, void *src, int size, int win_id, int contextId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking);
void fcollect_gcen2(void *dst, void *src, int size, int win_id, int blockId,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking);
void fcollect_gcen2(void *dst, void *src, int size, int win_id,
int contextId, MPI_Comm team, void *ata_buffptr,
ro_net_types type, volatile char *status, bool blocking);
void putMem(void *dst, void *src, int size, int pe, int win_id, int blockId,
int threadId, bool blocking, bool inline_data = false) override;
void putMem(void *dst, void *src, int size, int pe, int win_id,
int contextId, volatile char *status, bool blocking,
bool inline_data = false) override;
void amoFOP(void *dst, void *src, void *val, int pe, int win_id, int blockId,
int threadId, bool blocking, ROCSHMEM_OP op,
ro_net_types type) override;
void amoFOP(void *dst, void *src, void *val, int pe, int win_id,
int contextId, volatile char *status, bool blocking,
ROCSHMEM_OP op, ro_net_types type) override;
void amoFCAS(void *dst, void *src, void *val, int pe, int win_id, int blockId,
int threadId, bool blocking, void *cond,
ro_net_types type) override;
void amoFCAS(void *dst, void *src, void *val, int pe, int win_id,
int contextId, volatile char *status, bool blocking, void *cond,
ro_net_types type) override;
void getMem(void *dst, void *src, int size, int pe, int win_id, int blockId,
int threadId, bool blocking) override;
void getMem(void *dst, void *src, int size, int pe, int win_id,
int contextId, volatile char *status, bool blocking) override;
void quiet(int blockId, int threadId) override;
void quiet(int contextId, volatile char *status) override;
void progress() override;
@@ -142,7 +147,7 @@ class MPITransport : public Transport {
HostInterface *host_interface{nullptr};
private:
private:
struct CommKey {
CommKey(int _start, int _logPstride, int _size)
: start(_start), logPstride(_logPstride), size(_size) {}
@@ -160,23 +165,23 @@ class MPITransport : public Transport {
};
struct RequestProperties {
RequestProperties(int _threadId, int _blockId, bool _blocking, void *_src,
bool _inline_data)
: threadId(_threadId),
blockId(_blockId),
RequestProperties(volatile char *_status, int _contextId, bool _blocking,
void *_src, bool _inline_data)
: status(_status),
contextId(_contextId),
blocking(_blocking),
src(_src),
inline_data(_inline_data) {}
RequestProperties(int _threadId, int _blockId, bool _blocking)
: threadId(_threadId),
blockId(_blockId),
RequestProperties(volatile char *_status, int _contextId, bool _blocking)
: status(_status),
contextId(_contextId),
blocking(_blocking),
src(nullptr),
inline_data(false) {}
int threadId{-1};
int blockId{-1};
volatile char* status{nullptr};
int contextId{-1};
bool blocking{};
void *src{nullptr};
bool inline_data{};
@@ -202,7 +207,7 @@ class MPITransport : public Transport {
// Unordered vector of in-flight MPI Requests. Can complete out of order.
std::vector<Request> requests{};
std::vector<std::vector<int> > waiting_quiet{};
std::vector<std::vector<volatile char *> > waiting_quiet{};
std::vector<int> outstanding{};
+4 -5
Dosyayı Görüntüle
@@ -33,12 +33,11 @@ Queue::Queue() {
}
}
Queue::Queue(size_t max_queues, size_t max_wg_size, size_t queue_size)
Queue::Queue(size_t max_queues, size_t queue_size)
: max_queues_{max_queues},
max_wg_size_{max_wg_size},
queue_size_{queue_size},
queue_proxy_{max_queues, queue_size},
queue_desc_proxy_{max_queues, max_wg_size} {
queue_desc_proxy_{max_queues} {
gpu_queue = true;
char *value{nullptr};
@@ -101,8 +100,8 @@ void Queue::sfence_flush_hdp() {
}
}
void Queue::notify(int blockId, int threadId) {
descriptor(blockId)->status[threadId] = 1;
void Queue::notify(volatile char* status) {
*status = 1;
}
uint64_t Queue::size() {
+2 -4
Dosyayı Görüntüle
@@ -35,7 +35,7 @@ class Queue {
public:
Queue();
Queue(size_t max_queues, size_t max_threads_per_block, size_t queue_size);
Queue(size_t max_queues, size_t queue_size);
bool process(uint64_t queue_index, MPITransport* transport);
@@ -47,7 +47,7 @@ class Queue {
void sfence_flush_hdp();
void notify(int blockId, int threadId);
void notify(volatile char *status);
uint64_t size();
@@ -72,8 +72,6 @@ class Queue {
size_t max_queues_{};
size_t max_wg_size_{};
size_t queue_size_{};
};
+2 -24
Dosyayı Görüntüle
@@ -45,38 +45,22 @@ typedef struct queue_desc {
*/
uint64_t write_index;
char padding2[56];
/**
* This bit is used by the GPU to wait on a blocking operation. The initial
* value is 0. When a GPU enqueues a blocking operation, it waits for this
* value to resolve to 1, which is set by the CPU when the blocking
* operation completes. The GPU then resets status back to zero. There is
* a separate status variable for each work-item in a work-group
*/
char *status;
char padding3[56];
} __attribute__((__aligned__(64))) queue_desc_t;
template <typename ALLOCATOR>
class QueueDescProxy {
using ProxyT = DeviceProxy<ALLOCATOR, queue_desc_t>;
using ProxyStatusT = DeviceProxy<ALLOCATOR, char>;
public:
QueueDescProxy() = default;
QueueDescProxy(size_t max_queues, size_t max_threads_per_queue)
: max_queues_{max_queues}, max_threads_per_queue_{max_threads_per_queue},
max_threads_{max_queues * max_threads_per_queue}, proxy_{max_queues},
proxy_status_{max_queues * max_threads_per_queue} {
auto *status{proxy_status_.get()};
size_t status_bytes{sizeof(char) * max_threads_};
memset(status, 0, status_bytes);
QueueDescProxy(size_t max_queues)
: max_queues_{max_queues}, proxy_{max_queues} {
auto *queue_descs{proxy_.get()};
for (size_t i{0}; i < max_queues_; i++) {
queue_descs[i].read_index = 0;
queue_descs[i].write_index = 0;
queue_descs[i].status = status + i * max_threads_per_queue_;
}
}
@@ -93,13 +77,7 @@ class QueueDescProxy {
private:
ProxyT proxy_{};
ProxyStatusT proxy_status_{};
size_t max_queues_{};
size_t max_threads_per_queue_{};
size_t max_threads_{};
};
using QueueDescProxyT = QueueDescProxy<HIPDefaultFinegrainedAllocator>;
+1 -1
Dosyayı Görüntüle
@@ -57,13 +57,13 @@ typedef struct queue_element {
void *src{nullptr};
void *dst{nullptr};
int ro_net_win_id{-1};
int threadId{-1};
int logPE_stride{-1};
int PE_size{-1};
long *pSync{nullptr};
int op{-1};
int datatype{-1};
int PE_root{-1};
volatile char *status{nullptr};
MPI_Comm team_comm{};
union {
size_t size;
+31 -30
Dosyayı Görüntüle
@@ -45,60 +45,61 @@ class Transport {
virtual void finalizeTransport() = 0;
virtual void createNewTeam(ROBackend *backend_handle, Team *parent_team,
TeamInfo *team_info_wrt_parent,
TeamInfo *team_info_wrt_world, int num_pes,
int my_pe_in_new_team, MPI_Comm team_comm,
rocshmem_team_t *new_team) = 0;
TeamInfo *team_info_wrt_parent,
TeamInfo *team_info_wrt_world, int num_pes,
int my_pe_in_new_team, MPI_Comm team_comm,
rocshmem_team_t *new_team) = 0;
virtual void barrier(int wg_id, int threadId, bool blocking,
MPI_Comm team) = 0;
virtual void barrier(int wg_id, volatile char *status, bool blocking,
MPI_Comm team) = 0;
virtual void reduction(void *dst, void *src, int size, int pe, int win_id,
int wg_id, int start, int logPstride, int sizePE,
void *pWrk, long *pSync, ROCSHMEM_OP op,
ro_net_types type, int threadId, bool blocking) = 0;
int wg_id, int start, int logPstride, int sizePE,
void *pWrk, long *pSync, ROCSHMEM_OP op,
ro_net_types type, volatile char *status,
bool blocking) = 0;
virtual void team_reduction(void *dst, void *src, int size, int win_id,
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, int threadId,
bool blocking) = 0;
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
ro_net_types type, volatile char *status,
bool blocking) = 0;
virtual void broadcast(void *dst, void *src, int size, int pe, int win_id,
int wg_id, int start, int logPstride, int sizePE,
int PE_root, long *pSync, ro_net_types type,
int threadId, bool blocking) = 0;
int wg_id, int start, int logPstride, int sizePE,
int PE_root, long *pSync, ro_net_types type,
volatile char *status, bool blocking) = 0;
virtual void team_broadcast(void *dst, void *src, int size, int win_id,
int wg_id, MPI_Comm team, int PE_root,
ro_net_types type, int threadId,
bool blocking) = 0;
int wg_id, MPI_Comm team, int PE_root,
ro_net_types type, volatile char *status,
bool blocking) = 0;
virtual void alltoall(void *dst, void *src, int size, int win_id, int wg_id,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking) = 0;
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) = 0;
virtual void fcollect(void *dst, void *src, int size, int win_id, int wg_id,
MPI_Comm team, void *ata_buffptr, ro_net_types type,
int threadId, bool blocking) = 0;
MPI_Comm team, void *ata_buffptr, ro_net_types type,
volatile char *status, bool blocking) = 0;
virtual void putMem(void *dst, void *src, int size, int pe, int win_id,
int wg_id, int threadId, bool blocking,
bool inline_data = false) = 0;
int wg_id, volatile char *status, bool blocking,
bool inline_data = false) = 0;
virtual void getMem(void *dst, void *src, int size, int pe, int win_id,
int wg_id, int threadId, bool blocking) = 0;
int wg_id, volatile char *status, bool blocking) = 0;
virtual void amoFOP(void *dst, void *src, void *val, int pe, int win_id,
int wg_id, int threadId, bool blocking, ROCSHMEM_OP op,
ro_net_types type) = 0;
int wg_id, volatile char *status, bool blocking,
ROCSHMEM_OP op, ro_net_types type) = 0;
virtual void amoFCAS(void *dst, void *src, void *val, int pe, int win_id,
int wg_id, int threadId, bool blocking, void *cond,
ro_net_types type) = 0;
int wg_id, volatile char *status, bool blocking,
void *cond, ro_net_types type) = 0;
virtual bool readyForFinalize() = 0;
virtual void quiet(int wg_id, int threadId) = 0;
virtual void quiet(int wg_id, volatile char *status) = 0;
virtual void progress() = 0;