Refactor RO backend data structures (#49)
- Remove hdp and ipc pointers from BlockHandle, align RO stats with RO contexts - Add run commands for `rocshmem_g` and `rocshmem_p` API tests in driver.sh - Allocate rocshmem API return buffers based on number of device contexts. - Associate status flag address with blocking calls and remove threadId dependency - Associated the status flag address with each blocking call request to notify the GPU thread. - Removed dependency on threadId for determining the appropriate status flag index. - Move status flag buffer allocation to backend. - Initialize allocated memeory to zero
Bu işleme şunda yer alıyor:
işlemeyi yapan:
GitHub
ebeveyn
96424a59a8
işleme
df4ad2c04d
@@ -179,6 +179,20 @@ TestRMA() {
|
||||
|
||||
ExecTest "teamctxget" 2 1 1 1048576
|
||||
|
||||
ExecTest "g" 2 1 1 1048576
|
||||
ExecTest "g" 2 1 1024 512
|
||||
ExecTest "g" 2 8 1 1048576
|
||||
ExecTest "g" 2 16 128 8
|
||||
ExecTest "g" 2 32 256 512
|
||||
ExecTest "g" 2 64 1024 8
|
||||
|
||||
ExecTest "p" 2 1 1 1048576
|
||||
ExecTest "p" 2 1 1024 512
|
||||
ExecTest "p" 2 8 1 1048576
|
||||
ExecTest "p" 2 16 128 8
|
||||
ExecTest "p" 2 32 256 512
|
||||
ExecTest "p" 2 64 1024 8
|
||||
|
||||
################################ Non-Blocking ################################
|
||||
|
||||
ExecTest "putnbi" 2 1 1 1048576
|
||||
|
||||
@@ -51,7 +51,7 @@ class DeviceProxy {
|
||||
/*
|
||||
* Default memory provided by the allocation to recognizable bytes.
|
||||
*/
|
||||
memset(static_cast<void*>(temp), 0xBC, size_bytes);
|
||||
memset(static_cast<void*>(temp), 0, size_bytes);
|
||||
|
||||
/*
|
||||
* Pass the memory into a unique ptr for tracking.
|
||||
|
||||
@@ -36,10 +36,8 @@ struct BackendRegister {
|
||||
std::atomic<bool> worker_thread_exit{false};
|
||||
bool *needs_quiet{nullptr};
|
||||
bool *needs_blocking{nullptr};
|
||||
char *g_ret{nullptr};
|
||||
HdpPolicy *hdp_policy{nullptr};
|
||||
WindowInfo **heap_window_info{nullptr};
|
||||
atomic_ret_t *atomic_ret{nullptr};
|
||||
SymmetricHeap *heap_ptr{nullptr};
|
||||
};
|
||||
|
||||
|
||||
@@ -64,7 +64,15 @@ ROBackend::ROBackend(MPI_Comm comm)
|
||||
|
||||
max_wg_size_ = device_props.maxThreadsPerBlock;
|
||||
|
||||
queue_ = Queue(maximum_num_contexts_, max_wg_size_, queue_size_);
|
||||
size_t num_buff_elems = maximum_num_contexts_ * max_wg_size_;
|
||||
|
||||
g_ret_buffer_ = RetBufferProxyT(num_buff_elems);
|
||||
|
||||
atomic_ret_buffer_ = RetBufferProxyT(num_buff_elems);
|
||||
|
||||
status_ = StatusProxyT(num_buff_elems);
|
||||
|
||||
queue_ = Queue(maximum_num_contexts_, queue_size_);
|
||||
|
||||
transport_ = new MPITransport(comm, &queue_);
|
||||
num_pes = transport_->getNumPes();
|
||||
@@ -87,10 +95,6 @@ ROBackend::ROBackend(MPI_Comm comm)
|
||||
|
||||
initIPC();
|
||||
|
||||
init_g_ret(&heap, transport_->get_world_comm(), maximum_num_contexts_, &bp->g_ret);
|
||||
|
||||
allocate_atomic_region(&bp->atomic_ret, maximum_num_contexts_);
|
||||
|
||||
transport_->initTransport(maximum_num_contexts_, &backend_proxy);
|
||||
|
||||
host_interface = transport_->host_interface;
|
||||
@@ -107,14 +111,17 @@ ROBackend::ROBackend(MPI_Comm comm)
|
||||
reinterpret_cast<rocshmem_team_t>(team_world_proxy_->get());
|
||||
|
||||
default_block_handle_proxy_ = DefaultBlockHandleProxyT(
|
||||
bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get());
|
||||
g_ret_buffer_.get(),
|
||||
atomic_ret_buffer_.get(), &queue_,
|
||||
status_.get());
|
||||
|
||||
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
|
||||
|
||||
default_context_proxy_ = DefaultContextProxyT(this, tinfo);
|
||||
|
||||
block_handle_proxy_ = BlockHandleProxyT(bp->g_ret, bp->atomic_ret, &queue_,
|
||||
&ipcImpl, hdp_proxy_.get(),
|
||||
maximum_num_contexts_);
|
||||
block_handle_proxy_ = BlockHandleProxyT(g_ret_buffer_.get(),
|
||||
atomic_ret_buffer_.get(), &queue_,
|
||||
max_wg_size_, status_.get(), maximum_num_contexts_);
|
||||
setup_ctxs();
|
||||
|
||||
worker_thread = std::thread(&ROBackend::ro_net_poll, this);
|
||||
@@ -187,7 +194,7 @@ void ROBackend::ctx_destroy(Context *ctx) {
|
||||
void ROBackend::reset_backend_stats() {
|
||||
auto *bp{backend_proxy.get()};
|
||||
|
||||
for (size_t i{0}; i < MAX_NUM_BLOCKS; i++) {
|
||||
for (size_t i{0}; i < maximum_num_contexts_; i++) {
|
||||
bp->profiler[i].resetStats();
|
||||
}
|
||||
}
|
||||
@@ -208,7 +215,7 @@ void ROBackend::dump_backend_stats() {
|
||||
|
||||
auto *bp{backend_proxy.get()};
|
||||
|
||||
for (size_t i{0}; i < MAX_NUM_BLOCKS; i++) {
|
||||
for (size_t i{0}; i < maximum_num_contexts_; i++) {
|
||||
// Average latency as perceived from a thread
|
||||
const ROStats &prof{bp->profiler[i]};
|
||||
us_wait_slot += prof.getStat(WAITING_ON_SLOT) / gpu_frequency_mhz;
|
||||
|
||||
@@ -55,7 +55,9 @@ class ROHostContext;
|
||||
* the host (which is an inversion of the normal behavior).
|
||||
*/
|
||||
class ROBackend : public Backend {
|
||||
const unsigned MAX_NUM_BLOCKS{65536};
|
||||
using RetBufferProxyT = DeviceProxy<HIPAllocator, uint64_t>;
|
||||
using StatusProxyT =
|
||||
DeviceProxy<HIPDefaultFinegrainedAllocator, char>;
|
||||
|
||||
public:
|
||||
/**
|
||||
@@ -270,6 +272,25 @@ class ROBackend : public Backend {
|
||||
* @brief Number of MPI windows used for device contexts in RO Backend
|
||||
*/
|
||||
size_t num_windows_{32};
|
||||
|
||||
/**
|
||||
* @brief Return buffer for rocshmem_g API
|
||||
*/
|
||||
RetBufferProxyT g_ret_buffer_;
|
||||
|
||||
/**
|
||||
* @brief Return buffer for rocshmem atomic return APIs
|
||||
*/
|
||||
RetBufferProxyT atomic_ret_buffer_;
|
||||
|
||||
/**
|
||||
* This buffer is used by the GPU to wait on a blocking operation. The initial
|
||||
* value is 0. When a GPU enqueues a blocking operation, it waits for this
|
||||
* value to resolve to 1, which is set by the CPU when the blocking
|
||||
* operation completes. The GPU then resets status back to zero. There is
|
||||
* a separate status variable for each work-item in a RO Context
|
||||
*/
|
||||
StatusProxyT status_;
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -38,10 +38,8 @@ struct BlockHandle {
|
||||
volatile uint64_t write_index{};
|
||||
volatile uint64_t *host_read_index{};
|
||||
volatile char *status{nullptr};
|
||||
char *g_ret{nullptr};
|
||||
atomic_ret_t atomic_ret{};
|
||||
IpcImpl ipc{};
|
||||
HdpPolicy *hdp{};
|
||||
void *g_ret{nullptr};
|
||||
void *atomic_ret{nullptr};
|
||||
volatile uint64_t lock{};
|
||||
};
|
||||
|
||||
@@ -52,9 +50,8 @@ class DefaultBlockHandleProxy {
|
||||
public:
|
||||
DefaultBlockHandleProxy() = default;
|
||||
|
||||
DefaultBlockHandleProxy(char *g_ret, atomic_ret_t *atomic_ret, Queue *queue,
|
||||
IpcImpl *ipc_policy, HdpPolicy *hdp_policy,
|
||||
size_t num_elems = 1)
|
||||
DefaultBlockHandleProxy(void *g_ret, void *atomic_ret, Queue *queue,
|
||||
volatile char *status, size_t num_elems = 1)
|
||||
: proxy_{num_elems} {
|
||||
|
||||
// TODO(bpotter): create a default queue for this queue descriptor
|
||||
@@ -66,13 +63,9 @@ class DefaultBlockHandleProxy {
|
||||
block_handle->read_index = queue_descriptor->read_index;
|
||||
block_handle->write_index = queue_descriptor->write_index;
|
||||
block_handle->host_read_index = &queue_descriptor->read_index;
|
||||
block_handle->status = queue_descriptor->status;
|
||||
block_handle->status = status;
|
||||
block_handle->g_ret = g_ret;
|
||||
block_handle->atomic_ret.atomic_base_ptr = atomic_ret->atomic_base_ptr;
|
||||
block_handle->atomic_ret.atomic_counter = 0;
|
||||
block_handle->ipc.ipc_bases = ipc_policy->ipc_bases;
|
||||
block_handle->ipc.shm_size = ipc_policy->shm_size;
|
||||
block_handle->hdp = hdp_policy;
|
||||
block_handle->atomic_ret = atomic_ret;
|
||||
block_handle->lock = 0;
|
||||
}
|
||||
|
||||
@@ -99,27 +92,24 @@ class BlockHandleProxy {
|
||||
public:
|
||||
BlockHandleProxy() = default;
|
||||
|
||||
BlockHandleProxy(char *g_ret, atomic_ret_t *atomic_ret, Queue *queue,
|
||||
IpcImpl *ipc_policy, HdpPolicy *hdp_policy,
|
||||
size_t max_blocks)
|
||||
BlockHandleProxy(void *g_ret, void *atomic_ret, Queue *queue, size_t offset,
|
||||
volatile char *status, size_t max_blocks)
|
||||
: proxy_{max_blocks} {
|
||||
|
||||
for (size_t i{0}; i < max_blocks; i++) {
|
||||
auto queue_descriptor{queue->descriptor(i)};
|
||||
auto block_handle{&proxy_.get()[i]};
|
||||
size_t block_offset{i * offset};
|
||||
block_handle->profiler.resetStats();
|
||||
block_handle->queue = queue->elements(i);
|
||||
block_handle->queue_size = queue->size();
|
||||
block_handle->read_index = queue_descriptor->read_index;
|
||||
block_handle->write_index = queue_descriptor->write_index;
|
||||
block_handle->host_read_index = &queue_descriptor->read_index;
|
||||
block_handle->status = queue_descriptor->status;
|
||||
block_handle->g_ret = g_ret;
|
||||
block_handle->atomic_ret.atomic_base_ptr = atomic_ret->atomic_base_ptr;
|
||||
block_handle->atomic_ret.atomic_counter = 0;
|
||||
block_handle->ipc.ipc_bases = ipc_policy->ipc_bases;
|
||||
block_handle->ipc.shm_size = ipc_policy->shm_size;
|
||||
block_handle->hdp = hdp_policy;
|
||||
block_handle->status = status + block_offset;
|
||||
block_handle->g_ret = reinterpret_cast<uint64_t*>(g_ret) + block_offset;
|
||||
block_handle->atomic_ret = reinterpret_cast<uint64_t*>(atomic_ret) +
|
||||
block_offset;
|
||||
block_handle->lock = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ __device__ void ROContext::putmem(void *dest, const void *source, size_t nelems,
|
||||
}
|
||||
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true);
|
||||
ro_net_win_id, block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ __device__ void ROContext::getmem(void *dest, const void *source, size_t nelems,
|
||||
}
|
||||
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true);
|
||||
ro_net_win_id, block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,18 +134,21 @@ __device__ void ROContext::getmem_nbi(void *dest, const void *source,
|
||||
|
||||
__device__ void ROContext::fence() {
|
||||
build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true);
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
|
||||
true, get_status_flag());
|
||||
}
|
||||
|
||||
__device__ void ROContext::fence(int pe) {
|
||||
// TODO(khamidou): need to check if per pe has any special handling
|
||||
build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true);
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
|
||||
true, get_status_flag());
|
||||
}
|
||||
|
||||
__device__ void ROContext::quiet() {
|
||||
build_queue_element(RO_NET_QUIET, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true);
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
|
||||
true, get_status_flag());
|
||||
}
|
||||
|
||||
__device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
|
||||
@@ -163,7 +166,7 @@ __device__ void ROContext::barrier_all() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true);
|
||||
block_handle, true, get_status_flag());
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -172,7 +175,7 @@ __device__ void ROContext::sync_all() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER_ALL, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true);
|
||||
block_handle, true, get_status_flag());
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -182,7 +185,7 @@ __device__ void ROContext::sync(rocshmem_team_t team) {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
|
||||
true);
|
||||
true, get_status_flag());
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -195,7 +198,7 @@ __device__ void ROContext::ctx_destroy() {
|
||||
|
||||
build_queue_element(RO_NET_FINALIZE, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true);
|
||||
block_handle, true, get_status_flag());
|
||||
|
||||
int buffer_id = ro_net_win_id;
|
||||
backend->queue_.descriptor(buffer_id)->write_index = block_handle->write_index;
|
||||
@@ -219,7 +222,7 @@ __device__ void ROContext::putmem_wg(void *dest, const void *source,
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true);
|
||||
ro_net_win_id, block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -237,7 +240,7 @@ __device__ void ROContext::getmem_wg(void *dest, const void *source,
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true);
|
||||
ro_net_win_id, block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -291,7 +294,7 @@ __device__ void ROContext::putmem_wave(void *dest, const void *source,
|
||||
if (is_thread_zero_in_wave()) {
|
||||
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true);
|
||||
ro_net_win_id, block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -309,7 +312,7 @@ __device__ void ROContext::getmem_wave(void *dest, const void *source,
|
||||
if (is_thread_zero_in_wave()) {
|
||||
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true);
|
||||
ro_net_win_id, block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -578,7 +581,8 @@ __device__ void build_queue_element(
|
||||
ro_net_cmds type, void *dst, void *src, size_t size, int pe,
|
||||
int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync,
|
||||
MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle,
|
||||
bool blocking, ROCSHMEM_OP op, ro_net_types datatype) {
|
||||
bool blocking, volatile char *status, ROCSHMEM_OP op,
|
||||
ro_net_types datatype) {
|
||||
auto write_slot{next_write_slot(handle)};
|
||||
auto queue_element = &handle->queue[write_slot];
|
||||
|
||||
@@ -587,6 +591,9 @@ __device__ void build_queue_element(
|
||||
queue_element->ol1.size = size;
|
||||
queue_element->dst = dst;
|
||||
queue_element->ro_net_win_id = ro_net_win_id;
|
||||
if(blocking) {
|
||||
queue_element->status = status;
|
||||
}
|
||||
|
||||
if (type == RO_NET_P) {
|
||||
memcpy(&queue_element->src, src, size);
|
||||
@@ -594,9 +601,6 @@ __device__ void build_queue_element(
|
||||
queue_element->src = src;
|
||||
}
|
||||
|
||||
auto threadId {get_flat_id()};
|
||||
queue_element->threadId = threadId;
|
||||
|
||||
if (type == RO_NET_AMO_FOP) {
|
||||
queue_element->op = op;
|
||||
queue_element->datatype = datatype;
|
||||
@@ -655,19 +659,31 @@ __device__ void build_queue_element(
|
||||
if (blocking) {
|
||||
int network_status{0};
|
||||
do {
|
||||
refresh_volatile_sbyte(&network_status, &handle->status[threadId]);
|
||||
refresh_volatile_sbyte(&network_status, queue_element->status);
|
||||
} while (network_status == 0);
|
||||
|
||||
handle->status[threadId] = 0;
|
||||
*(queue_element->status) = 0;
|
||||
__threadfence();
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint64_t *ROContext::get_unused_atomic() {
|
||||
auto index{atomicAdd(&block_handle->atomic_ret.atomic_counter, 1)};
|
||||
index = index % max_nb_atomic;
|
||||
auto atomic_base_ptr{block_handle->atomic_ret.atomic_base_ptr};
|
||||
return &atomic_base_ptr[index];
|
||||
__device__ uint64_t *ROContext::get_atomic_ret_buf() {
|
||||
uint64_t *atomic_base_ptr{
|
||||
reinterpret_cast<uint64_t*>(block_handle->atomic_ret)};
|
||||
int thread_id{get_flat_block_id()};
|
||||
return &atomic_base_ptr[thread_id];
|
||||
}
|
||||
|
||||
__device__ uint64_t *ROContext::get_g_ret_buf() {
|
||||
uint64_t *g_ret{reinterpret_cast<uint64_t*>(block_handle->g_ret)};
|
||||
int thread_id{get_flat_block_id()};
|
||||
return &g_ret[thread_id];
|
||||
}
|
||||
|
||||
__device__ volatile char *ROContext::get_status_flag() {
|
||||
volatile char* status{block_handle->status};
|
||||
int thread_id{get_flat_block_id()};
|
||||
return &status[thread_id];
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -34,7 +34,7 @@ __device__ void build_queue_element(
|
||||
ro_net_cmds type, void *dst, void *src, size_t size, int pe,
|
||||
int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync,
|
||||
MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle,
|
||||
bool blocking, ROCSHMEM_OP op = ROCSHMEM_SUM,
|
||||
bool blocking, volatile char *status = nullptr, ROCSHMEM_OP op = ROCSHMEM_SUM,
|
||||
ro_net_types datatype = RO_NET_INT);
|
||||
|
||||
class ROContext : public Context {
|
||||
@@ -251,7 +251,11 @@ class ROContext : public Context {
|
||||
__device__ uint64_t signal_fetch_wave(const uint64_t *sig_addr);
|
||||
|
||||
private:
|
||||
__device__ uint64_t *get_unused_atomic();
|
||||
__device__ uint64_t *get_atomic_ret_buf();
|
||||
|
||||
__device__ uint64_t *get_g_ret_buf();
|
||||
|
||||
__device__ volatile char *get_status_flag();
|
||||
|
||||
BlockHandle *block_handle{nullptr};
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ __device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
|
||||
|
||||
build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast<T *>(source),
|
||||
nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm,
|
||||
ro_net_win_id, block_handle, true, Op, GetROType<T>::Type);
|
||||
ro_net_win_id, block_handle, true, get_status_flag(),
|
||||
Op, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
return ROCSHMEM_SUCCESS;
|
||||
@@ -137,8 +138,8 @@ __device__ void ROContext::to_all(T *dest, const T *source, int nreduce,
|
||||
|
||||
build_queue_element(RO_NET_TO_ALL, dest, const_cast<T *>(source), nreduce,
|
||||
PE_start, logPE_stride, PE_size, 0, pWrk, pSync,
|
||||
(MPI_Comm)NULL, ro_net_win_id, block_handle, true, Op,
|
||||
GetROType<T>::Type);
|
||||
(MPI_Comm)NULL, ro_net_win_id, block_handle, true,
|
||||
get_status_flag(), Op, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -166,7 +167,7 @@ __device__ void ROContext::p(T *dest, T value, int pe) {
|
||||
} else {
|
||||
build_queue_element(RO_NET_P, dest, &value, sizeof(T), pe, 0, 0, 0, nullptr,
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true);
|
||||
block_handle, true, get_status_flag());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,12 +180,7 @@ __device__ T ROContext::g(const T *source, int pe) {
|
||||
ipcImpl_.ipcCopy(&dest, ipcImpl_.ipc_bases[pe] + L_offset, sizeof(T));
|
||||
return dest;
|
||||
} else {
|
||||
int thread_id{get_flat_block_id()};
|
||||
int block_size{get_flat_block_size()};
|
||||
int offset{get_flat_grid_id() * block_size + thread_id};
|
||||
|
||||
char *base_dest{block_handle->g_ret};
|
||||
char *dest{&base_dest[offset * sizeof(int64_t)]};
|
||||
auto dest{get_g_ret_buf()};
|
||||
get<T>(reinterpret_cast<T *>(dest), source, 1, pe);
|
||||
return *(reinterpret_cast<T *>(dest));
|
||||
}
|
||||
@@ -206,12 +202,13 @@ __device__ void ROContext::get_nbi(T *dest, const T *source, size_t nelems,
|
||||
|
||||
template <typename T>
|
||||
__device__ T ROContext::amo_fetch_cas(void *dst, T value, T cond, int pe) {
|
||||
auto source{get_unused_atomic()};
|
||||
auto source{get_atomic_ret_buf()};
|
||||
build_queue_element(RO_NET_AMO_FCAS, dst, reinterpret_cast<T *>(source),
|
||||
value, pe, 0, 0, 0,
|
||||
reinterpret_cast<void *>(static_cast<long long>(cond)),
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle, true,
|
||||
ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
|
||||
true, get_status_flag(), ROCSHMEM_SUM,
|
||||
GetROType<T>::Type);
|
||||
__threadfence();
|
||||
return *source;
|
||||
}
|
||||
@@ -223,11 +220,11 @@ __device__ void ROContext::amo_cas(void *dst, T value, T cond, int pe) {
|
||||
|
||||
template <typename T>
|
||||
__device__ T ROContext::amo_fetch_add(void *dst, T value, int pe) {
|
||||
auto source{get_unused_atomic()};
|
||||
auto source{get_atomic_ret_buf()};
|
||||
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<T *>(source), value,
|
||||
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true, ROCSHMEM_SUM,
|
||||
GetROType<T>::Type);
|
||||
ro_net_win_id, block_handle, true, get_status_flag(),
|
||||
ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
__threadfence();
|
||||
return *source;
|
||||
}
|
||||
@@ -239,11 +236,11 @@ __device__ void ROContext::amo_add(void *dst, T value, int pe) {
|
||||
|
||||
template <typename T>
|
||||
__device__ T ROContext::amo_swap(void *dst, T value, int pe) {
|
||||
auto source{get_unused_atomic()};
|
||||
auto source{get_atomic_ret_buf()};
|
||||
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
|
||||
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true, ROCSHMEM_REPLACE,
|
||||
GetROType<T>::Type);
|
||||
ro_net_win_id, block_handle, true, get_status_flag(),
|
||||
ROCSHMEM_REPLACE, GetROType<T>::Type);
|
||||
__threadfence();
|
||||
return *source;
|
||||
}
|
||||
@@ -255,11 +252,11 @@ __device__ void ROContext::amo_set(void *dst, T value, int pe) {
|
||||
|
||||
template <typename T>
|
||||
__device__ T ROContext::amo_fetch_and(void *dst, T value, int pe) {
|
||||
auto source{get_unused_atomic()};
|
||||
auto source{get_atomic_ret_buf()};
|
||||
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
|
||||
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true, ROCSHMEM_AND,
|
||||
GetROType<T>::Type);
|
||||
ro_net_win_id, block_handle, true, get_status_flag(),
|
||||
ROCSHMEM_AND, GetROType<T>::Type);
|
||||
__threadfence();
|
||||
return *source;
|
||||
}
|
||||
@@ -271,11 +268,11 @@ __device__ void ROContext::amo_and(void *dst, T value, int pe) {
|
||||
|
||||
template <typename T>
|
||||
__device__ T ROContext::amo_fetch_or(void *dst, T value, int pe) {
|
||||
auto source{get_unused_atomic()};
|
||||
auto source{get_atomic_ret_buf()};
|
||||
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
|
||||
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true, ROCSHMEM_OR,
|
||||
GetROType<T>::Type);
|
||||
ro_net_win_id, block_handle, true, get_status_flag(),
|
||||
ROCSHMEM_OR, GetROType<T>::Type);
|
||||
__threadfence();
|
||||
return *source;
|
||||
}
|
||||
@@ -287,11 +284,11 @@ __device__ void ROContext::amo_or(void *dst, T value, int pe) {
|
||||
|
||||
template <typename T>
|
||||
__device__ T ROContext::amo_fetch_xor(void *dst, T value, int pe) {
|
||||
auto source{get_unused_atomic()};
|
||||
auto source{get_atomic_ret_buf()};
|
||||
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
|
||||
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
|
||||
ro_net_win_id, block_handle, true, ROCSHMEM_XOR,
|
||||
GetROType<T>::Type);
|
||||
ro_net_win_id, block_handle, true, get_status_flag(),
|
||||
ROCSHMEM_XOR, GetROType<T>::Type);
|
||||
__threadfence();
|
||||
return *source;
|
||||
}
|
||||
@@ -314,7 +311,7 @@ __device__ void ROContext::broadcast(rocshmem_team_t team, T *dest,
|
||||
build_queue_element(RO_NET_TEAM_BROADCAST, dest, const_cast<T *>(source),
|
||||
nelems, 0, 0, 0, pe_root, nullptr, nullptr,
|
||||
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
|
||||
ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -332,7 +329,7 @@ __device__ void ROContext::broadcast(T *dest, const T *source, int nelems,
|
||||
build_queue_element(RO_NET_BROADCAST, dest, const_cast<T *>(source), nelems,
|
||||
pe_start, log_pe_stride, pe_size, pe_root, nullptr,
|
||||
p_sync, (MPI_Comm)NULL, ro_net_win_id, block_handle, true,
|
||||
ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -350,7 +347,7 @@ __device__ void ROContext::alltoall(rocshmem_team_t team, T *dest,
|
||||
build_queue_element(RO_NET_ALLTOALL, dest, const_cast<T *>(source), nelems, 0,
|
||||
0, 0, 0, team_obj->ata_buffer, nullptr,
|
||||
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
|
||||
ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -368,7 +365,7 @@ __device__ void ROContext::fcollect(rocshmem_team_t team, T *dest,
|
||||
build_queue_element(RO_NET_FCOLLECT, dest, const_cast<T *>(source), nelems, 0,
|
||||
0, 0, 0, team_obj->ata_buffer, nullptr,
|
||||
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
|
||||
ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
get_status_flag(), ROCSHMEM_SUM, GetROType<T>::Type);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
case RO_NET_PUT:
|
||||
putMem(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF("Received PUT dst %p src %p size %lu pe %d win_id %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id);
|
||||
@@ -109,7 +109,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
|
||||
putMem(next_element.dst, source_buffer, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, true, true);
|
||||
next_element.status, true, true);
|
||||
DPRINTF("Received P dst %p value %p pe %d\n", next_element.dst,
|
||||
next_element.src, next_element.PE);
|
||||
break;
|
||||
@@ -117,14 +117,14 @@ void MPITransport::submitRequestsToMPI() {
|
||||
case RO_NET_GET:
|
||||
getMem(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF("Received GET dst %p src %p size %lu pe %d\n", next_element.dst,
|
||||
next_element.src, next_element.ol1.size, next_element.PE);
|
||||
break;
|
||||
case RO_NET_PUT_NBI:
|
||||
putMem(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, false);
|
||||
next_element.status, false);
|
||||
DPRINTF("Received PUT NBI dst %p src %p size %lu pe %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE);
|
||||
@@ -132,7 +132,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
case RO_NET_GET_NBI:
|
||||
getMem(next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, false);
|
||||
next_element.status, false);
|
||||
DPRINTF("Received GET NBI dst %p src %p size %lu pe %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.PE);
|
||||
@@ -141,7 +141,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
amoFOP(next_element.dst, next_element.src,
|
||||
const_cast<unsigned long long *>(&next_element.ol1.atomic_value),
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, true,
|
||||
next_element.status, true,
|
||||
static_cast<ROCSHMEM_OP>(next_element.op),
|
||||
static_cast<ro_net_types>(next_element.datatype));
|
||||
DPRINTF("Received AMO dst %p src %p Val %llu pe %d\n", next_element.dst,
|
||||
@@ -151,7 +151,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
amoFCAS(next_element.dst, next_element.src,
|
||||
const_cast<unsigned long long *>(&next_element.ol1.atomic_value),
|
||||
next_element.PE, next_element.ro_net_win_id, queue_idx,
|
||||
next_element.threadId, true,
|
||||
next_element.status, true,
|
||||
const_cast<void **>(&next_element.ol2.pWrk),
|
||||
static_cast<ro_net_types>(next_element.datatype));
|
||||
DPRINTF("Received F_CSWAP dst %p src %p Val %llu pe %d cond %ld\n",
|
||||
@@ -165,7 +165,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.team_comm,
|
||||
static_cast<ROCSHMEM_OP>(next_element.op),
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF("Received FLOAT_SUM_TEAM_REDUCE dst %p src %p size %lu team %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm);
|
||||
@@ -177,7 +177,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.PE_size, next_element.ol2.pWrk, next_element.pSync,
|
||||
static_cast<ROCSHMEM_OP>(next_element.op),
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF(
|
||||
"Received FLOAT_SUM_TO_ALL dst %p src %p size %lu "
|
||||
"PE_start %d, logPE_stride %d, PE_size %d, pWrk %p, pSync %p\n",
|
||||
@@ -190,7 +190,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.ro_net_win_id, queue_idx,
|
||||
next_element.team_comm, next_element.PE_root,
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF(
|
||||
"Received TEAM_BROADCAST dst %p src %p size %lu "
|
||||
"team %d, PE_root %d \n",
|
||||
@@ -203,7 +203,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.PE, next_element.logPE_stride,
|
||||
next_element.PE_size, next_element.PE_root, next_element.pSync,
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF(
|
||||
"Received BROADCAST dst %p src %p size %lu PE_start %d, "
|
||||
"logPE_stride %d, PE_size %d, PE_root %d, pSync %p\n",
|
||||
@@ -216,7 +216,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.ro_net_win_id, queue_idx, next_element.team_comm,
|
||||
next_element.ol2.pWrk,
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF("Received ALLTOALL dst %p src %p size %lu team %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm);
|
||||
@@ -226,26 +226,26 @@ void MPITransport::submitRequestsToMPI() {
|
||||
next_element.ro_net_win_id, queue_idx, next_element.team_comm,
|
||||
next_element.ol2.pWrk,
|
||||
static_cast<ro_net_types>(next_element.datatype),
|
||||
next_element.threadId, true);
|
||||
next_element.status, true);
|
||||
DPRINTF("Received FCOLLECT dst %p src %p size %lu team %d\n",
|
||||
next_element.dst, next_element.src, next_element.ol1.size,
|
||||
next_element.team_comm);
|
||||
break;
|
||||
case RO_NET_BARRIER_ALL:
|
||||
barrier(queue_idx, next_element.threadId, true, ro_net_comm_world);
|
||||
barrier(queue_idx, next_element.status, true, ro_net_comm_world);
|
||||
DPRINTF("Received Barrier_all\n");
|
||||
break;
|
||||
case RO_NET_SYNC:
|
||||
barrier(queue_idx, next_element.threadId, true, next_element.team_comm);
|
||||
barrier(queue_idx, next_element.status, true, next_element.team_comm);
|
||||
DPRINTF("Received Sync\n");
|
||||
break;
|
||||
case RO_NET_FENCE:
|
||||
case RO_NET_QUIET:
|
||||
quiet(queue_idx, next_element.threadId);
|
||||
quiet(queue_idx, next_element.status);
|
||||
DPRINTF("Received FENCE/QUIET\n");
|
||||
break;
|
||||
case RO_NET_FINALIZE:
|
||||
quiet(queue_idx, next_element.threadId);
|
||||
quiet(queue_idx, next_element.status);
|
||||
DPRINTF("Received Finalize\n");
|
||||
break;
|
||||
default:
|
||||
@@ -256,7 +256,7 @@ void MPITransport::submitRequestsToMPI() {
|
||||
}
|
||||
|
||||
void MPITransport::initTransport(int num_queues, BackendProxyT *proxy) {
|
||||
waiting_quiet.resize(num_queues, std::vector<int>());
|
||||
waiting_quiet.resize(num_queues, std::vector<volatile char *>());
|
||||
outstanding.resize(num_queues, 0);
|
||||
transport_up = false;
|
||||
|
||||
@@ -333,13 +333,13 @@ void MPITransport::global_exit(int status) {
|
||||
MPI_Abort(ro_net_comm_world, status);
|
||||
}
|
||||
|
||||
void MPITransport::barrier(int blockId, int threadId, bool blocking,
|
||||
MPI_Comm team) {
|
||||
void MPITransport::barrier(int contextId, volatile char *status, bool blocking,
|
||||
MPI_Comm team) {
|
||||
MPI_Request request{};
|
||||
NET_CHECK(MPI_Ibarrier(team, &request));
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
outstanding[blockId]++;
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
MPI_Op MPITransport::get_mpi_op(ROCSHMEM_OP op) {
|
||||
@@ -397,10 +397,10 @@ static MPI_Datatype convertType(ro_net_types type) {
|
||||
}
|
||||
|
||||
void MPITransport::reduction(void *dst, void *src, int size, int pe,
|
||||
int win_id, int blockId, int start, int logPstride,
|
||||
int sizePE, void *pWrk, long *pSync,
|
||||
ROCSHMEM_OP op, ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int win_id, int contextId, int start,
|
||||
int logPstride, int sizePE, void *pWrk,
|
||||
long *pSync, ROCSHMEM_OP op, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
MPI_Request request{};
|
||||
MPI_Op mpi_op{get_mpi_op(op)};
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
@@ -413,14 +413,15 @@ void MPITransport::reduction(void *dst, void *src, int size, int pe,
|
||||
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
|
||||
}
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
outstanding[blockId]++;
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::broadcast(void *dst, void *src, int size, int pe,
|
||||
int win_id, int blockId, int start, int logPstride,
|
||||
int sizePE, int root, long *pSync,
|
||||
ro_net_types type, int threadId, bool blocking) {
|
||||
int win_id, int contextId, int start,
|
||||
int logPstride, int sizePE, int root, long *pSync,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
MPI_Comm comm{createComm(start, 1 << logPstride, sizePE)};
|
||||
|
||||
int new_rank{};
|
||||
@@ -437,15 +438,15 @@ void MPITransport::broadcast(void *dst, void *src, int size, int pe,
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
|
||||
outstanding[blockId]++;
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int contextId, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char* status,
|
||||
bool blocking) {
|
||||
MPI_Request request{};
|
||||
|
||||
MPI_Op mpi_op{get_mpi_op(op)};
|
||||
@@ -459,15 +460,15 @@ void MPITransport::team_reduction(void *dst, void *src, int size, int win_id,
|
||||
NET_CHECK(MPI_Iallreduce(src, dst, size, mpi_type, mpi_op, comm, &request));
|
||||
}
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
|
||||
outstanding[blockId]++;
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, int root,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int contextId, MPI_Comm team, int root,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
MPI_Comm comm{team};
|
||||
int new_rank{};
|
||||
MPI_Comm_rank(comm, &new_rank);
|
||||
@@ -482,14 +483,15 @@ void MPITransport::team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
MPI_Request request{};
|
||||
NET_CHECK(MPI_Ibcast(data, size, mpi_type, root, comm, &request));
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
|
||||
outstanding[blockId]++;
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId, bool blocking) {
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
int pe_size{};
|
||||
NET_CHECK(MPI_Comm_size(team, &pe_size));
|
||||
|
||||
@@ -502,24 +504,24 @@ void MPITransport::alltoall(void *dst, void *src, int size, int win_id,
|
||||
#ifdef A2A_HEURISTICS
|
||||
if ((pe_size >= 8 || type_size * size < 2048) &&
|
||||
num_clust * clust_size == pe_size) {
|
||||
return alltoall_gcen(dst, src, size, win_id, blockId, team, ata_buffptr, type,
|
||||
threadId, blocking);
|
||||
return alltoall_gcen(dst, src, size, win_id, contextId, team, ata_buffptr,
|
||||
type, status, blocking);
|
||||
} else if (size <= 512) {
|
||||
#endif // A2A_HEURISTICS
|
||||
return alltoall_mpi(dst, src, size, blockId, team, ata_buffptr, type,
|
||||
threadId, blocking);
|
||||
return alltoall_mpi(dst, src, size, contextId, team, ata_buffptr, type,
|
||||
status, blocking);
|
||||
#ifdef A2A_HEURISTICS
|
||||
} else {
|
||||
return alltoall_broadcast(dst, src, size, win_id, blockId, team, ata_buffptr,
|
||||
type, threadId, blocking);
|
||||
return alltoall_broadcast(dst, src, size, win_id, contextId, team,
|
||||
ata_buffptr, type, status, blocking);
|
||||
}
|
||||
#endif // A2A_HEURISTICS
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_broadcast(void *dst, void *src, int size,
|
||||
int win_id, int blockId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking) {
|
||||
int win_id, int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
MPI_Comm comm{team};
|
||||
@@ -563,26 +565,26 @@ void MPITransport::alltoall_broadcast(void *dst, void *src, int size,
|
||||
NET_CHECK(MPI_Waitall(pe_size, pe_req.data(), MPI_STATUSES_IGNORE));
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
barrier(blockId, threadId, blocking, comm);
|
||||
barrier(contextId, status, blocking, comm);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_mpi(void *dst, void *src, int size, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
void MPITransport::alltoall_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
int new_rank{};
|
||||
NET_CHECK(MPI_Comm_rank(team, &new_rank));
|
||||
int pe_size{};
|
||||
NET_CHECK(MPI_Comm_size(team, &pe_size));
|
||||
MPI_Datatype mpi_type{convertType(type)};
|
||||
NET_CHECK(MPI_Alltoall(src, size, mpi_type, dst, size, mpi_type, team));
|
||||
quiet(blockId, threadId);
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
int new_rank{};
|
||||
@@ -664,14 +666,14 @@ void MPITransport::alltoall_gcen(void *dst, void *src, int size, int win_id,
|
||||
MPI_Comm comm_ring{createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust)};
|
||||
|
||||
barrier(blockId, threadId, false, comm_cluster);
|
||||
barrier(blockId, threadId, blocking, comm_ring);
|
||||
barrier(contextId, status, false, comm_cluster);
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// GPU-centric alltoall with in-place blocking synchronization
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
@@ -759,12 +761,13 @@ void MPITransport::alltoall_gcen2(void *dst, void *src, int size, int win_id,
|
||||
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust);
|
||||
// Now wait for completion
|
||||
barrier(blockId, threadId, blocking, comm_ring);
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId, bool blocking) {
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
int pe_size, type_size;
|
||||
MPI_Comm comm = team;
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
@@ -780,21 +783,21 @@ void MPITransport::fcollect(void *dst, void *src, int size, int win_id,
|
||||
// In most cases the MPI implementation is optimal
|
||||
// But it crashes for > 512 messages
|
||||
if (size <= 512) {
|
||||
fcollect_mpi(dst, src, size, blockId, team, ata_buffptr, type,
|
||||
threadId, blocking);
|
||||
fcollect_mpi(dst, src, size, contextId, team, ata_buffptr, type,
|
||||
status, blocking);
|
||||
} else if (num_clust * clust_size == pe_size) {
|
||||
fcollect_gcen(dst, src, size, win_id, blockId, team, ata_buffptr, type,
|
||||
threadId, blocking);
|
||||
fcollect_gcen(dst, src, size, win_id, contextId, team, ata_buffptr, type,
|
||||
status, blocking);
|
||||
} else {
|
||||
fcollect_broadcast(dst, src, size, win_id, blockId, team, ata_buffptr,
|
||||
type, threadId, blocking);
|
||||
fcollect_broadcast(dst, src, size, win_id, contextId, team, ata_buffptr,
|
||||
type, status, blocking);
|
||||
}
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_broadcast(void *dst, void *src, int size,
|
||||
int win_id, int blockId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking) {
|
||||
int win_id, int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// Broadcast implementation of fcollect
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
@@ -837,13 +840,13 @@ void MPITransport::fcollect_broadcast(void *dst, void *src, int size,
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
// Now wait for completion
|
||||
barrier(blockId, threadId, blocking, comm);
|
||||
barrier(contextId, status, blocking, comm);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_mpi(void *dst, void *src, int size, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
void MPITransport::fcollect_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) {
|
||||
// MPI's implementation of fcollect
|
||||
int new_rank, pe_size;
|
||||
|
||||
@@ -852,13 +855,13 @@ void MPITransport::fcollect_mpi(void *dst, void *src, int size, int blockId,
|
||||
NET_CHECK(MPI_Comm_rank(comm, &new_rank));
|
||||
NET_CHECK(MPI_Comm_size(comm, &pe_size));
|
||||
NET_CHECK(MPI_Allgather(src, size, mpi_type, dst, size, mpi_type, comm));
|
||||
quiet(blockId, threadId);
|
||||
quiet(contextId, status);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// GPU-centric implementation of fcollect
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
@@ -938,14 +941,14 @@ void MPITransport::fcollect_gcen(void *dst, void *src, int size, int win_id,
|
||||
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust);
|
||||
// Now wait for completion
|
||||
barrier(blockId, threadId, false, comm_cluster);
|
||||
barrier(blockId, threadId, blocking, comm_ring);
|
||||
barrier(contextId, status, false, comm_cluster);
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) {
|
||||
int contextId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) {
|
||||
// GPU-centric implementation with in-place, blocking synchronization
|
||||
auto *bp{backend_proxy->get()};
|
||||
int new_rank, pe_size;
|
||||
@@ -1027,12 +1030,12 @@ void MPITransport::fcollect_gcen2(void *dst, void *src, int size, int win_id,
|
||||
MPI_Comm comm_ring = createComm(world_ranks[new_rank % clust_size],
|
||||
stride * clust_size, num_clust);
|
||||
// Now wait for completion
|
||||
barrier(blockId, threadId, blocking, comm_ring);
|
||||
barrier(contextId, status, blocking, comm_ring);
|
||||
}
|
||||
|
||||
void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int blockId, int threadId, bool blocking,
|
||||
bool inline_data) {
|
||||
int contextId, volatile char *status, bool blocking,
|
||||
bool inline_data) {
|
||||
queue->flush_hdp();
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
@@ -1047,14 +1050,14 @@ void MPITransport::putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
// though it should be in the progress loop.
|
||||
NET_CHECK(MPI_Win_flush_all(bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
|
||||
outstanding[blockId]++;
|
||||
outstanding[contextId]++;
|
||||
}
|
||||
|
||||
void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id,
|
||||
int blockId, int threadId, bool blocking,
|
||||
ROCSHMEM_OP op, ro_net_types type) {
|
||||
int contextId, volatile char *status, bool blocking,
|
||||
ROCSHMEM_OP op, ro_net_types type) {
|
||||
queue->flush_hdp();
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
@@ -1069,14 +1072,14 @@ void MPITransport::amoFOP(void *dst, void *src, void *val, int pe, int win_id,
|
||||
// though it should be in the progress loop.
|
||||
NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
queue->notify(blockId, threadId);
|
||||
queue->notify(status);
|
||||
|
||||
queue->sfence_flush_hdp();
|
||||
}
|
||||
|
||||
void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe,
|
||||
int win_id, int blockId, int threadId, bool blocking,
|
||||
void *cond, ro_net_types type) {
|
||||
int win_id, int contextId, volatile char *status,
|
||||
bool blocking, void *cond, ro_net_types type) {
|
||||
queue->flush_hdp();
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
@@ -1091,14 +1094,15 @@ void MPITransport::amoFCAS(void *dst, void *src, void *val, int pe,
|
||||
// though it should be in the progress loop.
|
||||
NET_CHECK(MPI_Win_flush_local(pe, bp->heap_window_info[win_id]->get_win()));
|
||||
|
||||
queue->notify(blockId, threadId);
|
||||
queue->notify(status);
|
||||
|
||||
queue->sfence_flush_hdp();
|
||||
}
|
||||
|
||||
void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int blockId, int threadId, bool blocking) {
|
||||
outstanding[blockId]++;
|
||||
int contextId, volatile char *status,
|
||||
bool blocking) {
|
||||
outstanding[contextId]++;
|
||||
|
||||
auto *bp{backend_proxy->get()};
|
||||
MPI_Request request{};
|
||||
@@ -1106,7 +1110,7 @@ void MPITransport::getMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
dst, size, MPI_CHAR, pe, bp->heap_window_info[win_id]->get_offset(src),
|
||||
size, MPI_CHAR, bp->heap_window_info[win_id]->get_win(), &request));
|
||||
|
||||
requests.push_back({request, {threadId, blockId, blocking}});
|
||||
requests.push_back({request, {status, contextId, blocking}});
|
||||
}
|
||||
|
||||
std::unique_ptr<MPI_Request[]> MPITransport::raw_requests() {
|
||||
@@ -1139,20 +1143,20 @@ void MPITransport::progress() {
|
||||
auto *bp{backend_proxy->get()};
|
||||
for (int i{0}; i < outcount; i++) {
|
||||
int index{testsome_indices[i]};
|
||||
int blockId{requests[index].properties.blockId};
|
||||
int threadId{requests[index].properties.threadId};
|
||||
int contextId{requests[index].properties.contextId};
|
||||
volatile char *status{requests[index].properties.status};
|
||||
|
||||
if (blockId != -1) {
|
||||
outstanding[blockId]--;
|
||||
if (contextId != -1) {
|
||||
outstanding[contextId]--;
|
||||
DPRINTF(
|
||||
"Finished op for blockId %d at threadId %d "
|
||||
"Finished op for contextId %d at status addr %p "
|
||||
"(%d requests outstanding)\n",
|
||||
blockId, threadId, outstanding[blockId]);
|
||||
contextId, status, outstanding[contextId]);
|
||||
}
|
||||
|
||||
if (requests[index].properties.blocking) {
|
||||
if (blockId != -1) {
|
||||
queue->notify(blockId, threadId);
|
||||
if (contextId != -1) {
|
||||
queue->notify(status);
|
||||
}
|
||||
queue->sfence_flush_hdp();
|
||||
}
|
||||
@@ -1163,14 +1167,14 @@ void MPITransport::progress() {
|
||||
|
||||
// If the GPU has requested a quiet, notify it of completion when
|
||||
// all outstanding requests are complete.
|
||||
if (!outstanding[blockId] && !waiting_quiet[blockId].empty()) {
|
||||
for (const auto threadId : waiting_quiet[blockId]) {
|
||||
DPRINTF("Finished Quiet for blockId %d at threadId %d\n", blockId,
|
||||
threadId);
|
||||
queue->notify(blockId, threadId);
|
||||
if (!outstanding[contextId] && !waiting_quiet[contextId].empty()) {
|
||||
for (const auto status : waiting_quiet[contextId]) {
|
||||
DPRINTF("Finished Quiet for contextId %d at status addr %p\n", contextId,
|
||||
status);
|
||||
queue->notify(status);
|
||||
}
|
||||
|
||||
waiting_quiet[blockId].clear();
|
||||
waiting_quiet[contextId].clear();
|
||||
|
||||
queue->sfence_flush_hdp();
|
||||
}
|
||||
@@ -1185,15 +1189,15 @@ void MPITransport::progress() {
|
||||
}
|
||||
}
|
||||
|
||||
void MPITransport::quiet(int blockId, int threadId) {
|
||||
void MPITransport::quiet(int contextId, volatile char *status) {
|
||||
auto *bp{backend_proxy->get()};
|
||||
|
||||
if (!outstanding[blockId]) {
|
||||
DPRINTF("Finished Quiet immediately for blockId %d at threadId %d\n", blockId,
|
||||
threadId);
|
||||
queue->notify(blockId, threadId);
|
||||
if (!outstanding[contextId]) {
|
||||
DPRINTF("Finished Quiet immediately for contextId %d at status addr %p\n",
|
||||
contextId, status);
|
||||
queue->notify(status);
|
||||
} else {
|
||||
waiting_quiet[blockId].emplace_back(threadId);
|
||||
waiting_quiet[contextId].emplace_back(status);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ namespace rocshmem {
|
||||
class HostInterface;
|
||||
|
||||
class MPITransport : public Transport {
|
||||
public:
|
||||
public:
|
||||
explicit MPITransport(MPI_Comm com, Queue* queue);
|
||||
|
||||
virtual ~MPITransport();
|
||||
@@ -46,87 +46,92 @@ class MPITransport : public Transport {
|
||||
void finalizeTransport() override;
|
||||
|
||||
void createNewTeam(ROBackend *backend, Team *parent_team,
|
||||
TeamInfo *team_info_wrt_parent,
|
||||
TeamInfo *team_info_wrt_world, int num_pes,
|
||||
int my_pe_in_new_team, MPI_Comm team_comm,
|
||||
rocshmem_team_t *new_team) override;
|
||||
TeamInfo *team_info_wrt_parent,
|
||||
TeamInfo *team_info_wrt_world, int num_pes,
|
||||
int my_pe_in_new_team, MPI_Comm team_comm,
|
||||
rocshmem_team_t *new_team) override;
|
||||
|
||||
void barrier(int blockId, int threadId, bool blocking,
|
||||
MPI_Comm team) override;
|
||||
void barrier(int contextId, volatile char *status, bool blocking,
|
||||
MPI_Comm team) override;
|
||||
|
||||
void reduction(void *dst, void *src, int size, int pe, int win_id,
|
||||
int blockId, int start, int logPstride, int sizePE, void *pWrk,
|
||||
long *pSync, ROCSHMEM_OP op, ro_net_types type,
|
||||
int threadId, bool blocking) override;
|
||||
int contextId, int start, int logPstride, int sizePE,
|
||||
void *pWrk, long *pSync, ROCSHMEM_OP op, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void team_reduction(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, ROCSHMEM_OP op, ro_net_types type,
|
||||
int threadId, bool blocking) override;
|
||||
void team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) override;
|
||||
|
||||
void broadcast(void *dst, void *src, int size, int pe, int win_id,
|
||||
int blockId, int start, int logPstride, int sizePE,
|
||||
int PE_root, long *pSync, ro_net_types type, int threadId,
|
||||
bool blocking) override;
|
||||
int contextId, int start, int logPstride, int sizePE,
|
||||
int PE_root, long *pSync, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void team_broadcast(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, int PE_root, ro_net_types type,
|
||||
int threadId, bool blocking) override;
|
||||
void team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, int PE_root,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) override;
|
||||
|
||||
void alltoall(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking) override;
|
||||
void alltoall(void *dst, void *src, int size, int win_id, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void alltoall_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId, bool blocking);
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking);
|
||||
|
||||
void alltoall_mpi(void *dst, void *src, int size, int blockId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type, int threadId,
|
||||
bool blocking);
|
||||
void alltoall_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking);
|
||||
|
||||
void alltoall_gcen(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking);
|
||||
void alltoall_gcen(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status, bool blocking);
|
||||
|
||||
void alltoall_gcen2(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking);
|
||||
void alltoall_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status, bool blocking);
|
||||
|
||||
void fcollect(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking) override;
|
||||
void fcollect(void *dst, void *src, int size, int win_id, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) override;
|
||||
|
||||
void fcollect_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int blockId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, int threadId, bool blocking);
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking);
|
||||
|
||||
void fcollect_mpi(void *dst, void *src, int size, int blockId, MPI_Comm team,
|
||||
void *ata_buffptr, ro_net_types type, int threadId,
|
||||
bool blocking);
|
||||
void fcollect_mpi(void *dst, void *src, int size, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking);
|
||||
|
||||
void fcollect_gcen(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking);
|
||||
void fcollect_gcen(void *dst, void *src, int size, int win_id, int contextId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking);
|
||||
|
||||
void fcollect_gcen2(void *dst, void *src, int size, int win_id, int blockId,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking);
|
||||
void fcollect_gcen2(void *dst, void *src, int size, int win_id,
|
||||
int contextId, MPI_Comm team, void *ata_buffptr,
|
||||
ro_net_types type, volatile char *status, bool blocking);
|
||||
|
||||
void putMem(void *dst, void *src, int size, int pe, int win_id, int blockId,
|
||||
int threadId, bool blocking, bool inline_data = false) override;
|
||||
void putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int contextId, volatile char *status, bool blocking,
|
||||
bool inline_data = false) override;
|
||||
|
||||
void amoFOP(void *dst, void *src, void *val, int pe, int win_id, int blockId,
|
||||
int threadId, bool blocking, ROCSHMEM_OP op,
|
||||
ro_net_types type) override;
|
||||
void amoFOP(void *dst, void *src, void *val, int pe, int win_id,
|
||||
int contextId, volatile char *status, bool blocking,
|
||||
ROCSHMEM_OP op, ro_net_types type) override;
|
||||
|
||||
void amoFCAS(void *dst, void *src, void *val, int pe, int win_id, int blockId,
|
||||
int threadId, bool blocking, void *cond,
|
||||
ro_net_types type) override;
|
||||
void amoFCAS(void *dst, void *src, void *val, int pe, int win_id,
|
||||
int contextId, volatile char *status, bool blocking, void *cond,
|
||||
ro_net_types type) override;
|
||||
|
||||
void getMem(void *dst, void *src, int size, int pe, int win_id, int blockId,
|
||||
int threadId, bool blocking) override;
|
||||
void getMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int contextId, volatile char *status, bool blocking) override;
|
||||
|
||||
void quiet(int blockId, int threadId) override;
|
||||
void quiet(int contextId, volatile char *status) override;
|
||||
|
||||
void progress() override;
|
||||
|
||||
@@ -142,7 +147,7 @@ class MPITransport : public Transport {
|
||||
|
||||
HostInterface *host_interface{nullptr};
|
||||
|
||||
private:
|
||||
private:
|
||||
struct CommKey {
|
||||
CommKey(int _start, int _logPstride, int _size)
|
||||
: start(_start), logPstride(_logPstride), size(_size) {}
|
||||
@@ -160,23 +165,23 @@ class MPITransport : public Transport {
|
||||
};
|
||||
|
||||
struct RequestProperties {
|
||||
RequestProperties(int _threadId, int _blockId, bool _blocking, void *_src,
|
||||
bool _inline_data)
|
||||
: threadId(_threadId),
|
||||
blockId(_blockId),
|
||||
RequestProperties(volatile char *_status, int _contextId, bool _blocking,
|
||||
void *_src, bool _inline_data)
|
||||
: status(_status),
|
||||
contextId(_contextId),
|
||||
blocking(_blocking),
|
||||
src(_src),
|
||||
inline_data(_inline_data) {}
|
||||
|
||||
RequestProperties(int _threadId, int _blockId, bool _blocking)
|
||||
: threadId(_threadId),
|
||||
blockId(_blockId),
|
||||
RequestProperties(volatile char *_status, int _contextId, bool _blocking)
|
||||
: status(_status),
|
||||
contextId(_contextId),
|
||||
blocking(_blocking),
|
||||
src(nullptr),
|
||||
inline_data(false) {}
|
||||
|
||||
int threadId{-1};
|
||||
int blockId{-1};
|
||||
volatile char* status{nullptr};
|
||||
int contextId{-1};
|
||||
bool blocking{};
|
||||
void *src{nullptr};
|
||||
bool inline_data{};
|
||||
@@ -202,7 +207,7 @@ class MPITransport : public Transport {
|
||||
// Unordered vector of in-flight MPI Requests. Can complete out of order.
|
||||
std::vector<Request> requests{};
|
||||
|
||||
std::vector<std::vector<int> > waiting_quiet{};
|
||||
std::vector<std::vector<volatile char *> > waiting_quiet{};
|
||||
|
||||
std::vector<int> outstanding{};
|
||||
|
||||
|
||||
@@ -33,12 +33,11 @@ Queue::Queue() {
|
||||
}
|
||||
}
|
||||
|
||||
Queue::Queue(size_t max_queues, size_t max_wg_size, size_t queue_size)
|
||||
Queue::Queue(size_t max_queues, size_t queue_size)
|
||||
: max_queues_{max_queues},
|
||||
max_wg_size_{max_wg_size},
|
||||
queue_size_{queue_size},
|
||||
queue_proxy_{max_queues, queue_size},
|
||||
queue_desc_proxy_{max_queues, max_wg_size} {
|
||||
queue_desc_proxy_{max_queues} {
|
||||
|
||||
gpu_queue = true;
|
||||
char *value{nullptr};
|
||||
@@ -101,8 +100,8 @@ void Queue::sfence_flush_hdp() {
|
||||
}
|
||||
}
|
||||
|
||||
void Queue::notify(int blockId, int threadId) {
|
||||
descriptor(blockId)->status[threadId] = 1;
|
||||
void Queue::notify(volatile char* status) {
|
||||
*status = 1;
|
||||
}
|
||||
|
||||
uint64_t Queue::size() {
|
||||
|
||||
@@ -35,7 +35,7 @@ class Queue {
|
||||
public:
|
||||
Queue();
|
||||
|
||||
Queue(size_t max_queues, size_t max_threads_per_block, size_t queue_size);
|
||||
Queue(size_t max_queues, size_t queue_size);
|
||||
|
||||
bool process(uint64_t queue_index, MPITransport* transport);
|
||||
|
||||
@@ -47,7 +47,7 @@ class Queue {
|
||||
|
||||
void sfence_flush_hdp();
|
||||
|
||||
void notify(int blockId, int threadId);
|
||||
void notify(volatile char *status);
|
||||
|
||||
uint64_t size();
|
||||
|
||||
@@ -72,8 +72,6 @@ class Queue {
|
||||
|
||||
size_t max_queues_{};
|
||||
|
||||
size_t max_wg_size_{};
|
||||
|
||||
size_t queue_size_{};
|
||||
};
|
||||
|
||||
|
||||
@@ -45,38 +45,22 @@ typedef struct queue_desc {
|
||||
*/
|
||||
uint64_t write_index;
|
||||
char padding2[56];
|
||||
/**
|
||||
* This bit is used by the GPU to wait on a blocking operation. The initial
|
||||
* value is 0. When a GPU enqueues a blocking operation, it waits for this
|
||||
* value to resolve to 1, which is set by the CPU when the blocking
|
||||
* operation completes. The GPU then resets status back to zero. There is
|
||||
* a separate status variable for each work-item in a work-group
|
||||
*/
|
||||
char *status;
|
||||
char padding3[56];
|
||||
} __attribute__((__aligned__(64))) queue_desc_t;
|
||||
|
||||
template <typename ALLOCATOR>
|
||||
class QueueDescProxy {
|
||||
using ProxyT = DeviceProxy<ALLOCATOR, queue_desc_t>;
|
||||
using ProxyStatusT = DeviceProxy<ALLOCATOR, char>;
|
||||
|
||||
public:
|
||||
QueueDescProxy() = default;
|
||||
|
||||
QueueDescProxy(size_t max_queues, size_t max_threads_per_queue)
|
||||
: max_queues_{max_queues}, max_threads_per_queue_{max_threads_per_queue},
|
||||
max_threads_{max_queues * max_threads_per_queue}, proxy_{max_queues},
|
||||
proxy_status_{max_queues * max_threads_per_queue} {
|
||||
auto *status{proxy_status_.get()};
|
||||
size_t status_bytes{sizeof(char) * max_threads_};
|
||||
memset(status, 0, status_bytes);
|
||||
QueueDescProxy(size_t max_queues)
|
||||
: max_queues_{max_queues}, proxy_{max_queues} {
|
||||
|
||||
auto *queue_descs{proxy_.get()};
|
||||
for (size_t i{0}; i < max_queues_; i++) {
|
||||
queue_descs[i].read_index = 0;
|
||||
queue_descs[i].write_index = 0;
|
||||
queue_descs[i].status = status + i * max_threads_per_queue_;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,13 +77,7 @@ class QueueDescProxy {
|
||||
private:
|
||||
ProxyT proxy_{};
|
||||
|
||||
ProxyStatusT proxy_status_{};
|
||||
|
||||
size_t max_queues_{};
|
||||
|
||||
size_t max_threads_per_queue_{};
|
||||
|
||||
size_t max_threads_{};
|
||||
};
|
||||
|
||||
using QueueDescProxyT = QueueDescProxy<HIPDefaultFinegrainedAllocator>;
|
||||
|
||||
@@ -57,13 +57,13 @@ typedef struct queue_element {
|
||||
void *src{nullptr};
|
||||
void *dst{nullptr};
|
||||
int ro_net_win_id{-1};
|
||||
int threadId{-1};
|
||||
int logPE_stride{-1};
|
||||
int PE_size{-1};
|
||||
long *pSync{nullptr};
|
||||
int op{-1};
|
||||
int datatype{-1};
|
||||
int PE_root{-1};
|
||||
volatile char *status{nullptr};
|
||||
MPI_Comm team_comm{};
|
||||
union {
|
||||
size_t size;
|
||||
|
||||
@@ -45,60 +45,61 @@ class Transport {
|
||||
virtual void finalizeTransport() = 0;
|
||||
|
||||
virtual void createNewTeam(ROBackend *backend_handle, Team *parent_team,
|
||||
TeamInfo *team_info_wrt_parent,
|
||||
TeamInfo *team_info_wrt_world, int num_pes,
|
||||
int my_pe_in_new_team, MPI_Comm team_comm,
|
||||
rocshmem_team_t *new_team) = 0;
|
||||
TeamInfo *team_info_wrt_parent,
|
||||
TeamInfo *team_info_wrt_world, int num_pes,
|
||||
int my_pe_in_new_team, MPI_Comm team_comm,
|
||||
rocshmem_team_t *new_team) = 0;
|
||||
|
||||
virtual void barrier(int wg_id, int threadId, bool blocking,
|
||||
MPI_Comm team) = 0;
|
||||
virtual void barrier(int wg_id, volatile char *status, bool blocking,
|
||||
MPI_Comm team) = 0;
|
||||
|
||||
virtual void reduction(void *dst, void *src, int size, int pe, int win_id,
|
||||
int wg_id, int start, int logPstride, int sizePE,
|
||||
void *pWrk, long *pSync, ROCSHMEM_OP op,
|
||||
ro_net_types type, int threadId, bool blocking) = 0;
|
||||
int wg_id, int start, int logPstride, int sizePE,
|
||||
void *pWrk, long *pSync, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) = 0;
|
||||
|
||||
virtual void team_reduction(void *dst, void *src, int size, int win_id,
|
||||
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) = 0;
|
||||
int wg_id, MPI_Comm team, ROCSHMEM_OP op,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) = 0;
|
||||
|
||||
virtual void broadcast(void *dst, void *src, int size, int pe, int win_id,
|
||||
int wg_id, int start, int logPstride, int sizePE,
|
||||
int PE_root, long *pSync, ro_net_types type,
|
||||
int threadId, bool blocking) = 0;
|
||||
int wg_id, int start, int logPstride, int sizePE,
|
||||
int PE_root, long *pSync, ro_net_types type,
|
||||
volatile char *status, bool blocking) = 0;
|
||||
|
||||
virtual void team_broadcast(void *dst, void *src, int size, int win_id,
|
||||
int wg_id, MPI_Comm team, int PE_root,
|
||||
ro_net_types type, int threadId,
|
||||
bool blocking) = 0;
|
||||
int wg_id, MPI_Comm team, int PE_root,
|
||||
ro_net_types type, volatile char *status,
|
||||
bool blocking) = 0;
|
||||
|
||||
virtual void alltoall(void *dst, void *src, int size, int win_id, int wg_id,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking) = 0;
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) = 0;
|
||||
|
||||
virtual void fcollect(void *dst, void *src, int size, int win_id, int wg_id,
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
int threadId, bool blocking) = 0;
|
||||
MPI_Comm team, void *ata_buffptr, ro_net_types type,
|
||||
volatile char *status, bool blocking) = 0;
|
||||
|
||||
virtual void putMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int wg_id, int threadId, bool blocking,
|
||||
bool inline_data = false) = 0;
|
||||
int wg_id, volatile char *status, bool blocking,
|
||||
bool inline_data = false) = 0;
|
||||
|
||||
virtual void getMem(void *dst, void *src, int size, int pe, int win_id,
|
||||
int wg_id, int threadId, bool blocking) = 0;
|
||||
int wg_id, volatile char *status, bool blocking) = 0;
|
||||
|
||||
virtual void amoFOP(void *dst, void *src, void *val, int pe, int win_id,
|
||||
int wg_id, int threadId, bool blocking, ROCSHMEM_OP op,
|
||||
ro_net_types type) = 0;
|
||||
int wg_id, volatile char *status, bool blocking,
|
||||
ROCSHMEM_OP op, ro_net_types type) = 0;
|
||||
|
||||
virtual void amoFCAS(void *dst, void *src, void *val, int pe, int win_id,
|
||||
int wg_id, int threadId, bool blocking, void *cond,
|
||||
ro_net_types type) = 0;
|
||||
int wg_id, volatile char *status, bool blocking,
|
||||
void *cond, ro_net_types type) = 0;
|
||||
|
||||
virtual bool readyForFinalize() = 0;
|
||||
|
||||
virtual void quiet(int wg_id, int threadId) = 0;
|
||||
virtual void quiet(int wg_id, volatile char *status) = 0;
|
||||
|
||||
virtual void progress() = 0;
|
||||
|
||||
|
||||
Yeni konuda referans
Bir kullanıcı engelle