Implement ROCTX (#1094)

* Implement roctx
This commit is contained in:
Bertan Dogancay
2024-02-27 15:46:15 -07:00
zatwierdzone przez GitHub
rodzic dae6df6d16
commit b617aecc31
9 zmienionych plików z 364 dodań i 22 usunięć
+16 -10
Wyświetl plik
@@ -19,7 +19,7 @@ option(COLLTRACE "Collective Trace Option"
option(ENABLE_MSCCL_KERNEL "Enable MSCCL while compiling" ON)
option(ENABLE_IFC "Enable indirect function call" OFF)
option(INSTALL_DEPENDENCIES "Force install dependencies" OFF)
option(NVTX "Enable NVTX" OFF)
option(ROCTX "Enable ROCTX" OFF)
option(PROFILE "Enable profiling" OFF)
option(TIMETRACE "Enable time-trace during compilation" OFF)
option(TRACE "Enable additional tracing" OFF)
@@ -373,6 +373,7 @@ set(SRC_FILES
src/include/rccl_vars.h
src/include/rocm_smi_wrap.h
src/include/rocmwrap.h
src/include/roctx.h
src/include/shm.h
src/include/signals.h
src/include/socket.h
@@ -448,6 +449,7 @@ set(SRC_FILES
src/misc/profiler.cc
src/misc/rocm_smi_wrap.cc
src/misc/rocmwrap.cc
src/misc/roctx.cc
src/misc/shmutils.cc
src/misc/signals.cc
src/misc/socket.cc
@@ -547,9 +549,6 @@ if(DEMANGLE_DIR)
endif()
## Set RCCL compile definitions
if(NOT NVTX)
target_compile_definitions(rccl PRIVATE NVTX_NO_IMPL)
endif()
if(COLLTRACE)
target_compile_definitions(rccl PRIVATE ENABLE_COLLTRACE)
endif()
@@ -565,6 +564,11 @@ endif()
if(PROFILE)
target_compile_definitions(rccl PRIVATE ENABLE_PROFILING)
endif()
if(NOT ROCTX)
target_compile_definitions(rccl PRIVATE NVTX_NO_IMPL)
target_compile_definitions(rccl PRIVATE ROCTX_NO_IMPL)
target_compile_definitions(rccl PRIVATE NVTX_DISABLE)
endif()
if(TRACE)
target_compile_definitions(rccl PRIVATE ENABLE_TRACE)
endif()
@@ -638,7 +642,15 @@ if (HAS_BFD)
target_link_libraries(rccl PRIVATE iberty z)
endif()
endif()
if (ROCTX)
target_link_libraries(rccl PRIVATE -lroctx64)
endif()
target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code
target_link_libraries(rccl PRIVATE Threads::Threads)
target_link_libraries(rccl INTERFACE hip::host)
target_link_libraries(rccl PRIVATE hip::device)
target_link_libraries(rccl PRIVATE dl)
target_link_libraries(rccl PRIVATE ${ROCM_SMI_LIBRARIES})
## Set RCCL link options
target_link_options(rccl PRIVATE -parallel-jobs=16) # Use multiple threads to link
@@ -663,12 +675,6 @@ if (HAVE_KERNARG_PRELOAD)
target_link_options(rccl PRIVATE -Xoffload-linker -mllvm=-amdgpu-kernarg-preload-count=16)
endif()
target_link_libraries(rccl PRIVATE Threads::Threads)
target_link_libraries(rccl INTERFACE hip::host)
target_link_libraries(rccl PRIVATE hip::device)
target_link_libraries(rccl PRIVATE dl)
target_link_libraries(rccl PRIVATE ${ROCM_SMI_LIBRARIES})
## Track linking time
set_property(TARGET rccl PROPERTY RULE_LAUNCH_LINK "${CMAKE_COMMAND} -E time")
+1
Wyświetl plik
@@ -37,6 +37,7 @@ The root of this repository has a helper script 'install.sh' to build and instal
-l|--local_gpu_only Only compile for local GPU architecture
--no_clean Don't delete files if they already exist
--npkit-enable Compile with npkit enabled
--roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program)
-p|--package_build Build RCCL package
--prefix Specify custom directory to install RCCL to (default: /opt/rocm)
--rm-legacy-include-dir Remove legacy include dir Packaging added for file/folder reorg backward compatibility
+7 -7
Wyświetl plik
@@ -25,7 +25,7 @@ install_library=false
msccl_kernel_enabled=true
num_parallel_jobs=$(nproc)
npkit_enabled=false
nvtx_enabled=false
roctx_enabled=false
run_tests=false
run_tests_all=false
time_trace=false
@@ -51,7 +51,7 @@ function display_help()
echo " --amdgpu_targets Only compile for specified GPU architecture(s). For multiple targets, seperate by ';' (builds for all supported GPU architectures by default)"
echo " --no_clean Don't delete files if they already exist"
echo " --npkit-enable Compile with npkit enabled"
echo " --nvtx-enable Compile with nvtx enabled"
echo " --roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program)"
echo " -p|--package_build Build RCCL package"
echo " --prefix Specify custom directory to install RCCL to (default: /opt/rocm)"
echo " --rm-legacy-include-dir Remove legacy include dir Packaging added for file/folder reorg backward compatibility"
@@ -70,7 +70,7 @@ function display_help()
# check if we have a modern version of getopt that can handle whitespace and long parameters
getopt -T
if [[ $? -eq 4 ]]; then
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,nvtx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
else
echo "Need a new version of getopt"
exit 1
@@ -99,7 +99,7 @@ while true; do
--amdgpu_targets) build_amdgpu_targets=${2}; shift 2 ;;
--no_clean) clean_build=false; shift ;;
--npkit-enable) npkit_enabled=true; shift ;;
--nvtx-enable) nvtx_enabled=true; shift ;;
--roctx-enable) roctx_enabled=true; shift ;;
-p | --package_build) build_package=true; shift ;;
--prefix) install_prefix=${2}; shift 2 ;;
--rm-legacy-include-dir) build_freorg_bkwdcomp=false; shift ;;
@@ -223,9 +223,9 @@ if ($install_dependencies); then
cmake_common_options="${cmake_common_options} -DINSTALL_DEPENDENCIES=ON"
fi
# Enable NVTX
if [[ "${nvtx_enabled}" == true ]]; then
cmake_common_options="${cmake_common_options} -DNVTX=ON"
# Enable ROCTX
if [[ "${roctx_enabled}" == true ]]; then
cmake_common_options="${cmake_common_options} -DROCTX=ON"
fi
cmake_executable=cmake
+41
Wyświetl plik
@@ -68,6 +68,13 @@ NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_
ncclComm_t comm, hipStream_t stream);
ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, hipStream_t stream) {
// Just pass the size of one message and not the total bytes sent/received.
constexpr nvtxPayloadSchemaEntry_t AllToAllSchema[] = {
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"}
};
size_t msgsize = count * ncclTypeSize(datatype);
NVTX3_FUNC_WITH_PARAMS(AllToAll, AllToAllSchema, msgsize)
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
@@ -103,6 +110,18 @@ NCCL_API(ncclResult_t, ncclAllToAllv, const void *sendbuff, const size_t sendcou
ncclResult_t ncclAllToAllv(const void *sendbuff, const size_t sendcounts[], const size_t sdispls[],
void *recvbuff, const size_t recvcounts[], const size_t rdispls[],
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) {
struct NvtxParamsAllToAllv {
size_t sendbytes;
size_t recvbytes;
};
// Just pass the size of one send/recv messages and not the total bytes sent/received.
constexpr nvtxPayloadSchemaEntry_t AllToAllvSchema[] = {
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Send)"},
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Recv)"}
};
NvtxParamsAllToAllv payload{sendcounts[comm->rank] * ncclTypeSize(datatype), recvcounts[comm->rank] * ncclTypeSize(datatype)};
NVTX3_FUNC_WITH_PARAMS(AllToAllv, AllToAllvSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
@@ -170,6 +189,17 @@ NCCL_API(ncclResult_t, ncclGather, const void* sendbuff, void* recvbuff, size_t
ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream);
ncclResult_t ncclGather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream) {
struct NvtxParamsGather {
size_t bytes;
int root;
};
constexpr nvtxPayloadSchemaEntry_t GatherSchema[] = {
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"},
{0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsGather, root)}
};
NvtxParamsGather payload{sendcount * ncclTypeSize(datatype), root};
NVTX3_FUNC_WITH_PARAMS(Gather, GatherSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
@@ -254,6 +284,17 @@ NCCL_API(ncclResult_t, ncclScatter, const void* sendbuff, void* recvbuff, size_t
ncclComm_t comm, hipStream_t stream);
ncclResult_t ncclScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, int root,
ncclComm_t comm, hipStream_t stream) {
struct NvtxParamsScatter {
size_t bytes;
int root;
};
constexpr nvtxPayloadSchemaEntry_t ScatterSchema[] = {
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"},
{0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsScatter, root)}
};
NvtxParamsScatter payload{recvcount * ncclTypeSize(datatype), root};
NVTX3_FUNC_WITH_PARAMS(Scatter, ScatterSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
+18 -5
Wyświetl plik
@@ -8,6 +8,7 @@
#define NCCL_NVTX_H_
#include "nvtx3/nvtx3.hpp"
#include "roctx.h"
#if __cpp_constexpr >= 201304L && !defined(NVTX3_CONSTEXPR_IF_CPP14)
#define NVTX3_CONSTEXPR_IF_CPP14 constexpr
@@ -22,11 +23,16 @@
#define NVTX_SID_CommAbort 3 // same schema as NVTX_SID_CommInitRank
#define NVTX_SID_AllGather 4
#define NVTX_SID_AllReduce 5
#define NVTX_SID_Broadcast 6
#define NVTX_SID_ReduceScatter 7
#define NVTX_SID_Reduce 8
#define NVTX_SID_Send 9
#define NVTX_SID_Recv 10
#define NVTX_SID_AllToAll 6
#define NVTX_SID_AllToAllv 7
#define NVTX_SID_Broadcast 8
#define NVTX_SID_Gather 9
#define NVTX_SID_MSCCL 10
#define NVTX_SID_ReduceScatter 11
#define NVTX_SID_Reduce 12
#define NVTX_SID_Scatter 13
#define NVTX_SID_Send 14
#define NVTX_SID_Recv 15
// Define static schema ID for the reduction operation.
#define NVTX_PAYLOAD_ENTRY_NCCL_REDOP 11 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START
@@ -71,6 +77,12 @@ class payload_schema {
// @param N schema name
// @param S schema (entries)
// @param P payload (struct)
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
#define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \
nvtxPayloadData_t nvtx3_bpl__[] = { \
{NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \
roctx_scoped_range_in const roctx_range__{S, nvtx3_bpl__, std::extent<decltype(S)>::value, "RCCL_" #ID};
#else
#define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \
static const payload_schema schema{S, std::extent<decltype(S)>::value, \
NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, #ID}; \
@@ -79,6 +91,7 @@ class payload_schema {
{NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \
::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \
::nvtx3::v1::scoped_range_in<nccl_domain> const nvtx3_range__{nvtx3_func_attr__};
#endif
extern void initNvtxRegisteredEnums();
+5
Wyświetl plik
@@ -2777,10 +2777,15 @@ inline void mark(Args const&... args) noexcept
* `domain` to which the `registered_string_in` belongs. Else,
* `domain::global` to indicate that the global NVTX domain should be used.
*/
#if !defined(__HIP_PLATFORM_HCC__) && !defined(__HCC__) && !defined(__HIPCC__)
#define NVTX3_V1_FUNC_RANGE_IN(D) \
static ::nvtx3::v1::registered_string_in<D> const nvtx3_func_name__{__func__}; \
static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \
::nvtx3::v1::scoped_range_in<D> const nvtx3_range__{nvtx3_func_attr__};
#else
#define NVTX3_V1_FUNC_RANGE_IN(D) \
roctx_scoped_range_in const roctx_range__{__func__};
#endif
/**
* @brief Convenience macro for generating a range in the specified `domain`
+165
Wyświetl plik
@@ -0,0 +1,165 @@
/*************************************************************************
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef RCCL_ROCTX_H
#define RCCL_ROCTX_H
#include <iostream>
#include <string.h>
#include <map>
#include <roctracer/roctx.h>
#include "nvtx3/nvtx3.hpp"
#include "device.h"
#define MAX_MESSAGE_LENGTH 1024
#define NVTX_PAYLOAD_ENTRY_TYPE_REDOP 11 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START
/**
* \brief Equivalent of nvtx types for roctx.
*/
enum roctxPayloadEntryType {
/**
* Only include the required/used types by rccl,
* and needs to be updated in case of new type
* tracing.
*/
ROCTX_PAYLOAD_ENTRY_TYPE_INT,
ROCTX_PAYLOAD_ENTRY_TYPE_SIZE,
ROCTX_PAYLOAD_ENTRY_TYPE_REDOP,
ROCTX_PAYLOAD_NUM_ENTRY_TYPES
};
/**
* \brief Stores the contents of the message that will be used by roctx.
*/
struct roctxPayloadSchemaEntryInfo {
/**
* Description of the data.
*/
const char* name;
/**
* Type of the data.
*/
roctxPayloadEntryType type;
/**
* Union of possible payload types.
*
* Should be in sync with roctxPayloadEntryType.
*/
union {
int typeInt;
size_t typeSize;
ncclDevRedOp_t typeRedOp;
} payload;
};
struct roctxPayloadInfo {
/**
* Payload name. Usually the name of the function/API
* being called from.
*/
const char* id;
/**
* Number of paylod entries
*/
size_t numEntries;
/**
* Pointer to roctxPayloadSchemaEntryInfo elements in memory.
*/
roctxPayloadSchemaEntryInfo* payloadEntries = nullptr;
/**
* Message that will be used by roctx
*/
char* message = nullptr;
};
typedef roctxPayloadInfo* roctxPayloadInfo_t;
extern const char* roctxEntryTypeStr[ROCTX_PAYLOAD_NUM_ENTRY_TYPES];
extern const char* ncclRedOpStr[ncclNumDevRedOps];
/**
* \brief Maps nvtx types to roctx types.
*/
extern std::map<uint64_t, roctxPayloadEntryType> nvtxToRoctx;
/**
* \brief Allocate required memory for roctx
*/
void roctxAlloc(roctxPayloadInfo_t payloadInfo, const size_t numEntries);
/**
* \brief Free all the resources used by roctx
*/
void roctxFree(roctxPayloadInfo_t payloadInfo);
/**
* \brief Extracts payload schema entry info from nvtxPayloadSchemaEntry_t and,
* nvtxPayloadData_t and stores in an array.
*/
void extractPayloadInfo(const nvtxPayloadSchemaEntry_t* schema, const nvtxPayloadData_t* data, const size_t numEntries,
const char* schemaName, roctxPayloadInfo_t payloadInfo);
/**
* \brief Stringify roctxPayloadInfo_t struct. Used as roctx message.
*/
void stringify(roctxPayloadInfo_t payloadInfo);
/**
* \brief Class to make roctx calls scoped.
*/
class roctx_scoped_range_in {
public:
/**
* Construct a 'roctx_scoped_range_in' with specified NVTX params,
* 'numEntries', and 'schemaName'
*/
explicit roctx_scoped_range_in(const nvtxPayloadSchemaEntry_t* schema, const nvtxPayloadData_t* data,
const size_t numEntries, const char* schemaName) noexcept
{
#ifndef ROCTX_NO_IMPL
roctxAlloc(&payloadInfo, numEntries);
extractPayloadInfo(schema, data, numEntries, schemaName, &payloadInfo);
roctxRangePushA(payloadInfo.message);
#endif
}
/**
* Construct a 'roctx_scoped_range_in' with the specified 'message'
*/
explicit roctx_scoped_range_in(const char* message) noexcept
{
#ifndef ROCTX_NO_IMPL
roctxRangePushA(message);
#endif
}
/**
* Default constructor 'roctx_scoped_range_in'
*/
roctx_scoped_range_in() noexcept : roctx_scoped_range_in{""} {/*no impl*/}
/**
* Destroy the roctx_scoped_range_in, ending the ROCTX range event.
*/
~roctx_scoped_range_in() noexcept
{
#ifndef ROCTX_NO_IMPL
roctxRangePop();
roctxFree(&payloadInfo);
#endif
}
private:
roctxPayloadInfo payloadInfo;
};
#endif // RCCL_ROCTX_H
+99
Wyświetl plik
@@ -0,0 +1,99 @@
/*************************************************************************
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "roctx.h"
std::map<uint64_t, roctxPayloadEntryType> nvtxToRoctx {
{NVTX_PAYLOAD_ENTRY_TYPE_INT, ROCTX_PAYLOAD_ENTRY_TYPE_INT},
{NVTX_PAYLOAD_ENTRY_TYPE_SIZE, ROCTX_PAYLOAD_ENTRY_TYPE_SIZE},
{NVTX_PAYLOAD_ENTRY_TYPE_REDOP, ROCTX_PAYLOAD_ENTRY_TYPE_REDOP}};
const char* roctxEntryTypeStr[ROCTX_PAYLOAD_NUM_ENTRY_TYPES] = {"ROCTX_PAYLOAD_ENTRY_TYPE_INT", "ROCTX_PAYLOAD_ENTRY_TYPE_SIZE", "ROCTX_PAYLOAD_ENTRY_TYPE_REDOP"};
const char* ncclRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "MinMax", "PreMulSum", "SumPostDiv" };
void roctxAlloc(roctxPayloadInfo_t payloadInfo, const size_t numEntries) {
#ifndef ROCTX_NO_IMPL
// Allocate enough memory for numEntries in payloadEntries
payloadInfo->payloadEntries = (roctxPayloadSchemaEntryInfo*)malloc(numEntries * sizeof(roctxPayloadSchemaEntryInfo));
// Allocate memory for the message that will be constructed
payloadInfo->message = (char*)malloc(MAX_MESSAGE_LENGTH * sizeof(char));
#endif
}
void roctxFree(roctxPayloadInfo_t payloadInfo) {
#ifndef ROCTX_NO_IMPL
// Free all the dynamically allocated resources by roctx
if (payloadInfo->payloadEntries) free(payloadInfo->payloadEntries);
if (payloadInfo->message) free((void*)payloadInfo->message);
#endif
}
void extractPayloadInfo(const nvtxPayloadSchemaEntry_t* schema, const nvtxPayloadData_t* data, const size_t numEntries,
const char* schemaName, roctxPayloadInfo_t payloadInfo) {
if (payloadInfo->payloadEntries == nullptr) return;
payloadInfo->id = schemaName;
payloadInfo->numEntries = numEntries;
// Iterate over each entry in the schema
for (size_t i = 0; i < payloadInfo->numEntries; ++i) {
// Populate payload schema entry info for roctx
payloadInfo->payloadEntries[i].name = schema[i].name;
payloadInfo->payloadEntries[i].type = nvtxToRoctx[schema[i].type];
// Offset to index into the data stored in nvtxPayloadData_t->payload
uint64_t offset = schema[i].offset;
const void* entryData = reinterpret_cast<const char*>(data->payload) + offset;
// Populate payload union based on the roctx type
switch (payloadInfo->payloadEntries[i].type) {
case ROCTX_PAYLOAD_ENTRY_TYPE_INT: payloadInfo->payloadEntries[i].payload.typeInt = *reinterpret_cast<const int*>(entryData); break;
case ROCTX_PAYLOAD_ENTRY_TYPE_SIZE: payloadInfo->payloadEntries[i].payload.typeSize = *reinterpret_cast<const size_t*>(entryData); break;
case ROCTX_PAYLOAD_ENTRY_TYPE_REDOP: payloadInfo->payloadEntries[i].payload.typeRedOp = *reinterpret_cast<const ncclDevRedOp_t*>(entryData); break;
default: break;
}
}
// Stringify payloadInfo
stringify(payloadInfo);
}
void stringify(roctxPayloadInfo_t payloadInfo) {
if (!payloadInfo->payloadEntries || !payloadInfo->message) return;
int offset = snprintf(payloadInfo->message, MAX_MESSAGE_LENGTH, "{%s: ", payloadInfo->id);
for (size_t i = 0; i < payloadInfo->numEntries; ++i)
{
roctxPayloadSchemaEntryInfo entry = payloadInfo->payloadEntries[i];
offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%s: ", entry.name);
switch (entry.type)
{
case ROCTX_PAYLOAD_ENTRY_TYPE_INT:
offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%d", entry.payload.typeInt);
break;
case ROCTX_PAYLOAD_ENTRY_TYPE_SIZE:
offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%zu", entry.payload.typeSize);
break;
case ROCTX_PAYLOAD_ENTRY_TYPE_REDOP:
offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%s",
entry.payload.typeRedOp < ncclNumDevRedOps ? ncclRedOpStr[entry.payload.typeRedOp] : "unknown");
break;
default:
offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "unknown roctx payload type");
break;
}
if (i != payloadInfo->numEntries - 1)
offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, ", ");
}
snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "}");
}
+12
Wyświetl plik
@@ -44,6 +44,18 @@ ncclResult_t mscclRunAlgo(
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream) {
struct NvtxParamsMsccl {
size_t sendbytes;
size_t recvbytes;
};
// Just pass the size of one send/recv messages and not the total bytes sent/received.
constexpr nvtxPayloadSchemaEntry_t MscclSchema[] = {
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Send)"},
{0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Recv)"}
};
NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)};
NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload)
mscclStatus& status = mscclGetStatus();
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];