From 8f1cb6417a7eba8685f9b0be9e014a0bde16e6de Mon Sep 17 00:00:00 2001 From: "Madsen, Jonathan" Date: Fri, 21 Mar 2025 04:36:58 -0500 Subject: [PATCH] Record loaded runtimes (#6) * Provide rocprofiler_register_iterate_registration_info function - stores the runtime arguments for later reference * Fix compilation error * Removed unused variable * Store api tables in vector * Update license * Replace global_mutex usage with scoped_count - in rocp_invoke_registrations * Formatting --------- Co-authored-by: Jonathan R. Madsen [ROCm/rocprofiler-register commit: fc712350ef8ba1df3870dd8f8e9593bec1a5122e] --- .../rocprofiler-register.h | 44 ++++ .../details/scope_destructor.hpp | 87 ++++++++ .../rocprofiler_register.cpp | 191 ++++++++++++++++-- .../tests/rocprofiler/CMakeLists.txt | 21 +- .../tests/rocprofiler/rocprofiler.cpp | 71 ++++++- 5 files changed, 386 insertions(+), 28 deletions(-) create mode 100644 projects/rocprofiler-register/source/lib/rocprofiler-register/details/scope_destructor.hpp diff --git a/projects/rocprofiler-register/source/include/rocprofiler-register/rocprofiler-register.h b/projects/rocprofiler-register/source/include/rocprofiler-register/rocprofiler-register.h index 79a89fd1b8..f6751a4f59 100644 --- a/projects/rocprofiler-register/source/include/rocprofiler-register/rocprofiler-register.h +++ b/projects/rocprofiler-register/source/include/rocprofiler-register/rocprofiler-register.h @@ -129,6 +129,50 @@ rocprofiler_register_library_api_table( const char* rocprofiler_register_error_string(rocprofiler_register_error_code_t) ROCPROFILER_REGISTER_PUBLIC_API; +/// @brief Struct containing the information about the libraries which have registered +/// with rocprofiler-register. @see rocprofiler_register_iterate_registration_info +typedef struct rocprofiler_register_registration_info_t +{ + size_t size; ///< in case of future extensions + const char* common_name; ///< name of the library + uint32_t lib_version; ///< version + uint64_t api_table_length; ///< number of API tables +} rocprofiler_register_registration_info_t; + +/** + * @brief Callback function for iterating over the libraries which have registered + * with rocprofiler-register. @see rocprofiler_register_iterate_registration_info + * + * @param [in] info Pointer to library registration instance. Invokee should make a copy + * for reference outside of callback. + * @param [in] data User data passed to ::rocprofiler_register_iterate_registration_info + * @return int + * @retval 0 If zero is returned from callback, rocprofiler-register will continue to next + * registration info, if one exists + * @retval -1 If -1 (or any value != 0) is returned from callback, rocprofiler-register + * will cease to iterate over the remaining registration info, if any exists + */ +typedef int (*rocprofiler_register_registration_info_cb_t)( + rocprofiler_register_registration_info_t* info, + void* data); + +/** + * @brief Iterates over all the (valid) libraries which registered their API tables with + * rocprofiler-register. Any libraries which do not have an accepted common name, have an + * invalid import function address (in secure mode), or have registered too many instances + * are not reported by this function. + * + * @param [in] callback Callback function to invoke for each valid registered library + * @param [in] data User data to pass to the callback function + * @return ::rocprofiler_register_error_code_t + * @retval ::ROCP_REG_SUCCESS Always returned + */ +rocprofiler_register_error_code_t +rocprofiler_register_iterate_registration_info( + rocprofiler_register_registration_info_cb_t callback, + void* data) + ROCPROFILER_REGISTER_ATTRIBUTE(nonnull(1)) ROCPROFILER_REGISTER_PUBLIC_API; + #ifdef __cplusplus } #endif diff --git a/projects/rocprofiler-register/source/lib/rocprofiler-register/details/scope_destructor.hpp b/projects/rocprofiler-register/source/lib/rocprofiler-register/details/scope_destructor.hpp new file mode 100644 index 0000000000..16ae37079f --- /dev/null +++ b/projects/rocprofiler-register/source/lib/rocprofiler-register/details/scope_destructor.hpp @@ -0,0 +1,87 @@ +// MIT License +// +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "fmt/core.h" +#include "glog/logging.h" + +#include +#include + +namespace rocprofiler_register +{ +namespace common +{ +struct scope_destructor +{ + /// \fn scope_destructor(FuncT&& _fini, InitT&& _init) + /// \tparam FuncT "std::function or void (*)()" + /// \tparam InitT "std::function or void (*)()" + /// \param _fini Function to execute when object is destroyed + /// \param _init Function to execute when object is created (optional) + /// + /// \brief Provides a utility to perform an operation when exiting a scope. + template + scope_destructor( + FuncT&& _fini, + InitT&& _init = []() {}); + + ~scope_destructor() { m_functor(); } + + // delete copy operations + scope_destructor(const scope_destructor&) = delete; + scope_destructor& operator=(const scope_destructor&) = delete; + + // allow move operations + scope_destructor(scope_destructor&& rhs) noexcept; + scope_destructor& operator=(scope_destructor&& rhs) noexcept; + +private: + std::function m_functor = []() {}; +}; + +template +scope_destructor::scope_destructor(FuncT&& _fini, InitT&& _init) +: m_functor{ std::forward(_fini) } +{ + _init(); +} + +inline scope_destructor::scope_destructor(scope_destructor&& rhs) noexcept +: m_functor{ std::move(rhs.m_functor) } +{ + rhs.m_functor = []() {}; +} + +inline scope_destructor& +scope_destructor::operator=(scope_destructor&& rhs) noexcept +{ + if(this != &rhs) + { + m_functor = std::move(rhs.m_functor); + rhs.m_functor = []() {}; + } + return *this; +} +} // namespace common +} // namespace rocprofiler_register diff --git a/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp b/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp index ff95239ac0..b65985eb78 100644 --- a/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp +++ b/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp @@ -26,6 +26,7 @@ #include "details/environment.hpp" #include "details/filesystem.hpp" #include "details/logging.hpp" +#include "details/scope_destructor.hpp" #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -40,6 +42,7 @@ #include #include +#include extern "C" { #pragma weak rocprofiler_configure @@ -247,7 +250,13 @@ auto rocp_reg_get_imports(std::index_sequence) rocp_set_api_table_data_t rocp_load_rocprofiler_lib(std::string _rocp_reg_lib); -auto +struct rocp_scan_data +{ + void* handle = nullptr; + rocprofiler_set_api_table_t set_api_table_fn = nullptr; +}; + +rocp_scan_data rocp_reg_scan_for_tools() { auto* _configure_func = dlsym(RTLD_DEFAULT, "rocprofiler_configure"); @@ -265,7 +274,7 @@ rocp_reg_scan_for_tools() if(_found_tool) { if(rocprofiler_lib_handle && rocprofiler_lib_config_fn) - return std::make_pair(rocprofiler_lib_handle, rocprofiler_lib_config_fn); + return rocp_scan_data{ rocprofiler_lib_handle, rocprofiler_lib_config_fn }; if(_rocp_reg_lib.empty()) _rocp_reg_lib = rocprofiler_lib_name; @@ -281,7 +290,7 @@ rocp_reg_scan_for_tools() rocprofiler_lib_config_fn = &rocprofiler_set_api_table; } - return std::make_pair(rocprofiler_lib_handle, rocprofiler_lib_config_fn); + return rocp_scan_data{ rocprofiler_lib_handle, rocprofiler_lib_config_fn }; } rocp_set_api_table_data_t @@ -355,10 +364,105 @@ rocp_load_rocprofiler_lib(std::string _rocp_reg_lib) return std::make_tuple(rocprofiler_lib_handle, rocprofiler_lib_config_fn); } +struct registered_library_api_table +{ + bool propagated = false; + const char* common_name = nullptr; + rocprofiler_register_import_func_t import_func = nullptr; + uint32_t lib_version = 0; + std::vector api_tables = {}; + uint64_t instance_value = 0; +}; + +constexpr auto instance_bits = sizeof(uint64_t) * 8; // bits in instance_counters +constexpr auto max_instances = instance_bits * ROCP_REG_LAST; constexpr auto library_seq = std::make_index_sequence{}; auto global_count = std::atomic{ 0 }; auto import_info = rocp_reg_get_imports(library_seq); auto instance_counters = std::array{}; +auto registered = + std::array, max_instances>{}; + +struct scoped_count +{ + scoped_count() + : value{ ++global_count } + { } + + ~scoped_count() { --global_count; } + + scoped_count(const scoped_count&) = delete; + scoped_count(scoped_count&&) noexcept = delete; + scoped_count& operator=(const scoped_count&) = delete; + scoped_count& operator=(scoped_count&&) noexcept = delete; + + uint32_t value = 0; +}; + +std::optional* +rocp_add_registered_library_api_table(const char* common_name, + rocprofiler_register_import_func_t import_func, + uint32_t lib_version, + void** api_tables, + uint64_t api_tables_len, + uint64_t instance_val) +{ + LOG(INFO) << fmt::format("rocprofiler-register library api table registration:\n\t-" + "name: {}\n\t- version: {}\n\t- # tables: {}", + common_name, + lib_version, + api_tables_len); + + for(auto& itr : registered) + { + if(!itr) + { + auto _tables = std::vector{}; + _tables.reserve(api_tables_len); + for(uint64_t i = 0; i < api_tables_len; ++i) + _tables.emplace_back(api_tables[i]); + + itr = registered_library_api_table{ + false, common_name, import_func, + lib_version, std::move(_tables), instance_val + }; + return &itr; + } + } + + return nullptr; +} + +rocprofiler_register_error_code_t +rocp_invoke_registrations(bool invoke_all) +{ + auto _count = scoped_count{}; + if(_count.value > 1) return ROCP_REG_DEADLOCK; + + for(auto& itr : registered) + { + if(itr && (!itr->propagated || invoke_all)) + { + auto _scan_result = rocp_reg_scan_for_tools(); + + // rocprofiler_set_api_table has been found and we have pass the API data + auto _activate_rocprofiler = (_scan_result.set_api_table_fn != nullptr); + + if(_activate_rocprofiler) + { + auto _ret = _scan_result.set_api_table_fn(itr->common_name, + itr->lib_version, + itr->instance_value, + itr->api_tables.data(), + itr->api_tables.size()); + if(_ret != 0) return ROCP_REG_ROCPROFILER_ERROR; + itr->propagated = true; + } + } + } + + return ROCP_REG_SUCCESS; +} } // namespace extern "C" { @@ -382,22 +486,6 @@ rocprofiler_register_library_api_table( return ROCP_REG_NO_TOOLS; } - struct scoped_count - { - scoped_count() - : value{ ++global_count } - { } - - ~scoped_count() { --global_count; } - - scoped_count(const scoped_count&) = delete; - scoped_count(scoped_count&&) noexcept = delete; - scoped_count& operator=(const scoped_count&) = delete; - scoped_count& operator=(scoped_count&&) noexcept = delete; - - uint32_t value = 0; - }; - auto _count = scoped_count{}; if(_count.value > 1) return ROCP_REG_DEADLOCK; @@ -464,17 +552,33 @@ rocprofiler_register_library_api_table( auto& _bits = *reinterpret_cast(®ister_id->handle); _bits = bitset_t{ (offset_factor * _import_match->library_idx) + _instance_val }; + auto* reginfo = rocp_add_registered_library_api_table(common_name, + import_func, + lib_version, + api_tables, + api_table_length, + _instance_val); + + LOG_IF(WARNING, !reginfo) << fmt::format( + "rocprofiler-register failed to create registration info for " + "{} version {} (instance {})", + common_name, + lib_version, + _instance_val); + if(_bits.to_ulong() != register_id->handle) throw std::runtime_error("error encoding register_id"); // rocprofiler library is dlopened and we have the functor to pass the API data - auto _activate_rocprofiler = (_scan_result.second != nullptr); + auto _activate_rocprofiler = (_scan_result.set_api_table_fn != nullptr); if(_activate_rocprofiler) { - auto _ret = _scan_result.second( + auto _ret = _scan_result.set_api_table_fn( common_name, lib_version, _instance_val, api_tables, api_table_length); if(_ret != 0) return ROCP_REG_ROCPROFILER_ERROR; + + if(reginfo) (*reginfo)->propagated = true; } else { @@ -490,4 +594,49 @@ rocprofiler_register_error_string(rocprofiler_register_error_code_t _ec) return rocprofiler_register_error_string( _ec, std::make_index_sequence{}); } + +rocprofiler_register_error_code_t +rocprofiler_register_iterate_registration_info( + rocprofiler_register_registration_info_cb_t callback, + void* data) +{ + for(const auto& itr : registered) + { + if(itr) + { + auto _info = rocprofiler_register_registration_info_t{ + .size = sizeof(rocprofiler_register_registration_info_t), + .common_name = itr->common_name, + .lib_version = itr->lib_version, + .api_table_length = itr->api_tables.size() + }; + // invoke callback and break if the caller does not return zero + if(callback(&_info, data) != ROCP_REG_SUCCESS) break; + } + } + + return ROCP_REG_SUCCESS; +} + +rocprofiler_register_error_code_t +rocprofiler_register_invoke_nonpropagated_registrations() ROCPROFILER_REGISTER_PUBLIC_API; + +// +// This function can be invoked by ptrace +rocprofiler_register_error_code_t +rocprofiler_register_invoke_nonpropagated_registrations() +{ + return rocp_invoke_registrations(false); +} + +rocprofiler_register_error_code_t +rocprofiler_register_invoke_all_registrations() ROCPROFILER_REGISTER_PUBLIC_API; + +// +// This function can be invoked by ptrace +rocprofiler_register_error_code_t +rocprofiler_register_invoke_all_registrations() +{ + return rocp_invoke_registrations(true); +} } diff --git a/projects/rocprofiler-register/tests/rocprofiler/CMakeLists.txt b/projects/rocprofiler-register/tests/rocprofiler/CMakeLists.txt index 59a84a11bf..6be6a04534 100644 --- a/projects/rocprofiler-register/tests/rocprofiler/CMakeLists.txt +++ b/projects/rocprofiler-register/tests/rocprofiler/CMakeLists.txt @@ -1,11 +1,28 @@ # # mock rocprofiler library # + +find_package(rocprofiler-register REQUIRED) +if(TARGET rocprofiler-register::rocprofiler-register-headers) + get_property( + rocp_reg_INCLUDE_DIR + TARGET rocprofiler-register::rocprofiler-register-headers + PROPERTY INTERFACE_INCLUDE_DIRECTORIES) +else() + get_property( + rocp_reg_INCLUDE_DIR + TARGET rocprofiler-register::rocprofiler-register + PROPERTY INTERFACE_INCLUDE_DIRECTORIES) +endif() + add_library(rocprofiler SHARED) add_library(rocprofiler::rocprofiler ALIAS rocprofiler) target_sources(rocprofiler PRIVATE rocprofiler.cpp) -target_include_directories(rocprofiler PUBLIC $ - $) +target_include_directories( + rocprofiler + PUBLIC $ + $ + $) set_target_properties( rocprofiler PROPERTIES OUTPUT_NAME rocprofiler-sdk diff --git a/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp b/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp index eac61cadff..c513f805dd 100644 --- a/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp +++ b/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp @@ -1,4 +1,5 @@ +#include #include #include #include @@ -8,8 +9,11 @@ #include #include +#include #include #include +#include +#include #ifndef ROCP_REG_FILE_NAME # define ROCP_REG_FILE_NAME \ @@ -64,6 +68,26 @@ roctx_range_pop(const char* name) { printf("[%s][pop] %s\n", ROCP_REG_FILE_NAME, name); } + +using reginfo_vec_t = std::vector; + +bool +check_registration_info(const char* name, + uint64_t lib_version, + uint64_t num_tables, + const reginfo_vec_t& infovec) +{ + for(const auto& itr : infovec) + { + if(std::string_view{ name } == std::string_view{ itr.common_name }) + { + return std::tie(lib_version, num_tables) == + std::tie(itr.lib_version, itr.api_table_length); + } + } + + return false; +} } // namespace rocprofiler extern "C" { @@ -99,6 +123,34 @@ rocprofiler_set_api_table(const char* name, " did not contain rocprofiler_configure symbol" }; } + auto registration_info = ::rocprofiler::reginfo_vec_t{}; + { + auto* _handle = + dlopen("librocprofiler-register.so", RTLD_LAZY | RTLD_LOCAL | RTLD_NOLOAD); + if(!_handle) + throw std::runtime_error{ + "error opening librocprofiler-register.so library " + }; + auto* _sym = dlsym(_handle, "rocprofiler_register_iterate_registration_info"); + if(!_sym) + throw std::runtime_error{ + "librocprofiler-register.so did not contain " + "rocprofiler_register_iterate_registration_info symbol" + }; + + auto _func = [](rocprofiler_register_registration_info_t* _info, + void* _vdata) -> int { + auto* _vec = static_cast<::rocprofiler::reginfo_vec_t*>(_vdata); + _vec->emplace_back(*_info); + return 0; + }; + + auto iterate_registration_info = + reinterpret_cast( + _sym); + iterate_registration_info(_func, ®istration_info); + } + using hip_table_t = hip::HipApiTable; using hsa_table_t = hsa::HsaApiTable; using roctx_table_t = roctx::ROCTxApiTable; @@ -121,23 +173,23 @@ rocprofiler_set_api_table(const char* name, if(std::string_view{ name } == "hip") { hip_table_t* _table = static_cast(tables[0]); - _table->hip_init_fn = &rocprofiler::hip_init; + _table->hip_init_fn = &::rocprofiler::hip_init; } else if(std::string_view{ name } == "hsa") { hsa_table_t* _table = static_cast(tables[0]); - _table->hsa_init_fn = &rocprofiler::hsa_init; + _table->hsa_init_fn = &::rocprofiler::hsa_init; } else if(std::string_view{ name } == "roctx") { roctx_table_t* _table = static_cast(tables[0]); - _table->roctxRangePush_fn = &rocprofiler::roctx_range_push; - _table->roctxRangePop_fn = &rocprofiler::roctx_range_pop; + _table->roctxRangePush_fn = &::rocprofiler::roctx_range_push; + _table->roctxRangePop_fn = &::rocprofiler::roctx_range_pop; } else if(std::string_view{ name } == "rccl") { rccl_table_t* _table = static_cast(tables[0]); - _table->ncclGetVersion_fn = &rocprofiler::ncclGetVersion; + _table->ncclGetVersion_fn = &::rocprofiler::ncclGetVersion; } else if(std::string_view{ name } == "rocdecode") { @@ -151,6 +203,15 @@ rocprofiler_set_api_table(const char* name, } } + if(!::rocprofiler::check_registration_info( + name, lib_version, num_tables, registration_info)) + { + auto ss = std::stringstream{}; + ss << "no matching registration info for " << name << " " + << " version " << lib_version << " (# tables = " << num_tables << ")"; + throw std::runtime_error{ ss.str() }; + } + return 0; } }