#include "lib/rocprofiler/counters/core.hpp" #include "lib/common/synchronized.hpp" #include "lib/rocprofiler/aql/helpers.hpp" #include "lib/rocprofiler/aql/packet_construct.hpp" #include "lib/rocprofiler/context/context.hpp" #include "lib/rocprofiler/hsa/queue_controller.hpp" #include "lib/rocprofiler/registration.hpp" #include namespace rocprofiler { namespace counters { /** * Callback we get from HSA interceptor when a kernel packet is being enqueued. * * We return an AQLPacket containing the start/stop/read packets for injection. */ std::unique_ptr queue_cb(const std::shared_ptr& info, const hsa::Queue& queue, hsa::ClientID, hsa::rocprofiler_packet) { if(!info) return nullptr; std::unique_ptr ret_pkt; // Check packet cache info->packets.wlock([&](auto& pkt_vector) { // Delay packet generator construction until first HSA packet is processed // This ensures that HSA exists if(!info->pkt_generator) { // One time setup of profile config if(info->profile_cfg.reqired_hw_counters.empty()) { auto& config = info->profile_cfg; auto agent_name = std::string(config.agent.name); for(const auto& metric : config.metrics) { auto req_counters = rocprofiler::counters::get_required_hardware_counters(agent_name, metric); if(!req_counters) { throw std::runtime_error( fmt::format("Could not find counter {}", metric.name())); } config.reqired_hw_counters.insert(req_counters->begin(), req_counters->end()); const auto& asts = rocprofiler::counters::get_ast_map(); const auto* agent_map = rocprofiler::common::get_val(asts, agent_name); if(!agent_map) throw std::runtime_error( fmt::format("Coult not build AST for {}", agent_name)); const auto* counter_ast = rocprofiler::common::get_val(*agent_map, metric.name()); if(!counter_ast) { throw std::runtime_error( fmt::format("Coult not find AST for {}", metric.name())); } config.asts.push_back(*counter_ast); } } info->pkt_generator = std::make_unique( queue.get_agent(), std::vector{info->profile_cfg.reqired_hw_counters.begin(), info->profile_cfg.reqired_hw_counters.end()}); } if(!pkt_vector.empty()) { ret_pkt = std::move(pkt_vector.back()); pkt_vector.pop_back(); } }); if(!ret_pkt) { // If we do not have a packet in the cache, create one. ret_pkt = info->pkt_generator->construct_packet(hsa::get_queue_controller().get_ext_table()); } return ret_pkt; } /** * Callback called by HSA interceptor when the kernel has completed processing. */ void completed_cb(const std::shared_ptr& info, const hsa::Queue& queue, hsa::ClientID, hsa::rocprofiler_packet kernel, std::unique_ptr pkt) { if(!info) return; // auto out_buf = pkt->profile.output_buffer.ptr; // Read data and create user return.... auto decoded_pkt = EvaluateAST::read_pkt(info->pkt_generator.get(), *pkt); // return AQL packet for reuse. info->packets.wlock([&](auto& pkt_vector) { if(pkt) { pkt_vector.emplace_back(std::move(pkt)); } }); if(!info->user_cb) return; std::vector out; for(auto& ast : info->profile_cfg.asts) { auto* ret = ast.evaluate(decoded_pkt); CHECK(ret); out.insert(out.end(), ret->begin(), ret->end()); } // Maybe move to its own thread? info->user_cb(queue.get_id(), info->profile_cfg.agent, rocprofiler_correlation_id_t{}, &kernel.kernel_dispatch, info->callback_args, out.data(), // Date pointer does here. out.size(), // Number of objects info->profile_cfg.id); } class CounterController { public: // Adds a counter collection profile to our global cache. // Note: these profiles can be used across multiple contexts // and are independent of the context. uint64_t add_profile(profile_config&& config) { static std::atomic profile_val = 1; uint64_t ret = 0; _configs.wlock([&](auto& data) { config.id = rocprofiler_profile_config_id_t{.handle = profile_val}; data.emplace(profile_val, std::move(config)); ret = profile_val; profile_val++; }); return ret; } void destroy_profile(uint64_t id) { _configs.wlock([&](auto& data) { data.erase(id); }); } // Setup the counter collection service. counter_callback_info is created here // to contain the counters that need to be collected (specified in profile_id) and // the AQL packet generator for injecting packets. Note: the service is created // in the stop state. bool configure_dispatch(rocprofiler_context_id_t context_id, uint64_t profile_id, rocprofiler_profile_counting_dispatch_callback_t callback, void* callback_args) const { auto& ctx = *rocprofiler::context::get_registered_contexts().at(context_id.handle); // Note: A single profile config could be used on multiple contexts profile_config cfg; _configs.rlock([&](const auto& map) { cfg = map.at(profile_id); }); if(!ctx.counter_collection) { ctx.counter_collection = std::make_unique(); } auto& cb = *ctx.counter_collection->callbacks.emplace_back( std::make_shared()); cb.user_cb = callback; // Secondary copy of the config to be shared with async callback cb.profile_cfg = cfg; cb.callback_args = callback_args; cb.context = context_id; return true; } private: rocprofiler::common::Synchronized> _configs; }; CounterController& get_controller() { static CounterController controller; return controller; } uint64_t create_counter_profile(profile_config&& config) { return get_controller().add_profile(std::move(config)); } void destroy_counter_profile(uint64_t id) { get_controller().destroy_profile(id); } void start_context(context::context* ctx) { if(!ctx || !ctx->counter_collection) return; auto& controller = hsa::get_queue_controller(); // Only one thread should be attempting to enable/disable this context ctx->counter_collection->enabled.wlock([&](auto& enabled) { if(enabled) return; for(auto& cb : ctx->counter_collection->callbacks) { // Insert our callbacks into HSA Interceptor. This // turns on counter instrumentation. cb->queue_id = controller.add_callback( cb->profile_cfg.agent, [=](const hsa::Queue& q, hsa::ClientID c, hsa::rocprofiler_packet kern_pkt) { return queue_cb(cb, q, c, kern_pkt); }, // Completion CB [=](const hsa::Queue& q, hsa::ClientID c, hsa::rocprofiler_packet kern_pkt, std::unique_ptr aql) { completed_cb(cb, q, c, kern_pkt, std::move(aql)); }); } enabled = true; }); } void stop_context(context::context* ctx) { if(!ctx || !ctx->counter_collection) return; auto& controller = hsa::get_queue_controller(); ctx->counter_collection->enabled.wlock([&](auto& enabled) { if(!enabled) return; for(auto& cb : ctx->counter_collection->callbacks) { // Remove our callbacks from HSA's queue controller controller.remove_callback(cb->queue_id); cb->queue_id = -1; } enabled = false; }); } bool configure_dispatch(rocprofiler_context_id_t context_id, uint64_t profile_id, rocprofiler_profile_counting_dispatch_callback_t callback, void* callback_args) { return get_controller().configure_dispatch(context_id, profile_id, callback, callback_args); } } // namespace counters } // namespace rocprofiler