SWDEV-489158: Adding consumer+producer model to AST evaluation (#13)

* Rebased optizations for rocprofv3 tool

* Fixing merge conflicts

* Formatting

* Open from within mutex

* Small name changes

* Added operator

* removed some parameters

* Optimizing counter collection

* Re-arrange code

* Adding back dimension query

* Formatting

* Update source/lib/rocprofiler-sdk/thread_trace/att_core.cpp

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Formatting 2

* Fix for test compilation

* Fix for yield

* Adding back check for zero

* Improved thread handling

* Formatting

* Remove automatic start

* Adding test

* Small fixes

* Adding lock for buffer callbacks

* Fix for race condition in AST

* Adding check for ptr

---------

Co-authored-by: Giovanni Baraldi <gbaraldi@amd.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Baraldi, Giovanni
2024-12-05 05:33:53 +01:00
committed by GitHub
parent c42bdc3128
commit b7661bccfd
15 changed files with 577 additions and 153 deletions
@@ -1,9 +1,24 @@
set(ROCPROFILER_LIB_COUNTERS_SOURCES
metrics.cpp dimensions.cpp evaluate_ast.cpp core.cpp id_decode.cpp
dispatch_handlers.cpp controller.cpp device_counting.cpp)
metrics.cpp
dimensions.cpp
evaluate_ast.cpp
core.cpp
id_decode.cpp
dispatch_handlers.cpp
sample_processing.cpp
controller.cpp
device_counting.cpp)
set(ROCPROFILER_LIB_COUNTERS_HEADERS
metrics.hpp dimensions.hpp evaluate_ast.hpp core.hpp id_decode.hpp
dispatch_handlers.hpp controller.hpp device_counting.hpp)
metrics.hpp
dimensions.hpp
evaluate_ast.hpp
core.hpp
id_decode.hpp
dispatch_handlers.hpp
sample_processing.hpp
controller.hpp
device_counting.hpp
sample_consumer.hpp)
target_sources(rocprofiler-sdk-object-library PRIVATE ${ROCPROFILER_LIB_COUNTERS_SOURCES}
${ROCPROFILER_LIB_COUNTERS_HEADERS})
+11 -6
View File
@@ -28,6 +28,7 @@
#include "lib/rocprofiler-sdk/aql/packet_construct.hpp"
#include "lib/rocprofiler-sdk/context/context.hpp"
#include "lib/rocprofiler-sdk/counters/dispatch_handlers.hpp"
#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"
#include "lib/rocprofiler-sdk/hsa/queue_controller.hpp"
#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp"
@@ -157,6 +158,8 @@ start_context(const context::context* ctx)
if(!already_enabled)
{
callback_thread_start();
for(auto& cb : ctx->counter_collection->callbacks)
{
// Insert our callbacks into HSA Interceptor. This
@@ -182,12 +185,12 @@ start_context(const context::context* ctx)
correlation_id);
},
// Completion CB
[=](const hsa::Queue& q,
hsa::rocprofiler_packet kern_pkt,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& aql,
kernel_dispatch::profiling_time dispatch_time) {
completed_cb(ctx, cb, q, kern_pkt, session, aql, dispatch_time);
[=](const hsa::Queue& /* q */,
hsa::rocprofiler_packet /* kern_pkt */,
std::shared_ptr<hsa::Queue::queue_info_session_t>& session,
inst_pkt_t& aql,
kernel_dispatch::profiling_time dispatch_time) {
completed_cb(ctx, cb, session, aql, dispatch_time);
});
}
}
@@ -206,6 +209,8 @@ stop_context(const context::context* ctx)
});
if(controller) controller->disable_serialization();
callback_thread_stop();
}
rocprofiler_status_t
@@ -30,6 +30,7 @@
#include "lib/rocprofiler-sdk/buffer.hpp"
#include "lib/rocprofiler-sdk/context/context.hpp"
#include "lib/rocprofiler-sdk/counters/core.hpp"
#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"
#include "lib/rocprofiler-sdk/hsa/queue_controller.hpp"
#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp"
@@ -162,14 +163,13 @@ queue_cb(const context::context* ctx,
* Callback called by HSA interceptor when the kernel has completed processing.
*/
void
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
const hsa::Queue& /*queue*/,
hsa::rocprofiler_packet /*packet*/,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time)
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
std::shared_ptr<hsa::Queue::queue_info_session_t>& ptr_session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time)
{
auto& session = *ptr_session;
CHECK(info && ctx);
std::shared_ptr<profile_config> prof_config;
@@ -198,98 +198,9 @@ completed_cb(const context::context* ctx,
// We have no profile config, nothing to output.
if(!prof_config) return;
auto decoded_pkt = EvaluateAST::read_pkt(prof_config->pkt_generator.get(), *pkt);
EvaluateAST::read_special_counters(
*prof_config->agent, prof_config->required_special_counters, decoded_pkt);
prof_config->packets.wlock([&](auto& pkt_vector) {
if(pkt)
{
pkt_vector.emplace_back(std::move(pkt));
}
});
common::container::small_vector<rocprofiler_record_counter_t, 128> out;
rocprofiler::buffer::instance* buf = nullptr;
if(info->buffer)
{
buf = CHECK_NOTNULL(buffer::get_buffer(info->buffer->handle));
}
auto _corr_id_v =
rocprofiler_correlation_id_t{.internal = 0, .external = context::null_user_data};
if(const auto* _corr_id = session.correlation_id)
{
_corr_id_v.internal = _corr_id->internal;
if(const auto* external = rocprofiler::common::get_val(
session.tracing_data.external_correlation_ids, info->internal_context))
{
_corr_id_v.external = *external;
}
}
auto _dispatch_id = session.callback_record.dispatch_info.dispatch_id;
for(auto& ast : prof_config->asts)
{
std::vector<std::unique_ptr<std::vector<rocprofiler_record_counter_t>>> cache;
auto* ret = ast.evaluate(decoded_pkt, cache);
CHECK(ret);
ast.set_out_id(*ret);
out.reserve(out.size() + ret->size());
for(auto& val : *ret)
{
val.agent_id = prof_config->agent->id;
val.dispatch_id = _dispatch_id;
out.emplace_back(val);
}
}
if(!out.empty())
{
if(buf)
{
auto _header =
common::init_public_api_struct(rocprofiler_dispatch_counting_service_record_t{});
_header.num_records = out.size();
_header.correlation_id = _corr_id_v;
if(dispatch_time.status == HSA_STATUS_SUCCESS)
{
_header.start_timestamp = dispatch_time.start;
_header.end_timestamp = dispatch_time.end;
}
_header.dispatch_info = session.callback_record.dispatch_info;
buf->emplace(ROCPROFILER_BUFFER_CATEGORY_COUNTERS,
ROCPROFILER_COUNTER_RECORD_PROFILE_COUNTING_DISPATCH_HEADER,
_header);
for(auto itr : out)
buf->emplace(
ROCPROFILER_BUFFER_CATEGORY_COUNTERS, ROCPROFILER_COUNTER_RECORD_VALUE, itr);
}
else
{
CHECK(info->record_callback);
auto dispatch_data =
common::init_public_api_struct(rocprofiler_dispatch_counting_service_data_t{});
dispatch_data.dispatch_info = session.callback_record.dispatch_info;
dispatch_data.correlation_id = _corr_id_v;
if(dispatch_time.status == HSA_STATUS_SUCCESS)
{
dispatch_data.start_timestamp = dispatch_time.start;
dispatch_data.end_timestamp = dispatch_time.end;
}
info->record_callback(dispatch_data,
out.data(),
out.size(),
session.user_data,
info->record_callback_args);
}
}
completed_cb_params_t params{info, ptr_session, dispatch_time, prof_config, std::move(pkt)};
process_callback_data(std::move(params));
}
} // namespace counters
} // namespace rocprofiler
@@ -46,12 +46,11 @@ queue_cb(const context::context* ctx,
const context::correlation_id* correlation_id);
void
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
const hsa::Queue& queue,
hsa::rocprofiler_packet packet,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time);
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
std::shared_ptr<hsa::Queue::queue_info_session_t>& session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time);
} // namespace counters
} // namespace rocprofiler
@@ -21,6 +21,8 @@
// SOFTWARE.
#include "lib/rocprofiler-sdk/counters/evaluate_ast.hpp"
#include "lib/common/static_object.hpp"
#include "lib/common/synchronized.hpp"
#include <algorithm>
#include <cstdint>
@@ -569,7 +571,9 @@ using property_function_t = int64_t (*)(const rocprofiler_agent_t&);
int64_t
get_agent_property(std::string_view property, const rocprofiler_agent_t& agent)
{
static std::unordered_map<std::string_view, property_function_t> props = {
using map_t = std::unordered_map<std::string_view, property_function_t>;
static auto*& _props = common::static_object<common::Synchronized<map_t>>::construct(map_t{
GEN_MAP_ENTRY("cpu_cores_count", agent_info.cpu_cores_count),
GEN_MAP_ENTRY("simd_count", agent_info.simd_count),
GEN_MAP_ENTRY("mem_banks_count", agent_info.mem_banks_count),
@@ -599,13 +603,15 @@ get_agent_property(std::string_view property, const rocprofiler_agent_t& agent)
GEN_MAP_ENTRY("num_sdma_queues_per_engine", agent_info.num_sdma_queues_per_engine),
GEN_MAP_ENTRY("num_cp_queues", agent_info.num_cp_queues),
GEN_MAP_ENTRY("max_engine_clk_ccompute", agent_info.max_engine_clk_ccompute),
};
if(const auto* func = rocprofiler::common::get_val(props, property))
{
return (*func)(agent);
}
});
return 0.0;
return CHECK_NOTNULL(_props)->wlock([&property, &agent](map_t& props) -> int64_t {
if(const auto* func = rocprofiler::common::get_val(props, property))
{
return (*func)(agent);
}
return 0;
});
}
void
@@ -0,0 +1,109 @@
// MIT License
//
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#pragma once
#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"
#include <condition_variable>
#include <mutex>
#include <thread>
namespace rocprofiler
{
namespace counters
{
template <typename DataType>
class consumer_thread_t
{
static constexpr size_t SIZE = 128;
using consume_func_t = std::function<void(DataType&&)>;
public:
consumer_thread_t(consume_func_t func) { this->consume_fn = func; }
virtual ~consumer_thread_t() { exit(); }
void start()
{
{
std::unique_lock<std::mutex> lk(mut);
if(valid.exchange(true)) return;
}
consumer = std::thread{&consumer_thread_t::consumer_loop, this};
}
void exit()
{
{
std::unique_lock<std::mutex> lk(mut);
if(!valid.exchange(false)) return;
cv.notify_one();
}
consumer.join();
}
void add(DataType&& params)
{
std::unique_lock<std::mutex> lk(mut);
if(read_ptr + buffer.size() <= write_ptr || !valid)
{
// If not possible to use consumer thread, proccess with this thread
consume_fn(std::move(params));
return;
}
buffer.at(write_ptr % buffer.size()) = std::move(params);
write_ptr.fetch_add(1);
cv.notify_one();
}
protected:
void consumer_loop()
{
while(true)
{
while(read_ptr == write_ptr)
{
std::unique_lock<std::mutex> lk(mut);
cv.wait(lk, [&] { return read_ptr != write_ptr || !valid; });
if(!valid && read_ptr == write_ptr) return;
}
auto retrieved = std::move(buffer.at(read_ptr % buffer.size()));
read_ptr.fetch_add(1);
consume_fn(std::move(retrieved));
}
}
consume_func_t consume_fn;
std::atomic<bool> valid{false};
std::mutex mut;
std::atomic<size_t> write_ptr{0};
std::atomic<size_t> read_ptr{0};
std::array<DataType, SIZE> buffer;
std::thread consumer;
std::condition_variable cv;
};
} // namespace counters
} // namespace rocprofiler
@@ -0,0 +1,184 @@
// MIT License
//
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"
#include "lib/common/container/small_vector.hpp"
#include "lib/common/synchronized.hpp"
#include "lib/common/utility.hpp"
#include "lib/rocprofiler-sdk/buffer.hpp"
#include "lib/rocprofiler-sdk/context/context.hpp"
#include "lib/rocprofiler-sdk/counters/core.hpp"
#include "lib/rocprofiler-sdk/counters/sample_consumer.hpp"
#include "lib/rocprofiler-sdk/hsa/queue_controller.hpp"
#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp"
#include <rocprofiler-sdk/fwd.h>
#include <rocprofiler-sdk/rocprofiler.h>
namespace rocprofiler
{
namespace counters
{
std::mutex&
get_buffer_mut()
{
static auto*& mut = common::static_object<std::mutex>::construct();
return *CHECK_NOTNULL(mut);
}
/**
* Callback called by HSA interceptor when the kernel has completed processing.
*/
void
proccess_completed_cb(completed_cb_params_t&& params)
{
auto& info = params.info;
auto& session = *params.session;
auto& dispatch_time = params.dispatch_time;
auto& prof_config = params.prof_config;
auto& pkt = params.pkt;
ROCP_FATAL_IF(pkt == nullptr) << "AQL packet is a nullptr!";
auto decoded_pkt = EvaluateAST::read_pkt(prof_config->pkt_generator.get(), *pkt);
EvaluateAST::read_special_counters(
*prof_config->agent, prof_config->required_special_counters, decoded_pkt);
prof_config->packets.wlock([&](auto& pkt_vector) { pkt_vector.emplace_back(std::move(pkt)); });
common::container::small_vector<rocprofiler_record_counter_t, 128> out;
rocprofiler::buffer::instance* buf = nullptr;
if(info->buffer)
{
buf = CHECK_NOTNULL(buffer::get_buffer(info->buffer->handle));
}
auto _corr_id_v =
rocprofiler_correlation_id_t{.internal = 0, .external = context::null_user_data};
if(const auto* _corr_id = session.correlation_id)
{
_corr_id_v.internal = _corr_id->internal;
if(const auto* external = rocprofiler::common::get_val(
session.tracing_data.external_correlation_ids, info->internal_context))
{
_corr_id_v.external = *external;
}
}
auto _dispatch_id = session.callback_record.dispatch_info.dispatch_id;
for(auto& ast : prof_config->asts)
{
std::vector<std::unique_ptr<std::vector<rocprofiler_record_counter_t>>> cache;
auto* ret = ast.evaluate(decoded_pkt, cache);
CHECK(ret);
ast.set_out_id(*ret);
out.reserve(out.size() + ret->size());
for(auto& val : *ret)
{
val.agent_id = prof_config->agent->id;
val.dispatch_id = _dispatch_id;
out.emplace_back(val);
}
}
if(!out.empty())
{
if(buf)
{
auto _header =
common::init_public_api_struct(rocprofiler_dispatch_counting_service_record_t{});
_header.num_records = out.size();
_header.correlation_id = _corr_id_v;
if(dispatch_time.status == HSA_STATUS_SUCCESS)
{
_header.start_timestamp = dispatch_time.start;
_header.end_timestamp = dispatch_time.end;
}
_header.dispatch_info = session.callback_record.dispatch_info;
auto _lk = std::unique_lock{get_buffer_mut()}; // Buffer records need to be in order
buf->emplace(ROCPROFILER_BUFFER_CATEGORY_COUNTERS,
ROCPROFILER_COUNTER_RECORD_PROFILE_COUNTING_DISPATCH_HEADER,
_header);
for(auto itr : out)
buf->emplace(
ROCPROFILER_BUFFER_CATEGORY_COUNTERS, ROCPROFILER_COUNTER_RECORD_VALUE, itr);
}
else
{
CHECK(info->record_callback);
auto dispatch_data =
common::init_public_api_struct(rocprofiler_dispatch_counting_service_data_t{});
dispatch_data.dispatch_info = session.callback_record.dispatch_info;
dispatch_data.correlation_id = _corr_id_v;
if(dispatch_time.status == HSA_STATUS_SUCCESS)
{
dispatch_data.start_timestamp = dispatch_time.start;
dispatch_data.end_timestamp = dispatch_time.end;
}
info->record_callback(dispatch_data,
out.data(),
out.size(),
session.user_data,
info->record_callback_args);
}
}
}
auto&
callback_thread_get()
{
using consumer_t = consumer_thread_t<completed_cb_params_t>;
static auto*& _v = common::static_object<consumer_t>::construct(proccess_completed_cb);
return *CHECK_NOTNULL(_v);
}
void
callback_thread_start()
{
callback_thread_get().start();
}
void
callback_thread_stop()
{
callback_thread_get().exit();
}
void
process_callback_data(completed_cb_params_t&& params)
{
callback_thread_get().add(std::move(params));
}
} // namespace counters
} // namespace rocprofiler
@@ -0,0 +1,52 @@
// MIT License
//
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#pragma once
#include "lib/rocprofiler-sdk/context/context.hpp"
#include "lib/rocprofiler-sdk/hsa/aql_packet.hpp"
#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp"
namespace rocprofiler
{
namespace counters
{
struct completed_cb_params_t
{
std::shared_ptr<counter_callback_info> info;
std::shared_ptr<hsa::Queue::queue_info_session_t> session;
kernel_dispatch::profiling_time dispatch_time;
std::shared_ptr<profile_config> prof_config;
std::unique_ptr<rocprofiler::hsa::AQLPacket> pkt;
};
void
callback_thread_start();
void
callback_thread_stop();
void
process_callback_data(completed_cb_params_t&& params);
} // namespace counters
} // namespace rocprofiler
@@ -84,3 +84,21 @@ gtest_add_tests(
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_tests_properties(${counter-tests_TESTS} PROPERTIES TIMEOUT 45 LABELS "unittests")
set(ROCPROFILER_LIB_CONSUMER_TEST_SOURCES consumer_test.cpp)
add_executable(consumer-test)
target_sources(consumer-test PRIVATE ${ROCPROFILER_LIB_CONSUMER_TEST_SOURCES})
target_link_libraries(
consumer-test rocprofiler-sdk::rocprofiler-sdk-hsa-runtime
rocprofiler-sdk::rocprofiler-sdk-hip rocprofiler-sdk::rocprofiler-sdk-common-library
rocprofiler-sdk::rocprofiler-sdk-static-library GTest::gtest GTest::gtest_main)
gtest_add_tests(
TARGET consumer-test
SOURCES ${ROCPROFILER_LIB_CONSUMER_TEST_SOURCES}
TEST_LIST consumer-tests_TESTS
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_tests_properties(${consumer-tests_TESTS} PROPERTIES TIMEOUT 45 LABELS "unittests")
@@ -0,0 +1,116 @@
// MIT License
//
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include <algorithm>
#include <array>
#include <cstdint>
#include <future>
#include <mutex>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
#include <fmt/core.h>
#include <gtest/gtest.h>
#include "lib/rocprofiler-sdk/counters/sample_consumer.hpp"
namespace rocprofiler
{
namespace counters
{
constexpr size_t NUM_THREADS = 5;
constexpr size_t NUM_ELEMENTS = 1ul << 17;
using result_array_t = std::array<std::atomic<size_t>, NUM_ELEMENTS>;
using result_array_ptr_t = std::shared_ptr<result_array_t>;
struct DummyData
{
size_t index;
size_t increment;
result_array_ptr_t array;
};
using consumer_t = consumer_thread_t<DummyData>;
void
consume_fn(DummyData&& data)
{
data.array->at(data.index).fetch_add(data.increment);
}
TEST(consumer, nothread)
{
auto array = std::make_shared<result_array_t>();
consumer_t consumer(consume_fn);
consumer.add(DummyData{1, 1, array});
EXPECT_EQ(array->at(0).load(), 0);
EXPECT_EQ(array->at(1).load(), 1);
}
TEST(consumer, singlethread)
{
auto array = std::make_shared<result_array_t>();
{
consumer_t consumer(consume_fn);
consumer.start();
for(size_t i = 0; i < NUM_ELEMENTS; i++)
consumer.add(DummyData{i, 1, array});
}
for(auto& var : *array)
EXPECT_EQ(var.load(), 1);
}
TEST(consumer, multithreaded)
{
auto array = std::make_shared<result_array_t>();
consumer_t consumer(consume_fn);
auto produce_fn = [&](size_t tid) {
for(size_t i = 0; i < NUM_ELEMENTS; i++)
consumer.add(DummyData{i, tid, array});
};
{
std::vector<std::future<void>> threads{};
for(size_t i = 0; i < NUM_THREADS; i++)
threads.push_back(std::async(std::launch::async, produce_fn, i + 1));
consumer.start();
}
consumer.exit();
size_t expected = NUM_THREADS * (NUM_THREADS + 1) / 2;
for(auto& var : *array)
EXPECT_EQ(var.load(), expected);
}
} // namespace counters
} // namespace rocprofiler
@@ -459,14 +459,16 @@ TEST(core, check_callbacks)
&opt_buff_id),
"Could not create buffer");
cb_info->buffer = opt_buff_id;
// hsa::Queue::queue_info_session_t sess = {.queue = fq, .correlation_id = &corr_id};
hsa::Queue::queue_info_session_t sess = hsa::Queue::queue_info_session_t{.queue = fq};
sess.correlation_id = &corr_id;
auto _sess = hsa::Queue::queue_info_session_t{.queue = fq};
_sess.correlation_id = &corr_id;
auto sess = std::make_shared<hsa::Queue::queue_info_session_t>(std::move(_sess));
counters::inst_pkt_t pkts;
pkts.emplace_back(
std::make_pair(std::move(ret_pkt), static_cast<counters::ClientID>(0)));
completed_cb(&ctx, cb_info, fq, pkt, sess, pkts, kernel_dispatch::profiling_time{});
completed_cb(&ctx, cb_info, sess, pkts, kernel_dispatch::profiling_time{});
rocprofiler_flush_buffer(opt_buff_id);
rocprofiler_destroy_buffer(opt_buff_id);
}
+24 -17
View File
@@ -43,6 +43,7 @@
#include <hsa/hsa_ext_amd.h>
#include <atomic>
#include <memory>
// static assert for rocprofiler_packet ABI compatibility
static_assert(sizeof(hsa_ext_amd_aql_pm4_packet_t) == sizeof(hsa_kernel_dispatch_packet_t),
@@ -116,8 +117,10 @@ AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data)
get_balanced_signal_slots().fetch_add(1);
auto& queue_info_session = *static_cast<Queue::queue_info_session_t*>(data);
auto dispatch_time = kernel_dispatch::get_dispatch_time(queue_info_session);
auto& shared_ptr_info = *static_cast<std::shared_ptr<Queue::queue_info_session_t>*>(data);
auto& queue_info_session = *shared_ptr_info;
auto dispatch_time = kernel_dispatch::get_dispatch_time(queue_info_session);
kernel_dispatch::dispatch_complete(queue_info_session, dispatch_time);
@@ -128,7 +131,7 @@ AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data)
{
cb_pair.second(queue_info_session.queue,
queue_info_session.kernel_pkt,
queue_info_session,
shared_ptr_info,
queue_info_session.inst_pkt,
dispatch_time);
}
@@ -164,7 +167,7 @@ AsyncSignalHandler(hsa_signal_value_t /*signal_v*/, void* data)
}
queue_info_session.queue.async_complete();
delete static_cast<Queue::queue_info_session_t*>(data);
delete &shared_ptr_info;
return false;
}
@@ -446,20 +449,24 @@ WriteInterceptor(const void* packets,
// Enqueue the signal into the handler. Will call completed_cb when
// signal completes.
queue.signal_async_handler(
completion_signal,
new Queue::queue_info_session_t{.queue = queue,
.inst_pkt = std::move(inst_pkt),
.interrupt_signal = interrupt_signal,
.tid = thr_id,
.enqueue_ts = common::timestamp_ns(),
.user_data = user_data,
.correlation_id = corr_id,
.kernel_pkt = kernel_pkt,
.callback_record = callback_record,
.tracing_data = tracing_data_v});
{
Queue::queue_info_session_t info_session{.queue = queue,
.inst_pkt = std::move(inst_pkt),
.interrupt_signal = interrupt_signal,
.tid = thr_id,
.enqueue_ts = common::timestamp_ns(),
.user_data = user_data,
.correlation_id = corr_id,
.kernel_pkt = kernel_pkt,
.callback_record = callback_record,
.tracing_data = tracing_data_v};
auto shared = std::make_shared<Queue::queue_info_session_t>(std::move(info_session));
queue.signal_async_handler(completion_signal,
new std::shared_ptr<Queue::queue_info_session_t>(shared));
auto tracer_data = callback_record;
tracing::execute_phase_exit_callbacks(tracing_data_v.callback_contexts,
tracing_data_v.external_correlation_ids,
@@ -561,7 +568,7 @@ Queue::~Queue()
}
void
Queue::signal_async_handler(const hsa_signal_t& signal, Queue::queue_info_session_t* data) const
Queue::signal_async_handler(const hsa_signal_t& signal, void* data) const
{
#if !defined(NDEBUG)
CHECK_NOTNULL(hsa::get_queue_controller())->_debug_signals.wlock([&](auto& signals) {
+2 -2
View File
@@ -87,7 +87,7 @@ public:
// Signals the completion of the kernel packet.
using completed_cb_t = std::function<void(const Queue&,
const rocprofiler_packet&,
const Queue::queue_info_session_t&,
std::shared_ptr<Queue::queue_info_session_t>&,
inst_pkt_t&,
kernel_dispatch::profiling_time)>;
using callback_map_t = std::unordered_map<ClientID, std::pair<queue_cb_t, completed_cb_t>>;
@@ -109,7 +109,7 @@ public:
virtual const AgentCache& get_agent() const { return _agent; }
void create_signal(uint32_t attribute, hsa_signal_t* signal) const;
void signal_async_handler(const hsa_signal_t& signal, Queue::queue_info_session_t* data) const;
void signal_async_handler(const hsa_signal_t& signal, void* data) const;
template <typename FuncT>
void signal_callback(FuncT&& func) const;
@@ -347,12 +347,12 @@ pc_sampling_service_finish_configuration(context::pc_sampling_service* service)
const rocprofiler::hsa::Queue::queue_info_session_t::external_corr_id_map_t&,
const context::correlation_id*) { return nullptr; },
// Completion CB
[](const rocprofiler::hsa::Queue& q,
rocprofiler::hsa::rocprofiler_packet kern_pkt,
const rocprofiler::hsa::Queue::queue_info_session_t& session,
[](const rocprofiler::hsa::Queue& q,
rocprofiler::hsa::rocprofiler_packet kern_pkt,
std::shared_ptr<rocprofiler::hsa::Queue::queue_info_session_t>& session,
rocprofiler::hsa::inst_pkt_t&,
kernel_dispatch::profiling_time) {
kernel_completion_cb(q.get_agent().get_rocp_agent(), kern_pkt, session);
kernel_completion_cb(q.get_agent().get_rocp_agent(), kern_pkt, *session);
});
}
@@ -437,9 +437,9 @@ DispatchThreadTracer::start_context()
},
[=](const hsa::Queue& /* q */,
hsa::rocprofiler_packet /* kern_pkt */,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& aql,
kernel_dispatch::profiling_time) { this->post_kernel_call(aql, session); });
std::shared_ptr<hsa::Queue::queue_info_session_t>& session,
inst_pkt_t& aql,
kernel_dispatch::profiling_time) { this->post_kernel_call(aql, *session); });
});
}