ipc/to_all: add direct allreduce algorithm
add a simple version of an allreduce algorithm as a starting point.
[ROCm/rocshmem commit: ba21cb7b85]
This commit is contained in:
@@ -29,80 +29,10 @@
|
||||
#include "gpu_ib_team.hpp"
|
||||
#include "queue_pair.hpp"
|
||||
#include "../util.hpp"
|
||||
#include "../roc_shmem_calc.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
// clang-format off
|
||||
NOWARN(-Wunused-parameter,
|
||||
template <ROC_SHMEM_OP Op>
|
||||
struct OpWrap {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
static_assert(true, "Unimplemented gpu_ib collective.");
|
||||
}
|
||||
};
|
||||
)
|
||||
// clang-format on
|
||||
|
||||
/******************************************************************************
|
||||
************************** TEMPLATE SPECIALIZATIONS **************************
|
||||
*****************************************************************************/
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_SUM> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] += src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_MAX> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] = max(dst[i], src[i]);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_MIN> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] = min(dst[i], src[i]);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_PROD> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] *= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_AND> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] &= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_OR> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] |= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_XOR> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] ^= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void compute_reduce(T *src, T *dst, int size, int wg_id,
|
||||
int wg_size) {
|
||||
|
||||
@@ -239,6 +239,12 @@ class IPCContext : public Context {
|
||||
__device__ void internal_atomic_barrier(int pe, int PE_start, int stride,
|
||||
int n_pes, int64_t *pSync);
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void internal_direct_allreduce(T *dst, const T *src,
|
||||
int nelems, int PE_start, int
|
||||
logPE_stride, int PE_size,
|
||||
T *pWrk, long *pSync);
|
||||
|
||||
//internal functions used by collectives routines to write/read to
|
||||
//work/sync buffers
|
||||
__device__ void internal_putmem(void *dest, const void *source,
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include "context_ipc_device.hpp"
|
||||
#include "../util.hpp"
|
||||
#include "ipc_team.hpp"
|
||||
#include "../roc_shmem_calc.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
@@ -152,11 +153,88 @@ __device__ T IPCContext::amo_fetch_cas(void *dest, T value, T cond, int pe) {
|
||||
}
|
||||
|
||||
// Collectives
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void compute_reduce(T *src, T *dst, int size, int wg_id,
|
||||
int wg_size) {
|
||||
for (size_t i = wg_id; i < size; i += wg_size) {
|
||||
OpWrap<Op>::Calc(src, dst, i);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::internal_direct_allreduce(
|
||||
T *dst, const T *src, int nelems, int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk,
|
||||
long *pSync) { // NOLINT(runtime/int)
|
||||
|
||||
int stride = 1 << logPE_stride;
|
||||
int finish = PE_start + stride * PE_size;
|
||||
int pe = my_pe;
|
||||
|
||||
int wg_id = get_flat_block_id();
|
||||
int wg_size = get_flat_block_size();
|
||||
int64_t flag_val = 1;
|
||||
|
||||
for (int i = wg_id; i < nelems; i += wg_size) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = PE_start; i < finish; i += stride) {
|
||||
if (i != pe) {
|
||||
putmem_nbi_wg(&pWrk[pe * nelems], reinterpret_cast<const void *>(src),
|
||||
nelems * sizeof(T), i);
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
fence();
|
||||
put_nbi(&pSync[pe], &flag_val, 1, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
threadfence_system();
|
||||
__syncthreads();
|
||||
|
||||
// Do the compute and pSync reset in parallel.
|
||||
for (int i = PE_start; i < finish; i += stride) {
|
||||
if (i != pe) {
|
||||
// Wait for leader thread to see that the buffer is ready.
|
||||
if (is_thread_zero_in_block()) {
|
||||
wait_until(&pSync[i], ROC_SHMEM_CMP_EQ, flag_val);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
T *ptr = &pWrk[i * nelems];
|
||||
compute_reduce<T, Op>(ptr, dst, nelems, wg_id, wg_size);
|
||||
threadfence_system();
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = wg_id; i < num_pes; i += wg_size) {
|
||||
pSync[i] = ROC_SHMEM_SYNC_VALUE;
|
||||
}
|
||||
threadfence_system();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
//to_all<T, Op>(dest, source, nreduce, pe_start, log_pe_stride, pe_size, pWrk,
|
||||
// p_sync);
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
* Ensure that the stride is a multiple of 2 for GPU_IB.
|
||||
*/
|
||||
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_size = team_obj->tinfo_wrt_world->size;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
|
||||
|
||||
to_all<T, Op>(dest, source, nreduce, pe_start, log_pe_stride, pe_size, pWrk,
|
||||
p_sync);
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
@@ -164,6 +242,8 @@ __device__ void IPCContext::to_all(T *dest, const T *source, int nreduce,
|
||||
int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk,
|
||||
long *pSync) { // NOLINT(runtime/int)
|
||||
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, logPE_stride,
|
||||
PE_size, pWrk, pSync);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef LIBRARY_SRC_ROC_SHMEM_CALC_HPP_
|
||||
#define LIBRARY_SRC_ROC_SHMEM_CALC_HPP_
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
// clang-format off
|
||||
NOWARN(-Wunused-parameter,
|
||||
template <ROC_SHMEM_OP Op>
|
||||
struct OpWrap {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
static_assert(true, "Unimplemented ipc collective.");
|
||||
}
|
||||
};
|
||||
)
|
||||
// clang-format on
|
||||
|
||||
/******************************************************************************
|
||||
************************** TEMPLATE SPECIALIZATIONS **************************
|
||||
*****************************************************************************/
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_SUM> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] += src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_MAX> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] = max(dst[i], src[i]);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_MIN> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] = min(dst[i], src[i]);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_PROD> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] *= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_AND> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] &= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_OR> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] |= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OpWrap<ROC_SHMEM_XOR> {
|
||||
template <typename T>
|
||||
__device__ static void Calc(T *src, T *dst, int i) {
|
||||
dst[i] ^= src[i];
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
#endif // LIBRARY_SRC_ROC_SHMEM_CALC_HPP_
|
||||
Reference in New Issue
Block a user