Add default context alltoall API (#350)
[ROCm/rocshmem commit: fddbe7b15d]
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
# Changelog for rocSHMEM
|
||||
## Unreleased - rocSHMEM 3.x.x for ROCm 7.x.x
|
||||
### Added
|
||||
* Added new APIs:
|
||||
* `rocshmem_TYPENAME_alltoall_wg`
|
||||
|
||||
## rocSHMEM 3.2.0 for ROCm 7.2.0
|
||||
### Added
|
||||
|
||||
@@ -85,6 +85,7 @@ These APIs should be called from only one thread/wavefront/workgroup within the
|
||||
ROSHMEM_ALLTOALL
|
||||
----------------
|
||||
|
||||
.. cpp:function:: __device__ void rocshmem_TYPENAME_alltoall_wg(rocshmem_team_t team, TYPE *dest, const TYPE *source, int nelems)
|
||||
.. cpp:function:: __device__ void rocshmem_ctx_TYPENAME_alltoall_wg(rocshmem_ctx_t ctx, rocshmem_team_t team, TYPE *dest, const TYPE *source, int nelems)
|
||||
|
||||
:param team: The team participating in the collective.
|
||||
|
||||
@@ -95,6 +95,44 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_ulonglong_alltoall_wg(
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, unsigned long long *dest,
|
||||
const unsigned long long *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_float_alltoall_wg(
|
||||
rocshmem_team_t team, float *dest, const float *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_double_alltoall_wg(
|
||||
rocshmem_team_t team, double *dest, const double *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_char_alltoall_wg(
|
||||
rocshmem_team_t team, char *dest, const char *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_schar_alltoall_wg(
|
||||
rocshmem_team_t team, signed char *dest, const signed char *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_short_alltoall_wg(
|
||||
rocshmem_team_t team, short *dest, const short *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_int_alltoall_wg(
|
||||
rocshmem_team_t team, int *dest, const int *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_long_alltoall_wg(
|
||||
rocshmem_team_t team, long *dest, const long *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_longlong_alltoall_wg(
|
||||
rocshmem_team_t team, long long *dest, const long long *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_uchar_alltoall_wg(
|
||||
rocshmem_team_t team, unsigned char *dest, const unsigned char *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ushort_alltoall_wg(
|
||||
rocshmem_team_t team, unsigned short *dest, const unsigned short *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_uint_alltoall_wg(
|
||||
rocshmem_team_t team, unsigned int *dest, const unsigned int *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ulong_alltoall_wg(
|
||||
rocshmem_team_t team, unsigned long *dest, const unsigned long *source, int nelems);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ulonglong_alltoall_wg(
|
||||
rocshmem_team_t team, unsigned long long *dest, const unsigned long long *source, int nelems);
|
||||
|
||||
/**
|
||||
* @name SHMEM_BROADCAST
|
||||
|
||||
@@ -534,15 +534,24 @@ __device__ void rocshmem_broadcast_wg(rocshmem_ctx_t ctx,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rocshmem_alltoall_wg(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelem) {
|
||||
GPU_DPRINTF("Function: rocshmem_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n",
|
||||
ctx.ctx_opaque, team, dest, source, nelem);
|
||||
__device__ void rocshmem_ctx_alltoall_wg(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelem) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n",
|
||||
ctx.ctx_opaque, team, dest, source, nelem);
|
||||
|
||||
get_internal_ctx(ctx)->alltoall<T>(team, dest, source, nelem);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rocshmem_alltoall_wg(rocshmem_team_t team, T *dest,
|
||||
const T *source, int nelem) {
|
||||
GPU_DPRINTF("Function: rocshmem_alltoall_wg (ctx=%zd, team=%zd, dest=%p, source=%p, nelem=%d\n",
|
||||
ctx.ctx_opaque, team, dest, source, nelem);
|
||||
|
||||
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->alltoall<T>(team, dest, source, nelem);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rocshmem_fcollect_wg(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team, T *dest,
|
||||
@@ -666,8 +675,8 @@ __global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team,
|
||||
|
||||
// Call device alltoall function with created context and provided team
|
||||
// Using char type since size is in bytes (1 byte per element)
|
||||
rocshmem_alltoall_wg<char>(ctx, team, (char *) dest,
|
||||
(const char *) source, (int) size);
|
||||
rocshmem_ctx_alltoall_wg<char>(ctx, team, (char *) dest,
|
||||
(const char *) source, (int) size);
|
||||
|
||||
if (ctx_result == 0) {
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
@@ -1219,9 +1228,12 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
|
||||
template __device__ void rocshmem_broadcast_wg<T>( \
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \
|
||||
int nelem, int pe_root); \
|
||||
template __device__ void rocshmem_alltoall_wg<T>( \
|
||||
template __device__ void rocshmem_ctx_alltoall_wg<T>( \
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \
|
||||
int nelem); \
|
||||
template __device__ void rocshmem_alltoall_wg<T>( \
|
||||
rocshmem_team_t team, T * dest, const T *source, \
|
||||
int nelem); \
|
||||
template __device__ void rocshmem_fcollect_wg<T>( \
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, T * dest, const T *source, \
|
||||
int nelem); \
|
||||
@@ -1536,7 +1548,12 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
|
||||
__device__ void rocshmem_ctx_##TNAME##_alltoall_wg( \
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, const T *source, \
|
||||
int nelem) { \
|
||||
rocshmem_alltoall_wg<T>(ctx, team, dest, source, nelem); \
|
||||
rocshmem_ctx_alltoall_wg<T>(ctx, team, dest, source, nelem); \
|
||||
} \
|
||||
__device__ void rocshmem_##TNAME##_alltoall_wg( \
|
||||
rocshmem_team_t team, T *dest, const T *source, \
|
||||
int nelem) { \
|
||||
rocshmem_alltoall_wg<T>(team, dest, source, nelem); \
|
||||
} \
|
||||
__device__ void rocshmem_ctx_##TNAME##_fcollect_wg( \
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, T *dest, const T *source, \
|
||||
|
||||
Reference in New Issue
Block a user