Merge pull request #45 from Yiltan/to_all_reduce
Fixed Function Signature for `to_all` APIs
[ROCm/rocshmem commit: 958575d8a4]
Этот коммит содержится в:
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
** hipcc -c -fgpu-rdc -x hip rocshmem_allreduce_test.cc -I/opt/rocm/include
|
||||
** hipcc -c -fgpu-rdc -x hip rocshmem_allreduce_test.cc -I/opt/rocm/include
|
||||
** -I$ROCHSMEM_INSTALL_DIR/include -I$OPENMPI_UCX_INSTALL_DIR/include/
|
||||
** hipcc -fgpu-rdc --hip-link rocshmem_allreduce_test.o -o rocshmem_allreduce_test
|
||||
** $ROCHSMEM_INSTALL_DIR/lib/librocshmem.a $OPENMPI_UCX_INSTALL_DIR/lib/libmpi.so
|
||||
** hipcc -fgpu-rdc --hip-link rocshmem_allreduce_test.o -o rocshmem_allreduce_test
|
||||
** $ROCHSMEM_INSTALL_DIR/lib/librocshmem.a $OPENMPI_UCX_INSTALL_DIR/lib/libmpi.so
|
||||
** -L/opt/rocm/lib -lamdhip64 -lhsa-runtime64
|
||||
**
|
||||
** ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 8 ./rocshmem_allreduce_test
|
||||
@@ -34,7 +34,7 @@ __global__ void allreduce_test(int *source, int *dest, size_t nelem,
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
int num_pes = roc_shmem_ctx_n_pes(ctx);
|
||||
|
||||
roc_shmem_ctx_int_sum_wg_to_all(ctx, team, dest, source, nelem);
|
||||
roc_shmem_ctx_int_sum_wg_reduce(ctx, team, dest, source, nelem);
|
||||
|
||||
roc_shmem_ctx_quiet(ctx);
|
||||
__syncthreads();
|
||||
@@ -114,7 +114,7 @@ int main (int argc, char **argv)
|
||||
|
||||
bool pass = check_recvbuf(dest, nelem, my_pe, npes);
|
||||
printf("Test %s \t nelem %d %s\n", argv[0], nelem, pass ? "[PASS]" : "[FAIL]");
|
||||
|
||||
|
||||
roc_shmem_free(source);
|
||||
roc_shmem_free(dest);
|
||||
|
||||
|
||||
@@ -49,6 +49,12 @@ namespace rocshmem {
|
||||
#define ATTR_NO_INLINE
|
||||
#endif
|
||||
|
||||
|
||||
enum ROC_SHMEM_STATUS {
|
||||
ROC_SHMEM_SUCCESS = 0,
|
||||
ROC_SHMEM_ERROR = 1,
|
||||
};
|
||||
|
||||
enum ROC_SHMEM_OP {
|
||||
ROC_SHMEM_SUM,
|
||||
ROC_SHMEM_MAX,
|
||||
@@ -837,14 +843,10 @@ __device__ ATTR_NO_INLINE void roc_shmem_threadfence_system();
|
||||
* MACRO DECLARE SHMEM_REDUCTION APIs
|
||||
*/
|
||||
#define REDUCTION_API_GEN(T, TNAME, Op_API) \
|
||||
__device__ ATTR_NO_INLINE void roc_shmem_ctx_##TNAME##_##Op_API##_wg_to_all( \
|
||||
__device__ ATTR_NO_INLINE int roc_shmem_ctx_##TNAME##_##Op_API##_wg_reduce( \
|
||||
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
|
||||
int nreduce); \
|
||||
__host__ void roc_shmem_ctx_##TNAME##_##Op_API##_to_all( \
|
||||
roc_shmem_ctx_t ctx, T *dest, const T *source, int nreduce, \
|
||||
int PE_start, int logPE_stride, int PE_size, T *pWrk, \
|
||||
long *pSync); /* NOLINT */ \
|
||||
__host__ void roc_shmem_ctx_##TNAME##_##Op_API##_to_all( \
|
||||
__host__ int roc_shmem_ctx_##TNAME##_##Op_API##_reduce( \
|
||||
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
|
||||
int nreduce);
|
||||
|
||||
|
||||
@@ -27,4 +27,4 @@ cmake \
|
||||
-DUSE_HOST_SIDE_HDP_FLUSH=OFF\
|
||||
$src_path
|
||||
cmake --build . --parallel 8
|
||||
cmake --install .
|
||||
cmake --install .
|
||||
|
||||
@@ -192,8 +192,7 @@ class Context {
|
||||
long* pSync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void to_all(roc_shmem_team_t team, T* dest, const T* source,
|
||||
int nreduce);
|
||||
__device__ int reduce(roc_shmem_team_t team, T* dest, const T* source, int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__device__ void put(T* dest, const T* source, size_t nelems, int pe);
|
||||
@@ -361,8 +360,7 @@ class Context {
|
||||
long* pSync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void to_all(roc_shmem_team_t team, T* dest, const T* source,
|
||||
int nreduce);
|
||||
__host__ int reduce(roc_shmem_team_t team, T* dest, const T* source, int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__host__ void wait_until(T *ivars, int cmp, T val);
|
||||
|
||||
@@ -80,17 +80,17 @@ __device__ void Context::to_all(T *dest, const T *source, int nreduce,
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void Context::to_all(roc_shmem_team_t team, T *dest, const T *source,
|
||||
__device__ int Context::reduce(roc_shmem_team_t team, T *dest, const T *source,
|
||||
int nreduce) {
|
||||
if (nreduce == 0) {
|
||||
return;
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
ctxStats.incStat(NUM_TO_ALL);
|
||||
}
|
||||
|
||||
DISPATCH(to_all<PAIR(T, Op)>(team, dest, source, nreduce));
|
||||
DISPATCH_RET(reduce<PAIR(T, Op)>(team, dest, source, nreduce));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -222,15 +222,15 @@ __host__ void Context::to_all(T *dest, const T *source, int nreduce,
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void Context::to_all(roc_shmem_team_t team, T *dest, const T *source,
|
||||
int nreduce) { // NOLINT(runtime/int)
|
||||
__host__ int Context::reduce(roc_shmem_team_t team, T *dest, const T *source,
|
||||
int nreduce) { // NOLINT(runtime/int)
|
||||
if (nreduce == 0) {
|
||||
return;
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
ctxHostStats.incStat(NUM_HOST_TO_ALL);
|
||||
|
||||
HOST_DISPATCH(to_all<PAIR(T, Op)>(team, dest, source, nreduce));
|
||||
HOST_DISPATCH_RET(reduce<PAIR(T, Op)>(team, dest, source, nreduce));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -207,8 +207,7 @@ class HostInterface {
|
||||
long* p_sync); // NOLINT(runtime/int)
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void to_all(roc_shmem_team_t team, T* dest, const T* source,
|
||||
int nreduce);
|
||||
__host__ int reduce(roc_shmem_team_t team, T* dest, const T* source, int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__host__ void wait_until(T *ivars, int cmp, T val,
|
||||
|
||||
@@ -376,9 +376,9 @@ __host__ void HostInterface::to_all(T* dest, const T* source, int nreduce,
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void HostInterface::to_all(roc_shmem_team_t team, T* dest,
|
||||
__host__ int HostInterface::reduce(roc_shmem_team_t team, T* dest,
|
||||
const T* source, int nreduce) {
|
||||
DPRINTF("Function: Team-based host_to_all\n");
|
||||
DPRINTF("Function: Team-based host_reduce\n");
|
||||
|
||||
/*
|
||||
* Get the MPI communicator of this team
|
||||
@@ -388,7 +388,7 @@ __host__ void HostInterface::to_all(roc_shmem_team_t team, T* dest,
|
||||
|
||||
to_all_internal<T, Op>(mpi_comm, dest, source, nreduce);
|
||||
|
||||
return;
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -122,8 +122,7 @@ class IPCContext : public Context {
|
||||
|
||||
// Collectives
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void to_all(roc_shmem_team_t team, T *dest, const T *source,
|
||||
int nreduce);
|
||||
__device__ int reduce(roc_shmem_team_t team, T *dest, const T *source, int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__device__ void broadcast(roc_shmem_team_t team, T *dest, const T *source,
|
||||
|
||||
@@ -95,8 +95,7 @@ class IPCHostContext : public Context {
|
||||
long *p_sync);
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void to_all(roc_shmem_team_t team, T *dest, const T *source,
|
||||
int nreduce);
|
||||
__host__ int reduce(roc_shmem_team_t team, T *dest, const T *source, int nreduce);
|
||||
|
||||
template <typename T>
|
||||
__host__ void wait_until(T *ivars, int cmp, T val);
|
||||
|
||||
@@ -151,7 +151,7 @@ __device__ T IPCContext::amo_fetch_cas(void *dest, T value, T cond, int pe) {
|
||||
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), cond,
|
||||
value);
|
||||
}
|
||||
|
||||
|
||||
// Collectives
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void compute_reduce(T *src, T *dst, int size, int wg_id,
|
||||
@@ -346,8 +346,8 @@ __device__ void IPCContext::internal_ring_allreduce(
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
__device__ int IPCContext::reduce(roc_shmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
@@ -361,6 +361,7 @@ __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
|
||||
|
||||
internal_to_all<T, Op>(dest, source, nreduce, pe_start, stride, pe_size, pWrk,
|
||||
p_sync);
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
|
||||
@@ -109,9 +109,9 @@ __host__ void IPCHostContext::to_all(T *dest, const T *source, int nreduce,
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void IPCHostContext::to_all(roc_shmem_team_t team, T *dest,
|
||||
__host__ int IPCHostContext::reduce(roc_shmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
host_interface->to_all<T, Op>(team, dest, source, nreduce);
|
||||
return host_interface->reduce<T, Op>(team, dest, source, nreduce);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -734,13 +734,13 @@ __host__ void roc_shmem_to_all([[maybe_unused]] roc_shmem_ctx_t ctx, T *dest,
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__host__ void roc_shmem_to_all([[maybe_unused]] roc_shmem_ctx_t ctx,
|
||||
__host__ int roc_shmem_reduce([[maybe_unused]] roc_shmem_ctx_t ctx,
|
||||
roc_shmem_team_t team, T *dest, const T *source,
|
||||
int nreduce) {
|
||||
DPRINTF("Host function: Team-based roc_shmem_to_all\n");
|
||||
DPRINTF("Host function: Team-based roc_shmem_reduce\n");
|
||||
|
||||
get_internal_ctx(ROC_SHMEM_HOST_CTX_DEFAULT)
|
||||
->to_all<T, Op>(team, dest, source, nreduce);
|
||||
return get_internal_ctx(ROC_SHMEM_HOST_CTX_DEFAULT)
|
||||
->reduce<T, Op>(team, dest, source, nreduce);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -821,7 +821,7 @@ __host__ int roc_shmem_test(T *ivars, int cmp, T val) {
|
||||
template __host__ void roc_shmem_to_all<T, Op>( \
|
||||
roc_shmem_ctx_t ctx, T * dest, const T *source, int nreduce, \
|
||||
int PE_start, int logPE_stride, int PE_size, T *pWrk, long *pSync); \
|
||||
template __host__ void roc_shmem_to_all<T, Op>( \
|
||||
template __host__ int roc_shmem_reduce<T, Op>( \
|
||||
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T * dest, const T *source, \
|
||||
int nreduce);
|
||||
|
||||
@@ -977,10 +977,10 @@ __host__ int roc_shmem_test(T *ivars, int cmp, T val) {
|
||||
roc_shmem_to_all<T, Op>(ctx, dest, source, nreduce, PE_start, \
|
||||
logPE_stride, PE_size, pWrk, pSync); \
|
||||
} \
|
||||
__host__ void roc_shmem_ctx_##TNAME##_##Op_API##_to_all( \
|
||||
__host__ int roc_shmem_ctx_##TNAME##_##Op_API##_reduce( \
|
||||
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
|
||||
int nreduce) { \
|
||||
roc_shmem_to_all<T, Op>(ctx, team, dest, source, nreduce); \
|
||||
return roc_shmem_reduce<T, Op>(ctx, team, dest, source, nreduce); \
|
||||
}
|
||||
|
||||
#define ARITH_REDUCTION_DEF_GEN(T, TNAME) \
|
||||
|
||||
@@ -430,11 +430,11 @@ __device__ void *roc_shmem_ptr(const void *dest, int pe) {
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void roc_shmem_wg_to_all(roc_shmem_ctx_t ctx, roc_shmem_team_t team,
|
||||
T *dest, const T *source, int nreduce) {
|
||||
GPU_DPRINTF("Function: roc_shmem_to_all\n");
|
||||
__device__ int roc_shmem_wg_reduce(roc_shmem_ctx_t ctx, roc_shmem_team_t team,
|
||||
T *dest, const T *source, int nreduce) {
|
||||
GPU_DPRINTF("Function: roc_shmem_reduce\n");
|
||||
|
||||
get_internal_ctx(ctx)->to_all<T, Op>(team, dest, source, nreduce);
|
||||
return get_internal_ctx(ctx)->reduce<T, Op>(team, dest, source, nreduce);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -864,7 +864,7 @@ __device__ int roc_shmem_team_translate_pe(roc_shmem_team_t src_team,
|
||||
* Template generator for reductions
|
||||
*/
|
||||
#define REDUCTION_GEN(T, Op) \
|
||||
template __device__ void roc_shmem_wg_to_all<T, Op>( \
|
||||
template __device__ int roc_shmem_wg_reduce<T, Op>( \
|
||||
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T * dest, const T *source, \
|
||||
int nreduce);
|
||||
|
||||
@@ -1072,10 +1072,10 @@ __device__ int roc_shmem_team_translate_pe(roc_shmem_team_t src_team,
|
||||
**/
|
||||
|
||||
#define REDUCTION_DEF_GEN(T, TNAME, Op_API, Op) \
|
||||
__device__ void roc_shmem_ctx_##TNAME##_##Op_API##_wg_to_all( \
|
||||
__device__ int roc_shmem_ctx_##TNAME##_##Op_API##_wg_reduce( \
|
||||
roc_shmem_ctx_t ctx, roc_shmem_team_t team, T *dest, const T *source, \
|
||||
int nreduce) { \
|
||||
roc_shmem_wg_to_all<T, Op>(ctx, team, dest, source, nreduce); \
|
||||
return roc_shmem_wg_reduce<T, Op>(ctx, team, dest, source, nreduce); \
|
||||
}
|
||||
|
||||
#define ARITH_REDUCTION_DEF_GEN(T, TNAME) \
|
||||
|
||||
@@ -24,19 +24,19 @@ using namespace rocshmem;
|
||||
|
||||
/* Declare the template with a generic implementation */
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void wg_team_to_all(roc_shmem_ctx_t ctx, roc_shmem_team_t, T *dest,
|
||||
__device__ int wg_team_reduce(roc_shmem_ctx_t ctx, roc_shmem_team_t, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
return;
|
||||
return ROC_SHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
/* Define templates to call ROC_SHMEM */
|
||||
#define TEAM_REDUCTION_DEF_GEN(T, TNAME, Op_API, Op) \
|
||||
template <> \
|
||||
__device__ void wg_team_to_all<T, Op>(roc_shmem_ctx_t ctx, \
|
||||
roc_shmem_team_t team, T * dest, \
|
||||
const T *source, int nreduce) { \
|
||||
roc_shmem_ctx_##TNAME##_##Op_API##_wg_to_all(ctx, team, dest, source, \
|
||||
nreduce); \
|
||||
__device__ int wg_team_reduce<T, Op>(roc_shmem_ctx_t ctx, \
|
||||
roc_shmem_team_t team, T * dest, \
|
||||
const T *source, int nreduce) { \
|
||||
return roc_shmem_ctx_##TNAME##_##Op_API##_wg_reduce(ctx, team, dest, \
|
||||
source, nreduce); \
|
||||
}
|
||||
|
||||
#define TEAM_ARITH_REDUCTION_DEF_GEN(T, TNAME) \
|
||||
@@ -91,7 +91,7 @@ __global__ void TeamReductionTest(int loop, int skip, uint64_t *timer,
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start = roc_shmem_timer();
|
||||
}
|
||||
wg_team_to_all<T1, T2>(ctx, team, r_buf, s_buf, size);
|
||||
wg_team_reduce<T1, T2>(ctx, team, r_buf, s_buf, size);
|
||||
roc_shmem_ctx_wg_barrier_all(ctx);
|
||||
}
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user