diff --git a/projects/rocprofiler-systems/source/lib/rocprof-sys/library/rocprofiler-sdk/rccl.cpp b/projects/rocprofiler-systems/source/lib/rocprof-sys/library/rocprofiler-sdk/rccl.cpp index bd26bce5a5..3886af5c2d 100644 --- a/projects/rocprofiler-systems/source/lib/rocprof-sys/library/rocprofiler-sdk/rccl.cpp +++ b/projects/rocprofiler-systems/source/lib/rocprof-sys/library/rocprofiler-sdk/rccl.cpp @@ -22,13 +22,19 @@ #include "library/rocprofiler-sdk/rccl.hpp" +#include "core/categories.hpp" +#include "core/components/fwd.hpp" #include "core/config.hpp" #include "core/perfetto.hpp" +#include "core/trace_cache/cache_manager.hpp" +#include "core/trace_cache/sample_type.hpp" #include "library/tracing.hpp" #include "logger/debug.hpp" +#include + namespace rocprofsys { namespace rocprofiler_sdk @@ -70,6 +76,33 @@ write_perfetto_counter_track(uint64_t _val, uint64_t _begin_ts, uint64_t _end_ts } } +template +void +cache_rccl_comm_data_events(size_t bytes, uint64_t timestamp_ns) +{ + static std::mutex _mutex{}; + static uint64_t cumulative_bytes = 0; + { + std::unique_lock _lk{ _mutex }; + bytes = (cumulative_bytes += bytes); + } + const std::string track_name = Track::label; + const std::string event_metadata = "{}"; + const size_t stack_id = 0; + const size_t parent_stack_id = 0; + const size_t correlation_id = 0; + const std::string call_stack = "{}"; + const std::string line_info = "{}"; + const uint32_t device_id = 0; + + trace_cache::get_buffer_storage().store(trace_cache::pmc_event_with_sample{ + static_cast(category_enum_id::value), + track_name.c_str(), timestamp_ns, event_metadata.c_str(), stack_id, + parent_stack_id, correlation_id, call_stack.c_str(), line_info.c_str(), device_id, + static_cast(agent_type::CPU), track_name.c_str(), + static_cast(cumulative_bytes) }); +} + static auto rccl_type_size(ncclDataType_t datatype) { @@ -172,12 +205,18 @@ tool_tracing_callback_rccl(rocprofiler_callback_tracing_record_t record, break; } - if(config::get_use_perfetto() && size > 0) + if(size > 0) { if(is_send) + { + cache_rccl_comm_data_events(size, end_ts); write_perfetto_counter_track(size, begin_ts, end_ts); + } else + { + cache_rccl_comm_data_events(size, end_ts); write_perfetto_counter_track(size, begin_ts, end_ts); + } } } }