Add rccl API support (#66)

* [Draft]: Add rccl API support

* Partial tests

Need to add tests to the cmake file

[ROCm/rocprofiler-register commit: b71f9cabe6]
This commit is contained in:
Mythreya
2024-09-11 15:57:51 -07:00
committed by GitHub
parent ec3d720303
commit 13c0bf8e2b
8 changed files with 267 additions and 0 deletions
@@ -105,6 +105,7 @@ enum rocp_reg_supported_library // NOLINT(performance-enum-size)
ROCP_REG_HIP,
ROCP_REG_ROCTX,
ROCP_REG_HIP_COMPILER,
ROCP_REG_RCCL,
ROCP_REG_LAST,
};
@@ -159,6 +160,11 @@ ROCP_REG_DEFINE_LIBRARY_TRAITS(ROCP_REG_HIP_COMPILER,
"rocprofiler_register_import_hip_compiler",
"libamdhip64.so.[6-9]($|\\.[0-9\\.]+)")
ROCP_REG_DEFINE_LIBRARY_TRAITS(ROCP_REG_RCCL,
"rccl",
"rocprofiler_register_import_rccl",
"librccl.so.[6-9]($|\\.[0-9\\.]+)")
ROCP_REG_DEFINE_ERROR_MESSAGE(ROCP_REG_SUCCESS, "Success")
ROCP_REG_DEFINE_ERROR_MESSAGE(ROCP_REG_NO_TOOLS, "rocprofiler-register found no tools")
ROCP_REG_DEFINE_ERROR_MESSAGE(ROCP_REG_DEADLOCK, "rocprofiler-register deadlocked")
@@ -71,6 +71,7 @@ endif()
add_subdirectory(hsa-runtime)
add_subdirectory(amdhip)
add_subdirectory(roctx)
add_subdirectory(rccl)
add_subdirectory(rocprofiler)
#
@@ -12,6 +12,7 @@ extern "C" {
# pragma weak hsa_init
# pragma weak roctxRangePush
# pragma weak roctxRangePop
# pragma weak ncclGetVersion
#endif
extern void
@@ -26,6 +27,13 @@ roctxRangePush(const char*);
extern void
roctxRangePop(const char*);
enum ncclResult_t
{
};
extern ncclResult_t
ncclGetVersion(int* version);
#ifdef __cplusplus
}
#endif
@@ -18,6 +18,7 @@ namespace
{
decltype(hip_init)* hip_init_fn = nullptr;
decltype(hsa_init)* hsa_init_fn = nullptr;
decltype(ncclGetVersion)* ncclGetVersion_fn = nullptr;
decltype(roctxRangePush)* roctxRangePush_fn = nullptr;
decltype(roctxRangePush)* roctxRangePop_fn = nullptr;
@@ -27,6 +28,7 @@ enum rocp_reg_test_modes : uint8_t
ROCP_REG_TEST_HIP = (1 << 0),
ROCP_REG_TEST_HSA = (1 << 1),
ROCP_REG_TEST_ROCTX = (1 << 2),
ROCP_REG_TEST_RCCL = (1 << 3),
};
template <uint8_t Idx = ROCP_REG_TEST_NONE>
@@ -73,6 +75,7 @@ resolve_symbols(int _open_mode = RTLD_LOCAL | RTLD_LAZY)
void* amdhip_handle = nullptr;
void* hsart_handle = nullptr;
void* roctx_handle = nullptr;
void* rccl_handle = nullptr;
if constexpr((Idx & ROCP_REG_TEST_HIP) == ROCP_REG_TEST_HIP)
{
@@ -97,5 +100,12 @@ resolve_symbols(int _open_mode = RTLD_LOCAL | RTLD_LAZY)
_resolve_dlsym(roctxRangePush_fn, roctx_handle, "roctxRangePush");
_resolve_dlsym(roctxRangePop_fn, roctx_handle, "roctxRangePop");
}
if constexpr((Idx & ROCP_REG_TEST_RCCL) == ROCP_REG_TEST_RCCL)
{
ncclGetVersion_fn = ncclGetVersion;
if(!ncclGetVersion_fn) _resolve_dlopen(rccl_handle, "librccl.so");
_resolve_dlsym(ncclGetVersion_fn, rccl_handle, "ncclGetVersion");
}
}
} // namespace
@@ -0,0 +1,33 @@
#
#
#
if(NOT TARGET rocprofiler-register::rocprofiler-register)
# find_package(rocprofiler-register REQUIRED)
endif()
add_library(rccl SHARED)
add_library(rccl::rccl ALIAS rccl)
target_sources(rccl PRIVATE rccl.cpp rccl.hpp)
target_include_directories(rccl PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>)
target_link_libraries(rccl PRIVATE rocprofiler-register::rocprofiler-register)
set_target_properties(
rccl
PROPERTIES OUTPUT_NAME rccl
SOVERSION 1
VERSION 1.0)
rocp_register_strip_target(rccl)
add_library(rccl-invalid SHARED)
add_library(rccl::rccl-invalid ALIAS rccl-invalid)
target_sources(rccl-invalid PRIVATE rccl.cpp rccl.hpp)
target_include_directories(rccl-invalid
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>)
target_link_libraries(rccl-invalid PRIVATE rocprofiler-register::rocprofiler-register)
set_target_properties(
rccl-invalid
PROPERTIES OUTPUT_NAME rccl
SOVERSION 1
VERSION 1.0
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/invalid)
rocp_register_strip_target(rccl-invalid)
@@ -0,0 +1,149 @@
#include "rccl.hpp"
#include <rocprofiler-register/rocprofiler-register.h>
#include <atomic>
#include <iostream>
#include <mutex>
#include <string_view>
#define ROCP_REG_VERSION \
ROCPROFILER_REGISTER_COMPUTE_VERSION_2(RCCL_API_TRACE_VERSION_MAJOR, \
RCCL_API_TRACE_VERSION_PATCH)
ROCPROFILER_REGISTER_DEFINE_IMPORT(rccl, ROCP_REG_VERSION)
#ifndef ROCP_REG_FILE_NAME
# define ROCP_REG_FILE_NAME \
::std::string{ __FILE__ } \
.substr(::std::string_view{ __FILE__ }.find_last_of('/') + 1) \
.c_str()
#endif
namespace rccl
{
namespace
{
auto&
get_rccl_api_table_impl()
{
static auto _table = std::atomic<rcclApiFuncTable*>{ nullptr };
return _table;
}
void
register_profiler_impl()
{
static auto _const_api_table = rcclApiFuncTable{};
initialize_rccl_api_table(&_const_api_table);
// set this before any recursive opportunity arises
get_rccl_api_table_impl().exchange(&_const_api_table);
// create a copy of the api table for modification by registration
static auto _profiler_api_table = rcclApiFuncTable{};
copy_rccl_api_table(&_profiler_api_table, &_const_api_table);
void* _profiler_api_table_v = static_cast<void*>(&_profiler_api_table);
auto lib_id = rocprofiler_register_library_indentifier_t{};
auto success =
rocprofiler_register_library_api_table("rccl",
&ROCPROFILER_REGISTER_IMPORT_FUNC(rccl),
ROCP_REG_VERSION,
&_profiler_api_table_v,
1,
&lib_id);
if(success == 0)
{
printf("[%s] rccl identifier %lu\n", ROCP_REG_FILE_NAME, lib_id.handle);
auto* _api_table = &_const_api_table;
if(!get_rccl_api_table_impl().compare_exchange_strong(_api_table,
&_profiler_api_table))
{
// with the current impl, if we ever get here, someone is calling one the
// functions in this anonymous namespace that shouldn't
std::cerr
<< "register_profiler_impl expected the API table to be the internal "
"implementation and yet it is not. something went wrong.\n";
abort();
}
}
else if(success != ROCP_REG_NO_TOOLS)
{
std::cerr << "rccl library failed to register with rocprofiler-register: "
<< rocprofiler_register_error_string(success) << "\n";
exit(EXIT_FAILURE);
}
}
void
register_profiler()
{
// this registration scheme is designed to minimize overhead once
// registered (only pay cost of checking atomic boolean)
// once the profiler is registered. If the library has not
// been registered and two or more threads try to register concurrently
// the first thread to acquire the lock below, will block the
// threads until registration is complete. However,
// if the same thread performing the registration re-enters this function
// i.e. this library's API is called during registration, this function
// will prevent a deadlock by not attempting to re-enter the
// the call-once and not releasing any waiting threads by flipping
// the _is_registered field to true.
static auto _is_registered = std::atomic<bool>{ false };
if(!_is_registered.load(std::memory_order_acquire))
{
using mutex_t = std::recursive_mutex;
using auto_lock_t = std::unique_lock<mutex_t>;
static auto _once = std::once_flag{};
static auto _mutex = mutex_t{};
// defer the lock so we can check for recursion
auto _lk = auto_lock_t{ _mutex, std::defer_lock };
// this will be true if the same thread currently executing the call_once invokes
// the library's API while registering the profiler (e.g. tool which wants to
// instrument rccl API invokes a rccl function while registering with the
// profiler) we allow this thread to proceed and access the "const" API table but
// return so it does not flip _is_registered to true, which would result
// in any subsequent threads not waiting until the library is fully registered,
// resulting in missed callbacks for the tools
if(_lk.owns_lock()) return;
// ensures any subsequent threads wait until the first thread
// finishes registration
_lk.lock();
// call_once to ensure that we only register once
std::call_once(_once, register_profiler_impl);
// the first thread has completed registration and all
// threads waiting on lock will be released and this
// block will not be entered again
_is_registered.exchange(true, std::memory_order_release);
}
}
} // namespace
rcclApiFuncTable*
get_rccl_api_table()
{
register_profiler();
return get_rccl_api_table_impl().load(std::memory_order_relaxed);
}
void
rccl_init()
{
printf("[%s] %s\n", ROCP_REG_FILE_NAME, __FUNCTION__);
}
} // namespace rccl
extern "C" {
void
rccl_init(void)
{
rccl::get_rccl_api_table()->ncclGetVersion_fn({});
}
}
@@ -0,0 +1,46 @@
#pragma once
#define RCCL_API_TRACE_VERSION_MAJOR 0
#define RCCL_API_TRACE_VERSION_PATCH 0
#include <cstddef>
#include <cstdint>
extern "C" {
// fake rccl function
typedef int ncclResult_t;
enum ncclDataType_t
{
};
ncclResult_t
ncclGetVersion(int* version) __attribute__((visibility("default")));
}
namespace rccl
{
struct rcclApiFuncTable
{
uint64_t size = 0;
decltype(::ncclGetVersion)* ncclGetVersion_fn = nullptr;
};
ncclResult_t
ncclGetVersion(int* version);
// populates rccl api table with function pointers
inline void
initialize_rccl_api_table(rcclApiFuncTable* dst)
{
dst->size = sizeof(rcclApiFuncTable);
dst->ncclGetVersion_fn = &::rccl::ncclGetVersion;
}
// copies the api table from src to dst
inline void
copy_rccl_api_table(rcclApiFuncTable* dst, const rcclApiFuncTable* src)
{
*dst = *src;
}
} // namespace rccl
@@ -1,6 +1,7 @@
#include <amdhip/amdhip.hpp>
#include <hsa-runtime/hsa-runtime.hpp>
#include <rccl/rccl.hpp>
#include <roctx/roctx.hpp>
#include <dlfcn.h>
@@ -29,6 +30,13 @@ hsa_init()
printf("[%s] %s\n", ROCP_REG_FILE_NAME, __FUNCTION__);
}
ncclResult_t
ncclGetVersion(int*)
{
printf("[%s] %s\n", ROCP_REG_FILE_NAME, __FUNCTION__);
return {};
}
void
roctx_range_push(const char* name)
{
@@ -78,6 +86,7 @@ rocprofiler_set_api_table(const char* name,
using hip_table_t = hip::HipApiTable;
using hsa_table_t = hsa::HsaApiTable;
using roctx_table_t = roctx::ROCTxApiTable;
using rccl_table_t = rccl::rcclApiFuncTable;
auto* _wrap_v = std::getenv("ROCP_REG_TEST_WRAP");
bool _wrap = (_wrap_v != nullptr && std::stoi(_wrap_v) != 0);
@@ -107,6 +116,11 @@ rocprofiler_set_api_table(const char* name,
_table->roctxRangePush_fn = &rocprofiler::roctx_range_push;
_table->roctxRangePop_fn = &rocprofiler::roctx_range_pop;
}
else if(std::string_view{ name } == "rccl")
{
rccl_table_t* _table = static_cast<rccl_table_t*>(tables[0]);
_table->ncclGetVersion_fn = &rocprofiler::ncclGetVersion;
}
}
return 0;