From 687fac97dc04b0e5ffe449f3709bb2ab37e98536 Mon Sep 17 00:00:00 2001 From: Rahul Garg Date: Tue, 8 Aug 2017 07:02:22 +0530 Subject: [PATCH] Updated context management logic: 1) hipSetDevice sets a flag so that next call to hipCtxGetCurrent returns primary context on current device 2) hipCtxGetCurrent returns primary context on current device if TLS context stack is empty 3) hipCtxPopCurrent falls back to primary context on current device as default 4) hipCtxPushCurrent, hipCtxSetCurrent and hipCtxCreate reset the flag set in hipSetDevice [ROCm/clr commit: c4e93238772e5699ee1dc97269950d4165b63079] --- projects/clr/hipamd/src/hip_context.cpp | 46 +++++++++++----------- projects/clr/hipamd/src/hip_device.cpp | 1 + projects/clr/hipamd/src/hip_hcc_internal.h | 1 + 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/projects/clr/hipamd/src/hip_context.cpp b/projects/clr/hipamd/src/hip_context.cpp index 11ef6d6da5..69d75e7f31 100644 --- a/projects/clr/hipamd/src/hip_context.cpp +++ b/projects/clr/hipamd/src/hip_context.cpp @@ -31,6 +31,7 @@ THE SOFTWARE. // Stack of contexts thread_local std::stack tls_ctxStack; +thread_local bool tls_getPrimaryCtx = true; void ihipCtxStackUpdate() { @@ -65,6 +66,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device) *ctx = ictx; ihipSetTlsDefaultCtx(*ctx); tls_ctxStack.push(*ctx); + tls_getPrimaryCtx = false; deviceCrit->addContext(ictx); } @@ -93,8 +95,7 @@ hipError_t hipDriverGetVersion(int *driverVersion) hipError_t e = hipSuccess; if (driverVersion) { *driverVersion = 4; - } - else { + } else { e = hipErrorInvalidValue; } @@ -107,8 +108,7 @@ hipError_t hipRuntimeGetVersion(int *runtimeVersion) hipError_t e = hipSuccess; if (runtimeVersion) { *runtimeVersion = HIP_VERSION_PATCH; - } - else { + } else { e = hipErrorInvalidValue; } @@ -124,9 +124,7 @@ hipError_t hipCtxDestroy(hipCtx_t ctx) if(primaryCtx== ctx) { e = hipErrorInvalidValue; - } - else - { + } else { if(currentCtx == ctx) { //need to destroy the ctx associated with calling thread tls_ctxStack.pop(); @@ -146,19 +144,21 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx) { HIP_INIT_API(ctx); hipError_t e = hipSuccess; - ihipCtx_t* tempCtx; - *ctx = ihipGetTlsDefaultCtx(); + ihipCtx_t* currentCtx = ihipGetTlsDefaultCtx(); + auto deviceHandle = currentCtx->getDevice(); + *ctx = currentCtx; + if(!tls_ctxStack.empty()) { tls_ctxStack.pop(); } + if(!tls_ctxStack.empty()) { - tempCtx= tls_ctxStack.top(); - } - else { - tempCtx = nullptr; + currentCtx= tls_ctxStack.top(); + } else { + currentCtx = deviceHandle->_primaryCtx; } - ihipSetTlsDefaultCtx(tempCtx); //TOD0 - Shall check for NULL? + ihipSetTlsDefaultCtx(currentCtx); //TOD0 - Shall check for NULL? return ihipLogStatus(e); } @@ -169,8 +169,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx) if(ctx != NULL) { //TODO- is this check needed? ihipSetTlsDefaultCtx(ctx); tls_ctxStack.push(ctx); - } - else { + tls_getPrimaryCtx = false; + } else { e = hipErrorInvalidContext; } return ihipLogStatus(e); @@ -180,12 +180,11 @@ hipError_t hipCtxGetCurrent(hipCtx_t* ctx) { HIP_INIT_API(ctx); hipError_t e = hipSuccess; - if(!tls_ctxStack.empty()) { + if((tls_getPrimaryCtx) || tls_ctxStack.empty()) { + *ctx = ihipGetTlsDefaultCtx(); + } else { *ctx= tls_ctxStack.top(); } - else { - *ctx = NULL; - } return ihipLogStatus(e); } @@ -195,10 +194,10 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx) hipError_t e = hipSuccess; if(ctx == NULL) { tls_ctxStack.pop(); - } - else { + } else { ihipSetTlsDefaultCtx(ctx); tls_ctxStack.push(ctx); + tls_getPrimaryCtx = false; } return ihipLogStatus(e); } @@ -213,8 +212,7 @@ hipError_t hipCtxGetDevice(hipDevice_t *device) if(ctx == nullptr) { e = hipErrorInvalidContext; // TODO *device = nullptr; - } - else { + } else { auto deviceHandle = ctx->getDevice(); *device = deviceHandle->_deviceId; } diff --git a/projects/clr/hipamd/src/hip_device.cpp b/projects/clr/hipamd/src/hip_device.cpp index 5ff6dbf04d..1800c9369c 100644 --- a/projects/clr/hipamd/src/hip_device.cpp +++ b/projects/clr/hipamd/src/hip_device.cpp @@ -146,6 +146,7 @@ hipError_t hipSetDevice(int deviceId) return ihipLogStatus(hipErrorInvalidDevice); } else { ihipSetTlsDefaultCtx(ihipGetPrimaryCtx(deviceId)); + tls_getPrimaryCtx = true; return ihipLogStatus(hipSuccess); } } diff --git a/projects/clr/hipamd/src/hip_hcc_internal.h b/projects/clr/hipamd/src/hip_hcc_internal.h index 88c7eedda0..4cb85ffc19 100644 --- a/projects/clr/hipamd/src/hip_hcc_internal.h +++ b/projects/clr/hipamd/src/hip_hcc_internal.h @@ -114,6 +114,7 @@ private: //Extern tls extern thread_local hipError_t tls_lastHipError; extern thread_local TidInfo tls_tidInfo; +extern thread_local bool tls_getPrimaryCtx; extern std::vector g_dbStartTriggers; extern std::vector g_dbStopTriggers;