diff --git a/include/hcc_detail/hip_runtime_api.h b/include/hcc_detail/hip_runtime_api.h index 6bd2e2e982..029682a341 100644 --- a/include/hcc_detail/hip_runtime_api.h +++ b/include/hcc_detail/hip_runtime_api.h @@ -1274,6 +1274,15 @@ hipError_t hipDeviceGetName(char *name,int len,hipDevice_t device); */ hipError_t hipDeviceGetPCIBusId (int *pciBusId,int len,hipDevice_t device); +/** + * @brief Returns the total amount of memory on the device. + * @param [out] bytes + * @param [in] device + * + * @returns #hipSuccess, #hipErrorInavlidDevice + */ +hipError_t hipDeviceTotalMem (size_t *bytes,hipDevice_t device); + /** * @brief Returns the approximate HIP driver version. * diff --git a/include/nvcc_detail/hip_runtime_api.h b/include/nvcc_detail/hip_runtime_api.h index 6001226acf..e2bcf41dd9 100644 --- a/include/nvcc_detail/hip_runtime_api.h +++ b/include/nvcc_detail/hip_runtime_api.h @@ -645,6 +645,11 @@ inline static hipError_t hipDeviceGetPCIBusId (int *pciBusId,int len,hipDevice_t return hipCUResultTohipError(cuDeviceGetPCIBusId((char*)pciBusId,len,device)); } +inline static hipError_t hipDeviceTotalMem (size_t *bytes,hipDevice_t device) +{ + return hipCUResultTohipError(cuDeviceTotalMem(bytes,device)); +} + inline static hipError_t hipModuleLoad(hipModule_t *module, const char* fname) { return hipCUResultTohipError(cuModuleLoad(module, fname)); diff --git a/src/hip_device.cpp b/src/hip_device.cpp index d05f5bc69e..c3acbadbff 100644 --- a/src/hip_device.cpp +++ b/src/hip_device.cpp @@ -354,3 +354,10 @@ hipError_t hipDeviceGetPCIBusId (int *pciBusId,int len,hipDevice_t device) return ihipLogStatus(e); } +hipError_t hipDeviceTotalMem (size_t *bytes,hipDevice_t device) +{ + HIP_INIT_API(bytes, device); + hipError_t e = hipSuccess; + *bytes= device->_props.totalGlobalMem; + return ihipLogStatus(e); +}