Add automatic PyTorch library discovery for Python applications (#2623)
* Add automatic PyTorch library discovery for Python applications (#2623)
此提交包含在:
@@ -192,6 +192,10 @@ prepare_environment_for_run(parser_data_t& _data)
|
||||
rocprofsys::argparse::add_ld_preload(_data);
|
||||
rocprofsys::argparse::add_ld_library_path(_data);
|
||||
}
|
||||
|
||||
rocprofsys::argparse::add_torch_library_path(_data, _data.verbose > 0);
|
||||
|
||||
rocprofsys::common::consolidate_env_entries(_data.current);
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
@@ -933,3 +933,9 @@ parse_args(int argc, char** argv, std::vector<char*>& _env)
|
||||
|
||||
return _outv;
|
||||
}
|
||||
|
||||
void
|
||||
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv)
|
||||
{
|
||||
rocprofsys::common::add_torch_library_path(envp, argv, verbose > 0, updated_envs);
|
||||
}
|
||||
|
||||
@@ -51,6 +51,8 @@ main(int argc, char** argv)
|
||||
_argv.emplace_back(argv[i]);
|
||||
}
|
||||
|
||||
add_torch_library_path(_env, _argv);
|
||||
|
||||
print_updated_environment(_env);
|
||||
|
||||
if(!_argv.empty())
|
||||
|
||||
@@ -35,3 +35,6 @@ get_initial_environment();
|
||||
|
||||
std::vector<char*>
|
||||
parse_args(int argc, char** argv, std::vector<char*>& envp);
|
||||
|
||||
void
|
||||
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv);
|
||||
|
||||
@@ -26,9 +26,12 @@
|
||||
|
||||
#include "common/join.hpp"
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
@@ -197,7 +200,7 @@ remove_env(std::vector<char*>& _environ, std::string_view _env_var,
|
||||
{
|
||||
if(match(itr))
|
||||
{
|
||||
free(itr);
|
||||
std::free(itr);
|
||||
itr = nullptr;
|
||||
}
|
||||
}
|
||||
@@ -266,6 +269,113 @@ discover_llvm_libdir_for_ompt(bool verbose = false)
|
||||
return {};
|
||||
}
|
||||
|
||||
inline bool
|
||||
is_python_interpreter(std::string_view executable)
|
||||
{
|
||||
if(executable.empty()) return false;
|
||||
|
||||
const auto slash_pos = executable.rfind('/');
|
||||
const auto basename = (slash_pos != std::string_view::npos)
|
||||
? executable.substr(slash_pos + 1)
|
||||
: executable;
|
||||
|
||||
if(basename == "python" || basename == "python3") return true;
|
||||
|
||||
constexpr std::string_view python3_prefix = "python3.";
|
||||
|
||||
const bool has_valid_prefix =
|
||||
basename.size() > python3_prefix.size() &&
|
||||
basename.substr(0, python3_prefix.size()) == python3_prefix;
|
||||
if(!has_valid_prefix) return false;
|
||||
|
||||
const auto version_digits = basename.substr(python3_prefix.size());
|
||||
|
||||
return std::all_of(version_digits.begin(), version_digits.end(),
|
||||
[](unsigned char c) { return std::isdigit(c); });
|
||||
}
|
||||
|
||||
inline std::string
|
||||
discover_torch_libpath(const std::string& python_binary, bool verbose = false)
|
||||
{
|
||||
if(python_binary.empty()) return {};
|
||||
|
||||
const auto is_safe_executable_path = [](const std::string& path) {
|
||||
// Allow only a conservative set of characters in the executable path to
|
||||
// avoid injection when used in a shell command.
|
||||
for(unsigned char c : path)
|
||||
{
|
||||
if(std::isalnum(c) != 0) continue;
|
||||
switch(c)
|
||||
{
|
||||
case '/':
|
||||
case '.':
|
||||
case '_':
|
||||
case '-':
|
||||
case '+': break;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if(!is_safe_executable_path(python_binary))
|
||||
{
|
||||
ROCPROFSYS_ENVIRON_LOG(
|
||||
verbose, "Unsafe characters detected in Python interpreter path: %s\n",
|
||||
python_binary.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
const auto cmd = "\"" + python_binary +
|
||||
"\" -c \"import torch; print(torch.__path__[0])\" 2>/dev/null";
|
||||
|
||||
FILE* pipe = popen(cmd.c_str(), "r");
|
||||
if(!pipe)
|
||||
{
|
||||
ROCPROFSYS_ENVIRON_LOG(verbose, "Failed to execute command: %s\n", cmd.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
char buffer[1024];
|
||||
std::string result;
|
||||
while(fgets(buffer, sizeof(buffer), pipe))
|
||||
{
|
||||
result.append(buffer);
|
||||
// stop if we've read the full line (torch path is printed on a single line)
|
||||
if(!result.empty() && result.back() == '\n') break;
|
||||
}
|
||||
|
||||
int status = pclose(pipe);
|
||||
|
||||
if(status != 0 || result.empty())
|
||||
{
|
||||
ROCPROFSYS_ENVIRON_LOG(verbose, "torch not found for Python interpreter: %s\n",
|
||||
python_binary.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
while(!result.empty() &&
|
||||
(result.back() == '\n' || result.back() == '\r' || result.back() == ' '))
|
||||
{
|
||||
result.pop_back();
|
||||
}
|
||||
|
||||
if(result.empty()) return {};
|
||||
|
||||
std::string torch_libdir = result + "/lib";
|
||||
|
||||
if(!::tim::filepath::direxists(torch_libdir))
|
||||
{
|
||||
ROCPROFSYS_ENVIRON_LOG(verbose, "torch lib directory does not exist: %s\n",
|
||||
torch_libdir.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
ROCPROFSYS_ENVIRON_LOG(verbose, "Discovered torch library path: %s\n",
|
||||
torch_libdir.c_str());
|
||||
return torch_libdir;
|
||||
}
|
||||
|
||||
enum class update_mode : uint8_t
|
||||
{
|
||||
REPLACE = 0,
|
||||
@@ -335,7 +445,7 @@ update_env(std::vector<char*>& _environ, std::string_view _env_var, Tp&& _env_va
|
||||
}
|
||||
else
|
||||
{
|
||||
free(itr);
|
||||
std::free(itr);
|
||||
itr = strdup(join('=', _env_var, _env_val_str).c_str());
|
||||
}
|
||||
return;
|
||||
@@ -343,5 +453,145 @@ update_env(std::vector<char*>& _environ, std::string_view _env_var, Tp&& _env_va
|
||||
_environ.emplace_back(strdup(join('=', _env_var, _env_val_str).c_str()));
|
||||
}
|
||||
|
||||
template <typename UpdatedEnvsT>
|
||||
inline void
|
||||
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv,
|
||||
bool verbose, UpdatedEnvsT& updated_envs)
|
||||
{
|
||||
if(argv.empty() || argv.front() == nullptr) return;
|
||||
if(!is_python_interpreter(argv.front())) return;
|
||||
|
||||
auto torch_libpath = discover_torch_libpath(argv.front(), verbose);
|
||||
if(torch_libpath.empty()) return;
|
||||
|
||||
std::unordered_set<std::string> seen{ torch_libpath };
|
||||
std::string result = torch_libpath;
|
||||
|
||||
constexpr std::string_view ld_prefix = "LD_LIBRARY_PATH=";
|
||||
|
||||
auto is_ld_path = [&](char* entry) {
|
||||
return entry &&
|
||||
std::string_view{ entry }.substr(0, ld_prefix.length()) == ld_prefix;
|
||||
};
|
||||
|
||||
for(auto& entry : envp)
|
||||
{
|
||||
if(!is_ld_path(entry)) continue;
|
||||
|
||||
std::istringstream stream{ std::string{ entry + ld_prefix.length() } };
|
||||
for(std::string path; std::getline(stream, path, ':');)
|
||||
{
|
||||
if(!path.empty() && seen.insert(path).second) result += ":" + path;
|
||||
}
|
||||
|
||||
std::free(entry);
|
||||
entry = nullptr;
|
||||
}
|
||||
|
||||
envp.erase(std::remove(envp.begin(), envp.end(), nullptr), envp.end());
|
||||
envp.emplace_back(strdup(join("", ld_prefix, result).c_str()));
|
||||
|
||||
updated_envs.emplace(ld_prefix.substr(0, ld_prefix.length() - 1));
|
||||
}
|
||||
|
||||
inline void
|
||||
consolidate_env_entries(std::vector<char*>& envp)
|
||||
{
|
||||
constexpr char delim = ':';
|
||||
|
||||
struct key_data
|
||||
{
|
||||
std::vector<std::string> parts;
|
||||
std::unordered_set<std::string> seen;
|
||||
|
||||
void add_unique(std::string part)
|
||||
{
|
||||
if(!part.empty() && seen.insert(part).second)
|
||||
parts.emplace_back(std::move(part));
|
||||
}
|
||||
};
|
||||
|
||||
auto parse_entry = [](std::string_view entry)
|
||||
-> std::optional<std::pair<std::string_view, std::string_view>> {
|
||||
auto eq_pos = entry.find('=');
|
||||
if(eq_pos == std::string_view::npos) return std::nullopt;
|
||||
return std::make_pair(entry.substr(0, eq_pos), entry.substr(eq_pos + 1));
|
||||
};
|
||||
|
||||
auto join_parts = [delim](std::string_view key,
|
||||
const std::vector<std::string>& parts) {
|
||||
std::string result;
|
||||
|
||||
const auto total_parts_length = std::accumulate(
|
||||
parts.begin(), parts.end(), std::size_t{ 0 },
|
||||
[](std::size_t acc, const std::string& part) { return acc + part.size(); });
|
||||
|
||||
const auto delim_count = parts.size() - 1;
|
||||
const auto equal_sign_length = 1;
|
||||
|
||||
result.reserve(key.size() + equal_sign_length + total_parts_length + delim_count);
|
||||
result.append(key);
|
||||
result += '=';
|
||||
|
||||
result =
|
||||
std::accumulate(parts.begin(), parts.end(), std::move(result),
|
||||
[delim, &parts](std::string acc, const std::string& part) {
|
||||
if(part != parts.front()) acc += delim;
|
||||
acc.append(part);
|
||||
return acc;
|
||||
});
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string_view, key_data> key_map;
|
||||
std::vector<std::string_view> key_order;
|
||||
|
||||
for(auto* entry : envp)
|
||||
{
|
||||
if(!entry)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parsed = parse_entry(entry);
|
||||
if(!parsed)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto [key, value] = *parsed;
|
||||
|
||||
auto [it, inserted] = key_map.try_emplace(key);
|
||||
if(inserted)
|
||||
{
|
||||
key_order.emplace_back(key);
|
||||
}
|
||||
|
||||
auto& data = it->second;
|
||||
std::istringstream stream{ std::string{ value } };
|
||||
for(std::string part; std::getline(stream, part, delim);)
|
||||
{
|
||||
data.add_unique(part);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char*> result;
|
||||
result.reserve(key_order.size());
|
||||
|
||||
for(auto key : key_order)
|
||||
{
|
||||
result.emplace_back(strdup(join_parts(key, key_map[key].parts).c_str()));
|
||||
}
|
||||
|
||||
for(auto* entry : envp)
|
||||
{
|
||||
std::free(entry);
|
||||
entry = nullptr;
|
||||
}
|
||||
|
||||
envp = std::move(result);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace rocprofsys
|
||||
|
||||
@@ -24,6 +24,7 @@ add_library(
|
||||
lib-common-tests
|
||||
OBJECT
|
||||
test_discover_llvm_libdir.cpp
|
||||
test_environment.cpp
|
||||
test_path.cpp
|
||||
test_remove_env.cpp
|
||||
test_update_env.cpp
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common/environment.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace rocprofsys::common;
|
||||
|
||||
class IsPythonInterpreterTest : public ::testing::Test
|
||||
{};
|
||||
|
||||
TEST_F(IsPythonInterpreterTest, RecognizesPython)
|
||||
{
|
||||
EXPECT_TRUE(is_python_interpreter("python"));
|
||||
EXPECT_TRUE(is_python_interpreter("python3"));
|
||||
EXPECT_TRUE(is_python_interpreter("python3.8"));
|
||||
EXPECT_TRUE(is_python_interpreter("python3.9"));
|
||||
EXPECT_TRUE(is_python_interpreter("python3.10"));
|
||||
EXPECT_TRUE(is_python_interpreter("python3.11"));
|
||||
EXPECT_TRUE(is_python_interpreter("python3.12"));
|
||||
EXPECT_TRUE(is_python_interpreter("/usr/bin/python"));
|
||||
EXPECT_TRUE(is_python_interpreter("/usr/bin/python3"));
|
||||
EXPECT_TRUE(is_python_interpreter("/usr/bin/python3.10"));
|
||||
EXPECT_TRUE(is_python_interpreter("/home/user/venv/bin/python"));
|
||||
EXPECT_TRUE(is_python_interpreter("/opt/conda/bin/python3.11"));
|
||||
EXPECT_FALSE(is_python_interpreter("bash"));
|
||||
EXPECT_FALSE(is_python_interpreter("sh"));
|
||||
EXPECT_FALSE(is_python_interpreter("ruby"));
|
||||
EXPECT_FALSE(is_python_interpreter("node"));
|
||||
EXPECT_FALSE(is_python_interpreter("java"));
|
||||
EXPECT_FALSE(is_python_interpreter("/usr/bin/bash"));
|
||||
EXPECT_FALSE(is_python_interpreter("./my_app"));
|
||||
EXPECT_FALSE(is_python_interpreter("pythonista"));
|
||||
EXPECT_FALSE(is_python_interpreter("python_script.py"));
|
||||
EXPECT_FALSE(is_python_interpreter("mypython"));
|
||||
EXPECT_FALSE(is_python_interpreter("python2"));
|
||||
EXPECT_FALSE(is_python_interpreter("python3."));
|
||||
EXPECT_FALSE(is_python_interpreter("python3.a"));
|
||||
EXPECT_FALSE(is_python_interpreter("python3.10a"));
|
||||
EXPECT_FALSE(is_python_interpreter("python3x10"));
|
||||
EXPECT_FALSE(is_python_interpreter(""));
|
||||
EXPECT_FALSE(is_python_interpreter("/usr/bin/"));
|
||||
}
|
||||
|
||||
class DuplicatedEnvironmentEntriesTest : public ::testing::Test
|
||||
{};
|
||||
|
||||
TEST_F(DuplicatedEnvironmentEntriesTest, DuplicateEnvironmentEntries)
|
||||
{
|
||||
std::vector<char*> env_vars = {
|
||||
strdup("PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/bin2"),
|
||||
strdup("PATH=/usr/local/bin:/usr/bin:/bin"),
|
||||
};
|
||||
|
||||
consolidate_env_entries(env_vars);
|
||||
|
||||
ASSERT_EQ(env_vars.size(), 1);
|
||||
EXPECT_STREQ(env_vars[0], "PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/bin2");
|
||||
|
||||
for(auto* entry : env_vars)
|
||||
free(entry);
|
||||
}
|
||||
|
||||
TEST_F(DuplicatedEnvironmentEntriesTest, HandlesEmptyVector)
|
||||
{
|
||||
std::vector<char*> env_vars;
|
||||
consolidate_env_entries(env_vars);
|
||||
EXPECT_TRUE(env_vars.empty());
|
||||
}
|
||||
|
||||
TEST_F(DuplicatedEnvironmentEntriesTest, HandlesNullEntries)
|
||||
{
|
||||
std::vector<char*> env_vars = {
|
||||
strdup("PATH=/usr/bin"),
|
||||
nullptr,
|
||||
strdup("PATH=/bin"),
|
||||
};
|
||||
consolidate_env_entries(env_vars);
|
||||
ASSERT_EQ(env_vars.size(), 1);
|
||||
EXPECT_STREQ(env_vars[0], "PATH=/usr/bin:/bin");
|
||||
for(auto* entry : env_vars)
|
||||
std::free(entry);
|
||||
}
|
||||
|
||||
TEST_F(DuplicatedEnvironmentEntriesTest, HandlesEmptyValues)
|
||||
{
|
||||
std::vector<char*> env_vars = {
|
||||
strdup("EMPTY_VAR="),
|
||||
strdup("PATH=/usr/bin"),
|
||||
};
|
||||
consolidate_env_entries(env_vars);
|
||||
ASSERT_EQ(env_vars.size(), 2);
|
||||
|
||||
for(auto* entry : env_vars)
|
||||
std::free(entry);
|
||||
}
|
||||
|
||||
class AddTorchLibraryPathTest : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
std::unordered_set<std::string> updated_envs;
|
||||
};
|
||||
|
||||
TEST_F(AddTorchLibraryPathTest, SkipsNonPythonExecutables)
|
||||
{
|
||||
std::vector<char*> envp = {
|
||||
strdup("LD_LIBRARY_PATH=/usr/lib"),
|
||||
};
|
||||
std::vector<char*> argv = {
|
||||
strdup("/usr/bin/bash"),
|
||||
};
|
||||
add_torch_library_path(envp, argv, false, updated_envs);
|
||||
// Should not modify environment
|
||||
ASSERT_EQ(envp.size(), 1);
|
||||
EXPECT_STREQ(envp[0], "LD_LIBRARY_PATH=/usr/lib");
|
||||
for(auto* entry : envp)
|
||||
std::free(entry);
|
||||
for(auto* entry : argv)
|
||||
std::free(entry);
|
||||
}
|
||||
|
||||
TEST_F(AddTorchLibraryPathTest, HandlesEmptyArgv)
|
||||
{
|
||||
std::vector<char*> envp = {
|
||||
strdup("LD_LIBRARY_PATH=/usr/lib"),
|
||||
};
|
||||
std::vector<char*> argv;
|
||||
add_torch_library_path(envp, argv, false, updated_envs);
|
||||
ASSERT_EQ(envp.size(), 1);
|
||||
EXPECT_STREQ(envp[0], "LD_LIBRARY_PATH=/usr/lib");
|
||||
for(auto* entry : envp)
|
||||
std::free(entry);
|
||||
}
|
||||
|
||||
TEST_F(AddTorchLibraryPathTest, HandlesNullArgvFront)
|
||||
{
|
||||
std::vector<char*> envp = {
|
||||
strdup("LD_LIBRARY_PATH=/usr/lib"),
|
||||
};
|
||||
std::vector<char*> argv = { nullptr };
|
||||
add_torch_library_path(envp, argv, false, updated_envs);
|
||||
ASSERT_EQ(envp.size(), 1);
|
||||
for(auto* entry : envp)
|
||||
std::free(entry);
|
||||
}
|
||||
@@ -168,6 +168,14 @@ add_ld_library_path(parser_data& _data)
|
||||
return _data;
|
||||
}
|
||||
|
||||
parser_data&
|
||||
add_torch_library_path(parser_data& _data, bool verbose)
|
||||
{
|
||||
rocprofsys::common::add_torch_library_path(_data.current, _data.command, verbose,
|
||||
_data.updated);
|
||||
return _data;
|
||||
}
|
||||
|
||||
parser_data&
|
||||
add_core_arguments(parser_t& _parser, parser_data& _data)
|
||||
{
|
||||
|
||||
@@ -83,6 +83,9 @@ add_ld_preload(parser_data&);
|
||||
parser_data&
|
||||
add_ld_library_path(parser_data&);
|
||||
|
||||
parser_data&
|
||||
add_torch_library_path(parser_data&, bool verbose = false);
|
||||
|
||||
parser_data&
|
||||
add_core_arguments(parser_t&, parser_data&);
|
||||
|
||||
|
||||
新增問題並參考
封鎖使用者