diff --git a/include/hcc_detail/hip_runtime_api.h b/include/hcc_detail/hip_runtime_api.h index ae10a89df9..0d75e394d7 100644 --- a/include/hcc_detail/hip_runtime_api.h +++ b/include/hcc_detail/hip_runtime_api.h @@ -1063,6 +1063,7 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx); hipError_t hipCtxGetCurrent(hipCtx_t* ctx); +hipError_t hipCtxGetDevice(hipDevice_t *device); // TODO-ctx /** diff --git a/src/hip_context.cpp b/src/hip_context.cpp index eb47b9bcd2..9eb65fae39 100644 --- a/src/hip_context.cpp +++ b/src/hip_context.cpp @@ -152,3 +152,17 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx) return ihipLogStatus(e); } +hipError_t hipCtxGetDevice(hipDevice_t *device) +{ + hipError_t e = hipSuccess; + + ihipCtx_t *ctx = ihipGetTlsDefaultCtx(); + + if(ctx == nullptr) { + e = hipErrorInvalidContext; + } + else { + *device = (ihipDevice_t*)ctx->getDevice(); + } + return ihipLogStatus(e); +} diff --git a/tests/src/context/hipCtx_simple.cpp b/tests/src/context/hipCtx_simple.cpp index 7d634d36fe..882cf44f6d 100644 --- a/tests/src/context/hipCtx_simple.cpp +++ b/tests/src/context/hipCtx_simple.cpp @@ -30,13 +30,14 @@ int main(int argc, char *argv[]) HIPCHECK(hipInit(0)); hipDevice_t device; + hipDevice_t device1; hipCtx_t ctx; hipCtx_t ctx1; HIPCHECK(hipDeviceGetFromId(&device, 0)); HIPCHECK(hipCtxCreate(&ctx, 0, device)); HIPCHECK(hipCtxGetCurrent(&ctx1)); - + HIPCHECK(hipCtxGetDevice(&device1)); HIPCHECK(hipCtxPopCurrent(&ctx1)); HIPCHECK(hipCtxGetCurrent(&ctx1));