Co-authored-by: Aurelien Bouteiller <aurelien.bouteiller@amd.com>

[ROCm/rocshmem commit: c3eeae473b]
Этот коммит содержится в:
Yiltan
2025-10-20 11:42:39 -04:00
коммит произвёл GitHub
родитель 6bc1cc63ae
Коммит 92a7904656
14 изменённых файлов: 79 добавлений и 2 удалений
+2
Просмотреть файл
@@ -9,6 +9,8 @@
* Mellanox MLX5 (IB and RoCE ConnectX-7)
* Added new APIs:
* `rocshmem_get_device_ctx`
* `rocshmem_ctx_pe_quiet`
* `rocshmem_pe_quiet`
### Changed
* The following APIs have been deprecated:
+15
Просмотреть файл
@@ -34,3 +34,18 @@ ROCSHMEM_QUIET
**Description:**
This routine completes all previous operations posted to this context.
ROCSHMEM_PE_QUIET
--------------
.. cpp:function:: __device__ void rocshmem_ctx_pe_quiet(shmem_ctx_t ctx, const int *target_pes, size_t npes)
.. cpp:function:: __device__ void rocshmem_pe_quiet(const int *target_pes, size_t npes)
:param ctx: Context with which to perform this operation.
:param target_pes: Address of target PE array where the operations need to be completed
:param npes: The number of PEs in the target PE array
:returns: None.
**Description:**
This routine completes all previous operations posted to this context
for the PEs in the `target_pes` array.
+17
Просмотреть файл
@@ -498,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_quiet();
/**
* @brief Completes all previous operations posted to this context for PEs in the
* `target_pes` array.
*
* @param[in] ctx Context with which to perform this operation.
*
* @param[in] target_pes Address of target PE array where the operations need to be completed.
*
* @param[in] npes The number of PEs in the target PE array.
*
* @return void.
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_pe_quiet(rocshmem_ctx_t ctx, const int *target_pes, size_t npes);
__device__ ATTR_NO_INLINE void rocshmem_pe_quiet(const int *target_pes, size_t npes);
/**
* @brief Query the total number of PEs.
*
+1
Просмотреть файл
@@ -169,6 +169,7 @@ void Backend::dump_stats() {
device_stats.getStat(NUM_GET_NBI_WAVE));
printf("Fences %llu\n", device_stats.getStat(NUM_FENCE));
printf("Quiets %llu\n", device_stats.getStat(NUM_QUIET));
printf("PE Quiets %llu\n", device_stats.getStat(NUM_PE_QUIET));
printf("ToAll %llu\n", device_stats.getStat(NUM_TO_ALL));
printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL));
printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE));
+3 -1
Просмотреть файл
@@ -136,6 +136,8 @@ class Context {
__device__ void quiet();
__device__ void pe_quiet(size_t pe);
__device__ void* shmem_ptr(const void* dest, int pe);
__device__ void barrier_all();
@@ -477,7 +479,7 @@ class Context {
* @brief Duplicated local copy of backend's type
*/
BackendType btype;
/**
* @brief Stats common to all types of device contexts.
*/
+6
Просмотреть файл
@@ -139,6 +139,12 @@ __device__ void Context::quiet() {
DISPATCH(quiet());
}
__device__ void Context::pe_quiet(size_t pe) {
ctxStats.incStat(NUM_PE_QUIET);
DISPATCH(pe_quiet(pe));
}
__device__ void* Context::shmem_ptr(const void* dest, int pe) {
ctxStats.incStat(NUM_SHMEM_PTR);
+4
Просмотреть файл
@@ -177,6 +177,10 @@ __device__ void GDAContext::quiet() {
}
}
__device__ void GDAContext::pe_quiet(size_t pe) {
qps[pe].quiet();
}
__device__ void *GDAContext::shmem_ptr(const void *dest, int pe) {
void *ret = nullptr;
int local_pe{-1};
+2
Просмотреть файл
@@ -60,6 +60,8 @@ class GDAContext : public Context {
__device__ void quiet();
__device__ void pe_quiet(size_t pe);
__device__ void *shmem_ptr(const void *dest, int pe);
__device__ void barrier_all();
+4
Просмотреть файл
@@ -96,6 +96,10 @@ __device__ void IPCContext::quiet() {
fence();
}
__device__ void IPCContext::pe_quiet(size_t pe) {
fence(pe);
}
__device__ void *IPCContext::shmem_ptr(const void *dest, int pe) {
void *ret = nullptr;
void *dst = const_cast<void *>(dest);
+2
Просмотреть файл
@@ -59,6 +59,8 @@ class IPCContext : public Context {
__device__ void quiet();
__device__ void pe_quiet(size_t pe);
__device__ void *shmem_ptr(const void *dest, int pe);
__device__ void barrier_all();
+5
Просмотреть файл
@@ -161,6 +161,11 @@ __device__ void ROContext::quiet() {
true, get_status_flag(), is_default_ctx);
}
__device__ void ROContext::pe_quiet(size_t pe) {
// TODO: Optimize
quiet();
}
__device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
void *ret = nullptr;
int local_pe{-1};
+2
Просмотреть файл
@@ -61,6 +61,8 @@ class ROContext : public Context {
__device__ void quiet();
__device__ void pe_quiet(size_t pe);
__device__ void *shmem_ptr(const void *dest, int pe);
__device__ void barrier_all();
+15 -1
Просмотреть файл
@@ -107,7 +107,7 @@ __device__ void rocshmem_wg_finalize() {}
/******************************************************************************
* These host APIs use Device side symbol - ROCSHMEM_CTX_DEFAULT so it needs
* These host APIs use Device side symbol - ROCSHMEM_CTX_DEFAULT so it needs
* to stay here to avoid getting pulled into other places in compilation
******************************************************************************/
@@ -188,6 +188,10 @@ __device__ void rocshmem_quiet() {
rocshmem_ctx_quiet(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_pe_quiet(const int *target_pes, size_t npes) {
rocshmem_ctx_pe_quiet(ROCSHMEM_CTX_DEFAULT, target_pes, npes);
}
__device__ void rocshmem_threadfence_system() {
rocshmem_ctx_threadfence_system(ROCSHMEM_CTX_DEFAULT);
}
@@ -485,6 +489,16 @@ __device__ void rocshmem_ctx_quiet(rocshmem_ctx_t ctx) {
get_internal_ctx(ctx)->quiet();
}
__device__ void rocshmem_ctx_pe_quiet(rocshmem_ctx_t ctx, const int *target_pes, size_t npes) {
GPU_DPRINTF("Function: %s (ctx=%zd)\n", __FUNC__, ctx.ctx_opaque);
ContextTy *internal_ctx = get_internal_ctx(ctx);
for (int i = 0; i < npes; i++) {
internal_ctx->pe_quiet(translate_pe(ctx, target_pes[i]));
}
}
__device__ void *rocshmem_ptr(const void *dest, int pe) {
GPU_DPRINTF("Function: rocshmem_ptr (dest=%p, pe=%d w%d\n",
dest, pe, pe);
+1
Просмотреть файл
@@ -43,6 +43,7 @@ enum rocshmem_stats {
NUM_GET_NBI,
NUM_FENCE,
NUM_QUIET,
NUM_PE_QUIET,
NUM_TO_ALL,
NUM_BARRIER_ALL,
NUM_BARRIER_ALL_WAVE,