diff --git a/projects/roctracer/CMakeLists.txt b/projects/roctracer/CMakeLists.txt index 5ebde9efe5..b30adb3b16 100644 --- a/projects/roctracer/CMakeLists.txt +++ b/projects/roctracer/CMakeLists.txt @@ -93,7 +93,6 @@ set ( PUBLIC_HEADERS roctracer_hip.h roctracer_hsa.h roctracer_roctx.h - roctracer_cb_table.h ext/prof_protocol.h ext/hsa_rt_utils.hpp ) diff --git a/projects/roctracer/inc/roctracer_roctx.h b/projects/roctracer/inc/roctracer_roctx.h index 0c08dafc21..3beee3060b 100644 --- a/projects/roctracer/inc/roctracer_roctx.h +++ b/projects/roctracer/inc/roctracer_roctx.h @@ -32,6 +32,10 @@ #define INC_ROCTRACER_ROCTX_H_ #include +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + // ROC-TX API ID enumeration enum roctx_api_id_t { ROCTX_API_ID_roctxMarkA = 0, @@ -70,18 +74,6 @@ typedef struct roctx_api_data_s { } args; } roctx_api_data_t; -#ifdef __cplusplus -#include -namespace roctx { -// ROCTX callbacks table type -typedef roctracer::CbTable cb_table_t; -} // namespace roctx -#endif - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - // Regiter ROCTX callback for given opertaion id bool RegisterApiCallback(uint32_t op, void* callback, void* arg); diff --git a/projects/roctracer/script/hsaap.py b/projects/roctracer/script/hsaap.py index 2e67c75925..dd2ce504bf 100755 --- a/projects/roctracer/script/hsaap.py +++ b/projects/roctracer/script/hsaap.py @@ -324,6 +324,7 @@ class API_DescrParser: self.content += '\n' self.content += '#if PROF_API_IMPL\n' + self.content += '#include \"core/callback_table.h\"\n'; self.content += 'namespace roctracer {\n' self.content += 'namespace hsa_support {\n' self.add_section('API callback functions', '', self.gen_callbacks) @@ -394,7 +395,7 @@ class API_DescrParser: # generate API callbacks def gen_callbacks(self, n, name, call, struct): if n == -1: - self.content += 'typedef CbTable cb_table_t;\n' + self.content += 'typedef CallbackTable cb_table_t;\n' self.content += 'extern cb_table_t cb_table;\n' self.content += '\n' if call != '-': @@ -412,7 +413,7 @@ class API_DescrParser: self.content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n' self.content += ' activity_rtapi_callback_t api_callback_fun = NULL;\n' self.content += ' void* api_callback_arg = NULL;\n' - self.content += ' cb_table.get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n' + self.content += ' cb_table.Get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n' self.content += ' api_data.phase = 0;\n' self.content += ' if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n' if ret_type != 'void': diff --git a/projects/roctracer/src/CMakeLists.txt b/projects/roctracer/src/CMakeLists.txt index 8fff755303..f2068d65a4 100644 --- a/projects/roctracer/src/CMakeLists.txt +++ b/projects/roctracer/src/CMakeLists.txt @@ -53,7 +53,6 @@ target_link_libraries( ${TARGET_LIB} PRIVATE ${HSA_RUNTIME_LIB} c stdc++ ) set ( ROCTX_LIB "roctx64" ) set ( ROCTX_LIB_SRC ${LIB_DIR}/roctx/roctx.cpp - ${LIB_DIR}/roctx/roctx_intercept.cpp ) add_library ( ${ROCTX_LIB} SHARED ${ROCTX_LIB_SRC} ) target_include_directories ( ${ROCTX_LIB} PRIVATE ${LIB_DIR} ${ROOT_DIR} ${ROOT_DIR}/inc ${HSA_RUNTIME_INC_PATH} ${GEN_INC_DIR} ) diff --git a/projects/roctracer/inc/roctracer_cb_table.h b/projects/roctracer/src/core/callback_table.h similarity index 60% rename from projects/roctracer/inc/roctracer_cb_table.h rename to projects/roctracer/src/core/callback_table.h index def96b3488..52c4fccdb8 100644 --- a/projects/roctracer/inc/roctracer_cb_table.h +++ b/projects/roctracer/src/core/callback_table.h @@ -18,56 +18,42 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#ifndef CB_TABLE_H_ -#define CB_TABLE_H_ +#ifndef CALLBACK_TABLE_H_ +#define CALLBACK_TABLE_H_ #include +#include #include +#include namespace roctracer { // Generic callbacks table -template class CbTable { +template class CallbackTable { public: - typedef std::mutex mutex_t; + CallbackTable() + // Zero initialize the callbacks array as the function pointer is used to determine if the + // callback is enabled. + : callbacks_() {} - CbTable() { - std::lock_guard lck(mutex_); - for (int i = 0; i < N; i++) { - callback_[i] = NULL; - arg_[i] = NULL; - } + void Set(uint32_t id, activity_rtapi_callback_t callback, void* arg) { + assert(id < N && "id is out of range"); + std::lock_guard lock(mutex_); + callbacks_[id] = {callback, arg}; } - bool set(uint32_t id, activity_rtapi_callback_t callback, void* arg) { - std::lock_guard lck(mutex_); - bool ret = false; - if (id < N) { - callback_[id] = callback; - arg_[id] = arg; - ret = true; - } - return ret; - } - - bool get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) { - std::lock_guard lck(mutex_); - bool ret = false; - if (id < N) { - *callback = callback_[id]; - *arg = arg_[id]; - ret = true; - } - return ret; + void Get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) const { + assert(id < N && "id is out of range"); + std::lock_guard lock(mutex_); + std::tie(*callback, *arg) = callbacks_[id]; } private: - activity_rtapi_callback_t callback_[N]; - void* arg_[N]; - mutex_t mutex_; + std::array, N> callbacks_; + mutable std::mutex mutex_; }; } // namespace roctracer -#endif // CB_TALE_H_ +#endif // CALLBACK_TABLE_H_ diff --git a/projects/roctracer/src/core/roctracer.cpp b/projects/roctracer/src/core/roctracer.cpp index 70b6679603..92bcf7631c 100644 --- a/projects/roctracer/src/core/roctracer.cpp +++ b/projects/roctracer/src/core/roctracer.cpp @@ -788,7 +788,8 @@ static roctracer_status_t roctracer_enable_callback_fun(roctracer_domain_t domai break; } #endif - roctracer::hsa_support::cb_table.set(op, callback, user_data); + if (op >= HSA_API_ID_NUMBER) return ROCTRACER_STATUS_BAD_PARAMETER; + roctracer::hsa_support::cb_table.Set(op, callback, user_data); break; } case ACTIVITY_DOMAIN_HSA_EVT: { @@ -885,7 +886,8 @@ static roctracer_status_t roctracer_disable_callback_fun(roctracer_domain_t doma break; } #endif - roctracer::hsa_support::cb_table.set(op, NULL, NULL); + if (op >= HSA_API_ID_NUMBER) return ROCTRACER_STATUS_BAD_PARAMETER; + roctracer::hsa_support::cb_table.Set(op, NULL, NULL); break; } case ACTIVITY_DOMAIN_HCC_OPS: diff --git a/projects/roctracer/src/roctx/roctx.cpp b/projects/roctracer/src/roctx/roctx.cpp index 486e97b7b4..bbdf670541 100644 --- a/projects/roctracer/src/roctx/roctx.cpp +++ b/projects/roctracer/src/roctx/roctx.cpp @@ -27,6 +27,7 @@ #include #include "inc/ext/prof_protocol.h" +#include "core/callback_table.h" #include "util/exception.h" #include "util/logger.h" @@ -75,6 +76,14 @@ typedef enum { // Library implementation // namespace roctx { + +// ROCTX callbacks table type +typedef roctracer::CallbackTable cb_table_t; + +// callbacks table +cb_table_t cb_table; + + typedef std::stack message_stack_t; typedef std::map thread_map_t; typedef std::mutex map_mutex_t; @@ -124,7 +133,7 @@ PUBLIC_API void roctxMarkA(const char* message) { api_data.args.roctxMarkA.message = strdup(message); activity_rtapi_callback_t api_callback_fun = NULL; void* api_callback_arg = NULL; - roctx::cb_table.get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg); + roctx::cb_table.Get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg); API_METHOD_SUFFIX_NRET @@ -138,7 +147,7 @@ PUBLIC_API int roctxRangePushA(const char* message) { api_data.args.roctxRangePushA.message = strdup(message); activity_rtapi_callback_t api_callback_fun = NULL; void* api_callback_arg = NULL; - roctx::cb_table.get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg); + roctx::cb_table.Get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data, api_callback_arg); @@ -155,7 +164,7 @@ PUBLIC_API int roctxRangePop() { roctx_api_data_t api_data{}; activity_rtapi_callback_t api_callback_fun = NULL; void* api_callback_arg = NULL; - roctx::cb_table.get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg); + roctx::cb_table.Get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data, api_callback_arg); @@ -178,7 +187,7 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { api_data.args.roctxRangeStartA.id = roctx_range_counter; activity_rtapi_callback_t api_callback_fun = NULL; void* api_callback_arg = NULL; - roctx::cb_table.get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg); + roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data, api_callback_arg); @@ -193,7 +202,7 @@ PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) { api_data.args.roctxRangeStop.id = rangeId; activity_rtapi_callback_t api_callback_fun = NULL; void* api_callback_arg = NULL; - roctx::cb_table.get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg); + roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data, api_callback_arg); @@ -213,4 +222,16 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) } } +PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) { + if (op >= ROCTX_API_ID_NUMBER) return false; + roctx::cb_table.Set(op, reinterpret_cast(callback), arg); + return true; +} + +PUBLIC_API bool RemoveApiCallback(uint32_t op) { + if (op >= ROCTX_API_ID_NUMBER) return false; + roctx::cb_table.Get(op, NULL, NULL); + return true; +} + } // extern "C" diff --git a/projects/roctracer/src/roctx/roctx_intercept.cpp b/projects/roctracer/src/roctx/roctx_intercept.cpp deleted file mode 100644 index 11de368da5..0000000000 --- a/projects/roctracer/src/roctx/roctx_intercept.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/* -Copyright (c) 2018 Advanced Micro Devices, Inc. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -*/ - -#include "inc/roctx.h" -#include "inc/roctracer_roctx.h" -#include "util/logger.h" - -#define PUBLIC_API __attribute__((visibility("default"))) - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Library implementation -// -namespace roctx { - -// callbacks table -cb_table_t cb_table; - -} // namespace roctx - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Public library methods -// -extern "C" { - -PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) { - return roctx::cb_table.set(op, reinterpret_cast(callback), arg); -} - -PUBLIC_API bool RemoveApiCallback(uint32_t op) { return roctx::cb_table.set(op, NULL, NULL); } - -} // extern "C"