diff --git a/projects/clr/hipamd/src/hip_context.cpp b/projects/clr/hipamd/src/hip_context.cpp index 11ef6d6da5..69d75e7f31 100644 --- a/projects/clr/hipamd/src/hip_context.cpp +++ b/projects/clr/hipamd/src/hip_context.cpp @@ -31,6 +31,7 @@ THE SOFTWARE. // Stack of contexts thread_local std::stack tls_ctxStack; +thread_local bool tls_getPrimaryCtx = true; void ihipCtxStackUpdate() { @@ -65,6 +66,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device) *ctx = ictx; ihipSetTlsDefaultCtx(*ctx); tls_ctxStack.push(*ctx); + tls_getPrimaryCtx = false; deviceCrit->addContext(ictx); } @@ -93,8 +95,7 @@ hipError_t hipDriverGetVersion(int *driverVersion) hipError_t e = hipSuccess; if (driverVersion) { *driverVersion = 4; - } - else { + } else { e = hipErrorInvalidValue; } @@ -107,8 +108,7 @@ hipError_t hipRuntimeGetVersion(int *runtimeVersion) hipError_t e = hipSuccess; if (runtimeVersion) { *runtimeVersion = HIP_VERSION_PATCH; - } - else { + } else { e = hipErrorInvalidValue; } @@ -124,9 +124,7 @@ hipError_t hipCtxDestroy(hipCtx_t ctx) if(primaryCtx== ctx) { e = hipErrorInvalidValue; - } - else - { + } else { if(currentCtx == ctx) { //need to destroy the ctx associated with calling thread tls_ctxStack.pop(); @@ -146,19 +144,21 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx) { HIP_INIT_API(ctx); hipError_t e = hipSuccess; - ihipCtx_t* tempCtx; - *ctx = ihipGetTlsDefaultCtx(); + ihipCtx_t* currentCtx = ihipGetTlsDefaultCtx(); + auto deviceHandle = currentCtx->getDevice(); + *ctx = currentCtx; + if(!tls_ctxStack.empty()) { tls_ctxStack.pop(); } + if(!tls_ctxStack.empty()) { - tempCtx= tls_ctxStack.top(); - } - else { - tempCtx = nullptr; + currentCtx= tls_ctxStack.top(); + } else { + currentCtx = deviceHandle->_primaryCtx; } - ihipSetTlsDefaultCtx(tempCtx); //TOD0 - Shall check for NULL? + ihipSetTlsDefaultCtx(currentCtx); //TOD0 - Shall check for NULL? return ihipLogStatus(e); } @@ -169,8 +169,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx) if(ctx != NULL) { //TODO- is this check needed? ihipSetTlsDefaultCtx(ctx); tls_ctxStack.push(ctx); - } - else { + tls_getPrimaryCtx = false; + } else { e = hipErrorInvalidContext; } return ihipLogStatus(e); @@ -180,12 +180,11 @@ hipError_t hipCtxGetCurrent(hipCtx_t* ctx) { HIP_INIT_API(ctx); hipError_t e = hipSuccess; - if(!tls_ctxStack.empty()) { + if((tls_getPrimaryCtx) || tls_ctxStack.empty()) { + *ctx = ihipGetTlsDefaultCtx(); + } else { *ctx= tls_ctxStack.top(); } - else { - *ctx = NULL; - } return ihipLogStatus(e); } @@ -195,10 +194,10 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx) hipError_t e = hipSuccess; if(ctx == NULL) { tls_ctxStack.pop(); - } - else { + } else { ihipSetTlsDefaultCtx(ctx); tls_ctxStack.push(ctx); + tls_getPrimaryCtx = false; } return ihipLogStatus(e); } @@ -213,8 +212,7 @@ hipError_t hipCtxGetDevice(hipDevice_t *device) if(ctx == nullptr) { e = hipErrorInvalidContext; // TODO *device = nullptr; - } - else { + } else { auto deviceHandle = ctx->getDevice(); *device = deviceHandle->_deviceId; } diff --git a/projects/clr/hipamd/src/hip_device.cpp b/projects/clr/hipamd/src/hip_device.cpp index 5ff6dbf04d..1800c9369c 100644 --- a/projects/clr/hipamd/src/hip_device.cpp +++ b/projects/clr/hipamd/src/hip_device.cpp @@ -146,6 +146,7 @@ hipError_t hipSetDevice(int deviceId) return ihipLogStatus(hipErrorInvalidDevice); } else { ihipSetTlsDefaultCtx(ihipGetPrimaryCtx(deviceId)); + tls_getPrimaryCtx = true; return ihipLogStatus(hipSuccess); } } diff --git a/projects/clr/hipamd/src/hip_hcc_internal.h b/projects/clr/hipamd/src/hip_hcc_internal.h index 88c7eedda0..4cb85ffc19 100644 --- a/projects/clr/hipamd/src/hip_hcc_internal.h +++ b/projects/clr/hipamd/src/hip_hcc_internal.h @@ -114,6 +114,7 @@ private: //Extern tls extern thread_local hipError_t tls_lastHipError; extern thread_local TidInfo tls_tidInfo; +extern thread_local bool tls_getPrimaryCtx; extern std::vector g_dbStartTriggers; extern std::vector g_dbStopTriggers;