Merge pull request #145 from gargrahul/context_mgmt_changes

Updated context management logic:
Dieser Commit ist enthalten in:
Ben Sander
2017-08-09 17:47:45 -05:00
committet von GitHub
Commit dfc87a85bd
3 geänderte Dateien mit 24 neuen und 24 gelöschten Zeilen
+22 -24
Datei anzeigen
@@ -31,6 +31,7 @@ THE SOFTWARE.
// Stack of contexts
thread_local std::stack<ihipCtx_t *> tls_ctxStack;
thread_local bool tls_getPrimaryCtx = true;
void ihipCtxStackUpdate()
{
@@ -65,6 +66,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device)
*ctx = ictx;
ihipSetTlsDefaultCtx(*ctx);
tls_ctxStack.push(*ctx);
tls_getPrimaryCtx = false;
deviceCrit->addContext(ictx);
}
@@ -93,8 +95,7 @@ hipError_t hipDriverGetVersion(int *driverVersion)
hipError_t e = hipSuccess;
if (driverVersion) {
*driverVersion = 4;
}
else {
} else {
e = hipErrorInvalidValue;
}
@@ -107,8 +108,7 @@ hipError_t hipRuntimeGetVersion(int *runtimeVersion)
hipError_t e = hipSuccess;
if (runtimeVersion) {
*runtimeVersion = HIP_VERSION_PATCH;
}
else {
} else {
e = hipErrorInvalidValue;
}
@@ -124,9 +124,7 @@ hipError_t hipCtxDestroy(hipCtx_t ctx)
if(primaryCtx== ctx)
{
e = hipErrorInvalidValue;
}
else
{
} else {
if(currentCtx == ctx) {
//need to destroy the ctx associated with calling thread
tls_ctxStack.pop();
@@ -146,19 +144,21 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx)
{
HIP_INIT_API(ctx);
hipError_t e = hipSuccess;
ihipCtx_t* tempCtx;
*ctx = ihipGetTlsDefaultCtx();
ihipCtx_t* currentCtx = ihipGetTlsDefaultCtx();
auto deviceHandle = currentCtx->getDevice();
*ctx = currentCtx;
if(!tls_ctxStack.empty()) {
tls_ctxStack.pop();
}
if(!tls_ctxStack.empty()) {
tempCtx= tls_ctxStack.top();
}
else {
tempCtx = nullptr;
currentCtx= tls_ctxStack.top();
} else {
currentCtx = deviceHandle->_primaryCtx;
}
ihipSetTlsDefaultCtx(tempCtx); //TOD0 - Shall check for NULL?
ihipSetTlsDefaultCtx(currentCtx); //TOD0 - Shall check for NULL?
return ihipLogStatus(e);
}
@@ -169,8 +169,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx)
if(ctx != NULL) { //TODO- is this check needed?
ihipSetTlsDefaultCtx(ctx);
tls_ctxStack.push(ctx);
}
else {
tls_getPrimaryCtx = false;
} else {
e = hipErrorInvalidContext;
}
return ihipLogStatus(e);
@@ -180,12 +180,11 @@ hipError_t hipCtxGetCurrent(hipCtx_t* ctx)
{
HIP_INIT_API(ctx);
hipError_t e = hipSuccess;
if(!tls_ctxStack.empty()) {
if((tls_getPrimaryCtx) || tls_ctxStack.empty()) {
*ctx = ihipGetTlsDefaultCtx();
} else {
*ctx= tls_ctxStack.top();
}
else {
*ctx = NULL;
}
return ihipLogStatus(e);
}
@@ -195,10 +194,10 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx)
hipError_t e = hipSuccess;
if(ctx == NULL) {
tls_ctxStack.pop();
}
else {
} else {
ihipSetTlsDefaultCtx(ctx);
tls_ctxStack.push(ctx);
tls_getPrimaryCtx = false;
}
return ihipLogStatus(e);
}
@@ -213,8 +212,7 @@ hipError_t hipCtxGetDevice(hipDevice_t *device)
if(ctx == nullptr) {
e = hipErrorInvalidContext;
// TODO *device = nullptr;
}
else {
} else {
auto deviceHandle = ctx->getDevice();
*device = deviceHandle->_deviceId;
}
+1
Datei anzeigen
@@ -146,6 +146,7 @@ hipError_t hipSetDevice(int deviceId)
return ihipLogStatus(hipErrorInvalidDevice);
} else {
ihipSetTlsDefaultCtx(ihipGetPrimaryCtx(deviceId));
tls_getPrimaryCtx = true;
return ihipLogStatus(hipSuccess);
}
}
+1
Datei anzeigen
@@ -114,6 +114,7 @@ private:
//Extern tls
extern thread_local hipError_t tls_lastHipError;
extern thread_local TidInfo tls_tidInfo;
extern thread_local bool tls_getPrimaryCtx;
extern std::vector<ProfTrigger> g_dbStartTriggers;
extern std::vector<ProfTrigger> g_dbStopTriggers;