[rocprofiler-sdk] Add support for new RCCL API (#771)
* [rocprofiler-sdk] Add support for new RCCL API Add support for `ncclAllReduceWithBias` * Move func to be in sync with rccl header
Este cometimento está contido em:
cometido por
GitHub
ascendente
4174508dcc
cometimento
43ac6b2ef5
@@ -1106,7 +1106,20 @@ ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_mscclRunAlgo);
|
||||
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo);
|
||||
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclCommRegister);
|
||||
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclCommDeregister);
|
||||
#if RCCL_API_TRACE_VERSION_PATCH >= 1
|
||||
ROCPROFILER_ENUM_LABEL(ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias);
|
||||
#endif
|
||||
|
||||
#if RCCL_API_TRACE_VERSION_PATCH == 0
|
||||
static_assert(ROCPROFILER_RCCL_API_ID_LAST == 37);
|
||||
#elif RCCL_API_TRACE_VERSION_PATCH == 1
|
||||
static_assert(ROCPROFILER_RCCL_API_ID_LAST == 38);
|
||||
#else
|
||||
# if !defined(ROCPROFILER_UNSAFE_NO_VERSION_CHECK) && \
|
||||
(defined(ROCPROFILER_CI) && ROCPROFILER_CI > 0)
|
||||
static_assert(false, "Support for new RCCL_API_TRACE_VERSION_PATCH enumerations is required");
|
||||
# endif
|
||||
#endif
|
||||
|
||||
// rocprofiler_rocdecode_api_id_t
|
||||
ROCPROFILER_ENUM_INFO(rocprofiler_rocdecode_api_id_t, 0, ROCPROFILER_ROCDECODE_API_ID_LAST, false)
|
||||
|
||||
@@ -317,7 +317,19 @@ typedef union rocprofiler_rccl_api_args_t
|
||||
ncclComm_t comm;
|
||||
void* handle;
|
||||
} ncclCommDeregister;
|
||||
|
||||
#if RCCL_API_TRACE_VERSION_PATCH >= 1
|
||||
struct
|
||||
{
|
||||
const void* sendbuff;
|
||||
void* recvbuff;
|
||||
size_t count;
|
||||
ncclDataType_t datatype;
|
||||
ncclRedOp_t op;
|
||||
struct ncclComm* comm;
|
||||
hipStream_t stream;
|
||||
const void* acc;
|
||||
} ncclAllReduceWithBias;
|
||||
#endif
|
||||
} rocprofiler_rccl_api_args_t;
|
||||
|
||||
ROCPROFILER_EXTERN_C_FINI
|
||||
|
||||
@@ -68,6 +68,8 @@ typedef enum rocprofiler_rccl_api_id_t // NOLINT(performance-enum-size)
|
||||
ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo,
|
||||
ROCPROFILER_RCCL_API_ID_ncclCommRegister,
|
||||
ROCPROFILER_RCCL_API_ID_ncclCommDeregister,
|
||||
|
||||
#if RCCL_API_TRACE_VERSION_PATCH >= 1
|
||||
ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias,
|
||||
#endif
|
||||
ROCPROFILER_RCCL_API_ID_LAST,
|
||||
} rocprofiler_rccl_api_id_t;
|
||||
|
||||
+10
-2
@@ -47,7 +47,7 @@
|
||||
#define RCCL_API_TRACE_VERSION_MAJOR 0
|
||||
|
||||
// should be increased every time new members are added to existing dispatch tables
|
||||
#define RCCL_API_TRACE_VERSION_PATCH 0
|
||||
#define RCCL_API_TRACE_VERSION_PATCH 1
|
||||
|
||||
#if !defined(RCCL_EXTERN_C_INIT)
|
||||
# ifdef __cplusplus
|
||||
@@ -81,6 +81,14 @@ typedef ncclResult_t (*ncclAllReduce_fn_t)(const void* sendbuff,
|
||||
ncclRedOp_t op,
|
||||
struct ncclComm* comm,
|
||||
hipStream_t stream);
|
||||
typedef ncclResult_t (*ncclAllReduceWithBias_fn_t)(const void* sendbuff,
|
||||
void* recvbuff,
|
||||
size_t count,
|
||||
ncclDataType_t datatype,
|
||||
ncclRedOp_t op,
|
||||
struct ncclComm* comm,
|
||||
hipStream_t stream,
|
||||
const void* acc);
|
||||
typedef ncclResult_t (*ncclAllToAll_fn_t)(const void* sendbuff,
|
||||
void* recvbuff,
|
||||
size_t count,
|
||||
@@ -264,7 +272,7 @@ typedef struct rcclApiFuncTable
|
||||
mscclUnloadAlgo_fn_t mscclUnloadAlgo_fn;
|
||||
ncclCommRegister_fn_t ncclCommRegister_fn;
|
||||
ncclCommDeregister_fn_t ncclCommDeregister_fn;
|
||||
|
||||
ncclAllReduceWithBias_fn_t ncclAllReduceWithBias_fn;
|
||||
} rcclApiFuncTable;
|
||||
|
||||
RCCL_EXTERN_C_FINI
|
||||
|
||||
@@ -33,9 +33,6 @@ namespace rocprofiler
|
||||
namespace rccl
|
||||
{
|
||||
static_assert(RCCL_API_TRACE_VERSION_MAJOR == 0, "Major version updated for RCCL dispatch table");
|
||||
static_assert(RCCL_API_TRACE_VERSION_PATCH == 0, "Patch version updated for RCCL dispatch table");
|
||||
|
||||
ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 37)
|
||||
|
||||
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllGather_fn, 0)
|
||||
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllReduce_fn, 1)
|
||||
@@ -74,5 +71,17 @@ ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, mscclRunAlgo_fn, 33)
|
||||
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, mscclUnloadAlgo_fn, 34)
|
||||
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclCommRegister_fn, 35)
|
||||
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclCommDeregister_fn, 36)
|
||||
#if RCCL_API_TRACE_VERSION_PATCH >= 1
|
||||
ROCP_SDK_ENFORCE_ABI(::rcclApiFuncTable, ncclAllReduceWithBias_fn, 37)
|
||||
#endif
|
||||
|
||||
#if RCCL_API_TRACE_VERSION_PATCH == 0
|
||||
ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 37)
|
||||
#elif RCCL_API_TRACE_VERSION_PATCH == 1
|
||||
ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 38)
|
||||
#else
|
||||
INTERNAL_CI_ROCP_SDK_ENFORCE_ABI_VERSIONING(::rcclApiFuncTable, 0)
|
||||
#endif
|
||||
|
||||
} // namespace rccl
|
||||
} // namespace rocprofiler
|
||||
|
||||
@@ -102,6 +102,9 @@ RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ms
|
||||
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_mscclUnloadAlgo, mscclUnloadAlgo, mscclUnloadAlgo_fn, mscclAlgoHandle)
|
||||
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclCommRegister, ncclCommRegister, ncclCommRegister_fn, comm, buff, size, handle)
|
||||
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclCommDeregister, ncclCommDeregister, ncclCommDeregister_fn, comm, handle)
|
||||
#if RCCL_API_TRACE_VERSION_PATCH >= 1
|
||||
RCCL_API_INFO_DEFINITION_V(ROCPROFILER_RCCL_TABLE_ID, ROCPROFILER_RCCL_API_ID_ncclAllReduceWithBias, ncclAllReduceWithBias, ncclAllReduceWithBias_fn, sendbuff, recvbuff, count, datatype, op, comm, stream, acc)
|
||||
#endif
|
||||
|
||||
#else
|
||||
# error \
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador