From 258d264ecc239fb709e430dfbc37f48907307290 Mon Sep 17 00:00:00 2001 From: Yiltan Date: Wed, 10 Dec 2025 11:43:15 -0500 Subject: [PATCH] Add default context alltoall API (#350) [ROCm/rocshmem commit: fddbe7b15dbe66c7c5b05138e2d13a660ee21fbc] --- projects/rocshmem/CHANGELOG.md | 4 ++ projects/rocshmem/docs/api/coll.rst | 1 + .../include/rocshmem/rocshmem_COLL.hpp | 38 +++++++++++++++++++ projects/rocshmem/src/rocshmem_gpu.cpp | 35 ++++++++++++----- 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/projects/rocshmem/CHANGELOG.md b/projects/rocshmem/CHANGELOG.md index 360aef5d56..d901cf84b8 100644 --- a/projects/rocshmem/CHANGELOG.md +++ b/projects/rocshmem/CHANGELOG.md @@ -1,4 +1,8 @@ # Changelog for rocSHMEM +## Unreleased - rocSHMEM 3.x.x for ROCm 7.x.x +### Added +* Added new APIs: + * `rocshmem_TYPENAME_alltoall_wg` ## rocSHMEM 3.2.0 for ROCm 7.2.0 ### Added diff --git a/projects/rocshmem/docs/api/coll.rst b/projects/rocshmem/docs/api/coll.rst index 43f4cb4c1a..63709b6a96 100644 --- a/projects/rocshmem/docs/api/coll.rst +++ b/projects/rocshmem/docs/api/coll.rst @@ -85,6 +85,7 @@ These APIs should be called from only one thread/wavefront/workgroup within the ROSHMEM_ALLTOALL ---------------- +.. cpp:function:: __device__ void rocshmem_TYPENAME_alltoall_wg(rocshmem_team_t team, TYPE *dest, const TYPE *source, int nelems) .. cpp:function:: __device__ void rocshmem_ctx_TYPENAME_alltoall_wg(rocshmem_ctx_t ctx, rocshmem_team_t team, TYPE *dest, const TYPE *source, int nelems) :param team: The team participating in the collective. diff --git a/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp b/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp index 0d13f082bb..d9fbb4291d 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp @@ -95,6 +95,44 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_ulonglong_alltoall_wg( rocshmem_ctx_t ctx, rocshmem_team_t team, unsigned long long *dest, const unsigned long long *source, int nelems); +__device__ ATTR_NO_INLINE void rocshmem_float_alltoall_wg( + rocshmem_team_t team, float *dest, const float *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_double_alltoall_wg( + rocshmem_team_t team, double *dest, const double *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_char_alltoall_wg( + rocshmem_team_t team, char *dest, const char *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_schar_alltoall_wg( + rocshmem_team_t team, signed char *dest, const signed char *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_short_alltoall_wg( + rocshmem_team_t team, short *dest, const short *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_int_alltoall_wg( + rocshmem_team_t team, int *dest, const int *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_long_alltoall_wg( + rocshmem_team_t team, long *dest, const long *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_longlong_alltoall_wg( + rocshmem_team_t team, long long *dest, const long long *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_uchar_alltoall_wg( + rocshmem_team_t team, unsigned char *dest, const unsigned char *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_ushort_alltoall_wg( + rocshmem_team_t team, unsigned short *dest, const unsigned short *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_uint_alltoall_wg( + rocshmem_team_t team, unsigned int *dest, const unsigned int *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_ulong_alltoall_wg( + rocshmem_team_t team, unsigned long *dest, const unsigned long *source, int nelems); + +__device__ ATTR_NO_INLINE void rocshmem_ulonglong_alltoall_wg( + rocshmem_team_t team, unsigned long long *dest, const unsigned long long *source, int nelems); /** * @name SHMEM_BROADCAST diff --git a/projects/rocshmem/src/rocshmem_gpu.cpp b/projects/rocshmem/src/rocshmem_gpu.cpp index 721c164b31..404e68b8be 100644 --- a/projects/rocshmem/src/rocshmem_gpu.cpp +++ b/projects/rocshmem/src/rocshmem_gpu.cpp @@ -534,15 +534,24 @@ __device__ void rocshmem_broadcast_wg(rocshmem_ctx_t ctx, } template -__device__ void rocshmem_alltoall_wg(rocshmem_ctx_t ctx, - rocshmem_team_t team, T *dest, - const T *source, int nelem) { - GPU_DPRINTF("Function: rocshmem_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n", - ctx.ctx_opaque, team, dest, source, nelem); +__device__ void rocshmem_ctx_alltoall_wg(rocshmem_ctx_t ctx, + rocshmem_team_t team, T *dest, + const T *source, int nelem) { + GPU_DPRINTF("Function: rocshmem_ctx_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n", + ctx.ctx_opaque, team, dest, source, nelem); get_internal_ctx(ctx)->alltoall(team, dest, source, nelem); } +template +__device__ void rocshmem_alltoall_wg(rocshmem_team_t team, T *dest, + const T *source, int nelem) { + GPU_DPRINTF("Function: rocshmem_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n", + ctx.ctx_opaque, team, dest, source, nelem); + + get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->alltoall(team, dest, source, nelem); +} + template __device__ void rocshmem_fcollect_wg(rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, @@ -666,8 +675,8 @@ __global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team, // Call device alltoall function with created context and provided team // Using char type since size is in bytes (1 byte per element) - rocshmem_alltoall_wg(ctx, team, (char *) dest, - (const char *) source, (int) size); + rocshmem_ctx_alltoall_wg(ctx, team, (char *) dest, + (const char *) source, (int) size); if (ctx_result == 0) { rocshmem_wg_ctx_destroy(&ctx); @@ -1219,9 +1228,12 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team, template __device__ void rocshmem_broadcast_wg( \ rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \ int nelem, int pe_root); \ - template __device__ void rocshmem_alltoall_wg( \ + template __device__ void rocshmem_ctx_alltoall_wg( \ rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \ int nelem); \ + template __device__ void rocshmem_alltoall_wg( \ + rocshmem_team_t team, T * dest, const T *source, \ + int nelem); \ template __device__ void rocshmem_fcollect_wg( \ rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \ int nelem); \ @@ -1536,7 +1548,12 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team, __device__ void rocshmem_ctx_##TNAME##_alltoall_wg( \ rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, const T *source, \ int nelem) { \ - rocshmem_alltoall_wg(ctx, team, dest, source, nelem); \ + rocshmem_ctx_alltoall_wg(ctx, team, dest, source, nelem); \ + } \ + __device__ void rocshmem_##TNAME##_alltoall_wg( \ + rocshmem_team_t team, T *dest, const T *source, \ + int nelem) { \ + rocshmem_alltoall_wg(team, dest, source, nelem); \ } \ __device__ void rocshmem_ctx_##TNAME##_fcollect_wg( \ rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, const T *source, \