diff --git a/projects/rccl/CMakeLists.txt b/projects/rccl/CMakeLists.txt index 90786f8600..4204dc5ed2 100644 --- a/projects/rccl/CMakeLists.txt +++ b/projects/rccl/CMakeLists.txt @@ -19,7 +19,7 @@ option(COLLTRACE "Collective Trace Option" option(ENABLE_MSCCL_KERNEL "Enable MSCCL while compiling" ON) option(ENABLE_IFC "Enable indirect function call" OFF) option(INSTALL_DEPENDENCIES "Force install dependencies" OFF) -option(NVTX "Enable NVTX" OFF) +option(ROCTX "Enable ROCTX" OFF) option(PROFILE "Enable profiling" OFF) option(TIMETRACE "Enable time-trace during compilation" OFF) option(TRACE "Enable additional tracing" OFF) @@ -373,6 +373,7 @@ set(SRC_FILES src/include/rccl_vars.h src/include/rocm_smi_wrap.h src/include/rocmwrap.h + src/include/roctx.h src/include/shm.h src/include/signals.h src/include/socket.h @@ -448,6 +449,7 @@ set(SRC_FILES src/misc/profiler.cc src/misc/rocm_smi_wrap.cc src/misc/rocmwrap.cc + src/misc/roctx.cc src/misc/shmutils.cc src/misc/signals.cc src/misc/socket.cc @@ -547,9 +549,6 @@ if(DEMANGLE_DIR) endif() ## Set RCCL compile definitions -if(NOT NVTX) - target_compile_definitions(rccl PRIVATE NVTX_NO_IMPL) -endif() if(COLLTRACE) target_compile_definitions(rccl PRIVATE ENABLE_COLLTRACE) endif() @@ -565,6 +564,11 @@ endif() if(PROFILE) target_compile_definitions(rccl PRIVATE ENABLE_PROFILING) endif() +if(NOT ROCTX) + target_compile_definitions(rccl PRIVATE NVTX_NO_IMPL) + target_compile_definitions(rccl PRIVATE ROCTX_NO_IMPL) + target_compile_definitions(rccl PRIVATE NVTX_DISABLE) +endif() if(TRACE) target_compile_definitions(rccl PRIVATE ENABLE_TRACE) endif() @@ -638,7 +642,15 @@ if (HAS_BFD) target_link_libraries(rccl PRIVATE iberty z) endif() endif() +if (ROCTX) + target_link_libraries(rccl PRIVATE -lroctx64) +endif() target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code +target_link_libraries(rccl PRIVATE Threads::Threads) +target_link_libraries(rccl INTERFACE hip::host) +target_link_libraries(rccl PRIVATE hip::device) +target_link_libraries(rccl PRIVATE dl) +target_link_libraries(rccl PRIVATE ${ROCM_SMI_LIBRARIES}) ## Set RCCL link options target_link_options(rccl PRIVATE -parallel-jobs=16) # Use multiple threads to link @@ -663,12 +675,6 @@ if (HAVE_KERNARG_PRELOAD) target_link_options(rccl PRIVATE -Xoffload-linker -mllvm=-amdgpu-kernarg-preload-count=16) endif() -target_link_libraries(rccl PRIVATE Threads::Threads) -target_link_libraries(rccl INTERFACE hip::host) -target_link_libraries(rccl PRIVATE hip::device) -target_link_libraries(rccl PRIVATE dl) -target_link_libraries(rccl PRIVATE ${ROCM_SMI_LIBRARIES}) - ## Track linking time set_property(TARGET rccl PROPERTY RULE_LAUNCH_LINK "${CMAKE_COMMAND} -E time") diff --git a/projects/rccl/README.md b/projects/rccl/README.md index d6feae34fe..e54311bdce 100644 --- a/projects/rccl/README.md +++ b/projects/rccl/README.md @@ -37,6 +37,7 @@ The root of this repository has a helper script 'install.sh' to build and instal -l|--local_gpu_only Only compile for local GPU architecture --no_clean Don't delete files if they already exist --npkit-enable Compile with npkit enabled + --roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program) -p|--package_build Build RCCL package --prefix Specify custom directory to install RCCL to (default: /opt/rocm) --rm-legacy-include-dir Remove legacy include dir Packaging added for file/folder reorg backward compatibility diff --git a/projects/rccl/install.sh b/projects/rccl/install.sh index 018c7abb33..4705579804 100755 --- a/projects/rccl/install.sh +++ b/projects/rccl/install.sh @@ -25,7 +25,7 @@ install_library=false msccl_kernel_enabled=true num_parallel_jobs=$(nproc) npkit_enabled=false -nvtx_enabled=false +roctx_enabled=false run_tests=false run_tests_all=false time_trace=false @@ -51,7 +51,7 @@ function display_help() echo " --amdgpu_targets Only compile for specified GPU architecture(s). For multiple targets, seperate by ';' (builds for all supported GPU architectures by default)" echo " --no_clean Don't delete files if they already exist" echo " --npkit-enable Compile with npkit enabled" - echo " --nvtx-enable Compile with nvtx enabled" + echo " --roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program)" echo " -p|--package_build Build RCCL package" echo " --prefix Specify custom directory to install RCCL to (default: /opt/rocm)" echo " --rm-legacy-include-dir Remove legacy include dir Packaging added for file/folder reorg backward compatibility" @@ -70,7 +70,7 @@ function display_help() # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ $? -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,nvtx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@") else echo "Need a new version of getopt" exit 1 @@ -99,7 +99,7 @@ while true; do --amdgpu_targets) build_amdgpu_targets=${2}; shift 2 ;; --no_clean) clean_build=false; shift ;; --npkit-enable) npkit_enabled=true; shift ;; - --nvtx-enable) nvtx_enabled=true; shift ;; + --roctx-enable) roctx_enabled=true; shift ;; -p | --package_build) build_package=true; shift ;; --prefix) install_prefix=${2}; shift 2 ;; --rm-legacy-include-dir) build_freorg_bkwdcomp=false; shift ;; @@ -223,9 +223,9 @@ if ($install_dependencies); then cmake_common_options="${cmake_common_options} -DINSTALL_DEPENDENCIES=ON" fi -# Enable NVTX -if [[ "${nvtx_enabled}" == true ]]; then - cmake_common_options="${cmake_common_options} -DNVTX=ON" +# Enable ROCTX +if [[ "${roctx_enabled}" == true ]]; then + cmake_common_options="${cmake_common_options} -DROCTX=ON" fi cmake_executable=cmake diff --git a/projects/rccl/src/collectives.cc b/projects/rccl/src/collectives.cc index d3d3fc81d9..3fa2f83d9d 100644 --- a/projects/rccl/src/collectives.cc +++ b/projects/rccl/src/collectives.cc @@ -68,6 +68,13 @@ NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_ ncclComm_t comm, hipStream_t stream); ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) { + // Just pass the size of one message and not the total bytes sent/received. + constexpr nvtxPayloadSchemaEntry_t AllToAllSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes]"} + }; + size_t msgsize = count * ncclTypeSize(datatype); + NVTX3_FUNC_WITH_PARAMS(AllToAll, AllToAllSchema, msgsize) + if (mscclAvailable() && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, @@ -103,6 +110,18 @@ NCCL_API(ncclResult_t, ncclAllToAllv, const void *sendbuff, const size_t sendcou ncclResult_t ncclAllToAllv(const void *sendbuff, const size_t sendcounts[], const size_t sdispls[], void *recvbuff, const size_t recvcounts[], const size_t rdispls[], ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream) { + struct NvtxParamsAllToAllv { + size_t sendbytes; + size_t recvbytes; + }; + // Just pass the size of one send/recv messages and not the total bytes sent/received. + constexpr nvtxPayloadSchemaEntry_t AllToAllvSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Send)"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Recv)"} + }; + NvtxParamsAllToAllv payload{sendcounts[comm->rank] * ncclTypeSize(datatype), recvcounts[comm->rank] * ncclTypeSize(datatype)}; + NVTX3_FUNC_WITH_PARAMS(AllToAllv, AllToAllvSchema, payload) + if (mscclAvailable() && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls, @@ -170,6 +189,17 @@ NCCL_API(ncclResult_t, ncclGather, const void* sendbuff, void* recvbuff, size_t ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream); ncclResult_t ncclGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream) { + struct NvtxParamsGather { + size_t bytes; + int root; + }; + constexpr nvtxPayloadSchemaEntry_t GatherSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsGather, root)} + }; + NvtxParamsGather payload{sendcount * ncclTypeSize(datatype), root}; + NVTX3_FUNC_WITH_PARAMS(Gather, GatherSchema, payload) + if (mscclAvailable() && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, @@ -254,6 +284,17 @@ NCCL_API(ncclResult_t, ncclScatter, const void* sendbuff, void* recvbuff, size_t ncclComm_t comm, hipStream_t stream); ncclResult_t ncclScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, int root, ncclComm_t comm, hipStream_t stream) { + struct NvtxParamsScatter { + size_t bytes; + int root; + }; + constexpr nvtxPayloadSchemaEntry_t ScatterSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Bytes"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "Root", nullptr, 0, offsetof(NvtxParamsScatter, root)} + }; + NvtxParamsScatter payload{recvcount * ncclTypeSize(datatype), root}; + NVTX3_FUNC_WITH_PARAMS(Scatter, ScatterSchema, payload) + if (mscclAvailable() && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, diff --git a/projects/rccl/src/include/nvtx.h b/projects/rccl/src/include/nvtx.h index ab32ef27f0..b09d67c879 100644 --- a/projects/rccl/src/include/nvtx.h +++ b/projects/rccl/src/include/nvtx.h @@ -8,6 +8,7 @@ #define NCCL_NVTX_H_ #include "nvtx3/nvtx3.hpp" +#include "roctx.h" #if __cpp_constexpr >= 201304L && !defined(NVTX3_CONSTEXPR_IF_CPP14) #define NVTX3_CONSTEXPR_IF_CPP14 constexpr @@ -22,11 +23,16 @@ #define NVTX_SID_CommAbort 3 // same schema as NVTX_SID_CommInitRank #define NVTX_SID_AllGather 4 #define NVTX_SID_AllReduce 5 -#define NVTX_SID_Broadcast 6 -#define NVTX_SID_ReduceScatter 7 -#define NVTX_SID_Reduce 8 -#define NVTX_SID_Send 9 -#define NVTX_SID_Recv 10 +#define NVTX_SID_AllToAll 6 +#define NVTX_SID_AllToAllv 7 +#define NVTX_SID_Broadcast 8 +#define NVTX_SID_Gather 9 +#define NVTX_SID_MSCCL 10 +#define NVTX_SID_ReduceScatter 11 +#define NVTX_SID_Reduce 12 +#define NVTX_SID_Scatter 13 +#define NVTX_SID_Send 14 +#define NVTX_SID_Recv 15 // Define static schema ID for the reduction operation. #define NVTX_PAYLOAD_ENTRY_NCCL_REDOP 11 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START @@ -71,6 +77,12 @@ class payload_schema { // @param N schema name // @param S schema (entries) // @param P payload (struct) +#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) +#define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \ + nvtxPayloadData_t nvtx3_bpl__[] = { \ + {NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \ + roctx_scoped_range_in const roctx_range__{S, nvtx3_bpl__, std::extent::value, "RCCL_" #ID}; +#else #define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \ static const payload_schema schema{S, std::extent::value, \ NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, #ID}; \ @@ -79,6 +91,7 @@ class payload_schema { {NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \ ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ ::nvtx3::v1::scoped_range_in const nvtx3_range__{nvtx3_func_attr__}; +#endif extern void initNvtxRegisteredEnums(); diff --git a/projects/rccl/src/include/nvtx3/nvtx3.hpp b/projects/rccl/src/include/nvtx3/nvtx3.hpp index 8c62acd469..a4cef3849f 100644 --- a/projects/rccl/src/include/nvtx3/nvtx3.hpp +++ b/projects/rccl/src/include/nvtx3/nvtx3.hpp @@ -2777,10 +2777,15 @@ inline void mark(Args const&... args) noexcept * `domain` to which the `registered_string_in` belongs. Else, * `domain::global` to indicate that the global NVTX domain should be used. */ +#if !defined(__HIP_PLATFORM_HCC__) && !defined(__HCC__) && !defined(__HIPCC__) #define NVTX3_V1_FUNC_RANGE_IN(D) \ static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \ ::nvtx3::v1::scoped_range_in const nvtx3_range__{nvtx3_func_attr__}; +#else +#define NVTX3_V1_FUNC_RANGE_IN(D) \ + roctx_scoped_range_in const roctx_range__{__func__}; +#endif /** * @brief Convenience macro for generating a range in the specified `domain` diff --git a/projects/rccl/src/include/roctx.h b/projects/rccl/src/include/roctx.h new file mode 100644 index 0000000000..77751d243c --- /dev/null +++ b/projects/rccl/src/include/roctx.h @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef RCCL_ROCTX_H +#define RCCL_ROCTX_H + +#include +#include +#include + +#include +#include "nvtx3/nvtx3.hpp" +#include "device.h" + +#define MAX_MESSAGE_LENGTH 1024 +#define NVTX_PAYLOAD_ENTRY_TYPE_REDOP 11 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + +/** + * \brief Equivalent of nvtx types for roctx. +*/ +enum roctxPayloadEntryType { + /** + * Only include the required/used types by rccl, + * and needs to be updated in case of new type + * tracing. + */ + ROCTX_PAYLOAD_ENTRY_TYPE_INT, + ROCTX_PAYLOAD_ENTRY_TYPE_SIZE, + ROCTX_PAYLOAD_ENTRY_TYPE_REDOP, + ROCTX_PAYLOAD_NUM_ENTRY_TYPES +}; + +/** + * \brief Stores the contents of the message that will be used by roctx. +*/ +struct roctxPayloadSchemaEntryInfo { + /** + * Description of the data. + */ + const char* name; + + /** + * Type of the data. + */ + roctxPayloadEntryType type; + + /** + * Union of possible payload types. + * + * Should be in sync with roctxPayloadEntryType. + */ + union { + int typeInt; + size_t typeSize; + ncclDevRedOp_t typeRedOp; + } payload; +}; + +struct roctxPayloadInfo { + /** + * Payload name. Usually the name of the function/API + * being called from. + */ + const char* id; + + /** + * Number of paylod entries + */ + size_t numEntries; + + /** + * Pointer to roctxPayloadSchemaEntryInfo elements in memory. + */ + roctxPayloadSchemaEntryInfo* payloadEntries = nullptr; + + /** + * Message that will be used by roctx + */ + char* message = nullptr; +}; + +typedef roctxPayloadInfo* roctxPayloadInfo_t; + +extern const char* roctxEntryTypeStr[ROCTX_PAYLOAD_NUM_ENTRY_TYPES]; +extern const char* ncclRedOpStr[ncclNumDevRedOps]; + +/** + * \brief Maps nvtx types to roctx types. +*/ +extern std::map nvtxToRoctx; + +/** + * \brief Allocate required memory for roctx +*/ +void roctxAlloc(roctxPayloadInfo_t payloadInfo, const size_t numEntries); + +/** + * \brief Free all the resources used by roctx +*/ +void roctxFree(roctxPayloadInfo_t payloadInfo); + +/** + * \brief Extracts payload schema entry info from nvtxPayloadSchemaEntry_t and, + * nvtxPayloadData_t and stores in an array. +*/ +void extractPayloadInfo(const nvtxPayloadSchemaEntry_t* schema, const nvtxPayloadData_t* data, const size_t numEntries, + const char* schemaName, roctxPayloadInfo_t payloadInfo); + +/** + * \brief Stringify roctxPayloadInfo_t struct. Used as roctx message. +*/ +void stringify(roctxPayloadInfo_t payloadInfo); + +/** + * \brief Class to make roctx calls scoped. +*/ +class roctx_scoped_range_in { +public: + /** + * Construct a 'roctx_scoped_range_in' with specified NVTX params, + * 'numEntries', and 'schemaName' + */ + explicit roctx_scoped_range_in(const nvtxPayloadSchemaEntry_t* schema, const nvtxPayloadData_t* data, + const size_t numEntries, const char* schemaName) noexcept + { +#ifndef ROCTX_NO_IMPL + roctxAlloc(&payloadInfo, numEntries); + extractPayloadInfo(schema, data, numEntries, schemaName, &payloadInfo); + roctxRangePushA(payloadInfo.message); +#endif + } + + /** + * Construct a 'roctx_scoped_range_in' with the specified 'message' + */ + explicit roctx_scoped_range_in(const char* message) noexcept + { +#ifndef ROCTX_NO_IMPL + roctxRangePushA(message); +#endif + } + + /** + * Default constructor 'roctx_scoped_range_in' + */ + roctx_scoped_range_in() noexcept : roctx_scoped_range_in{""} {/*no impl*/} + + /** + * Destroy the roctx_scoped_range_in, ending the ROCTX range event. + */ + ~roctx_scoped_range_in() noexcept + { +#ifndef ROCTX_NO_IMPL + roctxRangePop(); + roctxFree(&payloadInfo); +#endif + } +private: + roctxPayloadInfo payloadInfo; +}; + +#endif // RCCL_ROCTX_H \ No newline at end of file diff --git a/projects/rccl/src/misc/roctx.cc b/projects/rccl/src/misc/roctx.cc new file mode 100644 index 0000000000..d3e5deef60 --- /dev/null +++ b/projects/rccl/src/misc/roctx.cc @@ -0,0 +1,99 @@ +/************************************************************************* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "roctx.h" + +std::map nvtxToRoctx { + {NVTX_PAYLOAD_ENTRY_TYPE_INT, ROCTX_PAYLOAD_ENTRY_TYPE_INT}, + {NVTX_PAYLOAD_ENTRY_TYPE_SIZE, ROCTX_PAYLOAD_ENTRY_TYPE_SIZE}, + {NVTX_PAYLOAD_ENTRY_TYPE_REDOP, ROCTX_PAYLOAD_ENTRY_TYPE_REDOP}}; + +const char* roctxEntryTypeStr[ROCTX_PAYLOAD_NUM_ENTRY_TYPES] = {"ROCTX_PAYLOAD_ENTRY_TYPE_INT", "ROCTX_PAYLOAD_ENTRY_TYPE_SIZE", "ROCTX_PAYLOAD_ENTRY_TYPE_REDOP"}; +const char* ncclRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "MinMax", "PreMulSum", "SumPostDiv" }; + +void roctxAlloc(roctxPayloadInfo_t payloadInfo, const size_t numEntries) { +#ifndef ROCTX_NO_IMPL + // Allocate enough memory for numEntries in payloadEntries + payloadInfo->payloadEntries = (roctxPayloadSchemaEntryInfo*)malloc(numEntries * sizeof(roctxPayloadSchemaEntryInfo)); + + // Allocate memory for the message that will be constructed + payloadInfo->message = (char*)malloc(MAX_MESSAGE_LENGTH * sizeof(char)); +#endif +} + +void roctxFree(roctxPayloadInfo_t payloadInfo) { +#ifndef ROCTX_NO_IMPL + // Free all the dynamically allocated resources by roctx + if (payloadInfo->payloadEntries) free(payloadInfo->payloadEntries); + if (payloadInfo->message) free((void*)payloadInfo->message); +#endif +} + +void extractPayloadInfo(const nvtxPayloadSchemaEntry_t* schema, const nvtxPayloadData_t* data, const size_t numEntries, + const char* schemaName, roctxPayloadInfo_t payloadInfo) { + + if (payloadInfo->payloadEntries == nullptr) return; + + payloadInfo->id = schemaName; + payloadInfo->numEntries = numEntries; + + // Iterate over each entry in the schema + for (size_t i = 0; i < payloadInfo->numEntries; ++i) { + // Populate payload schema entry info for roctx + payloadInfo->payloadEntries[i].name = schema[i].name; + payloadInfo->payloadEntries[i].type = nvtxToRoctx[schema[i].type]; + + // Offset to index into the data stored in nvtxPayloadData_t->payload + uint64_t offset = schema[i].offset; + const void* entryData = reinterpret_cast(data->payload) + offset; + + // Populate payload union based on the roctx type + switch (payloadInfo->payloadEntries[i].type) { + case ROCTX_PAYLOAD_ENTRY_TYPE_INT: payloadInfo->payloadEntries[i].payload.typeInt = *reinterpret_cast(entryData); break; + case ROCTX_PAYLOAD_ENTRY_TYPE_SIZE: payloadInfo->payloadEntries[i].payload.typeSize = *reinterpret_cast(entryData); break; + case ROCTX_PAYLOAD_ENTRY_TYPE_REDOP: payloadInfo->payloadEntries[i].payload.typeRedOp = *reinterpret_cast(entryData); break; + default: break; + } + } + + // Stringify payloadInfo + stringify(payloadInfo); +} + +void stringify(roctxPayloadInfo_t payloadInfo) { + if (!payloadInfo->payloadEntries || !payloadInfo->message) return; + + int offset = snprintf(payloadInfo->message, MAX_MESSAGE_LENGTH, "{%s: ", payloadInfo->id); + + for (size_t i = 0; i < payloadInfo->numEntries; ++i) + { + roctxPayloadSchemaEntryInfo entry = payloadInfo->payloadEntries[i]; + + offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%s: ", entry.name); + + switch (entry.type) + { + case ROCTX_PAYLOAD_ENTRY_TYPE_INT: + offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%d", entry.payload.typeInt); + break; + case ROCTX_PAYLOAD_ENTRY_TYPE_SIZE: + offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%zu", entry.payload.typeSize); + break; + case ROCTX_PAYLOAD_ENTRY_TYPE_REDOP: + offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "%s", + entry.payload.typeRedOp < ncclNumDevRedOps ? ncclRedOpStr[entry.payload.typeRedOp] : "unknown"); + break; + default: + offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "unknown roctx payload type"); + break; + } + + if (i != payloadInfo->numEntries - 1) + offset += snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, ", "); + } + + snprintf(payloadInfo->message + offset, MAX_MESSAGE_LENGTH - offset, "}"); +} \ No newline at end of file diff --git a/projects/rccl/src/msccl.cc b/projects/rccl/src/msccl.cc index be2459da0f..fd33c94961 100644 --- a/projects/rccl/src/msccl.cc +++ b/projects/rccl/src/msccl.cc @@ -44,6 +44,18 @@ ncclResult_t mscclRunAlgo( void* recvBuff, const size_t recvCounts[], const size_t rDisPls[], size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op, mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, hipStream_t stream) { + struct NvtxParamsMsccl { + size_t sendbytes; + size_t recvbytes; + }; + // Just pass the size of one send/recv messages and not the total bytes sent/received. + constexpr nvtxPayloadSchemaEntry_t MscclSchema[] = { + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Send)"}, + {0, NVTX_PAYLOAD_ENTRY_TYPE_SIZE, "Message size [bytes] (Recv)"} + }; + NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)}; + NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload) + mscclStatus& status = mscclGetStatus(); struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle]; struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];