From b7661bccfd979add272b655edf09cd4dace94d8d Mon Sep 17 00:00:00 2001 From: "Baraldi, Giovanni" Date: Thu, 5 Dec 2024 05:33:53 +0100 Subject: [PATCH] SWDEV-489158: Adding consumer+producer model to AST evaluation (#13) * Rebased optizations for rocprofv3 tool * Fixing merge conflicts * Formatting * Open from within mutex * Small name changes * Added operator * removed some parameters * Optimizing counter collection * Re-arrange code * Adding back dimension query * Formatting * Update source/lib/rocprofiler-sdk/thread_trace/att_core.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Formatting 2 * Fix for test compilation * Fix for yield * Adding back check for zero * Improved thread handling * Formatting * Remove automatic start * Adding test * Small fixes * Adding lock for buffer callbacks * Fix for race condition in AST * Adding check for ptr --------- Co-authored-by: Giovanni Baraldi Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../rocprofiler-sdk/counters/CMakeLists.txt | 23 ++- source/lib/rocprofiler-sdk/counters/core.cpp | 17 +- .../counters/dispatch_handlers.cpp | 109 +---------- .../counters/dispatch_handlers.hpp | 13 +- .../rocprofiler-sdk/counters/evaluate_ast.cpp | 20 +- .../counters/sample_consumer.hpp | 109 +++++++++++ .../counters/sample_processing.cpp | 184 ++++++++++++++++++ .../counters/sample_processing.hpp | 52 +++++ .../counters/tests/CMakeLists.txt | 18 ++ .../counters/tests/consumer_test.cpp | 116 +++++++++++ .../rocprofiler-sdk/counters/tests/core.cpp | 10 +- source/lib/rocprofiler-sdk/hsa/queue.cpp | 41 ++-- source/lib/rocprofiler-sdk/hsa/queue.hpp | 4 +- .../pc_sampling/hsa_adapter.cpp | 8 +- .../rocprofiler-sdk/thread_trace/att_core.cpp | 6 +- 15 files changed, 577 insertions(+), 153 deletions(-) create mode 100644 source/lib/rocprofiler-sdk/counters/sample_consumer.hpp create mode 100644 source/lib/rocprofiler-sdk/counters/sample_processing.cpp create mode 100644 source/lib/rocprofiler-sdk/counters/sample_processing.hpp create mode 100644 source/lib/rocprofiler-sdk/counters/tests/consumer_test.cpp diff --git a/source/lib/rocprofiler-sdk/counters/CMakeLists.txt b/source/lib/rocprofiler-sdk/counters/CMakeLists.txt index cba0268dd7..ef8985e07e 100644 --- a/source/lib/rocprofiler-sdk/counters/CMakeLists.txt +++ b/source/lib/rocprofiler-sdk/counters/CMakeLists.txt @@ -1,9 +1,24 @@ set(ROCPROFILER_LIB_COUNTERS_SOURCES - metrics.cpp dimensions.cpp evaluate_ast.cpp core.cpp id_decode.cpp - dispatch_handlers.cpp controller.cpp device_counting.cpp) + metrics.cpp + dimensions.cpp + evaluate_ast.cpp + core.cpp + id_decode.cpp + dispatch_handlers.cpp + sample_processing.cpp + controller.cpp + device_counting.cpp) set(ROCPROFILER_LIB_COUNTERS_HEADERS - metrics.hpp dimensions.hpp evaluate_ast.hpp core.hpp id_decode.hpp - dispatch_handlers.hpp controller.hpp device_counting.hpp) + metrics.hpp + dimensions.hpp + evaluate_ast.hpp + core.hpp + id_decode.hpp + dispatch_handlers.hpp + sample_processing.hpp + controller.hpp + device_counting.hpp + sample_consumer.hpp) target_sources(rocprofiler-sdk-object-library PRIVATE ${ROCPROFILER_LIB_COUNTERS_SOURCES} ${ROCPROFILER_LIB_COUNTERS_HEADERS}) diff --git a/source/lib/rocprofiler-sdk/counters/core.cpp b/source/lib/rocprofiler-sdk/counters/core.cpp index 5df3629122..e451428a2b 100644 --- a/source/lib/rocprofiler-sdk/counters/core.cpp +++ b/source/lib/rocprofiler-sdk/counters/core.cpp @@ -28,6 +28,7 @@ #include "lib/rocprofiler-sdk/aql/packet_construct.hpp" #include "lib/rocprofiler-sdk/context/context.hpp" #include "lib/rocprofiler-sdk/counters/dispatch_handlers.hpp" +#include "lib/rocprofiler-sdk/counters/sample_processing.hpp" #include "lib/rocprofiler-sdk/hsa/queue_controller.hpp" #include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" @@ -157,6 +158,8 @@ start_context(const context::context* ctx) if(!already_enabled) { + callback_thread_start(); + for(auto& cb : ctx->counter_collection->callbacks) { // Insert our callbacks into HSA Interceptor. This @@ -182,12 +185,12 @@ start_context(const context::context* ctx) correlation_id); }, // Completion CB - [=](const hsa::Queue& q, - hsa::rocprofiler_packet kern_pkt, - const hsa::Queue::queue_info_session_t& session, - inst_pkt_t& aql, - kernel_dispatch::profiling_time dispatch_time) { - completed_cb(ctx, cb, q, kern_pkt, session, aql, dispatch_time); + [=](const hsa::Queue& /* q */, + hsa::rocprofiler_packet /* kern_pkt */, + std::shared_ptr& session, + inst_pkt_t& aql, + kernel_dispatch::profiling_time dispatch_time) { + completed_cb(ctx, cb, session, aql, dispatch_time); }); } } @@ -206,6 +209,8 @@ stop_context(const context::context* ctx) }); if(controller) controller->disable_serialization(); + + callback_thread_stop(); } rocprofiler_status_t diff --git a/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp b/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp index b92700b347..f83d4a9a37 100644 --- a/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp +++ b/source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp @@ -30,6 +30,7 @@ #include "lib/rocprofiler-sdk/buffer.hpp" #include "lib/rocprofiler-sdk/context/context.hpp" #include "lib/rocprofiler-sdk/counters/core.hpp" +#include "lib/rocprofiler-sdk/counters/sample_processing.hpp" #include "lib/rocprofiler-sdk/hsa/queue_controller.hpp" #include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" @@ -162,14 +163,13 @@ queue_cb(const context::context* ctx, * Callback called by HSA interceptor when the kernel has completed processing. */ void -completed_cb(const context::context* ctx, - const std::shared_ptr& info, - const hsa::Queue& /*queue*/, - hsa::rocprofiler_packet /*packet*/, - const hsa::Queue::queue_info_session_t& session, - inst_pkt_t& pkts, - kernel_dispatch::profiling_time dispatch_time) +completed_cb(const context::context* ctx, + const std::shared_ptr& info, + std::shared_ptr& ptr_session, + inst_pkt_t& pkts, + kernel_dispatch::profiling_time dispatch_time) { + auto& session = *ptr_session; CHECK(info && ctx); std::shared_ptr prof_config; @@ -198,98 +198,9 @@ completed_cb(const context::context* ctx, // We have no profile config, nothing to output. if(!prof_config) return; - auto decoded_pkt = EvaluateAST::read_pkt(prof_config->pkt_generator.get(), *pkt); - EvaluateAST::read_special_counters( - *prof_config->agent, prof_config->required_special_counters, decoded_pkt); - - prof_config->packets.wlock([&](auto& pkt_vector) { - if(pkt) - { - pkt_vector.emplace_back(std::move(pkt)); - } - }); - - common::container::small_vector out; - rocprofiler::buffer::instance* buf = nullptr; - - if(info->buffer) - { - buf = CHECK_NOTNULL(buffer::get_buffer(info->buffer->handle)); - } - - auto _corr_id_v = - rocprofiler_correlation_id_t{.internal = 0, .external = context::null_user_data}; - if(const auto* _corr_id = session.correlation_id) - { - _corr_id_v.internal = _corr_id->internal; - if(const auto* external = rocprofiler::common::get_val( - session.tracing_data.external_correlation_ids, info->internal_context)) - { - _corr_id_v.external = *external; - } - } - - auto _dispatch_id = session.callback_record.dispatch_info.dispatch_id; - for(auto& ast : prof_config->asts) - { - std::vector>> cache; - auto* ret = ast.evaluate(decoded_pkt, cache); - CHECK(ret); - ast.set_out_id(*ret); - - out.reserve(out.size() + ret->size()); - for(auto& val : *ret) - { - val.agent_id = prof_config->agent->id; - val.dispatch_id = _dispatch_id; - out.emplace_back(val); - } - } - - if(!out.empty()) - { - if(buf) - { - auto _header = - common::init_public_api_struct(rocprofiler_dispatch_counting_service_record_t{}); - _header.num_records = out.size(); - _header.correlation_id = _corr_id_v; - if(dispatch_time.status == HSA_STATUS_SUCCESS) - { - _header.start_timestamp = dispatch_time.start; - _header.end_timestamp = dispatch_time.end; - } - _header.dispatch_info = session.callback_record.dispatch_info; - buf->emplace(ROCPROFILER_BUFFER_CATEGORY_COUNTERS, - ROCPROFILER_COUNTER_RECORD_PROFILE_COUNTING_DISPATCH_HEADER, - _header); - - for(auto itr : out) - buf->emplace( - ROCPROFILER_BUFFER_CATEGORY_COUNTERS, ROCPROFILER_COUNTER_RECORD_VALUE, itr); - } - else - { - CHECK(info->record_callback); - - auto dispatch_data = - common::init_public_api_struct(rocprofiler_dispatch_counting_service_data_t{}); - - dispatch_data.dispatch_info = session.callback_record.dispatch_info; - dispatch_data.correlation_id = _corr_id_v; - if(dispatch_time.status == HSA_STATUS_SUCCESS) - { - dispatch_data.start_timestamp = dispatch_time.start; - dispatch_data.end_timestamp = dispatch_time.end; - } - - info->record_callback(dispatch_data, - out.data(), - out.size(), - session.user_data, - info->record_callback_args); - } - } + completed_cb_params_t params{info, ptr_session, dispatch_time, prof_config, std::move(pkt)}; + process_callback_data(std::move(params)); } + } // namespace counters } // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp b/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp index 3ae2230494..f6a12c8c59 100644 --- a/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp +++ b/source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp @@ -46,12 +46,11 @@ queue_cb(const context::context* ctx, const context::correlation_id* correlation_id); void -completed_cb(const context::context* ctx, - const std::shared_ptr& info, - const hsa::Queue& queue, - hsa::rocprofiler_packet packet, - const hsa::Queue::queue_info_session_t& session, - inst_pkt_t& pkts, - kernel_dispatch::profiling_time dispatch_time); +completed_cb(const context::context* ctx, + const std::shared_ptr& info, + std::shared_ptr& session, + inst_pkt_t& pkts, + kernel_dispatch::profiling_time dispatch_time); + } // namespace counters } // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/counters/evaluate_ast.cpp b/source/lib/rocprofiler-sdk/counters/evaluate_ast.cpp index 9adb5ba38f..495d098578 100644 --- a/source/lib/rocprofiler-sdk/counters/evaluate_ast.cpp +++ b/source/lib/rocprofiler-sdk/counters/evaluate_ast.cpp @@ -21,6 +21,8 @@ // SOFTWARE. #include "lib/rocprofiler-sdk/counters/evaluate_ast.hpp" +#include "lib/common/static_object.hpp" +#include "lib/common/synchronized.hpp" #include #include @@ -569,7 +571,9 @@ using property_function_t = int64_t (*)(const rocprofiler_agent_t&); int64_t get_agent_property(std::string_view property, const rocprofiler_agent_t& agent) { - static std::unordered_map props = { + using map_t = std::unordered_map; + + static auto*& _props = common::static_object>::construct(map_t{ GEN_MAP_ENTRY("cpu_cores_count", agent_info.cpu_cores_count), GEN_MAP_ENTRY("simd_count", agent_info.simd_count), GEN_MAP_ENTRY("mem_banks_count", agent_info.mem_banks_count), @@ -599,13 +603,15 @@ get_agent_property(std::string_view property, const rocprofiler_agent_t& agent) GEN_MAP_ENTRY("num_sdma_queues_per_engine", agent_info.num_sdma_queues_per_engine), GEN_MAP_ENTRY("num_cp_queues", agent_info.num_cp_queues), GEN_MAP_ENTRY("max_engine_clk_ccompute", agent_info.max_engine_clk_ccompute), - }; - if(const auto* func = rocprofiler::common::get_val(props, property)) - { - return (*func)(agent); - } + }); - return 0.0; + return CHECK_NOTNULL(_props)->wlock([&property, &agent](map_t& props) -> int64_t { + if(const auto* func = rocprofiler::common::get_val(props, property)) + { + return (*func)(agent); + } + return 0; + }); } void diff --git a/source/lib/rocprofiler-sdk/counters/sample_consumer.hpp b/source/lib/rocprofiler-sdk/counters/sample_consumer.hpp new file mode 100644 index 0000000000..486cc991a4 --- /dev/null +++ b/source/lib/rocprofiler-sdk/counters/sample_consumer.hpp @@ -0,0 +1,109 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "lib/rocprofiler-sdk/counters/sample_processing.hpp" + +#include +#include +#include + +namespace rocprofiler +{ +namespace counters +{ +template +class consumer_thread_t +{ + static constexpr size_t SIZE = 128; + using consume_func_t = std::function; + +public: + consumer_thread_t(consume_func_t func) { this->consume_fn = func; } + virtual ~consumer_thread_t() { exit(); } + + void start() + { + { + std::unique_lock lk(mut); + if(valid.exchange(true)) return; + } + consumer = std::thread{&consumer_thread_t::consumer_loop, this}; + } + + void exit() + { + { + std::unique_lock lk(mut); + if(!valid.exchange(false)) return; + cv.notify_one(); + } + consumer.join(); + } + + void add(DataType&& params) + { + std::unique_lock lk(mut); + + if(read_ptr + buffer.size() <= write_ptr || !valid) + { + // If not possible to use consumer thread, proccess with this thread + consume_fn(std::move(params)); + return; + } + + buffer.at(write_ptr % buffer.size()) = std::move(params); + write_ptr.fetch_add(1); + cv.notify_one(); + } + +protected: + void consumer_loop() + { + while(true) + { + while(read_ptr == write_ptr) + { + std::unique_lock lk(mut); + cv.wait(lk, [&] { return read_ptr != write_ptr || !valid; }); + if(!valid && read_ptr == write_ptr) return; + } + + auto retrieved = std::move(buffer.at(read_ptr % buffer.size())); + read_ptr.fetch_add(1); + consume_fn(std::move(retrieved)); + } + } + + consume_func_t consume_fn; + std::atomic valid{false}; + std::mutex mut; + std::atomic write_ptr{0}; + std::atomic read_ptr{0}; + std::array buffer; + std::thread consumer; + std::condition_variable cv; +}; + +} // namespace counters +} // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/counters/sample_processing.cpp b/source/lib/rocprofiler-sdk/counters/sample_processing.cpp new file mode 100644 index 0000000000..b8b717e318 --- /dev/null +++ b/source/lib/rocprofiler-sdk/counters/sample_processing.cpp @@ -0,0 +1,184 @@ + + +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "lib/rocprofiler-sdk/counters/sample_processing.hpp" + +#include "lib/common/container/small_vector.hpp" +#include "lib/common/synchronized.hpp" +#include "lib/common/utility.hpp" +#include "lib/rocprofiler-sdk/buffer.hpp" +#include "lib/rocprofiler-sdk/context/context.hpp" +#include "lib/rocprofiler-sdk/counters/core.hpp" +#include "lib/rocprofiler-sdk/counters/sample_consumer.hpp" +#include "lib/rocprofiler-sdk/hsa/queue_controller.hpp" +#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" + +#include +#include + +namespace rocprofiler +{ +namespace counters +{ +std::mutex& +get_buffer_mut() +{ + static auto*& mut = common::static_object::construct(); + return *CHECK_NOTNULL(mut); +} + +/** + * Callback called by HSA interceptor when the kernel has completed processing. + */ +void +proccess_completed_cb(completed_cb_params_t&& params) +{ + auto& info = params.info; + auto& session = *params.session; + auto& dispatch_time = params.dispatch_time; + auto& prof_config = params.prof_config; + auto& pkt = params.pkt; + + ROCP_FATAL_IF(pkt == nullptr) << "AQL packet is a nullptr!"; + + auto decoded_pkt = EvaluateAST::read_pkt(prof_config->pkt_generator.get(), *pkt); + EvaluateAST::read_special_counters( + *prof_config->agent, prof_config->required_special_counters, decoded_pkt); + + prof_config->packets.wlock([&](auto& pkt_vector) { pkt_vector.emplace_back(std::move(pkt)); }); + + common::container::small_vector out; + rocprofiler::buffer::instance* buf = nullptr; + + if(info->buffer) + { + buf = CHECK_NOTNULL(buffer::get_buffer(info->buffer->handle)); + } + + auto _corr_id_v = + rocprofiler_correlation_id_t{.internal = 0, .external = context::null_user_data}; + if(const auto* _corr_id = session.correlation_id) + { + _corr_id_v.internal = _corr_id->internal; + if(const auto* external = rocprofiler::common::get_val( + session.tracing_data.external_correlation_ids, info->internal_context)) + { + _corr_id_v.external = *external; + } + } + + auto _dispatch_id = session.callback_record.dispatch_info.dispatch_id; + for(auto& ast : prof_config->asts) + { + std::vector>> cache; + auto* ret = ast.evaluate(decoded_pkt, cache); + CHECK(ret); + ast.set_out_id(*ret); + + out.reserve(out.size() + ret->size()); + for(auto& val : *ret) + { + val.agent_id = prof_config->agent->id; + val.dispatch_id = _dispatch_id; + out.emplace_back(val); + } + } + + if(!out.empty()) + { + if(buf) + { + auto _header = + common::init_public_api_struct(rocprofiler_dispatch_counting_service_record_t{}); + _header.num_records = out.size(); + _header.correlation_id = _corr_id_v; + if(dispatch_time.status == HSA_STATUS_SUCCESS) + { + _header.start_timestamp = dispatch_time.start; + _header.end_timestamp = dispatch_time.end; + } + _header.dispatch_info = session.callback_record.dispatch_info; + + auto _lk = std::unique_lock{get_buffer_mut()}; // Buffer records need to be in order + + buf->emplace(ROCPROFILER_BUFFER_CATEGORY_COUNTERS, + ROCPROFILER_COUNTER_RECORD_PROFILE_COUNTING_DISPATCH_HEADER, + _header); + + for(auto itr : out) + buf->emplace( + ROCPROFILER_BUFFER_CATEGORY_COUNTERS, ROCPROFILER_COUNTER_RECORD_VALUE, itr); + } + else + { + CHECK(info->record_callback); + + auto dispatch_data = + common::init_public_api_struct(rocprofiler_dispatch_counting_service_data_t{}); + + dispatch_data.dispatch_info = session.callback_record.dispatch_info; + dispatch_data.correlation_id = _corr_id_v; + if(dispatch_time.status == HSA_STATUS_SUCCESS) + { + dispatch_data.start_timestamp = dispatch_time.start; + dispatch_data.end_timestamp = dispatch_time.end; + } + + info->record_callback(dispatch_data, + out.data(), + out.size(), + session.user_data, + info->record_callback_args); + } + } +} + +auto& +callback_thread_get() +{ + using consumer_t = consumer_thread_t; + static auto*& _v = common::static_object::construct(proccess_completed_cb); + return *CHECK_NOTNULL(_v); +} + +void +callback_thread_start() +{ + callback_thread_get().start(); +} + +void +callback_thread_stop() +{ + callback_thread_get().exit(); +} + +void +process_callback_data(completed_cb_params_t&& params) +{ + callback_thread_get().add(std::move(params)); +} + +} // namespace counters +} // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/counters/sample_processing.hpp b/source/lib/rocprofiler-sdk/counters/sample_processing.hpp new file mode 100644 index 0000000000..bda7c54fde --- /dev/null +++ b/source/lib/rocprofiler-sdk/counters/sample_processing.hpp @@ -0,0 +1,52 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "lib/rocprofiler-sdk/context/context.hpp" +#include "lib/rocprofiler-sdk/hsa/aql_packet.hpp" +#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp" + +namespace rocprofiler +{ +namespace counters +{ +struct completed_cb_params_t +{ + std::shared_ptr info; + std::shared_ptr session; + kernel_dispatch::profiling_time dispatch_time; + std::shared_ptr prof_config; + std::unique_ptr pkt; +}; + +void +callback_thread_start(); + +void +callback_thread_stop(); + +void +process_callback_data(completed_cb_params_t&& params); + +} // namespace counters +} // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/counters/tests/CMakeLists.txt b/source/lib/rocprofiler-sdk/counters/tests/CMakeLists.txt index 6c98d1574b..b9057f9e54 100644 --- a/source/lib/rocprofiler-sdk/counters/tests/CMakeLists.txt +++ b/source/lib/rocprofiler-sdk/counters/tests/CMakeLists.txt @@ -84,3 +84,21 @@ gtest_add_tests( WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) set_tests_properties(${counter-tests_TESTS} PROPERTIES TIMEOUT 45 LABELS "unittests") + +set(ROCPROFILER_LIB_CONSUMER_TEST_SOURCES consumer_test.cpp) + +add_executable(consumer-test) +target_sources(consumer-test PRIVATE ${ROCPROFILER_LIB_CONSUMER_TEST_SOURCES}) + +target_link_libraries( + consumer-test rocprofiler-sdk::rocprofiler-sdk-hsa-runtime + rocprofiler-sdk::rocprofiler-sdk-hip rocprofiler-sdk::rocprofiler-sdk-common-library + rocprofiler-sdk::rocprofiler-sdk-static-library GTest::gtest GTest::gtest_main) + +gtest_add_tests( + TARGET consumer-test + SOURCES ${ROCPROFILER_LIB_CONSUMER_TEST_SOURCES} + TEST_LIST consumer-tests_TESTS + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + +set_tests_properties(${consumer-tests_TESTS} PROPERTIES TIMEOUT 45 LABELS "unittests") diff --git a/source/lib/rocprofiler-sdk/counters/tests/consumer_test.cpp b/source/lib/rocprofiler-sdk/counters/tests/consumer_test.cpp new file mode 100644 index 0000000000..c12513dd55 --- /dev/null +++ b/source/lib/rocprofiler-sdk/counters/tests/consumer_test.cpp @@ -0,0 +1,116 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "lib/rocprofiler-sdk/counters/sample_consumer.hpp" + +namespace rocprofiler +{ +namespace counters +{ +constexpr size_t NUM_THREADS = 5; +constexpr size_t NUM_ELEMENTS = 1ul << 17; +using result_array_t = std::array, NUM_ELEMENTS>; +using result_array_ptr_t = std::shared_ptr; + +struct DummyData +{ + size_t index; + size_t increment; + result_array_ptr_t array; +}; + +using consumer_t = consumer_thread_t; + +void +consume_fn(DummyData&& data) +{ + data.array->at(data.index).fetch_add(data.increment); +} + +TEST(consumer, nothread) +{ + auto array = std::make_shared(); + + consumer_t consumer(consume_fn); + consumer.add(DummyData{1, 1, array}); + + EXPECT_EQ(array->at(0).load(), 0); + EXPECT_EQ(array->at(1).load(), 1); +} + +TEST(consumer, singlethread) +{ + auto array = std::make_shared(); + + { + consumer_t consumer(consume_fn); + consumer.start(); + + for(size_t i = 0; i < NUM_ELEMENTS; i++) + consumer.add(DummyData{i, 1, array}); + } + + for(auto& var : *array) + EXPECT_EQ(var.load(), 1); +} + +TEST(consumer, multithreaded) +{ + auto array = std::make_shared(); + consumer_t consumer(consume_fn); + + auto produce_fn = [&](size_t tid) { + for(size_t i = 0; i < NUM_ELEMENTS; i++) + consumer.add(DummyData{i, tid, array}); + }; + + { + std::vector> threads{}; + for(size_t i = 0; i < NUM_THREADS; i++) + threads.push_back(std::async(std::launch::async, produce_fn, i + 1)); + + consumer.start(); + } + + consumer.exit(); + + size_t expected = NUM_THREADS * (NUM_THREADS + 1) / 2; + + for(auto& var : *array) + EXPECT_EQ(var.load(), expected); +} + +} // namespace counters +} // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/counters/tests/core.cpp b/source/lib/rocprofiler-sdk/counters/tests/core.cpp index 31c1b52375..a312df2dd9 100644 --- a/source/lib/rocprofiler-sdk/counters/tests/core.cpp +++ b/source/lib/rocprofiler-sdk/counters/tests/core.cpp @@ -459,14 +459,16 @@ TEST(core, check_callbacks) &opt_buff_id), "Could not create buffer"); cb_info->buffer = opt_buff_id; - // hsa::Queue::queue_info_session_t sess = {.queue = fq, .correlation_id = &corr_id}; - hsa::Queue::queue_info_session_t sess = hsa::Queue::queue_info_session_t{.queue = fq}; - sess.correlation_id = &corr_id; + + auto _sess = hsa::Queue::queue_info_session_t{.queue = fq}; + _sess.correlation_id = &corr_id; + + auto sess = std::make_shared(std::move(_sess)); counters::inst_pkt_t pkts; pkts.emplace_back( std::make_pair(std::move(ret_pkt), static_cast(0))); - completed_cb(&ctx, cb_info, fq, pkt, sess, pkts, kernel_dispatch::profiling_time{}); + completed_cb(&ctx, cb_info, sess, pkts, kernel_dispatch::profiling_time{}); rocprofiler_flush_buffer(opt_buff_id); rocprofiler_destroy_buffer(opt_buff_id); } diff --git a/source/lib/rocprofiler-sdk/hsa/queue.cpp b/source/lib/rocprofiler-sdk/hsa/queue.cpp index b776177f92..595b8cd93d 100644 --- a/source/lib/rocprofiler-sdk/hsa/queue.cpp +++ b/source/lib/rocprofiler-sdk/hsa/queue.cpp @@ -43,6 +43,7 @@ #include #include +#include // static assert for rocprofiler_packet ABI compatibility static_assert(sizeof(hsa_ext_amd_aql_pm4_packet_t) == sizeof(hsa_kernel_dispatch_packet_t), @@ -116,8 +117,10 @@ AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data) get_balanced_signal_slots().fetch_add(1); - auto& queue_info_session = *static_cast(data); - auto dispatch_time = kernel_dispatch::get_dispatch_time(queue_info_session); + auto& shared_ptr_info = *static_cast*>(data); + auto& queue_info_session = *shared_ptr_info; + + auto dispatch_time = kernel_dispatch::get_dispatch_time(queue_info_session); kernel_dispatch::dispatch_complete(queue_info_session, dispatch_time); @@ -128,7 +131,7 @@ AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data) { cb_pair.second(queue_info_session.queue, queue_info_session.kernel_pkt, - queue_info_session, + shared_ptr_info, queue_info_session.inst_pkt, dispatch_time); } @@ -164,7 +167,7 @@ AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data) } queue_info_session.queue.async_complete(); - delete static_cast(data); + delete &shared_ptr_info; return false; } @@ -446,20 +449,24 @@ WriteInterceptor(const void* packets, // Enqueue the signal into the handler. Will call completed_cb when // signal completes. - queue.signal_async_handler( - completion_signal, - new Queue::queue_info_session_t{.queue = queue, - .inst_pkt = std::move(inst_pkt), - .interrupt_signal = interrupt_signal, - .tid = thr_id, - .enqueue_ts = common::timestamp_ns(), - .user_data = user_data, - .correlation_id = corr_id, - .kernel_pkt = kernel_pkt, - .callback_record = callback_record, - .tracing_data = tracing_data_v}); { + Queue::queue_info_session_t info_session{.queue = queue, + .inst_pkt = std::move(inst_pkt), + .interrupt_signal = interrupt_signal, + .tid = thr_id, + .enqueue_ts = common::timestamp_ns(), + .user_data = user_data, + .correlation_id = corr_id, + .kernel_pkt = kernel_pkt, + .callback_record = callback_record, + .tracing_data = tracing_data_v}; + + auto shared = std::make_shared(std::move(info_session)); + + queue.signal_async_handler(completion_signal, + new std::shared_ptr(shared)); + auto tracer_data = callback_record; tracing::execute_phase_exit_callbacks(tracing_data_v.callback_contexts, tracing_data_v.external_correlation_ids, @@ -561,7 +568,7 @@ Queue::~Queue() } void -Queue::signal_async_handler(const hsa_signal_t& signal, Queue::queue_info_session_t* data) const +Queue::signal_async_handler(const hsa_signal_t& signal, void* data) const { #if !defined(NDEBUG) CHECK_NOTNULL(hsa::get_queue_controller())->_debug_signals.wlock([&](auto& signals) { diff --git a/source/lib/rocprofiler-sdk/hsa/queue.hpp b/source/lib/rocprofiler-sdk/hsa/queue.hpp index 0a432554ea..94c04f8ca8 100644 --- a/source/lib/rocprofiler-sdk/hsa/queue.hpp +++ b/source/lib/rocprofiler-sdk/hsa/queue.hpp @@ -87,7 +87,7 @@ public: // Signals the completion of the kernel packet. using completed_cb_t = std::function&, inst_pkt_t&, kernel_dispatch::profiling_time)>; using callback_map_t = std::unordered_map>; @@ -109,7 +109,7 @@ public: virtual const AgentCache& get_agent() const { return _agent; } void create_signal(uint32_t attribute, hsa_signal_t* signal) const; - void signal_async_handler(const hsa_signal_t& signal, Queue::queue_info_session_t* data) const; + void signal_async_handler(const hsa_signal_t& signal, void* data) const; template void signal_callback(FuncT&& func) const; diff --git a/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp b/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp index 6b8f3a04e4..815402a0f9 100644 --- a/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp +++ b/source/lib/rocprofiler-sdk/pc_sampling/hsa_adapter.cpp @@ -347,12 +347,12 @@ pc_sampling_service_finish_configuration(context::pc_sampling_service* service) const rocprofiler::hsa::Queue::queue_info_session_t::external_corr_id_map_t&, const context::correlation_id*) { return nullptr; }, // Completion CB - [](const rocprofiler::hsa::Queue& q, - rocprofiler::hsa::rocprofiler_packet kern_pkt, - const rocprofiler::hsa::Queue::queue_info_session_t& session, + [](const rocprofiler::hsa::Queue& q, + rocprofiler::hsa::rocprofiler_packet kern_pkt, + std::shared_ptr& session, rocprofiler::hsa::inst_pkt_t&, kernel_dispatch::profiling_time) { - kernel_completion_cb(q.get_agent().get_rocp_agent(), kern_pkt, session); + kernel_completion_cb(q.get_agent().get_rocp_agent(), kern_pkt, *session); }); } diff --git a/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp b/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp index a770a8e867..ae000d2764 100644 --- a/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp +++ b/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp @@ -437,9 +437,9 @@ DispatchThreadTracer::start_context() }, [=](const hsa::Queue& /* q */, hsa::rocprofiler_packet /* kern_pkt */, - const hsa::Queue::queue_info_session_t& session, - inst_pkt_t& aql, - kernel_dispatch::profiling_time) { this->post_kernel_call(aql, session); }); + std::shared_ptr& session, + inst_pkt_t& aql, + kernel_dispatch::profiling_time) { this->post_kernel_call(aql, *session); }); }); }