Cleanup/wg init (#260)

* remove wg_init and wg_finalize from functional tests

* Remove wg_init and wg_finalize from examples

* deprecate wg_init/finalize

* Updated docs

* Typo in documentation

---------

Co-authored-by: Yiltan <yiltan@amd.com>

[ROCm/rocshmem commit: 6e7277b544]
This commit is contained in:
Aurelien Bouteiller
2025-10-07 14:34:18 -04:00
zatwierdzone przez GitHub
rodzic 192c549d40
commit 8837414042
32 zmienionych plików z 23 dodań i 68 usunięć
+9
Wyświetl plik
@@ -1,5 +1,14 @@
# Changelog for rocSHMEM
## rocSHMEM 3.x.x for ROCm 7.x.x
### Changed
* The following APIs have been deprecated:
* `rocshmem_wg_init`
* `rocshmem_wg_finalize`
* `rocshmem_wg_init_thread`
## rocSHMEM 3.0.0 for ROCm 7.0.0
### Added
+11 -3
Wyświetl plik
@@ -23,12 +23,16 @@ you must select the device that this PE is associated to by calling
`hipSetDevice
<https://rocm.docs.amd.com/projects/HIP/en/docs-6.0.0/doxygen/html/group___device.html#ga43c1e7f15925eeb762195ccb5e063eae>`_.
.. cpp:function:: __device__ void rocshmem_wg_init(void)
.. WARNING::
Routine `rocshmem_wg_init` has been deprecated.
.. cpp:function:: [[deprecated]] __device__ void rocshmem_wg_init(void)
:Parameters: None.
:returns: None.
**Description:**
This routine has been deprecated, please do not use.
This routine initializes device-side rocSHMEM resources.
It must be called before any threads in this work-group invoke other rocSHMEM functions.
It must be called collectively by all threads in the work-group.
@@ -43,12 +47,16 @@ ROCSHMEM_FINALIZE
**Description:**
This routine finalizes the rocSHMEM library.
.. cpp:function:: __device__ void rocshmem_wg_finalize(void)
.. WARNING::
Routine `rocshmem_wg_finalize` has been deprecated.
.. cpp:function:: [[deprecated]] __device__ void rocshmem_wg_finalize(void)
:Parameters: None.
:returns: None.
**Description:**
This routine has been deprecated, please do not use.
This routine finalizes device-side rocSHMEM resources.
It must be called before work-group completion if the work-group also called ``rocshmem_wg_init``.
It must be called collectively by all threads in the work-group.
@@ -65,7 +73,7 @@ ROCSHMEM_INIT_ATTR
**Description:**
This routine initializes the rocSHMEM runtime and underlying transport layer using
the provided mode and attributes.
The parameter ``flags`` can be either
The parameter ``flags`` can be either
``ROCSHMEM_INIT_WITH_UNIQUEID`` or ``ROCSHMEM_INIT_WITH_MPI_COMM``.
ROCSHMEM_GET_UNIQUEID
@@ -65,7 +65,6 @@ __global__ void allreduce_test(int *source, int *dest, size_t nelem,
__shared__ rocshmem_ctx_t ctx;
int64_t ctx_type = 0;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int num_pes = rocshmem_ctx_n_pes(ctx);
@@ -75,7 +74,6 @@ __global__ void allreduce_test(int *source, int *dest, size_t nelem,
__syncthreads();
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
static void init_sendbuf (int *source, int nelem, int my_pe)
@@ -65,7 +65,6 @@ __global__ void alltoall_test(int *source, int *dest, size_t nelem,
__shared__ rocshmem_ctx_t ctx;
int64_t ctx_type = 0;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int num_pes = rocshmem_ctx_n_pes(ctx);
@@ -75,7 +74,6 @@ __global__ void alltoall_test(int *source, int *dest, size_t nelem,
__syncthreads();
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
static void init_sendbuf (int *source, int nelem, int my_pe, int npes)
@@ -65,7 +65,6 @@ __global__ void broadcast_test(int *source, int *dest, size_t nelem,
__shared__ rocshmem_ctx_t ctx;
int64_t ctx_type = 0;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int num_pes = rocshmem_ctx_n_pes(ctx);
@@ -75,7 +74,6 @@ __global__ void broadcast_test(int *source, int *dest, size_t nelem,
__syncthreads();
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
static void init_sendbuf(int *source, int nelem, int my_pe)
@@ -62,7 +62,6 @@ using namespace rocshmem;
__global__ void simple_getmem_test(int *src, int *dst, size_t nelem)
{
rocshmem_wg_init();
int threadId = blockIdx.x * blockDim.x + threadIdx.x;
if (threadId == 0) {
@@ -73,7 +72,6 @@ __global__ void simple_getmem_test(int *src, int *dst, size_t nelem)
}
__syncthreads();
rocshmem_wg_finalize();
}
#define MAX_ELEM 256
@@ -63,7 +63,6 @@ using namespace rocshmem;
__global__ void simple_put_signal_test(uint64_t *data, uint64_t *message, size_t nelem,
uint64_t *sig_addr, int my_pe, int dst_pe)
{
rocshmem_wg_init();
int threadId = blockIdx.x * blockDim.x + threadIdx.x;
@@ -78,7 +77,6 @@ __global__ void simple_put_signal_test(uint64_t *data, uint64_t *message, size_t
}
__syncthreads();
rocshmem_wg_finalize();
}
#define MAX_ELEM 256
@@ -374,7 +374,7 @@ __host__ void rocshmem_global_exit(int status);
*
* @return void.
*/
__device__ void rocshmem_wg_init();
[[deprecated]] __device__ void rocshmem_wg_init();
/**
* @brief Finalizes device-side rocSHMEM resources. Must be called before
@@ -384,7 +384,7 @@ __device__ void rocshmem_wg_init();
*
* @return void.
*/
__device__ void rocshmem_wg_finalize();
[[deprecated]] __device__ void rocshmem_wg_finalize();
/**
* @brief Initializes device-side rocSHMEM resources. Must be called before
@@ -400,7 +400,7 @@ __device__ void rocshmem_wg_finalize();
*
* @return void.
*/
__device__ void rocshmem_wg_init_thread(int requested, int *provided);
[[deprecated]] __device__ void rocshmem_wg_init_thread(int requested, int *provided);
/**
* @brief Query the thread mode used by the runtime.
@@ -254,7 +254,6 @@ void AMOBitwiseTester<T>::verifyResults(size_t size) {
int global_id = get_flat_id(); \
int n_threads = get_flat_grid_size(); \
int n_wgs = get_grid_num_blocks(); \
rocshmem_wg_init(); \
rocshmem_wg_ctx_create(ctx_type, &ctx); \
for (int i = 0; i < loop + skip; i++) { \
T *ptr = compute_target_ptr<T>(dest, addr_mode, wg_id, i, n_wgs); \
@@ -293,7 +292,6 @@ void AMOBitwiseTester<T>::verifyResults(size_t size) {
end_time[wg_id] = wall_clock64(); \
__syncthreads(); \
rocshmem_wg_ctx_destroy(&ctx); \
rocshmem_wg_finalize(); \
} \
template class AMOBitwiseTester<T>;
@@ -229,7 +229,6 @@ void AMOExtendedTester<T>::verifyResults(size_t /*size*/) {
int t_id = get_flat_block_id(); \
int n_threads = get_flat_grid_size(); \
int n_wgs = get_grid_num_blocks(); \
rocshmem_wg_init(); \
rocshmem_wg_ctx_create(ctx_type, &ctx); \
for (int i = 0; i < loop + skip; i++) { \
T *ptr = compute_target_ptr<T>(dest, addr_mode, wg_id, i, n_wgs); \
@@ -256,7 +255,6 @@ void AMOExtendedTester<T>::verifyResults(size_t /*size*/) {
end_time[wg_id] = wall_clock64(); \
__syncthreads(); \
rocshmem_wg_ctx_destroy(&ctx); \
rocshmem_wg_finalize(); \
} \
template class AMOExtendedTester<T>;
@@ -260,7 +260,6 @@ void AMOStandardTester<T>::verifyResults(size_t size) {
int t_id = get_flat_block_id(); \
int n_threads = get_flat_grid_size(); \
int n_wgs = get_grid_num_blocks(); \
rocshmem_wg_init(); \
rocshmem_wg_ctx_create(ctx_type, &ctx); \
for (int i = 0; i < loop + skip; i++) { \
T *ptr = compute_target_ptr<T>(dest, addr_mode, wg_id, i, n_wgs); \
@@ -294,7 +293,6 @@ void AMOStandardTester<T>::verifyResults(size_t size) {
end_time[wg_id] = wall_clock64(); \
__syncthreads(); \
rocshmem_wg_ctx_destroy(&ctx); \
rocshmem_wg_finalize(); \
} \
template class AMOStandardTester<T>;
@@ -39,7 +39,6 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
for (int i = 0; i < loop + skip; i++) {
if (hipThreadIdx_x == 0 && i == skip) {
@@ -84,7 +83,6 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -39,7 +39,6 @@
int wg_id = get_flat_grid_id();
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
/**
* Shared array to capture the start time for each wavefront
@@ -121,7 +120,6 @@
start_time[wg_id] = wf_start_time[0];
}
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -35,11 +35,9 @@ __global__ void EmptyTest(int loop, int skip, long long int *start_time,
long long int *end_time, int size, TestType type,
ShmemContextType ctx_type) {
__shared__ rocshmem_ctx_t ctx;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -37,7 +37,6 @@ __global__ void PingAllTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int pe = rocshmem_ctx_my_pe(ctx);
@@ -64,7 +63,6 @@ __global__ void PingAllTest(int loop, int skip, long long int *start_time,
rocshmem_ctx_quiet(ctx);
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -37,7 +37,6 @@ __global__ void PingPongTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int pe = rocshmem_ctx_my_pe(ctx);
@@ -65,7 +64,6 @@ __global__ void PingPongTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -37,7 +37,6 @@ __global__ void PrimitiveMRTest(int loop, long long int *start_time,
ShmemContextType ctx_type) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
if (hipThreadIdx_x == 0) {
@@ -57,7 +56,6 @@ __global__ void PrimitiveMRTest(int loop, long long int *start_time,
__syncthreads();
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -39,7 +39,6 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
int wg_id = get_flat_grid_id();
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
/**
@@ -121,7 +120,6 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -62,7 +62,6 @@ __global__ void RandomAccessTest(int loop, int skip, long long int *start_time,
uint32_t *PE_bins, ShmemContextType ctx_type) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int pe = rocshmem_ctx_my_pe(ctx);
@@ -97,7 +96,6 @@ __global__ void RandomAccessTest(int loop, int skip, long long int *start_time,
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -39,7 +39,6 @@ __global__ void ShmemPtrTest(int loop, int skip, long long int *start_time,
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
/**
@@ -111,7 +110,6 @@ __global__ void ShmemPtrTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -38,7 +38,6 @@ __global__ void PutmemSignalTest(int loop, int skip, long long int *start_time,
int sig_op) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
uint64_t signal = 1;
@@ -88,13 +87,11 @@ __global__ void PutmemSignalTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
__global__ void SignalFetchTest(int loop, int skip, long long int *start_time,
long long int *end_time, uint64_t *sig_addr,
uint64_t *fetched_value, TestType type) {
rocshmem_wg_init();
int wg_id = get_flat_grid_id();
@@ -125,7 +122,6 @@ __global__ void SignalFetchTest(int loop, int skip, long long int *start_time,
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -39,7 +39,6 @@ __global__ void SyncAllTest(int loop, int skip, long long int *start_time,
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
for (int i = 0; i < loop + skip; i++) {
if (hipThreadIdx_x == 0 && i == skip) {
@@ -84,7 +83,6 @@ __global__ void SyncAllTest(int loop, int skip, long long int *start_time,
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -36,7 +36,6 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
for (int i = 0; i < loop + skip; i++) {
@@ -69,7 +68,6 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -64,7 +64,6 @@ __global__ void TeamAlltoallTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
@@ -91,7 +90,6 @@ __global__ void TeamAlltoallTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -36,7 +36,6 @@ __global__ void TeamBarrierTest(int loop, int skip, long long int *start_time,
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
for (int i = 0; i < loop + skip; i++) {
@@ -69,7 +68,6 @@ __global__ void TeamBarrierTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -69,7 +69,6 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
@@ -97,7 +96,6 @@ __global__ void TeamBroadcastTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -44,7 +44,6 @@ rocshmem_team_t team_world_dup[NUM_TEAMS];
int expected_pe, int expected_n_pes) {
__shared__ rocshmem_ctx_t ctx;
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(team, ctx_type, &ctx);
int num_pes = rocshmem_ctx_n_pes(ctx);
@@ -64,7 +63,6 @@ rocshmem_team_t team_world_dup[NUM_TEAMS];
rocshmem_ctx_quiet(ctx);
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
__global__ void TeamCtxInfraTest(ShmemContextType ctx_type,
@@ -72,7 +70,6 @@ rocshmem_team_t team_world_dup[NUM_TEAMS];
__shared__ rocshmem_ctx_t ctx1, ctx2, ctx3;
__shared__ rocshmem_ctx_t ctx[NUM_TEAMS];
rocshmem_wg_init();
/**
* Test 1: Assert team infos of different ctxs
@@ -131,7 +128,6 @@ rocshmem_team_t team_world_dup[NUM_TEAMS];
rocshmem_wg_ctx_destroy(&ctx[team_i]);
}
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -43,7 +43,6 @@ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_ti
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(team, ctx_type, &ctx);
/**
@@ -114,7 +113,6 @@ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_ti
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -64,7 +64,6 @@ __global__ void TeamFcollectTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
@@ -90,7 +89,6 @@ __global__ void TeamFcollectTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -83,7 +83,6 @@ __global__ void TeamReductionTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
@@ -104,7 +103,6 @@ __global__ void TeamReductionTest(int loop, int skip, long long int *start_time,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -42,7 +42,6 @@ __global__ void WaveFrontPrimitiveTest(int loop, int skip,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
// Calculate start index for each wavefront
@@ -86,7 +85,6 @@ __global__ void WaveFrontPrimitiveTest(int loop, int skip,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
@@ -40,7 +40,6 @@ __global__ void WorkGroupPrimitiveTest(int loop, int skip,
ShmemContextType ctx_type) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
// Calculate start index for each work group
@@ -82,7 +81,6 @@ __global__ void WorkGroupPrimitiveTest(int loop, int skip,
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************