Merge pull request #45 from Yiltan/to_all_reduce

Fixed Function Signature for `to_all` APIs

[ROCm/rocshmem commit: 958575d8a4]
Этот коммит содержится в:
Yiltan
2024-11-06 10:54:36 -05:00
коммит произвёл GitHub
родитель b43ef2b45b 8df27a93be
Коммит bae2b2aece
15 изменённых файлов: 57 добавлений и 59 удалений
+5 -5
Просмотреть файл
@@ -1,8 +1,8 @@
/*
** hipcc -c -fgpu-rdc -x hip rocshmem_allreduce_test.cc -I/opt/rocm/include
** hipcc -c -fgpu-rdc -x hip rocshmem_allreduce_test.cc -I/opt/rocm/include
** -I$ROCHSMEM_INSTALL_DIR/include -I$OPENMPI_UCX_INSTALL_DIR/include/
** hipcc -fgpu-rdc --hip-link rocshmem_allreduce_test.o -o rocshmem_allreduce_test
** $ROCHSMEM_INSTALL_DIR/lib/librocshmem.a $OPENMPI_UCX_INSTALL_DIR/lib/libmpi.so
** hipcc -fgpu-rdc --hip-link rocshmem_allreduce_test.o -o rocshmem_allreduce_test
** $ROCHSMEM_INSTALL_DIR/lib/librocshmem.a $OPENMPI_UCX_INSTALL_DIR/lib/libmpi.so
** -L/opt/rocm/lib -lamdhip64 -lhsa-runtime64
**
** ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 8 ./rocshmem_allreduce_test
@@ -34,7 +34,7 @@ __global__ void allreduce_test(int *source, int *dest, size_t nelem,
roc_shmem_wg_ctx_create(ctx_type, &ctx);
int num_pes = roc_shmem_ctx_n_pes(ctx);
roc_shmem_ctx_int_sum_wg_to_all(ctx, team, dest, source, nelem);
roc_shmem_ctx_int_sum_wg_reduce(ctx, team, dest, source, nelem);
roc_shmem_ctx_quiet(ctx);
__syncthreads();
@@ -114,7 +114,7 @@ int main (int argc, char **argv)
bool pass = check_recvbuf(dest, nelem, my_pe, npes);
printf("Test %s \t nelem %d %s\n", argv[0], nelem, pass ? "[PASS]" : "[FAIL]");
roc_shmem_free(source);
roc_shmem_free(dest);
+8 -6
Просмотреть файл
@@ -49,6 +49,12 @@ namespace rocshmem {
#define ATTR_NO_INLINE
#endif
enum ROC_SHMEM_STATUS {
ROC_SHMEM_SUCCESS = 0,
ROC_SHMEM_ERROR = 1,
};
enum ROC_SHMEM_OP {
ROC_SHMEM_SUM,
ROC_SHMEM_MAX,
@@ -837,14 +843,10 @@ __device__ ATTR_NO_INLINE void roc_shmem_threadfence_system();
* MACRO DECLARE SHMEM_REDUCTION APIs
*/
#define REDUCTION_API_GEN(T, TNAME, Op_API) \
__device__ ATTR_NO_INLINE void roc_shmem_ctx_##TNAME##_##Op_API##_wg_to_all( \
__device__ ATTR_NO_INLINE int roc_shmem_ctx_##TNAME##_##Op_API##_wg_reduce( \
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
int nreduce); \
__host__ void roc_shmem_ctx_##TNAME##_##Op_API##_to_all( \
roc_shmem_ctx_t ctx, T *dest, const T *source, int nreduce, \
int PE_start, int logPE_stride, int PE_size, T *pWrk, \
long *pSync); /* NOLINT */ \
__host__ void roc_shmem_ctx_##TNAME##_##Op_API##_to_all( \
__host__ int roc_shmem_ctx_##TNAME##_##Op_API##_reduce( \
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
int nreduce);
+1 -1
Просмотреть файл
@@ -27,4 +27,4 @@ cmake \
-DUSE_HOST_SIDE_HDP_FLUSH=OFF\
$src_path
cmake --build . --parallel 8
cmake --install .
cmake --install .
+2 -4
Просмотреть файл
@@ -192,8 +192,7 @@ class Context {
long* pSync); // NOLINT(runtime/int)
template <typename T, ROC_SHMEM_OP Op>
__device__ void to_all(roc_shmem_team_t team, T* dest, const T* source,
int nreduce);
__device__ int reduce(roc_shmem_team_t team, T* dest, const T* source, int nreduce);
template <typename T>
__device__ void put(T* dest, const T* source, size_t nelems, int pe);
@@ -361,8 +360,7 @@ class Context {
long* pSync); // NOLINT(runtime/int)
template <typename T, ROC_SHMEM_OP Op>
__host__ void to_all(roc_shmem_team_t team, T* dest, const T* source,
int nreduce);
__host__ int reduce(roc_shmem_team_t team, T* dest, const T* source, int nreduce);
template <typename T>
__host__ void wait_until(T *ivars, int cmp, T val);
+3 -3
Просмотреть файл
@@ -80,17 +80,17 @@ __device__ void Context::to_all(T *dest, const T *source, int nreduce,
}
template <typename T, ROC_SHMEM_OP Op>
__device__ void Context::to_all(roc_shmem_team_t team, T *dest, const T *source,
__device__ int Context::reduce(roc_shmem_team_t team, T *dest, const T *source,
int nreduce) {
if (nreduce == 0) {
return;
return ROC_SHMEM_SUCCESS;
}
if (is_thread_zero_in_block()) {
ctxStats.incStat(NUM_TO_ALL);
}
DISPATCH(to_all<PAIR(T, Op)>(team, dest, source, nreduce));
DISPATCH_RET(reduce<PAIR(T, Op)>(team, dest, source, nreduce));
}
template <typename T>
+4 -4
Просмотреть файл
@@ -222,15 +222,15 @@ __host__ void Context::to_all(T *dest, const T *source, int nreduce,
}
template <typename T, ROC_SHMEM_OP Op>
__host__ void Context::to_all(roc_shmem_team_t team, T *dest, const T *source,
int nreduce) { // NOLINT(runtime/int)
__host__ int Context::reduce(roc_shmem_team_t team, T *dest, const T *source,
int nreduce) { // NOLINT(runtime/int)
if (nreduce == 0) {
return;
return ROC_SHMEM_SUCCESS;
}
ctxHostStats.incStat(NUM_HOST_TO_ALL);
HOST_DISPATCH(to_all<PAIR(T, Op)>(team, dest, source, nreduce));
HOST_DISPATCH_RET(reduce<PAIR(T, Op)>(team, dest, source, nreduce));
}
template <typename T>
+1 -2
Просмотреть файл
@@ -207,8 +207,7 @@ class HostInterface {
long* p_sync); // NOLINT(runtime/int)
template <typename T, ROC_SHMEM_OP Op>
__host__ void to_all(roc_shmem_team_t team, T* dest, const T* source,
int nreduce);
__host__ int reduce(roc_shmem_team_t team, T* dest, const T* source, int nreduce);
template <typename T>
__host__ void wait_until(T *ivars, int cmp, T val,
+3 -3
Просмотреть файл
@@ -376,9 +376,9 @@ __host__ void HostInterface::to_all(T* dest, const T* source, int nreduce,
}
template <typename T, ROC_SHMEM_OP Op>
__host__ void HostInterface::to_all(roc_shmem_team_t team, T* dest,
__host__ int HostInterface::reduce(roc_shmem_team_t team, T* dest,
const T* source, int nreduce) {
DPRINTF("Function: Team-based host_to_all\n");
DPRINTF("Function: Team-based host_reduce\n");
/*
* Get the MPI communicator of this team
@@ -388,7 +388,7 @@ __host__ void HostInterface::to_all(roc_shmem_team_t team, T* dest,
to_all_internal<T, Op>(mpi_comm, dest, source, nreduce);
return;
return ROC_SHMEM_SUCCESS;
}
template <typename T>
+1 -2
Просмотреть файл
@@ -122,8 +122,7 @@ class IPCContext : public Context {
// Collectives
template <typename T, ROC_SHMEM_OP Op>
__device__ void to_all(roc_shmem_team_t team, T *dest, const T *source,
int nreduce);
__device__ int reduce(roc_shmem_team_t team, T *dest, const T *source, int nreduce);
template <typename T>
__device__ void broadcast(roc_shmem_team_t team, T *dest, const T *source,
+1 -2
Просмотреть файл
@@ -95,8 +95,7 @@ class IPCHostContext : public Context {
long *p_sync);
template <typename T, ROC_SHMEM_OP Op>
__host__ void to_all(roc_shmem_team_t team, T *dest, const T *source,
int nreduce);
__host__ int reduce(roc_shmem_team_t team, T *dest, const T *source, int nreduce);
template <typename T>
__host__ void wait_until(T *ivars, int cmp, T val);
+4 -3
Просмотреть файл
@@ -151,7 +151,7 @@ __device__ T IPCContext::amo_fetch_cas(void *dest, T value, T cond, int pe) {
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), cond,
value);
}
// Collectives
template <typename T, ROC_SHMEM_OP Op>
__device__ void compute_reduce(T *src, T *dst, int size, int wg_id,
@@ -346,8 +346,8 @@ __device__ void IPCContext::internal_ring_allreduce(
}
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
const T *source, int nreduce) {
__device__ int IPCContext::reduce(roc_shmem_team_t team, T *dest,
const T *source, int nreduce) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
@@ -361,6 +361,7 @@ __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
internal_to_all<T, Op>(dest, source, nreduce, pe_start, stride, pe_size, pWrk,
p_sync);
return ROC_SHMEM_SUCCESS;
}
template <typename T, ROC_SHMEM_OP Op>
+2 -2
Просмотреть файл
@@ -109,9 +109,9 @@ __host__ void IPCHostContext::to_all(T *dest, const T *source, int nreduce,
}
template <typename T, ROC_SHMEM_OP Op>
__host__ void IPCHostContext::to_all(roc_shmem_team_t team, T *dest,
__host__ int IPCHostContext::reduce(roc_shmem_team_t team, T *dest,
const T *source, int nreduce) {
host_interface->to_all<T, Op>(team, dest, source, nreduce);
return host_interface->reduce<T, Op>(team, dest, source, nreduce);
}
template <typename T>
+7 -7
Просмотреть файл
@@ -734,13 +734,13 @@ __host__ void roc_shmem_to_all([[maybe_unused]] roc_shmem_ctx_t ctx, T *dest,
}
template <typename T, ROC_SHMEM_OP Op>
__host__ void roc_shmem_to_all([[maybe_unused]] roc_shmem_ctx_t ctx,
__host__ int roc_shmem_reduce([[maybe_unused]] roc_shmem_ctx_t ctx,
roc_shmem_team_t team, T *dest, const T *source,
int nreduce) {
DPRINTF("Host function: Team-based roc_shmem_to_all\n");
DPRINTF("Host function: Team-based roc_shmem_reduce\n");
get_internal_ctx(ROC_SHMEM_HOST_CTX_DEFAULT)
->to_all<T, Op>(team, dest, source, nreduce);
return get_internal_ctx(ROC_SHMEM_HOST_CTX_DEFAULT)
->reduce<T, Op>(team, dest, source, nreduce);
}
template <typename T>
@@ -821,7 +821,7 @@ __host__ int roc_shmem_test(T *ivars, int cmp, T val) {
template __host__ void roc_shmem_to_all<T, Op>( \
roc_shmem_ctx_t ctx, T * dest, const T *source, int nreduce, \
int PE_start, int logPE_stride, int PE_size, T *pWrk, long *pSync); \
template __host__ void roc_shmem_to_all<T, Op>( \
template __host__ int roc_shmem_reduce<T, Op>( \
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T * dest, const T *source, \
int nreduce);
@@ -977,10 +977,10 @@ __host__ int roc_shmem_test(T *ivars, int cmp, T val) {
roc_shmem_to_all<T, Op>(ctx, dest, source, nreduce, PE_start, \
logPE_stride, PE_size, pWrk, pSync); \
} \
__host__ void roc_shmem_ctx_##TNAME##_##Op_API##_to_all( \
__host__ int roc_shmem_ctx_##TNAME##_##Op_API##_reduce( \
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
int nreduce) { \
roc_shmem_to_all<T, Op>(ctx, team, dest, source, nreduce); \
return roc_shmem_reduce<T, Op>(ctx, team, dest, source, nreduce); \
}
#define ARITH_REDUCTION_DEF_GEN(T, TNAME) \
+7 -7
Просмотреть файл
@@ -430,11 +430,11 @@ __device__ void *roc_shmem_ptr(const void *dest, int pe) {
}
template <typename T, ROC_SHMEM_OP Op>
__device__ void roc_shmem_wg_to_all(roc_shmem_ctx_t ctx, roc_shmem_team_t team,
T *dest, const T *source, int nreduce) {
GPU_DPRINTF("Function: roc_shmem_to_all\n");
__device__ int roc_shmem_wg_reduce(roc_shmem_ctx_t ctx, roc_shmem_team_t team,
T *dest, const T *source, int nreduce) {
GPU_DPRINTF("Function: roc_shmem_reduce\n");
get_internal_ctx(ctx)->to_all<T, Op>(team, dest, source, nreduce);
return get_internal_ctx(ctx)->reduce<T, Op>(team, dest, source, nreduce);
}
template <typename T>
@@ -864,7 +864,7 @@ __device__ int roc_shmem_team_translate_pe(roc_shmem_team_t src_team,
* Template generator for reductions
*/
#define REDUCTION_GEN(T, Op) \
template __device__ void roc_shmem_wg_to_all<T, Op>( \
template __device__ int roc_shmem_wg_reduce<T, Op>( \
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T * dest, const T *source, \
int nreduce);
@@ -1072,10 +1072,10 @@ __device__ int roc_shmem_team_translate_pe(roc_shmem_team_t src_team,
**/
#define REDUCTION_DEF_GEN(T, TNAME, Op_API, Op) \
__device__ void roc_shmem_ctx_##TNAME##_##Op_API##_wg_to_all( \
__device__ int roc_shmem_ctx_##TNAME##_##Op_API##_wg_reduce( \
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
int nreduce) { \
roc_shmem_wg_to_all<T, Op>(ctx, team, dest, source, nreduce); \
return roc_shmem_wg_reduce<T, Op>(ctx, team, dest, source, nreduce); \
}
#define ARITH_REDUCTION_DEF_GEN(T, TNAME) \
+8 -8
Просмотреть файл
@@ -24,19 +24,19 @@ using namespace rocshmem;
/* Declare the template with a generic implementation */
template <typename T, ROC_SHMEM_OP Op>
__device__ void wg_team_to_all(roc_shmem_ctx_t ctx, roc_shmem_team_t, T *dest,
__device__ int wg_team_reduce(roc_shmem_ctx_t ctx, roc_shmem_team_t, T *dest,
const T *source, int nreduce) {
return;
return ROC_SHMEM_SUCCESS;
}
/* Define templates to call ROC_SHMEM */
#define TEAM_REDUCTION_DEF_GEN(T, TNAME, Op_API, Op) \
template <> \
__device__ void wg_team_to_all<T, Op>(roc_shmem_ctx_t ctx, \
roc_shmem_team_t team, T * dest, \
const T *source, int nreduce) { \
roc_shmem_ctx_##TNAME##_##Op_API##_wg_to_all(ctx, team, dest, source, \
nreduce); \
__device__ int wg_team_reduce<T, Op>(roc_shmem_ctx_t ctx, \
roc_shmem_team_t team, T * dest, \
const T *source, int nreduce) { \
return roc_shmem_ctx_##TNAME##_##Op_API##_wg_reduce(ctx, team, dest, \
source, nreduce); \
}
#define TEAM_ARITH_REDUCTION_DEF_GEN(T, TNAME) \
@@ -91,7 +91,7 @@ __global__ void TeamReductionTest(int loop, int skip, uint64_t *timer,
if (i == skip && hipThreadIdx_x == 0) {
start = roc_shmem_timer();
}
wg_team_to_all<T1, T2>(ctx, team, r_buf, s_buf, size);
wg_team_reduce<T1, T2>(ctx, team, r_buf, s_buf, size);
roc_shmem_ctx_wg_barrier_all(ctx);
}