diff --git a/hipamd/api/hip/hip_context.cpp b/hipamd/api/hip/hip_context.cpp index 2a67898bf3..983dc9b13c 100644 --- a/hipamd/api/hip/hip_context.cpp +++ b/hipamd/api/hip/hip_context.cpp @@ -21,13 +21,14 @@ THE SOFTWARE. */ #include - #include "hip_internal.hpp" #include "platform/runtime.hpp" #include "utils/versions.hpp" - +#include thread_local amd::Context* g_context = nullptr; +thread_local std::stack g_ctxtStack; + std::vector g_devices; hipError_t hipInit(unsigned int flags) @@ -65,6 +66,9 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device) *ctx = reinterpret_cast(g_devices[device]); + // Increment ref count for device primary context + g_devices[device]->retain(); + return hipSuccess; } @@ -72,7 +76,17 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx) { HIP_INIT_API(ctx); - g_context = reinterpret_cast(ctx); + if (ctx == nullptr) { + if(!g_ctxtStack.empty()) { + g_ctxtStack.pop(); + } + } else { + g_context = reinterpret_cast(as_amd(ctx)); + if(!g_ctxtStack.empty()) { + g_ctxtStack.pop(); + } + g_ctxtStack.push(g_context); + } return hipSuccess; } @@ -98,3 +112,63 @@ hipError_t hipRuntimeGetVersion(int *runtimeVersion) return hipSuccess; } + +hipError_t hipCtxDestroy(hipCtx_t ctx) +{ + HIP_INIT_API(ctx); + + amd::Context* amdContext = reinterpret_cast(as_amd(ctx)); + if (amdContext == nullptr) { + return hipErrorInvalidValue; + } + + // Need to remove the ctx of calling thread if its the top one + if (g_context == amdContext) { + g_ctxtStack.pop(); + } + + // Remove context from global context list + for (unsigned int i = 0; i < g_devices.size(); i++) { + if (g_devices[i] == amdContext) { + // Decrement ref count for device primary context + amdContext->release(); + } + } + + return hipSuccess; +} + + +hipError_t hipCtxPopCurrent(hipCtx_t* ctx) +{ + HIP_INIT_API(ctx); + + amd::Context* amdContext = reinterpret_cast(as_amd(ctx)); + if (amdContext == nullptr) { + return hipErrorInvalidContext; + } + + if (!g_ctxtStack.empty()) { + amdContext = g_ctxtStack.top(); + g_ctxtStack.pop(); + } else { + return hipErrorInvalidContext; + } + + return hipSuccess; +} + +hipError_t hipCtxPushCurrent(hipCtx_t ctx) +{ + HIP_INIT_API(ctx); + + amd::Context* amdContext = reinterpret_cast(as_amd(ctx)); + if (amdContext == nullptr) { + return hipErrorInvalidContext; + } + + g_context = amdContext; + g_ctxtStack.push(g_context); + + return hipSuccess; +} \ No newline at end of file