SWDEV-558836, SWDEV-558837 - Add hipMemSetMemPool and hipMemGetMemPoo… (#1349)

* SWDEV-558836, SWDEV-558837 - Add hipMemSetMemPool and hipMemGetMemPool implementation

* Add managed allocation type for mem pools

* Update rocprofiler-sdk with APis declaration
Este commit está contenido en:
vstojilj
2026-01-27 18:45:28 +01:00
cometido por GitHub
padre 324a864bc4
commit 9a8942a89c
Se han modificado 21 ficheros con 517 adiciones y 17 borrados
+2
Ver fichero
@@ -28,6 +28,8 @@ Full documentation for HIP is available at [rocm.docs.amd.com](https://rocm.docs
- `hipLibraryGetKernelCount` gets kernel count in library
- `hipStreamCopyAttributes` copies attributes from source stream to destination stream
- `hipOccupancyAvailableDynamicSMemPerBlock` returns dynamic shared memory available per block when launching numBlocks blocks on CU.
- `hipMemSetMemPool` Sets the current memory pool for a memory location and allocation type
- `hipMemGetMemPool` Gets the current memory pool for a memory location and of a particular allocation type
* New HIP flags
- `hipMemLocationTypeHost`, enables handling virtual memory management in host memory location, in addition to device memory.
- Support for flags in `hipGetProcAddress`, enables searching for the per-thread version symbols.
@@ -63,7 +63,7 @@
#define HIP_API_TABLE_STEP_VERSION 0
#define HIP_COMPILER_API_TABLE_STEP_VERSION 0
#define HIP_TOOLS_API_TABLE_STEP_VERSION 0
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 21
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 22
// HIP API interface
// HIP compiler dispatch functions
@@ -1119,6 +1119,10 @@ typedef hipError_t (*t_hipExtSetLoggingParams)(size_t log_level, size_t log_size
typedef hipError_t (*t_hipKernelGetParamInfo)(hipKernel_t kernel, size_t paramIndex,
size_t* paramOffset, size_t* paramSize);
typedef hipError_t (*t_hipMemSetMemPool)(hipMemLocation* location, hipMemAllocationType type,
hipMemPool_t pool);
typedef hipError_t (*t_hipMemGetMemPool)(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type);
// HIP Compiler dispatch table
struct HipCompilerDispatchTable {
// HIP_COMPILER_API_TABLE_STEP_VERSION == 0
@@ -1700,7 +1704,7 @@ struct HipDispatchTable {
t_hipLibraryEnumerateKernels hipLibraryEnumerateKernels_fn;
t_hipKernelGetLibrary hipKernelGetLibrary_fn;
t_hipKernelGetName hipKernelGetName_fn;
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 18
t_hipOccupancyAvailableDynamicSMemPerBlock hipOccupancyAvailableDynamicSMemPerBlock_fn;
@@ -1715,8 +1719,12 @@ struct HipDispatchTable {
t_hipExtEnableLogging hipExtEnableLogging_fn;
t_hipExtSetLoggingParams hipExtSetLoggingParams_fn;
// DO NOT EDIT ABOVE!
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 22
t_hipMemSetMemPool hipMemSetMemPool_fn;
t_hipMemGetMemPool hipMemGetMemPool_fn;
// DO NOT EDIT ABOVE!
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 23
// ******************************************************************************************* //
//
@@ -471,7 +471,9 @@ enum hip_api_id_t {
HIP_API_ID_hipExtDisableLogging = 451,
HIP_API_ID_hipExtEnableLogging = 452,
HIP_API_ID_hipExtSetLoggingParams = 453,
HIP_API_ID_LAST = 453,
HIP_API_ID_hipMemSetMemPool = 454,
HIP_API_ID_hipMemGetMemPool = 455,
HIP_API_ID_LAST = 455,
HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice),
HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties),
@@ -782,6 +784,7 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipMemGetAllocationPropertiesFromHandle: return "hipMemGetAllocationPropertiesFromHandle";
case HIP_API_ID_hipMemGetHandleForAddressRange: return "hipMemGetHandleForAddressRange";
case HIP_API_ID_hipMemGetInfo: return "hipMemGetInfo";
case HIP_API_ID_hipMemGetMemPool: return "hipMemGetMemPool";
case HIP_API_ID_hipMemImportFromShareableHandle: return "hipMemImportFromShareableHandle";
case HIP_API_ID_hipMemMap: return "hipMemMap";
case HIP_API_ID_hipMemMapArrayAsync: return "hipMemMapArrayAsync";
@@ -804,6 +807,7 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipMemRelease: return "hipMemRelease";
case HIP_API_ID_hipMemRetainAllocationHandle: return "hipMemRetainAllocationHandle";
case HIP_API_ID_hipMemSetAccess: return "hipMemSetAccess";
case HIP_API_ID_hipMemSetMemPool: return "hipMemSetMemPool";
case HIP_API_ID_hipMemUnmap: return "hipMemUnmap";
case HIP_API_ID_hipMemcpy: return "hipMemcpy";
case HIP_API_ID_hipMemcpy2D: return "hipMemcpy2D";
@@ -1229,6 +1233,7 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipMemGetAllocationPropertiesFromHandle", name) == 0) return HIP_API_ID_hipMemGetAllocationPropertiesFromHandle;
if (strcmp("hipMemGetHandleForAddressRange", name) == 0) return HIP_API_ID_hipMemGetHandleForAddressRange;
if (strcmp("hipMemGetInfo", name) == 0) return HIP_API_ID_hipMemGetInfo;
if (strcmp("hipMemGetMemPool", name) == 0) return HIP_API_ID_hipMemGetMemPool;
if (strcmp("hipMemImportFromShareableHandle", name) == 0) return HIP_API_ID_hipMemImportFromShareableHandle;
if (strcmp("hipMemMap", name) == 0) return HIP_API_ID_hipMemMap;
if (strcmp("hipMemMapArrayAsync", name) == 0) return HIP_API_ID_hipMemMapArrayAsync;
@@ -1251,6 +1256,7 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipMemRelease", name) == 0) return HIP_API_ID_hipMemRelease;
if (strcmp("hipMemRetainAllocationHandle", name) == 0) return HIP_API_ID_hipMemRetainAllocationHandle;
if (strcmp("hipMemSetAccess", name) == 0) return HIP_API_ID_hipMemSetAccess;
if (strcmp("hipMemSetMemPool", name) == 0) return HIP_API_ID_hipMemSetMemPool;
if (strcmp("hipMemUnmap", name) == 0) return HIP_API_ID_hipMemUnmap;
if (strcmp("hipMemcpy", name) == 0) return HIP_API_ID_hipMemcpy;
if (strcmp("hipMemcpy2D", name) == 0) return HIP_API_ID_hipMemcpy2D;
@@ -3012,6 +3018,13 @@ typedef struct hip_api_data_s {
size_t* total;
size_t total__val;
} hipMemGetInfo;
struct {
hipMemPool_t* pool;
hipMemPool_t pool__val;
hipMemLocation* location;
hipMemLocation location__val;
hipMemAllocationType type;
} hipMemGetMemPool;
struct {
hipMemGenericAllocationHandle_t* handle;
hipMemGenericAllocationHandle_t handle__val;
@@ -3143,6 +3156,12 @@ typedef struct hip_api_data_s {
hipMemAccessDesc desc__val;
size_t count;
} hipMemSetAccess;
struct {
hipMemLocation* location;
hipMemLocation location__val;
hipMemAllocationType type;
hipMemPool_t pool;
} hipMemSetMemPool;
struct {
void* ptr;
size_t size;
@@ -5670,6 +5689,12 @@ typedef struct hip_api_data_s {
cb_data.args.hipMemGetInfo.free = (size_t*)free; \
cb_data.args.hipMemGetInfo.total = (size_t*)total; \
};
// hipMemGetMemPool[('hipMemPool_t*', 'pool'), ('hipMemLocation*', 'location'), ('hipMemAllocationType', 'type')]
#define INIT_hipMemGetMemPool_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipMemGetMemPool.pool = (hipMemPool_t*)pool; \
cb_data.args.hipMemGetMemPool.location = (hipMemLocation*)location; \
cb_data.args.hipMemGetMemPool.type = (hipMemAllocationType)type; \
};
// hipMemImportFromShareableHandle[('hipMemGenericAllocationHandle_t*', 'handle'), ('void*', 'osHandle'), ('hipMemAllocationHandleType', 'shHandleType')]
#define INIT_hipMemImportFromShareableHandle_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipMemImportFromShareableHandle.handle = (hipMemGenericAllocationHandle_t*)handle; \
@@ -5806,6 +5831,12 @@ typedef struct hip_api_data_s {
cb_data.args.hipMemSetAccess.desc = (const hipMemAccessDesc*)desc; \
cb_data.args.hipMemSetAccess.count = (size_t)count; \
};
// hipMemSetMemPool[('hipMemLocation*', 'location'), ('hipMemAllocationType', 'type'), ('hipMemPool_t', 'pool')]
#define INIT_hipMemSetMemPool_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipMemSetMemPool.location = (hipMemLocation*)location; \
cb_data.args.hipMemSetMemPool.type = (hipMemAllocationType)type; \
cb_data.args.hipMemSetMemPool.pool = (hipMemPool_t)pool; \
};
// hipMemUnmap[('void*', 'ptr'), ('size_t', 'size')]
#define INIT_hipMemUnmap_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipMemUnmap.ptr = (void*)ptr; \
@@ -7939,6 +7970,11 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
if (data->args.hipMemGetInfo.free) data->args.hipMemGetInfo.free__val = *(data->args.hipMemGetInfo.free);
if (data->args.hipMemGetInfo.total) data->args.hipMemGetInfo.total__val = *(data->args.hipMemGetInfo.total);
break;
// hipMemGetMemPool[('hipMemPool_t*', 'pool'), ('hipMemLocation*', 'location'), ('hipMemAllocationType', 'type')]
case HIP_API_ID_hipMemGetMemPool:
if (data->args.hipMemGetMemPool.pool) data->args.hipMemGetMemPool.pool__val = *(data->args.hipMemGetMemPool.pool);
if (data->args.hipMemGetMemPool.location) data->args.hipMemGetMemPool.location__val = *(data->args.hipMemGetMemPool.location);
break;
// hipMemImportFromShareableHandle[('hipMemGenericAllocationHandle_t*', 'handle'), ('void*', 'osHandle'), ('hipMemAllocationHandleType', 'shHandleType')]
case HIP_API_ID_hipMemImportFromShareableHandle:
if (data->args.hipMemImportFromShareableHandle.handle) data->args.hipMemImportFromShareableHandle.handle__val = *(data->args.hipMemImportFromShareableHandle.handle);
@@ -8022,6 +8058,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
case HIP_API_ID_hipMemSetAccess:
if (data->args.hipMemSetAccess.desc) data->args.hipMemSetAccess.desc__val = *(data->args.hipMemSetAccess.desc);
break;
// hipMemSetMemPool[('hipMemLocation*', 'location'), ('hipMemAllocationType', 'type'), ('hipMemPool_t', 'pool')]
case HIP_API_ID_hipMemSetMemPool:
if (data->args.hipMemSetMemPool.location) data->args.hipMemSetMemPool.location__val = *(data->args.hipMemSetMemPool.location);
break;
// hipMemUnmap[('void*', 'ptr'), ('size_t', 'size')]
case HIP_API_ID_hipMemUnmap:
break;
@@ -10747,6 +10787,15 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
else { oss << ", total="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemGetInfo.total__val); }
oss << ")";
break;
case HIP_API_ID_hipMemGetMemPool:
oss << "hipMemGetMemPool(";
if (data->args.hipMemGetMemPool.pool == NULL) oss << "pool=NULL";
else { oss << "pool="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemGetMemPool.pool__val); }
if (data->args.hipMemGetMemPool.location == NULL) oss << ", location=NULL";
else { oss << ", location="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemGetMemPool.location__val); }
oss << ", type="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemGetMemPool.type);
oss << ")";
break;
case HIP_API_ID_hipMemImportFromShareableHandle:
oss << "hipMemImportFromShareableHandle(";
if (data->args.hipMemImportFromShareableHandle.handle == NULL) oss << "handle=NULL";
@@ -10922,6 +10971,14 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
oss << ", count="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemSetAccess.count);
oss << ")";
break;
case HIP_API_ID_hipMemSetMemPool:
oss << "hipMemSetMemPool(";
if (data->args.hipMemSetMemPool.location == NULL) oss << "location=NULL";
else { oss << "location="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemSetMemPool.location__val); }
oss << ", type="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemSetMemPool.type);
oss << ", pool="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemSetMemPool.pool);
oss << ")";
break;
case HIP_API_ID_hipMemUnmap:
oss << "hipMemUnmap(";
oss << "ptr="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemUnmap.ptr);
+2
Ver fichero
@@ -526,3 +526,5 @@ hipKernelGetParamInfo
hipExtDisableLogging
hipExtEnableLogging
hipExtSetLoggingParams
hipMemSetMemPool
hipMemGetMemPool
+10 -2
Ver fichero
@@ -888,6 +888,9 @@ hipError_t hipKernelGetParamInfo(hipKernel_t kernel, size_t paramIndex, size_t*
hipError_t hipExtDisableLogging();
hipError_t hipExtEnableLogging();
hipError_t hipExtSetLoggingParams(size_t log_level, size_t log_size, size_t log_mask);
hipError_t hipMemSetMemPool(hipMemLocation* location, hipMemAllocationType type, hipMemPool_t pool);
hipError_t hipMemGetMemPool(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type);
} // namespace hip
namespace hip {
@@ -1438,6 +1441,8 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
ptrDispatchTable->hipExtDisableLogging_fn = hip::hipExtDisableLogging;
ptrDispatchTable->hipExtEnableLogging_fn = hip::hipExtEnableLogging;
ptrDispatchTable->hipExtSetLoggingParams_fn = hip::hipExtSetLoggingParams;
ptrDispatchTable->hipMemSetMemPool_fn = hip::hipMemSetMemPool;
ptrDispatchTable->hipMemGetMemPool_fn = hip::hipMemGetMemPool;
}
#if HIP_ROCPROFILER_REGISTER > 0
@@ -2124,15 +2129,18 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipKernelGetParamInfo_fn, 507);
HIP_ENFORCE_ABI(HipDispatchTable, hipExtDisableLogging_fn, 508);
HIP_ENFORCE_ABI(HipDispatchTable, hipExtEnableLogging_fn, 509);
HIP_ENFORCE_ABI(HipDispatchTable, hipExtSetLoggingParams_fn, 510);
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 22
HIP_ENFORCE_ABI(HipDispatchTable, hipMemSetMemPool_fn, 511);
HIP_ENFORCE_ABI(HipDispatchTable, hipMemGetMemPool_fn, 512);
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
//
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
//
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 511)
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 513)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 21,
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 22,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
"pointers and then update this check so it is true");
#endif
+18
Ver fichero
@@ -79,6 +79,20 @@ bool Device::Create() {
// Current is default pool after device creation
current_mem_pool_ = default_mem_pool_;
// Create managed memory pool
hipMemPoolProps props = {.allocType = hipMemAllocationTypeManaged,
.handleTypes = hipMemHandleTypeNone,
.location = {.type = hipMemLocationTypeDevice, .id = deviceId_},
.win32SecurityAttributes = nullptr,
.maxSize = 0,
.reserved = {}};
default_managed_mem_pool_ = new MemoryPool(this, &props);
if (default_managed_mem_pool_ == nullptr) {
return false;
}
current_managed_mem_pool_ = default_managed_mem_pool_;
return true;
}
@@ -331,6 +345,10 @@ Device::~Device() {
graph_mem_pool_->release();
}
if (default_managed_mem_pool_ != nullptr) {
default_managed_mem_pool_->release();
}
if (null_stream_ != nullptr) {
hip::Stream::Destroy(null_stream_);
}
+2
Ver fichero
@@ -648,6 +648,8 @@ global:
hipExtDisableLogging;
hipExtEnableLogging;
hipExtSetLoggingParams;
hipMemSetMemPool;
hipMemGetMemPool;
local:
*;
} hip_7.1;
+24 -9
Ver fichero
@@ -541,21 +541,27 @@ public:
MemoryPool* default_mem_pool_; //!< Default memory pool for this device
MemoryPool* current_mem_pool_;
MemoryPool* graph_mem_pool_; //!< Memory pool, associated with graphs for this device
MemoryPool* current_managed_mem_pool_; //!< Memory pool for managed allocations
MemoryPool* default_managed_mem_pool_; //!< Memory pool for managed allocations
std::set<MemoryPool*> mem_pools_;
// Tracking Objects
ObjectRegistry<hipGraphicsResource_t> registeredGraphicsResources_; //!< Track registered graphics resources
ObjectRegistry<hipGraphicsResource_t> mappedGraphicsResources_; //!< Track mapped graphics resources
public:
Device(amd::Context* ctx, int devId): context_(ctx),
deviceId_(devId),
flags_(hipDeviceScheduleSpin),
isActive_(false),
default_mem_pool_(nullptr),
current_mem_pool_(nullptr),
graph_mem_pool_(nullptr)
{ assert(ctx != nullptr); }
public:
Device(amd::Context* ctx, int devId)
: context_(ctx),
deviceId_(devId),
flags_(hipDeviceScheduleSpin),
isActive_(false),
default_mem_pool_(nullptr),
current_mem_pool_(nullptr),
graph_mem_pool_(nullptr),
default_managed_mem_pool_(nullptr),
current_managed_mem_pool_(nullptr) {
assert(ctx != nullptr);
}
~Device();
bool Create();
@@ -619,6 +625,15 @@ public:
/// Get the graph memory pool on the device
MemoryPool* GetGraphMemoryPool() const { return graph_mem_pool_; }
/// Set managed memory pool on the device
void SetCurrentManagedMemoryPool(MemoryPool* pool) { current_managed_mem_pool_ = pool; }
/// Get managed memory pool on the device
MemoryPool* GetCurrentManagedMemoryPool() const { return current_managed_mem_pool_; }
/// Get default managed memory pool on the device
MemoryPool* GetDefaultManagedMemoryPool() const { return default_managed_mem_pool_; }
/// Add memory pool to the device
void AddMemoryPool(MemoryPool* pool);
+83 -1
Ver fichero
@@ -309,7 +309,8 @@ hipError_t hipMemPoolCreate(hipMemPool_t* mem_pool, const hipMemPoolProps* pool_
HIP_RETURN(hipErrorInvalidValue);
}
// validate hipMemAllocationType value
if (pool_props->allocType != hipMemAllocationTypePinned) {
if (pool_props->allocType != hipMemAllocationTypePinned &&
pool_props->allocType != hipMemAllocationTypeManaged) {
HIP_RETURN(hipErrorInvalidValue);
}
// Make sure the pool creation occurs on a valid device
@@ -359,6 +360,11 @@ hipError_t hipMemPoolDestroy(hipMemPool_t mem_pool) {
device->SetCurrentMemoryPool(device->GetDefaultMemoryPool());
}
// Same for managed pool
if (hip_mem_pool == device->GetCurrentManagedMemoryPool()) {
device->SetCurrentManagedMemoryPool(device->GetDefaultManagedMemoryPool());
}
hip_mem_pool->release();
HIP_RETURN(hipSuccess);
@@ -488,4 +494,80 @@ hipError_t hipMemPoolImportPointer(void** ptr, hipMemPool_t mem_pool,
mpool->retain();
HIP_RETURN(hipSuccess);
}
// ================================================================================================
hipError_t hipMemSetMemPool(hipMemLocation* location, hipMemAllocationType type,
hipMemPool_t pool) {
HIP_INIT_API(hipMemSetMemPool, location, type, pool);
CHECK_STREAM_CAPTURE_SUPPORTED();
if (location == nullptr || pool == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
// Only device pools can be created
if (location->type != hipMemLocationTypeDevice) {
HIP_RETURN(hipErrorInvalidValue);
}
if (type != hipMemAllocationTypePinned && type != hipMemAllocationTypeManaged) {
HIP_RETURN(hipErrorInvalidValue);
}
if (location->id >= g_devices.size()) {
HIP_RETURN(hipErrorInvalidValue);
}
auto mem_pool = reinterpret_cast<hip::MemoryPool*>(pool);
if (!IsMemPoolValid(mem_pool)) {
HIP_RETURN(hipErrorInvalidValue);
}
// Location and type must match pool's location and allocation type
if ((location->id != mem_pool->Device()->deviceId()) ||
(type != mem_pool->Properties().allocType)) {
HIP_RETURN(hipErrorInvalidValue);
}
if (type == hipMemAllocationTypePinned) {
g_devices[location->id]->SetCurrentMemoryPool(mem_pool);
} else {
// Pool set for managed allocation type can't be implicitly used for allocation, but it can be
// retrieved with hipMemGetMemPool
g_devices[location->id]->SetCurrentManagedMemoryPool(mem_pool);
}
HIP_RETURN(hipSuccess);
}
// ================================================================================================
hipError_t hipMemGetMemPool(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type) {
HIP_INIT_API(hipMemGetMemPool, pool, location, type);
if ((pool == nullptr) || (location == nullptr)) {
HIP_RETURN(hipErrorInvalidValue);
}
if (location->type != hipMemLocationTypeDevice) {
HIP_RETURN(hipErrorInvalidValue);
}
if (type != hipMemAllocationTypePinned && type != hipMemAllocationTypeManaged) {
HIP_RETURN(hipErrorInvalidValue);
}
if (location->id >= g_devices.size()) {
HIP_RETURN(hipErrorInvalidValue);
}
if (type == hipMemAllocationTypePinned) {
*pool = reinterpret_cast<hipMemPool_t>(g_devices[location->id]->GetCurrentMemoryPool());
} else {
*pool = reinterpret_cast<hipMemPool_t>(g_devices[location->id]->GetCurrentManagedMemoryPool());
}
HIP_RETURN(hipSuccess);
}
} // namespace hip
@@ -3168,4 +3168,16 @@ hipError_t hipExtSetLoggingParams(size_t log_level, size_t log_size, size_t log_
TRY;
return hip::GetHipDispatchTable()->hipExtSetLoggingParams_fn(log_level, log_size, log_mask);
CATCH;
}
hipError_t hipMemSetMemPool(hipMemLocation* location, hipMemAllocationType type,
hipMemPool_t pool) {
TRY;
return hip::GetHipDispatchTable()->hipMemSetMemPool_fn(location, type, pool);
CATCH;
}
hipError_t hipMemGetMemPool(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type) {
TRY;
return hip::GetHipDispatchTable()->hipMemGetMemPool_fn(pool, location, type);
CATCH;
}
@@ -223,7 +223,9 @@ set(TEST_SRC
hipMemGetAddressRange.cc
hipMallocMipmappedArray.cc
hipFreeMipmappedArray.cc
hipHostAlloc.cc)
hipHostAlloc.cc
hipMemSetMemPool.cc
hipMemGetMemPool.cc)
if(HIP_PLATFORM MATCHES "amd")
set(TEST_SRC
@@ -0,0 +1,91 @@
/*
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANNTY OF ANY KIND, EXPRESS OR
IMPLIED, INNCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANNY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER INN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR INN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#include <hip_test_common.hh>
#include "mempool_common.hh"
/**
* @addtogroup hipMemGetMemPool hipMemGetMemPool
* @{
* @ingroup MemoryTest
* `hipError_t hipMemGetMemPool(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type)` -
* Gets the current memory pool for the location and allocation type.
*/
TEST_CASE("Unit_hipMemGetMemPool_Negative") {
int dev;
HIP_CHECK(hipGetDevice(&dev));
hipMemPool_t pool;
hipMemLocation location{};
location.id = dev;
location.type = hipMemLocationTypeDevice;
SECTION("Invalid pool") {
HIP_CHECK_ERROR(hipMemGetMemPool(nullptr, &location, hipMemAllocationTypePinned),
hipErrorInvalidValue);
}
SECTION("Invalid location") {
HIP_CHECK_ERROR(hipMemGetMemPool(&pool, nullptr, hipMemAllocationTypePinned),
hipErrorInvalidValue);
location.id = -1;
HIP_CHECK_ERROR(hipMemGetMemPool(&pool, &location, hipMemAllocationTypePinned),
hipErrorInvalidValue);
location.id = dev;
location.type = hipMemLocationTypeNone;
HIP_CHECK_ERROR(hipMemGetMemPool(&pool, &location, hipMemAllocationTypePinned),
hipErrorInvalidValue);
}
SECTION("Invalid allocation type") {
HIP_CHECK_ERROR(hipMemGetMemPool(&pool, &location, hipMemAllocationTypeInvalid),
hipErrorInvalidValue);
}
}
TEST_CASE("Unit_hipMemGetMemPool_Basic") {
int dev;
HIP_CHECK(hipGetDevice(&dev));
auto alloc_type = GENERATE(hipMemAllocationTypePinned, hipMemAllocationTypeManaged);
hipMemPool_t mem_pool, curr_mem_pool;
hipMemPoolProps prop{};
prop.allocType = alloc_type;
prop.location.id = dev;
prop.location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemPoolCreate(&mem_pool, &prop));
hipMemLocation location{};
location.id = dev;
location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemGetMemPool(&curr_mem_pool, &location, alloc_type));
REQUIRE(curr_mem_pool != nullptr);
HIP_CHECK(hipMemSetMemPool(&location, alloc_type, mem_pool));
HIP_CHECK(hipMemGetMemPool(&curr_mem_pool, &location, alloc_type));
REQUIRE(curr_mem_pool == mem_pool);
HIP_CHECK(hipMemPoolDestroy(mem_pool));
}
@@ -0,0 +1,137 @@
/*
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANNTY OF ANY KIND, EXPRESS OR
IMPLIED, INNCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANNY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER INN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR INN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#include <hip_test_common.hh>
#include "mempool_common.hh"
/**
* @addtogroup hipMemSetMemPool hipMemSetMemPool
* @{
* @ingroup MemoryTest
* `hipError_t hipMemSetMemPool(hipMemLocation* location, hipMemAllocationType type,
hipMemPool_t pool)` -
* Sets the current memory pool for the location and allocation type.
*/
TEST_CASE("Unit_hipMemSetMemPool_Negative") {
int dev;
HIP_CHECK(hipGetDevice(&dev));
checkMempoolSupported(dev);
hipMemPool_t mem_pool;
hipMemPoolProps prop{};
prop.allocType = hipMemAllocationTypePinned;
prop.location.id = dev;
prop.location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemPoolCreate(&mem_pool, &prop));
hipMemLocation location{};
location.id = dev;
location.type = hipMemLocationTypeDevice;
SECTION("Invalid location") {
HIP_CHECK_ERROR(hipMemSetMemPool(nullptr, hipMemAllocationTypePinned, mem_pool),
hipErrorInvalidValue);
location.id = -1;
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypePinned, mem_pool),
hipErrorInvalidValue);
location.id = dev;
location.type = hipMemLocationTypeNone;
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypePinned, mem_pool),
hipErrorInvalidValue);
}
SECTION("Invalid pool") {
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypePinned, nullptr),
hipErrorInvalidValue);
// Pool device and location device do not match
int dev_cnt = 0;
HIP_CHECK(hipGetDeviceCount(&dev_cnt));
if (dev_cnt > 1) {
hipMemPool_t mem_pool2;
prop.allocType = hipMemAllocationTypePinned;
prop.location.id = dev + 1;
prop.location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemPoolCreate(&mem_pool2, &prop));
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypePinned, mem_pool2),
hipErrorInvalidValue);
HIP_CHECK(hipMemPoolDestroy(mem_pool2));
}
}
SECTION("Using destroyed pool") {
// Create a temporary pool
hipMemPool_t temp_pool;
prop.allocType = hipMemAllocationTypePinned;
prop.location.id = dev;
prop.location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemPoolCreate(&temp_pool, &prop));
// Destroy it
HIP_CHECK(hipMemPoolDestroy(temp_pool));
// Try to set the destroyed pool - should fail
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypePinned, temp_pool),
hipErrorInvalidValue);
}
SECTION("Invalid allocation type") {
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypeInvalid, mem_pool),
hipErrorInvalidValue);
// Different than the one pool got created for
HIP_CHECK_ERROR(hipMemSetMemPool(&location, hipMemAllocationTypeManaged, mem_pool),
hipErrorInvalidValue);
}
HIP_CHECK(hipMemPoolDestroy(mem_pool));
}
TEST_CASE("Unit_hipMemSetMemPool_Basic") {
int num_devices;
HIP_CHECK(hipGetDeviceCount(&num_devices));
auto alloc_type = GENERATE(hipMemAllocationTypePinned, hipMemAllocationTypeManaged);
for (int dev = 0; dev < num_devices; dev++) {
checkMempoolSupported(dev);
HIP_CHECK(hipSetDevice(dev));
hipMemPool_t mem_pool, curr_mem_pool;
hipMemPoolProps prop{};
prop.allocType = alloc_type;
prop.location.id = dev;
prop.location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemPoolCreate(&mem_pool, &prop));
hipMemLocation location{};
location.id = dev;
location.type = hipMemLocationTypeDevice;
HIP_CHECK(hipMemSetMemPool(&location, alloc_type, mem_pool));
HIP_CHECK(hipMemGetMemPool(&curr_mem_pool, &location, alloc_type));
REQUIRE(curr_mem_pool == mem_pool);
HIP_CHECK(hipMemPoolDestroy(mem_pool));
}
}
+14
Ver fichero
@@ -1226,6 +1226,7 @@ typedef enum hipMemAllocationType {
* location while the application is actively using it
*/
hipMemAllocationTypePinned = 0x1,
hipMemAllocationTypeManaged = 0x2,
hipMemAllocationTypeUncached = 0x40000000,
hipMemAllocationTypeMax = 0x7FFFFFFF
} hipMemAllocationType;
@@ -4441,6 +4442,19 @@ hipError_t hipMemPoolExportPointer(hipMemPoolPtrExportData* export_data, void* d
*/
hipError_t hipMemPoolImportPointer(void** dev_ptr, hipMemPool_t mem_pool,
hipMemPoolPtrExportData* export_data);
/**
* @brief Sets memory pool for memory location and allocation type.
*
*
*/
hipError_t hipMemSetMemPool(hipMemLocation* location, hipMemAllocationType type, hipMemPool_t pool);
/**
* @brief Retrieves memory pool for memory location and allocation type.
*
*
*/
hipError_t hipMemGetMemPool(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type);
// Doxygen end of ordered memory allocator
/**
* @}
@@ -4098,6 +4098,18 @@ inline static hipError_t hipMemPoolImportPointer(void** ptr, hipMemPool_t mem_po
}
#endif // CUDA_VERSION >= CUDA_11020
#if CUDA_VERSION >= CUDA_13000
inline static hipError_t hipMemSetMemPool(hipMemLocation* location, hipMemAllocationType type,
hipMemPool_t pool) {
return hipCUDAErrorTohipError(cuMemSetMemPool(location, type, pool));
}
inline static hipError_t hipMemGetMemPool(hipMemPool_t* pool, hipMemLocation* location,
hipMemAllocationType type) {
return hipCUDAErrorTohipError(cuMemGetMemPool(pool, location, type));
}
#endif // CUDA_VERSION >= CUDA_13000
#ifdef __cplusplus
}
#endif
@@ -1009,6 +1009,10 @@ ROCPROFILER_ENUM_LABEL(ROCPROFILER_HIP_RUNTIME_API_ID_hipExtDisableLogging)
ROCPROFILER_ENUM_LABEL(ROCPROFILER_HIP_RUNTIME_API_ID_hipExtEnableLogging)
ROCPROFILER_ENUM_LABEL(ROCPROFILER_HIP_RUNTIME_API_ID_hipExtSetLoggingParams)
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 22
ROCPROFILER_ENUM_LABEL(ROCPROFILER_HIP_RUNTIME_API_ID_hipMemSetMemPool)
ROCPROFILER_ENUM_LABEL(ROCPROFILER_HIP_RUNTIME_API_ID_hipMemGetMemPool)
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION == 0
static_assert(ROCPROFILER_HIP_RUNTIME_API_ID_LAST == 442);
#elif HIP_RUNTIME_API_TABLE_STEP_VERSION == 1
@@ -1053,6 +1057,8 @@ static_assert(ROCPROFILER_HIP_RUNTIME_API_ID_LAST == 507);
static_assert(ROCPROFILER_HIP_RUNTIME_API_ID_LAST == 508);
#elif HIP_RUNTIME_API_TABLE_STEP_VERSION == 21
static_assert(ROCPROFILER_HIP_RUNTIME_API_ID_LAST == 511);
#elif HIP_RUNTIME_API_TABLE_STEP_VERSION == 22
static_assert(ROCPROFILER_HIP_RUNTIME_API_ID_LAST == 513);
#else
# if !defined(ROCPROFILER_UNSAFE_NO_VERSION_CHECK) && \
(defined(ROCPROFILER_CI) && ROCPROFILER_CI > 0)
@@ -3397,6 +3397,20 @@ typedef union rocprofiler_hip_api_args_t
size_t log_mask;
} hipExtSetLoggingParams;
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 22
struct
{
hipMemLocation* location;
hipMemAllocationType type;
hipMemPool_t pool;
} hipMemSetMemPool;
struct
{
hipMemPool_t* pool;
hipMemLocation* location;
hipMemAllocationType type;
} hipMemGetMemPool;
#endif
} rocprofiler_hip_api_args_t;
ROCPROFILER_EXTERN_C_FINI
@@ -580,6 +580,10 @@ typedef enum rocprofiler_hip_runtime_api_id_t // NOLINT(performance-enum-size)
ROCPROFILER_HIP_RUNTIME_API_ID_hipExtDisableLogging,
ROCPROFILER_HIP_RUNTIME_API_ID_hipExtEnableLogging,
ROCPROFILER_HIP_RUNTIME_API_ID_hipExtSetLoggingParams,
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 22
ROCPROFILER_HIP_RUNTIME_API_ID_hipMemSetMemPool,
ROCPROFILER_HIP_RUNTIME_API_ID_hipMemGetMemPool,
#endif
ROCPROFILER_HIP_RUNTIME_API_ID_LAST,
} rocprofiler_hip_runtime_api_id_t;
@@ -628,6 +628,10 @@ ROCP_SDK_ENFORCE_ABI(::HipDispatchTable, hipExtDisableLogging_fn, 508);
ROCP_SDK_ENFORCE_ABI(::HipDispatchTable, hipExtEnableLogging_fn, 509);
ROCP_SDK_ENFORCE_ABI(::HipDispatchTable, hipExtSetLoggingParams_fn, 510);
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 22
ROCP_SDK_ENFORCE_ABI(::HipDispatchTable, hipMemSetMemPool_fn, 511);
ROCP_SDK_ENFORCE_ABI(::HipDispatchTable, hipMemGetMemPool_fn, 512);
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION == 0
ROCP_SDK_ENFORCE_ABI_VERSIONING(::HipDispatchTable, 442)
@@ -673,6 +677,8 @@ ROCP_SDK_ENFORCE_ABI_VERSIONING(::HipDispatchTable, 507)
ROCP_SDK_ENFORCE_ABI_VERSIONING(::HipDispatchTable, 508)
#elif HIP_RUNTIME_API_TABLE_STEP_VERSION == 21
ROCP_SDK_ENFORCE_ABI_VERSIONING(::HipDispatchTable, 511)
#elif HIP_RUNTIME_API_TABLE_STEP_VERSION == 22
ROCP_SDK_ENFORCE_ABI_VERSIONING(::HipDispatchTable, 513)
#else
INTERNAL_CI_ROCP_SDK_ENFORCE_ABI_VERSIONING(::HipDispatchTable, 0)
#endif
@@ -265,6 +265,9 @@ struct formatter<hipMemAllocationType> : rocprofiler::hip::details::base_formatt
ROCP_SDK_HIP_FORMAT_CASE_STMT(hipMemAllocationType, Max);
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 14
ROCP_SDK_HIP_FORMAT_CASE_STMT(hipMemAllocationType, Uncached);
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 22
ROCP_SDK_HIP_FORMAT_CASE_STMT(hipMemAllocationType, Managed);
#endif
ROCP_SDK_HIP_FORMAT_DFLT_CASE(hipMemAllocationType);
}
@@ -655,6 +655,11 @@ HIP_API_INFO_DEFINITION_0(ROCPROFILER_HIP_TABLE_ID_Runtime, ROCPROFILER_HIP_RUNT
HIP_API_INFO_DEFINITION_0(ROCPROFILER_HIP_TABLE_ID_Runtime, ROCPROFILER_HIP_RUNTIME_API_ID_hipExtEnableLogging, hipExtEnableLogging, hipExtEnableLogging_fn);
HIP_API_INFO_DEFINITION_V(ROCPROFILER_HIP_TABLE_ID_Runtime, ROCPROFILER_HIP_RUNTIME_API_ID_hipExtSetLoggingParams, hipExtSetLoggingParams, hipExtSetLoggingParams_fn, log_level, log_size, log_mask);
#endif
#if HIP_RUNTIME_API_TABLE_STEP_VERSION >= 22
HIP_API_INFO_DEFINITION_V(ROCPROFILER_HIP_TABLE_ID_Runtime, ROCPROFILER_HIP_RUNTIME_API_ID_hipMemSetMemPool, hipMemSetMemPool, hipMemSetMemPool_fn, location, type, pool);
HIP_API_INFO_DEFINITION_V(ROCPROFILER_HIP_TABLE_ID_Runtime, ROCPROFILER_HIP_RUNTIME_API_ID_hipMemGetMemPool, hipMemGetMemPool, hipMemGetMemPool_fn, pool, location, type);
#endif
// clang-format on
#else