SWDEV-389161: Fixing CTF plugin to work with the TraceBuffer

Change-Id: I4bd9f21bb91d6bd7cee1417d88a81d9d3be6ea9b


[ROCm/rocprofiler commit: 1bba393b1b]
Bu işleme şunda yer alıyor:
Ammar ELWazir
2023-05-23 18:13:05 +00:00
ebeveyn 4ba25a5c39
işleme c81ec8a710
5 değiştirilmiş dosya ile 84 ekleme ve 162 silme
+1 -1
Dosyayı Görüntüle
@@ -22,5 +22,5 @@
add_subdirectory(file)
add_subdirectory(perfetto)
#add_subdirectory(ctf)
add_subdirectory(ctf)
add_subdirectory(att)
+3 -4
Dosyayı Görüntüle
@@ -38,7 +38,7 @@ rocm_ctf::Plugin* the_plugin = nullptr;
} // namespace
ROCPROFILER_EXPORT int rocprofiler_plugin_initialize(const uint32_t rocprofiler_major_version,
const uint32_t rocprofiler_minor_version) {
const uint32_t rocprofiler_minor_version) {
if (rocprofiler_major_version != ROCPROFILER_VERSION_MAJOR ||
rocprofiler_minor_version < ROCPROFILER_VERSION_MINOR) {
return -1;
@@ -89,8 +89,7 @@ ROCPROFILER_EXPORT int rocprofiler_plugin_write_buffer_records(
}
ROCPROFILER_EXPORT int rocprofiler_plugin_write_record(
const rocprofiler_record_tracer_t record, const void* data,
rocprofiler_plugin_trace_record_data_t tracer_data) {
const rocprofiler_record_tracer_t record, rocprofiler_plugin_tracer_extra_data_t tracer_data) {
assert(the_plugin);
if (record.header.id.handle == 0) {
@@ -98,7 +97,7 @@ ROCPROFILER_EXPORT int rocprofiler_plugin_write_record(
}
try {
the_plugin->HandleTracerRecord(record, rocprofiler_session_id_t{0}, tracer_data, data);
the_plugin->HandleTracerRecord(record, rocprofiler_session_id_t{0}, tracer_data);
} catch (const std::exception& exc) {
std::cerr << "rocprofiler_plugin_write_record(): " << exc.what() << std::endl;
return -1;
+26 -115
Dosyayı Görüntüle
@@ -55,7 +55,8 @@ namespace {
// Abstract tracer event record using the barectf context type `CtxT`.
template <typename CtxT> class TracerEventRecord : public BarectfEventRecord<CtxT> {
protected:
explicit TracerEventRecord(const rocprofiler_record_tracer_t& record, const std::uint64_t clock_val)
explicit TracerEventRecord(const rocprofiler_record_tracer_t& record,
const std::uint64_t clock_val)
: BarectfEventRecord<CtxT>{clock_val},
op_{record.operation_id.id},
thread_id_{record.thread_id.value},
@@ -135,74 +136,16 @@ class RocTxEventRecord final : public TracerEventRecord<barectf_roctx_ctx> {
explicit RocTxEventRecord(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id)
: TracerEventRecord<barectf_roctx_ctx>{record, GetRecordBeginClockVal(record)},
id_{QueryId(record, session_id)},
msg_{QueryMsg(record, session_id)} {}
explicit RocTxEventRecord(const rocprofiler_record_tracer_t& record, uint64_t roctx_id,
std::string roctx_msg)
: TracerEventRecord<barectf_roctx_ctx>{record, GetRecordBeginClockVal(record)},
id_{roctx_id},
msg_{roctx_msg} {}
id_{record.operation_id.id},
msg_{
rocmtools::cxx_demangle(reinterpret_cast<const char*>(record.api_data_handle.handle))} {
}
void Write(barectf_roctx_ctx& barectf_ctx) const override {
barectf_roctx_trace_roctx(&barectf_ctx, GetThreadId(), id_, msg_.c_str());
}
private:
// Queries and returns the rocTX message of the record `record` and
// session ID `session_id`.
//
// Returns an empty string if not available.
static std::string QueryMsg(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id) {
// Query size first.
std::size_t msg_size = 0;
[[maybe_unused]] auto ret = rocprofiler_query_roctx_tracer_api_data_info_size(
session_id, ROCPROFILER_ROCTX_MESSAGE, record.api_data_handle, record.operation_id,
&msg_size);
assert(ret == ROCPROFILER_STATUS_SUCCESS && "Query rocTX message size");
if (msg_size == 0) {
// No size: return empty string.
return {};
}
// Query data (borrowed from the record: no need to free).
char* msg = nullptr;
ret = rocprofiler_query_roctx_tracer_api_data_info(
session_id, ROCPROFILER_ROCTX_MESSAGE, record.api_data_handle, record.operation_id, &msg);
assert(ret == ROCPROFILER_STATUS_SUCCESS && "Query rocTX message");
if (!msg) {
// No data: return empty string.
return {};
}
return rocmtools::cxx_demangle(msg);
}
// Queries and returns the rocTX ID of the record `record` and the
// session ID `session_id`.
//
// Returns 0 if anything goes wrong.
static std::uint64_t QueryId(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id) {
try {
return std::stoull(QueryAllocStr(
[&record, session_id](const auto size) {
return rocprofiler_query_roctx_tracer_api_data_info_size(
session_id, ROCPROFILER_ROCTX_ID, record.api_data_handle, record.operation_id, size);
},
[&record, session_id](const auto str) {
return rocprofiler_query_roctx_tracer_api_data_info(
session_id, ROCPROFILER_ROCTX_ID, record.api_data_handle, record.operation_id, str);
}));
} catch (...) {
return 0;
}
}
std::uint64_t id_;
std::string msg_;
};
@@ -211,7 +154,8 @@ class RocTxEventRecord final : public TracerEventRecord<barectf_roctx_ctx> {
class HsaApiEventRecord : public TracerEventRecord<barectf_hsa_api_ctx> {
protected:
explicit HsaApiEventRecord(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id, const std::uint64_t clock_val)
const rocprofiler_session_id_t session_id,
const std::uint64_t clock_val)
: TracerEventRecord<barectf_hsa_api_ctx>{record, clock_val},
api_data_{QueryApiData(record, session_id)} {}
explicit HsaApiEventRecord(const rocprofiler_record_tracer_t& record,
@@ -255,8 +199,7 @@ class HsaApiEventRecordBegin final : public HsaApiEventRecord {
: HsaApiEventRecord{record, session_id, GetRecordBeginClockVal(record)} {}
explicit HsaApiEventRecordBegin(const rocprofiler_record_tracer_t& record,
hsa_api_data_t& api_data)
: HsaApiEventRecord{record, GetRecordBeginClockVal(record),
api_data} {}
: HsaApiEventRecord{record, GetRecordBeginClockVal(record), api_data} {}
void Write(barectf_hsa_api_ctx& barectf_ctx) const override {
// Include generated switch statement.
@@ -270,8 +213,7 @@ class HsaApiEventRecordEnd final : public HsaApiEventRecord {
explicit HsaApiEventRecordEnd(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id)
: HsaApiEventRecord{record, session_id, GetRecordEndClockVal(record)} {}
explicit HsaApiEventRecordEnd(const rocprofiler_record_tracer_t& record,
hsa_api_data_t& api_data)
explicit HsaApiEventRecordEnd(const rocprofiler_record_tracer_t& record, hsa_api_data_t& api_data)
: HsaApiEventRecord{record, GetRecordBeginClockVal(record), api_data} {}
void Write(barectf_hsa_api_ctx& barectf_ctx) const override {
@@ -288,7 +230,7 @@ class HipApiEventRecord : public TracerEventRecord<barectf_hip_api_ctx> {
const std::uint64_t clock_val)
: TracerEventRecord<barectf_hip_api_ctx>{record, clock_val},
api_data_{QueryApiData(record, session_id)},
kernel_name_{QueryKernelName(record, session_id)} {}
kernel_name_{nullptr} {}
explicit HipApiEventRecord(const rocprofiler_record_tracer_t& record,
const std::uint64_t clock_val, hip_api_data_t& api_data,
std::string kernel_name)
@@ -323,32 +265,6 @@ class HipApiEventRecord : public TracerEventRecord<barectf_hip_api_ctx> {
return *reinterpret_cast<const hip_api_data_t*>(data);
}
// Queries and returns the kernel name of the record `record` and
// session ID `session_id`.
//
// Returns an empty string if not available.
static std::string QueryKernelName(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id) {
const auto kernel_name = QueryAllocStr(
[&record, session_id](const auto size) {
return rocprofiler_query_hip_tracer_api_data_info_size(
session_id, ROCPROFILER_HIP_KERNEL_NAME, record.api_data_handle, record.operation_id,
size);
},
[&record, session_id](const auto str) {
return rocprofiler_query_hip_tracer_api_data_info(session_id, ROCPROFILER_HIP_KERNEL_NAME,
record.api_data_handle,
record.operation_id, str);
});
if (kernel_name.size() > 1) {
// Return demangled version.
return rocmtools::cxx_demangle(kernel_name);
}
return kernel_name;
}
hip_api_data_t api_data_;
std::string kernel_name_;
};
@@ -409,7 +325,8 @@ class HsaHandleTypeEventRecord final : public BarectfEventRecord<barectf_hsa_han
// Abstract API operation event record.
class ApiOpEventRecord : public TracerEventRecord<barectf_api_ops_ctx> {
protected:
explicit ApiOpEventRecord(const rocprofiler_record_tracer_t& record, const std::uint64_t clock_val)
explicit ApiOpEventRecord(const rocprofiler_record_tracer_t& record,
const std::uint64_t clock_val)
: TracerEventRecord<barectf_api_ops_ctx>{record, clock_val} {}
};
@@ -540,11 +457,12 @@ class ProfilerEventRecord : public BarectfEventRecord<barectf_profiler_ctx> {
static std::string QueryKernelName(const rocprofiler_record_profiler_t& record) {
const auto kernel_name = QueryAllocStr(
[&record](const auto size) {
return rocprofiler_query_kernel_info_size(ROCPROFILER_KERNEL_NAME, record.kernel_id, size);
return rocprofiler_query_kernel_info_size(ROCPROFILER_KERNEL_NAME, record.kernel_id,
size);
},
[&record](const auto str) {
return rocprofiler_query_kernel_info(ROCPROFILER_KERNEL_NAME, record.kernel_id,
const_cast<const char**>(str));
const_cast<const char**>(str));
});
if (kernel_name.size() <= 1) {
@@ -590,7 +508,7 @@ class ProfilerEventRecord : public BarectfEventRecord<barectf_profiler_ctx> {
const char* counter_name = nullptr;
ret = rocprofiler_query_counter_info(session_id, ROCPROFILER_COUNTER_NAME,
counter.counter_handler, &counter_name);
counter.counter_handler, &counter_name);
assert(ret == ROCPROFILER_STATUS_SUCCESS && "Query counter name");
if (!counter_name) {
@@ -713,33 +631,25 @@ Plugin::Plugin(const std::size_t packet_size, const fs::path& trace_dir,
void Plugin::HandleTracerRecord(const rocprofiler_record_tracer_t& record,
const rocprofiler_session_id_t session_id,
rocprofiler_plugin_trace_record_data_t tracer_data,
const void* data) {
rocprofiler_plugin_tracer_extra_data_t tracer_data) {
std::lock_guard<std::mutex> lock{lock_};
// Depending on the domain, create and add an event record to the
// corresponding tracer.
switch (record.domain) {
case ACTIVITY_DOMAIN_ROCTX:
/*If data is nullptr then the call is asynchromous*/
if (data == nullptr)
roctx_tracer_.AddEventRecord(std::make_shared<const RocTxEventRecord>(record, session_id));
else {
const char* roctx_message = reinterpret_cast<const char*>(data);
std::string roctx_msg(roctx_message);
roctx_tracer_.AddEventRecord(
std::make_shared<const RocTxEventRecord>(record, tracer_data.roctx_id, roctx_msg));
}
roctx_tracer_.AddEventRecord(std::make_shared<const RocTxEventRecord>(record, session_id));
break;
case ACTIVITY_DOMAIN_HSA_API: {
/*If data is nullptr then the call is asynchromous*/
if (data == nullptr) {
if (record.api_data_handle.handle == nullptr) {
hsa_api_tracer_.AddEventRecord(
std::make_shared<const HsaApiEventRecordBegin>(record, session_id));
hsa_api_tracer_.AddEventRecord(
std::make_shared<const HsaApiEventRecordEnd>(record, session_id));
} else {
hsa_api_data_t hsa_api_data = *reinterpret_cast<const hsa_api_data_t*>(data);
hsa_api_data_t hsa_api_data =
*reinterpret_cast<const hsa_api_data_t*>(record.api_data_handle.handle);
hsa_api_tracer_.AddEventRecord(
std::make_shared<const HsaApiEventRecordBegin>(record, hsa_api_data));
hsa_api_tracer_.AddEventRecord(
@@ -749,14 +659,15 @@ void Plugin::HandleTracerRecord(const rocprofiler_record_tracer_t& record,
}
case ACTIVITY_DOMAIN_HIP_API: {
/*If data is nullptr then the call is asynchromous*/
if (data == nullptr) {
if (record.api_data_handle.handle == nullptr) {
hip_api_tracer_.AddEventRecord(
std::make_shared<const HipApiEventRecordBegin>(record, session_id));
hip_api_tracer_.AddEventRecord(
std::make_shared<const HipApiEventRecordEnd>(record, session_id));
} else {
std::string kernel_name;
hip_api_data_t hip_api_data = *reinterpret_cast<const hip_api_data_t*>(data);
hip_api_data_t hip_api_data =
*reinterpret_cast<const hip_api_data_t*>(record.api_data_handle.handle);
if (tracer_data.kernel_name != nullptr)
kernel_name = rocmtools::cxx_demangle(std::string(tracer_data.kernel_name));
else
@@ -797,7 +708,7 @@ void Plugin::HandleBufferRecords(const rocprofiler_record_header_t* begin,
const rocprofiler_buffer_id_t buffer_id) {
while (begin && begin < end) {
if (begin->kind == ROCPROFILER_TRACER_RECORD) {
rocprofiler_plugin_trace_record_data_t tracer_data = {};
rocprofiler_plugin_tracer_extra_data_t tracer_data = {};
HandleTracerRecord(*reinterpret_cast<const rocprofiler_record_tracer_t*>(begin), session_id,
tracer_data);
} else {
+1 -2
Dosyayı Görüntüle
@@ -55,8 +55,7 @@ class Plugin final {
// Handles a tracer record.
void HandleTracerRecord(const rocprofiler_record_tracer_t& record,
rocprofiler_session_id_t session_id,
rocprofiler_plugin_trace_record_data_t tracer_data,
const void* data = nullptr);
rocprofiler_plugin_tracer_extra_data_t tracer_data);
// Handles a profiler record.
+53 -40
Dosyayı Görüntüle
@@ -140,72 +140,83 @@ std::optional<rocprofiler_plugin_t> plugin;
struct hsa_api_trace_entry_t {
std::atomic<uint32_t> valid;
rocprofiler_record_tracer_t record;
hsa_api_data_t* api_data;
const char* function_name;
hsa_api_trace_entry_t(rocprofiler_record_tracer_t tracer_record, const char* function_name_str,
const hsa_api_data_t& data)
const hsa_api_data_t* data)
: valid(rocprofiler::TRACE_ENTRY_INIT) {
record = tracer_record;
function_name = function_name_str ? strdup(function_name_str) : nullptr;
record.api_data_handle.handle = &data;
api_data = reinterpret_cast<hsa_api_data_t*>(malloc(sizeof(hsa_api_data_t)));
memcpy(api_data, data, sizeof(hsa_api_data_t));
record.api_data_handle.handle = api_data;
}
~hsa_api_trace_entry_t() {
if (function_name != nullptr) free(const_cast<char*>(function_name));
if (api_data != nullptr) free(const_cast<hsa_api_data_t*>(api_data));
}
~hsa_api_trace_entry_t() { if (function_name != nullptr) free(const_cast<char*>(function_name)); }
};
struct roctx_trace_entry_t {
std::atomic<rocprofiler::TraceEntryState> valid;
const char* roctx_message;
rocprofiler_record_tracer_t record;
roctx_trace_entry_t(rocprofiler_record_tracer_t tracer_record, const char* roctx_message_str)
: valid(rocprofiler::TRACE_ENTRY_INIT) {
record = tracer_record;
roctx_message_str? record.api_data_handle.handle=strdup(roctx_message_str):nullptr;
roctx_message = roctx_message_str ? strdup(roctx_message_str) : nullptr;
record.api_data_handle.handle = roctx_message;
}
~roctx_trace_entry_t() {
if (roctx_message != nullptr) free(const_cast<char*>(roctx_message));
}
~roctx_trace_entry_t() { }
};
struct hip_api_trace_entry_t {
std::atomic<uint32_t> valid;
hip_api_data_t* api_data;
rocprofiler_record_tracer_t record;
const char* function_name;
const char* kernel_name;
hip_api_trace_entry_t(rocprofiler_record_tracer_t tracer_record, const char* kernel_name_str,
const char* function_name_str, const hip_api_data_t& data)
const char* function_name_str, const hip_api_data_t* data)
: valid(rocprofiler::TRACE_ENTRY_INIT) {
record = tracer_record;
kernel_name = kernel_name_str ? strdup(kernel_name_str) : nullptr;
function_name = function_name_str ? strdup(function_name_str) : nullptr;
record.api_data_handle.handle=&data;
api_data = reinterpret_cast<hip_api_data_t*>(malloc(sizeof(hip_api_data_t)));
memcpy(api_data, data, sizeof(hip_api_data_t));
record.api_data_handle.handle = api_data;
}
~hip_api_trace_entry_t() {
if (function_name != nullptr) free(const_cast<char*>(function_name));
if (kernel_name != nullptr) free(const_cast<char*>(kernel_name));
if (function_name != nullptr) free(const_cast<char*>(function_name));
if (kernel_name != nullptr) free(const_cast<char*>(kernel_name));
if (api_data != nullptr) free(const_cast<hip_api_data_t*>(api_data));
}
};
rocprofiler::TraceBuffer<hip_api_trace_entry_t> hip_api_buffer(
"HIP API", 0x200000, [](hip_api_trace_entry_t* entry) {
assert(plugin && "plugin is not initialized");
rocprofiler_plugin_tracer_extra_data_t tracer_extra_data;
tracer_extra_data.function_name = entry->function_name;
tracer_extra_data.kernel_name = entry->kernel_name;
plugin->write_callback_record(entry->record, tracer_extra_data);
plugin->write_callback_record(
entry->record,
rocprofiler_plugin_tracer_extra_data_t{.function_name = entry->function_name,
.kernel_name = entry->kernel_name});
});
rocprofiler::TraceBuffer<hsa_api_trace_entry_t> hsa_api_buffer(
"HSA API", 0x200000, [](hsa_api_trace_entry_t* entry) {
assert(plugin && "plugin is not initialized");
rocprofiler_plugin_tracer_extra_data_t tracer_extra_data;
tracer_extra_data.function_name = entry->function_name;
plugin->write_callback_record(entry->record,
tracer_extra_data);
plugin->write_callback_record(
entry->record,
rocprofiler_plugin_tracer_extra_data_t{.function_name = entry->function_name});
});
rocprofiler::TraceBuffer<roctx_trace_entry_t> roctx_trace_buffer(
"rocTX API", 0x200000, [](roctx_trace_entry_t* entry) {
assert(plugin && "plugin is not initialized");
rocprofiler_plugin_tracer_extra_data_t tracer_extra_data;
tracer_extra_data.function_name = nullptr;
plugin->write_callback_record(
entry->record, tracer_extra_data);
plugin->write_callback_record(entry->record, rocprofiler_plugin_tracer_extra_data_t{});
});
} // namespace
@@ -376,13 +387,11 @@ void finish() {
if (session_created.load(std::memory_order_relaxed)) {
session_created.exchange(false, std::memory_order_release);
CHECK_ROCPROFILER(rocprofiler_terminate_session(session_id));
rocprofiler::TraceBufferBase::FlushAll();
for ([[maybe_unused]] rocprofiler_buffer_id_t buffer_id : buffer_ids) {
CHECK_ROCPROFILER(rocprofiler_flush_data(session_id, buffer_id));
}
}
// CHECK_ROCPROFILER(rocprofiler_destroy_session(session_id));
// CHECK_ROCPROFILER(rocprofiler_finalize());
}
// load plugins
@@ -427,12 +436,13 @@ void sync_api_trace_callback(rocprofiler_record_tracer_t tracer_record, rocprof
char* data = nullptr;
size_t size = 0;
CHECK_ROCPROFILER(rocprofiler_query_hip_tracer_api_data_info_size(
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle, tracer_record.operation_id, &size));
if(size > 0)
CHECK_ROCPROFILER(rocprofiler_query_hip_tracer_api_data_info(
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle, tracer_record.operation_id, &data));
hip_api_data_t hip_api_data = *reinterpret_cast<hip_api_data_t*>(data);
//std::cout << "in api calback" << strlen(function_name_c) << "\t" << function_name_c << std::endl;
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle,
tracer_record.operation_id, &size));
if (size > 0)
CHECK_ROCPROFILER(rocprofiler_query_hip_tracer_api_data_info(
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle,
tracer_record.operation_id, &data));
hip_api_data_t* hip_api_data = reinterpret_cast<hip_api_data_t*>(data);
hip_api_trace_entry_t& entry = hip_api_buffer.Emplace(
tracer_record,
(const char*)kernel_name_c ? strdup(kernel_name_c) : nullptr,
@@ -457,19 +467,22 @@ void sync_api_trace_callback(rocprofiler_record_tracer_t tracer_record, rocprof
CHECK_ROCPROFILER(rocprofiler_query_hip_tracer_api_data_info_size(
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle, tracer_record.operation_id, &size));
CHECK_ROCPROFILER(rocprofiler_query_hip_tracer_api_data_info(
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle, tracer_record.operation_id, &data));
hsa_api_data_t hsa_api_data = *reinterpret_cast<hsa_api_data_t*>(data);
hsa_api_trace_entry_t& entry = hsa_api_buffer.Emplace(
tracer_record,
(const char*)(function_name_c),
hsa_api_data);
session_id, ROCPROFILER_HIP_API_DATA, tracer_record.api_data_handle,
tracer_record.operation_id, &data));
hsa_api_data_t* hsa_api_data = reinterpret_cast<hsa_api_data_t*>(data);
hsa_api_trace_entry_t& entry =
hsa_api_buffer.Emplace(tracer_record, (const char*)(function_name_c), hsa_api_data);
entry.valid.store(rocprofiler::TRACE_ENTRY_COMPLETE, std::memory_order_release);
}
if (tracer_record.domain == ACTIVITY_DOMAIN_ROCTX) {
size_t roctx_message_size = 0;
char *roctx_message_str = nullptr;
uint64_t roctx_id=0;
CHECK_ROCPROFILER(rocprofiler_query_roctx_tracer_api_data_info_size(
size_t roctx_message_size = 0;
char* roctx_message_str = nullptr;
CHECK_ROCPROFILER(rocprofiler_query_roctx_tracer_api_data_info_size(
session_id, ROCPROFILER_ROCTX_MESSAGE, tracer_record.api_data_handle,
tracer_record.operation_id, &roctx_message_size));
if (roctx_message_size > 1) {
roctx_message_str = (char*)malloc(roctx_message_size * sizeof(char));
CHECK_ROCPROFILER(rocprofiler_query_roctx_tracer_api_data_info(
session_id, ROCPROFILER_ROCTX_MESSAGE, tracer_record.api_data_handle,
tracer_record.operation_id, &roctx_message_size));
if (roctx_message_size > 1) {