Merge pull request #50 from avinashkethineedi/teams_interface

Update collective APIs to use teams interface

[ROCm/rocshmem commit: 3edf881b40]
Этот коммит содержится в:
Avinash Kethineedi
2024-11-12 15:31:42 -06:00
коммит произвёл GitHub
родитель bae2b2aece e4e18e31bb
Коммит 772a1f7f3f
3 изменённых файлов: 49 добавлений и 77 удалений
+2 -9
Просмотреть файл
@@ -196,11 +196,6 @@ class IPCContext : public Context {
char* g_ret;
//internal functions used by collective operations
template <typename T, ROC_SHMEM_OP Op>
__device__ void internal_to_all(T *dest, const T *source, int nreduce, int PE_start,
int stride, int PE_size, T *pWrk,
long *pSync); // NOLINT(runtime/int)
template <typename T>
__device__ void internal_broadcast(T *dest, const T *source, int nelems, int pe_root,
int pe_start, int stride, int pe_size,
@@ -234,12 +229,10 @@ class IPCContext : public Context {
template <typename T, ROC_SHMEM_OP Op>
__device__ void internal_direct_allreduce(T *dst, const T *src,
int nelems, int PE_start, int logPE_stride,
int PE_size, T *pWrk, long *pSync);
int nelems, IPCTeam *team_obj);
template <typename T, ROC_SHMEM_OP Op>
__device__ void internal_ring_allreduce(T *dst, const T *src,
int nelems, int PE_start, int logPE_stride,
int PE_size, T *pWrk, long *pSync,
int nelems, IPCTeam *team_obj,
int n_seg, int seg_size, int chunk_size);
//internal functions used by collectives routines to write/read to
+1 -5
Просмотреть файл
@@ -101,13 +101,9 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
__device__ void IPCContext::sync(roc_shmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
* Ensure that the stride is a multiple of 2.
*/
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
int pe = team_obj->my_pe_in_world;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_stride = (1 << log_pe_stride);
int pe_stride = team_obj->tinfo_wrt_world->stride;
int pe_size = team_obj->num_pes;
internal_sync(pe, pe_start, pe_stride, pe_size, barrier_sync);
+46 -63
Просмотреть файл
@@ -164,9 +164,13 @@ __device__ void compute_reduce(T *src, T *dst, int size, int wg_id,
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::internal_direct_allreduce(
T *dst, const T *src, int nelems, int PE_start, int stride,
int PE_size, T *pWrk,
long *pSync) { // NOLINT(runtime/int)
T *dst, const T *src, int nelems, IPCTeam *team_obj) { // NOLINT(runtime/int)
int stride = team_obj->tinfo_wrt_world->stride;
int PE_start = team_obj->tinfo_wrt_world->pe_start;
int PE_size = team_obj->tinfo_wrt_world->size;
long *pSync = team_obj->barrier_pSync;
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
int finish = PE_start + stride * PE_size;
int pe = my_pe;
@@ -276,12 +280,20 @@ __device__ void IPCContext::internal_direct_allreduce(
*/
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::internal_ring_allreduce(
T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start,
[[maybe_unused]] int stride, [[maybe_unused]] int PE_size, T *pWrk,
long *pSync, // NOLINT(runtime/int)
T *dst, const T *src, int nelems, IPCTeam *team_obj, // NOLINT(runtime/int)
int n_seg, int seg_size, int chunk_size) {
int stride = team_obj->tinfo_wrt_world->stride;
int PE_start = team_obj->tinfo_wrt_world->pe_start;
int PE_size = team_obj->tinfo_wrt_world->size;
long *pSync = team_obj->barrier_pSync;
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
int my_pe_in_team = team_obj->my_pe;
int off_seg, off_send, off_recv;
int send_pe = (my_pe + 1) % num_pes;
int send_pe = (my_pe_in_team + 1) % PE_size;
// send_pe is relative to team, convert it relative to team world
send_pe = team_obj->get_pe_in_world(send_pe);
long wait_val; // NOLINT(runtime/int)
int wg_size = get_flat_block_size();
@@ -295,9 +307,9 @@ __device__ void IPCContext::internal_ring_allreduce(
for (size_t seg = 0; seg < n_seg; seg++) {
off_seg = seg * seg_size;
// Loop 2 in the algorithm above
for (int iter = 0; iter < num_pes - 1; iter++) {
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size);
for (int iter = 0; iter < PE_size - 1; iter++) {
off_send = (((my_pe_in_team + 1 - iter + 2 * PE_size) % PE_size) * chunk_size);
off_recv = (((my_pe_in_team - iter + 2 * PE_size) % PE_size) * chunk_size);
internal_putmem_wg(reinterpret_cast<void *>(&pWrk[off_send]),
reinterpret_cast<void *>(&dst[off_send + off_seg]),
@@ -319,8 +331,8 @@ __device__ void IPCContext::internal_ring_allreduce(
}
// Loop 2 in the example above
for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) {
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
for (size_t iter = PE_size - 1; iter < 2 * PE_size - 2; iter++) {
off_send = (((my_pe_in_team + 1 - iter + 2 * PE_size) % PE_size) * chunk_size);
putmem_nbi_wg(reinterpret_cast<void *>(&dst[off_send + off_seg]),
reinterpret_cast<void *>(&dst[off_send + off_seg]),
chunk_size * sizeof(T), send_pe);
@@ -350,76 +362,58 @@ __device__ int IPCContext::reduce(roc_shmem_team_t team, T *dest,
const T *source, int nreduce) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
* Ensure that the stride is a multiple of 2 for GPU_IB.
*/
int stride = team_obj->tinfo_wrt_world->stride;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_size = team_obj->tinfo_wrt_world->size;
long *p_sync = team_obj->barrier_pSync;
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
int PE_size = team_obj->tinfo_wrt_world->size;
internal_to_all<T, Op>(dest, source, nreduce, pe_start, stride, pe_size, pWrk,
p_sync);
return ROC_SHMEM_SUCCESS;
}
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::internal_to_all(T *dest, const T *source, int nreduce,
int PE_start, int stride,
int PE_size, T *pWrk,
long *pSync) { // NOLINT(runtime/int)
size_t direct_pWrk = num_pes * nreduce;
size_t direct_pSync = num_pes;
size_t ring_pSync = 2 * num_pes;
size_t direct_pWrk = PE_size * nreduce;
size_t direct_pSync = PE_size;
size_t ring_pSync = 2 * PE_size;
size_t provided_pWrk = max(nreduce / 2 + 1, ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE);
size_t provided_pSync = ROC_SHMEM_REDUCE_SYNC_SIZE;
if (provided_pWrk >= direct_pWrk && provided_pSync >= direct_pSync) {
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, stride,
PE_size, pWrk, pSync);
internal_direct_allreduce<T, Op>(dest, source, nreduce, team_obj);
} else {
if (ring_pSync <= ROC_SHMEM_REDUCE_SYNC_SIZE) {
size_t ring_pWrk = ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE;
// integer division truncating value
int chunk_size = ring_pWrk / num_pes;
int seg_size = chunk_size * num_pes;
int chunk_size = ring_pWrk / PE_size;
int seg_size = chunk_size * PE_size;
// integer division truncating value
int n_seg = nreduce / seg_size;
// integer division rounding up
int n_seg_up = (nreduce + (seg_size -1)) / seg_size;
int n_seg_up = (nreduce - 1) / seg_size + 1;
// recalculate chunk_size
chunk_size = seg_size / num_pes;
chunk_size = seg_size / PE_size;
if (n_seg == 0) {
n_seg = 1;
}
internal_ring_allreduce<T, Op>(dest, source, nreduce, PE_start,
stride, PE_size, pWrk, pSync, n_seg,
internal_ring_allreduce<T, Op>(dest, source, nreduce, team_obj, n_seg,
seg_size, chunk_size);
if (n_seg_up > n_seg) {
T *p_dst = (dest + (n_seg * seg_size));
const T *p_src = (source + (n_seg * seg_size));
int p_count = nreduce - (n_seg * seg_size);
int p_chunk = p_count / num_pes;
int p_chunk = p_count / PE_size;
internal_ring_allreduce<T, Op>(p_dst, p_src, p_count, PE_start, stride,
PE_size, pWrk, pSync, 1, (p_chunk * num_pes), p_chunk);
internal_ring_allreduce<T, Op>(p_dst, p_src, p_count, team_obj, 1,
(p_chunk * PE_size), p_chunk);
if ((p_chunk * num_pes) < p_count) {
if ((p_chunk * PE_size) < p_count) {
// Final elements need to use direct_allreduce
p_count -= (p_chunk * num_pes);
p_dst += (p_chunk * num_pes);
const T *p_src2 = p_src + (p_chunk * num_pes);
p_count -= (p_chunk * PE_size);
p_dst += (p_chunk * PE_size);
const T *p_src2 = p_src + (p_chunk * PE_size);
internal_direct_allreduce<T, Op>(p_dst, p_src2, p_count, PE_start, stride,
PE_size, pWrk, pSync);
internal_direct_allreduce<T, Op>(p_dst, p_src2, p_count, team_obj);
}
}
} else {
GPU_DPRINTF("Unsupported reduction size for IPC conduit.\n");
return ROC_SHMEM_ERROR;
}
}
return ROC_SHMEM_SUCCESS;
}
template <typename T>
@@ -449,9 +443,6 @@ __device__ void IPCContext::broadcast(roc_shmem_team_t team, T *dst,
const T *src, int nelems, int pe_root) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
* Ensure that the stride is a multiple of 2 .
*/
int stride = team_obj->tinfo_wrt_world->stride;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_size = team_obj->tinfo_wrt_world->size;
@@ -490,13 +481,9 @@ __device__ void IPCContext::alltoall_linear(roc_shmem_team_t team, T *dst,
const T *src, int nelems) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
* Ensure that the stride is a multiple of 2
*/
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_size = team_obj->num_pes;
int stride = 1 << log_pe_stride;
int stride = team_obj->tinfo_wrt_world->stride;
long *pSync = team_obj->alltoall_pSync;
int my_pe_in_team = team_obj->my_pe;
@@ -523,13 +510,9 @@ __device__ void IPCContext::fcollect_linear(roc_shmem_team_t team, T *dst,
const T *src, int nelems) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
* Ensure that the stride is a multiple of 2.
*/
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_size = team_obj->num_pes;
int stride = 1 << log_pe_stride;
int stride = team_obj->tinfo_wrt_world->stride;
long *pSync = team_obj->alltoall_pSync;
int my_pe_in_team = team_obj->my_pe;