Add default context alltoall API (#350)

[ROCm/rocshmem commit: fddbe7b15d]
Šī revīzija ir iekļauta:
Yiltan
2025-12-10 11:43:15 -05:00
revīziju iesūtīja GitHub
vecāks 972893bab2
revīzija 258d264ecc
4 mainīti faili ar 69 papildinājumiem un 9 dzēšanām
+4
Parādīt failu
@@ -1,4 +1,8 @@
# Changelog for rocSHMEM
## Unreleased - rocSHMEM 3.x.x for ROCm 7.x.x
### Added
* Added new APIs:
* `rocshmem_TYPENAME_alltoall_wg`
## rocSHMEM 3.2.0 for ROCm 7.2.0
### Added
@@ -85,6 +85,7 @@ These APIs should be called from only one thread/wavefront/workgroup within the
ROSHMEM_ALLTOALL
----------------
.. cpp:function:: __device__ void rocshmem_TYPENAME_alltoall_wg(rocshmem_team_t team, TYPE *dest, const TYPE *source, int nelems)
.. cpp:function:: __device__ void rocshmem_ctx_TYPENAME_alltoall_wg(rocshmem_ctx_t ctx, rocshmem_team_t team, TYPE *dest, const TYPE *source, int nelems)
:param team: The team participating in the collective.
@@ -95,6 +95,44 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_ulonglong_alltoall_wg(
rocshmem_ctx_t ctx, rocshmem_team_t team, unsigned long long *dest,
const unsigned long long *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_float_alltoall_wg(
rocshmem_team_t team, float *dest, const float *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_double_alltoall_wg(
rocshmem_team_t team, double *dest, const double *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_char_alltoall_wg(
rocshmem_team_t team, char *dest, const char *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_schar_alltoall_wg(
rocshmem_team_t team, signed char *dest, const signed char *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_short_alltoall_wg(
rocshmem_team_t team, short *dest, const short *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_int_alltoall_wg(
rocshmem_team_t team, int *dest, const int *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_long_alltoall_wg(
rocshmem_team_t team, long *dest, const long *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_longlong_alltoall_wg(
rocshmem_team_t team, long long *dest, const long long *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_uchar_alltoall_wg(
rocshmem_team_t team, unsigned char *dest, const unsigned char *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_ushort_alltoall_wg(
rocshmem_team_t team, unsigned short *dest, const unsigned short *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_uint_alltoall_wg(
rocshmem_team_t team, unsigned int *dest, const unsigned int *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_ulong_alltoall_wg(
rocshmem_team_t team, unsigned long *dest, const unsigned long *source, int nelems);
__device__ ATTR_NO_INLINE void rocshmem_ulonglong_alltoall_wg(
rocshmem_team_t team, unsigned long long *dest, const unsigned long long *source, int nelems);
/**
* @name SHMEM_BROADCAST
+26 -9
Parādīt failu
@@ -534,15 +534,24 @@ __device__ void rocshmem_broadcast_wg(rocshmem_ctx_t ctx,
}
template <typename T>
__device__ void rocshmem_alltoall_wg(rocshmem_ctx_t ctx,
rocshmem_team_t team, T *dest,
const T *source, int nelem) {
GPU_DPRINTF("Function: rocshmem_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n",
ctx.ctx_opaque, team, dest, source, nelem);
__device__ void rocshmem_ctx_alltoall_wg(rocshmem_ctx_t ctx,
rocshmem_team_t team, T *dest,
const T *source, int nelem) {
GPU_DPRINTF("Function: rocshmem_ctx_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n",
ctx.ctx_opaque, team, dest, source, nelem);
get_internal_ctx(ctx)->alltoall<T>(team, dest, source, nelem);
}
template <typename T>
__device__ void rocshmem_alltoall_wg(rocshmem_team_t team, T *dest,
const T *source, int nelem) {
GPU_DPRINTF("Function: rocshmem_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n",
ctx.ctx_opaque, team, dest, source, nelem);
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->alltoall<T>(team, dest, source, nelem);
}
template <typename T>
__device__ void rocshmem_fcollect_wg(rocshmem_ctx_t ctx,
rocshmem_team_t team, T *dest,
@@ -666,8 +675,8 @@ __global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team,
// Call device alltoall function with created context and provided team
// Using char type since size is in bytes (1 byte per element)
rocshmem_alltoall_wg<char>(ctx, team, (char *) dest,
(const char *) source, (int) size);
rocshmem_ctx_alltoall_wg<char>(ctx, team, (char *) dest,
(const char *) source, (int) size);
if (ctx_result == 0) {
rocshmem_wg_ctx_destroy(&ctx);
@@ -1219,9 +1228,12 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
template __device__ void rocshmem_broadcast_wg<T>( \
rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \
int nelem, int pe_root); \
template __device__ void rocshmem_alltoall_wg<T>( \
template __device__ void rocshmem_ctx_alltoall_wg<T>( \
rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \
int nelem); \
template __device__ void rocshmem_alltoall_wg<T>( \
rocshmem_team_t team, T * dest, const T *source, \
int nelem); \
template __device__ void rocshmem_fcollect_wg<T>( \
rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \
int nelem); \
@@ -1536,7 +1548,12 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
__device__ void rocshmem_ctx_##TNAME##_alltoall_wg( \
rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, const T *source, \
int nelem) { \
rocshmem_alltoall_wg<T>(ctx, team, dest, source, nelem); \
rocshmem_ctx_alltoall_wg<T>(ctx, team, dest, source, nelem); \
} \
__device__ void rocshmem_##TNAME##_alltoall_wg( \
rocshmem_team_t team, T *dest, const T *source, \
int nelem) { \
rocshmem_alltoall_wg<T>(team, dest, source, nelem); \
} \
__device__ void rocshmem_ctx_##TNAME##_fcollect_wg( \
rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, const T *source, \