Make roctracer_cb_table.h a private header

Move roctracer_cb_table.h to the src/core directory, as it should not
be exposed as a public header, and rename it callback_table.h

Change-Id: Ib448cbd32a275df0268d53bd8d1da0bdc9201470
Этот коммит содержится в:
Laurent Morichetti
2022-04-18 16:13:08 -07:00
родитель dc22139977
Коммит cd62d841fa
8 изменённых файлов: 57 добавлений и 107 удалений
-1
Просмотреть файл
@@ -93,7 +93,6 @@ set ( PUBLIC_HEADERS
roctracer_hip.h
roctracer_hsa.h
roctracer_roctx.h
roctracer_cb_table.h
ext/prof_protocol.h
ext/hsa_rt_utils.hpp
)
+4 -12
Просмотреть файл
@@ -32,6 +32,10 @@
#define INC_ROCTRACER_ROCTX_H_
#include <roctx.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// ROC-TX API ID enumeration
enum roctx_api_id_t {
ROCTX_API_ID_roctxMarkA = 0,
@@ -70,18 +74,6 @@ typedef struct roctx_api_data_s {
} args;
} roctx_api_data_t;
#ifdef __cplusplus
#include <roctracer_cb_table.h>
namespace roctx {
// ROCTX callbacks table type
typedef roctracer::CbTable<ROCTX_API_ID_NUMBER> cb_table_t;
} // namespace roctx
#endif
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// Regiter ROCTX callback for given opertaion id
bool RegisterApiCallback(uint32_t op, void* callback, void* arg);
+3 -2
Просмотреть файл
@@ -324,6 +324,7 @@ class API_DescrParser:
self.content += '\n'
self.content += '#if PROF_API_IMPL\n'
self.content += '#include \"core/callback_table.h\"\n';
self.content += 'namespace roctracer {\n'
self.content += 'namespace hsa_support {\n'
self.add_section('API callback functions', '', self.gen_callbacks)
@@ -394,7 +395,7 @@ class API_DescrParser:
# generate API callbacks
def gen_callbacks(self, n, name, call, struct):
if n == -1:
self.content += 'typedef CbTable<HSA_API_ID_NUMBER> cb_table_t;\n'
self.content += 'typedef CallbackTable<HSA_API_ID_NUMBER> cb_table_t;\n'
self.content += 'extern cb_table_t cb_table;\n'
self.content += '\n'
if call != '-':
@@ -412,7 +413,7 @@ class API_DescrParser:
self.content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
self.content += ' activity_rtapi_callback_t api_callback_fun = NULL;\n'
self.content += ' void* api_callback_arg = NULL;\n'
self.content += ' cb_table.get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n'
self.content += ' cb_table.Get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n'
self.content += ' api_data.phase = 0;\n'
self.content += ' if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
if ret_type != 'void':
-1
Просмотреть файл
@@ -53,7 +53,6 @@ target_link_libraries( ${TARGET_LIB} PRIVATE ${HSA_RUNTIME_LIB} c stdc++ )
set ( ROCTX_LIB "roctx64" )
set ( ROCTX_LIB_SRC
${LIB_DIR}/roctx/roctx.cpp
${LIB_DIR}/roctx/roctx_intercept.cpp
)
add_library ( ${ROCTX_LIB} SHARED ${ROCTX_LIB_SRC} )
target_include_directories ( ${ROCTX_LIB} PRIVATE ${LIB_DIR} ${ROOT_DIR} ${ROOT_DIR}/inc ${HSA_RUNTIME_INC_PATH} ${GEN_INC_DIR} )
+20 -34
Просмотреть файл
@@ -18,56 +18,42 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. */
#ifndef CB_TABLE_H_
#define CB_TABLE_H_
#ifndef CALLBACK_TABLE_H_
#define CALLBACK_TABLE_H_
#include <ext/prof_protocol.h>
#include <cassert>
#include <mutex>
#include <utility>
namespace roctracer {
// Generic callbacks table
template <int N> class CbTable {
template <uint32_t N> class CallbackTable {
public:
typedef std::mutex mutex_t;
CallbackTable()
// Zero initialize the callbacks array as the function pointer is used to determine if the
// callback is enabled.
: callbacks_() {}
CbTable() {
std::lock_guard<mutex_t> lck(mutex_);
for (int i = 0; i < N; i++) {
callback_[i] = NULL;
arg_[i] = NULL;
}
void Set(uint32_t id, activity_rtapi_callback_t callback, void* arg) {
assert(id < N && "id is out of range");
std::lock_guard lock(mutex_);
callbacks_[id] = {callback, arg};
}
bool set(uint32_t id, activity_rtapi_callback_t callback, void* arg) {
std::lock_guard<mutex_t> lck(mutex_);
bool ret = false;
if (id < N) {
callback_[id] = callback;
arg_[id] = arg;
ret = true;
}
return ret;
}
bool get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) {
std::lock_guard<mutex_t> lck(mutex_);
bool ret = false;
if (id < N) {
*callback = callback_[id];
*arg = arg_[id];
ret = true;
}
return ret;
void Get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) const {
assert(id < N && "id is out of range");
std::lock_guard lock(mutex_);
std::tie(*callback, *arg) = callbacks_[id];
}
private:
activity_rtapi_callback_t callback_[N];
void* arg_[N];
mutex_t mutex_;
std::array<std::pair<activity_rtapi_callback_t /* callback */, void* /* arg */>, N> callbacks_;
mutable std::mutex mutex_;
};
} // namespace roctracer
#endif // CB_TALE_H_
#endif // CALLBACK_TABLE_H_
+4 -2
Просмотреть файл
@@ -788,7 +788,8 @@ static roctracer_status_t roctracer_enable_callback_fun(roctracer_domain_t domai
break;
}
#endif
roctracer::hsa_support::cb_table.set(op, callback, user_data);
if (op >= HSA_API_ID_NUMBER) return ROCTRACER_STATUS_BAD_PARAMETER;
roctracer::hsa_support::cb_table.Set(op, callback, user_data);
break;
}
case ACTIVITY_DOMAIN_HSA_EVT: {
@@ -885,7 +886,8 @@ static roctracer_status_t roctracer_disable_callback_fun(roctracer_domain_t doma
break;
}
#endif
roctracer::hsa_support::cb_table.set(op, NULL, NULL);
if (op >= HSA_API_ID_NUMBER) return ROCTRACER_STATUS_BAD_PARAMETER;
roctracer::hsa_support::cb_table.Set(op, NULL, NULL);
break;
}
case ACTIVITY_DOMAIN_HCC_OPS:
+26 -5
Просмотреть файл
@@ -27,6 +27,7 @@
#include <stack>
#include "inc/ext/prof_protocol.h"
#include "core/callback_table.h"
#include "util/exception.h"
#include "util/logger.h"
@@ -75,6 +76,14 @@ typedef enum {
// Library implementation
//
namespace roctx {
// ROCTX callbacks table type
typedef roctracer::CallbackTable<ROCTX_API_ID_NUMBER> cb_table_t;
// callbacks table
cb_table_t cb_table;
typedef std::stack<std::string> message_stack_t;
typedef std::map<uint32_t, message_stack_t*> thread_map_t;
typedef std::mutex map_mutex_t;
@@ -124,7 +133,7 @@ PUBLIC_API void roctxMarkA(const char* message) {
api_data.args.roctxMarkA.message = strdup(message);
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg);
roctx::cb_table.Get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg);
API_METHOD_SUFFIX_NRET
@@ -138,7 +147,7 @@ PUBLIC_API int roctxRangePushA(const char* message) {
api_data.args.roctxRangePushA.message = strdup(message);
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg);
roctx::cb_table.Get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data,
api_callback_arg);
@@ -155,7 +164,7 @@ PUBLIC_API int roctxRangePop() {
roctx_api_data_t api_data{};
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg);
roctx::cb_table.Get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data,
api_callback_arg);
@@ -178,7 +187,7 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) {
api_data.args.roctxRangeStartA.id = roctx_range_counter;
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg);
roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data,
api_callback_arg);
@@ -193,7 +202,7 @@ PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) {
api_data.args.roctxRangeStop.id = rangeId;
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg);
roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data,
api_callback_arg);
@@ -213,4 +222,16 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg)
}
}
PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
if (op >= ROCTX_API_ID_NUMBER) return false;
roctx::cb_table.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
return true;
}
PUBLIC_API bool RemoveApiCallback(uint32_t op) {
if (op >= ROCTX_API_ID_NUMBER) return false;
roctx::cb_table.Get(op, NULL, NULL);
return true;
}
} // extern "C"
-50
Просмотреть файл
@@ -1,50 +0,0 @@
/*
Copyright (c) 2018 Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#include "inc/roctx.h"
#include "inc/roctracer_roctx.h"
#include "util/logger.h"
#define PUBLIC_API __attribute__((visibility("default")))
///////////////////////////////////////////////////////////////////////////////////////////////////
// Library implementation
//
namespace roctx {
// callbacks table
cb_table_t cb_table;
} // namespace roctx
///////////////////////////////////////////////////////////////////////////////////////////////////
// Public library methods
//
extern "C" {
PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
return roctx::cb_table.set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
}
PUBLIC_API bool RemoveApiCallback(uint32_t op) { return roctx::cb_table.set(op, NULL, NULL); }
} // extern "C"