SWDEV-546179 - hipModuleGetFunctionCount AMD impl (#782)

* SWDEV-546179 - hipModuleGetFunctionCount AMD impl

* SWDEV-546179 - return invalid for count ptr

* SWDEV-546179 - hipModuleGetFunctionCount CHANGELOG.md update

[ROCm/clr commit: dfb46db2fb]
This commit is contained in:
GunaShekar, Ajay
2025-08-13 20:28:12 -07:00
committed by GitHub
parent 5c412edcd1
commit 76328ecfde
12 changed files with 87 additions and 5 deletions
+7
View File
@@ -2,6 +2,13 @@
Full documentation for HIP is available at [rocm.docs.amd.com](https://rocm.docs.amd.com/projects/HIP/en/latest/index.html)
## HIP 7.1 for ROCm 7.1
### Added
* New HIP APIs
- `hipModuleGetFunctionCount` returns the number of functions within a module
## HIP 7.0 for ROCm 7.0
### Added
@@ -63,7 +63,7 @@
#define HIP_API_TABLE_STEP_VERSION 0
#define HIP_COMPILER_API_TABLE_STEP_VERSION 0
#define HIP_TOOLS_API_TABLE_STEP_VERSION 0
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 13
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 14
// HIP API interface
// HIP compiler dispatch functions
@@ -635,6 +635,7 @@ typedef hipError_t (*t_hipMipmappedArrayGetLevel)(hipArray_t* pLevelArray,
unsigned int level);
typedef hipError_t (*t_hipModuleGetFunction)(hipFunction_t* function, hipModule_t module,
const char* kname);
typedef hipError_t (*t_hipModuleGetFunctionCount)(unsigned int* count, hipModule_t module);
typedef hipError_t (*t_hipModuleGetGlobal)(hipDeviceptr_t* dptr, size_t* bytes, hipModule_t hmod,
const char* name);
typedef hipError_t (*t_hipModuleGetTexRef)(textureReference** texRef, hipModule_t hmod,
@@ -1588,10 +1589,13 @@ struct HipDispatchTable {
t_hipMemGetHandleForAddressRange hipMemGetHandleForAddressRange_fn;
// HIP_RUNTIME_API_TABLE_STEP_VERSION = 13
t_hipModuleGetFunctionCount hipModuleGetFunctionCount_fn;
// HIP_RUNTIME_API_TABLE_STEP_VERSION = 14
// removed HIP_MEMSET_NODE_PARAMS replaced by hipMemsetParams
// DO NOT EDIT ABOVE!
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 13
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 14
// ******************************************************************************************* //
//
@@ -438,7 +438,8 @@ enum hip_api_id_t {
HIP_API_ID_hipLinkDestroy = 418,
HIP_API_ID_hipLaunchKernelExC = 419,
HIP_API_ID_hipDrvLaunchKernelEx = 420,
HIP_API_ID_LAST = 420,
HIP_API_ID_hipModuleGetFunctionCount = 421,
HIP_API_ID_LAST = 421,
HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice),
HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties),
@@ -885,6 +886,7 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipUserObjectRelease: return "hipUserObjectRelease";
case HIP_API_ID_hipUserObjectRetain: return "hipUserObjectRetain";
case HIP_API_ID_hipWaitExternalSemaphoresAsync: return "hipWaitExternalSemaphoresAsync";
case HIP_API_ID_hipModuleGetFunctionCount: return "hipModuleGetFunctionCount";
};
return "unknown";
};
@@ -1300,6 +1302,7 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipUserObjectRelease", name) == 0) return HIP_API_ID_hipUserObjectRelease;
if (strcmp("hipUserObjectRetain", name) == 0) return HIP_API_ID_hipUserObjectRetain;
if (strcmp("hipWaitExternalSemaphoresAsync", name) == 0) return HIP_API_ID_hipWaitExternalSemaphoresAsync;
if (strcmp("hipModuleGetFunctionCount", name) == 0) return HIP_API_ID_hipModuleGetFunctionCount;
return HIP_API_ID_NONE;
}
@@ -3289,6 +3292,11 @@ typedef struct hip_api_data_s {
const char* kname;
char kname__val;
} hipModuleGetFunction;
struct {
unsigned int* count;
unsigned int count__val;
hipModule_t mod;
} hipModuleGetFunctionCount;
struct {
hipDeviceptr_t* dptr;
hipDeviceptr_t dptr__val;
@@ -6236,6 +6244,12 @@ typedef struct hip_api_data_s {
cb_data.args.hipWaitExternalSemaphoresAsync.numExtSems = (unsigned int)numExtSems; \
cb_data.args.hipWaitExternalSemaphoresAsync.stream = (hipStream_t)stream; \
};
// hipModuleGetFunctionCount[('unsigned int*', 'count'), ('hipModule_t', 'mod')]
#define INIT_hipModuleGetFunctionCount_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipModuleGetFunctionCount.count = (unsigned int*)count; \
cb_data.args.hipModuleGetFunctionCount.mod = (hipModule_t)mod; \
};
#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data)
// Macros for non-public API primitives
@@ -7907,6 +7921,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
if (data->args.hipWaitExternalSemaphoresAsync.extSemArray) data->args.hipWaitExternalSemaphoresAsync.extSemArray__val = *(data->args.hipWaitExternalSemaphoresAsync.extSemArray);
if (data->args.hipWaitExternalSemaphoresAsync.paramsArray) data->args.hipWaitExternalSemaphoresAsync.paramsArray__val = *(data->args.hipWaitExternalSemaphoresAsync.paramsArray);
break;
// hipModuleGetFunctionCount[('unsigned int*', 'count'), ('hipModule_t', 'mod')]
case HIP_API_ID_hipModuleGetFunctionCount:
if (data->args.hipModuleGetFunctionCount.count) data->args.hipModuleGetFunctionCount.count__val = *(data->args.hipModuleGetFunctionCount.count);
break;
default: break;
};
}
@@ -11195,6 +11213,13 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipWaitExternalSemaphoresAsync.stream);
oss << ")";
break;
case HIP_API_ID_hipModuleGetFunctionCount:
oss << "hipModuleGetFunctionCount(";
if (data->args.hipModuleGetFunctionCount.count == NULL) oss << "count=NULL";
else { oss << "count="; roctracer::hip_support::detail::operator<<(oss, data->args.hipModuleGetFunctionCount.count__val); }
oss << ", mod="; roctracer::hip_support::detail::operator<<(oss, data->args.hipModuleGetFunctionCount.mod);
oss << ")";
break;
default: oss << "unknown";
};
return strdup(oss.str().c_str());
+1
View File
@@ -492,3 +492,4 @@ hipLinkCreate
hipLinkDestroy
hipLaunchKernelExC
hipDrvLaunchKernelEx
hipModuleGetFunctionCount
+6 -2
View File
@@ -519,6 +519,7 @@ hipError_t hipMipmappedArrayDestroy(hipMipmappedArray_t hMipmappedArray);
hipError_t hipMipmappedArrayGetLevel(hipArray_t* pLevelArray, hipMipmappedArray_t hMipMappedArray,
unsigned int level);
hipError_t hipModuleGetFunction(hipFunction_t* function, hipModule_t module, const char* kname);
hipError_t hipModuleGetFunctionCount(unsigned int* count, hipModule_t mod);
hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes, hipModule_t hmod,
const char* name);
hipError_t hipModuleGetTexRef(textureReference** texRef, hipModule_t hmod, const char* name);
@@ -1174,6 +1175,7 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
ptrDispatchTable->hipMipmappedArrayDestroy_fn = hip::hipMipmappedArrayDestroy;
ptrDispatchTable->hipMipmappedArrayGetLevel_fn = hip::hipMipmappedArrayGetLevel;
ptrDispatchTable->hipModuleGetFunction_fn = hip::hipModuleGetFunction;
ptrDispatchTable->hipModuleGetFunctionCount_fn = hip::hipModuleGetFunctionCount;
ptrDispatchTable->hipModuleGetGlobal_fn = hip::hipModuleGetGlobal;
ptrDispatchTable->hipModuleGetTexRef_fn = hip::hipModuleGetTexRef;
ptrDispatchTable->hipModuleLaunchCooperativeKernel_fn = hip::hipModuleLaunchCooperativeKernel;
@@ -1989,15 +1991,17 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipLaunchKernelExC_fn, 474);
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvLaunchKernelEx_fn, 475);
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 12
HIP_ENFORCE_ABI(HipDispatchTable, hipMemGetHandleForAddressRange_fn, 476);
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 13
HIP_ENFORCE_ABI(HipDispatchTable, hipModuleGetFunctionCount_fn, 477);
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
//
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
//
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 477)
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 478)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 13,
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 14,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
"pointers and then update this check so it is true");
#endif
@@ -124,6 +124,15 @@ hipError_t DynCO::getDynFunc(hipFunction_t* hfunc, std::string func_name) {
return it->second->getDynFunc(hfunc, module_);
}
hipError_t DynCO::getFuncCount(unsigned int* count) {
amd::ScopedLock lock(dclock_);
if (count == nullptr) {
return hipErrorInvalidValue;
}
*count = functions_.size();
return hipSuccess;
}
bool DynCO::isValidDynFunc(const void* hfunc) {
amd::ScopedLock lock(dclock_);
return std::any_of(functions_.begin(), functions_.end(),
@@ -111,6 +111,7 @@ public:
//Gets GlobalVar/Functions from a dynamically loaded code object
hipError_t getDynFunc(hipFunction_t* hfunc, std::string func_name);
hipError_t getFuncCount(unsigned int* count);
bool isValidDynFunc(const void* hfunc);
hipError_t getDeviceVar(DeviceVar** dvar, std::string var_name);
+7
View File
@@ -605,3 +605,10 @@ global:
local:
*;
} hip_6.4;
hip_7.1 {
global:
hipModuleGetFunctionCount;
local:
*;
} hip_6.5;
+9
View File
@@ -89,6 +89,15 @@ hipError_t hipModuleGetFunction(hipFunction_t* hfunc, hipModule_t hmod, const ch
HIP_RETURN(hipSuccess);
}
hipError_t hipModuleGetFunctionCount(unsigned int* count, hipModule_t mod) {
HIP_INIT_API(hipModuleGetFunctionCount, count, mod);
if (mod == nullptr) {
HIP_RETURN(hipErrorInvalidResourceHandle);
}
HIP_RETURN(PlatformState::instance().getFuncCount(count, mod););
}
hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes, hipModule_t hmod,
const char* name) {
HIP_INIT_API(hipModuleGetGlobal, dptr, bytes, hmod, name);
+11
View File
@@ -834,6 +834,17 @@ hipError_t PlatformState::getDynFunc(hipFunction_t* hfunc, hipModule_t hmod,
return it->second->getDynFunc(hfunc, func_name);
}
hipError_t PlatformState::getFuncCount(unsigned int* count, hipModule_t hmod) {
amd::ScopedLock lock(lock_);
auto it = dynCO_map_.find(hmod);
if (it == dynCO_map_.end()) {
LogPrintfError("Cannot find the module: 0x%x", hmod);
return hipErrorNotFound;
}
return it->second->getFuncCount(count);
}
bool PlatformState::isValidDynFunc(const void* hfunc) {
amd::ScopedLock lock(lock_);
return std::any_of(dynCO_map_.begin(), dynCO_map_.end(),
+1
View File
@@ -63,6 +63,7 @@ class PlatformState {
hipError_t unloadModule(hipModule_t hmod);
bool isValidDynFunc(const void* hfunc);
hipError_t getDynFunc(hipFunction_t* hfunc, hipModule_t hmod, const char* func_name);
hipError_t getFuncCount(unsigned int* count, hipModule_t hmod);
hipError_t getDynGlobalVar(const char* hostVar, hipModule_t hmod, hipDeviceptr_t* dev_ptr,
size_t* size_ptr);
hipError_t getDynTexRef(const char* hostVar, hipModule_t hmod, textureReference** texRef);
@@ -1220,6 +1220,9 @@ hipError_t hipMipmappedArrayGetLevel(hipArray_t* pLevelArray, hipMipmappedArray_
hipError_t hipModuleGetFunction(hipFunction_t* function, hipModule_t module, const char* kname) {
return hip::GetHipDispatchTable()->hipModuleGetFunction_fn(function, module, kname);
}
hipError_t hipModuleGetFunctionCount(unsigned int* count, hipModule_t mod) {
return hip::GetHipDispatchTable()->hipModuleGetFunctionCount_fn(count, mod);
}
hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes, hipModule_t hmod,
const char* name) {
return hip::GetHipDispatchTable()->hipModuleGetGlobal_fn(dptr, bytes, hmod, name);