diff --git a/include/hip/hcc_detail/hip_runtime_api.h b/include/hip/hcc_detail/hip_runtime_api.h index afb99da6a3..ffb03d23d7 100644 --- a/include/hip/hcc_detail/hip_runtime_api.h +++ b/include/hip/hcc_detail/hip_runtime_api.h @@ -391,7 +391,7 @@ hipError_t hipSetDevice(int deviceId); * This device is used implicitly for HIP runtime APIs called by this thread. * hipGetDevice returns in * @p device the default device for the calling host thread. * - * @returns #hipSuccess + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue * * @see hipSetDevice, hipGetDevicesizeBytes */ diff --git a/src/hip_device.cpp b/src/hip_device.cpp index 6406b48c2d..403194483a 100644 --- a/src/hip_device.cpp +++ b/src/hip_device.cpp @@ -33,18 +33,16 @@ hipError_t hipGetDevice(int* deviceId) { HIP_INIT_API(hipGetDevice, deviceId); hipError_t e = hipSuccess; + if (deviceId == nullptr) + return ihipLogStatus(hipErrorInvalidValue); auto ctx = ihipGetTlsDefaultCtx(); - if (deviceId != nullptr) { - if (ctx == nullptr) { - e = hipErrorInvalidDevice; // TODO, check error code. - *deviceId = -1; - } else { - *deviceId = ctx->getDevice()->_deviceId; - } + if (ctx == nullptr) { + e = hipErrorInvalidDevice; // TODO, check error code. + *deviceId = -1; } else { - e = hipErrorInvalidValue; + *deviceId = ctx->getDevice()->_deviceId; } return ihipLogStatus(e);