diff --git a/projects/rocprofiler-systems/source/bin/rocprof-sys-run/impl.cpp b/projects/rocprofiler-systems/source/bin/rocprof-sys-run/impl.cpp index f869ec8ca0..a19121529b 100644 --- a/projects/rocprofiler-systems/source/bin/rocprof-sys-run/impl.cpp +++ b/projects/rocprofiler-systems/source/bin/rocprof-sys-run/impl.cpp @@ -192,6 +192,10 @@ prepare_environment_for_run(parser_data_t& _data) rocprofsys::argparse::add_ld_preload(_data); rocprofsys::argparse::add_ld_library_path(_data); } + + rocprofsys::argparse::add_torch_library_path(_data, _data.verbose > 0); + + rocprofsys::common::consolidate_env_entries(_data.current); } void diff --git a/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/impl.cpp b/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/impl.cpp index 6a4b27edb4..96008b2dc6 100644 --- a/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/impl.cpp +++ b/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/impl.cpp @@ -933,3 +933,9 @@ parse_args(int argc, char** argv, std::vector& _env) return _outv; } + +void +add_torch_library_path(std::vector& envp, const std::vector& argv) +{ + rocprofsys::common::add_torch_library_path(envp, argv, verbose > 0, updated_envs); +} diff --git a/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.cpp b/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.cpp index 122d02e59d..03098fa03a 100644 --- a/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.cpp +++ b/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.cpp @@ -51,6 +51,8 @@ main(int argc, char** argv) _argv.emplace_back(argv[i]); } + add_torch_library_path(_env, _argv); + print_updated_environment(_env); if(!_argv.empty()) diff --git a/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.hpp b/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.hpp index 2134bc0680..8878d92635 100644 --- a/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.hpp +++ b/projects/rocprofiler-systems/source/bin/rocprof-sys-sample/rocprof-sys-sample.hpp @@ -35,3 +35,6 @@ get_initial_environment(); std::vector parse_args(int argc, char** argv, std::vector& envp); + +void +add_torch_library_path(std::vector& envp, const std::vector& argv); diff --git a/projects/rocprofiler-systems/source/lib/common/environment.hpp b/projects/rocprofiler-systems/source/lib/common/environment.hpp index 7d2f45eceb..1ff89d355a 100644 --- a/projects/rocprofiler-systems/source/lib/common/environment.hpp +++ b/projects/rocprofiler-systems/source/lib/common/environment.hpp @@ -26,9 +26,12 @@ #include "common/join.hpp" #include +#include #include #include #include +#include +#include #include #include #include @@ -197,7 +200,7 @@ remove_env(std::vector& _environ, std::string_view _env_var, { if(match(itr)) { - free(itr); + std::free(itr); itr = nullptr; } } @@ -266,6 +269,113 @@ discover_llvm_libdir_for_ompt(bool verbose = false) return {}; } +inline bool +is_python_interpreter(std::string_view executable) +{ + if(executable.empty()) return false; + + const auto slash_pos = executable.rfind('/'); + const auto basename = (slash_pos != std::string_view::npos) + ? executable.substr(slash_pos + 1) + : executable; + + if(basename == "python" || basename == "python3") return true; + + constexpr std::string_view python3_prefix = "python3."; + + const bool has_valid_prefix = + basename.size() > python3_prefix.size() && + basename.substr(0, python3_prefix.size()) == python3_prefix; + if(!has_valid_prefix) return false; + + const auto version_digits = basename.substr(python3_prefix.size()); + + return std::all_of(version_digits.begin(), version_digits.end(), + [](unsigned char c) { return std::isdigit(c); }); +} + +inline std::string +discover_torch_libpath(const std::string& python_binary, bool verbose = false) +{ + if(python_binary.empty()) return {}; + + const auto is_safe_executable_path = [](const std::string& path) { + // Allow only a conservative set of characters in the executable path to + // avoid injection when used in a shell command. + for(unsigned char c : path) + { + if(std::isalnum(c) != 0) continue; + switch(c) + { + case '/': + case '.': + case '_': + case '-': + case '+': break; + default: return false; + } + } + return true; + }; + + if(!is_safe_executable_path(python_binary)) + { + ROCPROFSYS_ENVIRON_LOG( + verbose, "Unsafe characters detected in Python interpreter path: %s\n", + python_binary.c_str()); + return {}; + } + + const auto cmd = "\"" + python_binary + + "\" -c \"import torch; print(torch.__path__[0])\" 2>/dev/null"; + + FILE* pipe = popen(cmd.c_str(), "r"); + if(!pipe) + { + ROCPROFSYS_ENVIRON_LOG(verbose, "Failed to execute command: %s\n", cmd.c_str()); + return {}; + } + + char buffer[1024]; + std::string result; + while(fgets(buffer, sizeof(buffer), pipe)) + { + result.append(buffer); + // stop if we've read the full line (torch path is printed on a single line) + if(!result.empty() && result.back() == '\n') break; + } + + int status = pclose(pipe); + + if(status != 0 || result.empty()) + { + ROCPROFSYS_ENVIRON_LOG(verbose, "torch not found for Python interpreter: %s\n", + python_binary.c_str()); + return {}; + } + + while(!result.empty() && + (result.back() == '\n' || result.back() == '\r' || result.back() == ' ')) + { + result.pop_back(); + } + + if(result.empty()) return {}; + + std::string torch_libdir = result + "/lib"; + + if(!::tim::filepath::direxists(torch_libdir)) + { + ROCPROFSYS_ENVIRON_LOG(verbose, "torch lib directory does not exist: %s\n", + torch_libdir.c_str()); + return {}; + } + + ROCPROFSYS_ENVIRON_LOG(verbose, "Discovered torch library path: %s\n", + torch_libdir.c_str()); + return torch_libdir; +} + enum class update_mode : uint8_t { REPLACE = 0, @@ -335,7 +445,7 @@ update_env(std::vector& _environ, std::string_view _env_var, Tp&& _env_va } else { - free(itr); + std::free(itr); itr = strdup(join('=', _env_var, _env_val_str).c_str()); } return; @@ -343,5 +453,145 @@ update_env(std::vector& _environ, std::string_view _env_var, Tp&& _env_va _environ.emplace_back(strdup(join('=', _env_var, _env_val_str).c_str())); } +template +inline void +add_torch_library_path(std::vector& envp, const std::vector& argv, + bool verbose, UpdatedEnvsT& updated_envs) +{ + if(argv.empty() || argv.front() == nullptr) return; + if(!is_python_interpreter(argv.front())) return; + + auto torch_libpath = discover_torch_libpath(argv.front(), verbose); + if(torch_libpath.empty()) return; + + std::unordered_set seen{ torch_libpath }; + std::string result = torch_libpath; + + constexpr std::string_view ld_prefix = "LD_LIBRARY_PATH="; + + auto is_ld_path = [&](char* entry) { + return entry && + std::string_view{ entry }.substr(0, ld_prefix.length()) == ld_prefix; + }; + + for(auto& entry : envp) + { + if(!is_ld_path(entry)) continue; + + std::istringstream stream{ std::string{ entry + ld_prefix.length() } }; + for(std::string path; std::getline(stream, path, ':');) + { + if(!path.empty() && seen.insert(path).second) result += ":" + path; + } + + std::free(entry); + entry = nullptr; + } + + envp.erase(std::remove(envp.begin(), envp.end(), nullptr), envp.end()); + envp.emplace_back(strdup(join("", ld_prefix, result).c_str())); + + updated_envs.emplace(ld_prefix.substr(0, ld_prefix.length() - 1)); +} + +inline void +consolidate_env_entries(std::vector& envp) +{ + constexpr char delim = ':'; + + struct key_data + { + std::vector parts; + std::unordered_set seen; + + void add_unique(std::string part) + { + if(!part.empty() && seen.insert(part).second) + parts.emplace_back(std::move(part)); + } + }; + + auto parse_entry = [](std::string_view entry) + -> std::optional> { + auto eq_pos = entry.find('='); + if(eq_pos == std::string_view::npos) return std::nullopt; + return std::make_pair(entry.substr(0, eq_pos), entry.substr(eq_pos + 1)); + }; + + auto join_parts = [delim](std::string_view key, + const std::vector& parts) { + std::string result; + + const auto total_parts_length = std::accumulate( + parts.begin(), parts.end(), std::size_t{ 0 }, + [](std::size_t acc, const std::string& part) { return acc + part.size(); }); + + const auto delim_count = parts.size() - 1; + const auto equal_sign_length = 1; + + result.reserve(key.size() + equal_sign_length + total_parts_length + delim_count); + result.append(key); + result += '='; + + result = + std::accumulate(parts.begin(), parts.end(), std::move(result), + [delim, &parts](std::string acc, const std::string& part) { + if(part != parts.front()) acc += delim; + acc.append(part); + return acc; + }); + + return result; + }; + + std::unordered_map key_map; + std::vector key_order; + + for(auto* entry : envp) + { + if(!entry) + { + continue; + } + + auto parsed = parse_entry(entry); + if(!parsed) + { + continue; + } + + auto [key, value] = *parsed; + + auto [it, inserted] = key_map.try_emplace(key); + if(inserted) + { + key_order.emplace_back(key); + } + + auto& data = it->second; + std::istringstream stream{ std::string{ value } }; + for(std::string part; std::getline(stream, part, delim);) + { + data.add_unique(part); + } + } + + std::vector result; + result.reserve(key_order.size()); + + for(auto key : key_order) + { + result.emplace_back(strdup(join_parts(key, key_map[key].parts).c_str())); + } + + for(auto* entry : envp) + { + std::free(entry); + entry = nullptr; + } + + envp = std::move(result); +} + } // namespace common } // namespace rocprofsys diff --git a/projects/rocprofiler-systems/source/lib/common/tests/CMakeLists.txt b/projects/rocprofiler-systems/source/lib/common/tests/CMakeLists.txt index 79d3f4ad03..4e2391824c 100644 --- a/projects/rocprofiler-systems/source/lib/common/tests/CMakeLists.txt +++ b/projects/rocprofiler-systems/source/lib/common/tests/CMakeLists.txt @@ -24,6 +24,7 @@ add_library( lib-common-tests OBJECT test_discover_llvm_libdir.cpp + test_environment.cpp test_path.cpp test_remove_env.cpp test_update_env.cpp diff --git a/projects/rocprofiler-systems/source/lib/common/tests/test_environment.cpp b/projects/rocprofiler-systems/source/lib/common/tests/test_environment.cpp new file mode 100644 index 0000000000..7c33603aa1 --- /dev/null +++ b/projects/rocprofiler-systems/source/lib/common/tests/test_environment.cpp @@ -0,0 +1,146 @@ +// Copyright (c) Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + +#include "common/environment.hpp" + +#include + +using namespace rocprofsys::common; + +class IsPythonInterpreterTest : public ::testing::Test +{}; + +TEST_F(IsPythonInterpreterTest, RecognizesPython) +{ + EXPECT_TRUE(is_python_interpreter("python")); + EXPECT_TRUE(is_python_interpreter("python3")); + EXPECT_TRUE(is_python_interpreter("python3.8")); + EXPECT_TRUE(is_python_interpreter("python3.9")); + EXPECT_TRUE(is_python_interpreter("python3.10")); + EXPECT_TRUE(is_python_interpreter("python3.11")); + EXPECT_TRUE(is_python_interpreter("python3.12")); + EXPECT_TRUE(is_python_interpreter("/usr/bin/python")); + EXPECT_TRUE(is_python_interpreter("/usr/bin/python3")); + EXPECT_TRUE(is_python_interpreter("/usr/bin/python3.10")); + EXPECT_TRUE(is_python_interpreter("/home/user/venv/bin/python")); + EXPECT_TRUE(is_python_interpreter("/opt/conda/bin/python3.11")); + EXPECT_FALSE(is_python_interpreter("bash")); + EXPECT_FALSE(is_python_interpreter("sh")); + EXPECT_FALSE(is_python_interpreter("ruby")); + EXPECT_FALSE(is_python_interpreter("node")); + EXPECT_FALSE(is_python_interpreter("java")); + EXPECT_FALSE(is_python_interpreter("/usr/bin/bash")); + EXPECT_FALSE(is_python_interpreter("./my_app")); + EXPECT_FALSE(is_python_interpreter("pythonista")); + EXPECT_FALSE(is_python_interpreter("python_script.py")); + EXPECT_FALSE(is_python_interpreter("mypython")); + EXPECT_FALSE(is_python_interpreter("python2")); + EXPECT_FALSE(is_python_interpreter("python3.")); + EXPECT_FALSE(is_python_interpreter("python3.a")); + EXPECT_FALSE(is_python_interpreter("python3.10a")); + EXPECT_FALSE(is_python_interpreter("python3x10")); + EXPECT_FALSE(is_python_interpreter("")); + EXPECT_FALSE(is_python_interpreter("/usr/bin/")); +} + +class DuplicatedEnvironmentEntriesTest : public ::testing::Test +{}; + +TEST_F(DuplicatedEnvironmentEntriesTest, DuplicateEnvironmentEntries) +{ + std::vector env_vars = { + strdup("PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/bin2"), + strdup("PATH=/usr/local/bin:/usr/bin:/bin"), + }; + + consolidate_env_entries(env_vars); + + ASSERT_EQ(env_vars.size(), 1); + EXPECT_STREQ(env_vars[0], "PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/bin2"); + + for(auto* entry : env_vars) + free(entry); +} + +TEST_F(DuplicatedEnvironmentEntriesTest, HandlesEmptyVector) +{ + std::vector env_vars; + consolidate_env_entries(env_vars); + EXPECT_TRUE(env_vars.empty()); +} + +TEST_F(DuplicatedEnvironmentEntriesTest, HandlesNullEntries) +{ + std::vector env_vars = { + strdup("PATH=/usr/bin"), + nullptr, + strdup("PATH=/bin"), + }; + consolidate_env_entries(env_vars); + ASSERT_EQ(env_vars.size(), 1); + EXPECT_STREQ(env_vars[0], "PATH=/usr/bin:/bin"); + for(auto* entry : env_vars) + std::free(entry); +} + +TEST_F(DuplicatedEnvironmentEntriesTest, HandlesEmptyValues) +{ + std::vector env_vars = { + strdup("EMPTY_VAR="), + strdup("PATH=/usr/bin"), + }; + consolidate_env_entries(env_vars); + ASSERT_EQ(env_vars.size(), 2); + + for(auto* entry : env_vars) + std::free(entry); +} + +class AddTorchLibraryPathTest : public ::testing::Test +{ +protected: + std::unordered_set updated_envs; +}; + +TEST_F(AddTorchLibraryPathTest, SkipsNonPythonExecutables) +{ + std::vector envp = { + strdup("LD_LIBRARY_PATH=/usr/lib"), + }; + std::vector argv = { + strdup("/usr/bin/bash"), + }; + add_torch_library_path(envp, argv, false, updated_envs); + // Should not modify environment + ASSERT_EQ(envp.size(), 1); + EXPECT_STREQ(envp[0], "LD_LIBRARY_PATH=/usr/lib"); + for(auto* entry : envp) + std::free(entry); + for(auto* entry : argv) + std::free(entry); +} + +TEST_F(AddTorchLibraryPathTest, HandlesEmptyArgv) +{ + std::vector envp = { + strdup("LD_LIBRARY_PATH=/usr/lib"), + }; + std::vector argv; + add_torch_library_path(envp, argv, false, updated_envs); + ASSERT_EQ(envp.size(), 1); + EXPECT_STREQ(envp[0], "LD_LIBRARY_PATH=/usr/lib"); + for(auto* entry : envp) + std::free(entry); +} + +TEST_F(AddTorchLibraryPathTest, HandlesNullArgvFront) +{ + std::vector envp = { + strdup("LD_LIBRARY_PATH=/usr/lib"), + }; + std::vector argv = { nullptr }; + add_torch_library_path(envp, argv, false, updated_envs); + ASSERT_EQ(envp.size(), 1); + for(auto* entry : envp) + std::free(entry); +} diff --git a/projects/rocprofiler-systems/source/lib/core/argparse.cpp b/projects/rocprofiler-systems/source/lib/core/argparse.cpp index c4230ff807..b4ea4c41d1 100644 --- a/projects/rocprofiler-systems/source/lib/core/argparse.cpp +++ b/projects/rocprofiler-systems/source/lib/core/argparse.cpp @@ -168,6 +168,14 @@ add_ld_library_path(parser_data& _data) return _data; } +parser_data& +add_torch_library_path(parser_data& _data, bool verbose) +{ + rocprofsys::common::add_torch_library_path(_data.current, _data.command, verbose, + _data.updated); + return _data; +} + parser_data& add_core_arguments(parser_t& _parser, parser_data& _data) { diff --git a/projects/rocprofiler-systems/source/lib/core/argparse.hpp b/projects/rocprofiler-systems/source/lib/core/argparse.hpp index 9280ddd4f9..744c097b0e 100644 --- a/projects/rocprofiler-systems/source/lib/core/argparse.hpp +++ b/projects/rocprofiler-systems/source/lib/core/argparse.hpp @@ -83,6 +83,9 @@ add_ld_preload(parser_data&); parser_data& add_ld_library_path(parser_data&); +parser_data& +add_torch_library_path(parser_data&, bool verbose = false); + parser_data& add_core_arguments(parser_t&, parser_data&);