From 13c0bf8e2b5280ac7e40aaa7d262b565f6bdf40f Mon Sep 17 00:00:00 2001 From: Mythreya Date: Wed, 11 Sep 2024 15:57:51 -0700 Subject: [PATCH] Add rccl API support (#66) * [Draft]: Add rccl API support * Partial tests Need to add tests to the cmake file [ROCm/rocprofiler-register commit: b71f9cabe6055386d2a2c6d669d4574bbd0e2ad0] --- .../rocprofiler_register.cpp | 6 + .../rocprofiler-register/tests/CMakeLists.txt | 1 + .../rocprofiler-register/tests/common/fwd.h | 8 + .../rocprofiler-register/tests/common/fwd.hpp | 10 ++ .../tests/rccl/CMakeLists.txt | 33 ++++ .../rocprofiler-register/tests/rccl/rccl.cpp | 149 ++++++++++++++++++ .../rocprofiler-register/tests/rccl/rccl.hpp | 46 ++++++ .../tests/rocprofiler/rocprofiler.cpp | 14 ++ 8 files changed, 267 insertions(+) create mode 100644 projects/rocprofiler-register/tests/rccl/CMakeLists.txt create mode 100644 projects/rocprofiler-register/tests/rccl/rccl.cpp create mode 100644 projects/rocprofiler-register/tests/rccl/rccl.hpp diff --git a/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp b/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp index 7c09d2f09b..3cabb87ab0 100644 --- a/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp +++ b/projects/rocprofiler-register/source/lib/rocprofiler-register/rocprofiler_register.cpp @@ -105,6 +105,7 @@ enum rocp_reg_supported_library // NOLINT(performance-enum-size) ROCP_REG_HIP, ROCP_REG_ROCTX, ROCP_REG_HIP_COMPILER, + ROCP_REG_RCCL, ROCP_REG_LAST, }; @@ -159,6 +160,11 @@ ROCP_REG_DEFINE_LIBRARY_TRAITS(ROCP_REG_HIP_COMPILER, "rocprofiler_register_import_hip_compiler", "libamdhip64.so.[6-9]($|\\.[0-9\\.]+)") +ROCP_REG_DEFINE_LIBRARY_TRAITS(ROCP_REG_RCCL, + "rccl", + "rocprofiler_register_import_rccl", + "librccl.so.[6-9]($|\\.[0-9\\.]+)") + ROCP_REG_DEFINE_ERROR_MESSAGE(ROCP_REG_SUCCESS, "Success") ROCP_REG_DEFINE_ERROR_MESSAGE(ROCP_REG_NO_TOOLS, "rocprofiler-register found no tools") ROCP_REG_DEFINE_ERROR_MESSAGE(ROCP_REG_DEADLOCK, "rocprofiler-register deadlocked") diff --git a/projects/rocprofiler-register/tests/CMakeLists.txt b/projects/rocprofiler-register/tests/CMakeLists.txt index 1349669016..36aca52be9 100644 --- a/projects/rocprofiler-register/tests/CMakeLists.txt +++ b/projects/rocprofiler-register/tests/CMakeLists.txt @@ -71,6 +71,7 @@ endif() add_subdirectory(hsa-runtime) add_subdirectory(amdhip) add_subdirectory(roctx) +add_subdirectory(rccl) add_subdirectory(rocprofiler) # diff --git a/projects/rocprofiler-register/tests/common/fwd.h b/projects/rocprofiler-register/tests/common/fwd.h index 571c66c0d4..5d6dcb21c1 100644 --- a/projects/rocprofiler-register/tests/common/fwd.h +++ b/projects/rocprofiler-register/tests/common/fwd.h @@ -12,6 +12,7 @@ extern "C" { # pragma weak hsa_init # pragma weak roctxRangePush # pragma weak roctxRangePop +# pragma weak ncclGetVersion #endif extern void @@ -26,6 +27,13 @@ roctxRangePush(const char*); extern void roctxRangePop(const char*); +enum ncclResult_t +{ +}; + +extern ncclResult_t +ncclGetVersion(int* version); + #ifdef __cplusplus } #endif diff --git a/projects/rocprofiler-register/tests/common/fwd.hpp b/projects/rocprofiler-register/tests/common/fwd.hpp index 6438a2a77b..c9956568e2 100644 --- a/projects/rocprofiler-register/tests/common/fwd.hpp +++ b/projects/rocprofiler-register/tests/common/fwd.hpp @@ -18,6 +18,7 @@ namespace { decltype(hip_init)* hip_init_fn = nullptr; decltype(hsa_init)* hsa_init_fn = nullptr; +decltype(ncclGetVersion)* ncclGetVersion_fn = nullptr; decltype(roctxRangePush)* roctxRangePush_fn = nullptr; decltype(roctxRangePush)* roctxRangePop_fn = nullptr; @@ -27,6 +28,7 @@ enum rocp_reg_test_modes : uint8_t ROCP_REG_TEST_HIP = (1 << 0), ROCP_REG_TEST_HSA = (1 << 1), ROCP_REG_TEST_ROCTX = (1 << 2), + ROCP_REG_TEST_RCCL = (1 << 3), }; template @@ -73,6 +75,7 @@ resolve_symbols(int _open_mode = RTLD_LOCAL | RTLD_LAZY) void* amdhip_handle = nullptr; void* hsart_handle = nullptr; void* roctx_handle = nullptr; + void* rccl_handle = nullptr; if constexpr((Idx & ROCP_REG_TEST_HIP) == ROCP_REG_TEST_HIP) { @@ -97,5 +100,12 @@ resolve_symbols(int _open_mode = RTLD_LOCAL | RTLD_LAZY) _resolve_dlsym(roctxRangePush_fn, roctx_handle, "roctxRangePush"); _resolve_dlsym(roctxRangePop_fn, roctx_handle, "roctxRangePop"); } + + if constexpr((Idx & ROCP_REG_TEST_RCCL) == ROCP_REG_TEST_RCCL) + { + ncclGetVersion_fn = ncclGetVersion; + if(!ncclGetVersion_fn) _resolve_dlopen(rccl_handle, "librccl.so"); + _resolve_dlsym(ncclGetVersion_fn, rccl_handle, "ncclGetVersion"); + } } } // namespace diff --git a/projects/rocprofiler-register/tests/rccl/CMakeLists.txt b/projects/rocprofiler-register/tests/rccl/CMakeLists.txt new file mode 100644 index 0000000000..b9764ac0b3 --- /dev/null +++ b/projects/rocprofiler-register/tests/rccl/CMakeLists.txt @@ -0,0 +1,33 @@ +# +# +# + +if(NOT TARGET rocprofiler-register::rocprofiler-register) + # find_package(rocprofiler-register REQUIRED) +endif() + +add_library(rccl SHARED) +add_library(rccl::rccl ALIAS rccl) +target_sources(rccl PRIVATE rccl.cpp rccl.hpp) +target_include_directories(rccl PUBLIC $) +target_link_libraries(rccl PRIVATE rocprofiler-register::rocprofiler-register) +set_target_properties( + rccl + PROPERTIES OUTPUT_NAME rccl + SOVERSION 1 + VERSION 1.0) +rocp_register_strip_target(rccl) + +add_library(rccl-invalid SHARED) +add_library(rccl::rccl-invalid ALIAS rccl-invalid) +target_sources(rccl-invalid PRIVATE rccl.cpp rccl.hpp) +target_include_directories(rccl-invalid + PUBLIC $) +target_link_libraries(rccl-invalid PRIVATE rocprofiler-register::rocprofiler-register) +set_target_properties( + rccl-invalid + PROPERTIES OUTPUT_NAME rccl + SOVERSION 1 + VERSION 1.0 + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/invalid) +rocp_register_strip_target(rccl-invalid) diff --git a/projects/rocprofiler-register/tests/rccl/rccl.cpp b/projects/rocprofiler-register/tests/rccl/rccl.cpp new file mode 100644 index 0000000000..3607888267 --- /dev/null +++ b/projects/rocprofiler-register/tests/rccl/rccl.cpp @@ -0,0 +1,149 @@ +#include "rccl.hpp" + +#include + +#include +#include +#include +#include + +#define ROCP_REG_VERSION \ + ROCPROFILER_REGISTER_COMPUTE_VERSION_2(RCCL_API_TRACE_VERSION_MAJOR, \ + RCCL_API_TRACE_VERSION_PATCH) + +ROCPROFILER_REGISTER_DEFINE_IMPORT(rccl, ROCP_REG_VERSION) + +#ifndef ROCP_REG_FILE_NAME +# define ROCP_REG_FILE_NAME \ + ::std::string{ __FILE__ } \ + .substr(::std::string_view{ __FILE__ }.find_last_of('/') + 1) \ + .c_str() +#endif + +namespace rccl +{ +namespace +{ +auto& +get_rccl_api_table_impl() +{ + static auto _table = std::atomic{ nullptr }; + return _table; +} + +void +register_profiler_impl() +{ + static auto _const_api_table = rcclApiFuncTable{}; + initialize_rccl_api_table(&_const_api_table); + + // set this before any recursive opportunity arises + get_rccl_api_table_impl().exchange(&_const_api_table); + + // create a copy of the api table for modification by registration + static auto _profiler_api_table = rcclApiFuncTable{}; + copy_rccl_api_table(&_profiler_api_table, &_const_api_table); + + void* _profiler_api_table_v = static_cast(&_profiler_api_table); + + auto lib_id = rocprofiler_register_library_indentifier_t{}; + auto success = + rocprofiler_register_library_api_table("rccl", + &ROCPROFILER_REGISTER_IMPORT_FUNC(rccl), + ROCP_REG_VERSION, + &_profiler_api_table_v, + 1, + &lib_id); + + if(success == 0) + { + printf("[%s] rccl identifier %lu\n", ROCP_REG_FILE_NAME, lib_id.handle); + auto* _api_table = &_const_api_table; + if(!get_rccl_api_table_impl().compare_exchange_strong(_api_table, + &_profiler_api_table)) + { + // with the current impl, if we ever get here, someone is calling one the + // functions in this anonymous namespace that shouldn't + std::cerr + << "register_profiler_impl expected the API table to be the internal " + "implementation and yet it is not. something went wrong.\n"; + abort(); + } + } + else if(success != ROCP_REG_NO_TOOLS) + { + std::cerr << "rccl library failed to register with rocprofiler-register: " + << rocprofiler_register_error_string(success) << "\n"; + exit(EXIT_FAILURE); + } +} + +void +register_profiler() +{ + // this registration scheme is designed to minimize overhead once + // registered (only pay cost of checking atomic boolean) + // once the profiler is registered. If the library has not + // been registered and two or more threads try to register concurrently + // the first thread to acquire the lock below, will block the + // threads until registration is complete. However, + // if the same thread performing the registration re-enters this function + // i.e. this library's API is called during registration, this function + // will prevent a deadlock by not attempting to re-enter the + // the call-once and not releasing any waiting threads by flipping + // the _is_registered field to true. + static auto _is_registered = std::atomic{ false }; + + if(!_is_registered.load(std::memory_order_acquire)) + { + using mutex_t = std::recursive_mutex; + using auto_lock_t = std::unique_lock; + static auto _once = std::once_flag{}; + static auto _mutex = mutex_t{}; + + // defer the lock so we can check for recursion + auto _lk = auto_lock_t{ _mutex, std::defer_lock }; + + // this will be true if the same thread currently executing the call_once invokes + // the library's API while registering the profiler (e.g. tool which wants to + // instrument rccl API invokes a rccl function while registering with the + // profiler) we allow this thread to proceed and access the "const" API table but + // return so it does not flip _is_registered to true, which would result + // in any subsequent threads not waiting until the library is fully registered, + // resulting in missed callbacks for the tools + if(_lk.owns_lock()) return; + + // ensures any subsequent threads wait until the first thread + // finishes registration + _lk.lock(); + // call_once to ensure that we only register once + std::call_once(_once, register_profiler_impl); + // the first thread has completed registration and all + // threads waiting on lock will be released and this + // block will not be entered again + _is_registered.exchange(true, std::memory_order_release); + } +} +} // namespace + +rcclApiFuncTable* +get_rccl_api_table() +{ + register_profiler(); + return get_rccl_api_table_impl().load(std::memory_order_relaxed); +} + +void +rccl_init() +{ + printf("[%s] %s\n", ROCP_REG_FILE_NAME, __FUNCTION__); +} +} // namespace rccl + +extern "C" { +void +rccl_init(void) +{ + rccl::get_rccl_api_table()->ncclGetVersion_fn({}); +} +} diff --git a/projects/rocprofiler-register/tests/rccl/rccl.hpp b/projects/rocprofiler-register/tests/rccl/rccl.hpp new file mode 100644 index 0000000000..d220f1da13 --- /dev/null +++ b/projects/rocprofiler-register/tests/rccl/rccl.hpp @@ -0,0 +1,46 @@ +#pragma once + +#define RCCL_API_TRACE_VERSION_MAJOR 0 +#define RCCL_API_TRACE_VERSION_PATCH 0 + +#include +#include + +extern "C" { +// fake rccl function +typedef int ncclResult_t; + +enum ncclDataType_t +{ +}; + +ncclResult_t +ncclGetVersion(int* version) __attribute__((visibility("default"))); +} + +namespace rccl +{ +struct rcclApiFuncTable +{ + uint64_t size = 0; + decltype(::ncclGetVersion)* ncclGetVersion_fn = nullptr; +}; + +ncclResult_t +ncclGetVersion(int* version); + +// populates rccl api table with function pointers +inline void +initialize_rccl_api_table(rcclApiFuncTable* dst) +{ + dst->size = sizeof(rcclApiFuncTable); + dst->ncclGetVersion_fn = &::rccl::ncclGetVersion; +} + +// copies the api table from src to dst +inline void +copy_rccl_api_table(rcclApiFuncTable* dst, const rcclApiFuncTable* src) +{ + *dst = *src; +} +} // namespace rccl diff --git a/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp b/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp index 519cb7ebee..087c7ead30 100644 --- a/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp +++ b/projects/rocprofiler-register/tests/rocprofiler/rocprofiler.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -29,6 +30,13 @@ hsa_init() printf("[%s] %s\n", ROCP_REG_FILE_NAME, __FUNCTION__); } +ncclResult_t +ncclGetVersion(int*) +{ + printf("[%s] %s\n", ROCP_REG_FILE_NAME, __FUNCTION__); + return {}; +} + void roctx_range_push(const char* name) { @@ -78,6 +86,7 @@ rocprofiler_set_api_table(const char* name, using hip_table_t = hip::HipApiTable; using hsa_table_t = hsa::HsaApiTable; using roctx_table_t = roctx::ROCTxApiTable; + using rccl_table_t = rccl::rcclApiFuncTable; auto* _wrap_v = std::getenv("ROCP_REG_TEST_WRAP"); bool _wrap = (_wrap_v != nullptr && std::stoi(_wrap_v) != 0); @@ -107,6 +116,11 @@ rocprofiler_set_api_table(const char* name, _table->roctxRangePush_fn = &rocprofiler::roctx_range_push; _table->roctxRangePop_fn = &rocprofiler::roctx_range_pop; } + else if(std::string_view{ name } == "rccl") + { + rccl_table_t* _table = static_cast(tables[0]); + _table->ncclGetVersion_fn = &rocprofiler::ncclGetVersion; + } } return 0;