Merge pull request #145 from gargrahul/context_mgmt_changes
Updated context management logic:
Dieser Commit ist enthalten in:
@@ -31,6 +31,7 @@ THE SOFTWARE.
|
||||
|
||||
// Stack of contexts
|
||||
thread_local std::stack<ihipCtx_t *> tls_ctxStack;
|
||||
thread_local bool tls_getPrimaryCtx = true;
|
||||
|
||||
void ihipCtxStackUpdate()
|
||||
{
|
||||
@@ -65,6 +66,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device)
|
||||
*ctx = ictx;
|
||||
ihipSetTlsDefaultCtx(*ctx);
|
||||
tls_ctxStack.push(*ctx);
|
||||
tls_getPrimaryCtx = false;
|
||||
deviceCrit->addContext(ictx);
|
||||
}
|
||||
|
||||
@@ -93,8 +95,7 @@ hipError_t hipDriverGetVersion(int *driverVersion)
|
||||
hipError_t e = hipSuccess;
|
||||
if (driverVersion) {
|
||||
*driverVersion = 4;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
@@ -107,8 +108,7 @@ hipError_t hipRuntimeGetVersion(int *runtimeVersion)
|
||||
hipError_t e = hipSuccess;
|
||||
if (runtimeVersion) {
|
||||
*runtimeVersion = HIP_VERSION_PATCH;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
@@ -124,9 +124,7 @@ hipError_t hipCtxDestroy(hipCtx_t ctx)
|
||||
if(primaryCtx== ctx)
|
||||
{
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
if(currentCtx == ctx) {
|
||||
//need to destroy the ctx associated with calling thread
|
||||
tls_ctxStack.pop();
|
||||
@@ -146,19 +144,21 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx)
|
||||
{
|
||||
HIP_INIT_API(ctx);
|
||||
hipError_t e = hipSuccess;
|
||||
ihipCtx_t* tempCtx;
|
||||
*ctx = ihipGetTlsDefaultCtx();
|
||||
ihipCtx_t* currentCtx = ihipGetTlsDefaultCtx();
|
||||
auto deviceHandle = currentCtx->getDevice();
|
||||
*ctx = currentCtx;
|
||||
|
||||
if(!tls_ctxStack.empty()) {
|
||||
tls_ctxStack.pop();
|
||||
}
|
||||
|
||||
if(!tls_ctxStack.empty()) {
|
||||
tempCtx= tls_ctxStack.top();
|
||||
}
|
||||
else {
|
||||
tempCtx = nullptr;
|
||||
currentCtx= tls_ctxStack.top();
|
||||
} else {
|
||||
currentCtx = deviceHandle->_primaryCtx;
|
||||
}
|
||||
|
||||
ihipSetTlsDefaultCtx(tempCtx); //TOD0 - Shall check for NULL?
|
||||
ihipSetTlsDefaultCtx(currentCtx); //TOD0 - Shall check for NULL?
|
||||
return ihipLogStatus(e);
|
||||
}
|
||||
|
||||
@@ -169,8 +169,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx)
|
||||
if(ctx != NULL) { //TODO- is this check needed?
|
||||
ihipSetTlsDefaultCtx(ctx);
|
||||
tls_ctxStack.push(ctx);
|
||||
}
|
||||
else {
|
||||
tls_getPrimaryCtx = false;
|
||||
} else {
|
||||
e = hipErrorInvalidContext;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
@@ -180,12 +180,11 @@ hipError_t hipCtxGetCurrent(hipCtx_t* ctx)
|
||||
{
|
||||
HIP_INIT_API(ctx);
|
||||
hipError_t e = hipSuccess;
|
||||
if(!tls_ctxStack.empty()) {
|
||||
if((tls_getPrimaryCtx) || tls_ctxStack.empty()) {
|
||||
*ctx = ihipGetTlsDefaultCtx();
|
||||
} else {
|
||||
*ctx= tls_ctxStack.top();
|
||||
}
|
||||
else {
|
||||
*ctx = NULL;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
}
|
||||
|
||||
@@ -195,10 +194,10 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx)
|
||||
hipError_t e = hipSuccess;
|
||||
if(ctx == NULL) {
|
||||
tls_ctxStack.pop();
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
ihipSetTlsDefaultCtx(ctx);
|
||||
tls_ctxStack.push(ctx);
|
||||
tls_getPrimaryCtx = false;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
}
|
||||
@@ -213,8 +212,7 @@ hipError_t hipCtxGetDevice(hipDevice_t *device)
|
||||
if(ctx == nullptr) {
|
||||
e = hipErrorInvalidContext;
|
||||
// TODO *device = nullptr;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
auto deviceHandle = ctx->getDevice();
|
||||
*device = deviceHandle->_deviceId;
|
||||
}
|
||||
|
||||
@@ -146,6 +146,7 @@ hipError_t hipSetDevice(int deviceId)
|
||||
return ihipLogStatus(hipErrorInvalidDevice);
|
||||
} else {
|
||||
ihipSetTlsDefaultCtx(ihipGetPrimaryCtx(deviceId));
|
||||
tls_getPrimaryCtx = true;
|
||||
return ihipLogStatus(hipSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +114,7 @@ private:
|
||||
//Extern tls
|
||||
extern thread_local hipError_t tls_lastHipError;
|
||||
extern thread_local TidInfo tls_tidInfo;
|
||||
extern thread_local bool tls_getPrimaryCtx;
|
||||
|
||||
extern std::vector<ProfTrigger> g_dbStartTriggers;
|
||||
extern std::vector<ProfTrigger> g_dbStopTriggers;
|
||||
|
||||
In neuem Issue referenzieren
Einen Benutzer sperren