Updated context management logic:
1) hipSetDevice sets a flag so that next call to hipCtxGetCurrent returns primary context on current device
2) hipCtxGetCurrent returns primary context on current device if TLS context stack is empty
3) hipCtxPopCurrent falls back to primary context on current device as default
4) hipCtxPushCurrent, hipCtxSetCurrent and hipCtxCreate reset the flag set in hipSetDevice
[ROCm/clr commit: c4e9323877]
This commit is contained in:
@@ -31,6 +31,7 @@ THE SOFTWARE.
|
||||
|
||||
// Stack of contexts
|
||||
thread_local std::stack<ihipCtx_t *> tls_ctxStack;
|
||||
thread_local bool tls_getPrimaryCtx = true;
|
||||
|
||||
void ihipCtxStackUpdate()
|
||||
{
|
||||
@@ -65,6 +66,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device)
|
||||
*ctx = ictx;
|
||||
ihipSetTlsDefaultCtx(*ctx);
|
||||
tls_ctxStack.push(*ctx);
|
||||
tls_getPrimaryCtx = false;
|
||||
deviceCrit->addContext(ictx);
|
||||
}
|
||||
|
||||
@@ -93,8 +95,7 @@ hipError_t hipDriverGetVersion(int *driverVersion)
|
||||
hipError_t e = hipSuccess;
|
||||
if (driverVersion) {
|
||||
*driverVersion = 4;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
@@ -107,8 +108,7 @@ hipError_t hipRuntimeGetVersion(int *runtimeVersion)
|
||||
hipError_t e = hipSuccess;
|
||||
if (runtimeVersion) {
|
||||
*runtimeVersion = HIP_VERSION_PATCH;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
@@ -124,9 +124,7 @@ hipError_t hipCtxDestroy(hipCtx_t ctx)
|
||||
if(primaryCtx== ctx)
|
||||
{
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
if(currentCtx == ctx) {
|
||||
//need to destroy the ctx associated with calling thread
|
||||
tls_ctxStack.pop();
|
||||
@@ -146,19 +144,21 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx)
|
||||
{
|
||||
HIP_INIT_API(ctx);
|
||||
hipError_t e = hipSuccess;
|
||||
ihipCtx_t* tempCtx;
|
||||
*ctx = ihipGetTlsDefaultCtx();
|
||||
ihipCtx_t* currentCtx = ihipGetTlsDefaultCtx();
|
||||
auto deviceHandle = currentCtx->getDevice();
|
||||
*ctx = currentCtx;
|
||||
|
||||
if(!tls_ctxStack.empty()) {
|
||||
tls_ctxStack.pop();
|
||||
}
|
||||
|
||||
if(!tls_ctxStack.empty()) {
|
||||
tempCtx= tls_ctxStack.top();
|
||||
}
|
||||
else {
|
||||
tempCtx = nullptr;
|
||||
currentCtx= tls_ctxStack.top();
|
||||
} else {
|
||||
currentCtx = deviceHandle->_primaryCtx;
|
||||
}
|
||||
|
||||
ihipSetTlsDefaultCtx(tempCtx); //TOD0 - Shall check for NULL?
|
||||
ihipSetTlsDefaultCtx(currentCtx); //TOD0 - Shall check for NULL?
|
||||
return ihipLogStatus(e);
|
||||
}
|
||||
|
||||
@@ -169,8 +169,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx)
|
||||
if(ctx != NULL) { //TODO- is this check needed?
|
||||
ihipSetTlsDefaultCtx(ctx);
|
||||
tls_ctxStack.push(ctx);
|
||||
}
|
||||
else {
|
||||
tls_getPrimaryCtx = false;
|
||||
} else {
|
||||
e = hipErrorInvalidContext;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
@@ -180,12 +180,11 @@ hipError_t hipCtxGetCurrent(hipCtx_t* ctx)
|
||||
{
|
||||
HIP_INIT_API(ctx);
|
||||
hipError_t e = hipSuccess;
|
||||
if(!tls_ctxStack.empty()) {
|
||||
if((tls_getPrimaryCtx) || tls_ctxStack.empty()) {
|
||||
*ctx = ihipGetTlsDefaultCtx();
|
||||
} else {
|
||||
*ctx= tls_ctxStack.top();
|
||||
}
|
||||
else {
|
||||
*ctx = NULL;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
}
|
||||
|
||||
@@ -195,10 +194,10 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx)
|
||||
hipError_t e = hipSuccess;
|
||||
if(ctx == NULL) {
|
||||
tls_ctxStack.pop();
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
ihipSetTlsDefaultCtx(ctx);
|
||||
tls_ctxStack.push(ctx);
|
||||
tls_getPrimaryCtx = false;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
}
|
||||
@@ -213,8 +212,7 @@ hipError_t hipCtxGetDevice(hipDevice_t *device)
|
||||
if(ctx == nullptr) {
|
||||
e = hipErrorInvalidContext;
|
||||
// TODO *device = nullptr;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
auto deviceHandle = ctx->getDevice();
|
||||
*device = deviceHandle->_deviceId;
|
||||
}
|
||||
|
||||
@@ -146,6 +146,7 @@ hipError_t hipSetDevice(int deviceId)
|
||||
return ihipLogStatus(hipErrorInvalidDevice);
|
||||
} else {
|
||||
ihipSetTlsDefaultCtx(ihipGetPrimaryCtx(deviceId));
|
||||
tls_getPrimaryCtx = true;
|
||||
return ihipLogStatus(hipSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +114,7 @@ private:
|
||||
//Extern tls
|
||||
extern thread_local hipError_t tls_lastHipError;
|
||||
extern thread_local TidInfo tls_tidInfo;
|
||||
extern thread_local bool tls_getPrimaryCtx;
|
||||
|
||||
extern std::vector<ProfTrigger> g_dbStartTriggers;
|
||||
extern std::vector<ProfTrigger> g_dbStopTriggers;
|
||||
|
||||
Reference in New Issue
Block a user