diff --git a/include/hip/hip_runtime_api.h b/include/hip/hip_runtime_api.h index a45a1ee27e..a2bfed5c69 100644 --- a/include/hip/hip_runtime_api.h +++ b/include/hip/hip_runtime_api.h @@ -181,6 +181,7 @@ typedef enum hipError_t { hipErrorSharedObjectSymbolNotFound = 302, hipErrorSharedObjectInitFailed = 303, hipErrorOperatingSystem = 304, + hipErrorSetOnActiveProcess = 305, hipErrorInvalidHandle = 400, hipErrorNotFound = 500, hipErrorIllegalAddress = 700, diff --git a/src/hip_device.cpp b/src/hip_device.cpp index 1cfdaa619d..0f2c2e2753 100644 --- a/src/hip_device.cpp +++ b/src/hip_device.cpp @@ -175,6 +175,24 @@ hipError_t hipDeviceReset(void) return ihipLogStatus(hipSuccess); } +hipError_t ihipDeviceSetState(void) +{ + hipError_t e = hipErrorInvalidContext; + auto *ctx = ihipGetTlsDefaultCtx(); + + if (ctx) { + ihipDevice_t *deviceHandle = ctx->getWriteableDevice(); + if(deviceHandle->_state == 0) + { + deviceHandle->_state = 1; + } + e = hipSuccess; + } + + return ihipLogStatus(e); +} + + hipError_t ihipDeviceGetAttribute(int* pi, hipDeviceAttribute_t attr, int device) { hipError_t e = hipSuccess; @@ -289,29 +307,35 @@ hipError_t hipSetDeviceFlags( unsigned int flags) // TODO : does this really OR in the flags or replaces previous flags: // TODO : Review error handling behavior for this function, it often returns ErrorSetOnActiveProcess if (ctx) { - ctx->_ctxFlags = ctx->_ctxFlags | flags; - if (flags & hipDeviceScheduleMask) { - switch (hipDeviceScheduleMask) { - case hipDeviceScheduleAuto: - case hipDeviceScheduleSpin: - case hipDeviceScheduleYield: - case hipDeviceScheduleBlockingSync: - e = hipSuccess; - break; - default: - e = hipSuccess; // TODO - should this be error? Map to Auto? - //e = hipErrorInvalidValue; - break; + auto *deviceHandle = ctx->getDevice(); + if(deviceHandle->_state == 0) + { + ctx->_ctxFlags = ctx->_ctxFlags | flags; + if (flags & hipDeviceScheduleMask) { + switch (hipDeviceScheduleMask) { + case hipDeviceScheduleAuto: + case hipDeviceScheduleSpin: + case hipDeviceScheduleYield: + case hipDeviceScheduleBlockingSync: + e = hipSuccess; + break; + default: + e = hipSuccess; // TODO - should this be error? Map to Auto? + //e = hipErrorInvalidValue; + break; + } } - } - unsigned supportedFlags = hipDeviceScheduleMask | hipDeviceMapHost | hipDeviceLmemResizeToMax; + unsigned supportedFlags = hipDeviceScheduleMask | hipDeviceMapHost | hipDeviceLmemResizeToMax; - if (flags & (~supportedFlags)) { - e = hipErrorInvalidValue; - } - } else { - e = hipErrorInvalidDevice; + if (flags & (~supportedFlags)) { + e = hipErrorInvalidValue; + } + } else { + e = hipErrorSetOnActiveProcess; + } + } else { + e = hipErrorInvalidDevice; } return ihipLogStatus(e); diff --git a/src/hip_hcc.cpp b/src/hip_hcc.cpp index a4ef2b392b..d760ade15d 100644 --- a/src/hip_hcc.cpp +++ b/src/hip_hcc.cpp @@ -482,7 +482,8 @@ void ihipCtxCriticalBase_t::addStream(ihipStream_t *stream) //================================================================================================= ihipDevice_t::ihipDevice_t(unsigned deviceId, unsigned deviceCnt, hc::accelerator &acc) : _deviceId(deviceId), - _acc(acc) + _acc(acc), + _state(0) { hsa_agent_t *agent = static_cast (acc.get_hsa_agent()); if (agent) { @@ -865,6 +866,7 @@ void ihipCtx_t::locked_reset() // Reset will remove peer mapping so don't need to do this explicitly. // FIXME - This is clearly a non-const action! Is this a context reset or a device reset - maybe should reference count? ihipDevice_t *device = getWriteableDevice(); + device->_state = 0; am_memtracker_reset(device->_acc); }; @@ -1553,6 +1555,7 @@ const char *ihipErrorString(hipError_t hip_error) case hipErrorSharedObjectSymbolNotFound : return "hipErrorSharedObjectSymbolNotFound"; case hipErrorSharedObjectInitFailed : return "hipErrorSharedObjectInitFailed"; case hipErrorOperatingSystem : return "hipErrorOperatingSystem"; + case hipErrorSetOnActiveProcess : return "hipErrorSetOnActiveProcess"; case hipErrorInvalidHandle : return "hipErrorInvalidHandle"; case hipErrorNotFound : return "hipErrorNotFound"; case hipErrorIllegalAddress : return "hipErrorIllegalAddress"; diff --git a/src/hip_hcc.h b/src/hip_hcc.h index ed85f1494c..8a4d457cb1 100644 --- a/src/hip_hcc.h +++ b/src/hip_hcc.h @@ -204,7 +204,8 @@ extern void recordApiTrace(std::string *fullStr, const std::string &apiStr); #define HIP_INIT()\ std::call_once(hip_initialized, ihipInit);\ ihipCtxStackUpdate(); - +#define HIP_SET_DEVICE()\ + ihipDeviceSetState(); // This macro should be called at the beginning of every HIP API. // It initialies the hip runtime (exactly once), and @@ -566,6 +567,8 @@ public: ihipCtx_t *_primaryCtx; + int _state; //1 if device is set otherwise 0 + private: hipError_t initProperties(hipDeviceProp_t* prop); }; @@ -703,6 +706,7 @@ extern ihipCtx_t *ihipGetTlsDefaultCtx(); extern void ihipSetTlsDefaultCtx(ihipCtx_t *ctx); extern hipError_t ihipSynchronize(void); extern void ihipCtxStackUpdate(); +extern hipError_t ihipDeviceSetState(); extern ihipDevice_t *ihipGetDevice(int); ihipCtx_t * ihipGetPrimaryCtx(unsigned deviceIndex); diff --git a/src/hip_memory.cpp b/src/hip_memory.cpp index 7e1a1738a6..74578e9b4b 100644 --- a/src/hip_memory.cpp +++ b/src/hip_memory.cpp @@ -105,7 +105,7 @@ hipError_t hipHostGetDevicePointer(void **devicePointer, void *hostPointer, unsi hipError_t hipMalloc(void** ptr, size_t sizeBytes) { HIP_INIT_API(ptr, sizeBytes); - + HIP_SET_DEVICE(); hipError_t hip_status = hipSuccess; // return NULL pointer when malloc size is 0 if (sizeBytes == 0) @@ -161,7 +161,7 @@ hipError_t hipMalloc(void** ptr, size_t sizeBytes) hipError_t hipHostMalloc(void** ptr, size_t sizeBytes, unsigned int flags) { HIP_INIT_API(ptr, sizeBytes, flags); - + HIP_SET_DEVICE(); hipError_t hip_status = hipSuccess; auto ctx = ihipGetTlsDefaultCtx(); @@ -233,7 +233,7 @@ hipError_t hipHostAlloc(void** ptr, size_t sizeBytes, unsigned int flags) hipError_t hipMallocPitch(void** ptr, size_t* pitch, size_t width, size_t height) { HIP_INIT_API(ptr, pitch, width, height); - + HIP_SET_DEVICE(); hipError_t hip_status = hipSuccess; if(width == 0 || height == 0) @@ -285,7 +285,7 @@ hipError_t hipMallocArray(hipArray** array, const hipChannelFormatDesc* desc, size_t width, size_t height, unsigned int flags) { HIP_INIT_API(array, desc, width, height, flags); - + HIP_SET_DEVICE(); hipError_t hip_status = hipSuccess; auto ctx = ihipGetTlsDefaultCtx();