Merge pull request #50 from avinashkethineedi/teams_interface
Update collective APIs to use teams interface
[ROCm/rocshmem commit: 3edf881b40]
Этот коммит содержится в:
@@ -196,11 +196,6 @@ class IPCContext : public Context {
|
||||
char* g_ret;
|
||||
|
||||
//internal functions used by collective operations
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void internal_to_all(T *dest, const T *source, int nreduce, int PE_start,
|
||||
int stride, int PE_size, T *pWrk,
|
||||
long *pSync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T>
|
||||
__device__ void internal_broadcast(T *dest, const T *source, int nelems, int pe_root,
|
||||
int pe_start, int stride, int pe_size,
|
||||
@@ -234,12 +229,10 @@ class IPCContext : public Context {
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void internal_direct_allreduce(T *dst, const T *src,
|
||||
int nelems, int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk, long *pSync);
|
||||
int nelems, IPCTeam *team_obj);
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void internal_ring_allreduce(T *dst, const T *src,
|
||||
int nelems, int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk, long *pSync,
|
||||
int nelems, IPCTeam *team_obj,
|
||||
int n_seg, int seg_size, int chunk_size);
|
||||
|
||||
//internal functions used by collectives routines to write/read to
|
||||
|
||||
@@ -101,13 +101,9 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
|
||||
__device__ void IPCContext::sync(roc_shmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
* Ensure that the stride is a multiple of 2.
|
||||
*/
|
||||
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_stride = (1 << log_pe_stride);
|
||||
int pe_stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_size = team_obj->num_pes;
|
||||
|
||||
internal_sync(pe, pe_start, pe_stride, pe_size, barrier_sync);
|
||||
|
||||
@@ -164,9 +164,13 @@ __device__ void compute_reduce(T *src, T *dst, int size, int wg_id,
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::internal_direct_allreduce(
|
||||
T *dst, const T *src, int nelems, int PE_start, int stride,
|
||||
int PE_size, T *pWrk,
|
||||
long *pSync) { // NOLINT(runtime/int)
|
||||
T *dst, const T *src, int nelems, IPCTeam *team_obj) { // NOLINT(runtime/int)
|
||||
|
||||
int stride = team_obj->tinfo_wrt_world->stride;
|
||||
int PE_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int PE_size = team_obj->tinfo_wrt_world->size;
|
||||
long *pSync = team_obj->barrier_pSync;
|
||||
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
|
||||
|
||||
int finish = PE_start + stride * PE_size;
|
||||
int pe = my_pe;
|
||||
@@ -276,12 +280,20 @@ __device__ void IPCContext::internal_direct_allreduce(
|
||||
*/
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::internal_ring_allreduce(
|
||||
T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start,
|
||||
[[maybe_unused]] int stride, [[maybe_unused]] int PE_size, T *pWrk,
|
||||
long *pSync, // NOLINT(runtime/int)
|
||||
T *dst, const T *src, int nelems, IPCTeam *team_obj, // NOLINT(runtime/int)
|
||||
int n_seg, int seg_size, int chunk_size) {
|
||||
|
||||
int stride = team_obj->tinfo_wrt_world->stride;
|
||||
int PE_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int PE_size = team_obj->tinfo_wrt_world->size;
|
||||
long *pSync = team_obj->barrier_pSync;
|
||||
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
|
||||
int my_pe_in_team = team_obj->my_pe;
|
||||
|
||||
int off_seg, off_send, off_recv;
|
||||
int send_pe = (my_pe + 1) % num_pes;
|
||||
int send_pe = (my_pe_in_team + 1) % PE_size;
|
||||
// send_pe is relative to team, convert it relative to team world
|
||||
send_pe = team_obj->get_pe_in_world(send_pe);
|
||||
long wait_val; // NOLINT(runtime/int)
|
||||
|
||||
int wg_size = get_flat_block_size();
|
||||
@@ -295,9 +307,9 @@ __device__ void IPCContext::internal_ring_allreduce(
|
||||
for (size_t seg = 0; seg < n_seg; seg++) {
|
||||
off_seg = seg * seg_size;
|
||||
// Loop 2 in the algorithm above
|
||||
for (int iter = 0; iter < num_pes - 1; iter++) {
|
||||
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
for (int iter = 0; iter < PE_size - 1; iter++) {
|
||||
off_send = (((my_pe_in_team + 1 - iter + 2 * PE_size) % PE_size) * chunk_size);
|
||||
off_recv = (((my_pe_in_team - iter + 2 * PE_size) % PE_size) * chunk_size);
|
||||
|
||||
internal_putmem_wg(reinterpret_cast<void *>(&pWrk[off_send]),
|
||||
reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
@@ -319,8 +331,8 @@ __device__ void IPCContext::internal_ring_allreduce(
|
||||
}
|
||||
|
||||
// Loop 2 in the example above
|
||||
for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) {
|
||||
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
for (size_t iter = PE_size - 1; iter < 2 * PE_size - 2; iter++) {
|
||||
off_send = (((my_pe_in_team + 1 - iter + 2 * PE_size) % PE_size) * chunk_size);
|
||||
putmem_nbi_wg(reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
chunk_size * sizeof(T), send_pe);
|
||||
@@ -350,76 +362,58 @@ __device__ int IPCContext::reduce(roc_shmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
* Ensure that the stride is a multiple of 2 for GPU_IB.
|
||||
*/
|
||||
int stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_size = team_obj->tinfo_wrt_world->size;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
|
||||
int PE_size = team_obj->tinfo_wrt_world->size;
|
||||
|
||||
internal_to_all<T, Op>(dest, source, nreduce, pe_start, stride, pe_size, pWrk,
|
||||
p_sync);
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::internal_to_all(T *dest, const T *source, int nreduce,
|
||||
int PE_start, int stride,
|
||||
int PE_size, T *pWrk,
|
||||
long *pSync) { // NOLINT(runtime/int)
|
||||
size_t direct_pWrk = num_pes * nreduce;
|
||||
size_t direct_pSync = num_pes;
|
||||
size_t ring_pSync = 2 * num_pes;
|
||||
size_t direct_pWrk = PE_size * nreduce;
|
||||
size_t direct_pSync = PE_size;
|
||||
size_t ring_pSync = 2 * PE_size;
|
||||
size_t provided_pWrk = max(nreduce / 2 + 1, ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE);
|
||||
size_t provided_pSync = ROC_SHMEM_REDUCE_SYNC_SIZE;
|
||||
|
||||
if (provided_pWrk >= direct_pWrk && provided_pSync >= direct_pSync) {
|
||||
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, stride,
|
||||
PE_size, pWrk, pSync);
|
||||
internal_direct_allreduce<T, Op>(dest, source, nreduce, team_obj);
|
||||
} else {
|
||||
if (ring_pSync <= ROC_SHMEM_REDUCE_SYNC_SIZE) {
|
||||
size_t ring_pWrk = ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE;
|
||||
// integer division truncating value
|
||||
int chunk_size = ring_pWrk / num_pes;
|
||||
int seg_size = chunk_size * num_pes;
|
||||
int chunk_size = ring_pWrk / PE_size;
|
||||
int seg_size = chunk_size * PE_size;
|
||||
|
||||
// integer division truncating value
|
||||
int n_seg = nreduce / seg_size;
|
||||
// integer division rounding up
|
||||
int n_seg_up = (nreduce + (seg_size -1)) / seg_size;
|
||||
int n_seg_up = (nreduce - 1) / seg_size + 1;
|
||||
// recalculate chunk_size
|
||||
chunk_size = seg_size / num_pes;
|
||||
chunk_size = seg_size / PE_size;
|
||||
if (n_seg == 0) {
|
||||
n_seg = 1;
|
||||
}
|
||||
internal_ring_allreduce<T, Op>(dest, source, nreduce, PE_start,
|
||||
stride, PE_size, pWrk, pSync, n_seg,
|
||||
internal_ring_allreduce<T, Op>(dest, source, nreduce, team_obj, n_seg,
|
||||
seg_size, chunk_size);
|
||||
if (n_seg_up > n_seg) {
|
||||
T *p_dst = (dest + (n_seg * seg_size));
|
||||
const T *p_src = (source + (n_seg * seg_size));
|
||||
int p_count = nreduce - (n_seg * seg_size);
|
||||
int p_chunk = p_count / num_pes;
|
||||
int p_chunk = p_count / PE_size;
|
||||
|
||||
internal_ring_allreduce<T, Op>(p_dst, p_src, p_count, PE_start, stride,
|
||||
PE_size, pWrk, pSync, 1, (p_chunk * num_pes), p_chunk);
|
||||
internal_ring_allreduce<T, Op>(p_dst, p_src, p_count, team_obj, 1,
|
||||
(p_chunk * PE_size), p_chunk);
|
||||
|
||||
if ((p_chunk * num_pes) < p_count) {
|
||||
if ((p_chunk * PE_size) < p_count) {
|
||||
// Final elements need to use direct_allreduce
|
||||
p_count -= (p_chunk * num_pes);
|
||||
p_dst += (p_chunk * num_pes);
|
||||
const T *p_src2 = p_src + (p_chunk * num_pes);
|
||||
p_count -= (p_chunk * PE_size);
|
||||
p_dst += (p_chunk * PE_size);
|
||||
const T *p_src2 = p_src + (p_chunk * PE_size);
|
||||
|
||||
internal_direct_allreduce<T, Op>(p_dst, p_src2, p_count, PE_start, stride,
|
||||
PE_size, pWrk, pSync);
|
||||
internal_direct_allreduce<T, Op>(p_dst, p_src2, p_count, team_obj);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GPU_DPRINTF("Unsupported reduction size for IPC conduit.\n");
|
||||
return ROC_SHMEM_ERROR;
|
||||
}
|
||||
}
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -449,9 +443,6 @@ __device__ void IPCContext::broadcast(roc_shmem_team_t team, T *dst,
|
||||
const T *src, int nelems, int pe_root) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
* Ensure that the stride is a multiple of 2 .
|
||||
*/
|
||||
int stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_size = team_obj->tinfo_wrt_world->size;
|
||||
@@ -490,13 +481,9 @@ __device__ void IPCContext::alltoall_linear(roc_shmem_team_t team, T *dst,
|
||||
const T *src, int nelems) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
* Ensure that the stride is a multiple of 2
|
||||
*/
|
||||
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_size = team_obj->num_pes;
|
||||
int stride = 1 << log_pe_stride;
|
||||
int stride = team_obj->tinfo_wrt_world->stride;
|
||||
long *pSync = team_obj->alltoall_pSync;
|
||||
int my_pe_in_team = team_obj->my_pe;
|
||||
|
||||
@@ -523,13 +510,9 @@ __device__ void IPCContext::fcollect_linear(roc_shmem_team_t team, T *dst,
|
||||
const T *src, int nelems) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
* Ensure that the stride is a multiple of 2.
|
||||
*/
|
||||
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_size = team_obj->num_pes;
|
||||
int stride = 1 << log_pe_stride;
|
||||
int stride = team_obj->tinfo_wrt_world->stride;
|
||||
long *pSync = team_obj->alltoall_pSync;
|
||||
int my_pe_in_team = team_obj->my_pe;
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user