SWDEV-546311 - implement hipKernelGetLibrary & hipLibraryEnumerateKer… (#1143)

* SWDEV-546311 - implement hipKernelGetLibrary & hipLibraryEnumerateKernels API

* Fix for LibraryEnumerateKernel and KernelGetName

* Update Enumerate Kernels to handle 0 numKernels

* Minor fixes to function names

* fix error checking in internal function

* Update changelog for new apis

---------

Co-authored-by: Rahul Manocha <rmanocha@amd.com>
这个提交包含在:
Rahul Manocha
2025-10-27 14:13:17 -07:00
提交者 GitHub
父节点 3e59eebf17
当前提交 f5d901f016
修改 13 个文件,包含 340 行新增26 行删除
+3
查看文件
@@ -7,6 +7,9 @@ Full documentation for HIP is available at [rocm.docs.amd.com](https://rocm.docs
### Added
* New HIP APIs
- `hipLibraryEnumerateKernels` Return Kernel handles within a library
- `hipKernelGetLibrary` Return Library handle for a hipKernel_t handle
- `hipKernelGetName` Return function name for a hipKernel_t handle
- `hipLibraryLoadData` creates library object from code
- `hipLibraryLoadFromFile` creates library object from file
- `hipLibraryUnload` unloads library
@@ -63,7 +63,7 @@
#define HIP_API_TABLE_STEP_VERSION 0
#define HIP_COMPILER_API_TABLE_STEP_VERSION 0
#define HIP_TOOLS_API_TABLE_STEP_VERSION 0
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 16
#define HIP_RUNTIME_API_TABLE_STEP_VERSION 17
// HIP API interface
// HIP compiler dispatch functions
@@ -1105,6 +1105,10 @@ typedef hipError_t (*t_hipLibraryGetKernel)(hipKernel_t* pKernel, hipLibrary_t l
const char* name);
typedef hipError_t (*t_hipLibraryGetKernelCount)(unsigned int *count,
hipLibrary_t library);
typedef hipError_t (*t_hipLibraryEnumerateKernels)(hipKernel_t* kernels, unsigned int numKernels,
hipLibrary_t library);
typedef hipError_t (*t_hipKernelGetLibrary)(hipLibrary_t* library, hipKernel_t kernel);
typedef hipError_t (*t_hipKernelGetName)(const char** name, hipKernel_t kernel);
// HIP Compiler dispatch table
struct HipCompilerDispatchTable {
@@ -1683,8 +1687,13 @@ struct HipDispatchTable {
// HIP_RUNTIME_API_TABLE_STEP_VERSION = 16
t_hipStreamCopyAttributes hipStreamCopyAttributes_fn;
// HIP_RUNTIME_API_TABLE_STEP_VERSION = 17
t_hipLibraryEnumerateKernels hipLibraryEnumerateKernels_fn;
t_hipKernelGetLibrary hipKernelGetLibrary_fn;
t_hipKernelGetName hipKernelGetName_fn;
// DO NOT EDIT ABOVE!
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 17
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 18
// ******************************************************************************************* //
//
@@ -463,7 +463,10 @@ enum hip_api_id_t {
HIP_API_ID_hipLibraryGetKernelCount = 443,
HIP_API_ID_hipMemGetHandleForAddressRange = 444,
HIP_API_ID_hipStreamCopyAttributes = 445,
HIP_API_ID_LAST = 445,
HIP_API_ID_hipKernelGetLibrary = 446,
HIP_API_ID_hipLibraryEnumerateKernels = 447,
HIP_API_ID_hipKernelGetName = 448,
HIP_API_ID_LAST = 448,
HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice),
HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties),
@@ -727,12 +730,15 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipIpcGetMemHandle: return "hipIpcGetMemHandle";
case HIP_API_ID_hipIpcOpenEventHandle: return "hipIpcOpenEventHandle";
case HIP_API_ID_hipIpcOpenMemHandle: return "hipIpcOpenMemHandle";
case HIP_API_ID_hipKernelGetLibrary: return "hipKernelGetLibrary";
case HIP_API_ID_hipKernelGetName: return "hipKernelGetName";
case HIP_API_ID_hipLaunchByPtr: return "hipLaunchByPtr";
case HIP_API_ID_hipLaunchCooperativeKernel: return "hipLaunchCooperativeKernel";
case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: return "hipLaunchCooperativeKernelMultiDevice";
case HIP_API_ID_hipLaunchHostFunc: return "hipLaunchHostFunc";
case HIP_API_ID_hipLaunchKernel: return "hipLaunchKernel";
case HIP_API_ID_hipLaunchKernelExC: return "hipLaunchKernelExC";
case HIP_API_ID_hipLibraryEnumerateKernels: return "hipLibraryEnumerateKernels";
case HIP_API_ID_hipLibraryGetKernel: return "hipLibraryGetKernel";
case HIP_API_ID_hipLibraryGetKernelCount: return "hipLibraryGetKernelCount";
case HIP_API_ID_hipLibraryLoadData: return "hipLibraryLoadData";
@@ -1166,12 +1172,15 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipIpcGetMemHandle", name) == 0) return HIP_API_ID_hipIpcGetMemHandle;
if (strcmp("hipIpcOpenEventHandle", name) == 0) return HIP_API_ID_hipIpcOpenEventHandle;
if (strcmp("hipIpcOpenMemHandle", name) == 0) return HIP_API_ID_hipIpcOpenMemHandle;
if (strcmp("hipKernelGetLibrary", name) == 0) return HIP_API_ID_hipKernelGetLibrary;
if (strcmp("hipKernelGetName", name) == 0) return HIP_API_ID_hipKernelGetName;
if (strcmp("hipLaunchByPtr", name) == 0) return HIP_API_ID_hipLaunchByPtr;
if (strcmp("hipLaunchCooperativeKernel", name) == 0) return HIP_API_ID_hipLaunchCooperativeKernel;
if (strcmp("hipLaunchCooperativeKernelMultiDevice", name) == 0) return HIP_API_ID_hipLaunchCooperativeKernelMultiDevice;
if (strcmp("hipLaunchHostFunc", name) == 0) return HIP_API_ID_hipLaunchHostFunc;
if (strcmp("hipLaunchKernel", name) == 0) return HIP_API_ID_hipLaunchKernel;
if (strcmp("hipLaunchKernelExC", name) == 0) return HIP_API_ID_hipLaunchKernelExC;
if (strcmp("hipLibraryEnumerateKernels", name) == 0) return HIP_API_ID_hipLibraryEnumerateKernels;
if (strcmp("hipLibraryGetKernel", name) == 0) return HIP_API_ID_hipLibraryGetKernel;
if (strcmp("hipLibraryGetKernelCount", name) == 0) return HIP_API_ID_hipLibraryGetKernelCount;
if (strcmp("hipLibraryLoadData", name) == 0) return HIP_API_ID_hipLibraryLoadData;
@@ -2672,6 +2681,16 @@ typedef struct hip_api_data_s {
hipIpcMemHandle_t handle;
unsigned int flags;
} hipIpcOpenMemHandle;
struct {
hipLibrary_t* library;
hipLibrary_t library__val;
hipKernel_t kernel;
} hipKernelGetLibrary;
struct {
const char** name;
const char* name__val;
hipKernel_t kernel;
} hipKernelGetName;
struct {
const void* hostFunction;
} hipLaunchByPtr;
@@ -2711,6 +2730,12 @@ typedef struct hip_api_data_s {
void** args;
void* args__val;
} hipLaunchKernelExC;
struct {
hipKernel_t* kernels;
hipKernel_t kernels__val;
unsigned int numKernels;
hipLibrary_t library;
} hipLibraryEnumerateKernels;
struct {
hipKernel_t* pKernel;
hipKernel_t pKernel__val;
@@ -5307,6 +5332,16 @@ typedef struct hip_api_data_s {
cb_data.args.hipIpcOpenMemHandle.handle = (hipIpcMemHandle_t)handle; \
cb_data.args.hipIpcOpenMemHandle.flags = (unsigned int)flags; \
};
// hipKernelGetLibrary[('hipLibrary_t*', 'library'), ('hipKernel_t', 'kernel')]
#define INIT_hipKernelGetLibrary_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipKernelGetLibrary.library = (hipLibrary_t*)library; \
cb_data.args.hipKernelGetLibrary.kernel = (hipKernel_t)kernel; \
};
// hipKernelGetName[('const char**', 'name'), ('hipKernel_t', 'kernel')]
#define INIT_hipKernelGetName_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipKernelGetName.name = (const char**)name; \
cb_data.args.hipKernelGetName.kernel = (hipKernel_t)kernel; \
};
// hipLaunchByPtr[('const void*', 'hostFunction')]
#define INIT_hipLaunchByPtr_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipLaunchByPtr.hostFunction = (const void*)hostFunction; \
@@ -5347,6 +5382,12 @@ typedef struct hip_api_data_s {
cb_data.args.hipLaunchKernelExC.fPtr = (const void*)fPtr; \
cb_data.args.hipLaunchKernelExC.args = (void**)args; \
};
// hipLibraryEnumerateKernels[('hipKernel_t*', 'kernels'), ('unsigned int', 'numKernels'), ('hipLibrary_t', 'library')]
#define INIT_hipLibraryEnumerateKernels_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipLibraryEnumerateKernels.kernels = (hipKernel_t*)kernels; \
cb_data.args.hipLibraryEnumerateKernels.numKernels = (unsigned int)numKernels; \
cb_data.args.hipLibraryEnumerateKernels.library = (hipLibrary_t)library; \
};
// hipLibraryGetKernel[('hipKernel_t*', 'pKernel'), ('hipLibrary_t', 'library'), ('const char*', 'name')]
#define INIT_hipLibraryGetKernel_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipLibraryGetKernel.pKernel = (hipKernel_t*)kernel; \
@@ -7632,6 +7673,14 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
case HIP_API_ID_hipIpcOpenMemHandle:
if (data->args.hipIpcOpenMemHandle.devPtr) data->args.hipIpcOpenMemHandle.devPtr__val = *(data->args.hipIpcOpenMemHandle.devPtr);
break;
// hipKernelGetLibrary[('hipLibrary_t*', 'library'), ('hipKernel_t', 'kernel')]
case HIP_API_ID_hipKernelGetLibrary:
if (data->args.hipKernelGetLibrary.library) data->args.hipKernelGetLibrary.library__val = *(data->args.hipKernelGetLibrary.library);
break;
// hipKernelGetName[('const char**', 'name'), ('hipKernel_t', 'kernel')]
case HIP_API_ID_hipKernelGetName:
if (data->args.hipKernelGetName.name) data->args.hipKernelGetName.name__val = *(data->args.hipKernelGetName.name);
break;
// hipLaunchByPtr[('const void*', 'hostFunction')]
case HIP_API_ID_hipLaunchByPtr:
break;
@@ -7655,6 +7704,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
if (data->args.hipLaunchKernelExC.config) data->args.hipLaunchKernelExC.config__val = *(data->args.hipLaunchKernelExC.config);
if (data->args.hipLaunchKernelExC.args) data->args.hipLaunchKernelExC.args__val = *(data->args.hipLaunchKernelExC.args);
break;
// hipLibraryEnumerateKernels[('hipKernel_t*', 'kernels'), ('unsigned int', 'numKernels'), ('hipLibrary_t', 'library')]
case HIP_API_ID_hipLibraryEnumerateKernels:
if (data->args.hipLibraryEnumerateKernels.kernels) data->args.hipLibraryEnumerateKernels.kernels__val = *(data->args.hipLibraryEnumerateKernels.kernels);
break;
// hipLibraryGetKernel[('hipKernel_t*', 'pKernel'), ('hipLibrary_t', 'library'), ('const char*', 'name')]
case HIP_API_ID_hipLibraryGetKernel:
if (data->args.hipLibraryGetKernel.pKernel) data->args.hipLibraryGetKernel.pKernel__val = *(data->args.hipLibraryGetKernel.pKernel);
@@ -10201,6 +10254,20 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
oss << ", flags="; roctracer::hip_support::detail::operator<<(oss, data->args.hipIpcOpenMemHandle.flags);
oss << ")";
break;
case HIP_API_ID_hipKernelGetLibrary:
oss << "hipKernelGetLibrary(";
if (data->args.hipKernelGetLibrary.library == NULL) oss << "library=NULL";
else { oss << "library="; roctracer::hip_support::detail::operator<<(oss, data->args.hipKernelGetLibrary.library__val); }
oss << ", kernel="; roctracer::hip_support::detail::operator<<(oss, data->args.hipKernelGetLibrary.kernel);
oss << ")";
break;
case HIP_API_ID_hipKernelGetName:
oss << "hipKernelGetName(";
if (data->args.hipKernelGetName.name == NULL) oss << "name=NULL";
else { oss << "name="; roctracer::hip_support::detail::operator<<(oss, (void*)data->args.hipKernelGetName.name__val); }
oss << ", kernel="; roctracer::hip_support::detail::operator<<(oss, data->args.hipKernelGetName.kernel);
oss << ")";
break;
case HIP_API_ID_hipLaunchByPtr:
oss << "hipLaunchByPtr(";
oss << "hostFunction="; roctracer::hip_support::detail::operator<<(oss, data->args.hipLaunchByPtr.hostFunction);
@@ -10252,6 +10319,14 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
else { oss << ", args="; roctracer::hip_support::detail::operator<<(oss, data->args.hipLaunchKernelExC.args__val); }
oss << ")";
break;
case HIP_API_ID_hipLibraryEnumerateKernels:
oss << "hipLibraryEnumerateKernels(";
if (data->args.hipLibraryEnumerateKernels.kernels == NULL) oss << "kernels=NULL";
else { oss << "kernels="; roctracer::hip_support::detail::operator<<(oss, data->args.hipLibraryEnumerateKernels.kernels__val); }
oss << ", numKernels="; roctracer::hip_support::detail::operator<<(oss, data->args.hipLibraryEnumerateKernels.numKernels);
oss << ", library="; roctracer::hip_support::detail::operator<<(oss, data->args.hipLibraryEnumerateKernels.library);
oss << ")";
break;
case HIP_API_ID_hipLibraryGetKernel:
oss << "hipLibraryGetKernel(";
if (data->args.hipLibraryGetKernel.pKernel == NULL) oss << "pKernel=NULL";
+3
查看文件
@@ -517,3 +517,6 @@ hipLibraryUnload
hipLibraryGetKernel
hipLibraryGetKernelCount
hipStreamCopyAttributes
hipLibraryEnumerateKernels
hipKernelGetLibrary
hipKernelGetName
+13 -2
查看文件
@@ -875,6 +875,10 @@ hipError_t hipLibraryLoadFromFile(hipLibrary_t* library, const char* fileName,
hipError_t hipLibraryUnload(hipLibrary_t library);
hipError_t hipLibraryGetKernel(hipKernel_t* pKernel, hipLibrary_t library, const char* name);
hipError_t hipLibraryGetKernelCount(unsigned int* count, hipLibrary_t library);
hipError_t hipLibraryEnumerateKernels(hipKernel_t* kernels, unsigned int numKernels,
hipLibrary_t library);
hipError_t hipKernelGetLibrary(hipLibrary_t* library, hipKernel_t kernel);
hipError_t hipKernelGetName(const char** name, hipKernel_t kernel);
} // namespace hip
namespace hip {
@@ -1416,6 +1420,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
ptrDispatchTable->hipLibraryUnload_fn = hip::hipLibraryUnload;
ptrDispatchTable->hipLibraryGetKernel_fn = hip::hipLibraryGetKernel;
ptrDispatchTable->hipLibraryGetKernelCount_fn = hip::hipLibraryGetKernelCount;
ptrDispatchTable->hipLibraryEnumerateKernels_fn = hip::hipLibraryEnumerateKernels;
ptrDispatchTable->hipKernelGetLibrary_fn = hip::hipKernelGetLibrary;
ptrDispatchTable->hipKernelGetName_fn = hip::hipKernelGetName;
}
#if HIP_ROCPROFILER_REGISTER > 0
@@ -2088,15 +2095,19 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipLibraryGetKernel_fn, 499);
HIP_ENFORCE_ABI(HipDispatchTable, hipLibraryGetKernelCount_fn, 500);
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 16
HIP_ENFORCE_ABI(HipDispatchTable, hipStreamCopyAttributes_fn, 501);
// HIP_RUNTIME_API_TABLE_STEP_VERSION == 17
HIP_ENFORCE_ABI(HipDispatchTable, hipLibraryEnumerateKernels_fn, 502);
HIP_ENFORCE_ABI(HipDispatchTable, hipKernelGetLibrary_fn, 503);
HIP_ENFORCE_ABI(HipDispatchTable, hipKernelGetName_fn, 504);
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
//
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
//
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 502)
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 505)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 16,
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 17,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
"pointers and then update this check so it is true");
#endif
+3
查看文件
@@ -639,6 +639,9 @@ global:
hipLibraryGetKernel;
hipLibraryGetKernelCount;
hipStreamCopyAttributes;
hipLibraryEnumerateKernels;
hipKernelGetLibrary;
hipKernelGetName;
local:
*;
} hip_7.1;
+97 -4
查看文件
@@ -36,12 +36,50 @@ void LibraryContainer::Register(std::string name, int device, hipKernel_t k) {
auto key = std::make_pair(name, device);
if (kernels_.find(key) == kernels_.end()) {
kernels_.insert(std::make_pair(std::make_pair(name, device), k));
if (!hip::PlatformState::instance().RegisterLibraryFunction(k)) {
auto lib = reinterpret_cast<hipLibrary_t>(this);
if (!hip::PlatformState::instance().RegisterLibraryFunction(k, lib)) {
LogPrintfInfo("Already registered: %p", k);
}
}
}
hipError_t LibraryContainer::GetKernelName(const char** name, hipKernel_t kernel) {
if (kernels_.empty()) {
return hipErrorInvalidValue;
}
for (const auto &it : kernels_) {
if (it.second == kernel) {
*name = it.first.first.c_str();
return hipSuccess;
}
}
return hipErrorInvalidValue;
}
hipError_t LibraryContainer::EnumerateKernels(hipKernel_t* k, unsigned int maxKernels) {
auto maxCount = (maxKernels > functions_.size()) ? functions_.size() : maxKernels;
auto device_id = hip::ihipGetDevice();
auto m = fatbin_->Module(device_id);
auto count = 0;
for (const auto&f : functions_) {
if (count >= maxCount) break;
hipKernel_t kern;
// build library only for un-registered kernels
if (auto ki = kernels_.find(std::make_pair(f.first, device_id)); ki!= kernels_.end()) {
kern = ki->second;
} else {
auto ret = f.second.get()->getDynFunc(reinterpret_cast<hipFunction_t*>(&kern), m);
if (ret != hipSuccess) {
return ret;
}
Register(f.first, device_id, kern);
}
k[count++] = kern;
}
return hipSuccess;
}
hipError_t LibraryContainer::Kernel(hipKernel_t* k, std::string name) {
auto device_id = hip::ihipGetDevice();
if (auto ki = kernels_.find(std::make_pair(name, device_id)); ki != kernels_.end()) {
@@ -54,7 +92,9 @@ hipError_t LibraryContainer::Kernel(hipKernel_t* k, std::string name) {
return hipErrorNotFound;
}
auto ret = f->second.get()->getDynFunc(reinterpret_cast<hipFunction_t*>(k), m);
if (ret != hipSuccess) {
return ret;
}
// Register it, basically make it available for query though the hip context.
Register(name, device_id, *k);
return hipSuccess;
@@ -93,9 +133,11 @@ hipError_t LibraryContainer::BuildIt() {
IHIP_RETURN_ONFAIL(fatbin_->BuildProgram(device_id));
auto program =
fatbin_->GetProgram(device_id)->getDeviceProgram(*hip::getCurrentDevice()->devices()[0]);
fatbin_->GetProgram(device_id)->getDeviceProgram(*hip::getCurrentDevice()->devices()[0]);
auto mod =
fatbin_->Module(device_id);
// Process Functions
// Process Functions and create kernel handles
std::vector<std::string> function_names;
program->getGlobalFuncFromCodeObj(&function_names);
for (auto& name : function_names) {
@@ -177,4 +219,55 @@ hipError_t hipLibraryGetKernel(hipKernel_t* kernel, hipLibrary_t library, const
ret = l->Kernel(kernel, kname);
HIP_RETURN(ret);
}
hipError_t hipLibraryEnumerateKernels(hipKernel_t* kernels, unsigned int numKernels,
hipLibrary_t library) {
HIP_INIT_API(hipLibraryEnumerateKernels, kernels, numKernels, library);
if (kernels == nullptr || library == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
auto l = reinterpret_cast<hip::LibraryContainer*>(library);
auto ret = l->BuildIt();
if (ret != hipSuccess) {
HIP_RETURN(ret);
}
if (numKernels == 0) {
HIP_RETURN(hipSuccess);
}
HIP_RETURN(l->EnumerateKernels(kernels, numKernels));
}
hipError_t hipKernelGetLibrary(hipLibrary_t* library, hipKernel_t kernel) {
HIP_INIT_API(hipKernelGetLibrary, library, kernel);
if (library == nullptr || kernel == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
if (!hip::PlatformState::instance().GetFunctionLibrary(kernel, library)) {
HIP_RETURN(hipErrorInvalidHandle);
}
HIP_RETURN(hipSuccess);
}
hipError_t hipKernelGetName(const char** name, hipKernel_t kernel) {
HIP_INIT_API(hipKernelGetName, name, kernel);
if (name == nullptr || kernel == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hipLibrary_t library;
if (!hip::PlatformState::instance().GetFunctionLibrary(kernel, &library)) {
HIP_RETURN(hipErrorInvalidHandle);
}
auto l = reinterpret_cast<hip::LibraryContainer*>(library);
auto ret = l->GetKernelName(name, kernel);
HIP_RETURN(ret);
}
} // namespace hip
+4
查看文件
@@ -57,6 +57,10 @@ class LibraryContainer {
// Register the kernel function, make an entry in global state
void Register(std::string name, int device, hipKernel_t k);
// Enumerate atmost maxKernels kernel handles in this library
hipError_t EnumerateKernels(hipKernel_t* k, unsigned int maxKernels);
hipError_t GetKernelName(const char** name, hipKernel_t kernel);
private:
LibraryContainer() = delete;
LibraryContainer(const LibraryContainer&) = delete;
+12 -3
查看文件
@@ -113,10 +113,10 @@ class PlatformState {
size_t UfdMapSize() const { return ufd_map_.size(); }
inline bool RegisterLibraryFunction(const hipKernel_t f) {
inline bool RegisterLibraryFunction(const hipKernel_t f, const hipLibrary_t l) {
amd::ScopedLock lock(lock_);
if (library_functions_.find(f) == library_functions_.end()) {
library_functions_.insert(f);
library_functions_.insert(std::make_pair(f, l));
return true;
}
return false;
@@ -130,6 +130,15 @@ class PlatformState {
return false;
}
inline bool GetFunctionLibrary(const hipKernel_t f, hipLibrary_t* lib) {
amd::ScopedLock lock(lock_);
if (library_functions_.find(f) != library_functions_.end()) {
*lib = library_functions_[f];
return true;
}
return false;
}
private:
// Dynamic Code Object map, keyin module to get the corresponding object
std::unordered_map<hipModule_t, hip::DynCO*> dynCO_map_;
@@ -140,6 +149,6 @@ class PlatformState {
std::unordered_map<std::string, std::shared_ptr<UniqueFD>> ufd_map_; //!< Unique File Desc Map
void* dynamicLibraryHandle_{nullptr};
std::unordered_set<hipKernel_t> library_functions_;
std::unordered_map<hipKernel_t, hipLibrary_t> library_functions_;
};
} // namespace hip
@@ -2040,4 +2040,14 @@ hipError_t hipLibraryGetKernel(hipKernel_t* pKernel, hipLibrary_t library, const
hipError_t hipLibraryGetKernelCount(unsigned int *count, hipLibrary_t library) {
return hip::GetHipDispatchTable()->hipLibraryGetKernelCount_fn(count,
library);
}
hipError_t hipLibraryEnumerateKernels(hipKernel_t* kernels, unsigned int numKernels,
hipLibrary_t library) {
return hip::GetHipDispatchTable()->hipLibraryEnumerateKernels_fn(kernels, numKernels, library);
}
hipError_t hipKernelGetLibrary(hipLibrary_t* library, hipKernel_t kernel) {
return hip::GetHipDispatchTable()->hipKernelGetLibrary_fn(library, kernel);
}
hipError_t hipKernelGetName(const char** name, hipKernel_t kernel) {
return hip::GetHipDispatchTable()->hipKernelGetName_fn(name, kernel);
}
@@ -25,6 +25,7 @@ THE SOFTWARE.
TEST_CASE("Unit_hip_library_load_co") {
constexpr size_t size = 32;
constexpr size_t num_kernels = 3;
std::vector<float> input1, input2;
input1.reserve(size);
input2.reserve(size);
@@ -45,6 +46,19 @@ TEST_CASE("Unit_hip_library_load_co") {
hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));
std::string lib_co = "library_code_load.code";
auto host_verify = [](const std::vector<float>& x, const std::vector<float>&y, const std::vector<float>& expected, int op) {
for (size_t i = 0; i < size; i++) {
float tmp;
switch(op) {
case 0: tmp = x[i] + y[i]; break;
case 1: tmp = x[i] - y[i]; break;
case 2: tmp = x[i] * y[i]; break;
default: tmp = 0;
}
REQUIRE(tmp == expected[i]);
}
};
SECTION("One Kernel") {
hipLibrary_t library;
@@ -54,6 +68,10 @@ TEST_CASE("Unit_hip_library_load_co") {
hipLibraryLoadFromFile(&library, lib_co.data(), nullptr, nullptr, 0, nullptr, nullptr, 0));
HIP_CHECK(hipLibraryGetKernel(&function, library, "add_kernel"));
hipLibrary_t new_library;
HIP_CHECK(hipKernelGetLibrary(&new_library, function));
REQUIRE(new_library == library);
unsigned int count = 0;
HIP_CHECK(hipLibraryGetKernelCount(&count, library));
REQUIRE(count == 3);
@@ -67,11 +85,7 @@ TEST_CASE("Unit_hip_library_load_co") {
std::vector<float> out(size, 0);
HIP_CHECK(hipMemcpy(out.data(), d_out, sizeof(float) * size, hipMemcpyDeviceToHost));
for (size_t i = 0; i < size; i++) {
float tmp = input1[i] + input2[i];
INFO("Index: " << i << " cpu res: " << tmp << " gpu res: " << out[i]);
REQUIRE(out[i] == tmp);
}
host_verify(input1, input2, out, 0);
}
SECTION("Two Kernel") {
@@ -81,6 +95,10 @@ TEST_CASE("Unit_hip_library_load_co") {
HIP_CHECK(
hipLibraryLoadFromFile(&library, lib_co.data(), nullptr, nullptr, 0, nullptr, nullptr, 0));
HIP_CHECK(hipLibraryGetKernel(&function, library, "sub_kernel"));
hipLibrary_t new_library;
HIP_CHECK(hipKernelGetLibrary(&new_library, function));
REQUIRE(new_library == library);
unsigned int count = 0;
HIP_CHECK(hipLibraryGetKernelCount(&count, library));
@@ -95,11 +113,7 @@ TEST_CASE("Unit_hip_library_load_co") {
std::vector<float> out(size, 0);
HIP_CHECK(hipMemcpy(out.data(), d_out, sizeof(float) * size, hipMemcpyDeviceToHost));
for (size_t i = 0; i < size; i++) {
float tmp = input1[i] - input2[i];
INFO("Index: " << i << " cpu res: " << tmp << " gpu res: " << out[i]);
REQUIRE(out[i] == tmp);
}
host_verify(input1, input2, out, 1);
}
SECTION("Three Kernel") {
@@ -109,6 +123,10 @@ TEST_CASE("Unit_hip_library_load_co") {
HIP_CHECK(
hipLibraryLoadFromFile(&library, lib_co.data(), nullptr, nullptr, 0, nullptr, nullptr, 0));
HIP_CHECK(hipLibraryGetKernel(&function, library, "mul_kernel"));
hipLibrary_t new_library;
HIP_CHECK(hipKernelGetLibrary(&new_library, function));
REQUIRE(new_library == library);
unsigned int count = 0;
HIP_CHECK(hipLibraryGetKernelCount(&count, library));
@@ -123,11 +141,45 @@ TEST_CASE("Unit_hip_library_load_co") {
std::vector<float> out(size, 0);
HIP_CHECK(hipMemcpy(out.data(), d_out, sizeof(float) * size, hipMemcpyDeviceToHost));
for (size_t i = 0; i < size; i++) {
float tmp = input1[i] * input2[i];
INFO("Index: " << i << " cpu res: " << tmp << " gpu res: " << out[i]);
REQUIRE(out[i] == tmp);
host_verify(input1, input2, out, 2);
}
SECTION("All Kernels") {
hipLibrary_t library;
hipKernel_t functions[num_kernels];
HIP_CHECK(
hipLibraryLoadFromFile(&library, lib_co.data(), nullptr, nullptr, 0, nullptr, nullptr, 0));
HIP_CHECK(hipLibraryEnumerateKernels(functions, num_kernels, library));
void* args[] = {&d_out, &d_in1, &d_in2};
auto kernel_idx = [](const char* kName) {
std::string ss = kName;
if (ss == "add_kernel") {
return 0;
} else if (ss == "sub_kernel") {
return 1;
} else if (ss == "mul_kernel") {
return 2;
}
return -1;
};
std::vector<float> out(size, 0);
for (int k = 0; k < num_kernels; k++) {
const char* kName = nullptr;
HIP_CHECK(hipKernelGetName(&kName, functions[k]));
HIP_CHECK(hipLaunchKernel(functions[k], 1, size, args, 0, stream));
HIP_CHECK(hipStreamSynchronize(stream));
HIP_CHECK(hipMemcpy(out.data(), d_out, sizeof(float) * size, hipMemcpyDeviceToHost));
host_verify(input1, input2, out, kernel_idx(kName));
}
HIP_CHECK(hipLibraryUnload(library));
}
HIP_CHECK(hipStreamDestroy(stream));
@@ -6424,6 +6424,35 @@ hipError_t hipLibraryGetKernel(hipKernel_t* pKernel, hipLibrary_t library, const
*/
hipError_t hipLibraryGetKernelCount(unsigned int *count, hipLibrary_t library);
/**
* @brief Retrieve kernel handles within a library
*
* @param [out] kernels Buffer for kernel handles
* @param [in] numKernels Maximum number of kernel handles to return to buffer
* @oaram [in] library Library handle to query from
* @return #hipSuccess, #hipErrorInvalidValue
*/
hipError_t hipLibraryEnumerateKernels(hipKernel_t* kernels, unsigned int numKernels,
hipLibrary_t library);
/**
* @brief Returns a Library Handle
*
* @param [out] library Returned Library handle
* @param [in] kernel Kernel to retrieve library Handle
* @return #hipSuccess, #hipErrorInvalidValue
*/
hipError_t hipKernelGetLibrary(hipLibrary_t* library, hipKernel_t kernel);
/**
* @brief Returns a Kernel Name
*
* @param [out] name Returned Kernel Name
* @param [in] kernel Kernel handle to retrieve name
* @return #hipSuccess, #hipErrorInvalidValue
*/
hipError_t hipKernelGetName(const char** name, hipKernel_t kernel);
/**
* @brief Find out attributes for a given function.
* @ingroup Execution
@@ -3672,6 +3672,19 @@ inline static hipError_t hipLibraryGetKernelCount(unsigned int* count, hipLibrar
return hipCUResultTohipError(cuLibraryGetKernelCount(count, library));
}
inline static hipError_t hipLibraryEnumerateKernels(hipKernel_t* kernels, unsigned int numKernels,
hipLibrary_t library) {
return hipCUResultTohipError(cuLibraryEnumerateKernels(kernels, numKernels, library));
}
inline static hipError_t hipKernelGetLibrary(hipLibrary_t* library, hipKernel_t kernel) {
return hipCUResultTohipError(cuKernelGetLibrary(library, kernel));
}
inline static hipError_t hipKernelGetName(const char** name, hipKernel_t kernel) {
return hipCUResultTohipError(cuKernelGetName(name, kernel));
}
inline static hipError_t hipLaunchKernel(const void* function_address, dim3 numBlocks,
dim3 dimBlocks, void** args, size_t sharedMemBytes,
hipStream_t stream) {