bugfix: do not dereference ctx during create_ctx if we did run out (#83)

Este commit está contenido en:
Aurelien Bouteiller
2025-04-16 10:37:44 -04:00
cometido por GitHub
padre f6ef19f5a9
commit 9befbe8293
Se han modificado 2 ficheros con 11 adiciones y 3 borrados
+6 -2
Ver fichero
@@ -289,7 +289,9 @@ __device__ int rocshmem_wg_ctx_create(long option, rocshmem_ctx_t *ctx) {
if (get_flat_block_id() == 0) {
ctx->team_opaque = reinterpret_cast<TeamInfo *>(ROCSHMEM_CTX_DEFAULT.team_opaque);
result = device_backend_proxy->create_ctx(option, ctx);
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(option);
if(result) {
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(option);
}
}
__syncthreads();
return result == true ? 0 : -1;
@@ -308,7 +310,9 @@ __device__ int rocshmem_wg_team_create_ctx(rocshmem_team_t team, long options,
TeamInfo *info_wrt_world = team_obj->tinfo_wrt_world;
ctx->team_opaque = info_wrt_world;
result = device_backend_proxy->create_ctx(options, ctx);
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(options);
if(result) {
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(options);
}
}
__syncthreads();
@@ -51,9 +51,12 @@ __global__ void TeamCtxInfraTest(ShmemContextType ctx_type,
*/
rocshmem_wg_team_create_ctx(team[0], ctx_type, &ctx1);
assert (nullptr != ctx1.ctx_opaque);
rocshmem_wg_team_create_ctx(team[0], ctx_type, &ctx2);
assert (nullptr != ctx2.ctx_opaque);
rocshmem_wg_ctx_destroy(&ctx1);
rocshmem_wg_team_create_ctx(team[0], ctx_type, &ctx3);
assert (nullptr != ctx3.ctx_opaque);
__syncthreads();
@@ -73,6 +76,7 @@ __global__ void TeamCtxInfraTest(ShmemContextType ctx_type,
*/
for (int team_i = 0; team_i < NUM_TEAMS; team_i++) {
rocshmem_wg_team_create_ctx(team[team_i], ctx_type, &ctx[team_i]);
assert (nullptr != ctx.ctx_opaque);
}
if (ctx[0].team_opaque == ctx[NUM_TEAMS - 1].team_opaque) {
@@ -133,7 +137,7 @@ void TeamCtxInfraTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
sizeof(rocshmem_team_t) * NUM_TEAMS, hipMemcpyHostToDevice));
hipLaunchKernelGGL(TeamCtxInfraTest, gridSize, blockSize, shared_bytes,
stream, _shmem_context, teams_on_device);
stream, _shmem_context, teams_on_device);
CHECK_HIP(hipFree(teams_on_device));
}