From 5f0f2f6e85fe9e1d48656d6ef8e6dec38aa5f652 Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Fri, 27 Sep 2024 21:25:41 +0000 Subject: [PATCH] ipc/to_all: add direct allreduce algorithm add a simple version of an allreduce algorithm as a starting point. [ROCm/rocshmem commit: ba21cb7b85729f4aa839dfde9ae9a101502ee7e7] --- .../src/gpu_ib/context_ib_tmpl_device.hpp | 72 +------------ .../rocshmem/src/ipc/context_ipc_device.hpp | 6 ++ .../src/ipc/context_ipc_tmpl_device.hpp | 84 ++++++++++++++- projects/rocshmem/src/roc_shmem_calc.hpp | 100 ++++++++++++++++++ 4 files changed, 189 insertions(+), 73 deletions(-) create mode 100644 projects/rocshmem/src/roc_shmem_calc.hpp diff --git a/projects/rocshmem/src/gpu_ib/context_ib_tmpl_device.hpp b/projects/rocshmem/src/gpu_ib/context_ib_tmpl_device.hpp index a3f463a342..c575cb0832 100644 --- a/projects/rocshmem/src/gpu_ib/context_ib_tmpl_device.hpp +++ b/projects/rocshmem/src/gpu_ib/context_ib_tmpl_device.hpp @@ -29,80 +29,10 @@ #include "gpu_ib_team.hpp" #include "queue_pair.hpp" #include "../util.hpp" +#include "../roc_shmem_calc.hpp" namespace rocshmem { -// clang-format off -NOWARN(-Wunused-parameter, -template -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - static_assert(true, "Unimplemented gpu_ib collective."); - } -}; -) -// clang-format on - -/****************************************************************************** - ************************** TEMPLATE SPECIALIZATIONS ************************** - *****************************************************************************/ -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] += src[i]; - } -}; - -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] = max(dst[i], src[i]); - } -}; - -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] = min(dst[i], src[i]); - } -}; - -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] *= src[i]; - } -}; - -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] &= src[i]; - } -}; - -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] |= src[i]; - } -}; - -template <> -struct OpWrap { - template - __device__ static void Calc(T *src, T *dst, int i) { - dst[i] ^= src[i]; - } -}; - template __device__ void compute_reduce(T *src, T *dst, int size, int wg_id, int wg_size) { diff --git a/projects/rocshmem/src/ipc/context_ipc_device.hpp b/projects/rocshmem/src/ipc/context_ipc_device.hpp index 2a7aaebd62..6bf8885d6c 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.hpp @@ -239,6 +239,12 @@ class IPCContext : public Context { __device__ void internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); + template + __device__ void internal_direct_allreduce(T *dst, const T *src, + int nelems, int PE_start, int + logPE_stride, int PE_size, + T *pWrk, long *pSync); + //internal functions used by collectives routines to write/read to //work/sync buffers __device__ void internal_putmem(void *dest, const void *source, diff --git a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp index 91bdbd45e7..33bb91a1f1 100644 --- a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp @@ -28,6 +28,7 @@ #include "context_ipc_device.hpp" #include "../util.hpp" #include "ipc_team.hpp" +#include "../roc_shmem_calc.hpp" namespace rocshmem { @@ -152,11 +153,88 @@ __device__ T IPCContext::amo_fetch_cas(void *dest, T value, T cond, int pe) { } // Collectives +template +__device__ void compute_reduce(T *src, T *dst, int size, int wg_id, + int wg_size) { + for (size_t i = wg_id; i < size; i += wg_size) { + OpWrap::Calc(src, dst, i); + } + __syncthreads(); +} + +template +__device__ void IPCContext::internal_direct_allreduce( + T *dst, const T *src, int nelems, int PE_start, int logPE_stride, + int PE_size, T *pWrk, + long *pSync) { // NOLINT(runtime/int) + + int stride = 1 << logPE_stride; + int finish = PE_start + stride * PE_size; + int pe = my_pe; + + int wg_id = get_flat_block_id(); + int wg_size = get_flat_block_size(); + int64_t flag_val = 1; + + for (int i = wg_id; i < nelems; i += wg_size) { + dst[i] = src[i]; + } + __syncthreads(); + + for (int i = PE_start; i < finish; i += stride) { + if (i != pe) { + putmem_nbi_wg(&pWrk[pe * nelems], reinterpret_cast(src), + nelems * sizeof(T), i); + + if (is_thread_zero_in_block()) { + fence(); + put_nbi(&pSync[pe], &flag_val, 1, i); + } + } + } + threadfence_system(); + __syncthreads(); + + // Do the compute and pSync reset in parallel. + for (int i = PE_start; i < finish; i += stride) { + if (i != pe) { + // Wait for leader thread to see that the buffer is ready. + if (is_thread_zero_in_block()) { + wait_until(&pSync[i], ROC_SHMEM_CMP_EQ, flag_val); + } + __syncthreads(); + + T *ptr = &pWrk[i * nelems]; + compute_reduce(ptr, dst, nelems, wg_id, wg_size); + threadfence_system(); + } + } + + __syncthreads(); + + for (int i = wg_id; i < num_pes; i += wg_size) { + pSync[i] = ROC_SHMEM_SYNC_VALUE; + } + threadfence_system(); + __syncthreads(); +} + template __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest, const T *source, int nreduce) { - //to_all(dest, source, nreduce, pe_start, log_pe_stride, pe_size, pWrk, - // p_sync); + IPCTeam *team_obj = reinterpret_cast(team); + + /** + * Ensure that the stride is a multiple of 2 for GPU_IB. + */ + int log_pe_stride = static_cast(team_obj->tinfo_wrt_world->log_stride); + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_size = team_obj->tinfo_wrt_world->size; + long *p_sync = team_obj->barrier_pSync; + T *pWrk = reinterpret_cast(team_obj->pWrk); + + to_all(dest, source, nreduce, pe_start, log_pe_stride, pe_size, pWrk, + p_sync); } template @@ -164,6 +242,8 @@ __device__ void IPCContext::to_all(T *dest, const T *source, int nreduce, int PE_start, int logPE_stride, int PE_size, T *pWrk, long *pSync) { // NOLINT(runtime/int) + internal_direct_allreduce(dest, source, nreduce, PE_start, logPE_stride, + PE_size, pWrk, pSync); } template diff --git a/projects/rocshmem/src/roc_shmem_calc.hpp b/projects/rocshmem/src/roc_shmem_calc.hpp new file mode 100644 index 0000000000..5420cde3ea --- /dev/null +++ b/projects/rocshmem/src/roc_shmem_calc.hpp @@ -0,0 +1,100 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef LIBRARY_SRC_ROC_SHMEM_CALC_HPP_ +#define LIBRARY_SRC_ROC_SHMEM_CALC_HPP_ + +namespace rocshmem { + +// clang-format off +NOWARN(-Wunused-parameter, +template +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + static_assert(true, "Unimplemented ipc collective."); + } +}; +) +// clang-format on + +/****************************************************************************** + ************************** TEMPLATE SPECIALIZATIONS ************************** + *****************************************************************************/ +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] += src[i]; + } +}; + +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] = max(dst[i], src[i]); + } +}; + +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] = min(dst[i], src[i]); + } +}; + +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] *= src[i]; + } +}; + +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] &= src[i]; + } +}; + +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] |= src[i]; + } +}; + +template <> +struct OpWrap { + template + __device__ static void Calc(T *src, T *dst, int i) { + dst[i] ^= src[i]; + } +}; + +} +#endif // LIBRARY_SRC_ROC_SHMEM_CALC_HPP_