diff --git a/hipamd/src/hip_mempool.cpp b/hipamd/src/hip_mempool.cpp index 8309d19e7c..29f8af99c6 100644 --- a/hipamd/src/hip_mempool.cpp +++ b/hipamd/src/hip_mempool.cpp @@ -44,6 +44,12 @@ hipError_t hipDeviceSetMemPool(int device, hipMemPool_t mem_pool) { if ((mem_pool == nullptr) || (device >= g_devices.size())) { HIP_RETURN(hipErrorInvalidValue); } + + auto poolDevice = reinterpret_cast(mem_pool)->Device(); + if (poolDevice->deviceId() != device) { + HIP_RETURN(hipErrorInvalidDevice); + } + g_devices[device]->SetCurrentMemoryPool(reinterpret_cast(mem_pool)); HIP_RETURN(hipSuccess); }