diff --git a/projects/rocshmem/docs/api/ctx.rst b/projects/rocshmem/docs/api/ctx.rst index b9aa2dd52f..a32c823482 100644 --- a/projects/rocshmem/docs/api/ctx.rst +++ b/projects/rocshmem/docs/api/ctx.rst @@ -16,15 +16,19 @@ ROCSHMEM_CTX_CREATE :param team: Team handle to derive the context from. :param options: Options for context creation. Ignored in current design; use the value ``0``. - :param ctx: Context handle. + :param ctx: A handle to the newly created context. :returns: All threads returns ``0`` if the context was created successfully. - If any thread returns non-zero value, the operation fails and a higher number of - ``ROCSHMEM_MAX_NUM_CONTEXTS`` is required. + If any thread returns non-zero value, the operation fails, ctx is set to ``ROCSHMEM_CTX_INVALID`` and a + higher number of ``ROCSHMEM_MAX_NUM_CONTEXTS`` is required. **Description:** -This routine creates an OpenSHMEM context. By design, the context is private to the calling work-group. -It must be called collectively by all threads in the work-group. +This routine creates an rocSHMEM context. By design, the context is private to the calling work-group. +It must be called collectively by all threads in the work-group. If the context was created successfully, a value +of zero is returned and the context handle pointed to by ctx specifies a valid context; otherwise, a nonzero value +is returned and ctx is set to ``ROCSHMEM_CTX_INVALID``. An unsuccessful context creation call is not treated as an +error and the rocSHMEM library remains in a correct state. The creation call can be reattempted after additional +resources become available. ROCSHMEM_CTX_DESTROY -------------------- @@ -36,8 +40,8 @@ ROCSHMEM_CTX_DESTROY :returns: None. **Description:** -This routine destroys an rocSHMEM context. -It must be called collectively by all threads in the work-group. +This routine destroys an rocSHMEM context. It must be called collectively by all threads in the work-group. +If ctx has the value ``ROCSHMEM_CTX_INVALID``, no operation is performed. ROCSHMEM_GET_DEVICE_CTX ----------------------- diff --git a/projects/rocshmem/include/rocshmem/rocshmem_common.hpp b/projects/rocshmem/include/rocshmem/rocshmem_common.hpp index 288aa98292..2795aa1dfc 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_common.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_common.hpp @@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8; * @brief GPU side OpenSHMEM context created from each work-groups' * rocshmem_wg_handle_t */ -typedef struct { +typedef struct rocshmem_ctx{ void *ctx_opaque; void *team_opaque; + + __host__ __device__ bool operator==(const struct rocshmem_ctx& other) const { + return (ctx_opaque == other.ctx_opaque && + team_opaque == other.team_opaque); + } + + __host__ __device__ bool operator!=(const struct rocshmem_ctx& other) const { + return !(*this == other); + } } rocshmem_ctx_t; /** @@ -116,6 +125,14 @@ typedef struct { */ extern "C" __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROCSHMEM_CTX_DEFAULT; +/** + * A value corresponding to an invalid communication context. This value can be + * used to initialize or update context handles to indicate that they do not + * reference a valid context. When managed in this way, applications can use an + * equality comparison to test whether a given context handle references a + * valid context. + */ +extern __constant__ rocshmem_ctx_t ROCSHMEM_CTX_INVALID; /** * Used internally to set default context. */ diff --git a/projects/rocshmem/src/rocshmem_gpu.cpp b/projects/rocshmem/src/rocshmem_gpu.cpp index a7cbe47d08..b0eea85859 100644 --- a/projects/rocshmem/src/rocshmem_gpu.cpp +++ b/projects/rocshmem/src/rocshmem_gpu.cpp @@ -74,6 +74,8 @@ __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROCSHMEM_CTX_D __constant__ Backend *device_backend_proxy; +__constant__ rocshmem_ctx_t ROCSHMEM_CTX_INVALID = {nullptr, nullptr}; + #if defined(ENABLE_IPC_BITCODE) typedef IPCContext ContextTy; #else @@ -324,6 +326,9 @@ __device__ int rocshmem_wg_ctx_create(long options, rocshmem_ctx_t *ctx) { if(result) { reinterpret_cast(ctx->ctx_opaque)->setFence(options); } + else { + *ctx = ROCSHMEM_CTX_INVALID; + } } __syncthreads(); return result == true ? 0 : -1; @@ -346,6 +351,9 @@ __device__ int rocshmem_wg_team_create_ctx(rocshmem_team_t team, long options, if(result) { reinterpret_cast(ctx->ctx_opaque)->setFence(options); } + else { + *ctx = ROCSHMEM_CTX_INVALID; + } } __syncthreads(); @@ -357,7 +365,7 @@ __device__ void rocshmem_wg_ctx_destroy( GPU_DPRINTF("Function: rocshmem_wg_ctx_destroy (ctx=%zd)\n", ctx->ctx_opaque); - if (get_flat_block_id() == 0) { + if (get_flat_block_id() == 0 && *ctx != ROCSHMEM_CTX_INVALID) { device_backend_proxy->destroy_ctx(ctx); } }