diff --git a/src/rocshmem_gpu.cpp b/src/rocshmem_gpu.cpp index 8d7aeb29ac..0809c42096 100644 --- a/src/rocshmem_gpu.cpp +++ b/src/rocshmem_gpu.cpp @@ -289,7 +289,9 @@ __device__ int rocshmem_wg_ctx_create(long option, rocshmem_ctx_t *ctx) { if (get_flat_block_id() == 0) { ctx->team_opaque = reinterpret_cast(ROCSHMEM_CTX_DEFAULT.team_opaque); result = device_backend_proxy->create_ctx(option, ctx); - reinterpret_cast(ctx->ctx_opaque)->setFence(option); + if(result) { + reinterpret_cast(ctx->ctx_opaque)->setFence(option); + } } __syncthreads(); return result == true ? 0 : -1; @@ -308,7 +310,9 @@ __device__ int rocshmem_wg_team_create_ctx(rocshmem_team_t team, long options, TeamInfo *info_wrt_world = team_obj->tinfo_wrt_world; ctx->team_opaque = info_wrt_world; result = device_backend_proxy->create_ctx(options, ctx); - reinterpret_cast(ctx->ctx_opaque)->setFence(options); + if(result) { + reinterpret_cast(ctx->ctx_opaque)->setFence(options); + } } __syncthreads(); diff --git a/tests/functional_tests/team_ctx_infra_tester.cpp b/tests/functional_tests/team_ctx_infra_tester.cpp index 3c05b1ba15..1a973c55b8 100644 --- a/tests/functional_tests/team_ctx_infra_tester.cpp +++ b/tests/functional_tests/team_ctx_infra_tester.cpp @@ -51,9 +51,12 @@ __global__ void TeamCtxInfraTest(ShmemContextType ctx_type, */ rocshmem_wg_team_create_ctx(team[0], ctx_type, &ctx1); + assert (nullptr != ctx1.ctx_opaque); rocshmem_wg_team_create_ctx(team[0], ctx_type, &ctx2); + assert (nullptr != ctx2.ctx_opaque); rocshmem_wg_ctx_destroy(&ctx1); rocshmem_wg_team_create_ctx(team[0], ctx_type, &ctx3); + assert (nullptr != ctx3.ctx_opaque); __syncthreads(); @@ -73,6 +76,7 @@ __global__ void TeamCtxInfraTest(ShmemContextType ctx_type, */ for (int team_i = 0; team_i < NUM_TEAMS; team_i++) { rocshmem_wg_team_create_ctx(team[team_i], ctx_type, &ctx[team_i]); + assert (nullptr != ctx.ctx_opaque); } if (ctx[0].team_opaque == ctx[NUM_TEAMS - 1].team_opaque) { @@ -133,7 +137,7 @@ void TeamCtxInfraTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, sizeof(rocshmem_team_t) * NUM_TEAMS, hipMemcpyHostToDevice)); hipLaunchKernelGGL(TeamCtxInfraTest, gridSize, blockSize, shared_bytes, - stream, _shmem_context, teams_on_device); + stream, _shmem_context, teams_on_device); CHECK_HIP(hipFree(teams_on_device)); }