2
0

[rocprofiler-sdk] Add support for new RCCL API (#771)

* [rocprofiler-sdk] Add support for new RCCL API

Add support for `ncclAllReduceWithBias`

* Move func to be in sync with rccl header
Este cometimento está contido em:
Mythreya Kuricheti
2025-09-02 03:47:44 -07:00
cometido por GitHub
ascendente 4174508dcc
cometimento 43ac6b2ef5
6 ficheiros modificados com 54 adições e 7 eliminações
@@ -1106,7 +1106,20 @@ ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_mscclRunAlgo);
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo);
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclCommRegister);
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclCommDeregister);
#if RCCL_API_TRACE_VERSION_PATCH >= 1
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias);
#endif
#if RCCL_API_TRACE_VERSION_PATCH == 0
static_assert(ROCPROFILER_RCCL_API_ID_LAST == 37);
#elif RCCL_API_TRACE_VERSION_PATCH == 1
static_assert(ROCPROFILER_RCCL_API_ID_LAST == 38);
#else
# if !defined(ROCPROFILER_UNSAFE_NO_VERSION_CHECK) && \
(defined(ROCPROFILER_CI) && ROCPROFILER_CI > 0)
static_assert(false, "Support for new RCCL_API_TRACE_VERSION_PATCH enumerations is required");
# endif
#endif
// rocprofiler_rocdecode_api_id_t
ROCPROFILER_ENUM_INFO(rocprofiler_rocdecode_api_id_t, 0, ROCPROFILER_ROCDECODE_API_ID_LAST, false)
@@ -317,7 +317,19 @@ typedef union rocprofiler_rccl_api_args_t
ncclComm_t comm;
void* handle;
} ncclCommDeregister;
#if RCCL_API_TRACE_VERSION_PATCH >= 1
struct
{
const void* sendbuff;
void* recvbuff;
size_t count;
ncclDataType_t datatype;
ncclRedOp_t op;
struct ncclComm* comm;
hipStream_t stream;
const void* acc;
} ncclAllReduceWithBias;
#endif
} rocprofiler_rccl_api_args_t;
ROCPROFILER_EXTERN_C_FINI
@@ -68,6 +68,8 @@ typedef enum rocprofiler_rccl_api_id_t // NOLINT(performance-enum-size)
ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo,
ROCPROFILER_RCCL_API_ID_ncclCommRegister,
ROCPROFILER_RCCL_API_ID_ncclCommDeregister,
#if RCCL_API_TRACE_VERSION_PATCH >= 1
ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias,
#endif
ROCPROFILER_RCCL_API_ID_LAST,
} rocprofiler_rccl_api_id_t;
@@ -47,7 +47,7 @@
#define RCCL_API_TRACE_VERSION_MAJOR 0
// should be increased every time new members are added to existing dispatch tables
#define RCCL_API_TRACE_VERSION_PATCH 0
#define RCCL_API_TRACE_VERSION_PATCH 1
#if !defined(RCCL_EXTERN_C_INIT)
# ifdef __cplusplus
@@ -81,6 +81,14 @@ typedef ncclResult_t (*ncclAllReduce_fn_t)(const void* sendbuff,
ncclRedOp_t op,
struct ncclComm* comm,
hipStream_t stream);
typedef ncclResult_t (*ncclAllReduceWithBias_fn_t)(const void* sendbuff,
void* recvbuff,
size_t count,
ncclDataType_t datatype,
ncclRedOp_t op,
struct ncclComm* comm,
hipStream_t stream,
const void* acc);
typedef ncclResult_t (*ncclAllToAll_fn_t)(const void* sendbuff,
void* recvbuff,
size_t count,
@@ -264,7 +272,7 @@ typedef struct rcclApiFuncTable
mscclUnloadAlgo_fn_t mscclUnloadAlgo_fn;
ncclCommRegister_fn_t ncclCommRegister_fn;
ncclCommDeregister_fn_t ncclCommDeregister_fn;
ncclAllReduceWithBias_fn_t ncclAllReduceWithBias_fn;
} rcclApiFuncTable;
RCCL_EXTERN_C_FINI
@@ -33,9 +33,6 @@ namespace rocprofiler
namespace rccl
{
static_assert(RCCL_API_TRACE_VERSION_MAJOR == 0, "Major version updated for RCCL dispatch table");
static_assert(RCCL_API_TRACE_VERSION_PATCH == 0, "Patch version updated for RCCL dispatch table");
ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 37)
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllGather_fn, 0)
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllReduce_fn, 1)
@@ -74,5 +71,17 @@ ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, mscclRunAlgo_fn, 33)
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, mscclUnloadAlgo_fn, 34)
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclCommRegister_fn, 35)
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclCommDeregister_fn, 36)
#if RCCL_API_TRACE_VERSION_PATCH >= 1
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllReduceWithBias_fn, 37)
#endif
#if RCCL_API_TRACE_VERSION_PATCH == 0
ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 37)
#elif RCCL_API_TRACE_VERSION_PATCH == 1
ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 38)
#else
INTERNAL_CI_ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 0)
#endif
} // namespace rccl
} // namespace rocprofiler
@@ -102,6 +102,9 @@ RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ms
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo, mscclUnloadAlgo, mscclUnloadAlgo_fn, mscclAlgoHandle)
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclCommRegister, ncclCommRegister, ncclCommRegister_fn, comm, buff, size, handle)
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclCommDeregister, ncclCommDeregister, ncclCommDeregister_fn, comm, handle)
#if RCCL_API_TRACE_VERSION_PATCH >= 1
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias, ncclAllReduceWithBias, ncclAllReduceWithBias_fn, sendbuff, recvbuff, count, datatype, op, comm, stream, acc)
#endif
#else
# error \