Add automatic PyTorch library discovery for Python applications (#2623)

* Add automatic PyTorch library discovery for Python applications (#2623)
This commit is contained in:
Milan Radosavljevic
2026-01-20 08:42:49 +01:00
کامیت شده توسط GitHub
والد c83b3aae07
کامیت b533f56197
9فایلهای تغییر یافته به همراه425 افزوده شده و 2 حذف شده
@@ -192,6 +192,10 @@ prepare_environment_for_run(parser_data_t& _data)
rocprofsys::argparse::add_ld_preload(_data);
rocprofsys::argparse::add_ld_library_path(_data);
}
rocprofsys::argparse::add_torch_library_path(_data, _data.verbose > 0);
rocprofsys::common::consolidate_env_entries(_data.current);
}
void
@@ -933,3 +933,9 @@ parse_args(int argc, char** argv, std::vector<char*>& _env)
return _outv;
}
void
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv)
{
rocprofsys::common::add_torch_library_path(envp, argv, verbose > 0, updated_envs);
}
@@ -51,6 +51,8 @@ main(int argc, char** argv)
_argv.emplace_back(argv[i]);
}
add_torch_library_path(_env, _argv);
print_updated_environment(_env);
if(!_argv.empty())
@@ -35,3 +35,6 @@ get_initial_environment();
std::vector<char*>
parse_args(int argc, char** argv, std::vector<char*>& envp);
void
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv);
@@ -26,9 +26,12 @@
#include "common/join.hpp"
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
@@ -197,7 +200,7 @@ remove_env(std::vector<char*>& _environ, std::string_view _env_var,
{
if(match(itr))
{
free(itr);
std::free(itr);
itr = nullptr;
}
}
@@ -266,6 +269,113 @@ discover_llvm_libdir_for_ompt(bool verbose = false)
return {};
}
inline bool
is_python_interpreter(std::string_view executable)
{
if(executable.empty()) return false;
const auto slash_pos = executable.rfind('/');
const auto basename = (slash_pos != std::string_view::npos)
? executable.substr(slash_pos + 1)
: executable;
if(basename == "python" || basename == "python3") return true;
constexpr std::string_view python3_prefix = "python3.";
const bool has_valid_prefix =
basename.size() > python3_prefix.size() &&
basename.substr(0, python3_prefix.size()) == python3_prefix;
if(!has_valid_prefix) return false;
const auto version_digits = basename.substr(python3_prefix.size());
return std::all_of(version_digits.begin(), version_digits.end(),
[](unsigned char c) { return std::isdigit(c); });
}
inline std::string
discover_torch_libpath(const std::string& python_binary, bool verbose = false)
{
if(python_binary.empty()) return {};
const auto is_safe_executable_path = [](const std::string& path) {
// Allow only a conservative set of characters in the executable path to
// avoid injection when used in a shell command.
for(unsigned char c : path)
{
if(std::isalnum(c) != 0) continue;
switch(c)
{
case '/':
case '.':
case '_':
case '-':
case '+': break;
default: return false;
}
}
return true;
};
if(!is_safe_executable_path(python_binary))
{
ROCPROFSYS_ENVIRON_LOG(
verbose, "Unsafe characters detected in Python interpreter path: %s\n",
python_binary.c_str());
return {};
}
const auto cmd = "\"" + python_binary +
"\" -c \"import torch; print(torch.__path__[0])\" 2>/dev/null";
FILE* pipe = popen(cmd.c_str(), "r");
if(!pipe)
{
ROCPROFSYS_ENVIRON_LOG(verbose, "Failed to execute command: %s\n", cmd.c_str());
return {};
}
char buffer[1024];
std::string result;
while(fgets(buffer, sizeof(buffer), pipe))
{
result.append(buffer);
// stop if we've read the full line (torch path is printed on a single line)
if(!result.empty() && result.back() == '\n') break;
}
int status = pclose(pipe);
if(status != 0 || result.empty())
{
ROCPROFSYS_ENVIRON_LOG(verbose, "torch not found for Python interpreter: %s\n",
python_binary.c_str());
return {};
}
while(!result.empty() &&
(result.back() == '\n' || result.back() == '\r' || result.back() == ' '))
{
result.pop_back();
}
if(result.empty()) return {};
std::string torch_libdir = result + "/lib";
if(!::tim::filepath::direxists(torch_libdir))
{
ROCPROFSYS_ENVIRON_LOG(verbose, "torch lib directory does not exist: %s\n",
torch_libdir.c_str());
return {};
}
ROCPROFSYS_ENVIRON_LOG(verbose, "Discovered torch library path: %s\n",
torch_libdir.c_str());
return torch_libdir;
}
enum class update_mode : uint8_t
{
REPLACE = 0,
@@ -335,7 +445,7 @@ update_env(std::vector<char*>& _environ, std::string_view _env_var, Tp&& _env_va
}
else
{
free(itr);
std::free(itr);
itr = strdup(join('=', _env_var, _env_val_str).c_str());
}
return;
@@ -343,5 +453,145 @@ update_env(std::vector<char*>& _environ, std::string_view _env_var, Tp&& _env_va
_environ.emplace_back(strdup(join('=', _env_var, _env_val_str).c_str()));
}
template <typename UpdatedEnvsT>
inline void
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv,
bool verbose, UpdatedEnvsT& updated_envs)
{
if(argv.empty() || argv.front() == nullptr) return;
if(!is_python_interpreter(argv.front())) return;
auto torch_libpath = discover_torch_libpath(argv.front(), verbose);
if(torch_libpath.empty()) return;
std::unordered_set<std::string> seen{ torch_libpath };
std::string result = torch_libpath;
constexpr std::string_view ld_prefix = "LD_LIBRARY_PATH=";
auto is_ld_path = [&](char* entry) {
return entry &&
std::string_view{ entry }.substr(0, ld_prefix.length()) == ld_prefix;
};
for(auto& entry : envp)
{
if(!is_ld_path(entry)) continue;
std::istringstream stream{ std::string{ entry + ld_prefix.length() } };
for(std::string path; std::getline(stream, path, ':');)
{
if(!path.empty() && seen.insert(path).second) result += ":" + path;
}
std::free(entry);
entry = nullptr;
}
envp.erase(std::remove(envp.begin(), envp.end(), nullptr), envp.end());
envp.emplace_back(strdup(join("", ld_prefix, result).c_str()));
updated_envs.emplace(ld_prefix.substr(0, ld_prefix.length() - 1));
}
inline void
consolidate_env_entries(std::vector<char*>& envp)
{
constexpr char delim = ':';
struct key_data
{
std::vector<std::string> parts;
std::unordered_set<std::string> seen;
void add_unique(std::string part)
{
if(!part.empty() && seen.insert(part).second)
parts.emplace_back(std::move(part));
}
};
auto parse_entry = [](std::string_view entry)
-> std::optional<std::pair<std::string_view, std::string_view>> {
auto eq_pos = entry.find('=');
if(eq_pos == std::string_view::npos) return std::nullopt;
return std::make_pair(entry.substr(0, eq_pos), entry.substr(eq_pos + 1));
};
auto join_parts = [delim](std::string_view key,
const std::vector<std::string>& parts) {
std::string result;
const auto total_parts_length = std::accumulate(
parts.begin(), parts.end(), std::size_t{ 0 },
[](std::size_t acc, const std::string& part) { return acc + part.size(); });
const auto delim_count = parts.size() - 1;
const auto equal_sign_length = 1;
result.reserve(key.size() + equal_sign_length + total_parts_length + delim_count);
result.append(key);
result += '=';
result =
std::accumulate(parts.begin(), parts.end(), std::move(result),
[delim, &parts](std::string acc, const std::string& part) {
if(part != parts.front()) acc += delim;
acc.append(part);
return acc;
});
return result;
};
std::unordered_map<std::string_view, key_data> key_map;
std::vector<std::string_view> key_order;
for(auto* entry : envp)
{
if(!entry)
{
continue;
}
auto parsed = parse_entry(entry);
if(!parsed)
{
continue;
}
auto [key, value] = *parsed;
auto [it, inserted] = key_map.try_emplace(key);
if(inserted)
{
key_order.emplace_back(key);
}
auto& data = it->second;
std::istringstream stream{ std::string{ value } };
for(std::string part; std::getline(stream, part, delim);)
{
data.add_unique(part);
}
}
std::vector<char*> result;
result.reserve(key_order.size());
for(auto key : key_order)
{
result.emplace_back(strdup(join_parts(key, key_map[key].parts).c_str()));
}
for(auto* entry : envp)
{
std::free(entry);
entry = nullptr;
}
envp = std::move(result);
}
} // namespace common
} // namespace rocprofsys
@@ -24,6 +24,7 @@ add_library(
lib-common-tests
OBJECT
test_discover_llvm_libdir.cpp
test_environment.cpp
test_path.cpp
test_remove_env.cpp
test_update_env.cpp
@@ -0,0 +1,146 @@
// Copyright (c) Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT
#include "common/environment.hpp"
#include <gtest/gtest.h>
using namespace rocprofsys::common;
class IsPythonInterpreterTest : public ::testing::Test
{};
TEST_F(IsPythonInterpreterTest, RecognizesPython)
{
EXPECT_TRUE(is_python_interpreter("python"));
EXPECT_TRUE(is_python_interpreter("python3"));
EXPECT_TRUE(is_python_interpreter("python3.8"));
EXPECT_TRUE(is_python_interpreter("python3.9"));
EXPECT_TRUE(is_python_interpreter("python3.10"));
EXPECT_TRUE(is_python_interpreter("python3.11"));
EXPECT_TRUE(is_python_interpreter("python3.12"));
EXPECT_TRUE(is_python_interpreter("/usr/bin/python"));
EXPECT_TRUE(is_python_interpreter("/usr/bin/python3"));
EXPECT_TRUE(is_python_interpreter("/usr/bin/python3.10"));
EXPECT_TRUE(is_python_interpreter("/home/user/venv/bin/python"));
EXPECT_TRUE(is_python_interpreter("/opt/conda/bin/python3.11"));
EXPECT_FALSE(is_python_interpreter("bash"));
EXPECT_FALSE(is_python_interpreter("sh"));
EXPECT_FALSE(is_python_interpreter("ruby"));
EXPECT_FALSE(is_python_interpreter("node"));
EXPECT_FALSE(is_python_interpreter("java"));
EXPECT_FALSE(is_python_interpreter("/usr/bin/bash"));
EXPECT_FALSE(is_python_interpreter("./my_app"));
EXPECT_FALSE(is_python_interpreter("pythonista"));
EXPECT_FALSE(is_python_interpreter("python_script.py"));
EXPECT_FALSE(is_python_interpreter("mypython"));
EXPECT_FALSE(is_python_interpreter("python2"));
EXPECT_FALSE(is_python_interpreter("python3."));
EXPECT_FALSE(is_python_interpreter("python3.a"));
EXPECT_FALSE(is_python_interpreter("python3.10a"));
EXPECT_FALSE(is_python_interpreter("python3x10"));
EXPECT_FALSE(is_python_interpreter(""));
EXPECT_FALSE(is_python_interpreter("/usr/bin/"));
}
class DuplicatedEnvironmentEntriesTest : public ::testing::Test
{};
TEST_F(DuplicatedEnvironmentEntriesTest, DuplicateEnvironmentEntries)
{
std::vector<char*> env_vars = {
strdup("PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/bin2"),
strdup("PATH=/usr/local/bin:/usr/bin:/bin"),
};
consolidate_env_entries(env_vars);
ASSERT_EQ(env_vars.size(), 1);
EXPECT_STREQ(env_vars[0], "PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/bin2");
for(auto* entry : env_vars)
free(entry);
}
TEST_F(DuplicatedEnvironmentEntriesTest, HandlesEmptyVector)
{
std::vector<char*> env_vars;
consolidate_env_entries(env_vars);
EXPECT_TRUE(env_vars.empty());
}
TEST_F(DuplicatedEnvironmentEntriesTest, HandlesNullEntries)
{
std::vector<char*> env_vars = {
strdup("PATH=/usr/bin"),
nullptr,
strdup("PATH=/bin"),
};
consolidate_env_entries(env_vars);
ASSERT_EQ(env_vars.size(), 1);
EXPECT_STREQ(env_vars[0], "PATH=/usr/bin:/bin");
for(auto* entry : env_vars)
std::free(entry);
}
TEST_F(DuplicatedEnvironmentEntriesTest, HandlesEmptyValues)
{
std::vector<char*> env_vars = {
strdup("EMPTY_VAR="),
strdup("PATH=/usr/bin"),
};
consolidate_env_entries(env_vars);
ASSERT_EQ(env_vars.size(), 2);
for(auto* entry : env_vars)
std::free(entry);
}
class AddTorchLibraryPathTest : public ::testing::Test
{
protected:
std::unordered_set<std::string> updated_envs;
};
TEST_F(AddTorchLibraryPathTest, SkipsNonPythonExecutables)
{
std::vector<char*> envp = {
strdup("LD_LIBRARY_PATH=/usr/lib"),
};
std::vector<char*> argv = {
strdup("/usr/bin/bash"),
};
add_torch_library_path(envp, argv, false, updated_envs);
// Should not modify environment
ASSERT_EQ(envp.size(), 1);
EXPECT_STREQ(envp[0], "LD_LIBRARY_PATH=/usr/lib");
for(auto* entry : envp)
std::free(entry);
for(auto* entry : argv)
std::free(entry);
}
TEST_F(AddTorchLibraryPathTest, HandlesEmptyArgv)
{
std::vector<char*> envp = {
strdup("LD_LIBRARY_PATH=/usr/lib"),
};
std::vector<char*> argv;
add_torch_library_path(envp, argv, false, updated_envs);
ASSERT_EQ(envp.size(), 1);
EXPECT_STREQ(envp[0], "LD_LIBRARY_PATH=/usr/lib");
for(auto* entry : envp)
std::free(entry);
}
TEST_F(AddTorchLibraryPathTest, HandlesNullArgvFront)
{
std::vector<char*> envp = {
strdup("LD_LIBRARY_PATH=/usr/lib"),
};
std::vector<char*> argv = { nullptr };
add_torch_library_path(envp, argv, false, updated_envs);
ASSERT_EQ(envp.size(), 1);
for(auto* entry : envp)
std::free(entry);
}
@@ -168,6 +168,14 @@ add_ld_library_path(parser_data& _data)
return _data;
}
parser_data&
add_torch_library_path(parser_data& _data, bool verbose)
{
rocprofsys::common::add_torch_library_path(_data.current, _data.command, verbose,
_data.updated);
return _data;
}
parser_data&
add_core_arguments(parser_t& _parser, parser_data& _data)
{
@@ -83,6 +83,9 @@ add_ld_preload(parser_data&);
parser_data&
add_ld_library_path(parser_data&);
parser_data&
add_torch_library_path(parser_data&, bool verbose = false);
parser_data&
add_core_arguments(parser_t&, parser_data&);