From 72fdb3e8ed24fdc9ec906c8a360c12eb40662889 Mon Sep 17 00:00:00 2001 From: "Jonathan R. Madsen" Date: Thu, 14 Mar 2024 01:25:43 -0500 Subject: [PATCH] Fix tracing context domain logic for operations (#621) * Fix tracing context domain logic for operations - logic error: domain enabled (all operations all implicitly enabled) + domain enabled for subset of operations resulted in only explicitly enabled operations being treated as enabled - domain_context: split single bitset for operations in all domains into array of bitsets for each domain * Update lib/common/mpl.hpp - assert_false for static_asserts in if constexpr expressions * Update lib/rocprofiler-sdk/tests/contexts.cpp - Tests for validating logic regarding domain and operations for callback and buffer tracing [ROCm/rocprofiler-sdk commit: 7ab1a8015fdc44a35f6ab44ff7998a62ef4d21a5] --- .../rocprofiler-sdk/source/lib/common/mpl.hpp | 6 + .../lib/rocprofiler-sdk/context/domain.cpp | 42 +- .../lib/rocprofiler-sdk/context/domain.hpp | 16 +- .../lib/rocprofiler-sdk/tests/CMakeLists.txt | 4 +- .../lib/rocprofiler-sdk/tests/contexts.cpp | 384 ++++++++++++++++++ 5 files changed, 435 insertions(+), 17 deletions(-) create mode 100644 projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/contexts.cpp diff --git a/projects/rocprofiler-sdk/source/lib/common/mpl.hpp b/projects/rocprofiler-sdk/source/lib/common/mpl.hpp index 151041a01e..c96270ca80 100644 --- a/projects/rocprofiler-sdk/source/lib/common/mpl.hpp +++ b/projects/rocprofiler-sdk/source/lib/common/mpl.hpp @@ -156,6 +156,12 @@ struct unqualified_type template using unqualified_type_t = typename unqualified_type::type; + +template +struct assert_false +{ + static constexpr auto value = false; +}; } // namespace mpl } // namespace common } // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.cpp index 2e6cd7d9c3..132662aa05 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.cpp @@ -21,8 +21,14 @@ // SOFTWARE. #include "lib/rocprofiler-sdk/context/domain.hpp" + +#include #include +#include + +#include + namespace rocprofiler { namespace context @@ -31,25 +37,41 @@ template bool domain_context::operator()(DomainT _domain) const { - return ((1 << _domain) & domains) == (1 << _domain); + constexpr uint64_t one = 1; + + if(_domain <= domain_info::none) return false; + + auto _didx = (_domain - 1); + return ((one << _didx) & domains) == (one << _didx); } template bool domain_context::operator()(DomainT _domain, uint32_t _op) const { - auto _offset = (_domain * opcode_padding_v); - return (*this)(_domain) && (opcodes.none() || opcodes.test(_offset + _op)); + if(_domain <= domain_info::none) return false; + + auto _didx = (_domain - 1); + + if(_didx >= array_size) return false; + + return (*this)(_domain) && (opcodes.at(_didx).none() || opcodes.at(_didx).test(_op)); } template rocprofiler_status_t add_domain(domain_context& _cfg, DomainT _domain) { - if(_domain <= domain_info::none || _domain >= domain_info::last) - return ROCPROFILER_STATUS_ERROR_KIND_NOT_FOUND; + static_assert((1 << domain_info::last) < std::numeric_limits::max(), + "uint64_t cannot handle all the domains"); - _cfg.domains |= (1 << _domain); + if(_domain <= domain_info::none) return ROCPROFILER_STATUS_ERROR_KIND_NOT_FOUND; + + auto _didx = (_domain - 1); + + if(_didx >= _cfg.array_size) return ROCPROFILER_STATUS_ERROR_KIND_NOT_FOUND; + + _cfg.domains |= (1 << _didx); return ROCPROFILER_STATUS_SUCCESS; } @@ -57,15 +79,13 @@ template rocprofiler_status_t add_domain_op(domain_context& _cfg, DomainT _domain, uint32_t _op) { - if(_domain <= domain_info::none || _domain >= domain_info::last) + if(_domain <= domain_info::none || (_domain - 1) >= _cfg.array_size) return ROCPROFILER_STATUS_ERROR_KIND_NOT_FOUND; if(_op >= domain_info::padding) return ROCPROFILER_STATUS_ERROR_OPERATION_NOT_FOUND; - auto _offset = (_domain * domain_info::padding); - if(_offset >= _cfg.opcodes.size()) return ROCPROFILER_STATUS_ERROR_OPERATION_NOT_FOUND; - - _cfg.opcodes.set(_offset + _op, true); + auto _didx = (_domain - 1); + _cfg.opcodes.at(_didx).set(_op, true); return ROCPROFILER_STATUS_SUCCESS; } diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.hpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.hpp index 630379338e..9903d9452f 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.hpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/context/domain.hpp @@ -65,8 +65,16 @@ struct domain_context rocprofiler_buffer_tracing_kind_t>; static_assert(common::mpl::is_one_of::value, "Unsupported domain type"); - static constexpr auto opcode_padding_v = domain_info::padding; - static constexpr auto max_opcodes_v = opcode_padding_v * domain_info::last; + static constexpr auto none = domain_info::none; + static constexpr auto last = domain_info::last; + + static_assert(last > none, "last must be > none"); + + static constexpr int64_t array_size = (last - none - 1); + static constexpr auto span_size = domain_info::padding; + + using bitset_type = std::bitset; + using array_type = std::array; /// check if domain is enabled bool operator()(DomainT) const; @@ -74,8 +82,8 @@ struct domain_context /// check if op in a domain is enabled bool operator()(DomainT, uint32_t) const; - int64_t domains = 0; - std::bitset opcodes = {}; + uint64_t domains = 0; + array_type opcodes = {}; }; template diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/CMakeLists.txt b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/CMakeLists.txt index e489a1c8e4..e33377ff7f 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/CMakeLists.txt +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/CMakeLists.txt @@ -11,8 +11,8 @@ include(GoogleTest) # # -------------------------------------------------------------------------------------- # -set(rocprofiler_lib_sources agent.cpp buffer.cpp hsa.cpp naming.cpp timestamp.cpp - version.cpp hsa_barrier.cpp) +set(rocprofiler_lib_sources agent.cpp buffer.cpp contexts.cpp hsa.cpp naming.cpp + timestamp.cpp version.cpp hsa_barrier.cpp) add_executable(rocprofiler-lib-tests) target_sources(rocprofiler-lib-tests PRIVATE ${rocprofiler_lib_sources} details/agent.cpp) diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/contexts.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/contexts.cpp new file mode 100644 index 0000000000..2443b7d362 --- /dev/null +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/tests/contexts.cpp @@ -0,0 +1,384 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include + +#include "lib/common/demangle.hpp" +#include "lib/common/mpl.hpp" +#include "lib/common/utility.hpp" +#include "lib/rocprofiler-sdk/context/context.hpp" +#include "lib/rocprofiler-sdk/context/domain.hpp" + +#include +#include + +namespace context = ::rocprofiler::context; +namespace common = ::rocprofiler::common; + +namespace +{ +#define EXPECT_ROCP_SUCCESS(...) \ + EXPECT_EQ(ROCPROFILER_STATUS_SUCCESS, (__VA_ARGS__)) << #__VA_ARGS__ + +#define EXPECT_ROCP_SUCCESS_STREAM(_VAR_NAME, ...) \ + auto _VAR_NAME = (__VA_ARGS__); \ + EXPECT_EQ(ROCPROFILER_STATUS_SUCCESS, _VAR_NAME) << #__VA_ARGS__ << " :: " + +template +auto +get_operation_name_impl(Tp kind, uint32_t op) +{ + const char* opname = "()"; + + if constexpr(std::is_same::value) + EXPECT_ROCP_SUCCESS( + rocprofiler_query_callback_tracing_kind_operation_name(kind, op, &opname, nullptr)); + else if constexpr(std::is_same::value) + EXPECT_ROCP_SUCCESS( + rocprofiler_query_buffer_tracing_kind_operation_name(kind, op, &opname, nullptr)); + else + static_assert(common::mpl::assert_false::value, "invalid type"); + + return std::string_view{opname}; +} + +#define get_operation_name(...) get_operation_name_impl(__VA_ARGS__) + +template +auto +get_operations_impl(Tp kind) +{ + using opvector_t = std::map; + + auto iterate_operations = [](Tp _kind_v, rocprofiler_tracing_operation_t op, void* data) { + auto* _data = static_cast(data); + + _data->emplace(op, get_operation_name(_kind_v, op)); + return 0; + }; + + auto opdata = opvector_t{}; + if constexpr(std::is_same::value) + rocprofiler_iterate_callback_tracing_kind_operations(kind, iterate_operations, &opdata); + else if constexpr(std::is_same::value) + rocprofiler_iterate_buffer_tracing_kind_operations(kind, iterate_operations, &opdata); + else + static_assert(common::mpl::assert_false::value, "invalid type"); + + return opdata; +} + +#define get_operations(...) get_operations_impl(__VA_ARGS__) + +template +auto +get_domain_name(Tp idx_v) +{ + const char* _name = "()"; + + if constexpr(std::is_same::value) + EXPECT_ROCP_SUCCESS(rocprofiler_query_callback_tracing_kind_name(idx_v, &_name, nullptr)); + else if constexpr(std::is_same::value) + EXPECT_ROCP_SUCCESS(rocprofiler_query_buffer_tracing_kind_name(idx_v, &_name, nullptr)); + else + static_assert(common::mpl::assert_false::value, "invalid type"); + + return std::string_view{_name}; +} + +template +struct kind_info; + +template <> +struct kind_info +{ + using type = rocprofiler_callback_tracing_kind_t; +}; + +template <> +struct kind_info +{ + using type = rocprofiler_buffer_tracing_kind_t; +}; + +template +using kind_info_t = typename kind_info::type; + +template +auto +add_domain_impl(Tp* tracer, int val) +{ + using kind_type = kind_info_t; + + static auto type_name = common::cxx_demangle(typeid(kind_type).name()); + + auto idx = static_cast(val); + + auto loc_info = std::stringstream{}; + loc_info << type_name << " (kind=" << val << ") :: " << get_domain_name(idx); + + // should initially be false + EXPECT_FALSE(tracer->domains(idx)) << loc_info.str(); + + EXPECT_ROCP_SUCCESS_STREAM(_status, context::add_domain(tracer->domains, idx)) + << loc_info.str() << " returned " << _status + << " :: " << rocprofiler_get_status_string(_status); + EXPECT_TRUE(tracer->domains(idx)) << loc_info.str(); +} + +#define add_domain(...) add_domain_impl(__VA_ARGS__) + +template +auto +add_domain_op_impl(Tp* tracer, int val, uint32_t op) +{ + using kind_type = kind_info_t; + + static auto type_name = common::cxx_demangle(typeid(kind_type).name()); + + auto idx = static_cast(val); + + auto loc_info = std::stringstream{}; + loc_info << type_name << " (kind=" << val << ", op=" << op << ") :: " << get_domain_name(idx); + + // conditional enabling of domain + if(!tracer->domains(idx)) add_domain(tracer, val); + + EXPECT_ROCP_SUCCESS_STREAM(_status, context::add_domain_op(tracer->domains, idx, op)) + << loc_info.str() << " returned " << _status + << " :: " << rocprofiler_get_status_string(_status); + EXPECT_TRUE(tracer->domains(idx, op)) << loc_info.str(); +} + +#define add_domain_op(...) add_domain_op_impl(__VA_ARGS__) + +template +auto +check_operations_impl(Tp* tracer, int val, BoolT = {}) +{ + using kind_type = kind_info_t; + + auto idx = static_cast(val); + + auto operations = get_operations(idx); + for(auto oitr : operations) + { + if constexpr(BoolT::value) + { + EXPECT_TRUE(tracer->domains(idx, oitr.first)) + << get_domain_name(idx) << " (operation=" << oitr.first << "/" << oitr.second + << ")"; + } + else + { + EXPECT_FALSE(tracer->domains(idx, oitr.first)) + << get_domain_name(idx) << " (operation=" << oitr.first << "/" << oitr.second + << ")"; + } + } +} + +#define check_operations(...) check_operations_impl(__VA_ARGS__) + +template +auto +check_operation_impl(Tp* tracer, int val, int op, BoolT) +{ + using kind_type = kind_info_t; + + auto idx = static_cast(val); + + auto operations = get_operations(idx); + auto opname = operations.at(op); + + if constexpr(BoolT::value) + { + EXPECT_TRUE(tracer->domains(idx, op)) + << get_domain_name(idx) << " (operation=" << op << "/" << opname << ")"; + } + else + { + EXPECT_FALSE(tracer->domains(idx, op)) + << get_domain_name(idx) << " (operation=" << op << "/" << opname << ")"; + } +} + +#define check_operation(...) check_operation_impl(__VA_ARGS__) +} // namespace + +TEST(contexts, callback_tracing) +{ + constexpr auto none = ROCPROFILER_CALLBACK_TRACING_NONE; + constexpr auto last = ROCPROFILER_CALLBACK_TRACING_LAST; + + auto get_tracer = []() -> auto* + { + static auto ctx = context::context{}; + ctx.callback_tracer.reset(); + ctx.callback_tracer = std::make_unique(); + return ctx.callback_tracer.get(); + }; + + { + auto* tracer = get_tracer(); + + EXPECT_EQ(tracer->callback_data.size(), last); + + for(int i = none + 1; i < last; ++i) + { + auto idx = static_cast(i); + EXPECT_FALSE(tracer->domains(idx)) << "i=" << i << " :: " << get_domain_name(idx); + } + + for(int i = none + 1; i < last; ++i) + { + add_domain(tracer, i); + check_operations(tracer, i); + } + + check_operations(tracer, none, std::false_type{}); + check_operations(tracer, last, std::false_type{}); + } + + { + auto* tracer = get_tracer(); + + for(int i = last - 1; i > none; --i) + { + add_domain(tracer, i); + check_operations(tracer, i); + } + } + + { + auto* tracer = get_tracer(); + + auto fully_enabled = std::set{ROCPROFILER_CALLBACK_TRACING_HIP_RUNTIME_API, + ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT, + ROCPROFILER_CALLBACK_TRACING_MARKER_CONTROL_API, + ROCPROFILER_CALLBACK_TRACING_MARKER_CORE_API, + ROCPROFILER_CALLBACK_TRACING_MARKER_NAME_API}; + + for(auto i : fully_enabled) + { + add_domain(tracer, i); + check_operations(tracer, i); + } + + for(int i = none + 1; i < last; ++i) + { + if(fully_enabled.count(i) == 0) + { + check_operations(tracer, i, std::false_type{}); + } + } + + add_domain_op(tracer, + ROCPROFILER_CALLBACK_TRACING_HIP_COMPILER_API, + ROCPROFILER_HIP_COMPILER_API_ID___hipPushCallConfiguration); + + auto extra_enabled = fully_enabled; + extra_enabled.emplace(ROCPROFILER_CALLBACK_TRACING_HIP_COMPILER_API); + + for(auto itrv : extra_enabled) + { + auto itr = static_cast(itrv); + EXPECT_TRUE(tracer->domains(itr)) << get_domain_name(itr); + } + + check_operation(tracer, + ROCPROFILER_CALLBACK_TRACING_HIP_COMPILER_API, + ROCPROFILER_HIP_COMPILER_API_ID___hipPushCallConfiguration, + std::true_type{}); + + auto operations = get_operations(ROCPROFILER_CALLBACK_TRACING_HIP_COMPILER_API); + operations.erase(ROCPROFILER_HIP_COMPILER_API_ID___hipPushCallConfiguration); + + for(auto itr : operations) + { + check_operation(tracer, + ROCPROFILER_CALLBACK_TRACING_HIP_COMPILER_API, + itr.first, + std::false_type{}); + } + } + + { + auto* tracer = get_tracer(); + for(int i = none + 1; i < last; ++i) + { + check_operations(tracer, i, std::false_type{}); + } + } +} + +TEST(contexts, buffer_tracing) +{ + constexpr auto none = ROCPROFILER_BUFFER_TRACING_NONE; + constexpr auto last = ROCPROFILER_BUFFER_TRACING_LAST; + + auto get_tracer = []() -> auto* + { + static auto ctx = context::context{}; + ctx.buffered_tracer.reset(); + ctx.buffered_tracer = std::make_unique(); + return ctx.buffered_tracer.get(); + }; + + { + auto* tracer = get_tracer(); + + EXPECT_EQ(tracer->buffer_data.size(), last); + + for(int i = none + 1; i < last; ++i) + { + auto idx = static_cast(i); + EXPECT_FALSE(tracer->domains(idx)) << "i=" << i << " :: " << get_domain_name(idx); + } + + for(int i = none + 1; i < last; ++i) + { + add_domain(tracer, i); + check_operations(tracer, i); + } + } + + { + auto* tracer = get_tracer(); + for(int i = last - 1; i > none; --i) + { + add_domain(tracer, i); + check_operations(tracer, i); + } + } + + { + auto* tracer = get_tracer(); + for(int i = none + 1; i < last; ++i) + { + check_operations(tracer, i, std::false_type{}); + } + } +}