diff --git a/projects/hip-tests/catch/unit/device/hipDeviceTotalMem.cc b/projects/hip-tests/catch/unit/device/hipDeviceTotalMem.cc index d5b153ae28..6812b37873 100644 --- a/projects/hip-tests/catch/unit/device/hipDeviceTotalMem.cc +++ b/projects/hip-tests/catch/unit/device/hipDeviceTotalMem.cc @@ -37,13 +37,13 @@ TEST_CASE("Unit_hipDeviceTotalMem_NegTst") { #endif // Scenario 1 SECTION("bytes is nullptr") { - REQUIRE_FALSE(hipDeviceTotalMem(nullptr, 0) == hipSuccess); + HIP_CHECK_ERROR(hipDeviceTotalMem(nullptr, 0), hipErrorInvalidValue); } size_t totMem; // Scenario 2 SECTION("device is -1") { - REQUIRE_FALSE(hipDeviceTotalMem(&totMem, -1) == hipSuccess); + HIP_CHECK_ERROR(hipDeviceTotalMem(&totMem, -1), hipErrorInvalidDevice); } // Scenario 3 @@ -51,14 +51,14 @@ TEST_CASE("Unit_hipDeviceTotalMem_NegTst") { int numDevices; HIP_CHECK(hipGetDeviceCount(&numDevices)); size_t totMem; - REQUIRE_FALSE(hipDeviceTotalMem(&totMem, numDevices) == hipSuccess); + HIP_CHECK_ERROR(hipDeviceTotalMem(&totMem, numDevices), hipErrorInvalidDevice); } } // Scenario 4 TEST_CASE("Unit_hipDeviceTotalMem_ValidateTotalMem") { size_t totMem; - int numDevices; + int numDevices = 0; HIP_CHECK(hipGetDeviceCount(&numDevices)); REQUIRE(numDevices != 0); @@ -70,5 +70,30 @@ TEST_CASE("Unit_hipDeviceTotalMem_ValidateTotalMem") { HIP_CHECK(hipDeviceGet(&device, devNo)); HIP_CHECK(hipGetDeviceProperties(&prop, device)); HIP_CHECK(hipDeviceTotalMem(&totMem, device)); - REQUIRE_FALSE(totMem != prop.totalGlobalMem); + + size_t free = 0, total = 0; + HIP_CHECK(hipMemGetInfo(&free, &total)); + + REQUIRE(totMem == prop.totalGlobalMem); + REQUIRE(total == totMem); +} + +TEST_CASE("Unit_hipDeviceTotalMem_NonSelectedDevice") { + auto deviceCount = HipTest::getDeviceCount(); + if (deviceCount < 2) { + HipTest::HIP_SKIP_TEST("Multi Device Test, will not run on single gpu systems. Skipping."); + return; + } + + for (int i = 1; i < deviceCount; i++) { + HIP_CHECK(hipSetDevice(i - 1)); + hipDevice_t device; + HIP_CHECK(hipDeviceGet(&device, i)); + + size_t totMem = 0; + hipDeviceProp_t prop; + HIP_CHECK(hipDeviceTotalMem(&totMem, device)); + HIP_CHECK(hipGetDeviceProperties(&prop, device)); + REQUIRE(totMem == prop.totalGlobalMem); + } }