diff --git a/projects/rocprofiler-systems/source/lib/include/library/components/pthread_gotcha.hpp b/projects/rocprofiler-systems/source/lib/include/library/components/pthread_gotcha.hpp index 51817756d9..e58e8a8494 100644 --- a/projects/rocprofiler-systems/source/lib/include/library/components/pthread_gotcha.hpp +++ b/projects/rocprofiler-systems/source/lib/include/library/components/pthread_gotcha.hpp @@ -26,6 +26,7 @@ #include "library/defines.hpp" #include "library/timemory.hpp" +#include #include namespace omnitrace @@ -59,12 +60,24 @@ struct pthread_gotcha : tim::component::base static void configure(); static void shutdown(); - // threads can set this to avoid starting sampling on child threads - static bool& enable_sampling_on_child_threads(); + // query current value + static bool sampling_enabled_on_child_threads(); + + // use this to disable sampling in a region (e.g. right before thread creation) + static bool push_enable_sampling_on_child_threads(bool _v); + + // use this to restore previous setting + static bool pop_enable_sampling_on_child_threads(); + + // make sure every newly created thead starts with this value + static void set_sampling_on_all_future_threads(bool _v); // pthread_create int operator()(pthread_t* thread, const pthread_attr_t* attr, void* (*start_routine)(void*), void* arg) const; + +private: + static bool& sampling_on_child_threads(); }; using pthread_gotcha_t = tim::component::gotcha<2, std::tuple<>, pthread_gotcha>; diff --git a/projects/rocprofiler-systems/source/lib/src/library.cpp b/projects/rocprofiler-systems/source/lib/src/library.cpp index b67ce203f3..fcf1fa332b 100644 --- a/projects/rocprofiler-systems/source/lib/src/library.cpp +++ b/projects/rocprofiler-systems/source/lib/src/library.cpp @@ -26,6 +26,7 @@ #include "library/components/functors.hpp" #include "library/components/fwd.hpp" #include "library/components/mpi_gotcha.hpp" +#include "library/components/pthread_gotcha.hpp" #include "library/config.hpp" #include "library/critical_trace.hpp" #include "library/debug.hpp" @@ -383,7 +384,7 @@ omnitrace_init_library_hidden() OMNITRACE_CONDITIONAL_PRINT_F(get_verbose() >= 0, "Disabling critical trace in %s mode...\n", std::to_string(_mode).c_str()); - get_use_sampling() = true; + get_use_sampling() = tim::get_env("OMNITRACE_USE_SAMPLING", true); get_use_critical_trace() = false; } @@ -441,10 +442,11 @@ omnitrace_init_tooling_hidden() auto _dtor = scope::destructor{ []() { if(get_use_sampling()) { - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); thread_sampler::setup(); sampling::setup(); - pthread_gotcha::enable_sampling_on_child_threads() = true; + pthread_gotcha::pop_enable_sampling_on_child_threads(); + pthread_gotcha::push_enable_sampling_on_child_threads(get_use_sampling()); sampling::unblock_signals(); } get_main_bundle()->start(); @@ -453,7 +455,7 @@ omnitrace_init_tooling_hidden() if(get_use_sampling()) { - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); sampling::block_signals(); } @@ -692,6 +694,10 @@ omnitrace_init_hidden(const char* _mode, bool _is_binary_rewrite, const char* _a tim::set_env("OMNITRACE_MODE", _mode, 0); config::is_binary_rewrite() = _is_binary_rewrite; + // set OMNITRACE_USE_SAMPLING to ON by default if mode is sampling + tim::set_env("OMNITRACE_USE_SAMPLING", (get_mode() == Mode::Sampling) ? "ON" : "OFF", + 0); + // default to KokkosP enabled when sampling, otherwise default to off tim::set_env("OMNITRACE_USE_KOKKOSP", (get_mode() == Mode::Sampling) ? "ON" : "OFF", 0); @@ -738,7 +744,8 @@ omnitrace_finalize_hidden(void) library_functors::configure([](const char*) {}, [](const char*) {}); - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); + pthread_gotcha::set_sampling_on_all_future_threads(false); auto _debug_init = get_debug_finalize(); auto _debug_value = get_debug(); diff --git a/projects/rocprofiler-systems/source/lib/src/library/components/backtrace.cpp b/projects/rocprofiler-systems/source/lib/src/library/components/backtrace.cpp index ff9e33e053..81d699bc3d 100644 --- a/projects/rocprofiler-systems/source/lib/src/library/components/backtrace.cpp +++ b/projects/rocprofiler-systems/source/lib/src/library/components/backtrace.cpp @@ -595,10 +595,9 @@ backtrace::post_process(int64_t _tid) _process_perfetto(_data, false); else { - auto _v = pthread_gotcha::enable_sampling_on_child_threads(); - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); std::thread{ _process_perfetto, _data, true }.join(); - pthread_gotcha::enable_sampling_on_child_threads() = _v; + pthread_gotcha::pop_enable_sampling_on_child_threads(); } } diff --git a/projects/rocprofiler-systems/source/lib/src/library/components/pthread_gotcha.cpp b/projects/rocprofiler-systems/source/lib/src/library/components/pthread_gotcha.cpp index 71de8b81dd..cdadbff29f 100644 --- a/projects/rocprofiler-systems/source/lib/src/library/components/pthread_gotcha.cpp +++ b/projects/rocprofiler-systems/source/lib/src/library/components/pthread_gotcha.cpp @@ -26,6 +26,7 @@ #include "library/config.hpp" #include "library/debug.hpp" #include "library/sampling.hpp" +#include "library/thread_data.hpp" #include #include @@ -83,6 +84,21 @@ stop_bundle(bundle_t& _bundle, int64_t _tid) // exclude popping wall-clock _bundle.pop(main_pw_t{}, _tid); } + +auto +get_thread_index() +{ + static std::atomic _c{ 0 }; + static thread_local int64_t _v = _c++; + return _v; +} + +auto& +get_sampling_on_child_threads_history(int64_t _idx = get_thread_index()) +{ + static auto _v = std::array, OMNITRACE_MAX_THREADS>{}; + return _v.at(_idx); +} } // namespace pthread_gotcha::wrapper::wrapper(routine_t _routine, void* _arg, bool _enable_sampling, @@ -99,11 +115,10 @@ pthread_gotcha::wrapper::operator()() const { std::shared_ptr _bundle{}; std::set _signals{}; - auto& _enable_sampling = pthread_gotcha::enable_sampling_on_child_threads(); - auto _active = (get_state() == omnitrace::State::Active); - int64_t _tid = -1; - auto _is_sampling = false; - auto _dtor = scope::destructor{ [&]() { + auto _active = (get_state() == omnitrace::State::Active); + int64_t _tid = -1; + auto _is_sampling = false; + auto _dtor = scope::destructor{ [&]() { if(_is_sampling) { sampling::block_signals(_signals); @@ -121,7 +136,7 @@ pthread_gotcha::wrapper::operator()() const if(_active) get_cpu_cid_stack(threading::get_id(), m_parent_tid); - if(m_enable_sampling && _enable_sampling && _active) + if(m_enable_sampling && _active) { _tid = threading::get_id(); threading::set_thread_name(TIMEMORY_JOIN(" ", "Thread", _tid).c_str()); @@ -133,10 +148,10 @@ pthread_gotcha::wrapper::operator()() const .first->second; } if(_bundle) start_bundle(*_bundle); - _is_sampling = true; - _enable_sampling = false; - _signals = sampling::setup(); - _enable_sampling = true; + _is_sampling = true; + push_enable_sampling_on_child_threads(false); + _signals = sampling::setup(); + pop_enable_sampling_on_child_threads(); sampling::unblock_signals(); } @@ -187,10 +202,48 @@ pthread_gotcha::shutdown() bundles.clear(); } -bool& -pthread_gotcha::enable_sampling_on_child_threads() +bool +pthread_gotcha::sampling_enabled_on_child_threads() { - static thread_local bool _v = get_use_sampling(); + return sampling_on_child_threads(); +} + +bool +pthread_gotcha::push_enable_sampling_on_child_threads(bool _v) +{ + auto& _hist = get_sampling_on_child_threads_history(); + bool _last = sampling_on_child_threads(); + _hist.emplace_back(_last); + sampling_on_child_threads() = _v; + return _last; +} + +bool +pthread_gotcha::pop_enable_sampling_on_child_threads() +{ + auto& _hist = get_sampling_on_child_threads_history(); + if(!_hist.empty()) + { + bool _restored = _hist.back(); + _hist.pop_back(); + sampling_on_child_threads() = _restored; + } + return sampling_on_child_threads(); +} + +void +pthread_gotcha::set_sampling_on_all_future_threads(bool _v) +{ + for(size_t i = 0; i < max_supported_threads; ++i) + get_sampling_on_child_threads_history(i).emplace_back(_v); +} + +bool& +pthread_gotcha::sampling_on_child_threads() +{ + static thread_local bool _v = get_sampling_on_child_threads_history().empty() + ? false + : get_sampling_on_child_threads_history().back(); return _v; } @@ -200,7 +253,7 @@ pthread_gotcha::operator()(pthread_t* thread, const pthread_attr_t* attr, void* (*start_routine)(void*), void* arg) const { bundle_t _bundle{ "pthread_create" }; - auto _enable_sampling = enable_sampling_on_child_threads(); + auto _enable_sampling = sampling_enabled_on_child_threads(); auto _active = (get_state() == omnitrace::State::Active); int64_t _tid = (_active) ? threading::get_id() : 0; diff --git a/projects/rocprofiler-systems/source/lib/src/library/components/rocm_smi.cpp b/projects/rocprofiler-systems/source/lib/src/library/components/rocm_smi.cpp index fcbd1ebd6d..ccf922aafd 100644 --- a/projects/rocprofiler-systems/source/lib/src/library/components/rocm_smi.cpp +++ b/projects/rocprofiler-systems/source/lib/src/library/components/rocm_smi.cpp @@ -347,8 +347,7 @@ setup() if(is_initialized() || !get_use_rocm_smi()) return; - auto _enable_samp = pthread_gotcha::enable_sampling_on_child_threads(); - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); // assign the data value to determined by rocm-smi data::device_count = device_count(); @@ -387,7 +386,7 @@ setup() data::setup(); - pthread_gotcha::enable_sampling_on_child_threads() = _enable_samp; + pthread_gotcha::pop_enable_sampling_on_child_threads(); } void diff --git a/projects/rocprofiler-systems/source/lib/src/library/components/roctracer.cpp b/projects/rocprofiler-systems/source/lib/src/library/components/roctracer.cpp index 396a8cdfe0..868f43eadb 100644 --- a/projects/rocprofiler-systems/source/lib/src/library/components/roctracer.cpp +++ b/projects/rocprofiler-systems/source/lib/src/library/components/roctracer.cpp @@ -30,7 +30,6 @@ #include "library/sampling.hpp" #include "library/thread_data.hpp" -namespace rocm_smi = omnitrace::rocm_smi; using namespace omnitrace; namespace tim @@ -204,7 +203,7 @@ extern "C" bool OnLoad(HsaApiTable* table, uint64_t runtime_version, uint64_t failed_tool_count, const char* const* failed_tool_names) { - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); OMNITRACE_CONDITIONAL_BASIC_PRINT(get_debug_env() || get_verbose_env() > 0, "[%s]\n", __FUNCTION__); tim::consume_parameters(table, runtime_version, failed_tool_count, @@ -297,7 +296,7 @@ extern "C" rocm_smi::set_state(State::Active); comp::roctracer::setup(); - pthread_gotcha::enable_sampling_on_child_threads() = true; + pthread_gotcha::pop_enable_sampling_on_child_threads(); return true; } diff --git a/projects/rocprofiler-systems/source/lib/src/library/thread_sampler.cpp b/projects/rocprofiler-systems/source/lib/src/library/thread_sampler.cpp index 7ea9a40b5e..666ecda9c5 100644 --- a/projects/rocprofiler-systems/source/lib/src/library/thread_sampler.cpp +++ b/projects/rocprofiler-systems/source/lib/src/library/thread_sampler.cpp @@ -104,8 +104,7 @@ sampler::setup() // shutdown if already running shutdown(); - auto _enable_samp = pthread_gotcha::enable_sampling_on_child_threads(); - pthread_gotcha::enable_sampling_on_child_threads() = false; + pthread_gotcha::push_enable_sampling_on_child_threads(false); if(get_use_rocm_smi()) { @@ -142,7 +141,7 @@ sampler::setup() _fut.wait(); - pthread_gotcha::enable_sampling_on_child_threads() = _enable_samp; + pthread_gotcha::pop_enable_sampling_on_child_threads(); set_state(State::Active); }