diff --git a/scripts/functional_tests/driver.sh b/scripts/functional_tests/driver.sh index cdd83e949a..92c4952f80 100755 --- a/scripts/functional_tests/driver.sh +++ b/scripts/functional_tests/driver.sh @@ -179,6 +179,20 @@ TestRMA() { ExecTest "teamctxget" 2 1 1 1048576 + ExecTest "g" 2 1 1 1048576 + ExecTest "g" 2 1 1024 512 + ExecTest "g" 2 8 1 1048576 + ExecTest "g" 2 16 128 8 + ExecTest "g" 2 32 256 512 + ExecTest "g" 2 64 1024 8 + + ExecTest "p" 2 1 1 1048576 + ExecTest "p" 2 1 1024 512 + ExecTest "p" 2 8 1 1048576 + ExecTest "p" 2 16 128 8 + ExecTest "p" 2 32 256 512 + ExecTest "p" 2 64 1024 8 + ################################ Non-Blocking ################################ ExecTest "putnbi" 2 1 1 1048576 diff --git a/src/device_proxy.hpp b/src/device_proxy.hpp index 243c82c931..667ae3615b 100644 --- a/src/device_proxy.hpp +++ b/src/device_proxy.hpp @@ -51,7 +51,7 @@ class DeviceProxy { /* * Default memory provided by the allocation to recognizable bytes. */ - memset(static_cast(temp), 0xBC, size_bytes); + memset(static_cast(temp), 0, size_bytes); /* * Pass the memory into a unique ptr for tracking. diff --git a/src/reverse_offload/backend_proxy.hpp b/src/reverse_offload/backend_proxy.hpp index b1d80dc413..247bf92e3e 100644 --- a/src/reverse_offload/backend_proxy.hpp +++ b/src/reverse_offload/backend_proxy.hpp @@ -36,10 +36,8 @@ struct BackendRegister { std::atomic worker_thread_exit{false}; bool *needs_quiet{nullptr}; bool *needs_blocking{nullptr}; - char *g_ret{nullptr}; HdpPolicy *hdp_policy{nullptr}; WindowInfo **heap_window_info{nullptr}; - atomic_ret_t *atomic_ret{nullptr}; SymmetricHeap *heap_ptr{nullptr}; }; diff --git a/src/reverse_offload/backend_ro.cpp b/src/reverse_offload/backend_ro.cpp index 32675077d3..9fc5fcd503 100644 --- a/src/reverse_offload/backend_ro.cpp +++ b/src/reverse_offload/backend_ro.cpp @@ -64,7 +64,15 @@ ROBackend::ROBackend(MPI_Comm comm) max_wg_size_ = device_props.maxThreadsPerBlock; - queue_ = Queue(maximum_num_contexts_, max_wg_size_, queue_size_); + size_t num_buff_elems = maximum_num_contexts_ * max_wg_size_; + + g_ret_buffer_ = RetBufferProxyT(num_buff_elems); + + atomic_ret_buffer_ = RetBufferProxyT(num_buff_elems); + + status_ = StatusProxyT(num_buff_elems); + + queue_ = Queue(maximum_num_contexts_, queue_size_); transport_ = new MPITransport(comm, &queue_); num_pes = transport_->getNumPes(); @@ -87,10 +95,6 @@ ROBackend::ROBackend(MPI_Comm comm) initIPC(); - init_g_ret(&heap, transport_->get_world_comm(), maximum_num_contexts_, &bp->g_ret); - - allocate_atomic_region(&bp->atomic_ret, maximum_num_contexts_); - transport_->initTransport(maximum_num_contexts_, &backend_proxy); host_interface = transport_->host_interface; @@ -107,14 +111,17 @@ ROBackend::ROBackend(MPI_Comm comm) reinterpret_cast(team_world_proxy_->get()); default_block_handle_proxy_ = DefaultBlockHandleProxyT( - bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get()); + g_ret_buffer_.get(), + atomic_ret_buffer_.get(), &queue_, + status_.get()); TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world; + default_context_proxy_ = DefaultContextProxyT(this, tinfo); - block_handle_proxy_ = BlockHandleProxyT(bp->g_ret, bp->atomic_ret, &queue_, - &ipcImpl, hdp_proxy_.get(), - maximum_num_contexts_); + block_handle_proxy_ = BlockHandleProxyT(g_ret_buffer_.get(), + atomic_ret_buffer_.get(), &queue_, + max_wg_size_, status_.get(), maximum_num_contexts_); setup_ctxs(); worker_thread = std::thread(&ROBackend::ro_net_poll, this); @@ -187,7 +194,7 @@ void ROBackend::ctx_destroy(Context *ctx) { void ROBackend::reset_backend_stats() { auto *bp{backend_proxy.get()}; - for (size_t i{0}; i < MAX_NUM_BLOCKS; i++) { + for (size_t i{0}; i < maximum_num_contexts_; i++) { bp->profiler[i].resetStats(); } } @@ -208,7 +215,7 @@ void ROBackend::dump_backend_stats() { auto *bp{backend_proxy.get()}; - for (size_t i{0}; i < MAX_NUM_BLOCKS; i++) { + for (size_t i{0}; i < maximum_num_contexts_; i++) { // Average latency as perceived from a thread const ROStats &prof{bp->profiler[i]}; us_wait_slot += prof.getStat(WAITING_ON_SLOT) / gpu_frequency_mhz; diff --git a/src/reverse_offload/backend_ro.hpp b/src/reverse_offload/backend_ro.hpp index 2000c1792b..e4907c1c84 100644 --- a/src/reverse_offload/backend_ro.hpp +++ b/src/reverse_offload/backend_ro.hpp @@ -55,7 +55,9 @@ class ROHostContext; * the host (which is an inversion of the normal behavior). */ class ROBackend : public Backend { - const unsigned MAX_NUM_BLOCKS{65536}; + using RetBufferProxyT = DeviceProxy; + using StatusProxyT = + DeviceProxy; public: /** @@ -270,6 +272,25 @@ class ROBackend : public Backend { * @brief Number of MPI windows used for device contexts in RO Backend */ size_t num_windows_{32}; + + /** + * @brief Return buffer for rocshmem_g API + */ + RetBufferProxyT g_ret_buffer_; + + /** + * @brief Return buffer for rocshmem atomic return APIs + */ + RetBufferProxyT atomic_ret_buffer_; + + /** + * This buffer is used by the GPU to wait on a blocking operation. The initial + * value is 0. When a GPU enqueues a blocking operation, it waits for this + * value to resolve to 1, which is set by the CPU when the blocking + * operation completes. The GPU then resets status back to zero. There is + * a separate status variable for each work-item in a RO Context + */ + StatusProxyT status_; }; } // namespace rocshmem diff --git a/src/reverse_offload/block_handle.hpp b/src/reverse_offload/block_handle.hpp index 711db9d04a..19174938c6 100644 --- a/src/reverse_offload/block_handle.hpp +++ b/src/reverse_offload/block_handle.hpp @@ -38,10 +38,8 @@ struct BlockHandle { volatile uint64_t write_index{}; volatile uint64_t *host_read_index{}; volatile char *status{nullptr}; - char *g_ret{nullptr}; - atomic_ret_t atomic_ret{}; - IpcImpl ipc{}; - HdpPolicy *hdp{}; + void *g_ret{nullptr}; + void *atomic_ret{nullptr}; volatile uint64_t lock{}; }; @@ -52,9 +50,8 @@ class DefaultBlockHandleProxy { public: DefaultBlockHandleProxy() = default; - DefaultBlockHandleProxy(char *g_ret, atomic_ret_t *atomic_ret, Queue *queue, - IpcImpl *ipc_policy, HdpPolicy *hdp_policy, - size_t num_elems = 1) + DefaultBlockHandleProxy(void *g_ret, void *atomic_ret, Queue *queue, + volatile char *status, size_t num_elems = 1) : proxy_{num_elems} { // TODO(bpotter): create a default queue for this queue descriptor @@ -66,13 +63,9 @@ class DefaultBlockHandleProxy { block_handle->read_index = queue_descriptor->read_index; block_handle->write_index = queue_descriptor->write_index; block_handle->host_read_index = &queue_descriptor->read_index; - block_handle->status = queue_descriptor->status; + block_handle->status = status; block_handle->g_ret = g_ret; - block_handle->atomic_ret.atomic_base_ptr = atomic_ret->atomic_base_ptr; - block_handle->atomic_ret.atomic_counter = 0; - block_handle->ipc.ipc_bases = ipc_policy->ipc_bases; - block_handle->ipc.shm_size = ipc_policy->shm_size; - block_handle->hdp = hdp_policy; + block_handle->atomic_ret = atomic_ret; block_handle->lock = 0; } @@ -99,27 +92,24 @@ class BlockHandleProxy { public: BlockHandleProxy() = default; - BlockHandleProxy(char *g_ret, atomic_ret_t *atomic_ret, Queue *queue, - IpcImpl *ipc_policy, HdpPolicy *hdp_policy, - size_t max_blocks) + BlockHandleProxy(void *g_ret, void *atomic_ret, Queue *queue, size_t offset, + volatile char *status, size_t max_blocks) : proxy_{max_blocks} { for (size_t i{0}; i < max_blocks; i++) { auto queue_descriptor{queue->descriptor(i)}; auto block_handle{&proxy_.get()[i]}; + size_t block_offset{i * offset}; block_handle->profiler.resetStats(); block_handle->queue = queue->elements(i); block_handle->queue_size = queue->size(); block_handle->read_index = queue_descriptor->read_index; block_handle->write_index = queue_descriptor->write_index; block_handle->host_read_index = &queue_descriptor->read_index; - block_handle->status = queue_descriptor->status; - block_handle->g_ret = g_ret; - block_handle->atomic_ret.atomic_base_ptr = atomic_ret->atomic_base_ptr; - block_handle->atomic_ret.atomic_counter = 0; - block_handle->ipc.ipc_bases = ipc_policy->ipc_bases; - block_handle->ipc.shm_size = ipc_policy->shm_size; - block_handle->hdp = hdp_policy; + block_handle->status = status + block_offset; + block_handle->g_ret = reinterpret_cast(g_ret) + block_offset; + block_handle->atomic_ret = reinterpret_cast(atomic_ret) + + block_offset; block_handle->lock = 0; } } diff --git a/src/reverse_offload/context_ro_device.cpp b/src/reverse_offload/context_ro_device.cpp index 5c7b82ee89..fe5162e67e 100644 --- a/src/reverse_offload/context_ro_device.cpp +++ b/src/reverse_offload/context_ro_device.cpp @@ -71,7 +71,7 @@ __device__ void ROContext::putmem(void *dest, const void *source, size_t nelems, } build_queue_element(RO_NET_PUT, dest, const_cast(source), nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true); + ro_net_win_id, block_handle, true, get_status_flag()); } } @@ -90,7 +90,7 @@ __device__ void ROContext::getmem(void *dest, const void *source, size_t nelems, } build_queue_element(RO_NET_GET, dest, const_cast(source), nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true); + ro_net_win_id, block_handle, true, get_status_flag()); } } @@ -134,18 +134,21 @@ __device__ void ROContext::getmem_nbi(void *dest, const void *source, __device__ void ROContext::fence() { build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, - nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true); + nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, + true, get_status_flag()); } __device__ void ROContext::fence(int pe) { // TODO(khamidou): need to check if per pe has any special handling build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, - nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true); + nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, + true, get_status_flag()); } __device__ void ROContext::quiet() { build_queue_element(RO_NET_QUIET, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, - nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true); + nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, + true, get_status_flag()); } __device__ void *ROContext::shmem_ptr(const void *dest, int pe) { @@ -163,7 +166,7 @@ __device__ void ROContext::barrier_all() { if (is_thread_zero_in_block()) { build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, - block_handle, true); + block_handle, true, get_status_flag()); } __syncthreads(); } @@ -172,7 +175,7 @@ __device__ void ROContext::sync_all() { if (is_thread_zero_in_block()) { build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, - block_handle, true); + block_handle, true, get_status_flag()); } __syncthreads(); } @@ -182,7 +185,7 @@ __device__ void ROContext::sync(rocshmem_team_t team) { if (is_thread_zero_in_block()) { build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, - true); + true, get_status_flag()); } __syncthreads(); } @@ -195,7 +198,7 @@ __device__ void ROContext::ctx_destroy() { build_queue_element(RO_NET_FINALIZE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, - block_handle, true); + block_handle, true, get_status_flag()); int buffer_id = ro_net_win_id; backend->queue_.descriptor(buffer_id)->write_index = block_handle->write_index; @@ -219,7 +222,7 @@ __device__ void ROContext::putmem_wg(void *dest, const void *source, if (is_thread_zero_in_block()) { build_queue_element(RO_NET_PUT, dest, const_cast(source), nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true); + ro_net_win_id, block_handle, true, get_status_flag()); } } __syncthreads(); @@ -237,7 +240,7 @@ __device__ void ROContext::getmem_wg(void *dest, const void *source, if (is_thread_zero_in_block()) { build_queue_element(RO_NET_GET, dest, const_cast(source), nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true); + ro_net_win_id, block_handle, true, get_status_flag()); } } __syncthreads(); @@ -291,7 +294,7 @@ __device__ void ROContext::putmem_wave(void *dest, const void *source, if (is_thread_zero_in_wave()) { build_queue_element(RO_NET_PUT, dest, const_cast(source), nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true); + ro_net_win_id, block_handle, true, get_status_flag()); } } } @@ -309,7 +312,7 @@ __device__ void ROContext::getmem_wave(void *dest, const void *source, if (is_thread_zero_in_wave()) { build_queue_element(RO_NET_GET, dest, const_cast(source), nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true); + ro_net_win_id, block_handle, true, get_status_flag()); } } } @@ -578,7 +581,8 @@ __device__ void build_queue_element( ro_net_cmds type, void *dst, void *src, size_t size, int pe, int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync, MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle, - bool blocking, ROCSHMEM_OP op, ro_net_types datatype) { + bool blocking, volatile char *status, ROCSHMEM_OP op, + ro_net_types datatype) { auto write_slot{next_write_slot(handle)}; auto queue_element = &handle->queue[write_slot]; @@ -587,6 +591,9 @@ __device__ void build_queue_element( queue_element->ol1.size = size; queue_element->dst = dst; queue_element->ro_net_win_id = ro_net_win_id; + if(blocking) { + queue_element->status = status; + } if (type == RO_NET_P) { memcpy(&queue_element->src, src, size); @@ -594,9 +601,6 @@ __device__ void build_queue_element( queue_element->src = src; } - auto threadId {get_flat_id()}; - queue_element->threadId = threadId; - if (type == RO_NET_AMO_FOP) { queue_element->op = op; queue_element->datatype = datatype; @@ -655,19 +659,31 @@ __device__ void build_queue_element( if (blocking) { int network_status{0}; do { - refresh_volatile_sbyte(&network_status, &handle->status[threadId]); + refresh_volatile_sbyte(&network_status, queue_element->status); } while (network_status == 0); - handle->status[threadId] = 0; + *(queue_element->status) = 0; __threadfence(); } } -__device__ uint64_t *ROContext::get_unused_atomic() { - auto index{atomicAdd(&block_handle->atomic_ret.atomic_counter, 1)}; - index = index % max_nb_atomic; - auto atomic_base_ptr{block_handle->atomic_ret.atomic_base_ptr}; - return &atomic_base_ptr[index]; +__device__ uint64_t *ROContext::get_atomic_ret_buf() { + uint64_t *atomic_base_ptr{ + reinterpret_cast(block_handle->atomic_ret)}; + int thread_id{get_flat_block_id()}; + return &atomic_base_ptr[thread_id]; +} + +__device__ uint64_t *ROContext::get_g_ret_buf() { + uint64_t *g_ret{reinterpret_cast(block_handle->g_ret)}; + int thread_id{get_flat_block_id()}; + return &g_ret[thread_id]; +} + +__device__ volatile char *ROContext::get_status_flag() { + volatile char* status{block_handle->status}; + int thread_id{get_flat_block_id()}; + return &status[thread_id]; } } // namespace rocshmem diff --git a/src/reverse_offload/context_ro_device.hpp b/src/reverse_offload/context_ro_device.hpp index ebdfb26d9e..4cae7a09aa 100644 --- a/src/reverse_offload/context_ro_device.hpp +++ b/src/reverse_offload/context_ro_device.hpp @@ -34,7 +34,7 @@ __device__ void build_queue_element( ro_net_cmds type, void *dst, void *src, size_t size, int pe, int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync, MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle, - bool blocking, ROCSHMEM_OP op = ROCSHMEM_SUM, + bool blocking, volatile char *status = nullptr, ROCSHMEM_OP op = ROCSHMEM_SUM, ro_net_types datatype = RO_NET_INT); class ROContext : public Context { @@ -251,7 +251,11 @@ class ROContext : public Context { __device__ uint64_t signal_fetch_wave(const uint64_t *sig_addr); private: - __device__ uint64_t *get_unused_atomic(); + __device__ uint64_t *get_atomic_ret_buf(); + + __device__ uint64_t *get_g_ret_buf(); + + __device__ volatile char *get_status_flag(); BlockHandle *block_handle{nullptr}; diff --git a/src/reverse_offload/context_ro_tmpl_device.hpp b/src/reverse_offload/context_ro_tmpl_device.hpp index 5a24e0430f..9aff6e32f4 100644 --- a/src/reverse_offload/context_ro_tmpl_device.hpp +++ b/src/reverse_offload/context_ro_tmpl_device.hpp @@ -120,7 +120,8 @@ __device__ int ROContext::reduce(rocshmem_team_t team, T *dest, build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast(source), nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm, - ro_net_win_id, block_handle, true, Op, GetROType::Type); + ro_net_win_id, block_handle, true, get_status_flag(), + Op, GetROType::Type); __syncthreads(); return ROCSHMEM_SUCCESS; @@ -137,8 +138,8 @@ __device__ void ROContext::to_all(T *dest, const T *source, int nreduce, build_queue_element(RO_NET_TO_ALL, dest, const_cast(source), nreduce, PE_start, logPE_stride, PE_size, 0, pWrk, pSync, - (MPI_Comm)NULL, ro_net_win_id, block_handle, true, Op, - GetROType::Type); + (MPI_Comm)NULL, ro_net_win_id, block_handle, true, + get_status_flag(), Op, GetROType::Type); __syncthreads(); } @@ -166,7 +167,7 @@ __device__ void ROContext::p(T *dest, T value, int pe) { } else { build_queue_element(RO_NET_P, dest, &value, sizeof(T), pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, - block_handle, true); + block_handle, true, get_status_flag()); } } @@ -179,12 +180,7 @@ __device__ T ROContext::g(const T *source, int pe) { ipcImpl_.ipcCopy(&dest, ipcImpl_.ipc_bases[pe] + L_offset, sizeof(T)); return dest; } else { - int thread_id{get_flat_block_id()}; - int block_size{get_flat_block_size()}; - int offset{get_flat_grid_id() * block_size + thread_id}; - - char *base_dest{block_handle->g_ret}; - char *dest{&base_dest[offset * sizeof(int64_t)]}; + auto dest{get_g_ret_buf()}; get(reinterpret_cast(dest), source, 1, pe); return *(reinterpret_cast(dest)); } @@ -206,12 +202,13 @@ __device__ void ROContext::get_nbi(T *dest, const T *source, size_t nelems, template __device__ T ROContext::amo_fetch_cas(void *dst, T value, T cond, int pe) { - auto source{get_unused_atomic()}; + auto source{get_atomic_ret_buf()}; build_queue_element(RO_NET_AMO_FCAS, dst, reinterpret_cast(source), value, pe, 0, 0, 0, reinterpret_cast(static_cast(cond)), - nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true, - ROCSHMEM_SUM, GetROType::Type); + nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, + true, get_status_flag(), ROCSHMEM_SUM, + GetROType::Type); __threadfence(); return *source; } @@ -223,11 +220,11 @@ __device__ void ROContext::amo_cas(void *dst, T value, T cond, int pe) { template __device__ T ROContext::amo_fetch_add(void *dst, T value, int pe) { - auto source{get_unused_atomic()}; + auto source{get_atomic_ret_buf()}; build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast(source), value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true, ROCSHMEM_SUM, - GetROType::Type); + ro_net_win_id, block_handle, true, get_status_flag(), + ROCSHMEM_SUM, GetROType::Type); __threadfence(); return *source; } @@ -239,11 +236,11 @@ __device__ void ROContext::amo_add(void *dst, T value, int pe) { template __device__ T ROContext::amo_swap(void *dst, T value, int pe) { - auto source{get_unused_atomic()}; + auto source{get_atomic_ret_buf()}; build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast(source), value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true, ROCSHMEM_REPLACE, - GetROType::Type); + ro_net_win_id, block_handle, true, get_status_flag(), + ROCSHMEM_REPLACE, GetROType::Type); __threadfence(); return *source; } @@ -255,11 +252,11 @@ __device__ void ROContext::amo_set(void *dst, T value, int pe) { template __device__ T ROContext::amo_fetch_and(void *dst, T value, int pe) { - auto source{get_unused_atomic()}; + auto source{get_atomic_ret_buf()}; build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast(source), value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true, ROCSHMEM_AND, - GetROType::Type); + ro_net_win_id, block_handle, true, get_status_flag(), + ROCSHMEM_AND, GetROType::Type); __threadfence(); return *source; } @@ -271,11 +268,11 @@ __device__ void ROContext::amo_and(void *dst, T value, int pe) { template __device__ T ROContext::amo_fetch_or(void *dst, T value, int pe) { - auto source{get_unused_atomic()}; + auto source{get_atomic_ret_buf()}; build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast(source), value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true, ROCSHMEM_OR, - GetROType::Type); + ro_net_win_id, block_handle, true, get_status_flag(), + ROCSHMEM_OR, GetROType::Type); __threadfence(); return *source; } @@ -287,11 +284,11 @@ __device__ void ROContext::amo_or(void *dst, T value, int pe) { template __device__ T ROContext::amo_fetch_xor(void *dst, T value, int pe) { - auto source{get_unused_atomic()}; + auto source{get_atomic_ret_buf()}; build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast(source), value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, - ro_net_win_id, block_handle, true, ROCSHMEM_XOR, - GetROType::Type); + ro_net_win_id, block_handle, true, get_status_flag(), + ROCSHMEM_XOR, GetROType::Type); __threadfence(); return *source; } @@ -314,7 +311,7 @@ __device__ void ROContext::broadcast(rocshmem_team_t team, T *dest, build_queue_element(RO_NET_TEAM_BROADCAST, dest, const_cast(source), nelems, 0, 0, 0, pe_root, nullptr, nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, true, - ROCSHMEM_SUM, GetROType::Type); + get_status_flag(), ROCSHMEM_SUM, GetROType::Type); __syncthreads(); } @@ -332,7 +329,7 @@ __device__ void ROContext::broadcast(T *dest, const T *source, int nelems, build_queue_element(RO_NET_BROADCAST, dest, const_cast(source), nelems, pe_start, log_pe_stride, pe_size, pe_root, nullptr, p_sync, (MPI_Comm)NULL, ro_net_win_id, block_handle, true, - ROCSHMEM_SUM, GetROType::Type); + get_status_flag(), ROCSHMEM_SUM, GetROType::Type); __syncthreads(); } @@ -350,7 +347,7 @@ __device__ void ROContext::alltoall(rocshmem_team_t team, T *dest, build_queue_element(RO_NET_ALLTOALL, dest, const_cast(source), nelems, 0, 0, 0, 0, team_obj->ata_buffer, nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, true, - ROCSHMEM_SUM, GetROType::Type); + get_status_flag(), ROCSHMEM_SUM, GetROType::Type); __syncthreads(); } @@ -368,7 +365,7 @@ __device__ void ROContext::fcollect(rocshmem_team_t team, T *dest, build_queue_element(RO_NET_FCOLLECT, dest, const_cast(source), nelems, 0, 0, 0, 0, team_obj->ata_buffer, nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, true, - ROCSHMEM_SUM, GetROType::Type); + get_status_flag(), ROCSHMEM_SUM, GetROType::Type); __syncthreads(); } diff --git a/src/reverse_offload/mpi_transport.cpp b/src/reverse_offload/mpi_transport.cpp index dc678d5bbc..e6dc7aa1a9 100644 --- a/src/reverse_offload/mpi_transport.cpp +++ b/src/reverse_offload/mpi_transport.cpp @@ -94,7 +94,7 @@ void MPITransport::submitRequestsToMPI() { case RO_NET_PUT: putMem(next_element.dst, next_element.src, next_element.ol1.size, next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, true); + next_element.status, true); DPRINTF("Received PUT dst %p src %p size %lu pe %d win_id %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.PE, next_element.ro_net_win_id); @@ -109,7 +109,7 @@ void MPITransport::submitRequestsToMPI() { putMem(next_element.dst, source_buffer, next_element.ol1.size, next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, true, true); + next_element.status, true, true); DPRINTF("Received P dst %p value %p pe %d\n", next_element.dst, next_element.src, next_element.PE); break; @@ -117,14 +117,14 @@ void MPITransport::submitRequestsToMPI() { case RO_NET_GET: getMem(next_element.dst, next_element.src, next_element.ol1.size, next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, true); + next_element.status, true); DPRINTF("Received GET dst %p src %p size %lu pe %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.PE); break; case RO_NET_PUT_NBI: putMem(next_element.dst, next_element.src, next_element.ol1.size, next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, false); + next_element.status, false); DPRINTF("Received PUT NBI dst %p src %p size %lu pe %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.PE); @@ -132,7 +132,7 @@ void MPITransport::submitRequestsToMPI() { case RO_NET_GET_NBI: getMem(next_element.dst, next_element.src, next_element.ol1.size, next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, false); + next_element.status, false); DPRINTF("Received GET NBI dst %p src %p size %lu pe %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.PE); @@ -141,7 +141,7 @@ void MPITransport::submitRequestsToMPI() { amoFOP(next_element.dst, next_element.src, const_cast(&next_element.ol1.atomic_value), next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, true, + next_element.status, true, static_cast(next_element.op), static_cast(next_element.datatype)); DPRINTF("Received AMO dst %p src %p Val %llu pe %d\n", next_element.dst, @@ -151,7 +151,7 @@ void MPITransport::submitRequestsToMPI() { amoFCAS(next_element.dst, next_element.src, const_cast(&next_element.ol1.atomic_value), next_element.PE, next_element.ro_net_win_id, queue_idx, - next_element.threadId, true, + next_element.status, true, const_cast(&next_element.ol2.pWrk), static_cast(next_element.datatype)); DPRINTF("Received F_CSWAP dst %p src %p Val %llu pe %d cond %ld\n", @@ -165,7 +165,7 @@ void MPITransport::submitRequestsToMPI() { next_element.team_comm, static_cast(next_element.op), static_cast(next_element.datatype), - next_element.threadId, true); + next_element.status, true); DPRINTF("Received FLOAT_SUM_TEAM_REDUCE dst %p src %p size %lu team %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.team_comm); @@ -177,7 +177,7 @@ void MPITransport::submitRequestsToMPI() { next_element.PE_size, next_element.ol2.pWrk, next_element.pSync, static_cast(next_element.op), static_cast(next_element.datatype), - next_element.threadId, true); + next_element.status, true); DPRINTF( "Received FLOAT_SUM_TO_ALL dst %p src %p size %lu " "PE_start %d, logPE_stride %d, PE_size %d, pWrk %p, pSync %p\n", @@ -190,7 +190,7 @@ void MPITransport::submitRequestsToMPI() { next_element.ro_net_win_id, queue_idx, next_element.team_comm, next_element.PE_root, static_cast(next_element.datatype), - next_element.threadId, true); + next_element.status, true); DPRINTF( "Received TEAM_BROADCAST dst %p src %p size %lu " "team %d, PE_root %d \n", @@ -203,7 +203,7 @@ void MPITransport::submitRequestsToMPI() { next_element.PE, next_element.logPE_stride, next_element.PE_size, next_element.PE_root, next_element.pSync, static_cast(next_element.datatype), - next_element.threadId, true); + next_element.status, true); DPRINTF( "Received BROADCAST dst %p src %p size %lu PE_start %d, " "logPE_stride %d, PE_size %d, PE_root %d, pSync %p\n", @@ -216,7 +216,7 @@ void MPITransport::submitRequestsToMPI() { next_element.ro_net_win_id, queue_idx, next_element.team_comm, next_element.ol2.pWrk, static_cast(next_element.datatype), - next_element.threadId, true); + next_element.status, true); DPRINTF("Received ALLTOALL dst %p src %p size %lu team %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.team_comm); @@ -226,26 +226,26 @@ void MPITransport::submitRequestsToMPI() { next_element.ro_net_win_id, queue_idx, next_element.team_comm, next_element.ol2.pWrk, static_cast(next_element.datatype), - next_element.threadId, true); + next_element.status, true); DPRINTF("Received FCOLLECT dst %p src %p size %lu team %d\n", next_element.dst, next_element.src, next_element.ol1.size, next_element.team_comm); break; case RO_NET_BARRIER_ALL: - barrier(queue_idx, next_element.threadId, true, ro_net_comm_world); + barrier(queue_idx, next_element.status, true, ro_net_comm_world); DPRINTF("Received Barrier_all\n"); break; case RO_NET_SYNC: - barrier(queue_idx, next_element.threadId, true, next_element.team_comm); + barrier(queue_idx, next_element.status, true, next_element.team_comm); DPRINTF("Received Sync\n"); break; case RO_NET_FENCE: case RO_NET_QUIET: - quiet(queue_idx, next_element.threadId); + quiet(queue_idx, next_element.status); DPRINTF("Received FENCE/QUIET\n"); break; case RO_NET_FINALIZE: - quiet(queue_idx, next_element.threadId); + quiet(queue_idx, next_element.status); DPRINTF("Received Finalize\n"); break; default: @@ -256,7 +256,7 @@ void MPITransport::submitRequestsToMPI() { } void MPITransport::initTransport(int num_queues, BackendProxyT *proxy) { - waiting_quiet.resize(num_queues, std::vector()); + waiting_quiet.resize(num_queues, std::vector()); outstanding.resize(num_queues, 0); transport_up = false; @@ -333,13 +333,13 @@ void MPITransport::global_exit(int status) { MPI_Abort(ro_net_comm_world, status); } -void MPITransport::barrier(int blockId, int threadId, bool blocking, - MPI_Comm team) { +void MPITransport::barrier(int contextId, volatile char *status, bool blocking, + MPI_Comm team) { MPI_Request request{}; NET_CHECK(MPI_Ibarrier(team, &request)); - requests.push_back({request, {threadId, blockId, blocking}}); - outstanding[blockId]++; + requests.push_back({request, {status, contextId, blocking}}); + outstanding[contextId]++; } MPI_Op MPITransport::get_mpi_op(ROCSHMEM_OP op) { @@ -397,10 +397,10 @@ static MPI_Datatype convertType(ro_net_types type) { } void MPITransport::reduction(void *dst, void *src, int size, int pe, - int win_id, int blockId, int start, int logPstride, - int sizePE, void *pWrk, long *pSync, - ROCSHMEM_OP op, ro_net_types type, int threadId, - bool blocking) { + int win_id, int contextId, int start, + int logPstride, int sizePE, void *pWrk, + long *pSync, ROCSHMEM_OP op, ro_net_types type, + volatile char *status, bool blocking) { MPI_Request request{}; MPI_Op mpi_op{get_mpi_op(op)}; MPI_Datatype mpi_type{convertType(type)}; @@ -413,14 +413,15 @@ void MPITransport::reduction(void *dst, void *src, int size, int pe, NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request)); } - requests.push_back({request, {threadId, blockId, blocking}}); - outstanding[blockId]++; + requests.push_back({request, {status, contextId, blocking}}); + outstanding[contextId]++; } void MPITransport::broadcast(void *dst, void *src, int size, int pe, - int win_id, int blockId, int start, int logPstride, - int sizePE, int root, long *pSync, - ro_net_types type, int threadId, bool blocking) { + int win_id, int contextId, int start, + int logPstride, int sizePE, int root, long *pSync, + ro_net_types type, volatile char *status, + bool blocking) { MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)}; int new_rank{}; @@ -437,15 +438,15 @@ void MPITransport::broadcast(void *dst, void *src, int size, int pe, MPI_Datatype mpi_type{convertType(type)}; NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request)); - requests.push_back({request, {threadId, blockId, blocking}}); + requests.push_back({request, {status, contextId, blocking}}); - outstanding[blockId]++; + outstanding[contextId]++; } void MPITransport::team_reduction(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, ROCSHMEM_OP op, - ro_net_types type, int threadId, - bool blocking) { + int contextId, MPI_Comm team, ROCSHMEM_OP op, + ro_net_types type, volatile char* status, + bool blocking) { MPI_Request request{}; MPI_Op mpi_op{get_mpi_op(op)}; @@ -459,15 +460,15 @@ void MPITransport::team_reduction(void *dst, void *src, int size, int win_id, NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request)); } - requests.push_back({request, {threadId, blockId, blocking}}); + requests.push_back({request, {status, contextId, blocking}}); - outstanding[blockId]++; + outstanding[contextId]++; } void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, int root, - ro_net_types type, int threadId, - bool blocking) { + int contextId, MPI_Comm team, int root, + ro_net_types type, volatile char *status, + bool blocking) { MPI_Comm comm{team}; int new_rank{}; MPI_Comm_rank(comm, &new_rank); @@ -482,14 +483,15 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id, MPI_Request request{}; NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request)); - requests.push_back({request, {threadId, blockId, blocking}}); + requests.push_back({request, {status, contextId, blocking}}); - outstanding[blockId]++; + outstanding[contextId]++; } void MPITransport::alltoall(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, bool blocking) { + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, + bool blocking) { int pe_size{}; NET_CHECK(MPI_Comm_size(team, &pe_size)); @@ -502,24 +504,24 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id, #ifdef A2A_HEURISTICS if ((pe_size >= 8 || type_size * size < 2048) && num_clust * clust_size == pe_size) { - return alltoall_gcen(dst, src, size, win_id, blockId, team, ata_buffptr, type, - threadId, blocking); + return alltoall_gcen(dst, src, size, win_id, contextId, team, ata_buffptr, + type, status, blocking); } else if (size <= 512) { #endif // A2A_HEURISTICS - return alltoall_mpi(dst, src, size, blockId, team, ata_buffptr, type, - threadId, blocking); + return alltoall_mpi(dst, src, size, contextId, team, ata_buffptr, type, + status, blocking); #ifdef A2A_HEURISTICS } else { - return alltoall_broadcast(dst, src, size, win_id, blockId, team, ata_buffptr, - type, threadId, blocking); + return alltoall_broadcast(dst, src, size, win_id, contextId, team, + ata_buffptr, type, status, blocking); } #endif // A2A_HEURISTICS } void MPITransport::alltoall_broadcast(void *dst, void *src, int size, - int win_id, int blockId, MPI_Comm team, - void *ata_buffptr, ro_net_types type, - int threadId, bool blocking) { + int win_id, int contextId, MPI_Comm team, + void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) { auto *bp{backend_proxy->get()}; MPI_Comm comm{team}; @@ -563,26 +565,26 @@ void MPITransport::alltoall_broadcast(void *dst, void *src, int size, NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE)); NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); - barrier(blockId, threadId, blocking, comm); + barrier(contextId, status, blocking, comm); } -void MPITransport::alltoall_mpi(void *dst, void *src, int size, int blockId, - MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, - bool blocking) { +void MPITransport::alltoall_mpi(void *dst, void *src, int size, int contextId, + MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, + bool blocking) { int new_rank{}; NET_CHECK(MPI_Comm_rank(team, &new_rank)); int pe_size{}; NET_CHECK(MPI_Comm_size(team, &pe_size)); MPI_Datatype mpi_type{convertType(type)}; NET_CHECK(MPI_Alltoall(src, size, mpi_type, dst, size, mpi_type, team)); - quiet(blockId, threadId); + quiet(contextId, status); } void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, - bool blocking) { + int contextId, MPI_Comm team, + void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) { auto *bp{backend_proxy->get()}; int new_rank{}; @@ -664,14 +666,14 @@ void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id, MPI_Comm comm_ring{createComm(world_ranks[new_rank % clust_size], stride * clust_size, num_clust)}; - barrier(blockId, threadId, false, comm_cluster); - barrier(blockId, threadId, blocking, comm_ring); + barrier(contextId, status, false, comm_cluster); + barrier(contextId, status, blocking, comm_ring); } void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, - bool blocking) { + int contextId, MPI_Comm team, + void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) { // GPU-centric alltoall with in-place blocking synchronization auto *bp{backend_proxy->get()}; int new_rank, pe_size; @@ -759,12 +761,13 @@ void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id, MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size], stride * clust_size, num_clust); // Now wait for completion - barrier(blockId, threadId, blocking, comm_ring); + barrier(contextId, status, blocking, comm_ring); } void MPITransport::fcollect(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, bool blocking) { + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, + bool blocking) { int pe_size, type_size; MPI_Comm comm = team; NET_CHECK(MPI_Comm_size(comm, &pe_size)); @@ -780,21 +783,21 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id, // In most cases the MPI implementation is optimal // But it crashes for > 512 messages if (size <= 512) { - fcollect_mpi(dst, src, size, blockId, team, ata_buffptr, type, - threadId, blocking); + fcollect_mpi(dst, src, size, contextId, team, ata_buffptr, type, + status, blocking); } else if (num_clust * clust_size == pe_size) { - fcollect_gcen(dst, src, size, win_id, blockId, team, ata_buffptr, type, - threadId, blocking); + fcollect_gcen(dst, src, size, win_id, contextId, team, ata_buffptr, type, + status, blocking); } else { - fcollect_broadcast(dst, src, size, win_id, blockId, team, ata_buffptr, - type, threadId, blocking); + fcollect_broadcast(dst, src, size, win_id, contextId, team, ata_buffptr, + type, status, blocking); } } void MPITransport::fcollect_broadcast(void *dst, void *src, int size, - int win_id, int blockId, MPI_Comm team, - void *ata_buffptr, ro_net_types type, - int threadId, bool blocking) { + int win_id, int contextId, MPI_Comm team, + void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) { // Broadcast implementation of fcollect auto *bp{backend_proxy->get()}; int new_rank, pe_size; @@ -837,13 +840,13 @@ void MPITransport::fcollect_broadcast(void *dst, void *src, int size, NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); // Now wait for completion - barrier(blockId, threadId, blocking, comm); + barrier(contextId, status, blocking, comm); } -void MPITransport::fcollect_mpi(void *dst, void *src, int size, int blockId, - MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, - bool blocking) { +void MPITransport::fcollect_mpi(void *dst, void *src, int size, int contextId, + MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, + bool blocking) { // MPI's implementation of fcollect int new_rank, pe_size; @@ -852,13 +855,13 @@ void MPITransport::fcollect_mpi(void *dst, void *src, int size, int blockId, NET_CHECK(MPI_Comm_rank(comm, &new_rank)); NET_CHECK(MPI_Comm_size(comm, &pe_size)); NET_CHECK(MPI_Allgather(src, size, mpi_type, dst, size, mpi_type, comm)); - quiet(blockId, threadId); + quiet(contextId, status); } void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, - bool blocking) { + int contextId, MPI_Comm team, + void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) { // GPU-centric implementation of fcollect auto *bp{backend_proxy->get()}; int new_rank, pe_size; @@ -938,14 +941,14 @@ void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id, MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size], stride * clust_size, num_clust); // Now wait for completion - barrier(blockId, threadId, false, comm_cluster); - barrier(blockId, threadId, blocking, comm_ring); + barrier(contextId, status, false, comm_cluster); + barrier(contextId, status, blocking, comm_ring); } void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, - bool blocking) { + int contextId, MPI_Comm team, + void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) { // GPU-centric implementation with in-place, blocking synchronization auto *bp{backend_proxy->get()}; int new_rank, pe_size; @@ -1027,12 +1030,12 @@ void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id, MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size], stride * clust_size, num_clust); // Now wait for completion - barrier(blockId, threadId, blocking, comm_ring); + barrier(contextId, status, blocking, comm_ring); } void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id, - int blockId, int threadId, bool blocking, - bool inline_data) { + int contextId, volatile char *status, bool blocking, + bool inline_data) { queue->flush_hdp(); auto *bp{backend_proxy->get()}; @@ -1047,14 +1050,14 @@ void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id, // though it should be in the progress loop. NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win())); - requests.push_back({request, {threadId, blockId, blocking}}); + requests.push_back({request, {status, contextId, blocking}}); - outstanding[blockId]++; + outstanding[contextId]++; } void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id, - int blockId, int threadId, bool blocking, - ROCSHMEM_OP op, ro_net_types type) { + int contextId, volatile char *status, bool blocking, + ROCSHMEM_OP op, ro_net_types type) { queue->flush_hdp(); auto *bp{backend_proxy->get()}; @@ -1069,14 +1072,14 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id, // though it should be in the progress loop. NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win())); - queue->notify(blockId, threadId); + queue->notify(status); queue->sfence_flush_hdp(); } void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe, - int win_id, int blockId, int threadId, bool blocking, - void *cond, ro_net_types type) { + int win_id, int contextId, volatile char *status, + bool blocking, void *cond, ro_net_types type) { queue->flush_hdp(); auto *bp{backend_proxy->get()}; @@ -1091,14 +1094,15 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe, // though it should be in the progress loop. NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win())); - queue->notify(blockId, threadId); + queue->notify(status); queue->sfence_flush_hdp(); } void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id, - int blockId, int threadId, bool blocking) { - outstanding[blockId]++; + int contextId, volatile char *status, + bool blocking) { + outstanding[contextId]++; auto *bp{backend_proxy->get()}; MPI_Request request{}; @@ -1106,7 +1110,7 @@ void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id, dst, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(src), size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request)); - requests.push_back({request, {threadId, blockId, blocking}}); + requests.push_back({request, {status, contextId, blocking}}); } std::unique_ptr MPITransport::raw_requests() { @@ -1139,20 +1143,20 @@ void MPITransport::progress() { auto *bp{backend_proxy->get()}; for (int i{0}; i < outcount; i++) { int index{testsome_indices[i]}; - int blockId{requests[index].properties.blockId}; - int threadId{requests[index].properties.threadId}; + int contextId{requests[index].properties.contextId}; + volatile char *status{requests[index].properties.status}; - if (blockId != -1) { - outstanding[blockId]--; + if (contextId != -1) { + outstanding[contextId]--; DPRINTF( - "Finished op for blockId %d at threadId %d " + "Finished op for contextId %d at status addr %p " "(%d requests outstanding)\n", - blockId, threadId, outstanding[blockId]); + contextId, status, outstanding[contextId]); } if (requests[index].properties.blocking) { - if (blockId != -1) { - queue->notify(blockId, threadId); + if (contextId != -1) { + queue->notify(status); } queue->sfence_flush_hdp(); } @@ -1163,14 +1167,14 @@ void MPITransport::progress() { // If the GPU has requested a quiet, notify it of completion when // all outstanding requests are complete. - if (!outstanding[blockId] && !waiting_quiet[blockId].empty()) { - for (const auto threadId : waiting_quiet[blockId]) { - DPRINTF("Finished Quiet for blockId %d at threadId %d\n", blockId, - threadId); - queue->notify(blockId, threadId); + if (!outstanding[contextId] && !waiting_quiet[contextId].empty()) { + for (const auto status : waiting_quiet[contextId]) { + DPRINTF("Finished Quiet for contextId %d at status addr %p\n", contextId, + status); + queue->notify(status); } - waiting_quiet[blockId].clear(); + waiting_quiet[contextId].clear(); queue->sfence_flush_hdp(); } @@ -1185,15 +1189,15 @@ void MPITransport::progress() { } } -void MPITransport::quiet(int blockId, int threadId) { +void MPITransport::quiet(int contextId, volatile char *status) { auto *bp{backend_proxy->get()}; - if (!outstanding[blockId]) { - DPRINTF("Finished Quiet immediately for blockId %d at threadId %d\n", blockId, - threadId); - queue->notify(blockId, threadId); + if (!outstanding[contextId]) { + DPRINTF("Finished Quiet immediately for contextId %d at status addr %p\n", + contextId, status); + queue->notify(status); } else { - waiting_quiet[blockId].emplace_back(threadId); + waiting_quiet[contextId].emplace_back(status); } } diff --git a/src/reverse_offload/mpi_transport.hpp b/src/reverse_offload/mpi_transport.hpp index fdd8e5a133..4912db81cd 100644 --- a/src/reverse_offload/mpi_transport.hpp +++ b/src/reverse_offload/mpi_transport.hpp @@ -36,7 +36,7 @@ namespace rocshmem { class HostInterface; class MPITransport : public Transport { - public: +public: explicit MPITransport(MPI_Comm com, Queue* queue); virtual ~MPITransport(); @@ -46,87 +46,92 @@ class MPITransport : public Transport { void finalizeTransport() override; void createNewTeam(ROBackend *backend, Team *parent_team, - TeamInfo *team_info_wrt_parent, - TeamInfo *team_info_wrt_world, int num_pes, - int my_pe_in_new_team, MPI_Comm team_comm, - rocshmem_team_t *new_team) override; + TeamInfo *team_info_wrt_parent, + TeamInfo *team_info_wrt_world, int num_pes, + int my_pe_in_new_team, MPI_Comm team_comm, + rocshmem_team_t *new_team) override; - void barrier(int blockId, int threadId, bool blocking, - MPI_Comm team) override; + void barrier(int contextId, volatile char *status, bool blocking, + MPI_Comm team) override; void reduction(void *dst, void *src, int size, int pe, int win_id, - int blockId, int start, int logPstride, int sizePE, void *pWrk, - long *pSync, ROCSHMEM_OP op, ro_net_types type, - int threadId, bool blocking) override; + int contextId, int start, int logPstride, int sizePE, + void *pWrk, long *pSync, ROCSHMEM_OP op, ro_net_types type, + volatile char *status, bool blocking) override; - void team_reduction(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, ROCSHMEM_OP op, ro_net_types type, - int threadId, bool blocking) override; + void team_reduction(void *dst, void *src, int size, int win_id, + int contextId, MPI_Comm team, ROCSHMEM_OP op, + ro_net_types type, volatile char *status, + bool blocking) override; void broadcast(void *dst, void *src, int size, int pe, int win_id, - int blockId, int start, int logPstride, int sizePE, - int PE_root, long *pSync, ro_net_types type, int threadId, - bool blocking) override; + int contextId, int start, int logPstride, int sizePE, + int PE_root, long *pSync, ro_net_types type, + volatile char *status, bool blocking) override; - void team_broadcast(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, int PE_root, ro_net_types type, - int threadId, bool blocking) override; + void team_broadcast(void *dst, void *src, int size, int win_id, + int contextId, MPI_Comm team, int PE_root, + ro_net_types type, volatile char *status, + bool blocking) override; - void alltoall(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking) override; + void alltoall(void *dst, void *src, int size, int win_id, int contextId, + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) override; void alltoall_broadcast(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, bool blocking); + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, + bool blocking); - void alltoall_mpi(void *dst, void *src, int size, int blockId, MPI_Comm team, - void *ata_buffptr, ro_net_types type, int threadId, - bool blocking); + void alltoall_mpi(void *dst, void *src, int size, int contextId, + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking); - void alltoall_gcen(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking); + void alltoall_gcen(void *dst, void *src, int size, int win_id, + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, bool blocking); - void alltoall_gcen2(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking); + void alltoall_gcen2(void *dst, void *src, int size, int win_id, + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, bool blocking); - void fcollect(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking) override; + void fcollect(void *dst, void *src, int size, int win_id, int contextId, + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) override; void fcollect_broadcast(void *dst, void *src, int size, int win_id, - int blockId, MPI_Comm team, void *ata_buffptr, - ro_net_types type, int threadId, bool blocking); + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, + bool blocking); - void fcollect_mpi(void *dst, void *src, int size, int blockId, MPI_Comm team, - void *ata_buffptr, ro_net_types type, int threadId, - bool blocking); + void fcollect_mpi(void *dst, void *src, int size, int contextId, + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking); - void fcollect_gcen(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking); + void fcollect_gcen(void *dst, void *src, int size, int win_id, int contextId, + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking); - void fcollect_gcen2(void *dst, void *src, int size, int win_id, int blockId, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking); + void fcollect_gcen2(void *dst, void *src, int size, int win_id, + int contextId, MPI_Comm team, void *ata_buffptr, + ro_net_types type, volatile char *status, bool blocking); - void putMem(void *dst, void *src, int size, int pe, int win_id, int blockId, - int threadId, bool blocking, bool inline_data = false) override; + void putMem(void *dst, void *src, int size, int pe, int win_id, + int contextId, volatile char *status, bool blocking, + bool inline_data = false) override; - void amoFOP(void *dst, void *src, void *val, int pe, int win_id, int blockId, - int threadId, bool blocking, ROCSHMEM_OP op, - ro_net_types type) override; + void amoFOP(void *dst, void *src, void *val, int pe, int win_id, + int contextId, volatile char *status, bool blocking, + ROCSHMEM_OP op, ro_net_types type) override; - void amoFCAS(void *dst, void *src, void *val, int pe, int win_id, int blockId, - int threadId, bool blocking, void *cond, - ro_net_types type) override; + void amoFCAS(void *dst, void *src, void *val, int pe, int win_id, + int contextId, volatile char *status, bool blocking, void *cond, + ro_net_types type) override; - void getMem(void *dst, void *src, int size, int pe, int win_id, int blockId, - int threadId, bool blocking) override; + void getMem(void *dst, void *src, int size, int pe, int win_id, + int contextId, volatile char *status, bool blocking) override; - void quiet(int blockId, int threadId) override; + void quiet(int contextId, volatile char *status) override; void progress() override; @@ -142,7 +147,7 @@ class MPITransport : public Transport { HostInterface *host_interface{nullptr}; - private: +private: struct CommKey { CommKey(int _start, int _logPstride, int _size) : start(_start), logPstride(_logPstride), size(_size) {} @@ -160,23 +165,23 @@ class MPITransport : public Transport { }; struct RequestProperties { - RequestProperties(int _threadId, int _blockId, bool _blocking, void *_src, - bool _inline_data) - : threadId(_threadId), - blockId(_blockId), + RequestProperties(volatile char *_status, int _contextId, bool _blocking, + void *_src, bool _inline_data) + : status(_status), + contextId(_contextId), blocking(_blocking), src(_src), inline_data(_inline_data) {} - RequestProperties(int _threadId, int _blockId, bool _blocking) - : threadId(_threadId), - blockId(_blockId), + RequestProperties(volatile char *_status, int _contextId, bool _blocking) + : status(_status), + contextId(_contextId), blocking(_blocking), src(nullptr), inline_data(false) {} - int threadId{-1}; - int blockId{-1}; + volatile char* status{nullptr}; + int contextId{-1}; bool blocking{}; void *src{nullptr}; bool inline_data{}; @@ -202,7 +207,7 @@ class MPITransport : public Transport { // Unordered vector of in-flight MPI Requests. Can complete out of order. std::vector requests{}; - std::vector > waiting_quiet{}; + std::vector > waiting_quiet{}; std::vector outstanding{}; diff --git a/src/reverse_offload/queue.cpp b/src/reverse_offload/queue.cpp index b6e37b48c1..76e4cea2e1 100644 --- a/src/reverse_offload/queue.cpp +++ b/src/reverse_offload/queue.cpp @@ -33,12 +33,11 @@ Queue::Queue() { } } -Queue::Queue(size_t max_queues, size_t max_wg_size, size_t queue_size) +Queue::Queue(size_t max_queues, size_t queue_size) : max_queues_{max_queues}, - max_wg_size_{max_wg_size}, queue_size_{queue_size}, queue_proxy_{max_queues, queue_size}, - queue_desc_proxy_{max_queues, max_wg_size} { + queue_desc_proxy_{max_queues} { gpu_queue = true; char *value{nullptr}; @@ -101,8 +100,8 @@ void Queue::sfence_flush_hdp() { } } -void Queue::notify(int blockId, int threadId) { - descriptor(blockId)->status[threadId] = 1; +void Queue::notify(volatile char* status) { + *status = 1; } uint64_t Queue::size() { diff --git a/src/reverse_offload/queue.hpp b/src/reverse_offload/queue.hpp index 3ea41a39c7..eb110c761b 100644 --- a/src/reverse_offload/queue.hpp +++ b/src/reverse_offload/queue.hpp @@ -35,7 +35,7 @@ class Queue { public: Queue(); - Queue(size_t max_queues, size_t max_threads_per_block, size_t queue_size); + Queue(size_t max_queues, size_t queue_size); bool process(uint64_t queue_index, MPITransport* transport); @@ -47,7 +47,7 @@ class Queue { void sfence_flush_hdp(); - void notify(int blockId, int threadId); + void notify(volatile char *status); uint64_t size(); @@ -72,8 +72,6 @@ class Queue { size_t max_queues_{}; - size_t max_wg_size_{}; - size_t queue_size_{}; }; diff --git a/src/reverse_offload/queue_desc_proxy.hpp b/src/reverse_offload/queue_desc_proxy.hpp index f5aa66c461..3b301ed88c 100644 --- a/src/reverse_offload/queue_desc_proxy.hpp +++ b/src/reverse_offload/queue_desc_proxy.hpp @@ -45,38 +45,22 @@ typedef struct queue_desc { */ uint64_t write_index; char padding2[56]; - /** - * This bit is used by the GPU to wait on a blocking operation. The initial - * value is 0. When a GPU enqueues a blocking operation, it waits for this - * value to resolve to 1, which is set by the CPU when the blocking - * operation completes. The GPU then resets status back to zero. There is - * a separate status variable for each work-item in a work-group - */ - char *status; - char padding3[56]; } __attribute__((__aligned__(64))) queue_desc_t; template class QueueDescProxy { using ProxyT = DeviceProxy; - using ProxyStatusT = DeviceProxy; public: QueueDescProxy() = default; - QueueDescProxy(size_t max_queues, size_t max_threads_per_queue) - : max_queues_{max_queues}, max_threads_per_queue_{max_threads_per_queue}, - max_threads_{max_queues * max_threads_per_queue}, proxy_{max_queues}, - proxy_status_{max_queues * max_threads_per_queue} { - auto *status{proxy_status_.get()}; - size_t status_bytes{sizeof(char) * max_threads_}; - memset(status, 0, status_bytes); + QueueDescProxy(size_t max_queues) + : max_queues_{max_queues}, proxy_{max_queues} { auto *queue_descs{proxy_.get()}; for (size_t i{0}; i < max_queues_; i++) { queue_descs[i].read_index = 0; queue_descs[i].write_index = 0; - queue_descs[i].status = status + i * max_threads_per_queue_; } } @@ -93,13 +77,7 @@ class QueueDescProxy { private: ProxyT proxy_{}; - ProxyStatusT proxy_status_{}; - size_t max_queues_{}; - - size_t max_threads_per_queue_{}; - - size_t max_threads_{}; }; using QueueDescProxyT = QueueDescProxy; diff --git a/src/reverse_offload/queue_proxy.hpp b/src/reverse_offload/queue_proxy.hpp index ce3bda6fa0..c5780560bf 100644 --- a/src/reverse_offload/queue_proxy.hpp +++ b/src/reverse_offload/queue_proxy.hpp @@ -57,13 +57,13 @@ typedef struct queue_element { void *src{nullptr}; void *dst{nullptr}; int ro_net_win_id{-1}; - int threadId{-1}; int logPE_stride{-1}; int PE_size{-1}; long *pSync{nullptr}; int op{-1}; int datatype{-1}; int PE_root{-1}; + volatile char *status{nullptr}; MPI_Comm team_comm{}; union { size_t size; diff --git a/src/reverse_offload/transport.hpp b/src/reverse_offload/transport.hpp index 108bdcbf8f..fc815c6120 100644 --- a/src/reverse_offload/transport.hpp +++ b/src/reverse_offload/transport.hpp @@ -45,60 +45,61 @@ class Transport { virtual void finalizeTransport() = 0; virtual void createNewTeam(ROBackend *backend_handle, Team *parent_team, - TeamInfo *team_info_wrt_parent, - TeamInfo *team_info_wrt_world, int num_pes, - int my_pe_in_new_team, MPI_Comm team_comm, - rocshmem_team_t *new_team) = 0; + TeamInfo *team_info_wrt_parent, + TeamInfo *team_info_wrt_world, int num_pes, + int my_pe_in_new_team, MPI_Comm team_comm, + rocshmem_team_t *new_team) = 0; - virtual void barrier(int wg_id, int threadId, bool blocking, - MPI_Comm team) = 0; + virtual void barrier(int wg_id, volatile char *status, bool blocking, + MPI_Comm team) = 0; virtual void reduction(void *dst, void *src, int size, int pe, int win_id, - int wg_id, int start, int logPstride, int sizePE, - void *pWrk, long *pSync, ROCSHMEM_OP op, - ro_net_types type, int threadId, bool blocking) = 0; + int wg_id, int start, int logPstride, int sizePE, + void *pWrk, long *pSync, ROCSHMEM_OP op, + ro_net_types type, volatile char *status, + bool blocking) = 0; virtual void team_reduction(void *dst, void *src, int size, int win_id, - int wg_id, MPI_Comm team, ROCSHMEM_OP op, - ro_net_types type, int threadId, - bool blocking) = 0; + int wg_id, MPI_Comm team, ROCSHMEM_OP op, + ro_net_types type, volatile char *status, + bool blocking) = 0; virtual void broadcast(void *dst, void *src, int size, int pe, int win_id, - int wg_id, int start, int logPstride, int sizePE, - int PE_root, long *pSync, ro_net_types type, - int threadId, bool blocking) = 0; + int wg_id, int start, int logPstride, int sizePE, + int PE_root, long *pSync, ro_net_types type, + volatile char *status, bool blocking) = 0; virtual void team_broadcast(void *dst, void *src, int size, int win_id, - int wg_id, MPI_Comm team, int PE_root, - ro_net_types type, int threadId, - bool blocking) = 0; + int wg_id, MPI_Comm team, int PE_root, + ro_net_types type, volatile char *status, + bool blocking) = 0; virtual void alltoall(void *dst, void *src, int size, int win_id, int wg_id, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking) = 0; + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) = 0; virtual void fcollect(void *dst, void *src, int size, int win_id, int wg_id, - MPI_Comm team, void *ata_buffptr, ro_net_types type, - int threadId, bool blocking) = 0; + MPI_Comm team, void *ata_buffptr, ro_net_types type, + volatile char *status, bool blocking) = 0; virtual void putMem(void *dst, void *src, int size, int pe, int win_id, - int wg_id, int threadId, bool blocking, - bool inline_data = false) = 0; + int wg_id, volatile char *status, bool blocking, + bool inline_data = false) = 0; virtual void getMem(void *dst, void *src, int size, int pe, int win_id, - int wg_id, int threadId, bool blocking) = 0; + int wg_id, volatile char *status, bool blocking) = 0; virtual void amoFOP(void *dst, void *src, void *val, int pe, int win_id, - int wg_id, int threadId, bool blocking, ROCSHMEM_OP op, - ro_net_types type) = 0; + int wg_id, volatile char *status, bool blocking, + ROCSHMEM_OP op, ro_net_types type) = 0; virtual void amoFCAS(void *dst, void *src, void *val, int pe, int win_id, - int wg_id, int threadId, bool blocking, void *cond, - ro_net_types type) = 0; + int wg_id, volatile char *status, bool blocking, + void *cond, ro_net_types type) = 0; virtual bool readyForFinalize() = 0; - virtual void quiet(int wg_id, int threadId) = 0; + virtual void quiet(int wg_id, volatile char *status) = 0; virtual void progress() = 0;