From 9f4651f20fedcae9f57a96d8f5b2b0d65ab1bbcf Mon Sep 17 00:00:00 2001 From: Karthikeyan Arumugam Date: Tue, 23 Dec 2025 07:33:10 -0800 Subject: [PATCH] Add support for AMD AINIC within RCCL default internal network plugin. (#2078) * Added support for AMD ROCm net-ib alongside vanilla net-ib, with auto-generation to detect conflicts early during NCCL sync and enable future customizations. * Integrated AMD AINIC support in RCCL for out-of-the-box usage, leveraging performance improvements by default, channel pinning for optimal pipeline performance, and extended support for 32B in-line CTS messages. * Implemented internal derivation of AINIC-specific flags when RCCL AINIC environment parameter is set, and checks before initializing AINIC net-ib methods. * Included snapshot of auto-generated ROCm net-ib file (src/transport/net_ib_rocm.cc) for reference. * Fixed typos in RCCL param API (RCCL_AINIC_ROCE) and dlclose. * Updated plugin loading logic: * Load internal ROCmIB plugin only when NCCL_NET_PLUGIN is not set. * Load default internal net-ib only when not AINIC and no external plugin env is set. --- CMakeLists.txt | 10 + cmake/rocmIb.cmake | 249 +++ ext-src/rocm_netib.patch | 797 ++++++++ src/include/ionic/ionicdvcore.h | 20 + src/include/ionic/ionicdvsymbols.h | 16 + src/include/ionic/ionicdvwrap.h | 17 + src/include/net.h | 5 + src/misc/ionicdvsymbols.cc | 60 + src/misc/ionicdvwrap.cc | 59 + src/plugin/net.cc | 16 +- src/transport/net.cc | 19 +- src/transport/net_ib_rocm.cc | 3006 ++++++++++++++++++++++++++++ 12 files changed, 4262 insertions(+), 12 deletions(-) create mode 100644 cmake/rocmIb.cmake create mode 100644 ext-src/rocm_netib.patch create mode 100644 src/include/ionic/ionicdvcore.h create mode 100644 src/include/ionic/ionicdvsymbols.h create mode 100644 src/include/ionic/ionicdvwrap.h create mode 100644 src/misc/ionicdvsymbols.cc create mode 100644 src/misc/ionicdvwrap.cc create mode 100644 src/transport/net_ib_rocm.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 8fded5ddd9..933ae4c47d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,9 @@ endif() # Determine which GPU architectures to build for set(GPU_TARGETS "${DEFAULT_GPUS}" CACHE STRING "Target default GPUs if GPU_TARGETS is not defined.") +# ROCM NetIB patch +include(cmake/rocmIb.cmake) + # Modify GPU architectures for Address Sanitizer builds by appending "xnack+" if (BUILD_ADDRESS_SANITIZER) SET(amdgpu_targets "") @@ -571,6 +574,9 @@ set(SRC_FILES src/include/mlx5/mlx5dvcore.h src/include/mlx5/mlx5dvsymbols.h src/include/mlx5/mlx5dvwrap.h + src/include/ionic/ionicdvcore.h + src/include/ionic/ionicdvsymbols.h + src/include/ionic/ionicdvwrap.h src/include/msccl/msccl_lifecycle.h src/include/msccl/msccl_parser.h src/include/msccl/msccl_scheduler.h @@ -647,6 +653,8 @@ set(SRC_FILES src/misc/ipcsocket.cc src/misc/mlx5dvsymbols.cc src/misc/mlx5dvwrap.cc + src/misc/ionicdvsymbols.cc + src/misc/ionicdvwrap.cc src/misc/npkit.cc # src/misc/nvmlwrap.cc src/misc/nvmlwrap_stub.cc @@ -695,6 +703,7 @@ set(SRC_FILES src/transport/generic.cc src/transport/net.cc src/transport/net_ib.cc + src/transport/net_ib_rocm.cc src/transport/net_socket.cc src/transport/nvls.cc src/transport/p2p.cc @@ -862,6 +871,7 @@ target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device/network/unpack) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include/mlx5) +target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include/ionic) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include/plugin) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/gensrc) target_include_directories(rccl PRIVATE ${HSA_INCLUDE_PATH}) diff --git a/cmake/rocmIb.cmake b/cmake/rocmIb.cmake new file mode 100644 index 0000000000..b37064d21d --- /dev/null +++ b/cmake/rocmIb.cmake @@ -0,0 +1,249 @@ +# MIT License +# +# Copyright (c) 2020 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Dependencies + +# HIP dependency is handled earlier in the project cmake file +# when VerifyCompiler.cmake is included. + +# GIT + +# Test dependencies + +# For downloading, building, and installing required dependencies +include(cmake/DownloadProject.cmake) + +message(STATUS "Generating ROCM NetIB... ") + +# ------------------------- +# Configurable paths +# ------------------------- +# Path to RCCL source tree (local clone) +set(RCCL_SRC_DIR "${CMAKE_SOURCE_DIR}" CACHE PATH "Path to RCCL source directory") +# Path to patch file +set(ROCM_NETIB_PATCH_FILE "${CMAKE_SOURCE_DIR}/ext-src/rocm_netib.patch" CACHE FILEPATH "ROCM NETIB Patch file to apply to RCCL") +set(ROCM_NETIB_FILE "${CMAKE_SOURCE_DIR}/src/transport/net_ib_rocm.cc" CACHE FILEPATH "Generated ROCM NETIB file") + +# ------------------------- +# Find tools +# ------------------------- +find_program(PATCH_EXECUTABLE patch) +find_program(SED_EXECUTABLE sed) + +execute_process( + COMMAND ${CMAKE_COMMAND} -E echo "Applying RCCL ROCM NetIB patch... to ${CMAKE_SOURCE_DIR}" + COMMAND bash -c "patch -p1 -i ${ROCM_NETIB_PATCH_FILE} -o ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/NCCL_PARAM(Ib/NCCL_PARAM(RocmIb/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/RCCL_PARAM(Ib/RCCL_PARAM(RocmIb/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclParamIb/ncclParamRocmIb/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/rcclParamIb/rcclParamRocmIb/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbMergedDevs/rocmIbMergedDevs/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbDevs/rocmIbDevs/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbLock/rocmIbLock/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ibProviderName/rocmIbProviderName/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbAsyncThread/rocmIbAsyncThread/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbGdrSupport/rocmIbGdrSupport/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbDmaBufSupport/rocmIbDmaBufSupport/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbInitCommDevBase/rocmIbInitCommDevBase/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbDestroyBase/rocmIbDestroyBase/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbRtrQp/rocmIbRtrQp/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbRtsQp/rocmIbRtsQp/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ForceEnableGdrdma/RocmForceEnableGdrdma/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbCheckVProps/rocmIbCheckVProps/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbGetRequest/rocmIbGetRequest/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbFreeRequest/rocmIbFreeRequest/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbRegMrDmaBufInternal/rocmIbRegMrDmaBufInternal/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbGetNetCommDevBase/rocmIbGetNetCommDevBase/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbDeregMrInternal/rocmIbDeregMrInternal/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbPostFifo/rocmIbPostFifo/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/reqTypeStr/rocmIbReqTypeStr/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/rcclNetP2pPolicy/rcclRocmNetP2pPolicy/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbMakeVDeviceInternal/rocmIbMakeVDeviceInternal/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbMakeVDevice/rocmIbMakeVDevice/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbInit/rocmIbInit/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbDevices/rocmIbDevices/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbGetPhysProperties/rocmIbGetPhysProperties/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbGetProperties/rocmIbGetProperties/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbListen\(/rocmIbListen\(/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbListen,/rocmIbListen,/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbConnect\(/rocmIbConnect\(/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbConnect /rocmIbConnect /g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbConnect,/rocmIbConnect,/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbAccept/rocmIbAccept/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbTest/rocmIbTest/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbRegMrDmaBuf/rocmIbRegMrDmaBuf/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbRegMr/rocmIbRegMr/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbDeregMr/rocmIbDeregMr/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbIsend/rocmIbIsend/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbIrecv/rocmIbIrecv/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbIflush/rocmIbIflush/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbCloseSend/rocmIbCloseSend/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbCloseRecv/rocmIbCloseRecv/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclIbCloseListen/rocmIbCloseListen/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +execute_process( + COMMAND bash -c "sed -i 's/ncclNetIb/rocmNetIb/g' ${ROCM_NETIB_FILE}" + WORKING_DIRECTORY ${RCCL_SRC_DIR} +) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") diff --git a/ext-src/rocm_netib.patch b/ext-src/rocm_netib.patch new file mode 100644 index 0000000000..84a0e90685 --- /dev/null +++ b/ext-src/rocm_netib.patch @@ -0,0 +1,797 @@ +diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc +index 7af56a6c..5c3e3d46 100644 +--- a/src/transport/net_ib.cc ++++ b/src/transport/net_ib.cc +@@ -28,6 +28,7 @@ + + #include "ibvwrap.h" + #include "mlx5/mlx5dvwrap.h" ++#include "ionic/ionicdvwrap.h" + #include "graph/xml.h" + + #define MAXSUFFIXSIZE 16 +@@ -107,9 +108,31 @@ struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_VDEVS]; + struct ncclIbDev ncclIbDevs[MAX_IB_DEVS]; + pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER; + static int ncclIbRelaxedOrderingEnabled = 0; ++static bool rcclAinicRoce = 0; ++static bool rcclCtsInlineData = 0; ++static bool rcclCtsOffloadEnabled = 0; ++static bool ncclIbUseInline = 0; ++static int ncclIbGdrFlushDisable = 0; ++ ++enum ncclIbChannelType { ++ ncclIbChannelTypeCts = 0, ++ ncclIbChannelTypeData = 1, ++ ncclIbChannelTypeMax = 2 ++}; ++ ++struct ncclChannelToUd { ++ int channelId; ++ bool udId; ++ bool udAllocated; ++}; ++ ++static ncclChannelToUd nccl_channel_ud_map[MAXCHANNELS][ncclIbChannelTypeMax]; ++static bool nccl_channel_last_ud[MAX_IB_DEVS][ncclIbChannelTypeMax]; + + #define NCCL_IB_LLSTR(ll) (((ll) == IBV_LINK_LAYER_INFINIBAND) ? "IB" : (((ll) == IBV_LINK_LAYER_ETHERNET) ? "RoCE" : "UNSPECIFIED")) + ++#define NCCL_CTS_QP_SLOT_INVALID 0xFF ++ + #define NCCL_IB_SL_DEFAULT 0 + #define NCCL_IB_TC_DEFAULT 0 + +@@ -131,6 +154,13 @@ NCCL_PARAM(IbEceEnable,"IB_ECE_ENABLE",1); + NCCL_PARAM(IbDataDirect,"IB_DATA_DIRECT",1); + NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); + RCCL_PARAM(IbQpsPerP2p, "IB_QPS_PER_P2P", 0); ++NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); ++ ++// AMD AINIC ++RCCL_PARAM(CtsInlineData, "CTS_INLINE_DATA", -1); ++RCCL_PARAM(CtsOffloadEnabled, "CTS_OFFLOAD_ENABLED", -1); ++ ++extern int64_t rcclParamAinicRoce(); + + static ncclResult_t ncclIbStatsInit(struct ncclIbStats* stat) { + __atomic_store_n(&stat->fatalErrorCount, 0, __ATOMIC_RELAXED); +@@ -630,6 +660,10 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction, ncclProfilerCallback_t pr + static int shownIbHcaEnv = 0; + if(wrap_ibv_symbols() != ncclSuccess) { return ncclInternalError; } + if(wrap_mlx5dv_symbols() != ncclSuccess) { INFO(NCCL_NET, "NET/IB : Failed to open mlx5dv symbols. Advance features like CX-8 Direct-NIC will be disabled."); } ++ if(wrap_ionicdv_symbols() != ncclSuccess) { ++ WARN("NET/IB : Failed to open ionicdv symbols. Advance features like AINIC UD load balancing will be disabled."); ++ return ncclInternalError; ++ } + + // Detect IB cards + int nIbDevs = 0; +@@ -783,6 +817,24 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction, ncclProfilerCallback_t pr + INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", + ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline)); + ++ ncclIbUseInline = ncclParamIbUseInline(); ++ ncclIbGdrFlushDisable = ncclParamIbGdrFlushDisable(); ++ ++ rcclAinicRoce = ((rcclParamAinicRoce() == 1) ? true : false); ++ if (rcclAinicRoce) { ++ // for AINIC, these params are defaulted to enabled unless user forces it to disable(0). ++ rcclCtsInlineData = ((rcclParamCtsInlineData() == 0) ? false : true); ++ rcclCtsOffloadEnabled = ((rcclParamCtsOffloadEnabled() == 0) ? false : true); ++ // for AINIC IbUseInline is enabled by default always ++ ncclIbUseInline = true; ++ // for AINIC GDR flush is disabled by default ++ ncclIbGdrFlushDisable = 1; ++ ++ INFO(NCCL_INIT|NCCL_NET, "NET/IB : AINIC RoCEv2 optimizations enabled: CTS Inline Data: %s; CTS Offload: %s; " ++ "IB Use Inline: enabled; GDR Flush: disabled", rcclCtsInlineData ? "Enabled": "Disabled", ++ rcclCtsOffloadEnabled ? "Enabled": "Disabled"); ++ } ++ + pthread_mutex_unlock(&ncclIbLock); + } + exit: +@@ -1112,6 +1164,8 @@ struct ncclIbListenComm { + struct ncclIbCommStage stage; + }; + ++#define MAX_INLINE_DATA_SIZE 24 ++ + struct alignas(64) ncclIbSendFifo { + uint64_t addr; + uint64_t size; +@@ -1122,10 +1176,21 @@ struct alignas(64) ncclIbSendFifo { + char padding[16]; + }; + ++struct alignas(32) ncclIbSendFifoCtsInline { ++ uint64_t addr; ++ uint32_t rkeys[1]; ++ int size; ++ uint8_t nreqs; ++ uint16_t tag; ++ uint32_t idx; ++ char padding[9]; ++} __attribute__((packed)); ++ + struct ncclIbQp { + struct ibv_qp* qp; + int devIndex; + int remDevIdx; ++ int8_t ctsQpSlot; + }; + + struct ncclIbRemSizesFifo { +@@ -1172,6 +1237,7 @@ struct ncclIbSendComm { + struct ncclIbNetCommBase base; + // Start with fifo and ibv structs as they have alignment restrictions + struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; ++ struct ncclIbSendFifoCtsInline fifo_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS]; + struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS + 1]; + // Each dev correlates to a mergedIbDev +@@ -1187,6 +1253,7 @@ struct ncclIbSendComm { + static_assert((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset"); + static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); + static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples"); ++static_assert((sizeof(struct ncclIbSendFifoCtsInline) % 32) == 0, "ncclIbSendFifoCtsInline element size must be 32-byte multiples"); + static_assert((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned"); + static_assert((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned"); + +@@ -1201,6 +1268,7 @@ struct ncclIbGpuFlush { + + struct ncclIbRemFifo { + struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; ++ struct ncclIbSendFifoCtsInline elems_cts_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + uint64_t fifoTail; + uint64_t addr; + uint32_t flags; +@@ -1265,20 +1333,59 @@ returning: + return res; + } + +-ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, int access_flags, void* qp_context, struct ncclIbQp* qp) { ++ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, ++ int access_flags, void* qp_context, struct ncclIbQp* qp, ++ int channel_id, bool data_qp, int8_t cts_qp_slot) { + struct ibv_qp_init_attr qpInitAttr; ++ enum ncclIbChannelType channel_type = (data_qp ? ncclIbChannelTypeData : ncclIbChannelTypeCts); + memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr)); + qpInitAttr.qp_context = qp_context; + qpInitAttr.send_cq = base->cq; + qpInitAttr.recv_cq = base->cq; + qpInitAttr.qp_type = IBV_QPT_RC; ++ ++ if (rcclAinicRoce) { ++ if (!nccl_channel_ud_map[channel_id][channel_type].udAllocated) { ++ bool lud = nccl_channel_last_ud[base->ibDevN][channel_type]; ++ nccl_channel_ud_map[channel_id][channel_type].udId = lud; ++ nccl_channel_ud_map[channel_id][channel_type].udAllocated = true; ++ nccl_channel_last_ud[base->ibDevN][channel_type] = ++ !(nccl_channel_last_ud[base->ibDevN][channel_type]); ++ } ++ if (nccl_channel_ud_map[channel_id][channel_type].udId) { ++ wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_HIGH); ++ } else { ++ wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_LOW); ++ } ++ qpInitAttr.sq_sig_all |= (1 << 16); ++ if (data_qp) { ++ qpInitAttr.sq_sig_all |= (1 << 17); ++ } else { ++ qpInitAttr.sq_sig_all &= (~(1 << 17)); ++ } ++ qpInitAttr.sq_sig_all |= (1 << 18); ++ ++ if (rcclCtsOffloadEnabled) { ++ qpInitAttr.sq_sig_all |= (1 << 19); ++ } else { ++ qpInitAttr.sq_sig_all &= (~(1 << 19)); ++ } ++ } ++ + // We might send 2 messages per send (RDMA and RDMA_WITH_IMM) + qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS; + qpInitAttr.cap.max_recv_wr = MAX_REQUESTS; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; +- qpInitAttr.cap.max_inline_data = ncclParamIbUseInline() ? sizeof(struct ncclIbSendFifo) : 0; ++ if (rcclCtsInlineData) { ++ qpInitAttr.cap.max_inline_data = MAX_INLINE_DATA_SIZE; ++ } else { ++ qpInitAttr.cap.max_inline_data = ncclIbUseInline ? sizeof(struct ncclIbSendFifo) : 0; ++ } + NCCLCHECK(wrap_ibv_create_qp(&qp->qp, base->pd, &qpInitAttr)); ++ if (rcclAinicRoce) { ++ NCCLCHECK(wrap_ionicdv_qp_set_gda(qp->qp, false, true)); ++ } + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); + qpAttr.qp_state = IBV_QPS_INIT; +@@ -1288,6 +1395,9 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, + NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p", + ib_port, base->ibDevN, ncclIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd); ++ if (rcclAinicRoce) { ++ qp->ctsQpSlot = cts_qp_slot; ++ } + return ncclSuccess; + } + +@@ -1371,7 +1481,7 @@ fail: + goto exit; + } + +-ncclResult_t ncclIbConnect(int dev, ncclNetCommConfig_t* config, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { ++ncclResult_t ncclIbConnect(int dev, ncclNetCommConfig_t* config, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** sendDevComm) { + ncclResult_t ret = ncclSuccess; + struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; + struct ncclIbCommStage* stage = &handle->stage; +@@ -1379,8 +1489,13 @@ ncclResult_t ncclIbConnect(int dev, ncclNetCommConfig_t* config, void* opaqueHan + int ready; + uint8_t link_layer = IBV_LINK_LAYER_UNSPECIFIED; + int isP2p = 0; ++ int channel_id = 0; + *sendComm = NULL; + ++ if (rcclAinicRoce) { ++ channel_id = ((ncclNet_ctxt_t *)sendDevComm)->chId; ++ } ++ + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; +@@ -1461,7 +1576,7 @@ ib_recv_dev_list: + for (int q = 0; q < comm->base.nqps; q++) { + ncclIbSendCommDev* commDev = comm->devs + devIndex; + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; +- NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q), ret, fail); ++ NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail); + comm->base.qps[q].devIndex = devIndex; + meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex; +@@ -1486,7 +1601,11 @@ ib_recv_dev_list: + devInfo->lid = ibDev->portAttr.lid; + devInfo->ibv_dev_index = commDev->base.ibDevN; + // Prepare my fifo +- NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); ++ if (rcclCtsInlineData) { ++ NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo_inline, sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); ++ } else { ++ NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); ++ } + devInfo->fifoRkey = commDev->fifoMr->rkey; + + // Pack local GID info +@@ -1528,7 +1647,11 @@ ib_recv_dev_list: + return ncclInternalError; + } + } +- meta.fifoAddr = (uint64_t)comm->fifo; ++ if (rcclCtsInlineData) { ++ meta.fifoAddr = (uint64_t)comm->fifo_inline; ++ } else { ++ meta.fifoAddr = (uint64_t)comm->fifo; ++ } + meta.sl = (ncclParamIbSl() != -1) ? ncclParamIbSl() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_SL_DEFAULT; + meta.tc = (ncclParamIbTc() != -1) ? ncclParamIbTc() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_TC_DEFAULT; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); +@@ -1673,18 +1796,22 @@ ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDevicePro + return ncclSuccess; + } + +-NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); + RCCL_PARAM(IbGdrFlushGpuMemNoRelaxedOrdering, "GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING", 1); + +-ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { ++ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** recvDevComm) { + ncclResult_t ret = ncclSuccess; + struct ncclIbListenComm* lComm = (struct ncclIbListenComm*)listenComm; + struct ncclIbCommStage* stage = &lComm->stage; + struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*)stage->comm; + int ready; + int link_layer = IBV_LINK_LAYER_UNSPECIFIED; ++ int channel_id = 0; + *recvComm = NULL; + ++ if (rcclAinicRoce) { ++ channel_id = ((ncclNet_ctxt_t *) recvDevComm)->chId; ++ } ++ + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; +@@ -1814,7 +1941,7 @@ ib_recv: + // Local ibDevN + ibDevN = rComm->devs[devIndex].base.ibDevN; + ibDev = ncclIbDevs + ibDevN; +- NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp), ret, fail); ++ NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp, channel_id, false, q), ret, fail); + qp->devIndex = devIndex; + devIndex = (devIndex + 1) % rComm->base.vProps.ndevs; + +@@ -1840,16 +1967,22 @@ ib_recv: + + useDmaBuf = (ncclIbDmaBufSupport(lComm->dev) == ncclSuccess); + rComm->flushEnabled = ((ncclIbGdrSupport() == ncclSuccess || useDmaBuf) +- && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; ++ && (ncclIbGdrFlushDisable == 0)) ? 1 : 0; + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { + rCommDev = rComm->devs + i; + ibDev = ncclIbDevs + rCommDev->base.ibDevN; + + // Retain remote fifo info and prepare my RDMA ops + rComm->remFifo.addr = remMeta.fifoAddr; +- NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); ++ if (rcclCtsInlineData) { ++ NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems_cts_inline, ++ sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, ++ IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); ++ } else { ++ NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); ++ } + rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; +- if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE; ++ if (ncclIbUseInline) rComm->remFifo.flags = IBV_SEND_INLINE; + + // Allocate Flush dummy buffer for GPU Direct RDMA + if (rComm->flushEnabled) { +@@ -1887,7 +2020,7 @@ ib_recv: + rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem; + rCommDev->gpuFlush.sge.length = 1; + rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey; +- NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp), ret, fail); ++ NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail); + struct ncclIbDevInfo devInfo; + devInfo.lid = ibDev->portAttr.lid; + devInfo.link_layer = ibDev->portAttr.link_layer; +@@ -2115,10 +2248,15 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { + + NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0); + +-ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { ++ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot, bool use_write_op) { + struct ncclIbRequest** reqs = comm->fifoReqs[slot]; + volatile struct ncclIbSendFifo* slots = comm->fifo[slot]; +- int nreqs = slots[0].nreqs; ++ int nreqs; ++ if (rcclCtsOffloadEnabled) { ++ nreqs = 1; ++ } else { ++ nreqs = slots[0].nreqs; ++ } + if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; + + uint64_t wr_id = 0ULL; +@@ -2130,7 +2268,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { + sge->addr=(uintptr_t)reqs[r]->send.data; + wr->opcode = IBV_WR_RDMA_WRITE; + wr->send_flags = 0; +- wr->wr.rdma.remote_addr = slots[r].addr; ++ if (rcclCtsOffloadEnabled) { ++ wr->wr.rdma.remote_addr = 0xdeadbeef; ++ } else { ++ wr->wr.rdma.remote_addr = slots[r].addr; ++ } + wr->next = wr + 1; + wr_id += (reqs[r] - comm->base.reqs) << (r*8); + #ifdef NCCL_ENABLE_NET_PROFILING +@@ -2141,7 +2283,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { + // Write size as immediate data. In the case of multi-send, only write + // 0 or 1 as size to indicate whether there was data sent or received. + uint32_t immData = 0; +- if (nreqs == 1) { ++ if ((nreqs == 1) && (use_write_op == false)) { + immData = reqs[0]->send.size; + } else { + int* sizes = comm->remSizesFifo.elems[slot]; +@@ -2151,22 +2293,24 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { + } + + struct ibv_send_wr* lastWr = comm->wrs+nreqs-1; +- if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) { +- // When using ADAPTIVE_ROUTING, send the bulk of the data first as an +- // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote +- // completion. +- lastWr++; +- memset(lastWr, 0, sizeof(struct ibv_send_wr)); +- if (nreqs > 1) { +- // Write remote sizes Fifo +- lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); +- lastWr->num_sge = 1; +- lastWr->sg_list = &comm->remSizesFifo.sge; ++ if (use_write_op == false) { ++ if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) { ++ // When using ADAPTIVE_ROUTING, send the bulk of the data first as an ++ // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote ++ // completion. ++ lastWr++; ++ memset(lastWr, 0, sizeof(struct ibv_send_wr)); ++ if (nreqs > 1) { ++ // Write remote sizes Fifo ++ lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); ++ lastWr->num_sge = 1; ++ lastWr->sg_list = &comm->remSizesFifo.sge; ++ } + } ++ lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; ++ lastWr->imm_data = immData; + } + lastWr->wr_id = wr_id; +- lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; +- lastWr->imm_data = immData; + lastWr->next = NULL; + lastWr->send_flags = IBV_SEND_SIGNALED; + +@@ -2182,7 +2326,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { + //ncclIbAddEvent(reqs[r], devIndex, &comm->devs[devIndex].base); + + // Select proper rkey (needed even for 0-size send) +- comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; ++ if (rcclCtsOffloadEnabled) { ++ comm->wrs[r].wr.rdma.rkey = 0xbade; ++ } else { ++ comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; ++ } + + int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align; + int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize); +@@ -2198,7 +2346,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { + } + } + +- if (nreqs > 1) { ++ if ((use_write_op == false) && (nreqs > 1)) { + // Also make sure lastWr writes remote sizes using the right lkey + comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey; + lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex]; +@@ -2256,32 +2404,46 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* + NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__)); + + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; ++ bool use_write_op = false; ++ if (rcclAinicRoce) { ++ use_write_op = (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) ? true : false; ++ } + + // Wait for the receiver to have posted the corresponding receive + int nreqs = 0; + volatile struct ncclIbSendFifo* slots; + ++ if (rcclCtsOffloadEnabled) { ++ nreqs = 1; ++ } ++ + int slot = (comm->fifoHead) % MAX_REQUESTS; + struct ncclIbRequest** reqs = comm->fifoReqs[slot]; +- slots = comm->fifo[slot]; +- uint64_t idx = comm->fifoHead+1; +- if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; } +- nreqs = slots[0].nreqs; +- // Wait until all data has arrived +- for (int r=1; rfifo[slot]; ++ uint64_t idx = comm->fifoHead+1; ++ if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; } ++ nreqs = slots[0].nreqs; ++ // Wait until all data has arrived ++ for (int r=1; r slots[r].size) size = slots[r].size; +- // Sanity checks +- if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { +- char line[SOCKET_NAME_MAXLEN + 1]; +- union ncclSocketAddress addr; +- ncclSocketGetAddr(&comm->base.sock, &addr); +- WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", +- r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); +- return ncclInternalError; ++ if (!rcclCtsOffloadEnabled) { ++ if (reqs[r] != NULL || slots[r].tag != tag) continue; ++ ++ if (size > slots[r].size) size = slots[r].size; ++ // Sanity checks ++ if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { ++ char line[SOCKET_NAME_MAXLEN + 1]; ++ union ncclSocketAddress addr; ++ ncclSocketGetAddr(&comm->base.sock, &addr); ++ WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", ++ r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); ++ return ncclInternalError; ++ } ++ } else{ ++ if (reqs[r] != NULL) continue; + } + + struct ncclIbRequest* req; +@@ -2325,10 +2487,12 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* + } + + TIME_START(0); +- NCCLCHECK(ncclIbMultiSend(comm, slot)); ++ NCCLCHECK(ncclIbMultiSend(comm, slot, use_write_op)); + + // Clear slots[0]->nreqs, as well as other fields to help debugging and sanity checks +- memset((void*)slots, 0, sizeof(struct ncclIbSendFifo)); ++ if (!rcclCtsOffloadEnabled) { ++ memset((void*)slots, 0, sizeof(struct ncclIbSendFifo)); ++ } + memset(reqs, 0, NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbRequest*)); + comm->fifoHead++; + TIME_STOP(0); +@@ -2341,30 +2505,60 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* + + ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { + struct ibv_send_wr wr; ++ struct ncclIbSendFifo* localElem = NULL; ++ struct ncclIbSendFifoCtsInline* localElemCtsInline = NULL; ++ uint64_t localElemRef; ++ int qpIndex = 0; ++ ncclIbQp* ctsQp = NULL; + memset(&wr, 0, sizeof(wr)); + + int slot = comm->remFifo.fifoTail%MAX_REQUESTS; + req->recv.sizes = comm->sizesFifo[slot]; + for (int i=0; irecv.sizes[i] = 0; +- struct ncclIbSendFifo* localElem = comm->remFifo.elems[slot]; ++ if (rcclCtsInlineData) { ++ localElemCtsInline = comm->remFifo.elems_cts_inline[slot]; ++ } else { ++ localElem = comm->remFifo.elems[slot]; ++ } + +- // Select the next devIndex (local) and QP to use for posting this CTS message +- // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value +- ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex; +- comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; ++ if (rcclAinicRoce) { ++ qpIndex = comm->base.qpIndex; ++ ctsQp = comm->base.qps + qpIndex; ++ } else { ++ // Select the next devIndex (local) and QP to use for posting this CTS message ++ // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value ++ ctsQp = comm->base.qps + comm->base.devIndex; ++ comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; ++ } + + for (int i=0; ibase.vProps.ndevs; j++) +- localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; ++ // Send all applicable rkeys ++ for (int j = 0; j < comm->base.vProps.ndevs; j++) ++ localElemCtsInline[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; ++ ++ localElemCtsInline[i].nreqs = n; ++ localElemCtsInline[i].size = sizes[i]; // Sanity/Debugging ++ localElemCtsInline[i].tag = tags[i]; ++ localElemCtsInline[i].idx = comm->remFifo.fifoTail+1; ++ localElemRef = (uint64_t)localElemCtsInline; ++ ++ } else { ++ localElem[i].addr = (uint64_t)data[i]; + +- localElem[i].nreqs = n; +- localElem[i].size = sizes[i]; // Sanity/Debugging +- localElem[i].tag = tags[i]; +- localElem[i].idx = comm->remFifo.fifoTail+1; ++ // Send all applicable rkeys ++ for (int j = 0; j < comm->base.vProps.ndevs; j++) ++ localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; ++ ++ localElem[i].nreqs = n; ++ localElem[i].size = sizes[i]; // Sanity/Debugging ++ localElem[i].tag = tags[i]; ++ localElem[i].idx = comm->remFifo.fifoTail+1; ++ localElemRef = (uint64_t)localElem; ++ } + } + wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo); + +@@ -2372,8 +2566,12 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz + wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey; + + // Set the correct sge properties +- comm->devs[ctsQp->devIndex].fifoSge.addr = (uint64_t)localElem; +- comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); ++ comm->devs[ctsQp->devIndex].fifoSge.addr = localElemRef; ++ if (rcclCtsInlineData) { ++ comm->devs[ctsQp->devIndex].fifoSge.length = MAX_INLINE_DATA_SIZE; ++ } else { ++ comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); ++ } + wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge; + wr.num_sge = 1; + +@@ -2403,7 +2601,13 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz + // + // slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled. + // This works out that each fifo posting QP gets drained +- if (slot == ctsQp->devIndex) { ++ if (rcclAinicRoce) { ++ if (slot == ctsQp->ctsQpSlot) { ++ wr.send_flags |= IBV_SEND_SIGNALED; ++ wr.wr_id = req - comm->base.reqs; ++ ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); ++ } ++ } else if (slot == ctsQp->devIndex) { + wr.send_flags |= IBV_SEND_SIGNALED; + wr.wr_id = req - comm->base.reqs; + ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); +@@ -2418,10 +2622,16 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz + + comm->remFifo.fifoTail++; + ++ if (rcclAinicRoce) { ++ // Select the next qpIndex ++ comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps; ++ } + return ncclSuccess; + } + + ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** phandles, void** request) { ++ ncclResult_t res = ncclSuccess; ++ bool netOptRecvCompletionEnabled = false; + struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; + if (comm->base.ready == 0) { + WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0"); +@@ -2431,6 +2641,11 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* + if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; + NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__)); + ++ if (rcclAinicRoce) { ++ if (*request == (void *) NCCL_NET_OPTIONAL_RECV_COMPLETION) { ++ netOptRecvCompletionEnabled = true; ++ } ++ } + struct ncclIbRequest* req; + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); + req->type = NCCL_NET_IB_REQ_RECV; +@@ -2444,50 +2659,64 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* + req->devBases[i] = &comm->devs[i].base; + } + +- struct ibv_recv_wr wr; +- memset(&wr, 0, sizeof(wr)); +- wr.wr_id = req - comm->base.reqs; +- wr.sg_list = NULL; +- wr.num_sge = 0; ++ if (!netOptRecvCompletionEnabled) { ++ struct ibv_recv_wr wr; ++ memset(&wr, 0, sizeof(wr)); ++ wr.wr_id = req - comm->base.reqs; ++ wr.sg_list = NULL; ++ wr.num_sge = 0; + +- TIME_START(1); +- // Select either all QPs, or one qp per-device +- const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; ++ TIME_START(1); ++ // Select either all QPs, or one qp per-device ++ const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; + +- // Post recvs +- struct ibv_recv_wr* bad_wr; +- for (int i = 0; i < nqps; i++) { +- struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; +- ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); ++ // Post recvs ++ struct ibv_recv_wr* bad_wr; ++ int qpIndex = comm->base.qpIndex; ++ for (int i = 0; i < nqps; i++) { ++ struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; ++ ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); + #ifdef NCCL_ENABLE_NET_PROFILING +- // Start a QP event for every request in the multirecv and every qp +- for (int r = 0; r < n; r++) { +- int nEventHandles = req->pInfo[r].nEventHandles; +- assert(nEventHandles < MAX_QPS_PER_REQ); +- req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex; +- // Store info for profiler +- int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER; +- req->pInfo[r].data.type = ncclProfileQp; +- req->pInfo[r].data.qp.device = qp->devIndex; +- req->pInfo[r].data.qp.wr_id = wr.wr_id; +- req->pInfo[r].data.qp.qpNum = qp->qp->qp_num; +- NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data)); +- req->pInfo[r].nEventHandles++; +- } ++ // Start a QP event for every request in the multirecv and every qp ++ for (int r = 0; r < n; r++) { ++ int nEventHandles = req->pInfo[r].nEventHandles; ++ assert(nEventHandles < MAX_QPS_PER_REQ); ++ req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex; ++ // Store info for profiler ++ int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER; ++ req->pInfo[r].data.type = ncclProfileQp; ++ req->pInfo[r].data.qp.device = qp->devIndex; ++ req->pInfo[r].data.qp.wr_id = wr.wr_id; ++ req->pInfo[r].data.qp.qpNum = qp->qp->qp_num; ++ NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data)); ++ req->pInfo[r].nEventHandles++; ++ } + #endif +- NCCLCHECK(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr)); +- comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; +- } ++ NCCLCHECKGOTO(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr), res, err); ++ // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs ++ // inside ncclIbPostFifo() ++ if (rcclAinicRoce) { ++ qpIndex = (qpIndex+1)%comm->base.nqps; ++ } else { ++ comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; ++ } ++ } + +- TIME_STOP(1); ++ TIME_STOP(1); ++ } // netOptRecvCompletionEnabled = false + + // Post to FIFO to notify sender + TIME_START(2); +- NCCLCHECK(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req)); ++ NCCLCHECKGOTO(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req), res, err); + TIME_STOP(2); + + *request = req; + return ncclSuccess; ++err: ++ if (req) { ++ ncclIbFreeRequest(req); ++ } ++ return res; + } + + ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { +@@ -2556,6 +2785,8 @@ static int getReqQpIndex(struct ncclIbRequest* req, int request, int qpNumber) { + } + #endif + ++#define NCCL_CQ_POLL_MAX_EVENT 16 ++ + ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { + struct ncclIbRequest *r = (struct ncclIbRequest*)request; + *done = 0; +@@ -2589,13 +2820,18 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { + + int totalWrDone = 0; + int wrDone = 0; +- struct ibv_wc wcs[4]; ++ struct ibv_wc wcs[NCCL_CQ_POLL_MAX_EVENT]; ++ int cqMaxPollEvent = 4; ++ if (rcclAinicRoce) { ++ cqMaxPollEvent = NCCL_CQ_POLL_MAX_EVENT; ++ } + + for (int i = 0; i < NCCL_IB_MAX_DEVS_PER_NIC; i++) { + TIME_START(3); + // If we expect any completions from this device's CQ + if (r->events[i]) { +- NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, 4, wcs, &wrDone)); ++ NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, cqMaxPollEvent, ++ wcs, &wrDone)); + totalWrDone += wrDone; + if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } + if (wrDone == 0) continue; +@@ -2742,7 +2978,7 @@ ncclResult_t rcclNetP2pPolicy(void* handle, int isP2p) { + } + + ncclNet_t ncclNetIb = { +- "IB", ++ "ROCM-IB", + ncclIbInit, + ncclIbDevices, + ncclIbGetProperties, diff --git a/src/include/ionic/ionicdvcore.h b/src/include/ionic/ionicdvcore.h new file mode 100644 index 0000000000..539f99ffd2 --- /dev/null +++ b/src/include/ionic/ionicdvcore.h @@ -0,0 +1,20 @@ +#ifndef NCCL_IONICDV_CORE_H_ +#define NCCL_IONICDV_CORE_H_ + +/* Basic ionic direct verbs structs. + * Needed to dynamically load ionic direct verbs functions without + * explicit including of ionic direct verbs header. + */ + +#include +#include +#include +#include +#include "ibvwrap.h" + +enum ionicdv_reg_udma_mask { + IONIC_UDMA_MASK_LOW = 1, + IONIC_UDMA_MASK_HIGH = 2 +}; + +#endif // NCCL_IONICDV_CORE_H_ diff --git a/src/include/ionic/ionicdvsymbols.h b/src/include/ionic/ionicdvsymbols.h new file mode 100644 index 0000000000..813203d800 --- /dev/null +++ b/src/include/ionic/ionicdvsymbols.h @@ -0,0 +1,16 @@ +#ifndef NCCL_IONICDV_SYMBOLS_H_ +#define NCCL_IONICDV_SYMBOLS_H_ + +#include "ionic/ionicdvcore.h" +#include "nccl.h" + +/* Ionic Direct Verbs Function Pointers*/ +struct ncclIonicdvSymbols { + int (*ionicdv_internal_qp_set_gda)(struct ibv_qp *qp, bool enable_send, bool enable_recv); + int (*ionicdv_internal_pd_set_udma_mask)(struct ibv_pd *ibpd, uint8_t udma_mask); +}; + +/* Constructs ionic direct verbs symbols per rdma-core linking or dynamic loading mode */ +ncclResult_t buildIonicdvSymbols(struct ncclIonicdvSymbols* ionicdvSymbols); + +#endif // NCCL_IONICDV_SYMBOLS_H_ diff --git a/src/include/ionic/ionicdvwrap.h b/src/include/ionic/ionicdvwrap.h new file mode 100644 index 0000000000..510367465d --- /dev/null +++ b/src/include/ionic/ionicdvwrap.h @@ -0,0 +1,17 @@ +#ifndef NCCL_IONICDVWRAP_H_ +#define NCCL_IONICDVWRAP_H_ + +#include +#include +#include "ionic/ionicdvcore.h" +#include "core.h" +#include "ibvwrap.h" +#include +#include + +ncclResult_t wrap_ionicdv_symbols(void); +/* NCCL wrappers of ionic direct verbs functions */ +ncclResult_t wrap_ionicdv_qp_set_gda(struct ibv_qp *ibqp, bool enable_send, bool enable_recv); +ncclResult_t wrap_ionicdv_pd_set_udma_mask(struct ibv_pd *ibpd, uint8_t udma_mask); + +#endif // NCCL_IONICDVWRAP_H_ diff --git a/src/include/net.h b/src/include/net.h index 3c8567fc31..c51cc7b7e6 100644 --- a/src/include/net.h +++ b/src/include/net.h @@ -25,4 +25,9 @@ extern ncclNet_t ncclNetSocket; extern ncclResult_t rcclNetP2pPolicy(void* handle, int isP2p); +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +extern ncclNet_t rocmNetIb; +extern ncclResult_t rcclRocmNetP2pPolicy(void* handle, int isP2p); +#endif + #endif diff --git a/src/misc/ionicdvsymbols.cc b/src/misc/ionicdvsymbols.cc new file mode 100644 index 0000000000..bb93a46891 --- /dev/null +++ b/src/misc/ionicdvsymbols.cc @@ -0,0 +1,60 @@ +#include +#include + +#include "ionic/ionicdvsymbols.h" + +/* ionicdv dynamic loading mode. Symbols are loaded from shared objects. */ +#include +#include "core.h" + +// IONICDV Library versioning +#define IONIC_VERSION "IONIC_1.0" + +ncclResult_t buildIonicdvSymbols(struct ncclIonicdvSymbols* ionicdvSymbols) { + static void* ionicdvhandle = NULL; + void* tmp; + void** cast; + + ionicdvhandle = dlopen("libionic.so", RTLD_NOW); + if (!ionicdvhandle) { + ionicdvhandle = dlopen("libionic.so.1", RTLD_NOW); + if (!ionicdvhandle) { + INFO(NCCL_INIT, "Failed to open libionic.so[.1]"); + goto teardown; + } + } + +#define LOAD_SYM(handle, symbol, funcptr) do { \ + cast = (void**)&funcptr; \ + tmp = dlvsym(handle, symbol, IONIC_VERSION); \ + if (tmp == NULL) { \ + WARN("dlvsym failed on %s - %s version %s", symbol, dlerror(), IONIC_VERSION); \ + goto teardown; \ + } else { \ + WARN("dlvsym loaded successfully for %s - version %s", symbol, IONIC_VERSION); \ + } \ + *cast = tmp; \ + } while (0) + +// Attempt to load a specific symbol version - fail silently +#define LOAD_SYM_VERSION(handle, symbol, funcptr, version) do { \ + cast = (void**)&funcptr; \ + *cast = dlvsym(handle, symbol, version); \ + if (*cast == NULL) { \ + INFO(NCCL_NET, "dlvsym failed on %s - %s version %s", symbol, dlerror(), version); \ + } \ + } while (0) + + LOAD_SYM(ionicdvhandle, "ionic_dv_qp_set_gda", ionicdvSymbols->ionicdv_internal_qp_set_gda); + LOAD_SYM(ionicdvhandle, "ionic_dv_pd_set_udma_mask", ionicdvSymbols->ionicdv_internal_pd_set_udma_mask); + INFO(NCCL_INIT, "Loaded dlvsym from libionic.so[.1]"); + + return ncclSuccess; + +teardown: + ionicdvSymbols->ionicdv_internal_qp_set_gda = NULL; + ionicdvSymbols->ionicdv_internal_pd_set_udma_mask = NULL; + + if (ionicdvhandle != NULL) dlclose(ionicdvhandle); + return ncclSystemError; +} diff --git a/src/misc/ionicdvwrap.cc b/src/misc/ionicdvwrap.cc new file mode 100644 index 0000000000..ece63303b0 --- /dev/null +++ b/src/misc/ionicdvwrap.cc @@ -0,0 +1,59 @@ +#include "ionic/ionicdvwrap.h" +#include +#include +#include "param.h" + +#include "ionic/ionicdvcore.h" +#include "ionic/ionicdvsymbols.h" + +static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; +static ncclResult_t initResult; +struct ncclIonicdvSymbols ionicdvSymbols; + +extern int64_t rcclParamAinicRoce(); + +ncclResult_t wrap_ionicdv_symbols(void) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + if (rcclParamAinicRoce() == 1) { + pthread_once(&initOnceControl, + [](){ initResult = buildIonicdvSymbols(&ionicdvSymbols); }); + return initResult; + } +#endif + // simply return for unsupported platform/NIC. + return ncclSuccess; +} + +/* CHECK_NOT_NULL: helper macro to check for NULL symbol */ +#define CHECK_NOT_NULL(container, internal_name) \ + if (container.internal_name == NULL) { \ + WARN("lib wrapper not initialized."); \ + return ncclInternalError; \ + } + +#define IONICDV_INT_CHECK_RET_ERRNO(container, internal_name, call, success_retval, name) \ + CHECK_NOT_NULL(container, internal_name); \ + int ret = container.call; \ + if (ret != success_retval) { \ + INFO(NCCL_NET, "Call to " name " failed with error %s errno %d", strerror(ret), ret); \ + return ncclSystemError; \ + } else { \ + INFO(NCCL_NET, "Call to " name " success with ret %d", ret); \ + } \ + return ncclSuccess; + +ncclResult_t wrap_ionicdv_qp_set_gda(struct ibv_qp *qp, bool enable_send, bool enable_recv) { + if (ionicdvSymbols.ionicdv_internal_qp_set_gda == NULL) { + errno = EOPNOTSUPP; + return ncclSystemError; + } + IONICDV_INT_CHECK_RET_ERRNO(ionicdvSymbols, ionicdv_internal_qp_set_gda, ionicdv_internal_qp_set_gda(qp, enable_send, enable_recv), 0, "ionic_dv_qp_set_gda"); +} + +ncclResult_t wrap_ionicdv_pd_set_udma_mask(struct ibv_pd *ibpd, uint8_t udma_mask) { + if (ionicdvSymbols.ionicdv_internal_pd_set_udma_mask == NULL) { + errno = EOPNOTSUPP; + return ncclSystemError; + } + IONICDV_INT_CHECK_RET_ERRNO(ionicdvSymbols, ionicdv_internal_pd_set_udma_mask, ionicdv_internal_pd_set_udma_mask(ibpd, udma_mask), 0, "ionic_dv_pd_set_udma_mask"); +} diff --git a/src/plugin/net.cc b/src/plugin/net.cc index e8203e20ae..321cd74da7 100644 --- a/src/plugin/net.cc +++ b/src/plugin/net.cc @@ -30,6 +30,8 @@ extern getNcclCollNet_t getNcclCollNet_v8; extern getNcclCollNet_t getNcclCollNet_v9; extern getNcclCollNet_t getNcclCollNet_v10; +extern int64_t rcclParamAinicRoce(); + NCCL_PARAM(NetPluginRefCount, "NET_PLUGIN_REF_COUNT", 1); #define NCCL_NET_VERSION_COUNT 5 int ncclNetVersion[NCCL_NET_VERSION_COUNT] = {10, 9, 8, 7, 6}; @@ -244,8 +246,18 @@ static void initPluginLibsOnceFunc() { } // Add 2 internal ib and socket plugins - netPluginLibs[pluginCounter].ncclNet = &ncclNetIb; - netPluginLibs[pluginCounter++].ncclNetPluginState = ncclNetPluginStateInitReady; +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + if ((rcclParamAinicRoce() == 1) && !(envNetPlugin)) { + // For AINIC add rocm internal ib instead of default internal ib + netPluginLibs[pluginCounter].ncclNet = &rocmNetIb; + netPluginLibs[pluginCounter++].ncclNetPluginState = ncclNetPluginStateInitReady; + } else { +#endif + netPluginLibs[pluginCounter].ncclNet = &ncclNetIb; + netPluginLibs[pluginCounter++].ncclNetPluginState = ncclNetPluginStateInitReady; +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + } +#endif netPluginLibs[pluginCounter].ncclNet = &ncclNetSocket; netPluginLibs[pluginCounter++].ncclNetPluginState = ncclNetPluginStateInitReady; pluginCount = pluginCounter; diff --git a/src/transport/net.cc b/src/transport/net.cc index 158c3e71b9..9087604d88 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -29,8 +29,6 @@ static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large"); -#define RCCL_ANP_PLUGIN_STR "RCCL-ANP" - #define NCCL_NET_MAP_HOSTMEM 0 #define NCCL_NET_MAP_DEVMEM 1 #define NCCL_NET_MAP_SHARED_HOSTMEM 2 @@ -199,6 +197,7 @@ struct setupReq { }; NCCL_PARAM(NetOptionalRecvCompletion, "NET_OPTIONAL_RECV_COMPLETION", 1); +RCCL_PARAM(AinicRoce, "AINIC_ROCE", 0); static_assert(sizeof(ncclNetHandle_t) + sizeof(int) <= CONNECT_SIZE, "Not large enough ncclConnect to hold ncclNetHandle_t and useGdr flag"); // Forward declaration @@ -769,12 +768,12 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str ncclNet_ctxt_t ncclNetCtxt = {}; struct sendNetResources* resources = (struct sendNetResources*)(connection->transportResources); ncclNetCommConfig_t commConfig = {0}; + bool rcclAinicRoce = ((rcclParamAinicRoce() == 1) ? true : false); if (reqSize != sizeof(netSendConnectArgs)) return ncclInternalError; ncclResult_t ret = ncclSuccess; netSendConnectArgs* req = (netSendConnectArgs*) reqBuff; commConfig.trafficClass = req->trafficClass == NCCL_CONFIG_UNDEF_INT ? NCCL_NET_TRAFFIC_CLASS_UNDEF : req->trafficClass; NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, false /*isRecv*/, &resources->netDeviceHandle)); - bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR)); // Only call rcclNetP2pPolicy for ncclNetIb if (proxyState->ncclNet == &ncclNetIb) { @@ -804,7 +803,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str comms->activeConnect[resources->channelId] = (resources->tpLocalRank + 1); if (comms->sendComm[resources->channelId] == NULL && comms->activeConnect[resources->channelId] == (resources->tpLocalRank + 1)) { - if (rccl_anp) { + if (rcclAinicRoce) { ncclNetCtxt.chId = resources->channelId; ret = proxyState->ncclNet->connect(resources->netDev, &commConfig, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt); @@ -816,7 +815,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str resources->netSendComm = comms->sendComm[resources->channelId]; if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; } else { - if (rccl_anp) { + if (rcclAinicRoce) { ncclNetCtxt.chId = resources->channelId; ret = proxyState->ncclNet->connect(resources->netDev, &commConfig, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { @@ -825,7 +824,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str } } else { // Connect to remote peer - if (rccl_anp) { + if (rcclAinicRoce) { ncclNetCtxt.chId = resources->channelId; ret = proxyState->ncclNet->connect(resources->netDev, &commConfig, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { @@ -979,8 +978,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str resources->tpRemoteProxyRank = req->proxyRank; ncclResult_t ret = ncclSuccess; ncclNet_ctxt_t ncclNetCtxt = {}; + bool rcclAinicRoce = ((rcclParamAinicRoce() == 1) ? true : false); - bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR)); NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle)); // Finish connection establishment from remote peer if (resources->shared) { @@ -1007,7 +1006,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str //try connecting while comm is null if (comms->recvComm[resources->channelId] == NULL && comms->activeAccept[resources->channelId] == (resources->tpLocalRank + 1)) { - if (rccl_anp) { + if (rcclAinicRoce) { ncclNetCtxt.chId = resources->channelId; ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt); @@ -1019,7 +1018,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str resources->netRecvComm = comms->recvComm[resources->channelId]; if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; } else { - if (rccl_anp) { + if (rcclAinicRoce) { ncclNetCtxt.chId = resources->channelId; ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { @@ -1028,7 +1027,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str } } else { // Connect to remote peer - if (rccl_anp) { + if (rcclAinicRoce) { ncclNetCtxt.chId = resources->channelId; ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { diff --git a/src/transport/net_ib_rocm.cc b/src/transport/net_ib_rocm.cc new file mode 100644 index 0000000000..c1674c01de --- /dev/null +++ b/src/transport/net_ib_rocm.cc @@ -0,0 +1,3006 @@ +/************************************************************************* + * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "nccl.h" +#include "core.h" +#include "socket.h" +#include "net.h" +#include "graph.h" +#include "utils.h" +#include "param.h" +#include "profiler/net_ib.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#define ENABLE_TIMER 0 +#include "timer.h" +#include + +#include "ibvwrap.h" +#include "mlx5/mlx5dvwrap.h" +#include "ionic/ionicdvwrap.h" +#include "graph/xml.h" + +#define MAXSUFFIXSIZE 16 +#define MAXNAMESIZE (64 + MAXSUFFIXSIZE) +static char ncclIbIfName[MAX_IF_NAME_SIZE+1]; +static union ncclSocketAddress ncclIbIfAddr; + +struct ncclIbMr { + uintptr_t addr; + size_t pages; + int refs; + ibv_mr *mr; +}; + +struct ncclIbMrCache { + struct ncclIbMr *slots; + int capacity, population; +}; + +static int ncclNMergedIbDevs = -1; +#define NCCL_IB_MAX_DEVS_PER_NIC 4 +#define MAX_MERGED_DEV_NAME (MAXNAMESIZE*NCCL_IB_MAX_DEVS_PER_NIC)+NCCL_IB_MAX_DEVS_PER_NIC +struct alignas(64) ncclIbMergedDev { + ncclNetVDeviceProps_t vProps; + int speed; + char devName[MAX_MERGED_DEV_NAME]; // Up to NCCL_IB_MAX_DEVS_PER_NIC * name size, and a character for each '+' +}; + +struct ncclIbStats { + int fatalErrorCount; +}; + +enum ncclIbProvider { + IB_PROVIDER_NONE = 0, + IB_PROVIDER_MLX5 = 1, + IB_PROVIDER_MAX = 2, +}; + +const char* rocmIbProviderName[] = { + "None", + "Mlx5", +}; + +static int ncclNIbDevs = -1; +struct alignas(64) ncclIbDev { + pthread_mutex_t lock; + int device; + uint64_t guid; + uint8_t portNum; + uint8_t link; + int speed; + ibv_context* context; + int pdRefs; + ibv_pd* pd; + char devName[MAXNAMESIZE]; + char* pciPath; + char* virtualPciPath; + int realPort; + int maxQp; + float latency; + struct ncclIbMrCache mrCache; + int ar; // ADAPTIVE_ROUTING + struct ibv_port_attr portAttr; + struct ncclIbStats stats; + int dmaBufSupported; + enum ncclIbProvider ibProvider; + union { + struct { + int dataDirect; + } mlx5; + } capsProvider; +}; + +#define MAX_IB_DEVS 32 +#define MAX_IB_VDEVS MAX_IB_DEVS*8 +struct ncclIbMergedDev rocmIbMergedDevs[MAX_IB_VDEVS]; +struct ncclIbDev rocmIbDevs[MAX_IB_DEVS]; +pthread_mutex_t rocmIbLock = PTHREAD_MUTEX_INITIALIZER; +static int ncclIbRelaxedOrderingEnabled = 0; +static bool rcclAinicRoce = 0; +static bool rcclCtsInlineData = 0; +static bool rcclCtsOffloadEnabled = 0; +static bool ncclIbUseInline = 0; +static int ncclIbGdrFlushDisable = 0; + +enum ncclIbChannelType { + ncclIbChannelTypeCts = 0, + ncclIbChannelTypeData = 1, + ncclIbChannelTypeMax = 2 +}; + +struct ncclChannelToUd { + int channelId; + bool udId; + bool udAllocated; +}; + +static ncclChannelToUd nccl_channel_ud_map[MAXCHANNELS][ncclIbChannelTypeMax]; +static bool nccl_channel_last_ud[MAX_IB_DEVS][ncclIbChannelTypeMax]; + +#define NCCL_IB_LLSTR(ll) (((ll) == IBV_LINK_LAYER_INFINIBAND) ? "IB" : (((ll) == IBV_LINK_LAYER_ETHERNET) ? "RoCE" : "UNSPECIFIED")) + +#define NCCL_CTS_QP_SLOT_INVALID 0xFF + +#define NCCL_IB_SL_DEFAULT 0 +#define NCCL_IB_TC_DEFAULT 0 + +NCCL_PARAM(RocmIbGidIndex, "IB_GID_INDEX", -1); +NCCL_PARAM(RocmIbRoutableFlidIbGidIndex, "IB_ROUTABLE_FLID_GID_INDEX", 1); +NCCL_PARAM(RocmIbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2); +NCCL_PARAM(RocmIbTimeout, "IB_TIMEOUT", 20); +NCCL_PARAM(RocmIbRetryCnt, "IB_RETRY_CNT", 7); +NCCL_PARAM(RocmIbPkey, "IB_PKEY", 0); +NCCL_PARAM(RocmIbUseInline, "IB_USE_INLINE", 0); +NCCL_PARAM(RocmIbSl, "IB_SL", -1); +NCCL_PARAM(RocmIbTc, "IB_TC", -1); +NCCL_PARAM(RocmIbArThreshold, "IB_AR_THRESHOLD", 8192); +NCCL_PARAM(RocmIbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2); +NCCL_PARAM(RocmIbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); +NCCL_PARAM(RocmIbFifoTc, "IB_FIFO_TC", -1); +NCCL_PARAM(RocmIbAsyncEvents,"IB_RETURN_ASYNC_EVENTS",1); +NCCL_PARAM(RocmIbEceEnable,"IB_ECE_ENABLE",1); +NCCL_PARAM(RocmIbDataDirect,"IB_DATA_DIRECT",1); +NCCL_PARAM(RocmIbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); +RCCL_PARAM(RocmIbQpsPerP2p, "IB_QPS_PER_P2P", 0); +NCCL_PARAM(RocmIbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); + +// AMD AINIC +RCCL_PARAM(CtsInlineData, "CTS_INLINE_DATA", -1); +RCCL_PARAM(CtsOffloadEnabled, "CTS_OFFLOAD_ENABLED", -1); + +extern int64_t rcclParamAinicRoce(); + +static ncclResult_t ncclIbStatsInit(struct ncclIbStats* stat) { + __atomic_store_n(&stat->fatalErrorCount, 0, __ATOMIC_RELAXED); + return ncclSuccess; +} +static void ncclIbStatsFatalError(struct ncclIbStats* stat){ + __atomic_fetch_add(&stat->fatalErrorCount, 1, __ATOMIC_RELAXED); +} +static ncclResult_t ncclIbStatsCheckFatalCount(struct ncclIbStats* stat, const char* funcName) { + if (ncclParamRocmIbAsyncEvents() && __atomic_load_n(&stat->fatalErrorCount, __ATOMIC_RELAXED)) { + ERROR("RCCL encountered a communication fatal error (detected in %s)\n", funcName); + ERROR("RCCL cannot recover from this network failure and now exiting. Please check the network health."); + return ncclSystemError; + } + return ncclSuccess; +} +static void ncclIbQpFatalError(struct ibv_qp* qp) { + ncclIbStatsFatalError((struct ncclIbStats*)qp->qp_context); +} +static void ncclIbCqFatalError(struct ibv_cq* cq) { + ncclIbStatsFatalError((struct ncclIbStats*)cq->cq_context); +} +// Calculate number of QPs based on P2P flag and device counts +static int ncclIbCalculateNqps(int isP2p, int localNdevs, int remoteNdevs, const char* funcName) { + auto qp_multiplier = (rcclParamRocmIbQpsPerP2p() > 0 && isP2p) ? + rcclParamRocmIbQpsPerP2p() : ncclParamRocmIbQpsPerConn(); + int localNqps = qp_multiplier * localNdevs; + int remoteNqps = qp_multiplier * remoteNdevs; + int maxNqps = (remoteNqps > localNqps) ? remoteNqps : localNqps; + INFO(NCCL_NET, "NET/IB: %s Max Nqps=%d, localNqps=%d, remoteNqps=%d", + funcName, maxNqps, localNqps, remoteNqps); + return maxNqps; +} + +static void ncclIbDevFatalError(struct ncclIbDev* dev) { + ncclIbStatsFatalError(&dev->stats); +} + +pthread_t rocmIbAsyncThread; +static void* rocmIbAsyncThreadMain(void* args) { + struct ncclIbDev* dev = (struct ncclIbDev*)args; + while (1) { + struct ibv_async_event event; + if (ncclSuccess != wrap_ibv_get_async_event(dev->context, &event)) { break; } + char *str; + struct ibv_cq* cq = event.element.cq; // only valid if CQ error + struct ibv_qp* qp = event.element.qp; // only valid if QP error + struct ibv_srq* srq = event.element.srq; // only valid if SRQ error + if (ncclSuccess != wrap_ibv_event_type_str(&str, event.event_type)) { break; } + switch (event.event_type) { + case IBV_EVENT_DEVICE_FATAL: + // the above is device fatal error + WARN("NET/IB : %s:%d async fatal event: %s", dev->devName, dev->portNum, str); + ncclIbDevFatalError(dev); + break; + case IBV_EVENT_CQ_ERR: + // the above is a CQ fatal error + WARN("NET/IB : %s:%d async fatal event on CQ (%p): %s", dev->devName, dev->portNum, cq, str); + ncclIbCqFatalError(cq); + break; + case IBV_EVENT_QP_FATAL: + case IBV_EVENT_QP_REQ_ERR: + case IBV_EVENT_QP_ACCESS_ERR: + // the above are QP fatal errors + WARN("NET/IB : %s:%d async fatal event on QP (%p): %s", dev->devName, dev->portNum, qp, str); + ncclIbQpFatalError(qp); + break; + case IBV_EVENT_SRQ_ERR: + // SRQ are not used in NCCL + WARN("NET/IB : %s:%d async fatal event on SRQ, unused for now (%p): %s", dev->devName, dev->portNum, srq, str); + break; + case IBV_EVENT_PATH_MIG_ERR: + case IBV_EVENT_PORT_ERR: + case IBV_EVENT_PATH_MIG: + case IBV_EVENT_PORT_ACTIVE: + case IBV_EVENT_SQ_DRAINED: + case IBV_EVENT_LID_CHANGE: + case IBV_EVENT_PKEY_CHANGE: + case IBV_EVENT_SM_CHANGE: + case IBV_EVENT_QP_LAST_WQE_REACHED: + case IBV_EVENT_CLIENT_REREGISTER: + case IBV_EVENT_SRQ_LIMIT_REACHED: + // the above are non-fatal + WARN("NET/IB : %s:%d Got async error event: %s", dev->devName, dev->portNum, str); + break; + case IBV_EVENT_COMM_EST: + break; + default: + WARN("NET/IB : %s:%d unknown event type (%d)", dev->devName, dev->portNum, event.event_type); + break; + } + // acknowledgment needs to happen last to avoid user-after-free + if (ncclSuccess != wrap_ibv_ack_async_event(&event)) { break; } + } + return NULL; +} + +static sa_family_t envIbAddrFamily(void) { + sa_family_t family = AF_INET; + const char* env = ncclGetEnv("NCCL_IB_ADDR_FAMILY"); + if (env == NULL || strlen(env) == 0) { + return family; + } + + INFO(NCCL_ENV, "NCCL_IB_ADDR_FAMILY set by environment to %s", env); + + if (strcmp(env, "AF_INET") == 0) { + family = AF_INET; + } else if (strcmp(env, "AF_INET6") == 0) { + family = AF_INET6; + } + + return family; +} + +static void* envIbAddrRange(sa_family_t af, int* mask) { + *mask = 0; + static struct in_addr addr; + static struct in6_addr addr6; + void *ret = (af == AF_INET) ? (void *)&addr : (void *)&addr6; + + const char* env = ncclGetEnv("NCCL_IB_ADDR_RANGE"); + if (NULL == env || strlen(env) == 0) { + return NULL; + } + + INFO(NCCL_ENV, "NCCL_IB_ADDR_RANGE set by environment to %s", env); + + char addrString[128] = { 0 }; + snprintf(addrString, 128, "%s", env); + char *addrStrPtr = addrString; + char *maskStrPtr = strstr(addrString, "/"); + if (NULL == maskStrPtr) { + return NULL; + } + *(maskStrPtr++) = '\0'; + + if (inet_pton(af, addrStrPtr, ret) == 0) { + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + return NULL; + } + + *mask = (int)strtol(maskStrPtr, NULL, 10); + if (af == AF_INET && *mask > 32) { + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + *mask = 0; + ret = NULL; + } else if (af == AF_INET6 && *mask > 128) { + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + *mask = 0; + ret = NULL; + } + + return ret; +} + +static sa_family_t getGidAddrFamily(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + bool isIpV4Mapped = ((a->s6_addr32[0] | a->s6_addr32[1]) | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL; + bool isIpV4MappedMulticast = (a->s6_addr32[0] == htonl(0xff0e0000) && ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL)); + return (isIpV4Mapped || isIpV4MappedMulticast) ? AF_INET : AF_INET6; +} + +static bool matchGidAddrPrefix(sa_family_t af, void* prefix, int prefixlen, union ibv_gid* gid) { + struct in_addr *base = NULL; + struct in6_addr *base6 = NULL; + struct in6_addr *addr6 = NULL;; + if (af == AF_INET) { + base = (struct in_addr *)prefix; + } else { + base6 = (struct in6_addr *)prefix; + } + addr6 = (struct in6_addr *)gid->raw; + +#define NETMASK(bits) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1))) + + int i = 0; + while (prefixlen > 0 && i < 4) { + if (af == AF_INET) { + int mask = NETMASK(prefixlen); + if ((base->s_addr & mask) ^ (addr6->s6_addr32[3] & mask)) { + break; + } + prefixlen = 0; + break; + } else { + if (prefixlen >= 32) { + if (base6->s6_addr32[i] ^ addr6->s6_addr32[i]) { + break; + } + prefixlen -= 32; + ++i; + } else { + int mask = NETMASK(prefixlen); + if ((base6->s6_addr32[i] & mask) ^ (addr6->s6_addr32[i] & mask)) { + break; + } + prefixlen = 0; + } + } + } + + return (prefixlen == 0) ? true : false; +} + +static bool configuredGid(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + int trailer = (a->s6_addr32[1] | a->s6_addr32[2] | a->s6_addr32[3]); + if (((a->s6_addr32[0] | trailer) == 0UL) || ((a->s6_addr32[0] == htonl(0xfe800000)) && (trailer == 0UL))) { + return false; + } + return true; +} + +static bool linkLocalGid(union ibv_gid* gid) { + const struct in6_addr *a = (struct in6_addr *)gid->raw; + if (a->s6_addr32[0] == htonl(0xfe800000) && a->s6_addr32[1] == 0UL) { + return true; + } + return false; +} + +static bool validGid(union ibv_gid* gid) { + return (configuredGid(gid) && !linkLocalGid(gid)); +} + +static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) { + char gidRoceVerStr[16] = { 0 }; + char roceTypePath[PATH_MAX] = { 0 }; + snprintf(roceTypePath, sizeof(roceTypePath), "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); + + int fd = open(roceTypePath, O_RDONLY); + if (fd == -1) { + WARN("NET/IB: open failed in ncclIbRoceGetVersionNum: %s", strerror(errno)); + return ncclSystemError; + } + int ret = read(fd, gidRoceVerStr, 15); + close(fd); + + if (ret == -1) { + // In containerized environments, read could return EINVAL if the GID index is not mapped to the + // container sysfs. In this case return ncclSuccess and let the caller move to next GID index. + if (errno == EINVAL) return ncclSuccess; + WARN("NET/IB: read failed in ncclIbRoceGetVersionNum: %s", strerror(errno)); + return ncclSystemError; + } + + if (strlen(gidRoceVerStr)) { + if (strncmp(gidRoceVerStr, "IB/RoCE v1", strlen("IB/RoCE v1")) == 0 || strncmp(gidRoceVerStr, "RoCE v1", strlen("RoCE v1")) == 0) { + *version = 1; + } else if (strncmp(gidRoceVerStr, "RoCE v2", strlen("RoCE v2")) == 0) { + *version = 2; + } + } + + return ncclSuccess; +} + +static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t portNum, sa_family_t af, void* prefix, int prefixlen, int roceVer, int gidIndexCandidate, int* gidIndex) { + union ibv_gid gid, gidCandidate; + NCCLCHECK(wrap_ibv_query_gid(context, portNum, *gidIndex, &gid)); + NCCLCHECK(wrap_ibv_query_gid(context, portNum, gidIndexCandidate, &gidCandidate)); + + sa_family_t usrFam = af; + sa_family_t gidFam = getGidAddrFamily(&gid); + sa_family_t gidCandidateFam = getGidAddrFamily(&gidCandidate); + bool gidCandidateMatchSubnet = matchGidAddrPrefix(usrFam, prefix, prefixlen, &gidCandidate); + + if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet) { + *gidIndex = gidIndexCandidate; + } else { + if (gidCandidateFam != usrFam || !validGid(&gidCandidate) || !gidCandidateMatchSubnet) { + return ncclSuccess; + } + int usrRoceVer = roceVer; + int gidRoceVerNum, gidRoceVerNumCandidate = -1; + const char* deviceName = wrap_ibv_get_device_name(context->device); + NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, *gidIndex, &gidRoceVerNum)); + NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, gidIndexCandidate, &gidRoceVerNumCandidate)); + if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid(&gid)) && gidRoceVerNumCandidate == usrRoceVer) { + *gidIndex = gidIndexCandidate; + } + } + + return ncclSuccess; +} + +// GID Format +// global: | 64b - subnet-prefix | 64b - EUI | +// raw : | 10b fixed | 22b 0 | 16b FLID | 16b subnet-prefix | 64b - EUI | +static uint16_t ncclIbExtractLocalSubnetPrefix(uint64_t subnet_prefix) +{ + return (be64toh(subnet_prefix) & 0xffff); +} + +static int ncclIbExtractFlid (union ibv_gid *gid) +{ + return ntohs(*((uint16_t*)((uintptr_t)(gid->raw) + 4))); +} + +static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex) { + int gidTblLen = portAttr->gid_tbl_len; + + //for IB, choose GID Index that will have routable FLID if present + if (portAttr->link_layer == IBV_LINK_LAYER_INFINIBAND) { + union ibv_gid gid; + int routableGidIndex = ncclParamRocmIbRoutableFlidIbGidIndex(); + if (routableGidIndex < gidTblLen) { + NCCLCHECK(wrap_ibv_query_gid(context, portNum, routableGidIndex, &gid)); + if (ncclIbExtractFlid(&gid) != 0) { + *gidIndex = routableGidIndex; + return ncclSuccess; + } + } + *gidIndex = 0; + return ncclSuccess; + } + + //for ROCE + *gidIndex = ncclParamRocmIbGidIndex(); + if (*gidIndex >= 0) { + return ncclSuccess; + } + + sa_family_t userAddrFamily = envIbAddrFamily(); + int userRoceVersion = ncclParamRocmIbRoceVersionNum(); + int prefixlen; + void *prefix = envIbAddrRange(userAddrFamily, &prefixlen); + + *gidIndex = 0; + for (int gidIndexNext = 1; gidIndexNext < gidTblLen; ++gidIndexNext) { + NCCLCHECK(ncclUpdateGidIndex(context, portNum, userAddrFamily, prefix, prefixlen, userRoceVersion, gidIndexNext, gidIndex)); + } + + return ncclSuccess; +} + +NCCL_PARAM(RocmIbDisable, "IB_DISABLE", 0); +NCCL_PARAM(RocmIbMergeVfs, "IB_MERGE_VFS", 1); +NCCL_PARAM(RocmIbMergeNics, "IB_MERGE_NICS", 1); + +// Returns 0 if this is the path of two VFs of the same physical device +static int ncclIbMatchVfPath(char* path1, char* path2) { + // Merge multi-port NICs into the same PCI device + if (ncclParamRocmIbMergeVfs()) { + return strncmp(path1, path2, strlen(path1)-4) == 0; + } else { + return strncmp(path1, path2, strlen(path1)-1) == 0; + } +} + +static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) { + char devicePath[PATH_MAX]; + snprintf(devicePath, PATH_MAX, "/sys/class/infiniband/%s/device", devName); + char* p = realpath(devicePath, NULL); + if (p == NULL) { + WARN("Could not find real path of %s (%s)", devName, devicePath); + } else { + // Merge multi-port NICs into the same PCI device + p[strlen(p)-1] = '0'; + // Also merge virtual functions (VF) into the same device + if (ncclParamRocmIbMergeVfs()) p[strlen(p)-3] = p[strlen(p)-4] = '0'; + // Keep the real port aside (the ibv port is always 1 on recent cards) + *realPort = 0; + for (int d=0; dndevs > 1) { + INFO(NCCL_NET, "NET/IB : Skipping makeVDevice, NCCL_IB_MERGE_NICS=0"); + return ncclInvalidUsage; + } + + if (props->ndevs == 0) { + WARN("NET/IB : Can't make virtual NIC with 0 devices"); + return ncclInvalidUsage; + } + + if (ncclNMergedIbDevs == MAX_IB_VDEVS) { + WARN("NET/IB : Cannot allocate any more virtual devices (%d)", MAX_IB_VDEVS); + return ncclInvalidUsage; + } + + // Always count up number of merged devices + ncclIbMergedDev* mDev = rocmIbMergedDevs + ncclNMergedIbDevs; + mDev->vProps.ndevs = 0; + mDev->speed = 0; + + for (int i = 0; i < props->ndevs; i++) { + ncclIbDev* dev = rocmIbDevs + props->devs[i]; + if (mDev->vProps.ndevs == NCCL_IB_MAX_DEVS_PER_NIC) return ncclInvalidUsage; + mDev->vProps.devs[mDev->vProps.ndevs++] = props->devs[i]; + mDev->speed += dev->speed; + // Each successive time, copy the name '+' new name + if (mDev->vProps.ndevs > 1) { + snprintf(mDev->devName + strlen(mDev->devName), sizeof(mDev->devName) - strlen(mDev->devName), "+%s", dev->devName); + // First time, copy the plain name + } else { + strncpy(mDev->devName, dev->devName, MAXNAMESIZE); + } + } + + // Check link layers + ncclIbDev* dev0 = rocmIbDevs + props->devs[0]; + for (int i = 1; i < props->ndevs; i++) { + if (props->devs[i] >= ncclNIbDevs) { + WARN("NET/IB : Cannot use physical device %d, max %d", props->devs[i], ncclNIbDevs); + return ncclInvalidUsage; + } + ncclIbDev* dev = rocmIbDevs + props->devs[i]; + if (dev->link != dev0->link) { + WARN("NET/IB : Attempted to merge incompatible devices: [%d]%s:%d/%s and [%d]%s:%d/%s. Try selecting NICs of only one link type using NCCL_IB_HCA", + props->devs[0], dev0->devName, dev0->portNum, NCCL_IB_LLSTR(dev0->link), props->devs[i], dev->devName, dev->portNum, NCCL_IB_LLSTR(dev->link)); + return ncclInvalidUsage; + } + } + + *d = ncclNMergedIbDevs++; + INFO(NCCL_NET, "NET/IB : Made virtual device [%d] name=%s speed=%d ndevs=%d", *d, mDev->devName, mDev->speed, mDev->vProps.ndevs); + return ncclSuccess; +} + +ncclResult_t rocmIbMakeVDevice(int* d, ncclNetVDeviceProps_t* props) { + pthread_mutex_lock(&rocmIbLock); + ncclResult_t res = rocmIbMakeVDeviceInternal(d, props); + pthread_mutex_unlock(&rocmIbLock); + return res; +} + +static ncclProfilerCallback_t ncclProfilerFunction; + +ncclResult_t rocmIbInit(ncclDebugLogger_t logFunction, ncclProfilerCallback_t profFunction) { + ncclResult_t ret = ncclSuccess; + ncclProfilerFunction = profFunction; + if (ncclParamRocmIbDisable()) return ncclInternalError; + static int shownIbHcaEnv = 0; + if(wrap_ibv_symbols() != ncclSuccess) { return ncclInternalError; } + if(wrap_mlx5dv_symbols() != ncclSuccess) { INFO(NCCL_NET, "NET/IB : Failed to open mlx5dv symbols. Advance features like CX-8 Direct-NIC will be disabled."); } + if(wrap_ionicdv_symbols() != ncclSuccess) { + WARN("NET/IB : Failed to open ionicdv symbols. Advance features like AINIC UD load balancing will be disabled."); + return ncclInternalError; + } + + // Detect IB cards + int nIbDevs = 0; + struct ibv_device** devices = NULL; + + if (ncclNIbDevs == -1) { + pthread_mutex_lock(&rocmIbLock); + wrap_ibv_fork_init(); + if (ncclNIbDevs == -1) { + int nIpIfs = 0; + ncclNIbDevs = 0; + ncclNMergedIbDevs = 0; + NCCLCHECK(ncclFindInterfaces(ncclIbIfName, &ncclIbIfAddr, MAX_IF_NAME_SIZE, 1, &nIpIfs)); + if (nIpIfs != 1) { + WARN("NET/IB : No IP interface found."); + ret = ncclInternalError; + goto fail; + } + + // Check if user defined which IB device:port to use + const char* userIbEnv = ncclGetEnv("NCCL_IB_HCA"); + if (userIbEnv != NULL && shownIbHcaEnv++ == 0) INFO(NCCL_NET|NCCL_ENV, "NCCL_IB_HCA set to %s", userIbEnv); + struct netIf userIfs[MAX_IB_DEVS]; + bool searchNot = userIbEnv && userIbEnv[0] == '^'; + if (searchNot) userIbEnv++; + bool searchExact = userIbEnv && userIbEnv[0] == '='; + if (searchExact) userIbEnv++; + int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS); + + if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; } + + for (int d=0; dname); + continue; + } + enum ncclIbProvider ibProvider = IB_PROVIDER_NONE; + char dataDirectDevicePath[PATH_MAX]; + int dataDirectSupported = 0; + int skipNetDevForDataDirect = 0; + if (wrap_mlx5dv_is_supported(devices[d])) { + ibProvider = IB_PROVIDER_MLX5; + snprintf(dataDirectDevicePath, PATH_MAX, "/sys"); + if((ncclMlx5dvDmaBufCapable(context)) && (wrap_mlx5dv_get_data_direct_sysfs_path(context, dataDirectDevicePath + 4, PATH_MAX - 4) == ncclSuccess)) { + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Data Direct DMA Interface is detected for device:%s", devices[d]->name); + // Now check whether Data Direct has been disabled by the user + if(ncclParamRocmIbDataDirect() == 1) { dataDirectSupported = 1; skipNetDevForDataDirect = 1; } + if(ncclParamRocmIbDataDirect() == 2) { dataDirectSupported = 1; skipNetDevForDataDirect = 0; } + } + } + int nPorts = 0; + struct ibv_device_attr devAttr; + memset(&devAttr, 0, sizeof(devAttr)); + if (ncclSuccess != wrap_ibv_query_device(context, &devAttr)) { + WARN("NET/IB : Unable to query device %s", devices[d]->name); + if (ncclSuccess != wrap_ibv_close_device(context)) + { + ret = ncclInternalError; + goto fail; + } + continue; + } + for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) { + // dataDirect = 0 exposes the devices normally, dataDirect = 1 exposes the devices through direct NIC + for (int dataDirect = skipNetDevForDataDirect; dataDirect < 1 + dataDirectSupported; ++dataDirect) { + struct ibv_port_attr portAttr; + if (ncclSuccess != wrap_ibv_query_port(context, port_num, &portAttr)) { + WARN("NET/IB : Unable to query port_num %d", port_num); + continue; + } + if (portAttr.state != IBV_PORT_ACTIVE) continue; + if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND + && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) continue; + + // check against user specified HCAs/ports + if (! (matchIfList(devices[d]->name, port_num, userIfs, nUserIfs, searchExact) ^ searchNot)) { + continue; + } + pthread_mutex_init(&rocmIbDevs[ncclNIbDevs].lock, NULL); + rocmIbDevs[ncclNIbDevs].device = d; + rocmIbDevs[ncclNIbDevs].ibProvider = ibProvider; + rocmIbDevs[ncclNIbDevs].guid = devAttr.sys_image_guid; + rocmIbDevs[ncclNIbDevs].portAttr = portAttr; + rocmIbDevs[ncclNIbDevs].portNum = port_num; + rocmIbDevs[ncclNIbDevs].link = portAttr.link_layer; + if (portAttr.active_speed_ex) + // A non-zero active_speed_ex indicates XDR rate (0x100) or higher + rocmIbDevs[ncclNIbDevs].speed = ncclIbSpeed(portAttr.active_speed_ex) * ncclIbWidth(portAttr.active_width); + else + rocmIbDevs[ncclNIbDevs].speed = ncclIbSpeed(portAttr.active_speed) * ncclIbWidth(portAttr.active_width); + rocmIbDevs[ncclNIbDevs].context = context; + rocmIbDevs[ncclNIbDevs].pdRefs = 0; + rocmIbDevs[ncclNIbDevs].pd = NULL; + if (!dataDirect) { + strncpy(rocmIbDevs[ncclNIbDevs].devName, devices[d]->name, MAXNAMESIZE); + NCCLCHECKGOTO(ncclIbGetPciPath(rocmIbDevs[ncclNIbDevs].devName, &rocmIbDevs[ncclNIbDevs].pciPath, &rocmIbDevs[ncclNIbDevs].realPort), ret, fail); + } else { + snprintf(rocmIbDevs[ncclNIbDevs].devName, MAXNAMESIZE, "%s_dma", devices[d]->name); + NCCLCHECK(ncclCalloc(&rocmIbDevs[ncclNIbDevs].pciPath, PATH_MAX)); + strncpy(rocmIbDevs[ncclNIbDevs].pciPath, dataDirectDevicePath, PATH_MAX); + rocmIbDevs[ncclNIbDevs].capsProvider.mlx5.dataDirect = 1; + } + rocmIbDevs[ncclNIbDevs].maxQp = devAttr.max_qp; + rocmIbDevs[ncclNIbDevs].mrCache.capacity = 0; + rocmIbDevs[ncclNIbDevs].mrCache.population = 0; + rocmIbDevs[ncclNIbDevs].mrCache.slots = NULL; + NCCLCHECK(ncclIbStatsInit(&rocmIbDevs[ncclNIbDevs].stats)); + + // Enable ADAPTIVE_ROUTING by default on IB networks + // But allow it to be overloaded by an env parameter + rocmIbDevs[ncclNIbDevs].ar = (portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND) ? 1 : 0; + if (ncclParamRocmIbAdaptiveRouting() != -2) rocmIbDevs[ncclNIbDevs].ar = ncclParamRocmIbAdaptiveRouting(); + + INFO(NCCL_NET,"NET/IB: [%d] %s:%s:%d/%s provider=%s speed=%d context=%p pciPath=%s ar=%d", d, devices[d]->name, devices[d]->dev_name, rocmIbDevs[ncclNIbDevs].portNum, + NCCL_IB_LLSTR(portAttr.link_layer), rocmIbProviderName[rocmIbDevs[ncclNIbDevs].ibProvider], rocmIbDevs[ncclNIbDevs].speed, context, rocmIbDevs[ncclNIbDevs].pciPath, rocmIbDevs[ncclNIbDevs].ar); + + PTHREADCHECKGOTO(pthread_create(&rocmIbAsyncThread, NULL, rocmIbAsyncThreadMain, rocmIbDevs + ncclNIbDevs), "pthread_create", ret, fail); + ncclSetThreadName(rocmIbAsyncThread, "NCCL IbAsync %2d", ncclNIbDevs); + PTHREADCHECKGOTO(pthread_detach(rocmIbAsyncThread), "pthread_detach", ret, fail); // will not be pthread_join()'d + + // Add this plain physical device to the list of virtual devices + int vDev; + ncclNetVDeviceProps_t vProps = {0}; + vProps.ndevs = 1; + vProps.devs[0] = ncclNIbDevs; + NCCLCHECK(rocmIbMakeVDeviceInternal(&vDev, &vProps)); + + ncclNIbDevs++; + nPorts++; + } + } + if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; } + } + if (ncclSuccess != wrap_ibv_free_device_list(devices)) { ret = ncclInternalError; goto fail;} + } + if (ncclNIbDevs == 0) { + INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found."); + } + + // Print out all net devices to the user (in the same format as before) + char line[2048]; + line[0] = '\0'; + // Determine whether RELAXED_ORDERING is enabled and possible + ncclIbRelaxedOrderingEnabled = ncclIbRelaxedOrderingCapable(); + for (int d = 0; d < ncclNIbDevs; d++) { + snprintf(line+strlen(line), sizeof(line)-strlen(line), " [%d]%s:%d/%s", d, rocmIbDevs[d].devName, + rocmIbDevs[d].portNum, NCCL_IB_LLSTR(rocmIbDevs[d].link)); + } + char addrline[SOCKET_NAME_MAXLEN+1]; + INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", + ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline)); + + ncclIbUseInline = ncclParamRocmIbUseInline(); + ncclIbGdrFlushDisable = ncclParamRocmIbGdrFlushDisable(); + + rcclAinicRoce = ((rcclParamAinicRoce() == 1) ? true : false); + if (rcclAinicRoce) { + // for AINIC, these params are defaulted to enabled unless user forces it to disable(0). + rcclCtsInlineData = ((rcclParamCtsInlineData() == 0) ? false : true); + rcclCtsOffloadEnabled = ((rcclParamCtsOffloadEnabled() == 0) ? false : true); + // for AINIC IbUseInline is enabled by default always + ncclIbUseInline = true; + // for AINIC GDR flush is disabled by default + ncclIbGdrFlushDisable = 1; + + INFO(NCCL_INIT|NCCL_NET, "NET/IB : AINIC RoCEv2 optimizations enabled: CTS Inline Data: %s; CTS Offload: %s; " + "IB Use Inline: enabled; GDR Flush: disabled", rcclCtsInlineData ? "Enabled": "Disabled", + rcclCtsOffloadEnabled ? "Enabled": "Disabled"); + } + + pthread_mutex_unlock(&rocmIbLock); + } +exit: + return ret; +fail: + if(ncclSuccess != wrap_ibv_free_device_list(devices)){WARN("NET/IB : Unable to free device list");} + pthread_mutex_unlock(&rocmIbLock); + goto exit; +} + +ncclResult_t rocmIbDevices(int* ndev) { + *ndev = ncclNMergedIbDevs; + return ncclSuccess; +} + +// Introduce RCCL_FORCE_ENABLE_GDRDMA to force load GPU-NIC RDMA module +// Use ONLY for debugging! +RCCL_PARAM(RocmForceEnableGdrdma, "FORCE_ENABLE_GDRDMA", -1); + +// Detect whether GDR can work on a given NIC with the current CUDA device +// Returns : +// ncclSuccess : GDR works +// ncclSystemError : no module or module loaded but not supported by GPU +#define KNL_MODULE_LOADED(a) ((access(a, F_OK) == -1) ? 0 : 1) +static int ncclIbGdrModuleLoaded = 0; // 1 = true, 0 = false +static void ibGdrSupportInitOnce() { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + if (rcclParamRocmForceEnableGdrdma() == 1) { + // RCCL_FORCE_ENABLE_GDRDMA=1 enables GPU-NIC RDMA only from RCCL-side + // Requires support from NIC driver modules + // Use ONLY for debugging! + ncclIbGdrModuleLoaded = 1; + INFO(NCCL_INIT, "RCCL_FORCE_ENABLE_GDRDMA = 1, so explicitly setting ncclIbGdrModuleLoaded = 1"); + } + + if (ncclIbGdrModuleLoaded == 0) { + // Check for `memory_peers` directory containing `amdkfd/version` + // This `memory_peers` directory is created by NIC-GPU driver interaction + // On Linux kernel 5.15.0 (e.g. Ubuntu 22.04), `memory_peers` is created under `/sys/kernel/mm/` + // However, on newer kernels like Ubuntu 24.04.1 (Linux kernel 6.8.0) or Ubuntu 22.04.4 HWE (Linux kernel 6.5.0), + // this `memory_peers` directory is either not created (go to else-if condition) + // or created under a different path like `/sys/kernel/` or `/sys/` (depending on your ib_peer_mem module) + const char* memory_peers_paths[] = {"/sys/kernel/mm/memory_peers/amdkfd/version", + "/sys/kernel/memory_peers/amdkfd/version", + "/sys/memory_peers/amdkfd/version", + NULL}; + int i = 0; + + while (memory_peers_paths[i]) { + if (access(memory_peers_paths[i], F_OK) == 0) { + ncclIbGdrModuleLoaded = 1; + INFO(NCCL_INIT,"Found %s", memory_peers_paths[i]); + break; + } else { + ncclIbGdrModuleLoaded = 0; + } + ++i; + } + + char strValue[MAX_STR_LEN]; + ncclTopoGetStrFromSys("/sys/devices/virtual/dmi/id", "bios_version", strValue); + if (strncmp("Hyper-V UEFI Release", strValue, 20) == 0) { + int roMode = ncclParamRocmIbPciRelaxedOrdering(); + ncclTopoGetStrFromSys("/proc/sys/kernel", "numa_balancing", strValue); + if (strcmp(strValue, "1") == 0 && roMode == 0) + ncclIbGdrModuleLoaded = 0; + } + + if (ncclIbGdrModuleLoaded == 0) { + // Check for `ib_register_peer_memory_client` symbol in `/proc/kallsyms` + // if your system uses native OS ib_peer module + char buf[256]; + FILE *fp = NULL; + fp = fopen("/proc/kallsyms", "r"); + + if (fp == NULL) { + INFO(NCCL_INIT,"Could not open /proc/kallsyms"); + } else { + while (fgets(buf, sizeof(buf), fp) != NULL) { + if (strstr(buf, "t ib_register_peer_memory_client") != NULL || + strstr(buf, "T ib_register_peer_memory_client") != NULL) { + ncclIbGdrModuleLoaded = 1; + INFO(NCCL_INIT,"Found ib_register_peer_memory_client in /proc/kallsyms"); + break; + } + } + } + } + } +#else + // Check for the nv_peer_mem module being loaded + ncclIbGdrModuleLoaded = KNL_MODULE_LOADED("/sys/kernel/mm/memory_peers/nv_mem/version") || + KNL_MODULE_LOADED("/sys/kernel/mm/memory_peers/nv_mem_nc/version") || + KNL_MODULE_LOADED("/sys/module/nvidia_peermem/version"); +#endif +} + +ncclResult_t rocmIbGdrSupport() { + static pthread_once_t once = PTHREAD_ONCE_INIT; + pthread_once(&once, ibGdrSupportInitOnce); + if (!ncclIbGdrModuleLoaded) + return ncclSystemError; + return ncclSuccess; +} + +static __thread int ibDmaSupportInitDev; // which device to init, must be thread local +static void ibDmaBufSupportInitOnce(){ + ncclResult_t res; + int dev_fail = 0; + + // This is a physical device, not a virtual one, so select from ibDevs + ncclIbMergedDev* mergedDev = rocmIbMergedDevs + ibDmaSupportInitDev; + ncclIbDev* ibDev = rocmIbDevs + mergedDev->vProps.devs[0]; + struct ibv_pd* pd; + struct ibv_context* ctx = ibDev->context; + res = rocmLibraryInit(); + if (res != ncclSuccess) goto failure; + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); + // Test kernel DMA-BUF support with a dummy call (fd=-1) + (void)wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); + // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) + dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + // stop the search and goto failure + if (dev_fail) goto failure; + ibDev->dmaBufSupported = 1; + return; +failure: + ibDev->dmaBufSupported = -1; + return; +} +// Detect whether DMA-BUF support is present in the kernel +// Returns : +// ncclSuccess : DMA-BUF support is available +// ncclSystemError : DMA-BUF is not supported by the kernel +ncclResult_t rocmIbDmaBufSupport(int dev) { + struct oncewrap { + pthread_once_t once = PTHREAD_ONCE_INIT; + }; + static oncewrap onces[MAX_IB_DEVS]; + // init the device only once + ibDmaSupportInitDev = dev; + pthread_once(&onces[dev].once, ibDmaBufSupportInitOnce); + ncclIbMergedDev* mergedDev = rocmIbMergedDevs + ibDmaSupportInitDev; + ncclIbDev* ibDev = rocmIbDevs + mergedDev->vProps.devs[0]; + int dmaBufSupported = ibDev->dmaBufSupported; + if (dmaBufSupported == 1) return ncclSuccess; + return ncclSystemError; +} + +#define NCCL_NET_IB_MAX_RECVS 8 + +ncclResult_t rocmIbGetPhysProperties(int dev, ncclNetProperties_t* props) { + struct ncclIbDev* ibDev = rocmIbDevs + dev; + pthread_mutex_lock(&ibDev->lock); + props->name = ibDev->devName; + props->speed = ibDev->speed; + props->pciPath = ibDev->pciPath; + props->guid = ibDev->guid; + props->ptrSupport = NCCL_PTR_HOST; + if (rocmIbGdrSupport() == ncclSuccess) { + props->ptrSupport |= NCCL_PTR_CUDA; // GDR support via nv_peermem + } + props->regIsGlobal = 1; + if (rocmIbDmaBufSupport(dev) == ncclSuccess) { + props->ptrSupport |= NCCL_PTR_DMABUF; // GDR support via DMA-BUF + } + props->forceFlush = 0; + if (ibDev->capsProvider.mlx5.dataDirect) { + props->forceFlush = 1; + } + props->latency = 0; // Not set + props->port = ibDev->portNum + ibDev->realPort; + props->maxComms = ibDev->maxQp; + props->maxRecvs = NCCL_NET_IB_MAX_RECVS; + props->netDeviceType = NCCL_NET_DEVICE_HOST; + props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->maxP2pBytes = NCCL_MAX_NET_SIZE_BYTES; + pthread_mutex_unlock(&ibDev->lock); + return ncclSuccess; +} + +ncclResult_t rocmIbGetProperties(int dev, ncclNetProperties_t* props) { + if (dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Requested properties for vNic %d, only %d vNics have been created", dev, ncclNMergedIbDevs); + return ncclInvalidUsage; + } + struct ncclIbMergedDev* mergedDev = rocmIbMergedDevs + dev; + // Take the rest of the properties from an arbitrary sub-device (should be the same) + NCCLCHECK(rocmIbGetPhysProperties(mergedDev->vProps.devs[0], props)); + props->name = mergedDev->devName; + props->speed = mergedDev->speed; + memcpy(&props->vProps, &mergedDev->vProps, sizeof(ncclNetVDeviceProps_t)); + return ncclSuccess; +} + +// We need to support NCCL_NET_MAX_REQUESTS for each concurrent receive +#define MAX_REQUESTS (NCCL_NET_MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS) +static_assert(MAX_REQUESTS <= 256, "request id are encoded in wr_id and we need up to 8 requests ids per completion"); + +#define NCCL_IB_MAX_QPS 128 + +// Per-QP connection metatdata +struct ncclIbQpInfo { + uint32_t qpn; + + // Fields needed for ece (enhanced connection establishment) + struct ibv_ece ece; + int ece_supported; + int devIndex; +}; + +// Per-Dev connection metadata +struct ncclIbDevInfo { + uint32_t lid; + uint8_t ib_port; + enum ibv_mtu mtu; + uint8_t link_layer; + + // For RoCE and IB Rounter + union ibv_gid gid; + + // FIFO RDMA info + uint32_t fifoRkey; + + //remote dev info + union ibv_gid remoteGid; + + int ibv_dev_index; +}; + +// Struct containing everything needed to establish connections +struct ncclIbConnectionMetadata { + struct ncclIbQpInfo qpInfo[NCCL_IB_MAX_QPS]; + struct ncclIbDevInfo devs[NCCL_IB_MAX_DEVS_PER_NIC]; + char devName[MAX_MERGED_DEV_NAME]; + uint64_t fifoAddr; + int ndevs; + int tc; + int sl; + int isP2p; +}; + +enum ncclIbCommState { + ncclIbCommStateStart = 0, + ncclIbCommStateConnect = 1, + ncclIbCommStateAccept = 3, + ncclIbCommStateSend = 4, + ncclIbCommStateRecv = 5, + ncclIbCommStateConnecting = 6, + ncclIbCommStateConnected = 7, + ncclIbCommStatePendingReady = 8, + ncclIbCommStateSendDevList = 9, + ncclIbCommStateRecvDevList = 10, +}; + +struct ncclIbCommStage { + enum ncclIbCommState state; + int offset; + void* buffer; + void* comm; +}; + +struct ncclIbHandle { + union ncclSocketAddress connectAddr; // Filled by the target + uint64_t magic; // random number to help debugging + int isP2p; // P2P flag + struct ncclIbCommStage stage; // Used by the other side when connecting +}; + +// Retain local RoCE address for error logging +struct ncclIbGidInfo { + uint8_t link_layer; + union ibv_gid localGid; + int32_t localGidIndex; +}; + +#define NCCL_NET_IB_REQ_UNUSED 0 +#define NCCL_NET_IB_REQ_SEND 1 +#define NCCL_NET_IB_REQ_RECV 2 +#define NCCL_NET_IB_REQ_FLUSH 3 +const char* rocmIbReqTypeStr[] = { "Unused", "Send", "Recv", "Flush" }; + +#define MAX_QPS_PER_REQ 8 +struct ncclProfilerInfo { + void* qpEventHandles[MAX_QPS_PER_REQ]; + int qpIndex[MAX_QPS_PER_REQ]; + int nEventHandles; + ncclProfilerNetIbDescr_v1_t data; + void* pHandle; +}; + +struct ncclIbRequest { + struct ncclIbNetCommBase* base; + int type; + struct ncclSocket* sock; + int events[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbNetCommDevBase* devBases[NCCL_IB_MAX_DEVS_PER_NIC]; +#ifdef NCCL_ENABLE_NET_PROFILING + struct ncclProfilerInfo pInfo[NCCL_NET_IB_MAX_RECVS]; +#endif + int nreqs; + union { + struct { + int size; + void* data; + uint32_t lkeys[NCCL_IB_MAX_DEVS_PER_NIC]; + int offset; + } send; + struct { + int* sizes; + } recv; + }; +}; + +struct ncclIbNetCommDevBase { + int ibDevN; + struct ibv_pd* pd; + struct ibv_cq* cq; + uint64_t pad[2]; + struct ncclIbGidInfo gidInfo; +}; + +struct ncclIbListenComm { + int dev; + struct ncclSocket sock; + struct ncclIbCommStage stage; +}; + +#define MAX_INLINE_DATA_SIZE 24 + +struct alignas(64) ncclIbSendFifo { + uint64_t addr; + uint64_t size; + uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; + uint32_t nreqs; + uint32_t tag; + uint64_t idx; + char padding[16]; +}; + +struct alignas(32) ncclIbSendFifoCtsInline { + uint64_t addr; + uint32_t rkeys[1]; + int size; + uint8_t nreqs; + uint16_t tag; + uint32_t idx; + char padding[9]; +} __attribute__((packed)); + +struct ncclIbQp { + struct ibv_qp* qp; + int devIndex; + int remDevIdx; + int8_t ctsQpSlot; +}; + +struct ncclIbRemSizesFifo { + int elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + uint64_t fifoTail; + uint64_t addr; + uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; + uint32_t flags; + struct ibv_mr* mrs[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ibv_sge sge; +}; + +// A per-dev struct for netIbSendComm +struct alignas(8) ncclIbSendCommDev { + struct ncclIbNetCommDevBase base; + struct ibv_mr* fifoMr; +}; + + +// Wrapper to track an MR per-device, if needed +struct ncclIbMrHandle { + ibv_mr* mrs[NCCL_IB_MAX_DEVS_PER_NIC]; +}; + +struct alignas(32) ncclIbNetCommBase { + ncclNetVDeviceProps_t vProps; + bool isSend; + struct ncclIbRequest reqs[MAX_REQUESTS]; + struct ncclIbQp qps[NCCL_IB_MAX_QPS]; + int nqps; + int qpIndex; + int devIndex; + struct ncclSocket sock; + int ready; + // Track necessary remDevInfo here + int nRemDevs; + int nDataQps; + struct ncclIbDevInfo remDevs[NCCL_IB_MAX_DEVS_PER_NIC]; + // statistics about the comm + struct ncclIbStats stats; +}; + +struct ncclIbSendComm { + struct ncclIbNetCommBase base; + // Start with fifo and ibv structs as they have alignment restrictions + struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ncclIbSendFifoCtsInline fifo_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS]; + struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS + 1]; + // Each dev correlates to a mergedIbDev + struct ncclIbSendCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbRequest* fifoReqs[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ncclIbRemSizesFifo remSizesFifo; + uint64_t fifoHead; + int ar; // Use adaptive routing when all merged devices have it enabled +}; +// The SendFifo needs to be 32-byte aligned and each element needs +// to be a 32-byte multiple, so that an entry does not get split and +// written out of order when IB Relaxed Ordering is enabled +static_assert((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset"); +static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); +static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples"); +static_assert((sizeof(struct ncclIbSendFifoCtsInline) % 32) == 0, "ncclIbSendFifoCtsInline element size must be 32-byte multiples"); +static_assert((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned"); +static_assert((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned"); + +struct ncclIbGpuFlush { + struct ibv_mr* hostMr; + struct ibv_mr* gpuMr; + int* gpuFlushGpuMem; + struct ibv_sge sge; + struct ncclIbQp qp; + int dmabuf_fd; +}; + +struct ncclIbRemFifo { + struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ncclIbSendFifoCtsInline elems_cts_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + uint64_t fifoTail; + uint64_t addr; + uint32_t flags; +}; + +struct alignas(16) ncclIbRecvCommDev { + struct ncclIbNetCommDevBase base; + struct ncclIbGpuFlush gpuFlush; + struct ibv_mr* fifoMr; + struct ibv_sge fifoSge; + struct ibv_mr* sizesFifoMr; +}; + +struct ncclIbRecvComm { + struct ncclIbNetCommBase base; + struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbRemFifo remFifo; + int sizesFifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + int gpuFlushHostMem; + int flushEnabled; +}; +static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbRecvComm fifo must be 32-byte aligned"); + +static void ncclIbAddEvent(struct ncclIbRequest* req, int devIndex, struct ncclIbNetCommDevBase* base) { + req->events[devIndex]++; + req->devBases[devIndex] = base; +} +ncclResult_t rocmIbInitCommDevBase(int ibDevN, struct ncclIbNetCommDevBase* base, void* cq_context) { + base->ibDevN = ibDevN; + ncclIbDev* ibDev = rocmIbDevs + ibDevN; + pthread_mutex_lock(&ibDev->lock); + if (0 == ibDev->pdRefs++) { + ncclResult_t res; + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&ibDev->pd, ibDev->context), res, failure); + if (0) { + failure: + pthread_mutex_unlock(&ibDev->lock); + return res; + } + } + base->pd = ibDev->pd; + pthread_mutex_unlock(&ibDev->lock); + + // CQ is sized to accommodate the max SQ + RQ WQE completions. If each SQ WQE could be signaled, then, + // for each QP, there can be 2*MAX_REQUESTS completions for SQ and MAX_REQUESTS completions for RQ. + NCCLCHECK(wrap_ibv_create_cq(&base->cq, ibDev->context, 3*MAX_REQUESTS*ncclParamRocmIbQpsPerConn(), cq_context, NULL, 0)); + + return ncclSuccess; +} + +ncclResult_t rocmIbDestroyBase(struct ncclIbNetCommDevBase* base) { + ncclResult_t res; + NCCLCHECK(wrap_ibv_destroy_cq(base->cq)); + + pthread_mutex_lock(&rocmIbDevs[base->ibDevN].lock); + if (0 == --rocmIbDevs[base->ibDevN].pdRefs) { + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(rocmIbDevs[base->ibDevN].pd), res, returning); + } + res = ncclSuccess; +returning: + pthread_mutex_unlock(&rocmIbDevs[base->ibDevN].lock); + return res; +} + +ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, + int access_flags, void* qp_context, struct ncclIbQp* qp, + int channel_id, bool data_qp, int8_t cts_qp_slot) { + struct ibv_qp_init_attr qpInitAttr; + enum ncclIbChannelType channel_type = (data_qp ? ncclIbChannelTypeData : ncclIbChannelTypeCts); + memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr)); + qpInitAttr.qp_context = qp_context; + qpInitAttr.send_cq = base->cq; + qpInitAttr.recv_cq = base->cq; + qpInitAttr.qp_type = IBV_QPT_RC; + + if (rcclAinicRoce) { + if (!nccl_channel_ud_map[channel_id][channel_type].udAllocated) { + bool lud = nccl_channel_last_ud[base->ibDevN][channel_type]; + nccl_channel_ud_map[channel_id][channel_type].udId = lud; + nccl_channel_ud_map[channel_id][channel_type].udAllocated = true; + nccl_channel_last_ud[base->ibDevN][channel_type] = + !(nccl_channel_last_ud[base->ibDevN][channel_type]); + } + if (nccl_channel_ud_map[channel_id][channel_type].udId) { + wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_HIGH); + } else { + wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_LOW); + } + qpInitAttr.sq_sig_all |= (1 << 16); + if (data_qp) { + qpInitAttr.sq_sig_all |= (1 << 17); + } else { + qpInitAttr.sq_sig_all &= (~(1 << 17)); + } + qpInitAttr.sq_sig_all |= (1 << 18); + + if (rcclCtsOffloadEnabled) { + qpInitAttr.sq_sig_all |= (1 << 19); + } else { + qpInitAttr.sq_sig_all &= (~(1 << 19)); + } + } + + // We might send 2 messages per send (RDMA and RDMA_WITH_IMM) + qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS; + qpInitAttr.cap.max_recv_wr = MAX_REQUESTS; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; + if (rcclCtsInlineData) { + qpInitAttr.cap.max_inline_data = MAX_INLINE_DATA_SIZE; + } else { + qpInitAttr.cap.max_inline_data = ncclIbUseInline ? sizeof(struct ncclIbSendFifo) : 0; + } + NCCLCHECK(wrap_ibv_create_qp(&qp->qp, base->pd, &qpInitAttr)); + if (rcclAinicRoce) { + NCCLCHECK(wrap_ionicdv_qp_set_gda(qp->qp, false, true)); + } + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); + qpAttr.qp_state = IBV_QPS_INIT; + qpAttr.pkey_index = ncclParamRocmIbPkey(); + qpAttr.port_num = ib_port; + qpAttr.qp_access_flags = access_flags; + NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p", + ib_port, base->ibDevN, rocmIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd); + if (rcclAinicRoce) { + qp->ctsQpSlot = cts_qp_slot; + } + return ncclSuccess; +} + +ncclResult_t rocmIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool fifoTc, int tc, int sl) { + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); + qpAttr.qp_state = IBV_QPS_RTR; + qpAttr.path_mtu = info->mtu; + qpAttr.dest_qp_num = dest_qp_num; + qpAttr.rq_psn = 0; + qpAttr.max_dest_rd_atomic = 1; + qpAttr.min_rnr_timer = 12; + if (info->link_layer == IBV_LINK_LAYER_ETHERNET) { + qpAttr.ah_attr.is_global = 1; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix; + qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id; + qpAttr.ah_attr.grh.flow_label = 0; + qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex; + qpAttr.ah_attr.grh.hop_limit = 255; + qpAttr.ah_attr.grh.traffic_class = fifoTc && ncclParamRocmIbFifoTc() != -1 ? ncclParamRocmIbFifoTc() : tc; + } else { + //pick lid if subnet prefixs are same, FLID if they are not + if (ncclIbExtractLocalSubnetPrefix(sGidInfo->localGid.global.subnet_prefix) == + ncclIbExtractLocalSubnetPrefix(info->gid.global.subnet_prefix)) { + qpAttr.ah_attr.is_global = 0; + qpAttr.ah_attr.dlid = info->lid; + } else { + uint16_t flid = ncclIbExtractFlid(&info->gid); + if (flid == 0) { + WARN("Warning: remote FLID configured as zero even when endpoints are on different subnets, using dlid as fallback"); + qpAttr.ah_attr.dlid = info->lid; + } else { + qpAttr.ah_attr.dlid = ncclIbExtractFlid(&info->gid); + } + qpAttr.ah_attr.is_global = 1; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix; + qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id; + qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex; + qpAttr.ah_attr.grh.hop_limit = 255; + } + } + qpAttr.ah_attr.sl = sl; + qpAttr.ah_attr.src_path_bits = 0; + qpAttr.ah_attr.port_num = info->ib_port; + TRACE(NCCL_NET, "NET/IB : rocmIbRtrQp qpn=%u mtu=%d dst=%u ll=%u port=%u sl: %d tc: %d", qp->qp_num, info->mtu, dest_qp_num, info->link_layer, info->ib_port, qpAttr.ah_attr.sl, qpAttr.ah_attr.grh.traffic_class); + NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)); + return ncclSuccess; +} + +ncclResult_t rocmIbRtsQp(struct ibv_qp* qp) { + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); + qpAttr.qp_state = IBV_QPS_RTS; + qpAttr.timeout = ncclParamRocmIbTimeout(); + qpAttr.retry_cnt = ncclParamRocmIbRetryCnt(); + qpAttr.rnr_retry = 7; + qpAttr.sq_psn = 0; + qpAttr.max_rd_atomic = 1; + NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC)); + return ncclSuccess; +} + +ncclResult_t rocmIbListen(int dev, void* opaqueHandle, void** listenComm) { + ncclResult_t ret = ncclSuccess; + struct ncclIbListenComm* comm; + NCCLCHECK(ncclCalloc(&comm, 1)); + struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; + static_assert(sizeof(struct ncclIbHandle) < NCCL_NET_HANDLE_MAXSIZE, "ncclIbHandle size too large"); + memset(handle, 0, sizeof(struct ncclIbHandle)); + comm->dev = dev; + handle->magic = NCCL_SOCKET_MAGIC; + NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclIbIfAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); + NCCLCHECKGOTO(ncclSocketListen(&comm->sock), ret, fail); + NCCLCHECKGOTO(ncclSocketGetAddr(&comm->sock, &handle->connectAddr), ret, fail); + *listenComm = comm; +exit: + return ret; +fail: + (void)ncclSocketClose(&comm->sock); + free(comm); + goto exit; +} + +ncclResult_t rocmIbConnect(int dev, ncclNetCommConfig_t* config, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** sendDevComm) { + ncclResult_t ret = ncclSuccess; + struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; + struct ncclIbCommStage* stage = &handle->stage; + struct ncclIbSendComm* comm = (struct ncclIbSendComm*)stage->comm; + int ready; + uint8_t link_layer = IBV_LINK_LAYER_UNSPECIFIED; + int isP2p = 0; + int channel_id = 0; + *sendComm = NULL; + + if (rcclAinicRoce) { + channel_id = ((ncclNet_ctxt_t *)sendDevComm)->chId; + } + + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnecting) goto ib_connect; + if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; + if (stage->state != ncclIbCommStateStart) { + WARN("Error: trying to connect already connected sendComm"); + return ncclInternalError; + } + stage->buffer = NULL; + + NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); + NCCLCHECKGOTO(ncclIbStatsInit(&comm->base.stats), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&comm->base.sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); + stage->comm = comm; + stage->state = ncclIbCommStateConnect; + NCCLCHECKGOTO(ncclSocketConnect(&comm->base.sock), ret, fail); + +ib_connect_check: + /* since ncclSocketConnect is async, we must check if connection is complete */ + NCCLCHECKGOTO(ncclSocketReady(&comm->base.sock, &ready), ret, fail); + if (!ready) return ncclSuccess; + + // IB Setup + struct ncclIbMergedDev* mergedDev; + if (dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Trying to use non-existent virtual device %d", dev); + return ncclInternalError; + } + + mergedDev = rocmIbMergedDevs + dev; + comm->base.vProps = mergedDev->vProps; + comm->base.isSend = true; + stage->state = ncclIbCommStateSendDevList; + stage->offset = 0; + struct ncclIbConnectionMetadata meta; + NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(meta)), ret, fail); + memcpy(stage->buffer, &mergedDev->vProps, sizeof(ncclNetVDeviceProps_t)); + +// In the case of mismatched nDevs, we will make sure that both sides of a logical connection have the same number of RC qps +ib_send_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + + stage->state = ncclIbCommStateRecvDevList; + stage->offset = 0; + +ib_recv_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + stage->offset = 0; + ncclNetVDeviceProps_t remoteVProps; + memcpy(&remoteVProps, stage->buffer, sizeof(ncclNetVDeviceProps_t)); + mergedDev = rocmIbMergedDevs + dev; + comm->base.vProps = mergedDev->vProps; + + // Read isP2p from handle + isP2p = handle->isP2p; + INFO(NCCL_NET, "NET/IB: rocmIbConnect isP2p=%d", isP2p); + comm->base.nqps = ncclIbCalculateNqps(isP2p, comm->base.vProps.ndevs, + remoteVProps.ndevs, __func__); + + // Init PD, Ctx for each IB device + comm->ar = 1; // Set to 1 for logic + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + int ibDevN = comm->base.vProps.devs[i]; + NCCLCHECKGOTO(rocmIbInitCommDevBase(ibDevN, &comm->devs[i].base, &comm->base.stats), ret, fail); + comm->ar = comm->ar && rocmIbDevs[ibDevN].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled + } + + memset(&meta, 0, sizeof(meta)); + meta.ndevs = comm->base.vProps.ndevs; + meta.isP2p = isP2p; + // Alternate QPs between devices + int devIndex; + devIndex = 0; + for (int q = 0; q < comm->base.nqps; q++) { + ncclIbSendCommDev* commDev = comm->devs + devIndex; + ncclIbDev* ibDev = rocmIbDevs + commDev->base.ibDevN; + NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail); + comm->base.qps[q].devIndex = devIndex; + meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex; + + if (ncclParamRocmIbEceEnable()) { + // Query ece capabilities (enhanced connection establishment) + NCCLCHECKGOTO(wrap_ibv_query_ece(comm->base.qps[q].qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported), ret, fail); + } else { + meta.qpInfo[q].ece_supported = 0; + } + devIndex = (devIndex + 1) % comm->base.vProps.ndevs; + } + + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + ncclIbSendCommDev* commDev = comm->devs + i; + ncclIbDev* ibDev = rocmIbDevs + commDev->base.ibDevN; + + // Write to the metadata struct via this pointer + ncclIbDevInfo* devInfo = meta.devs + i; + devInfo->ib_port = ibDev->portNum; + devInfo->mtu = ibDev->portAttr.active_mtu; + devInfo->lid = ibDev->portAttr.lid; + devInfo->ibv_dev_index = commDev->base.ibDevN; + // Prepare my fifo + if (rcclCtsInlineData) { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo_inline, sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } else { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } + devInfo->fifoRkey = commDev->fifoMr->rkey; + + // Pack local GID info + devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + NCCLCHECKGOTO(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &commDev->base.gidInfo.localGidIndex), ret, fail); + NCCLCHECKGOTO(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid), ret, fail); + devInfo->gid.global.subnet_prefix = commDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo->gid.global.interface_id = commDev->base.gidInfo.localGid.global.interface_id; + + // info logging + for (int q = 0; q < comm->base.nqps; q++) { + // Print just the QPs for this dev + if (comm->base.qps[q].devIndex == i) { + if (devInfo->link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d subnet-prefix %lu FLID %d fifoRkey=0x%x fifoLkey=0x%x", + comm->base.vProps.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", + dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, + devInfo->gid.global.subnet_prefix, ncclIbExtractFlid(&devInfo->gid), devInfo->fifoRkey, commDev->fifoMr->lkey); + } else { // RoCE + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", + comm->base.vProps.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, + commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, + (int64_t)commDev->base.gidInfo.localGidIndex, + devInfo->gid.global.subnet_prefix, devInfo->gid.global.interface_id, devInfo->fifoRkey, commDev->fifoMr->lkey); + } + // Log ECE info + if (meta.qpInfo[q].ece_supported) { + INFO(NCCL_NET,"NET/IB: IbDev %d Port %d qpn %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x}", + commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, + meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask); + } + } + } + if (link_layer == IBV_LINK_LAYER_UNSPECIFIED) link_layer = devInfo->link_layer; + if (link_layer != devInfo->link_layer) { + int ibDev0 = comm->devs[0].base.ibDevN; + WARN("NET/IB : Attempted to connect incompatible devices: [%d]%s:%d/%s and [%d]%s:%d/%s. Try selecting NICs of only one link type using NCCL_IB_HCA", + commDev->base.ibDevN, ibDev->devName, ibDev->portNum, NCCL_IB_LLSTR(ibDev->portAttr.link_layer), ibDev0, rocmIbDevs[ibDev0].devName, rocmIbDevs[ibDev0].portNum, NCCL_IB_LLSTR(link_layer)); + return ncclInternalError; + } + } + if (rcclCtsInlineData) { + meta.fifoAddr = (uint64_t)comm->fifo_inline; + } else { + meta.fifoAddr = (uint64_t)comm->fifo; + } + meta.sl = (ncclParamRocmIbSl() != -1) ? ncclParamRocmIbSl() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_SL_DEFAULT; + meta.tc = (ncclParamRocmIbTc() != -1) ? ncclParamRocmIbTc() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_TC_DEFAULT; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); + + stage->state = ncclIbCommStateSend; + stage->offset = 0; + + memcpy(stage->buffer, &meta, sizeof(meta)); + +ib_send: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(meta), &stage->offset), ret, fail); + if (stage->offset != sizeof(meta)) return ncclSuccess; + + stage->state = ncclIbCommStateConnecting; + stage->offset = 0; + // Clear the staging buffer for re-use + memset(stage->buffer, 0, sizeof(meta)); + +ib_connect: + struct ncclIbConnectionMetadata remMeta; + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclIbConnectionMetadata), &stage->offset), ret, fail); + if (stage->offset != sizeof(remMeta)) return ncclSuccess; + + memcpy(&remMeta, stage->buffer, sizeof(ncclIbConnectionMetadata)); + + comm->base.nRemDevs = remMeta.ndevs; + + // ensure that the remote devices have the same link layer than the local devices used in the connection. + if (comm->base.vProps.ndevs > 0) { + int ibDev0 = comm->devs[0].base.ibDevN; + link_layer = rocmIbDevs[ibDev0].portAttr.link_layer; + for (int i = 0; i < remMeta.ndevs; i++) { + if (remMeta.devs[i].link_layer != link_layer) { + WARN("NET/IB : Remote %s device is incompatible with the local [%d]%s:%d/%s. Try selecting NICs of only one link type using NCCL_IB_HCA", + NCCL_IB_LLSTR(remMeta.devs[i].link_layer), ibDev0, rocmIbDevs[ibDev0].devName, rocmIbDevs[ibDev0].portNum, NCCL_IB_LLSTR(link_layer)); + return ncclInternalError; + } + } + } + + // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. + for (int i = 0; i < remMeta.ndevs; i++) { + comm->base.remDevs[i] = remMeta.devs[i]; + comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].gid.global.interface_id; + comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].gid.global.subnet_prefix; + // Retain remote sizes fifo info and prepare RDMA ops + comm->remSizesFifo.rkeys[i] = remMeta.devs[i].fifoRkey; + comm->remSizesFifo.addr = remMeta.fifoAddr; + } + + for (int i=0; i < comm->base.vProps.ndevs; i++) { + NCCLCHECKGOTO(wrap_ibv_reg_mr(comm->remSizesFifo.mrs+i, comm->devs[i].base.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } + comm->base.nRemDevs = remMeta.ndevs; + + for (int q = 0; q < comm->base.nqps; q++) { + struct ncclIbQpInfo* remQpInfo = remMeta.qpInfo + q; + struct ncclIbDevInfo* remDevInfo = remMeta.devs + remQpInfo->devIndex; + + // Assign per-QP remDev + comm->base.qps[q].remDevIdx = remQpInfo->devIndex; + int devIndex = comm->base.qps[q].devIndex; + ncclIbSendCommDev* commDev = comm->devs + devIndex; + + struct ibv_qp* qp = comm->base.qps[q].qp; + if (remQpInfo->ece_supported) { + struct ncclIbQp* nqp = comm->base.qps + q; + int ibDevN = comm->devs[nqp->devIndex].base.ibDevN; + struct ncclIbDev* ibDev = rocmIbDevs + ibDevN; + INFO(NCCL_NET,"NET/IB: IbDev %d Port %d qpn %d set_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x}", + ibDevN, ibDev->portNum, qp->qp_num, remMeta.qpInfo[q].ece_supported, remMeta.qpInfo[q].ece.vendor_id, remMeta.qpInfo[q].ece.options, remMeta.qpInfo[q].ece.comp_mask); + NCCLCHECKGOTO(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported), ret, fail); + } + + ncclIbDev* ibDev = rocmIbDevs + commDev->base.ibDevN; + remDevInfo->mtu = std::min(remDevInfo->mtu, ibDev->portAttr.active_mtu); + NCCLCHECKGOTO(rocmIbRtrQp(qp, &commDev->base.gidInfo, remQpInfo->qpn, remDevInfo, false, remMeta.tc, remMeta.sl), ret, fail); + NCCLCHECKGOTO(rocmIbRtsQp(qp), ret, fail); + } + + comm->base.nDataQps = std::max(comm->base.vProps.ndevs, comm->base.nRemDevs); + + comm->base.ready = 1; + stage->state = ncclIbCommStateConnected; + stage->offset = 0; + +ib_send_ready: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, &comm->base.ready, sizeof(int), &stage->offset), ret, fail); + if (stage->offset != sizeof(int)) return ncclSuccess; + + *sendComm = comm; +exit: + if (stage->buffer) free(stage->buffer); + stage->state = ncclIbCommStateStart; + return ret; +fail: + free(comm); + goto exit; +} + +NCCL_PARAM(RocmIbWarnRailLocal, "IB_WARN_RAIL_LOCAL", 0); + +ncclResult_t rocmIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDeviceProps_t* vProps2) { + ncclNetVDeviceProps_t outVProps = {0}; + ncclNetVDeviceProps_t* minVProps = vProps2; + ncclNetVDeviceProps_t* maxVProps = vProps1; + if (vProps2->ndevs > vProps1->ndevs) { + minVProps = vProps1; + maxVProps = vProps2; + } + + // Find the intersection of devices + for (int i = 0; i < minVProps->ndevs; i++) { + int dev = minVProps->devs[i]; + for (int j = 0; j < maxVProps->ndevs; j++) { + // Found + if (maxVProps->devs[j] == dev) { + outVProps.devs[outVProps.ndevs++] = dev; + } + } + } + + // In the case that at least one side has a fused NIC but there are no matching physical NICs, we should check if the user wants this + if (ncclParamRocmIbWarnRailLocal() && outVProps.ndevs < maxVProps->ndevs) { + char local[128]; + int cursor = 1; + snprintf(local, sizeof(local), "%d", vProps1->devs[0]); + for (int i = 1; i < vProps1->ndevs; i++) { + snprintf(local+cursor, sizeof(local)-cursor, ",%d", vProps1->devs[i]); + cursor += 2; + } + char remote[128]; + snprintf(remote, sizeof(remote), "%d", vProps2->devs[0]); + cursor = 1; + for (int i = 1; i < vProps2->ndevs; i++) { + snprintf(remote+cursor, sizeof(remote)-cursor, ",%d", vProps2->devs[i]); + cursor += 2; + } + INFO(NCCL_NET, "NET/IB : There are mismatched physical devices between local (%s) and remote (%s). To disable this warning, set NCCL_IB_WARN_RAIL_LOCAL=0", local, remote); + } + + return ncclSuccess; +} + +RCCL_PARAM(RocmIbGdrFlushGpuMemNoRelaxedOrdering, "GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING", 1); + +ncclResult_t rocmIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** recvDevComm) { + ncclResult_t ret = ncclSuccess; + struct ncclIbListenComm* lComm = (struct ncclIbListenComm*)listenComm; + struct ncclIbCommStage* stage = &lComm->stage; + struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*)stage->comm; + int ready; + int link_layer = IBV_LINK_LAYER_UNSPECIFIED; + int channel_id = 0; + *recvComm = NULL; + + if (rcclAinicRoce) { + channel_id = ((ncclNet_ctxt_t *) recvDevComm)->chId; + } + + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; + if (stage->state == ncclIbCommStateRecv) goto ib_recv; + if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStatePendingReady) goto ib_recv_ready; + if (stage->state != ncclIbCommStateStart) { + WARN("Listencomm in unknown state %d", stage->state); + return ncclInternalError; + } + + NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm))); + NCCLCHECKGOTO(ncclIbStatsInit(&rComm->base.stats), ret, fail); + stage->comm = rComm; + stage->state = ncclIbCommStateAccept; + NCCLCHECKGOTO(ncclSocketInit(&rComm->base.sock), ret, fail); + NCCLCHECKGOTO(ncclSocketAccept(&rComm->base.sock, &lComm->sock), ret, fail); + + // Alloc stage->buffer here to be used for all following steps + struct ncclIbConnectionMetadata remMeta; + stage->offset = 0; + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta))); + +ib_accept_check: + NCCLCHECKGOTO(ncclSocketReady(&rComm->base.sock, &ready), ret, fail); + if (!ready) return ncclSuccess; + stage->state = ncclIbCommStateRecvDevList; + stage->offset = 0; + +// In the case of mismatched nDevs, we will make sure that both sides of a logical connection have the same number of RC qps +ib_recv_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + ncclNetVDeviceProps_t remoteVProps; + memcpy(&remoteVProps, stage->buffer, sizeof(ncclNetVDeviceProps_t)); + if (lComm->dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Trying to use non-existent virtual device %d", lComm->dev); + return ncclInternalError; + } + + // Reduce the physical device list and store in the connection base + struct ncclIbMergedDev* mergedDev; + mergedDev = rocmIbMergedDevs + lComm->dev; + NCCLCHECK(rocmIbCheckVProps(&mergedDev->vProps, &remoteVProps)); + rComm->base.vProps = mergedDev->vProps; + memcpy(stage->buffer, &rComm->base.vProps, sizeof(ncclNetVDeviceProps_t)); + rComm->base.isSend = false; + stage->offset = 0; + stage->state = ncclIbCommStateSendDevList; + +ib_send_dev_list: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset), ret, fail); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + + stage->offset = 0; + stage->state = ncclIbCommStateRecv; + +ib_recv: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(remMeta), &stage->offset), ret, fail); + if (stage->offset != sizeof(remMeta)) return ncclSuccess; + + /* copy back the received info */ + memcpy(&remMeta, stage->buffer, sizeof(struct ncclIbConnectionMetadata)); + + // IB setup + // Pre-declare variables because of goto + struct ncclIbDev* ibDev; + int ibDevN; + struct ncclIbRecvCommDev* rCommDev; + struct ncclIbDevInfo* remDevInfo; + struct ncclIbQp* qp; + bool useDmaBuf; + + mergedDev = rocmIbMergedDevs + lComm->dev; + rComm->base.nRemDevs = remMeta.ndevs; + rComm->base.nqps = ncclIbCalculateNqps(remMeta.isP2p, rComm->base.vProps.ndevs, + remMeta.ndevs, __func__); + if (rComm->base.nRemDevs != rComm->base.vProps.ndevs) { + INFO(NCCL_NET, "NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", + mergedDev->devName, rComm->base.vProps.ndevs, remMeta.devName, rComm->base.nRemDevs); + } + + // Metadata to send back to requestor (sender) + struct ncclIbConnectionMetadata meta; + memset(&meta, 0, sizeof(meta)); + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { + rCommDev = rComm->devs + i; + ibDevN = rComm->base.vProps.devs[i]; + NCCLCHECKGOTO(rocmIbInitCommDevBase(ibDevN, &rCommDev->base, &rComm->base.stats), ret, fail); + ibDev = rocmIbDevs + ibDevN; + NCCLCHECKGOTO(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &rCommDev->base.gidInfo.localGidIndex), ret, fail); + NCCLCHECKGOTO(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid), ret, fail); + if (link_layer == IBV_LINK_LAYER_UNSPECIFIED) link_layer = ibDev->portAttr.link_layer; + if (link_layer != ibDev->portAttr.link_layer) { + int ibDev0 = rComm->devs[0].base.ibDevN; + WARN("NET/IB : Attempted to connect incompatible devices: [%d]%s:%d/%s and [%d]%s:%d/%s. Try selecting NICs of only one link type using NCCL_IB_HCA", + ibDevN, ibDev->devName, ibDev->portNum, NCCL_IB_LLSTR(ibDev->portAttr.link_layer), ibDev0, rocmIbDevs[ibDev0].devName, rocmIbDevs[ibDev0].portNum, NCCL_IB_LLSTR(link_layer)); + return ncclInternalError; + } + } + + // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. + for (int i = 0; i < remMeta.ndevs; i++) { + rComm->base.remDevs[i] = remMeta.devs[i]; + rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].gid.global.interface_id; + rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].gid.global.subnet_prefix; + if (remMeta.devs[i].link_layer != link_layer) { + int ibDev0 = rComm->devs[0].base.ibDevN; + WARN("NET/IB : Remote %s device is incompatible with the local [%d]%s:%d/%s. Try selecting NICs of only one link type using NCCL_IB_HCA", + NCCL_IB_LLSTR(remMeta.devs[i].link_layer), ibDev0, rocmIbDevs[ibDev0].devName, rocmIbDevs[ibDev0].portNum, NCCL_IB_LLSTR(link_layer)); + return ncclInternalError; + } + } + + // Stripe QP creation across merged devs + // Make sure to get correct remote peer dev and QP info + int remDevIndex; + int devIndex; + devIndex = 0; + for (int q = 0; q < rComm->base.nqps; q++) { + remDevIndex = remMeta.qpInfo[q].devIndex; + remDevInfo = remMeta.devs + remDevIndex; + qp = rComm->base.qps+q; + rCommDev = rComm->devs + devIndex; + qp->remDevIdx = remDevIndex; + + // Local ibDevN + ibDevN = rComm->devs[devIndex].base.ibDevN; + ibDev = rocmIbDevs + ibDevN; + NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp, channel_id, false, q), ret, fail); + qp->devIndex = devIndex; + devIndex = (devIndex + 1) % rComm->base.vProps.ndevs; + + // Set the ece (enhanced connection establishment) on this QP before RTR + if (remMeta.qpInfo[q].ece_supported) { + // Coverity suspects a copy-paste error below due to the use of remMeta in one argument and meta in another. + // However, this has been confirmed to be intentional. + // coverity[copy_paste_error] + NCCLCHECKGOTO(wrap_ibv_set_ece(qp->qp, &remMeta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported), ret, fail); + } else { + meta.qpInfo[q].ece_supported = 0; + } + + NCCLCHECKGOTO(rocmIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, true, remMeta.tc, remMeta.sl), ret, fail); + NCCLCHECKGOTO(rocmIbRtsQp(qp->qp), ret, fail); + + // Query the reduced ece for this QP (matching enhancements between the requestor and the responder) + // Store this in our own qpInfo for returning to the requestor + if (remMeta.qpInfo[q].ece_supported && meta.qpInfo[q].ece_supported) { + NCCLCHECKGOTO(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported), ret, fail); + } + } + + useDmaBuf = (rocmIbDmaBufSupport(lComm->dev) == ncclSuccess); + rComm->flushEnabled = ((rocmIbGdrSupport() == ncclSuccess || useDmaBuf) + && (ncclIbGdrFlushDisable == 0)) ? 1 : 0; + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { + rCommDev = rComm->devs + i; + ibDev = rocmIbDevs + rCommDev->base.ibDevN; + + // Retain remote fifo info and prepare my RDMA ops + rComm->remFifo.addr = remMeta.fifoAddr; + if (rcclCtsInlineData) { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems_cts_inline, + sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, + IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } else { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } + rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; + if (ncclIbUseInline) rComm->remFifo.flags = IBV_SEND_INLINE; + + // Allocate Flush dummy buffer for GPU Direct RDMA + if (rComm->flushEnabled) { + if (rcclParamRocmIbGdrFlushGpuMemNoRelaxedOrdering()) { +#if defined(HIP_UNCACHED_MEMORY) + NCCLCHECKGOTO(ncclCudaCalloc(&rCommDev->gpuFlush.gpuFlushGpuMem, sizeof(int), hipDeviceMallocUncached), ret, fail); +#else + NCCLCHECKGOTO(ncclCudaCalloc(&rCommDev->gpuFlush.gpuFlushGpuMem, sizeof(int), hipDeviceMallocFinegrained), ret, fail); +#endif + if (useDmaBuf) + { + uint64_t export_offset = 0; + void *aligned_ptr = NULL; + size_t aligned_size = 0; + get_aligned_ptr_and_size(rCommDev->gpuFlush.gpuFlushGpuMem, sizeof(int) /*devicebuffersize*/, &aligned_ptr, &aligned_size); + hsa_status_t export_status = pfn_hsa_amd_portable_export_dmabuf(aligned_ptr, aligned_size, &rCommDev->gpuFlush.dmabuf_fd, &export_offset); + if (rCommDev->gpuFlush.dmabuf_fd < 0 || export_status != HSA_STATUS_SUCCESS) + { + WARN("Failed to export DMA BUF"); + goto fail; + } + NCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&rCommDev->gpuFlush.gpuMr, rCommDev->base.pd, export_offset, sizeof(int), (uint64_t)rCommDev->gpuFlush.gpuFlushGpuMem /*iova*/, rCommDev->gpuFlush.dmabuf_fd, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ), ret, fail); + } + else + { + rCommDev->gpuFlush.dmabuf_fd = -1; + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->gpuFlush.gpuMr, rCommDev->base.pd, rCommDev->gpuFlush.gpuFlushGpuMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ), ret, fail); + } + } else { + rCommDev->gpuFlush.gpuFlushGpuMem = nullptr; + rCommDev->gpuFlush.gpuMr = nullptr; + rCommDev->gpuFlush.dmabuf_fd = -1; + } + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->gpuFlush.hostMr, rCommDev->base.pd, &rComm->gpuFlushHostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE), ret, fail); + rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem; + rCommDev->gpuFlush.sge.length = 1; + rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey; + NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail); + struct ncclIbDevInfo devInfo; + devInfo.lid = ibDev->portAttr.lid; + devInfo.link_layer = ibDev->portAttr.link_layer; + devInfo.ib_port = ibDev->portNum; + devInfo.gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo.gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id; + devInfo.mtu = ibDev->portAttr.active_mtu; + NCCLCHECKGOTO(rocmIbRtrQp(rCommDev->gpuFlush.qp.qp, &rCommDev->base.gidInfo, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo, false, remMeta.tc, remMeta.sl), ret, fail); + NCCLCHECKGOTO(rocmIbRtsQp(rCommDev->gpuFlush.qp.qp), ret, fail); + } + + // Fill Handle + meta.devs[i].lid = ibDev->portAttr.lid; + meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + meta.devs[i].ib_port = ibDev->portNum; + meta.devs[i].gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + meta.devs[i].gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id; + meta.devs[i].mtu = ibDev->portAttr.active_mtu; + meta.devs[i].ibv_dev_index = rCommDev->base.ibDevN; + + // Prepare sizes fifo + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rComm->devs[i].sizesFifoMr, rComm->devs[i].base.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + meta.devs[i].fifoRkey = rComm->devs[i].sizesFifoMr->rkey; + } + meta.fifoAddr = (uint64_t)rComm->sizesFifo; + meta.sl = remMeta.sl; + meta.tc = remMeta.tc; + + for (int q = 0; q < rComm->base.nqps; q++) { + meta.qpInfo[q].qpn = rComm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = rComm->base.qps[q].devIndex; + } + meta.ndevs = rComm->base.vProps.ndevs; + meta.isP2p = remMeta.isP2p; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); + rComm->base.nDataQps = std::max(rComm->base.vProps.ndevs, rComm->base.nRemDevs); + + stage->state = ncclIbCommStateSend; + stage->offset = 0; + if (stage->buffer) { + free(stage->buffer); + stage->buffer = NULL; + } + NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbConnectionMetadata)), ret, fail); + memcpy(stage->buffer, &meta, sizeof(struct ncclIbConnectionMetadata)); + +ib_send: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(struct ncclIbConnectionMetadata), &stage->offset), ret, fail); + if (stage->offset < sizeof(struct ncclIbConnectionMetadata)) return ncclSuccess; + + stage->offset = 0; + stage->state = ncclIbCommStatePendingReady; + +ib_recv_ready: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, &rComm->base.ready, sizeof(int), &stage->offset), ret, fail); + if (stage->offset != sizeof(int)) return ncclSuccess; + + *recvComm = rComm; +exit: + /* reset lComm stage */ + if (stage->buffer) free(stage->buffer); + stage->state = ncclIbCommStateStart; + stage->offset = 0; + stage->comm = NULL; + stage->buffer = NULL; + return ret; +fail: + free(rComm); + goto exit; +} + +ncclResult_t rocmIbGetRequest(struct ncclIbNetCommBase* base, struct ncclIbRequest** req) { + for (int i=0; ireqs+i; + if (r->type == NCCL_NET_IB_REQ_UNUSED) { + r->base = base; + r->sock = NULL; + memset(r->devBases, 0, sizeof(r->devBases)); + memset(r->events, 0, sizeof(r->events)); + *req = r; + return ncclSuccess; + } + } + WARN("NET/IB : unable to allocate requests"); + *req = NULL; + return ncclInternalError; +} + +ncclResult_t rocmIbFreeRequest(struct ncclIbRequest* r) { + r->type = NCCL_NET_IB_REQ_UNUSED; + return ncclSuccess; +} + +ncclResult_t rocmIbTest(void* request, int* done, int* size); + +ncclResult_t rocmIbRegMrDmaBufInternal(ncclIbNetCommDevBase* base, void* data, size_t size, int type, uint64_t offset, int fd, ibv_mr** mhandle) { + static __thread uintptr_t pageSize = 0; + if (pageSize == 0) pageSize = sysconf(_SC_PAGESIZE); + struct ncclIbMrCache* cache = &rocmIbDevs[base->ibDevN].mrCache; + uintptr_t addr = (uintptr_t)data & -pageSize; + size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; + ncclResult_t res; + pthread_mutex_lock(&rocmIbDevs[base->ibDevN].lock); + for (int slot=0; /*true*/; slot++) { + if (slot == cache->population || addr < cache->slots[slot].addr) { // didn't find in cache + if (cache->population == cache->capacity) { // must grow cache + cache->capacity = cache->capacity < 32 ? 32 : 2*cache->capacity; + NCCLCHECKGOTO(ncclRealloc(&cache->slots, cache->population, cache->capacity), res, returning); + } + // Deregister / register + struct ibv_mr* mr; + unsigned int flags = IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ; + if (ncclIbRelaxedOrderingEnabled) flags |= IBV_ACCESS_RELAXED_ORDERING; + if (fd != -1) { + /* DMA-BUF support */ + if (!rocmIbDevs[base->ibDevN].capsProvider.mlx5.dataDirect) { + NCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&mr, base->pd, offset, pages*pageSize, addr, fd, flags), res, returning); + } else { + NCCLCHECKGOTO(wrap_mlx5dv_reg_dmabuf_mr(&mr, base->pd, offset, pages*pageSize, addr, fd, flags, MLX5DV_REG_DMABUF_ACCESS_DATA_DIRECT), res, returning); + } + } else { + if (ncclIbRelaxedOrderingEnabled) { + // Use IBVERBS_1.8 API - needed for IBV_ACCESS_RELAXED_ORDERING support + NCCLCHECKGOTO(wrap_ibv_reg_mr_iova2(&mr, base->pd, (void*)addr, pages*pageSize, addr, flags), res, returning); + } + else { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&mr, base->pd, (void*)addr, pages*pageSize, flags), res, returning); + } + } + TRACE(NCCL_INIT|NCCL_NET,"regAddr=0x%lx size=%lld rkey=0x%x lkey=0x%x fd=%d", (unsigned long)addr, (long long)pages*pageSize, mr->rkey, mr->lkey, fd); + if (slot != cache->population) memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclIbMr)); + cache->slots[slot].addr = addr; + cache->slots[slot].pages = pages; + cache->slots[slot].refs = 1; + cache->slots[slot].mr = mr; + cache->population += 1; + *mhandle = mr; + res = ncclSuccess; + goto returning; + } else if ((addr >= cache->slots[slot].addr) && + ((addr-cache->slots[slot].addr)/pageSize+pages) <= cache->slots[slot].pages) { + cache->slots[slot].refs += 1; + *mhandle = cache->slots[slot].mr; + res = ncclSuccess; + goto returning; + } + } +returning: + pthread_mutex_unlock(&rocmIbDevs[base->ibDevN].lock); + return res; +} + +struct ncclIbNetCommDevBase* rocmIbGetNetCommDevBase(ncclIbNetCommBase* base, int devIndex) { + if (base->isSend) { + struct ncclIbSendComm* sComm = (struct ncclIbSendComm*) base; + return &sComm->devs[devIndex].base; + } else { + struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*) base; + return &rComm->devs[devIndex].base; + } +} + +/* DMA-BUF support */ +ncclResult_t rocmIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { + ncclResult_t ret = ncclSuccess; + assert(size > 0); + struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) malloc(sizeof(struct ncclIbMrHandle)); + for (int i = 0; i < base->vProps.ndevs; i++) { + // Each ncclIbNetCommDevBase is at different offset in send and recv netComms + struct ncclIbNetCommDevBase* devComm = rocmIbGetNetCommDevBase(base, i); + NCCLCHECKGOTO(rocmIbRegMrDmaBufInternal(devComm, data, size, type, offset, fd, mhandleWrapper->mrs + i), ret, fail); + } + *mhandle = (void*) mhandleWrapper; +exit: + return ret; +fail: + free(mhandleWrapper); + goto exit; +} + +ncclResult_t rocmIbRegMr(void* comm, void* data, size_t size, int type, void** mhandle) { + return rocmIbRegMrDmaBuf(comm, data, size, type, 0ULL, -1, mhandle); +} + +ncclResult_t rocmIbDeregMrInternal(ncclIbNetCommDevBase* base, ibv_mr* mhandle) { + struct ncclIbMrCache* cache = &rocmIbDevs[base->ibDevN].mrCache; + ncclResult_t res; + pthread_mutex_lock(&rocmIbDevs[base->ibDevN].lock); + for (int i=0; i < cache->population; i++) { + if (mhandle == cache->slots[i].mr) { + if (0 == --cache->slots[i].refs) { + memmove(&cache->slots[i], &cache->slots[--cache->population], sizeof(struct ncclIbMr)); + if (cache->population == 0) { + free(cache->slots); + cache->slots = NULL; + cache->capacity = 0; + } + NCCLCHECKGOTO(wrap_ibv_dereg_mr(mhandle), res, returning); + } + res = ncclSuccess; + goto returning; + } + } + WARN("NET/IB: could not find mr %p inside cache of %d entries", mhandle, cache->population); + res = ncclInternalError; +returning: + pthread_mutex_unlock(&rocmIbDevs[base->ibDevN].lock); + return res; +} + +ncclResult_t rocmIbDeregMr(void* comm, void* mhandle) { + if (mhandle == NULL) return ncclSuccess; + + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; + struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; + for (int i = 0; i < base->vProps.ndevs; i++) { + // Each ncclIbNetCommDevBase is at different offset in send and recv netComms + struct ncclIbNetCommDevBase* devComm = rocmIbGetNetCommDevBase(base, i); + NCCLCHECK(rocmIbDeregMrInternal(devComm, mhandleWrapper->mrs[i])); + } + free(mhandleWrapper); + return ncclSuccess; +} + +NCCL_PARAM(RocmIbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0); + +ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot, bool use_write_op) { + struct ncclIbRequest** reqs = comm->fifoReqs[slot]; + volatile struct ncclIbSendFifo* slots = comm->fifo[slot]; + int nreqs; + if (rcclCtsOffloadEnabled) { + nreqs = 1; + } else { + nreqs = slots[0].nreqs; + } + if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; + + uint64_t wr_id = 0ULL; + for (int r=0; rwrs+r; + memset(wr, 0, sizeof(struct ibv_send_wr)); + + struct ibv_sge* sge = comm->sges+r; + sge->addr=(uintptr_t)reqs[r]->send.data; + wr->opcode = IBV_WR_RDMA_WRITE; + wr->send_flags = 0; + if (rcclCtsOffloadEnabled) { + wr->wr.rdma.remote_addr = 0xdeadbeef; + } else { + wr->wr.rdma.remote_addr = slots[r].addr; + } + wr->next = wr + 1; + wr_id += (reqs[r] - comm->base.reqs) << (r*8); +#ifdef NCCL_ENABLE_NET_PROFILING + reqs[r]->pInfo[0].nEventHandles = 0; +#endif + } + + // Write size as immediate data. In the case of multi-send, only write + // 0 or 1 as size to indicate whether there was data sent or received. + uint32_t immData = 0; + if ((nreqs == 1) && (use_write_op == false)) { + immData = reqs[0]->send.size; + } else { + int* sizes = comm->remSizesFifo.elems[slot]; + for (int r=0; rsend.size; + comm->remSizesFifo.sge.addr = (uint64_t)sizes; + comm->remSizesFifo.sge.length = nreqs*sizeof(int); + } + + struct ibv_send_wr* lastWr = comm->wrs+nreqs-1; + if (use_write_op == false) { + if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamRocmIbArThreshold())) { + // When using ADAPTIVE_ROUTING, send the bulk of the data first as an + // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote + // completion. + lastWr++; + memset(lastWr, 0, sizeof(struct ibv_send_wr)); + if (nreqs > 1) { + // Write remote sizes Fifo + lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); + lastWr->num_sge = 1; + lastWr->sg_list = &comm->remSizesFifo.sge; + } + } + lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + lastWr->imm_data = immData; + } + lastWr->wr_id = wr_id; + lastWr->next = NULL; + lastWr->send_flags = IBV_SEND_SIGNALED; + + // Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work + const int align = 128; + int nqps = ncclParamRocmIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; + for (int i = 0; i < nqps; i++) { + int qpIndex = comm->base.qpIndex; + ncclIbQp* qp = comm->base.qps + qpIndex; + int devIndex = qp->devIndex; + for (int r=0; rdevs[devIndex].base); + + // Select proper rkey (needed even for 0-size send) + if (rcclCtsOffloadEnabled) { + comm->wrs[r].wr.rdma.rkey = 0xbade; + } else { + comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; + } + + int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align; + int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize); + if (length <= 0) { + comm->wrs[r].sg_list = NULL; + comm->wrs[r].num_sge = 0; + } else { + // Select proper lkey + comm->sges[r].lkey = reqs[r]->send.lkeys[devIndex]; + comm->sges[r].length = length; + comm->wrs[r].sg_list = comm->sges+r; + comm->wrs[r].num_sge = 1; + } + } + + if ((use_write_op == false) && (nreqs > 1)) { + // Also make sure lastWr writes remote sizes using the right lkey + comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey; + lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex]; + } + + struct ibv_send_wr* bad_wr; +#ifdef NCCL_ENABLE_NET_PROFILING + // QP profiling loop + for (int r=0; rpInfo[0].nEventHandles; + assert(nEventHandles < MAX_QPS_PER_REQ); + reqs[r]->pInfo[0].qpIndex[nEventHandles] = qpIndex; + // Store info for profiler + int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER; + reqs[r]->pInfo[0].data.type = ncclProfileQp; + reqs[r]->pInfo[0].data.qp.device = devIndex; + reqs[r]->pInfo[0].data.qp.wr_id = comm->wrs[r].wr_id; + reqs[r]->pInfo[0].data.qp.opcode = comm->wrs[r].opcode; + reqs[r]->pInfo[0].data.qp.qpNum = qp->qp->qp_num; + reqs[r]->pInfo[0].data.qp.length = comm->sges[r].length; + void* pHandle = reqs[r]->pInfo[0].pHandle; + NCCLCHECK(ncclProfilerFunction(&reqs[r]->pInfo[0].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, pHandle, pluginId, &reqs[r]->pInfo[0].data)); + reqs[r]->pInfo[0].nEventHandles++; + } +#endif + NCCLCHECK(wrap_ibv_post_send(qp->qp, comm->wrs, &bad_wr)); + + for (int r=0; rsend.size, nqps), align) * align; + reqs[r]->send.offset += chunkSize; + comm->sges[r].addr += chunkSize; + comm->wrs[r].wr.rdma.remote_addr += chunkSize; + + TRACE(NCCL_VERBS, "Posted send wr_id=%lu, wr_indx=%d, qp_num=%d, src_nic=%d, dst_nic=%d, dlid=%d, opcode=%d, send_flags=%d, imm_data=%d, remote_addr=%lx, rkey=%x, length=%d, lkey=%x", + comm->wrs[r].wr_id, r, qp->qp->qp_num, comm->devs[qp->devIndex].base.ibDevN , comm->base.remDevs[qp->remDevIdx].ibv_dev_index, comm->base.remDevs[qp->remDevIdx].lid, + comm->wrs[r].opcode, comm->wrs[r].send_flags, comm->wrs[r].imm_data, comm->wrs[r].wr.rdma.remote_addr, + comm->wrs[r].wr.rdma.rkey,comm->wrs[r].sg_list ? comm->wrs[r].sg_list->length : 0, comm->wrs[r].sg_list ? comm->wrs[r].sg_list->lkey : 0); + } + + // Select the next qpIndex + comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps; + } + + return ncclSuccess; +} + +ncclResult_t rocmIbIsend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void* phandle, void** request) { + struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; + if (comm->base.ready == 0) { + WARN("NET/IB: rocmIbIsend() called when comm->base.ready == 0"); + *request = NULL; + return ncclInternalError; + } + NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__)); + + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; + bool use_write_op = false; + if (rcclAinicRoce) { + use_write_op = (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) ? true : false; + } + + // Wait for the receiver to have posted the corresponding receive + int nreqs = 0; + volatile struct ncclIbSendFifo* slots; + + if (rcclCtsOffloadEnabled) { + nreqs = 1; + } + + int slot = (comm->fifoHead) % MAX_REQUESTS; + struct ncclIbRequest** reqs = comm->fifoReqs[slot]; + if (!rcclCtsOffloadEnabled) { + slots = comm->fifo[slot]; + uint64_t idx = comm->fifoHead+1; + if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; } + nreqs = slots[0].nreqs; + // Wait until all data has arrived + for (int r=1; r slots[r].size) size = slots[r].size; + // Sanity checks + if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { + char line[SOCKET_NAME_MAXLEN + 1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(&comm->base.sock, &addr); + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", + r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); + return ncclInternalError; + } + } else{ + if (reqs[r] != NULL) continue; + } + + struct ncclIbRequest* req; + NCCLCHECK(rocmIbGetRequest(&comm->base, &req)); + req->type = NCCL_NET_IB_REQ_SEND; + req->sock = &comm->base.sock; + req->base = &comm->base; + req->nreqs = nreqs; + req->send.size = size; + req->send.data = data; + req->send.offset = 0; +#ifdef NCCL_ENABLE_NET_PROFILING + req->pInfo[0].pHandle = phandle; +#endif + + // Populate events + int nEvents = ncclParamRocmIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; + int qpIndex = comm->base.qpIndex; + // Count down + while (nEvents > 0) { + ncclIbQp* qp = comm->base.qps + qpIndex; + int devIndex = qp->devIndex; + ncclIbAddEvent(req, devIndex, &comm->devs[devIndex].base); + // Track the valid lkey for this RDMA_Write + req->send.lkeys[devIndex] = mhandleWrapper->mrs[devIndex]->lkey; + nEvents--; + // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs inside ncclIbMultiSend() + qpIndex = (qpIndex+1)%comm->base.nqps; + } + + // Store all lkeys + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + req->send.lkeys[i] = mhandleWrapper->mrs[i]->lkey; + } + + *request = reqs[r] = req; + + // If this is a multi-recv, send only when all requests have matched. + for (int r=0; rnreqs, as well as other fields to help debugging and sanity checks + if (!rcclCtsOffloadEnabled) { + memset((void*)slots, 0, sizeof(struct ncclIbSendFifo)); + } + memset(reqs, 0, NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbRequest*)); + comm->fifoHead++; + TIME_STOP(0); + return ncclSuccess; + } + + *request = NULL; + return ncclSuccess; +} + +ncclResult_t rocmIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { + struct ibv_send_wr wr; + struct ncclIbSendFifo* localElem = NULL; + struct ncclIbSendFifoCtsInline* localElemCtsInline = NULL; + uint64_t localElemRef; + int qpIndex = 0; + ncclIbQp* ctsQp = NULL; + memset(&wr, 0, sizeof(wr)); + + int slot = comm->remFifo.fifoTail%MAX_REQUESTS; + req->recv.sizes = comm->sizesFifo[slot]; + for (int i=0; irecv.sizes[i] = 0; + if (rcclCtsInlineData) { + localElemCtsInline = comm->remFifo.elems_cts_inline[slot]; + } else { + localElem = comm->remFifo.elems[slot]; + } + + if (rcclAinicRoce) { + qpIndex = comm->base.qpIndex; + ctsQp = comm->base.qps + qpIndex; + } else { + // Select the next devIndex (local) and QP to use for posting this CTS message + // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value + ctsQp = comm->base.qps + comm->base.devIndex; + comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; + } + + for (int i=0; ibase.vProps.ndevs; j++) + localElemCtsInline[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; + + localElemCtsInline[i].nreqs = n; + localElemCtsInline[i].size = sizes[i]; // Sanity/Debugging + localElemCtsInline[i].tag = tags[i]; + localElemCtsInline[i].idx = comm->remFifo.fifoTail+1; + localElemRef = (uint64_t)localElemCtsInline; + + } else { + localElem[i].addr = (uint64_t)data[i]; + + // Send all applicable rkeys + for (int j = 0; j < comm->base.vProps.ndevs; j++) + localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; + + localElem[i].nreqs = n; + localElem[i].size = sizes[i]; // Sanity/Debugging + localElem[i].tag = tags[i]; + localElem[i].idx = comm->remFifo.fifoTail+1; + localElemRef = (uint64_t)localElem; + } + } + wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo); + + // Lookup the correct fifoRkey + wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey; + + // Set the correct sge properties + comm->devs[ctsQp->devIndex].fifoSge.addr = localElemRef; + if (rcclCtsInlineData) { + comm->devs[ctsQp->devIndex].fifoSge.length = MAX_INLINE_DATA_SIZE; + } else { + comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); + } + wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge; + wr.num_sge = 1; + + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = comm->remFifo.flags; // IBV_SEND_INLINE + + // We need to occasionally post a request with the IBV_SEND_SIGNALED flag, otherwise + // the send queue will never empty. + // + // From https://www.rdmamojo.com/2014/06/30/working-unsignaled-completions/ + // "How to use Unsignaled Completion?" / "Gotchas and Pitfalls" + // All posted Send Requested, Signaled and Unsignaled, are considered outstanding until + // a Work Completion that they, or Send Requests that were posted after them, was polled + // from the Completion Queue associated with the Send Queue. This means if one works with + // a Queue Pair that was configured to work with Unsignaled Completions, he must make + // sure that occasionally (before the Send Queue is full with outstanding Send Requests) + // a Send Request that generate Work Completion will be posted. + // + // Not following this rule may lead to a case that the Send Queue is full with Send + // Requests that won't generate Work Completion: + // + // - The Send Queue is full, so no new Send Requests can be posted to it + // - The Send Queue can't be emptied, since no Work Completion can be generated anymore + // (the reason is that no Work Completion, that can generate Work Completion that + // polling it will empty the Send Queue, can be posted) + // - The status of all posted Send Request is considered unknown + // + // slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled. + // This works out that each fifo posting QP gets drained + if (rcclAinicRoce) { + if (slot == ctsQp->ctsQpSlot) { + wr.send_flags |= IBV_SEND_SIGNALED; + wr.wr_id = req - comm->base.reqs; + ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); + } + } else if (slot == ctsQp->devIndex) { + wr.send_flags |= IBV_SEND_SIGNALED; + wr.wr_id = req - comm->base.reqs; + ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); + } + + struct ibv_send_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_send(ctsQp->qp, &wr, &bad_wr)); + + TRACE(NCCL_VERBS, "Posted send wr_id=%lu, wr_indx=%d, qp_num=%d, src_nic=%d, dst_nic=%d, dlid=%lu, opcode=%d, send_flags=%d, imm_data=%d, remote_addr=%lx, rkey=%x, length=%d, lkey=%x", + wr.wr_id, 0, ctsQp->qp->qp_num, comm->devs[ctsQp->devIndex].base.ibDevN, comm->base.remDevs[ctsQp->remDevIdx].ibv_dev_index, comm->base.remDevs[ctsQp->remDevIdx].lid, + wr.opcode, wr.send_flags, wr.imm_data, wr.wr.rdma.remote_addr, wr.wr.rdma.rkey, wr.sg_list ? wr.sg_list->length : 0, wr.sg_list ? wr.sg_list->lkey : 0); + + comm->remFifo.fifoTail++; + + if (rcclAinicRoce) { + // Select the next qpIndex + comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps; + } + return ncclSuccess; +} + +ncclResult_t rocmIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** phandles, void** request) { + ncclResult_t res = ncclSuccess; + bool netOptRecvCompletionEnabled = false; + struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; + if (comm->base.ready == 0) { + WARN("NET/IB: rocmIbIrecv() called when comm->base.ready == 0"); + *request = NULL; + return ncclInternalError; + } + if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; + NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__)); + + if (rcclAinicRoce) { + if (*request == (void *) NCCL_NET_OPTIONAL_RECV_COMPLETION) { + netOptRecvCompletionEnabled = true; + } + } + struct ncclIbRequest* req; + NCCLCHECK(rocmIbGetRequest(&comm->base, &req)); + req->type = NCCL_NET_IB_REQ_RECV; + req->sock = &comm->base.sock; + req->nreqs = n; +#ifdef NCCL_ENABLE_NET_PROFILING + for (int r = 0; r < n && phandles; r++) req->pInfo[r].nEventHandles = 0; +#endif + + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + req->devBases[i] = &comm->devs[i].base; + } + + if (!netOptRecvCompletionEnabled) { + struct ibv_recv_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = req - comm->base.reqs; + wr.sg_list = NULL; + wr.num_sge = 0; + + TIME_START(1); + // Select either all QPs, or one qp per-device + const int nqps = ncclParamRocmIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; + + // Post recvs + struct ibv_recv_wr* bad_wr; + int qpIndex = comm->base.qpIndex; + for (int i = 0; i < nqps; i++) { + struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; + ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); +#ifdef NCCL_ENABLE_NET_PROFILING + // Start a QP event for every request in the multirecv and every qp + for (int r = 0; r < n; r++) { + int nEventHandles = req->pInfo[r].nEventHandles; + assert(nEventHandles < MAX_QPS_PER_REQ); + req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex; + // Store info for profiler + int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER; + req->pInfo[r].data.type = ncclProfileQp; + req->pInfo[r].data.qp.device = qp->devIndex; + req->pInfo[r].data.qp.wr_id = wr.wr_id; + req->pInfo[r].data.qp.qpNum = qp->qp->qp_num; + NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data)); + req->pInfo[r].nEventHandles++; + } +#endif + NCCLCHECKGOTO(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr), res, err); + // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs + // inside rocmIbPostFifo() + if (rcclAinicRoce) { + qpIndex = (qpIndex+1)%comm->base.nqps; + } else { + comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; + } + } + + TIME_STOP(1); + } // netOptRecvCompletionEnabled = false + + // Post to FIFO to notify sender + TIME_START(2); + NCCLCHECKGOTO(rocmIbPostFifo(comm, n, data, sizes, tags, mhandles, req), res, err); + TIME_STOP(2); + + *request = req; + return ncclSuccess; +err: + if (req) { + rocmIbFreeRequest(req); + } + return res; +} + +ncclResult_t rocmIbIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { + struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; + int last = -1; + for (int i=0; iflushEnabled == 0 || last == -1) return ncclSuccess; + + // Only flush once using the last non-zero receive + struct ncclIbRequest* req; + NCCLCHECK(rocmIbGetRequest(&comm->base, &req)); + req->type = NCCL_NET_IB_REQ_FLUSH; + req->sock = &comm->base.sock; + struct ncclIbMrHandle* mhandle = (struct ncclIbMrHandle*) mhandles[last]; + + // We don't know which devIndex the recv was on, so we flush on all devices + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = req - comm->base.reqs; + if (rcclParamRocmIbGdrFlushGpuMemNoRelaxedOrdering()) { + wr.wr.rdma.remote_addr = (uint64_t)(comm->devs[i].gpuFlush.gpuFlushGpuMem); + wr.wr.rdma.rkey = comm->devs[i].gpuFlush.gpuMr->rkey; + wr.sg_list = &comm->devs[i].gpuFlush.sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = 0; + struct ibv_send_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_send(comm->devs[i].gpuFlush.qp.qp, &wr, &bad_wr)); + } + memset(&wr, 0, sizeof(wr)); + wr.wr_id = req - comm->base.reqs; + if (rcclParamRocmIbGdrFlushGpuMemNoRelaxedOrdering()) { + wr.wr.rdma.remote_addr = (uint64_t)(comm->devs[i].gpuFlush.gpuFlushGpuMem); + wr.wr.rdma.rkey = comm->devs[i].gpuFlush.gpuMr->rkey; + } else { + wr.wr.rdma.remote_addr = (uint64_t)data[last]; + wr.wr.rdma.rkey = mhandle->mrs[i]->rkey; + } + wr.sg_list = &comm->devs[i].gpuFlush.sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_READ; + wr.send_flags = IBV_SEND_SIGNALED; + + TIME_START(4); + struct ibv_send_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_send(comm->devs[i].gpuFlush.qp.qp, &wr, &bad_wr)); + TIME_STOP(4); + + ncclIbAddEvent(req, i, &comm->devs[i].base); + } + + *request = req; + return ncclSuccess; +} + +#define HCA_NAME(req, index) ((req)->devBases[(index)]->pd->context->device->name) + +#ifdef NCCL_ENABLE_NET_PROFILING +static int getReqQpIndex(struct ncclIbRequest* req, int request, int qpNumber) { + for (int i = 0; i < MAX_QPS_PER_REQ; i++) { + int qpIndex = req->pInfo[request].qpIndex[i]; + if (req->base->qps[qpIndex].qp->qp_num == qpNumber) return i; + } + return 0; +} +#endif + +#define NCCL_CQ_POLL_MAX_EVENT 16 + +ncclResult_t rocmIbTest(void* request, int* done, int* sizes) { + struct ncclIbRequest *r = (struct ncclIbRequest*)request; + *done = 0; + while (1) { + NCCLCHECK(ncclIbStatsCheckFatalCount(&r->base->stats,__func__)); + if (r->events[0] == 0 && r->events[1] == 0 && r->events[2] == 0 && r->events[3] == 0) { + TRACE(NCCL_NET, "r=%p done", r); + *done = 1; + if (sizes && r->type == NCCL_NET_IB_REQ_RECV) { + for (int i=0; inreqs; i++) { + sizes[i] = r->recv.sizes[i]; +#ifdef NCCL_ENABLE_NET_PROFILING + for (int j = 0; j < r->pInfo[i].nEventHandles; j++) { + NCCLCHECK(ncclProfilerFunction(&r->pInfo[i].qpEventHandles[j], ncclProfilerNetEventStop, NULL, 0, NULL)); + } +#endif + } + } + if (sizes && r->type == NCCL_NET_IB_REQ_SEND) { + sizes[0] = r->send.size; +#ifdef NCCL_ENABLE_NET_PROFILING + for (int j = 0; j < r->pInfo[0].nEventHandles; j++) { + NCCLCHECK(ncclProfilerFunction(&r->pInfo[0].qpEventHandles[j], ncclProfilerNetEventStop, NULL, 0, NULL)); + } +#endif + } + // Stop all remaining Qp events for this event + NCCLCHECK(rocmIbFreeRequest(r)); + return ncclSuccess; + } + + int totalWrDone = 0; + int wrDone = 0; + struct ibv_wc wcs[NCCL_CQ_POLL_MAX_EVENT]; + int cqMaxPollEvent = 4; + if (rcclAinicRoce) { + cqMaxPollEvent = NCCL_CQ_POLL_MAX_EVENT; + } + + for (int i = 0; i < NCCL_IB_MAX_DEVS_PER_NIC; i++) { + TIME_START(3); + // If we expect any completions from this device's CQ + if (r->events[i]) { + NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, cqMaxPollEvent, + wcs, &wrDone)); + totalWrDone += wrDone; + if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } + if (wrDone == 0) continue; + for (int w=0; wstatus != IBV_WC_SUCCESS) { + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); + char localGidString[INET6_ADDRSTRLEN] = ""; + char remoteGidString[INET6_ADDRSTRLEN] = ""; + const char* localGidStr = NULL, *remoteGidStr = NULL; + if (r->devBases[i]->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) { + localGidStr = ibvGetGidStr(&r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); + remoteGidStr = ibvGetGidStr(&r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); + } + + char line[SOCKET_NAME_MAXLEN+1]; + char *hcaName = r->devBases[i]->pd->context->device->name; + WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%u vendor err %u (%s)%s%s%s%s hca %s", + ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, rocmIbReqTypeStr[r->type], + localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString, hcaName); + return ncclRemoteError; + } + + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); + struct ncclIbRequest* req = r->base->reqs+(wc->wr_id & 0xff); + + #ifdef ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN+1]; + TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%u wr_id=%lu r=%p type=%d events={%d,%d,%d,%d}, i=%d", + ncclSocketToString(&addr, line), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], req->events[2], req->events[3], i); + #endif + if (req && req->type == NCCL_NET_IB_REQ_SEND) { + for (int j = 0; j < req->nreqs; j++) { + struct ncclIbRequest* sendReq = r->base->reqs+((wc->wr_id >> (j*8)) & 0xff); + if ((sendReq->events[i] <= 0)) { + WARN("NET/IB: sendReq(%p)->events={%d,%d,%d,%d}, i=%d, j=%d <= 0", sendReq, sendReq->events[0], sendReq->events[1], sendReq->events[2], sendReq->events[3], i, j); + return ncclInternalError; + } + sendReq->events[i]--; +#ifdef NCCL_ENABLE_NET_PROFILING + // Stop Qp event for sendReq + int qpIndex = getReqQpIndex(sendReq, j, wc->qp_num); + NCCLCHECK(ncclProfilerFunction(&sendReq->pInfo[j].qpEventHandles[qpIndex], ncclProfilerNetEventStop, NULL, 0, NULL)); +#endif + } + } else { + if (req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + if (req->type != NCCL_NET_IB_REQ_RECV) { + WARN("NET/IB: wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM and req->type=%d", req->type); + return ncclInternalError; + } + if (req->nreqs == 1) { + req->recv.sizes[0] = wc->imm_data; + } + } + req->events[i]--; +#ifdef NCCL_ENABLE_NET_PROFILING + // Stop Qp event for workFifo + for (int j = 0; j < req->nreqs; j++) { + int qpIndex = getReqQpIndex(req, j, wc->qp_num); + NCCLCHECK(ncclProfilerFunction(&req->pInfo[j].qpEventHandles[qpIndex], ncclProfilerNetEventStop, NULL, 0, NULL)); + } +#endif + } + } + // Once the IB fatal event is reported in the async thread, we want to propagate this error + // to communicator and prevent further polling to reduce error pollution. + NCCLCHECK(ncclIbStatsCheckFatalCount(&rocmIbDevs[r->devBases[i]->ibDevN].stats,__func__)); + } + } + + // If no CQEs found on any device, return and come back later + if (totalWrDone == 0) return ncclSuccess; + } +} + +ncclResult_t rocmIbCloseSend(void* sendComm) { + struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; + if (comm) { + NCCLCHECK(ncclSocketClose(&comm->base.sock)); + + for (int q = 0; q < comm->base.nqps; q++) + if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); + + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + struct ncclIbSendCommDev* commDev = comm->devs + i; + if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); + if (comm->remSizesFifo.mrs[i] != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mrs[i])); + NCCLCHECK(rocmIbDestroyBase(&commDev->base)); + } + free(comm); + } + TIME_PRINT("IB"); + return ncclSuccess; +} + +ncclResult_t rocmIbCloseRecv(void* recvComm) { + struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; + if (comm) { + NCCLCHECK(ncclSocketClose(&comm->base.sock)); + + for (int q = 0; q < comm->base.nqps; q++) + if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); + + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + struct ncclIbRecvCommDev* commDev = comm->devs + i; + if (comm->flushEnabled) { + if (commDev->gpuFlush.gpuFlushGpuMem != nullptr) { + NCCLCHECK(ncclCudaFree(commDev->gpuFlush.gpuFlushGpuMem)); + commDev->gpuFlush.gpuFlushGpuMem = nullptr; + if (commDev->gpuFlush.gpuMr != nullptr) NCCLCHECK(wrap_ibv_dereg_mr(commDev->gpuFlush.gpuMr)); + commDev->gpuFlush.gpuMr = nullptr; + if(commDev->gpuFlush.dmabuf_fd > 0) { close(commDev->gpuFlush.dmabuf_fd);} + } + if (commDev->gpuFlush.qp.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(commDev->gpuFlush.qp.qp)); + if (commDev->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->gpuFlush.hostMr)); + } + if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); + if (commDev->sizesFifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->sizesFifoMr)); + NCCLCHECK(rocmIbDestroyBase(&commDev->base)); + } + free(comm); + } + return ncclSuccess; +} + +ncclResult_t rocmIbCloseListen(void* listenComm) { + struct ncclIbListenComm* comm = (struct ncclIbListenComm*)listenComm; + if (comm) { + NCCLCHECK(ncclSocketClose(&comm->sock)); + free(comm); + } + return ncclSuccess; +} + +ncclResult_t rcclRocmNetP2pPolicy(void* handle, int isP2p) { + if (!handle) return ncclInvalidArgument; + struct ncclIbHandle* ibHandle = (struct ncclIbHandle*)handle; + if (ibHandle->magic != NCCL_SOCKET_MAGIC) return ncclInvalidArgument; + ibHandle->isP2p = isP2p; + return ncclSuccess; +} + +ncclNet_t rocmNetIb = { + "ROCM-IB", + rocmIbInit, + rocmIbDevices, + rocmIbGetProperties, + rocmIbListen, + rocmIbConnect, + rocmIbAccept, + rocmIbRegMr, + rocmIbRegMrDmaBuf, + rocmIbDeregMr, + rocmIbIsend, + rocmIbIrecv, + rocmIbIflush, + rocmIbTest, + rocmIbCloseSend, + rocmIbCloseRecv, + rocmIbCloseListen, + NULL /* getDeviceMr */, + NULL /* irecvConsumed */, + rocmIbMakeVDevice +}; + +/* + ncclIbSetProperties, + ncclIbRefreshDevices +*/