From 43ac6b2ef5f545ba05283ca707577df93d963ffd Mon Sep 17 00:00:00 2001 From: Mythreya Kuricheti Date: Tue, 2 Sep 2025 03:47:44 -0700 Subject: [PATCH] [rocprofiler-sdk] Add support for new RCCL API (#771) * [rocprofiler-sdk] Add support for new RCCL API Add support for `ncclAllReduceWithBias` * Move func to be in sync with rccl header --- .../include/rocprofiler-sdk/cxx/enum_string.hpp | 13 +++++++++++++ .../include/rocprofiler-sdk/rccl/api_args.h | 14 +++++++++++++- .../source/include/rocprofiler-sdk/rccl/api_id.h | 4 +++- .../rocprofiler-sdk/rccl/details/api_trace.h | 12 ++++++++++-- .../source/lib/rocprofiler-sdk/rccl/abi.cpp | 15 ++++++++++++--- .../source/lib/rocprofiler-sdk/rccl/rccl.def.cpp | 3 +++ 6 files changed, 54 insertions(+), 7 deletions(-) diff --git a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/cxx/enum_string.hpp b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/cxx/enum_string.hpp index 975b250ebe..edc98375f3 100644 --- a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/cxx/enum_string.hpp +++ b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/cxx/enum_string.hpp @@ -1106,7 +1106,20 @@ ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_mscclRunAlgo); ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo); ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclCommRegister); ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclCommDeregister); +#if RCCL_API_TRACE_VERSION_PATCH >= 1 +ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias); +#endif + +#if RCCL_API_TRACE_VERSION_PATCH == 0 static_assert(ROCPROFILER_RCCL_API_ID_LAST == 37); +#elif RCCL_API_TRACE_VERSION_PATCH == 1 +static_assert(ROCPROFILER_RCCL_API_ID_LAST == 38); +#else +# if !defined(ROCPROFILER_UNSAFE_NO_VERSION_CHECK) && \ + (defined(ROCPROFILER_CI) && ROCPROFILER_CI > 0) +static_assert(false, "Support for new RCCL_API_TRACE_VERSION_PATCH enumerations is required"); +# endif +#endif // rocprofiler_rocdecode_api_id_t ROCPROFILER_ENUM_INFO(rocprofiler_rocdecode_api_id_t, 0, ROCPROFILER_ROCDECODE_API_ID_LAST, false) diff --git a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_args.h b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_args.h index 3af8d07a42..769d3951d2 100644 --- a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_args.h +++ b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_args.h @@ -317,7 +317,19 @@ typedef union rocprofiler_rccl_api_args_t ncclComm_t comm; void* handle; } ncclCommDeregister; - +#if RCCL_API_TRACE_VERSION_PATCH >= 1 + struct + { + const void* sendbuff; + void* recvbuff; + size_t count; + ncclDataType_t datatype; + ncclRedOp_t op; + struct ncclComm* comm; + hipStream_t stream; + const void* acc; + } ncclAllReduceWithBias; +#endif } rocprofiler_rccl_api_args_t; ROCPROFILER_EXTERN_C_FINI diff --git a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_id.h b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_id.h index 1c668ab355..7cfb8a4a4c 100644 --- a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_id.h +++ b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/api_id.h @@ -68,6 +68,8 @@ typedef enum rocprofiler_rccl_api_id_t // NOLINT(performance-enum-size) ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo, ROCPROFILER_RCCL_API_ID_ncclCommRegister, ROCPROFILER_RCCL_API_ID_ncclCommDeregister, - +#if RCCL_API_TRACE_VERSION_PATCH >= 1 + ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias, +#endif ROCPROFILER_RCCL_API_ID_LAST, } rocprofiler_rccl_api_id_t; diff --git a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/details/api_trace.h b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/details/api_trace.h index f1a0f473a9..1c8694b5ae 100644 --- a/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/details/api_trace.h +++ b/projects/rocprofiler-sdk/source/include/rocprofiler-sdk/rccl/details/api_trace.h @@ -47,7 +47,7 @@ #define RCCL_API_TRACE_VERSION_MAJOR 0 // should be increased every time new members are added to existing dispatch tables -#define RCCL_API_TRACE_VERSION_PATCH 0 +#define RCCL_API_TRACE_VERSION_PATCH 1 #if !defined(RCCL_EXTERN_C_INIT) # ifdef __cplusplus @@ -81,6 +81,14 @@ typedef ncclResult_t (*ncclAllReduce_fn_t)(const void* sendbuff, ncclRedOp_t op, struct ncclComm* comm, hipStream_t stream); +typedef ncclResult_t (*ncclAllReduceWithBias_fn_t)(const void* sendbuff, + void* recvbuff, + size_t count, + ncclDataType_t datatype, + ncclRedOp_t op, + struct ncclComm* comm, + hipStream_t stream, + const void* acc); typedef ncclResult_t (*ncclAllToAll_fn_t)(const void* sendbuff, void* recvbuff, size_t count, @@ -264,7 +272,7 @@ typedef struct rcclApiFuncTable mscclUnloadAlgo_fn_t mscclUnloadAlgo_fn; ncclCommRegister_fn_t ncclCommRegister_fn; ncclCommDeregister_fn_t ncclCommDeregister_fn; - + ncclAllReduceWithBias_fn_t ncclAllReduceWithBias_fn; } rcclApiFuncTable; RCCL_EXTERN_C_FINI diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/abi.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/abi.cpp index bf516162b0..a177c5281c 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/abi.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/abi.cpp @@ -33,9 +33,6 @@ namespace rocprofiler namespace rccl { static_assert(RCCL_API_TRACE_VERSION_MAJOR == 0, "Major version updated for RCCL dispatch table"); -static_assert(RCCL_API_TRACE_VERSION_PATCH == 0, "Patch version updated for RCCL dispatch table"); - -ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 37) ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllGather_fn, 0) ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllReduce_fn, 1) @@ -74,5 +71,17 @@ ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, mscclRunAlgo_fn, 33) ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, mscclUnloadAlgo_fn, 34) ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclCommRegister_fn, 35) ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclCommDeregister_fn, 36) +#if RCCL_API_TRACE_VERSION_PATCH >= 1 +ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllReduceWithBias_fn, 37) +#endif + +#if RCCL_API_TRACE_VERSION_PATCH == 0 +ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 37) +#elif RCCL_API_TRACE_VERSION_PATCH == 1 +ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 38) +#else +INTERNAL_CI_ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 0) +#endif + } // namespace rccl } // namespace rocprofiler diff --git a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/rccl.def.cpp b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/rccl.def.cpp index 1b9f98c6a3..324ad8c0f9 100644 --- a/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/rccl.def.cpp +++ b/projects/rocprofiler-sdk/source/lib/rocprofiler-sdk/rccl/rccl.def.cpp @@ -102,6 +102,9 @@ RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ms RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo, mscclUnloadAlgo, mscclUnloadAlgo_fn, mscclAlgoHandle) RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclCommRegister, ncclCommRegister, ncclCommRegister_fn, comm, buff, size, handle) RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclCommDeregister, ncclCommDeregister, ncclCommDeregister_fn, comm, handle) +#if RCCL_API_TRACE_VERSION_PATCH >= 1 +RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias, ncclAllReduceWithBias, ncclAllReduceWithBias_fn, sendbuff, recvbuff, count, datatype, op, comm, stream, acc) +#endif #else # error \