diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d216189bd..020b65fdc0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -423,7 +423,6 @@ set(SRC_FILES src/msccl.cc src/proxy.cc src/rccl_wrap.cc - src/register.cc src/transport.cc src/device/all_gather.h src/device/all_reduce.h @@ -499,6 +498,7 @@ set(SRC_FILES src/include/param.h src/include/profiler.h src/include/proxy.h + src/include/ras.h src/include/rccl_common.h src/include/rccl_vars.h src/include/register.h @@ -589,6 +589,16 @@ set(SRC_FILES src/misc/msccl/msccl_parser.cc src/misc/msccl/msccl_setup.cc src/misc/msccl/msccl_status.cc + src/ras/client.cc + src/ras/client_support.cc + src/ras/collectives.cc + src/ras/peers.cc + src/ras/ras.cc + src/ras/ras_internal.h + src/ras/rasnet.cc + src/register/coll_reg.cc + src/register/register.cc + src/register/sendrecv_reg.cc src/transport/coll_net.cc src/transport/generic.cc src/transport/net.cc diff --git a/cmake/scripts/add_unroll.sh b/cmake/scripts/add_unroll.sh index c14a66a09d..8b1cddfff3 100644 --- a/cmake/scripts/add_unroll.sh +++ b/cmake/scripts/add_unroll.sh @@ -21,14 +21,11 @@ HIP_FILE=$1 if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then - sed -i "s/template/template/g" "$HIP_FILE" - sed -i "s/template/template/g" "$HIP_FILE" - sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" "$HIP_FILE" - sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" "$HIP_FILE" - sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" - sed -i "s/\\(runRing]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" - sed -i "s/runTreeUpDown>/runTreeUpDown, COLL_UNROLL>/" "$HIP_FILE" - sed -i "s/\\(runTreeSplit]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + perl -pi -e 's/(template/\1, int COLL_UNROLL\2>/g' "$HIP_FILE" + perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, COLL_UNROLL>/g' "$HIP_FILE" + perl -pi -e 's/(runRing\()/\1, COLL_UNROLL\2/g' "$HIP_FILE" + perl -pi -e 's/(runTreeUpDown\(/\1, COLL_UNROLL>(/' "$HIP_FILE" + perl -pi -e 's/(runTreeSplit\(/\1, COLL_UNROLL>(/' "$HIP_FILE" sed -i "s/\\(struct RunWorkColl]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" sed -i "s/\\(struct RunWorkBatch]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" diff --git a/ext-net/README.md b/ext-net/README.md index 781fd904a4..aa1a3945e6 100644 --- a/ext-net/README.md +++ b/ext-net/README.md @@ -60,9 +60,9 @@ of newer ones. The `nccl/` directory is populated with `net_vX.h` files extracting all relevant definitions from old API versions. It also provides error codes in `err.h`. -# API (v6) +# API (v9) -Below is the main `ncclNet_v6` struct. Each function is explained in later sections. +Below is the main `ncclNet_v9` struct. Each function is explained in later sections. ``` typedef struct { @@ -73,7 +73,7 @@ typedef struct { // Return the number of adapters. ncclResult_t (*devices)(int* ndev); // Get various device properties. - ncclResult_t (*getProperties)(int dev, ncclNetProperties_v6_t* props); + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v9_t* props); // Create a receiving object and provide a handle to connect to it. The // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged // between ranks to create a connection. @@ -82,24 +82,26 @@ typedef struct { // This call must not block for the connection to be established, and instead // should return successfully with sendComm == NULL with the expectation that // it will be called again until sendComm != NULL. - ncclResult_t (*connect)(int dev, void* handle, void** sendComm); + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm); // Finalize connection establishment after remote peer has called connect. // This call must not block for the connection to be established, and instead // should return successfully with recvComm == NULL with the expectation that // it will be called again until recvComm != NULL. - ncclResult_t (*accept)(void* listenComm, void** recvComm); + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm); // Register/Deregister memory. Comm can be either a sendComm or a recvComm. // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. - ncclResult_t (*regMr)(void* comm, void* data, int size, int type, void** mhandle); + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle); /* DMA-BUF support */ ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); ncclResult_t (*deregMr)(void* comm, void* mhandle); // Asynchronous send to a peer. // May return request == NULL if the call cannot be performed (or would block) - ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request); + ncclResult_t (*isend)(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request); // Asynchronous recv from a peer. // May return request == NULL if the call cannot be performed (or would block) - ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request); + ncclResult_t (*irecv)(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request); // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is // visible to the GPU ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); @@ -110,7 +112,17 @@ typedef struct { ncclResult_t (*closeSend)(void* sendComm); ncclResult_t (*closeRecv)(void* recvComm); ncclResult_t (*closeListen)(void* listenComm); -} ncclNet_v6_t; + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); + + // Virtual NIC APIs. makeVDevice will create a virtual NIC given the specified properties, and tell the caller + // what index this new vNIC exists at + ncclResult_t (*makeVDevice)(int* d, ncclNetVDeviceProps_t* props); +} ncclNet_t; ``` ## Error codes @@ -136,11 +148,19 @@ not need to rely on CUDA, this should not be common. NCCL will call the `init` function first, then query the number of network devices with the `devices` function, getting each network device properties with `getProperties`. +If NCCL wishes to initialize virtual devices, used in NIC fusion currently, it can call `makeVDevice` +specifying a list of physical devices (the original devices listed from `devices`) it wishes to +merge together. If the plugin does not support NIC fusion, it can set `makeVDevice` to null. + To establish a connection between two network devices, NCCL will first call `listen` on the receiving side, pass the returned handle to the sender side of the connection, and call `connect` with that handle. Finally, `accept` will be called on the receiving side to finalize the connection establishment. +`connect` and `accept` can receive an optional `netDevComm` pointer from the caller, if the caller +wishes to make use of device networking. This parameter may be ignored by the plugin if it does +not support device-side networking. + Once the connection is established, communication will be done using the functions `isend`, `irecv` and `test`. Prior to calling `isend` or `irecv`, NCCL will call the `regMr` function on all buffers to allow RDMA NICs to prepare buffers. `deregMr` will be used to unregister buffers. @@ -219,6 +239,12 @@ different offset within the original buffer, with a smaller size, etc), then der The call to ncclCommDeregister should call the final deregMr() and effectively remove the mapping on the network adapter. +The `forceFlush` field can request the NCCL core to call flush for all transfers. By default, +flushes are only called when the GPU architecture or PCI topology would not not guarantee correct +PCI ordering. Plugins can set it to one if the NIC operates in a mode where e.g. the data and the +completion paths use different PCI links and therefore need a call to flush() to guarantee +ordering. + The `speed` field indicates the speed of the network port in Mbps (10^6 bits per second). This is important to ensure proper optimization of flows within the node. @@ -234,6 +260,17 @@ The `maxComms` field indicates the maximum number of connections we can create. The `maxRecvs` field indicates the maximum number for grouped receive operations (see grouped receive). +The `netDeviceType` indicates which type of device networking this plugin supports. The current supported +options are `NCCL_NET_DEVICE_HOST` and `NCCL_NET_DEVICE_UNPACK`. + +The `netDeviceVersion` indicates the version of device networking this plugin supports. Currently, this must match the associated netDeviceVersion of this netDeviceType compiled into NCCL core. Net device functionality is built as apart of NCCL core's device code. + +The `maxP2pBytes` and `maxCollBytes` fields indicate the maximum size the plugin can handle for +point-to-point and collective calls. This will tell the NCCL core to cut large operations into +multiple smaller chunks if needed. + +`vProps` is the list of devices that have been fused into the current device. Each entry is an index pointing to the child device. + ### Connection establishment Connections are used in an unidirectional manner. There is therefore a sender side and a receiver @@ -332,6 +369,12 @@ handled by a single request handle. The sizes provided to `irecv` can (and will) be larger than the size of the `isend` operation. The contrary (receive size being lower than the send size) is an error, however. +NCCL sets request pointer in `irecv` to `NCCL_NET_OPTIONAL_RECV_COMPLETION` when it is using +LL or LL128 protocols. In these cases, NCCL polls on flag embedded in data to detect completion +of irecv and is resilient to redundant network writes. This allows the plugin to optimize request +completions on such irecvs (for example, complete the request immediately). The plugin is still +expected to set a valid request pointer on return which NCCL can poll to check for completion. + Note: for a given connection, send/receive operations should always match in the order they were posted. Tags provided for receive operations are only used to assign a given send operation to one of the buffers of the first (multi-)receive in the queue, not to allow for out-of-order tag diff --git a/ext-net/example/nccl/net.h b/ext-net/example/nccl/net.h index 2aea8c439b..112967ab86 100644 --- a/ext-net/example/nccl/net.h +++ b/ext-net/example/nccl/net.h @@ -12,6 +12,8 @@ #include "err.h" #define NCCL_NET_HANDLE_MAXSIZE 128 +#define NCCL_MAX_NET_SIZE_BYTES (1*1024*1024*1024*1024L) //1TB +#define NCCL_NET_OPTIONAL_RECV_COMPLETION 0x1 #define NCCL_PTR_HOST 0x1 #define NCCL_PTR_CUDA 0x2 @@ -20,6 +22,7 @@ // Maximum number of requests per comm object #define NCCL_NET_MAX_REQUESTS 32 +#include "net_v9.h" #include "net_v8.h" #include "net_v7.h" #include "net_v6.h" diff --git a/ext-net/example/nccl/net_device.h b/ext-net/example/nccl/net_device.h index b430d90646..874fb5999a 100644 --- a/ext-net/example/nccl/net_device.h +++ b/ext-net/example/nccl/net_device.h @@ -25,6 +25,7 @@ typedef struct { } ncclNetDeviceHandle_v7_t; typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t; -typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_t; +typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_v9_t; +typedef ncclNetDeviceHandle_v9_t ncclNetDeviceHandle_t; #endif diff --git a/ext-net/example/nccl/net_v8.h b/ext-net/example/nccl/net_v8.h index 3161558205..54a61f61b4 100644 --- a/ext-net/example/nccl/net_v8.h +++ b/ext-net/example/nccl/net_v8.h @@ -23,8 +23,6 @@ typedef struct { int netDeviceVersion; // Version number for network offload } ncclNetProperties_v8_t; -typedef ncclNetProperties_v8_t ncclNetProperties_t; - typedef struct { // Name of the network (mainly for logs) const char* name; diff --git a/ext-net/example/nccl/net_v9.h b/ext-net/example/nccl/net_v9.h new file mode 100644 index 0000000000..61035ecc93 --- /dev/null +++ b/ext-net/example/nccl/net_v9.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. + */ + +#ifndef NCCL_NET_V9_H_ +#define NCCL_NET_V9_H_ + +#include "net_device.h" + +#define NCCL_NET_MAX_DEVS_PER_NIC_V9 4 +#define NCCL_NET_MAX_DEVS_PER_NIC NCCL_NET_MAX_DEVS_PER_NIC_V9 +typedef struct { + int ndevs; + int devs[NCCL_NET_MAX_DEVS_PER_NIC_V9]; +} ncclNetVDeviceProps_v9_t; +typedef ncclNetVDeviceProps_v9_t ncclNetVDeviceProps_t; + +typedef struct { + char* name; // Used mostly for logging. + char* pciPath; // Path to the PCI device in /sys. + uint64_t guid; // Unique identifier for the NIC chip. Important for + // cards with multiple PCI functions (Physical or virtual). + int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] + int regIsGlobal; // regMr is not tied to a particular comm + int forceFlush; // Force a flush on receives + int speed; // Port speed in Mbps. + int port; // Port number. + float latency; // Network latency + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives. + ncclNetDeviceType netDeviceType; // Network offload type + int netDeviceVersion; // Version number for network offload + ncclNetVDeviceProps_v9_t vProps; + size_t maxP2pBytes; // Max transfer size for point-to-point operations + size_t maxCollBytes; // Max transfer size for collective operations +} ncclNetProperties_v9_t; + +typedef ncclNetProperties_v9_t ncclNetProperties_t; + +typedef struct { + // Name of the network (mainly for logs) + const char* name; + // Initialize the network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v9_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create a connection. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Connect to a handle and return a sending comm object for that peer. + // This call must not block for the connection to be established, and instead + // should return successfully with sendComm == NULL with the expectation that + // it will be called again until sendComm != NULL. + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v9_t** sendDevComm); + // Finalize connection establishment after remote peer has called connect. + // This call must not block for the connection to be established, and instead + // should return successfully with recvComm == NULL with the expectation that + // it will be called again until recvComm != NULL. + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v9_t** recvDevComm); + // Register/Deregister memory. Comm can be either a sendComm or a recvComm. + // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + // Asynchronous send to a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*isend)(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request); + // Asynchronous recv from a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*irecv)(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* sizes); + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); + + // Virtual NIC APIs. makeVDevice will create a virtual NIC given the specified properties, and tell the caller + // what index this new vNIC exists at + ncclResult_t (*makeVDevice)(int* d, ncclNetVDeviceProps_t* props); +} ncclNet_v9_t; + +#endif // end include guard diff --git a/ext-net/example/nccl/types.h b/ext-net/example/nccl/types.h index e40f5b50d4..8274c203c9 100644 --- a/ext-net/example/nccl/types.h +++ b/ext-net/example/nccl/types.h @@ -16,8 +16,8 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, ncclBfloat16 = 9, - ncclFp8E4M3 = 10, - ncclFp8E5M2 = 11, + ncclFloat8e4m3 = 10, + ncclFloat8e5m2 = 11, } ncclDataType_t; #endif diff --git a/ext-net/example/plugin.c b/ext-net/example/plugin.c index 128dde9b47..2852242617 100644 --- a/ext-net/example/plugin.c +++ b/ext-net/example/plugin.c @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2024, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -7,15 +7,15 @@ #include "net.h" #define __hidden __attribute__ ((visibility("hidden"))) +#define NCCL_PLUGIN_MAX_RECVS 1 int max_requests = NCCL_NET_MAX_REQUESTS; __hidden ncclResult_t pluginInit(ncclDebugLogger_t logFunction) { return ncclSuccess; } __hidden ncclResult_t pluginDevices(int* ndev) { *ndev = 0; return ncclSuccess; } - __hidden ncclResult_t pluginPciPath(int dev, char** path) { return ncclInternalError; } __hidden ncclResult_t pluginPtrSupport(int dev, int* supportedTypes) { return ncclInternalError; } -__hidden ncclResult_t pluginGetProperties(int dev, ncclNetProperties_v8_t* props) { +__hidden ncclResult_t pluginGetProperties(int dev, ncclNetProperties_t* props) { // Below are default values, if unsure don't change. props->name = "Example"; @@ -27,6 +27,8 @@ __hidden ncclResult_t pluginGetProperties(int dev, ncclNetProperties_v8_t* props props->ptrSupport = NCCL_PTR_HOST; // If you regMr has a fast registration cache, set to 1. If set to 0, user buffer registration may be disabled. props->regIsGlobal = 0; + // Force flush after receive. Needed if the control path and data path use a different path to the GPU + props->forceFlush = 0; // Speed in *Mbps*. 100000 means 100G props->speed = 100000; // Port number, used in conjunction with guid @@ -36,20 +38,27 @@ __hidden ncclResult_t pluginGetProperties(int dev, ncclNetProperties_v8_t* props // Maximum number of comm objects we can create. props->maxComms = 1024*1024; // Maximum number of receive operations taken by irecv(). - props->maxRecvs = 1; + props->maxRecvs = NCCL_PLUGIN_MAX_RECVS; // Coupling with NCCL network device-side code. - props->netDeviceType = 0; + props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; - return ncclInternalError; + // Used to tell NCCL core whether this is a virtual device fusing multiple physical devices. + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + // maximum transfer sizes the plugin can handle + props->maxP2pBytes = NCCL_MAX_NET_SIZE_BYTES; + props->maxCollBytes = NCCL_MAX_NET_SIZE_BYTES; + return ncclSuccess; } + __hidden ncclResult_t pluginListen(int dev, void* handle, void** listenComm) { return ncclInternalError; } -__hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm) { return ncclInternalError; } -__hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm) { return ncclInternalError; } +__hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** sendDevComm) { return ncclInternalError; } +__hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** recvDevComm) { return ncclInternalError; } __hidden ncclResult_t pluginRegMr(void* collComm, void* data, size_t size, int type, void** mhandle) { return ncclInternalError; } __hidden ncclResult_t pluginRegMrDmaBuf(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { return ncclInternalError; } __hidden ncclResult_t pluginDeregMr(void* collComm, void* mhandle) { return ncclInternalError;} -__hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { return ncclInternalError; } -__hidden ncclResult_t pluginIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { return ncclInternalError; } +__hidden ncclResult_t pluginIsend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { return ncclInternalError; } +__hidden ncclResult_t pluginIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { return ncclInternalError; } __hidden ncclResult_t pluginIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { return ncclInternalError; } __hidden ncclResult_t pluginTest(void* request, int* done, int* size) { return ncclInternalError; } __hidden ncclResult_t pluginCloseSend(void* sendComm) { return ncclInternalError; } @@ -57,10 +66,11 @@ __hidden ncclResult_t pluginCloseRecv(void* recvComm) { return ncclInternalError __hidden ncclResult_t pluginCloseListen(void* listenComm) { return ncclInternalError; } __hidden ncclResult_t pluginIrecvConsumed(void* recvComm, int n, void* request) { return ncclInternalError; } __hidden ncclResult_t pluginGetDeviceMr(void* comm, void* mhandle, void** dptr_mhandle) { return ncclInternalError; } +__hidden ncclResult_t pluginMakeVDevice(int* d, ncclNetVDeviceProps_t* props) { return ncclInternalError; } #define PLUGIN_NAME "Plugin" -const ncclNet_v8_t ncclNetPlugin_v8 = { +ncclNet_v9_t ncclNetPlugin_v9 = { .name = PLUGIN_NAME, .init = pluginInit, .devices = pluginDevices, @@ -80,8 +90,60 @@ const ncclNet_v8_t ncclNetPlugin_v8 = { .closeListen = pluginCloseListen, .getDeviceMr = pluginGetDeviceMr, .irecvConsumed = pluginIrecvConsumed, + .makeVDevice = pluginMakeVDevice, }; +__hidden ncclResult_t pluginGetProperties_v8(int dev, ncclNetProperties_v8_t* props_v8) { + ncclNetProperties_t props; + ncclResult_t ret = pluginGetProperties(dev, &props); + if (ret != ncclSuccess) return ret; + props_v8->name = props.name; + props_v8->pciPath = props.pciPath; + props_v8->guid = props.guid; + props_v8->ptrSupport = props.ptrSupport; + props_v8->regIsGlobal = props.regIsGlobal; + props_v8->speed = props.speed; + props_v8->latency = props.latency; + props_v8->port = props.port; + props_v8->maxComms = props.maxComms; + props_v8->maxRecvs = props.maxRecvs; + props_v8->netDeviceType = props.netDeviceType; + props_v8->netDeviceVersion = props.netDeviceVersion; + return ncclSuccess; +} + +__hidden ncclResult_t pluginIsend_v8(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { + return pluginIsend(sendComm, data, (int)size, tag, mhandle, request); +} + +__hidden ncclResult_t pluginIrecv_v8(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { + size_t sizesOut[NCCL_PLUGIN_MAX_RECVS]; + for (int i=0; iguid = props.guid; props_v7->ptrSupport = props.ptrSupport; props_v7->speed = props.speed; + props_v7->latency = props.latency; props_v7->port = props.port; props_v7->maxComms = props.maxComms; props_v7->maxRecvs = props.maxRecvs; @@ -114,8 +177,8 @@ const ncclNet_v7_t ncclNetPlugin_v7 = { .regMr = pluginRegMr_v7, .regMrDmaBuf = pluginRegMrDmaBuf, .deregMr = pluginDeregMr, - .isend = pluginIsend, - .irecv = pluginIrecv, + .isend = pluginIsend_v8, + .irecv = pluginIrecv_v8, .iflush = pluginIflush, .test = pluginTest, .closeSend = pluginCloseSend, @@ -134,6 +197,7 @@ __hidden ncclResult_t pluginGetProperties_v6(int dev, ncclNetProperties_v6_t* pr props_v6->guid = props.guid; props_v6->ptrSupport = props.ptrSupport; props_v6->speed = props.speed; + props_v6->latency = props.latency; props_v6->port = props.port; props_v6->maxComms = props.maxComms; props_v6->maxRecvs = props.maxRecvs; @@ -154,8 +218,8 @@ const ncclNet_v6_t ncclNetPlugin_v6 = { .regMr = pluginRegMr_v7, .regMrDmaBuf = pluginRegMrDmaBuf, .deregMr = pluginDeregMr, - .isend = pluginIsend, - .irecv = pluginIrecv, + .isend = pluginIsend_v8, + .irecv = pluginIrecv_v8, .iflush = pluginIflush, .test = pluginTest, .closeSend = pluginCloseSend, @@ -174,8 +238,8 @@ const ncclNet_v5_t ncclNetPlugin_v5 = { .accept = pluginAccept_v6, .regMr = pluginRegMr_v7, .deregMr = pluginDeregMr, - .isend = pluginIsend, - .irecv = pluginIrecv, + .isend = pluginIsend_v8, + .irecv = pluginIrecv_v8, .iflush = pluginIflush, .test = pluginTest, .closeSend = pluginCloseSend, @@ -198,11 +262,11 @@ static ncclResult_t pluginGetProperties_v4(int dev, ncclNetProperties_v4_t* prop return ncclSuccess; } static ncclResult_t pluginIsend_v4(void *sendComm, void* data, int size, void *mhandle, void** request) { - return pluginIsend(sendComm, data, size, 0, mhandle, request); + return pluginIsend_v8(sendComm, data, size, 0, mhandle, request); } static ncclResult_t pluginIrecv_v4(void* recvComm, void* data, int size, void* mhandle, void** request) { int tag = 0; - return pluginIrecv(recvComm, 1, &data, &size, &tag, &mhandle, request); + return pluginIrecv_v8(recvComm, 1, &data, &size, &tag, &mhandle, request); } static ncclResult_t pluginIflush_v4(void* recvComm, void* data, int size, void* mhandle, void** request) { return pluginIflush(recvComm, 1, &data, &size, &mhandle, request); diff --git a/ext-profiler/example/event.h b/ext-profiler/example/event.h index 7432808133..1486a22482 100644 --- a/ext-profiler/example/event.h +++ b/ext-profiler/example/event.h @@ -14,6 +14,7 @@ #define MAX_CHANNELS 32 #define MAX_STEPS 16 +#define MAX_OPS 16 // Up to 64K ranks for PAT #define PROXY_OP_SEND_STATE_OFFSET (ncclProfilerProxyOpSendPosted) #define PROXY_OP_RECV_STATE_OFFSET (ncclProfilerProxyOpRecvPosted) @@ -86,7 +87,7 @@ struct taskEventBase { int rank; // rank of the operation in NCCL communicator const char* name; // FIXME: unused uint64_t commHash; // communicator identifier - uint8_t func; // ncclFunc* + const char* func; // ncclFunc* int refCount; // number of references for this operation struct group* parent; // parent event group struct taskEventBase* next; // next top level event in group @@ -102,16 +103,14 @@ struct collective { size_t count; size_t trafficBytes; int root; - uint8_t datatype; + const char* datatype; uint8_t nMaxChannels; - uint8_t algo; - uint8_t proto; - int op; + const char* algo; + const char* proto; int nWarps; - int isCollnet; - int isNvls; - struct proxyOp send[MAX_CHANNELS];// array of send proxy operation events - struct proxyOp recv[MAX_CHANNELS];// array of recv proxy operation events + struct proxyOp send[MAX_CHANNELS][MAX_OPS];// array of send proxy operation events + struct proxyOp recv[MAX_CHANNELS][MAX_OPS];// array of recv proxy operation events + int nProxyOps[MAX_CHANNELS]; }; struct p2p { @@ -119,9 +118,9 @@ struct p2p { uint8_t func; void const* buff; size_t count; - uint8_t datatype; + const char* datatype; int peer; - struct proxyOp op; + struct proxyOp op[MAX_CHANNELS]; }; struct group { diff --git a/ext-profiler/example/nccl/profiler.h b/ext-profiler/example/nccl/profiler.h index db7bc3feae..6680cfecef 100644 --- a/ext-profiler/example/nccl/profiler.h +++ b/ext-profiler/example/nccl/profiler.h @@ -13,6 +13,7 @@ #include "common.h" #include "err.h" +#include "profiler_v2.h" #include "profiler_v1.h" #endif // end include guard diff --git a/ext-profiler/example/nccl/profiler_v1.h b/ext-profiler/example/nccl/profiler_v1.h index 8724a1c662..7d34bed57f 100644 --- a/ext-profiler/example/nccl/profiler_v1.h +++ b/ext-profiler/example/nccl/profiler_v1.h @@ -9,16 +9,6 @@ #include -enum { - ncclProfileGroup = (1 << 0), // group event type - ncclProfileColl = (1 << 1), // host collective call event type - ncclProfileP2p = (1 << 2), // host point-to-point call event type - ncclProfileProxyOp = (1 << 3), // proxy operation event type - ncclProfileProxyStep = (1 << 4), // proxy step event type - ncclProfileProxyCtrl = (1 << 5), // proxy control event type - ncclProfileNumEvents = ( 6), -}; - typedef struct { uint8_t type; // event type descriptor: ncclProfileColl, ... void* parentObj; // pointer to the profiler parent object (for coll is the group) @@ -69,42 +59,8 @@ typedef struct { }; } ncclProfilerEventDescr_v1_t; -typedef enum { - ncclProfilerProxyOpSendPosted, - ncclProfilerProxyOpSendRemFifoWait, - ncclProfilerProxyOpSendTransmitted, - ncclProfilerProxyOpSendDone, - ncclProfilerProxyOpRecvPosted, - ncclProfilerProxyOpRecvReceived, - ncclProfilerProxyOpRecvTransmitted, - ncclProfilerProxyOpRecvDone, - - /* Legacy proxy profiler states */ - ncclProfilerProxyStepSendGPUWait, - ncclProfilerProxyStepSendWait, - ncclProfilerProxyStepRecvWait, - ncclProfilerProxyStepRecvFlushWait, - ncclProfilerProxyStepRecvGPUWait, - - /* Legacy proxy control states */ - ncclProfilerProxyCtrlIdle, - ncclProfilerProxyCtrlActive, - ncclProfilerProxyCtrlSleep, - ncclProfilerProxyCtrlWakeup, - ncclProfilerProxyCtrlAppend, - ncclProfilerProxyCtrlAppendEnd, -} ncclProfilerEventState_v1_t; - -typedef union { - struct { - size_t transSize; - int steps; - } proxyOp; - - struct { - int appendedProxyOps; - } proxyCtrl; -} ncclProfilerEventStateArgs_v1_t; +typedef ncclProfilerEventState_v2_t ncclProfilerEventState_v1_t; +typedef ncclProfilerEventStateArgs_v2_t ncclProfilerEventStateArgs_v1_t; typedef struct { const char* name; @@ -142,9 +98,4 @@ typedef struct { ncclResult_t (*finalize)(void* context); } ncclProfiler_v1_t; -typedef ncclProfilerEventDescr_v1_t ncclProfilerEventDescr_t; -typedef ncclProfilerEventState_v1_t ncclProfilerEventState_t; -typedef ncclProfilerEventStateArgs_v1_t ncclProfilerEventStateArgs_t; -typedef ncclProfiler_v1_t ncclProfiler_t; - #endif diff --git a/ext-profiler/example/nccl/profiler_v2.h b/ext-profiler/example/nccl/profiler_v2.h new file mode 100644 index 0000000000..aab4ccf868 --- /dev/null +++ b/ext-profiler/example/nccl/profiler_v2.h @@ -0,0 +1,146 @@ +/************************************************************************* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_PROFILER_V2_H_ +#define NCCL_PROFILER_V2_H_ + +#include + +enum { + ncclProfileGroup = (1 << 0), // group event type + ncclProfileColl = (1 << 1), // host collective call event type + ncclProfileP2p = (1 << 2), // host point-to-point call event type + ncclProfileProxyOp = (1 << 3), // proxy operation event type + ncclProfileProxyStep = (1 << 4), // proxy step event type + ncclProfileProxyCtrl = (1 << 5), // proxy control event type +}; + +typedef struct { + uint8_t type; // event type descriptor: ncclProfileColl, ... + void* parentObj; // pointer to the profiler parent object (for coll is the group) + int rank; // originating rank + union { + struct { + const char* name; + uint64_t commHash; + uint64_t seqNumber; + const char* func; + void const* sendBuff; + void* recvBuff; + size_t count; + int root; + const char* datatype; + size_t trafficBytes; + uint8_t nMaxChannels; + uint8_t nWarps; + const char* algo; + const char* proto; + } coll; + + struct { + const char* name; + uint64_t commHash; + const char* func; + void* buff; + const char* datatype; + size_t count; + int peer; + } p2p; + + struct { + pid_t pid; // pid of the originating process + uint8_t channelId; // channel id for this proxy operation + int peer; // remote rank for send/recv + int nSteps; // number of steps for this proxy operation + int chunkSize; // amount of data transferred by this proxy operation + int isSend; + } proxyOp; + + struct { + int step; + } proxyStep; + }; +} ncclProfilerEventDescr_v2_t; + +typedef enum { + ncclProfilerProxyOpSendPosted, + ncclProfilerProxyOpSendRemFifoWait, + ncclProfilerProxyOpSendTransmitted, + ncclProfilerProxyOpSendDone, + ncclProfilerProxyOpRecvPosted, + ncclProfilerProxyOpRecvReceived, + ncclProfilerProxyOpRecvTransmitted, + ncclProfilerProxyOpRecvDone, + + /* Legacy proxy profiler states */ + ncclProfilerProxyStepSendGPUWait, + ncclProfilerProxyStepSendWait, + ncclProfilerProxyStepRecvWait, + ncclProfilerProxyStepRecvFlushWait, + ncclProfilerProxyStepRecvGPUWait, + + /* Legacy proxy control states */ + ncclProfilerProxyCtrlIdle, + ncclProfilerProxyCtrlActive, + ncclProfilerProxyCtrlSleep, + ncclProfilerProxyCtrlWakeup, + ncclProfilerProxyCtrlAppend, + ncclProfilerProxyCtrlAppendEnd, +} ncclProfilerEventState_v2_t; + +typedef union { + struct { + size_t transSize; + int steps; + } proxyOp; + + struct { + int appendedProxyOps; + } proxyCtrl; +} ncclProfilerEventStateArgs_v2_t; + +typedef struct { + const char* name; + + // init - initialize the profiler plugin + // Input + // - context : opaque profiler context object for separating profiler behavior across comms + // Output + // - eActivationMask: bitmask of active events set by the plugin + ncclResult_t (*init)(void** context, int* eActivationMask); + + // startEvent - initialize and start a new event for the supplied event descriptor inside the eventset + // Input + // - context: opaque profiler context object + // - eDescr : pointer to ncclProfilerEventDescr_t object + // Output + // - eHandle: return event handle for supplied event descriptor object + ncclResult_t (*startEvent)(void* context, void** eHandle, ncclProfilerEventDescr_v2_t* eDescr); + + // stopEvent - stop/finalize an event inside and event set + // Input + // - eHandle: handle to event object + ncclResult_t (*stopEvent)(void* eHandle); + + // recordEventState - record event state transitions and event attribute updates + // Input + // - eHandle : handle to event object created through startEvent + // - eStateArgs: optional argument used to capture event attribute updates associated with the state transition + // - eState : event state transition + ncclResult_t (*recordEventState)(void* eHandle, ncclProfilerEventState_v2_t eState, ncclProfilerEventStateArgs_v2_t* eStateArgs); + + // finalize - finalize the profiler plugin + // Input + // - context: opaque profiler context object + ncclResult_t (*finalize)(void* context); +} ncclProfiler_v2_t; + +typedef ncclProfilerEventDescr_v2_t ncclProfilerEventDescr_t; +typedef ncclProfilerEventState_v2_t ncclProfilerEventState_t; +typedef ncclProfilerEventStateArgs_v2_t ncclProfilerEventStateArgs_t; +typedef ncclProfiler_v2_t ncclProfiler_t; + +#endif diff --git a/ext-profiler/example/plugin.c b/ext-profiler/example/plugin.c index f9de60813a..64d5d8be1d 100644 --- a/ext-profiler/example/plugin.c +++ b/ext-profiler/example/plugin.c @@ -21,11 +21,18 @@ static int initialized; // initialization counter for profiler static double startTime; // profiler start time -static int groupPoolSize = 16; -static int collPoolSize = 16; -static int p2pPoolSize = 1024; -static int proxyCtrlPoolSize = 16; -static int detachPoolSize = 128; +static const int defaultEActivationMask = ncclProfileColl | ncclProfileP2p; +static const int defaultGroupPoolSize = 16; +static const int defaultCollPoolSize = 16; +static const int defaultP2pPoolSize = 1024; +static const int defaultProxyCtrlPoolSize = 16; +static const int defaultDetachPoolSize = 128; + +static int groupPoolSize; +static int collPoolSize; +static int p2pPoolSize; +static int proxyCtrlPoolSize; +static int detachPoolSize; static int detachPoolBase; static int detachPoolIndex; static int detachPoolDone; @@ -56,25 +63,25 @@ __hidden ncclResult_t exampleProfilerInit(void** context, int* eActivationMask) pthread_mutex_lock(&lock); if (__atomic_fetch_add(&initialized, 1, __ATOMIC_RELAXED) == 0) { // first thread initializes event mask, environment and detach pool - __atomic_store_n(eActivationMask, ncclProfileColl | ncclProfileP2p, __ATOMIC_RELAXED); - if (getenv("NCCL_PROFILE_EVENT_MASK")) { - __atomic_store_n(eActivationMask, atoi(getenv("NCCL_PROFILE_EVENT_MASK")), __ATOMIC_RELAXED); - } - if (getenv("NCCL_PROFILE_GROUP_POOL_SIZE")) { - groupPoolSize = atoi(getenv("NCCL_PROFILE_GROUP_POOL_SIZE")); - } - if (getenv("NCCL_PROFILE_COLL_POOL_SIZE")) { - collPoolSize = atoi(getenv("NCCL_PROFILE_COLL_POOL_SIZE")); - } - if (getenv("NCCL_PROFILE_P2P_POOL_SIZE")) { - p2pPoolSize = atoi(getenv("NCCL_PROFILE_P2P_POOL_SIZE")); - } - if (getenv("NCCL_PROFILE_PROXY_CTRL_POOL_SIZE")) { - proxyCtrlPoolSize = atoi(getenv("NCCL_PROFILE_PROXY_CTRL_POOL_SIZE")); - } - if (getenv("NCCL_PROFILE_PROXY_DETACH_POOL_SIZE")) { - detachPoolSize = atoi(getenv("NCCL_PROFILE_PROXY_DETACH_POOL_SIZE")); - } + const char* str; + str = getenv("NCCL_PROFILE_EVENT_MASK"); + __atomic_store_n(eActivationMask, str ? atoi(str) : defaultEActivationMask, __ATOMIC_RELAXED); + + str = getenv("NCCL_PROFILE_GROUP_POOL_SIZE"); + groupPoolSize = str ? atoi(str) : defaultGroupPoolSize; + + str = getenv("NCCL_PROFILE_COLL_POOL_SIZE"); + collPoolSize = str ? atoi(str) : defaultCollPoolSize; + + str = getenv("NCCL_PROFILE_P2P_POOL_SIZE"); + p2pPoolSize = str ? atoi(str) : defaultP2pPoolSize; + + str = getenv("NCCL_PROFILE_PROXY_CTRL_POOL_SIZE"); + proxyCtrlPoolSize = str ? atoi(str) : defaultProxyCtrlPoolSize; + + str = getenv("NCCL_PROFILE_PROXY_DETACH_POOL_SIZE"); + detachPoolSize = str ? atoi(str) : defaultDetachPoolSize; + // detach pool is used to store PXN proxyOps and is shared among threads detachPool = (struct proxyOp *)calloc(detachPoolSize, sizeof(*detachPool)); if (detachPool == NULL) { @@ -107,6 +114,13 @@ __hidden ncclResult_t exampleProfilerInit(void** context, int* eActivationMask) ctx->proxyCtrlPool = (struct proxyCtrl *)calloc(proxyCtrlPoolSize, sizeof(*ctx->proxyCtrlPool)); if (ctx->proxyCtrlPool == NULL) goto fail; + // Print event pool sizes for debugging + //fprintf(stdout, "Profiler: Group pool size (bytes): %lu\n", sizeof(struct group)*groupPoolSize); + //fprintf(stdout, "Profiler: Coll pool size (bytes): %lu\n", sizeof(struct collective)*collPoolSize); + //fprintf(stdout, "Profiler: P2p pool size (bytes): %lu\n", sizeof(struct p2p)*p2pPoolSize); + //fprintf(stdout, "Profiler: Proxy pool size (bytes): %lu\n", sizeof(struct proxyCtrl)*proxyCtrlPoolSize); + //fprintf(stdout, "Profiler: PXN pool size (bytes): %lu\n", sizeof(struct proxyOp)*detachPoolSize); + *context = ctx; return ncclSuccess; @@ -154,7 +168,7 @@ __hidden ncclResult_t exampleProfilerFinalize(void* context) { free(ctx); // last thread cleans up shared detach pool - if (__atomic_fetch_sub(&initialized, 1, __ATOMIC_RELAXED) - 1 == 0) { + if (__atomic_sub_fetch(&initialized, 1, __ATOMIC_RELAXED) == 0) { start = (detachPoolIndex - detachPoolSize >= 0) ? detachPoolIndex - detachPoolSize : 0; end = detachPoolIndex; for (int i = start; i < end; i++) { @@ -171,7 +185,7 @@ __hidden ncclResult_t exampleProfilerFinalize(void* context) { __hidden void updateEvent(void* handle); -__hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, ncclProfilerEventDescr_v1_t* eDescr) { +__hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, ncclProfilerEventDescr_t* eDescr) { *eHandle = NULL; struct context* ctx = (struct context *)context; if (eDescr->type == ncclProfileGroup) { @@ -185,14 +199,15 @@ __hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, n if (base->type == ncclProfileColl) { struct collective* c = (struct collective *)base; // reset event proxyOps & proxySteps - memset(c->send, 0, sizeof(struct proxyOp)*MAX_CHANNELS); - memset(c->recv, 0, sizeof(struct proxyOp)*MAX_CHANNELS); + memset(c->send, 0, sizeof(struct proxyOp)*MAX_CHANNELS*MAX_OPS); + memset(c->recv, 0, sizeof(struct proxyOp)*MAX_CHANNELS*MAX_OPS); + memset(c->nProxyOps, 0, sizeof(int)*MAX_CHANNELS); // release collective events in the group and return them to the collective pool __atomic_fetch_add(&ctx->collPoolBase, 1, __ATOMIC_RELAXED); } else if (base->type == ncclProfileP2p) { struct p2p* p = (struct p2p *)base; // reset event proxyOp and proxySteps - memset(&p->op, 0, sizeof(struct proxyOp)); + memset(&p->op, 0, sizeof(struct proxyOp)*MAX_CHANNELS); // release p2p events in the group and return them to the p2p pool __atomic_fetch_add(&ctx->p2pPoolBase, 1, __ATOMIC_RELAXED); } @@ -203,7 +218,6 @@ __hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, n return ncclSuccess; } event->type = ncclProfileGroup; - __atomic_store_n(&event->refCount, 1, __ATOMIC_RELAXED); event->ctx = ctx; event->groupId = groupId; event->startTs = gettime() - startTime; @@ -238,14 +252,11 @@ __hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, n event->count = eDescr->coll.count; event->root = eDescr->coll.root; event->datatype = eDescr->coll.datatype; - event->op = eDescr->coll.op; event->trafficBytes = eDescr->coll.trafficBytes; event->nMaxChannels = eDescr->coll.nMaxChannels; event->nWarps = eDescr->coll.nWarps; event->algo = eDescr->coll.algo; event->proto = eDescr->coll.proto; - event->isCollnet = eDescr->coll.isCollnet; - event->isNvls = eDescr->coll.isNvls; *eHandle = event; taskEventQueueEnqueue(parent, (struct taskEventBase *)event); // increment the group ref counter so the event will staty open @@ -326,9 +337,13 @@ __hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, n if (eventBase->type == ncclProfileColl) { struct collective* parent = (struct collective *)eDescr->parentObj; - struct proxyOp* event = (eDescr->proxyOp.isSend) ? &parent->send[eDescr->proxyOp.channelId] : &parent->recv[eDescr->proxyOp.channelId]; + int channelId = eDescr->proxyOp.channelId; + struct proxyOp* event = (eDescr->proxyOp.isSend) ? + &parent->send[channelId][parent->nProxyOps[channelId]++] : + &parent->recv[channelId][parent->nProxyOps[channelId]++]; + event->type = ncclProfileProxyOp; - event->channelId = eDescr->proxyOp.channelId; + event->channelId = channelId; event->pid = eDescr->proxyOp.pid; event->rank = eDescr->rank; event->peer = eDescr->proxyOp.peer; @@ -338,13 +353,14 @@ __hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, n event->parent = eventBase; event->startTs = gettime() - startTime; *eHandle = event; - __atomic_store_n(&parent->base.refCount, 1, __ATOMIC_RELAXED); + __atomic_fetch_add(&parent->base.refCount, 1, __ATOMIC_RELAXED); debugEvent(event, "ProxyOpStart"); } else { // ncclProfileP2p struct p2p* parent = (struct p2p *)eDescr->parentObj; - struct proxyOp* event = &parent->op; + int channelId = eDescr->proxyOp.channelId; + struct proxyOp* event = &parent->op[channelId]; event->type = ncclProfileProxyOp; - event->channelId = eDescr->proxyOp.channelId; + event->channelId = channelId; event->pid = eDescr->proxyOp.pid; event->rank = eDescr->rank; event->peer = eDescr->proxyOp.peer; @@ -354,7 +370,7 @@ __hidden ncclResult_t exampleProfilerStartEvent(void* context, void** eHandle, n event->parent = eventBase; event->startTs = gettime() - startTime; *eHandle = event; - __atomic_store_n(&parent->base.refCount, 1, __ATOMIC_RELAXED); + __atomic_fetch_add(&parent->base.refCount, 1, __ATOMIC_RELAXED); debugEvent(event, "ProxyOpStart"); } } else if (eDescr->type == ncclProfileProxyStep) { @@ -379,7 +395,7 @@ void updateEvent(void* handle) { uint8_t type = *(uint8_t *)handle; if (type == ncclProfileGroup) { struct group* event = (struct group *)handle; - if (__atomic_fetch_sub(&event->refCount, 1, __ATOMIC_RELAXED) == 1) { + if (__atomic_sub_fetch(&event->refCount, 1, __ATOMIC_RELAXED) == 0) { event->stopTs = gettime() - startTime; // return group event to the pool __atomic_fetch_add(&event->ctx->groupPoolBase, 1, __ATOMIC_RELAXED); @@ -387,7 +403,7 @@ void updateEvent(void* handle) { debugEvent(event, "GroupStop"); } else if (type == ncclProfileColl) { struct collective* event = (struct collective *)handle; - if (__atomic_fetch_sub(&event->base.refCount, 1, __ATOMIC_RELAXED) == 1) { + if (__atomic_sub_fetch(&event->base.refCount, 1, __ATOMIC_RELAXED) == 0) { event->base.stopTs = gettime() - startTime; debugEvent(event, "CollStop"); updateEvent(event->base.parent); @@ -396,7 +412,7 @@ void updateEvent(void* handle) { debugEvent(event, "CollStop"); } else if (type == ncclProfileP2p) { struct p2p* event = (struct p2p *)handle; - if (__atomic_fetch_sub(&event->base.refCount, 1, __ATOMIC_RELAXED) == 1) { + if (__atomic_sub_fetch(&event->base.refCount, 1, __ATOMIC_RELAXED) == 0) { event->base.stopTs = gettime() - startTime; debugEvent(event, "P2pStop"); updateEvent(event->base.parent); @@ -408,7 +424,7 @@ void updateEvent(void* handle) { event->stopTs = gettime() - startTime; if (event->pid != pid) { // only for proxyOps that don't have a parent collective/p2p (i.e., PXN) - int done = __atomic_fetch_add(&detachPoolDone, 1, __ATOMIC_RELAXED) + 1; + int done = __atomic_add_fetch(&detachPoolDone, 1, __ATOMIC_RELAXED); if (done == detachPoolSize) { // reset the event completed (done) counter __atomic_store_n(&detachPoolDone, 0, __ATOMIC_RELAXED); @@ -451,12 +467,20 @@ __hidden ncclResult_t exampleProfilerStopEvent(void* eHandle) { struct collective* event = (struct collective *)eHandle; event->base.stopTs = gettime() - startTime; return ncclSuccess; + } else if (type == ncclProfileP2p) { + // stopping the p2p event in NCCL core does not + // mean the p2p has completed. It means the p2p + // was submitted/enqueued so we need to keep the event open + struct p2p* event = (struct p2p *)eHandle; + event->base.stopTs = gettime() - startTime; + return ncclSuccess; } + updateEvent(eHandle); return ncclSuccess; } -__hidden ncclResult_t exampleProfilerRecordEventState(void* eHandle, ncclProfilerEventState_v1_t eState, ncclProfilerEventStateArgs_v1_t* eStateArgs) { +__hidden ncclResult_t exampleProfilerRecordEventState(void* eHandle, ncclProfilerEventState_t eState, ncclProfilerEventStateArgs_t* eStateArgs) { // the event handle might be null if we run out of events if (eHandle == NULL) return ncclSuccess; @@ -482,7 +506,7 @@ __hidden ncclResult_t exampleProfilerRecordEventState(void* eHandle, ncclProfile return ncclSuccess; } -ncclProfiler_v1_t ncclProfiler_v1 = { +ncclProfiler_t ncclProfiler_v2 = { "Example-profiler", exampleProfilerInit, exampleProfilerStartEvent, diff --git a/ext-profiler/example/print_event.c b/ext-profiler/example/print_event.c index 490ba7ce44..f26a9eeb21 100644 --- a/ext-profiler/example/print_event.c +++ b/ext-profiler/example/print_event.c @@ -11,56 +11,6 @@ #define __hidden __attribute__ ((visibility("hidden"))) -__hidden const char* ncclFuncToString(int func) { - switch(func) { - case 0: - return "ncclBroadcast"; - case 1: - return "ncclReduce"; - case 2: - return "ncclAllGather"; - case 3: - return "ncclReduceScatter"; - case 4: - return "ncclAllReduce"; - case 5: - return "ncclSendRecv"; - case 6: - return "ncclSend"; - case 7: - return "ncclRecv"; - } - return NULL; -} - -__hidden const char* ncclAlgoToString(int algo) { - switch(algo) { - case 0: - return "Tree"; - case 1: - return "Ring"; - case 2: - return "CollnetDirect"; - case 3: - return "CollnetChain"; - case 4: - return "Nvls"; - case 5: - return "NvlsTree"; - } -} - -__hidden const char* ncclProtoToString(int proto) { - switch(proto) { - case 0: - return "LL"; - case 1: - return "LL128"; - case 2: - return "Simple"; - } -} - // FIXME: chrome tracing asynchronous events (following used) allow event nesting for events that have same id and category // It appears that nesting more than three events causes issues. Therefore, every event is given an increasing id and a // category that matches the type of event (GROUP, COLL, P2P, PROXY, NET) @@ -77,24 +27,24 @@ __hidden void printGroupEventTrailer(FILE* fh, struct group* event) { static __thread int collId; __hidden void printCollEventHeader(FILE* fh, struct collective* event) { - fprintf(fh, "{\"name\": \"%s\", \"cat\": \"COLL\", \"ph\": \"b\", \"id\": %d, \"pid\": %d, \"tid\": %d, \"ts\": %f, \"args\": {\"SeqNum\": %lu, \"CommHash\": %lu, \"Rank\": %d, \"Count\": %lu, \"Datatype\": %d, \"Algorithm\": \"%s\", \"Protocol\": \"%s\", \"nMaxChannels\": %d}},\n", - ncclFuncToString(event->base.func), collId, getpid(), 1, event->base.startTs, event->seqNumber, event->base.commHash, event->base.rank, event->count, event->datatype, ncclAlgoToString(event->algo), ncclProtoToString(event->proto), event->nMaxChannels); + fprintf(fh, "{\"name\": \"%s\", \"cat\": \"COLL\", \"ph\": \"b\", \"id\": %d, \"pid\": %d, \"tid\": %d, \"ts\": %f, \"args\": {\"SeqNum\": %lu, \"CommHash\": %lu, \"Rank\": %d, \"Count\": %lu, \"Datatype\": \"%s\", \"Algorithm\": \"%s\", \"Protocol\": \"%s\", \"nMaxChannels\": %d}},\n", + event->base.func, collId, getpid(), 1, event->base.startTs, event->seqNumber, event->base.commHash, event->base.rank, event->count, event->datatype, event->algo, event->proto, event->nMaxChannels); } __hidden void printCollEventTrailer(FILE* fh, struct collective* event) { fprintf(fh, "{\"name\": \"%s\", \"cat\": \"COLL\", \"ph\": \"e\", \"id\": %d, \"pid\": %d, \"tid\": %d, \"ts\": %f},\n", - ncclFuncToString(event->base.func), collId++, getpid(), 1, event->base.stopTs); + event->base.func, collId++, getpid(), 1, event->base.stopTs); } static __thread int p2pId; __hidden void printP2pEventHeader(FILE* fh, struct p2p* event) { - fprintf(fh, "{\"name\": \"%s\", \"cat\": \"P2P\", \"ph\": \"b\", \"id\": %d, \"pid\": %d, \"tid\": %d, \"ts\": %f, \"args\": {\"CommHash\": %lu, \"Rank\": %d, \"Peer\": %d, \"Count\": %lu, \"Datatype\": %d}},\n", - ncclFuncToString(event->base.func), p2pId, getpid(), 1, event->base.startTs, event->base.commHash, event->base.rank, event->peer, event->count, event->datatype); + fprintf(fh, "{\"name\": \"%s\", \"cat\": \"P2P\", \"ph\": \"b\", \"id\": %d, \"pid\": %d, \"tid\": %d, \"ts\": %f, \"args\": {\"CommHash\": %lu, \"Rank\": %d, \"Peer\": %d, \"Count\": %lu, \"Datatype\": \"%s\"}},\n", + event->base.func, p2pId, getpid(), 1, event->base.startTs, event->base.commHash, event->base.rank, event->peer, event->count, event->datatype); } __hidden void printP2pEventTrailer(FILE* fh, struct p2p* event) { fprintf(fh, "{\"name\": \"%s\", \"cat\": \"P2P\", \"ph\": \"e\", \"id\": %d, \"pid\": %d, \"tid\": %d, \"ts\": %f},\n", - ncclFuncToString(event->base.func), p2pId++, getpid(), 1, event->base.stopTs); + event->base.func, p2pId++, getpid(), 1, event->base.stopTs); } static __thread int proxyOpId; @@ -250,14 +200,18 @@ void printEvent(FILE* fh, void* handle) { struct collective* c = (struct collective *)handle; printCollEventHeader(fh, c); for (int i = 0; i < MAX_CHANNELS; i++) { - printEvent(fh, &c->send[i]); - printEvent(fh, &c->recv[i]); + for (int j = 0; j < c->nProxyOps[i]; j++) { + printEvent(fh, &c->send[i][j]); + printEvent(fh, &c->recv[i][j]); + } } printCollEventTrailer(fh, c); } else if (type == ncclProfileP2p) { struct p2p* p = (struct p2p *)handle; printP2pEventHeader(fh, p); - printEvent(fh, &p->op); + for (int i = 0; i < MAX_CHANNELS; i++) { + printEvent(fh, &p->op[i]); + } printP2pEventTrailer(fh, p); } else if (type == ncclProfileProxyOp) { struct proxyOp* p = (struct proxyOp *)handle; diff --git a/ext-tuner/example/nccl/tuner.h b/ext-tuner/example/nccl/tuner.h index aafabd72d8..77b543d12c 100644 --- a/ext-tuner/example/nccl/tuner.h +++ b/ext-tuner/example/nccl/tuner.h @@ -67,6 +67,7 @@ typedef struct { // - numPipeOps: number of operations in the group // - numAlgo: number of algorithms in collCostTable // - numProto: number of protocols in collCostTable + // - regBuff: can register user buffer // // Outputs: // - nChannels: number of channels (hence SMs) to be used. @@ -82,15 +83,15 @@ typedef struct { // Unset fields will be set automatically by NCCL. ncclResult_t (*getCollInfo)(void* context, ncclFunc_t collType, size_t nBytes, int numPipeOps, float** collCostTable, int numAlgo, int numProto, - int* nChannels); + int regBuff, int* nChannels); // Terminates the plugin and cleans up any resources that the plugin allocated. // context: tuner context object ncclResult_t (*destroy)(void* context); -} ncclTuner_v3_t; +} ncclTuner_v4_t; -typedef ncclTuner_v3_t ncclTuner_t; +typedef ncclTuner_v4_t ncclTuner_t; -#define NCCL_TUNER_PLUGIN_SYMBOL "ncclTunerPlugin_v3" +#define NCCL_TUNER_PLUGIN_SYMBOL "ncclTunerPlugin_v4" #endif diff --git a/ext-tuner/example/plugin.c b/ext-tuner/example/plugin.c index 416ff3da7f..a66000275f 100644 --- a/ext-tuner/example/plugin.c +++ b/ext-tuner/example/plugin.c @@ -226,7 +226,7 @@ __hidden ncclResult_t pluginDestroy(void* context) { return ncclSuccess; } #define PLUGIN_NAME "Example" -const ncclTuner_v3_t ncclTunerPlugin_v3 = { +const ncclTuner_v4_t ncclTunerPlugin_v4 = { .name = PLUGIN_NAME, .init = pluginInit, .getCollInfo = pluginGetCollInfo, diff --git a/makefiles/common.mk b/makefiles/common.mk index 59e4151cee..82164ab5c0 100644 --- a/makefiles/common.mk +++ b/makefiles/common.mk @@ -12,6 +12,7 @@ DEBUG ?= 0 ASAN ?= 0 UBSAN ?= 0 TRACE ?= 0 +WERROR ?= 0 PROFAPI ?= 1 NVTX ?= 1 RDMA_CORE ?= 0 @@ -115,6 +116,10 @@ ifeq ($(NVTX), 0) CXXFLAGS += -DNVTX_DISABLE endif +ifneq ($(WERROR), 0) +CXXFLAGS += -Werror +endif + ifneq ($(KEEP), 0) NVCUFLAGS += -keep endif diff --git a/makefiles/version.mk b/makefiles/version.mk index bcc0ff3ce1..2523009340 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 23 -NCCL_PATCH := 4 +NCCL_MINOR := 24 +NCCL_PATCH := 3 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/Makefile b/src/Makefile index b254eac32c..2c5d9e863e 100644 --- a/src/Makefile +++ b/src/Makefile @@ -7,17 +7,22 @@ include ../makefiles/common.mk include ../makefiles/version.mk ##### src files -INCEXPORTS := nccl.h nccl_net.h +INCEXPORTS := nccl.h LIBSRCFILES := \ bootstrap.cc channel.cc collectives.cc debug.cc enqueue.cc group.cc \ - init.cc init_nvtx.cc net.cc proxy.cc transport.cc register.cc \ + init.cc init_nvtx.cc net.cc proxy.cc transport.cc \ $(wildcard graph/*.cc) \ $(wildcard misc/*.cc) \ - $(wildcard transport/*.cc) + $(wildcard transport/*.cc) \ + $(wildcard register/*.cc) \ + $(filter-out ras/client.cc,$(wildcard ras/*.cc)) +BINSRCFILES := ras/client.cc ##### lib files LIBNAME := libnccl.so STATICLIBNAME := libnccl_static.a +##### binaries +BINNAME := ncclras ##### pkgconfig files PKGCONFIGFILE := nccl.pc ##### dirs @@ -26,11 +31,12 @@ INCDIR := $(BUILDDIR)/include LIBDIR := $(BUILDDIR)/lib OBJDIR := $(BUILDDIR)/obj PKGDIR := $(BUILDDIR)/lib/pkgconfig +BINDIR := $(BUILDDIR)/bin ##### target files CUDARTLIB ?= cudart_static +# Use compatibility shim only with static cudart; see https://github.com/NVIDIA/nccl/issues/658 ifeq ($(CUDARTLIB), cudart_static) - # Use compatibility shim only with static cudart; see https://github.com/NVIDIA/nccl/issues/658 LIBSRCFILES += enhcompat.cc endif @@ -40,18 +46,21 @@ LIBTARGET := $(LIBNAME:%=%.$(NCCL_MAJOR).$(NCCL_MINOR).$(NCCL_PATCH)) STATICLIBTARGET := $(STATICLIBNAME) PKGTARGET := $(PKGCONFIGFILE) LIBOBJ := $(LIBSRCFILES:%.cc=$(OBJDIR)/%.o) -DEPFILES := $(LIBOBJ:%.o=%.d) +BINOBJ := $(BINSRCFILES:%.cc=$(OBJDIR)/%.o) +DEPFILES := $(LIBOBJ:%.o=%.d) $(BINOBJ:%.o=%.d) LDFLAGS += -L${CUDA_LIB} -l$(CUDARTLIB) -lpthread -lrt -ldl DEVMANIFEST := $(BUILDDIR)/obj/device/manifest ##### rules -build : lib staticlib +build : lib staticlib binary lib : $(INCTARGETS) $(LIBDIR)/$(LIBTARGET) $(PKGDIR)/$(PKGTARGET) staticlib : $(LIBDIR)/$(STATICLIBTARGET) +binary : $(BINDIR)/$(BINNAME) + $(DEVMANIFEST): ALWAYS_REBUILD $(INCTARGETS) $(MAKE) -C ./device @@ -85,6 +94,11 @@ $(LIBDIR)/$(STATICLIBTARGET): $(LIBOBJ) $(DEVMANIFEST) mkdir -p $(LIBDIR) ar cr $@ $(LIBOBJ) $$(cat $(DEVMANIFEST)) +$(BINDIR)/$(BINNAME): $(BINOBJ) + @printf "Linking %-35s > %s\n" $(BINNAME) $@ + mkdir -p $(BINDIR) + $(CXX) $(CXXFLAGS) $^ -o $@ + $(PKGDIR)/nccl.pc : nccl.pc.in mkdir -p $(PKGDIR) @printf "Generating %-35s > %s\n" $< $@ @@ -121,15 +135,17 @@ $(OBJDIR)/%.o : %.cc $(INCTARGETS) clean : $(MAKE) -C device clean - rm -rf ${INCDIR} ${LIBDIR} ${PKGDIR} ${OBJDIR} + rm -rf ${BINDIR} ${INCDIR} ${LIBDIR} ${PKGDIR} ${OBJDIR} install : build mkdir -p $(PREFIX)/lib mkdir -p $(PREFIX)/lib/pkgconfig mkdir -p $(PREFIX)/include + mkdir -p $(PREFIX)/bin cp -P -v $(BUILDDIR)/lib/lib* $(PREFIX)/lib/ cp -P -v $(BUILDDIR)/lib/pkgconfig/* $(PREFIX)/lib/pkgconfig/ cp -v $(BUILDDIR)/include/* $(PREFIX)/include/ + cp -v $(BUILDDIR)/bin/ncclras $(PREFIX)/bin/ FILESTOFORMAT := $(shell find . -name ".\#*" -prune -o \( -name "*.cc" -o -name "*.h" \) -print | grep -v -E 'ibvwrap.h|nvmlwrap.h|gdrwrap.h|nccl.h') # Note that formatting.mk defines a new target so in order to not overwrite the default target, diff --git a/src/bootstrap.cc b/src/bootstrap.cc index 51db5da040..599fa2c049 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -14,6 +14,7 @@ #include "proxy.h" #include "signals.h" // [RCCL] #include "param.h" +#include "ras.h" #define BOOTSTRAP_N_CHECK_ABORT 10000 #define BOOTSTRAP_TAG_CONNECT (0x1 << 31) @@ -71,7 +72,7 @@ static int localIdFromRoot(int rank, int root, int nRanks, int nRoots) { int ir = BOOTSTRAP_PID(root, nRoots); return rank - firstRankFromRoot(ir, nRanks, nRoots); } -// return the number of child for a root, root will be periodized +// Check if the given rank is the first rank from the root static int isFirstFromRoot(int rank, int root, int nRanks, int nRoots) { return (rank == firstRankFromRoot(root, nRanks, nRoots)); } @@ -111,13 +112,13 @@ ncclResult_t bootstrapNetInit() { if (nIfs <= 0) { WARN("Bootstrap : no socket interface found"); pthread_mutex_unlock(&bootstrapNetLock); - return ncclInternalError; + return ncclInvalidUsage; } } char line[SOCKET_NAME_MAXLEN+MAX_IF_NAME_SIZE+2]; - sprintf(line, " %s:", bootstrapNetIfName); + snprintf(line, sizeof(line), " %s:", bootstrapNetIfName); ncclSocketToString(&bootstrapNetIfAddr, line+strlen(line)); - INFO(NCCL_BOOTSTRAP, "Bootstrap : Using%s", line); + INFO(NCCL_BOOTSTRAP, "Bootstrap: Using%s", line); bootstrapNetInitDone = 1; } pthread_mutex_unlock(&bootstrapNetLock); @@ -153,7 +154,7 @@ static ncclResult_t netIsend(ncclNet_t* net, void* sendComm, void* data, int siz int* done) { if (*done) return ncclSuccess; if (!*sendReq) { - NCCLCHECK(net->isend(sendComm, data, size, tag, dataHandle, sendReq)); + NCCLCHECK(net->isend(sendComm, data, (size_t)size, tag, dataHandle, sendReq)); } if (*sendReq) { NCCLCHECK(net->test(*sendReq, done, NULL)); @@ -167,7 +168,8 @@ static ncclResult_t netIrecv(ncclNet_t* net, void* recvComm, void* data, int siz int* done) { if (*done) return ncclSuccess; if (!*recvReq) { - NCCLCHECK(net->irecv(recvComm, 1, &data, &size, &tag, &dataHandle, recvReq)); + size_t size64 = size; + NCCLCHECK(net->irecv(recvComm, 1, &data, &size64, &tag, &dataHandle, recvReq)); } if (*recvReq) { NCCLCHECK(net->test(*recvReq, done, NULL)); @@ -303,7 +305,7 @@ static void* bootstrapRoot(void* rargs) { // if the number of root > 1, we will receive one extra info from the first local_id of the next root n2send = nRankFromRoot(iroot, nranks, nroots); nrecv = n2send + ((nroots > 1) ? 1 : 0); - NCCLCHECKGOTO(ncclCalloc(&rankInfo, nrecv * sizeof(union ringConnectInfo)), res, out); + NCCLCHECKGOTO(ncclCalloc(&rankInfo, nrecv), res, out); NCCLCHECKGOTO(ncclCalloc(&rankAddressesRoot, nrecv), res, out); } @@ -493,29 +495,37 @@ static ncclResult_t netGetDevice(int rank, struct ncclComm* comm, int* dev) { struct netIf userIfs[MAX_OOB_DEVS]; int nUserIfs = parseStringList(userIfEnv, userIfs, MAX_OOB_DEVS); // loop over the device and return the first one matching - int devId = 0; int nDev = 0; NCCLCHECK(comm->ncclNet->devices(&nDev)); + int devId = 0; while (devId < nDev) { ncclNetProperties_t props; comm->ncclNet->getProperties(devId, &props); // check against user specified HCAs/ports - bool found = matchIfList(props.name, props.port, userIfs, nUserIfs, searchExact) ^ searchNot; - if (found) { + if (matchIfList(props.name, props.port, userIfs, nUserIfs, searchExact) ^ searchNot) { + // All plain physical devices have been initialized at this point devOOB = devId; break; } devId++; } if (devOOB == -1) { - WARN("no device found matching NCCL_OOB_NET_IFNAME=%s, ignoring", userIfEnv); - goto noEnv; + if (!searchNot) + WARN("no device found matching %s%s, verify NCCL_OOB_NET_IFNAME", searchExact ? "exactly " : "", userIfEnv); + else + WARN("no device found after excluding %s%s, verify NCCL_OOB_NET_IFNAME", searchExact ? "exactly " : "", userIfEnv); + pthread_mutex_unlock(&bootstrapNetLock); + return ncclInvalidArgument; } } else { - noEnv: // default choice is device 0 devOOB = 0; } + // display info on the chosen device + ncclNetProperties_t props; + ncclResult_t res = comm->ncclNet->getProperties(devOOB, &props); + bool hasProp = res == ncclSuccess; + INFO(NCCL_BOOTSTRAP, "Bootstrap: Using %s:%d", (hasProp) ? props.name : "N/A", (hasProp) ? props.port : -1); } pthread_mutex_unlock(&bootstrapNetLock); } @@ -546,7 +556,8 @@ static ncclResult_t socketRingConnect(ncclSocketAddress* addr, struct ncclSocket } static ncclResult_t ringAllInfo(struct ncclComm* comm, struct bootstrapState* state, union ncclSocketAddress* peerAddresss, - union ncclSocketAddress* peerProxy, uint64_t* peerUDS) { + union ncclSocketAddress* peerProxy, uint64_t* peerUDS, + struct rasRankInit* rasRanks) { ncclResult_t res = ncclSuccess; int rank = comm->rank; int nRanks = comm->nRanks; @@ -554,6 +565,7 @@ static ncclResult_t ringAllInfo(struct ncclComm* comm, struct bootstrapState* st union ncclSocketAddress peerAddress; union ncclSocketAddress peerProxy; uint64_t peerUDS; + struct rasRankInit rasRank; }* ringData = NULL; NCCLCHECK(ncclCalloc(&ringData, nRanks)); @@ -564,6 +576,8 @@ static ncclResult_t ringAllInfo(struct ncclComm* comm, struct bootstrapState* st memcpy(&(ringData[rank].peerProxy), peerProxy + rank, sizeof(union ncclSocketAddress)); if (peerUDS) memcpy(&(ringData[rank].peerUDS), peerUDS + rank, sizeof(uint64_t)); + if (rasRanks) + memcpy(&(ringData[rank].rasRank), rasRanks + rank, sizeof(*rasRanks)); // allgather NCCLCHECKGOTO(bootstrapAllGather(state, ringData, sizeof(struct bootstrapRingData)), res, exit); @@ -576,6 +590,8 @@ static ncclResult_t ringAllInfo(struct ncclComm* comm, struct bootstrapState* st memcpy(peerProxy + irank, &(ringData[irank].peerProxy), sizeof(union ncclSocketAddress)); if (peerUDS) memcpy(peerUDS + irank, &(ringData[irank].peerUDS), sizeof(uint64_t)); + if (rasRanks) + memcpy(rasRanks + irank, &(ringData[irank].rasRank), sizeof(*rasRanks)); } exit: @@ -599,7 +615,10 @@ fail: NCCL_PARAM(StaggerRate, "UID_STAGGER_RATE", 7000); NCCL_PARAM(StaggerThreshold, "UID_STAGGER_THRESHOLD", 256); +NCCL_PARAM(RasEnable, "RAS_ENABLE", 1); + ncclResult_t bootstrapInit(int nHandles, void* handles, struct ncclComm* comm) { + ncclResult_t result = ncclSuccess; int rank = comm->rank; int nranks = comm->nRanks; // char nextPeerHandle[NCCL_NET_HANDLE_MAXSIZE]; @@ -608,6 +627,8 @@ ncclResult_t bootstrapInit(int nHandles, void* handles, struct ncclComm* comm) { struct ncclSocket sock, listenSockRoot; struct extInfo info = {0}; union ringConnectInfo nextPeer; + bool performRasAddRanks = true; + struct rasRankInit* rasRanks = nullptr; uint64_t timers[BOOTSTRAP_INIT_TIME_N] = {0}; @@ -701,23 +722,45 @@ ncclResult_t bootstrapInit(int nHandles, void* handles, struct ncclComm* comm) { // in case of failure, those resources will be free'd when calling bootstrapDestroy, so we can return immediatly NCCLCHECK(ncclCalloc(&state->peerProxyAddresses, nranks)); NCCLCHECK(ncclCalloc(&proxySocket, 1)); - NCCLCHECK(createListenSocket(comm, comm->magic, proxySocket, state->peerProxyAddresses + rank, ncclSocketTypeProxy)); + NCCLCHECKGOTO(createListenSocket(comm, comm->magic, proxySocket, state->peerProxyAddresses + rank, ncclSocketTypeProxy), result, fail); - NCCLCHECK(ncclCalloc(&state->peerProxyAddressesUDS, nranks)); - NCCLCHECK(getUDS(state->peerProxyAddressesUDS + rank)); + NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddressesUDS, nranks), result, fail); + NCCLCHECKGOTO(getUDS(state->peerProxyAddressesUDS + rank), result, fail); // create a socket for others to reach out (P2P) union ncclSocketAddress peerSocketAddress; - NCCLCHECK(createListenSocket(comm, comm->magic, &STATE_LISTEN(state, peerSocket), &peerSocketAddress, ncclSocketTypeBootstrap)); - NCCLCHECK(ncclCalloc(&state->peerP2pAddresses, nranks * sizeof(union ncclSocketAddress))); + NCCLCHECKGOTO(createListenSocket(comm, comm->magic, &STATE_LISTEN(state, peerSocket), &peerSocketAddress, ncclSocketTypeBootstrap), result, fail); + NCCLCHECKGOTO(ncclCalloc(&state->peerP2pAddresses, nranks), result, fail); memcpy(state->peerP2pAddresses + rank, &peerSocketAddress, sizeof(union ncclSocketAddress)); + // Initialize RAS + if (ncclParamRasEnable() == 1) { + // The RAS thread will take care of freeing the memory allocated below. + NCCLCHECK(ncclCalloc(&rasRanks, nranks)); + memcpy(&rasRanks[rank].addr, &bootstrapNetIfAddr, sizeof(rasRanks[rank].addr)); + rasRanks[rank].pid = getpid(); + rasRanks[rank].cudaDev = comm->cudaDev; + rasRanks[rank].nvmlDev = comm->nvmlDev; + if (ncclRasCommInit(comm, rasRanks+rank) != ncclSuccess) { + INFO(NCCL_INIT|NCCL_RAS, "Continuing in spite of a RAS initialization error"); + // We should still participate in the ringAllInfo below as the peers will be waiting for us. + // Just make sure that the address is clearly invalid... + memset(rasRanks+rank, '\0', sizeof(*rasRanks)); + performRasAddRanks = false; + } + } + BOOTSTRAP_PROF_OPEN(timers[BOOTSTRAP_INIT_TIME_RING]); - NCCLCHECK(ringAllInfo(comm, state, state->peerP2pAddresses, state->peerProxyAddresses, state->peerProxyAddressesUDS)); + NCCLCHECKGOTO(ringAllInfo(comm, state, state->peerP2pAddresses, state->peerProxyAddresses, state->peerProxyAddressesUDS, rasRanks), result, fail); BOOTSTRAP_PROF_CLOSE(timers[BOOTSTRAP_INIT_TIME_RING]); // Create the service proxy and get the UDS - NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses, state->peerProxyAddressesUDS)); + NCCLCHECKGOTO(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses, state->peerProxyAddressesUDS), result, fail); + + if (ncclParamRasEnable() == 1 && performRasAddRanks) { + if (ncclRasAddRanks(rasRanks, nranks) != ncclSuccess) + INFO(NCCL_INIT|NCCL_RAS, "Continuing in spite of a RAS initialization error"); + } BOOTSTRAP_PROF_CLOSE(timers[BOOTSTRAP_INIT_TIME_TOTAL]); TRACE(NCCL_BOOTSTRAP, "rank %d nranks %d - DONE", rank, nranks); @@ -727,8 +770,11 @@ ncclResult_t bootstrapInit(int nHandles, void* handles, struct ncclComm* comm) { timers[BOOTSTRAP_INIT_TIME_RECV] / 1e9, timers[BOOTSTRAP_INIT_TIME_RING] / 1e9, timers[BOOTSTRAP_INIT_TIME_DELAY] / 1e9); - - return ncclSuccess; +exit: + return result; +fail: + free(proxySocket); + goto exit; } ncclResult_t bootstrapSplit(uint64_t magic, struct ncclComm* comm, struct ncclComm* parent, int color, int key, int* parentRanks) { @@ -766,6 +812,11 @@ ncclResult_t bootstrapSplit(uint64_t magic, struct ncclComm* comm, struct ncclCo union ncclSocketAddress peerSocketAddress; NCCLCHECK(createListenSocket(comm, comm->magic, &STATE_LISTEN(state, peerSocket), &peerSocketAddress, ncclSocketTypeBootstrap)); + if (ncclParamRasEnable() == 1) { + if (ncclRasCommInit(comm, nullptr) != ncclSuccess) + INFO(NCCL_INIT|NCCL_RAS, "Continuing in spite of a RAS initialization error"); + } + // Get addr from next rank using the parent's connections NCCLCHECKGOTO(bootstrapSend(parent->bootstrap, prev, BOOTSTRAP_TAG_COMMSPLIT, &info, sizeof(union ringConnectInfo)), ret, fail); NCCLCHECKGOTO(bootstrapRecv(parent->bootstrap, next, BOOTSTRAP_TAG_COMMSPLIT, &nextPeer, sizeof(union ringConnectInfo)), ret, fail); @@ -778,14 +829,14 @@ ncclResult_t bootstrapSplit(uint64_t magic, struct ncclComm* comm, struct ncclCo NCCLCHECK(socketRingConnect(&nextPeer.addr, &STATE_RING(state, socket.send), &STATE_LISTEN(state, socket), &STATE_RING(state, socket.recv), comm->magic, state->abortFlag)); } - NCCLCHECKGOTO(ncclCalloc(&state->peerP2pAddresses, nranks * sizeof(union ncclSocketAddress)), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&state->peerP2pAddresses, nranks), ret, fail); memcpy(state->peerP2pAddresses + rank, &peerSocketAddress, sizeof(union ncclSocketAddress)); if (parent->config.splitShare) { /* map local rank to top parent local rank. */ for (int i = 0; i < nranks; ++i) { comm->topParentRanks[i] = parent->topParentRanks[parentRanks[i]]; } - NCCLCHECKGOTO(ringAllInfo(comm, state, state->peerP2pAddresses, NULL, NULL), ret, fail); + NCCLCHECKGOTO(ringAllInfo(comm, state, state->peerP2pAddresses, NULL, NULL, NULL), ret, fail); } else { NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddresses, nranks), ret, fail); NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddressesUDS, nranks), ret, fail); @@ -793,7 +844,7 @@ ncclResult_t bootstrapSplit(uint64_t magic, struct ncclComm* comm, struct ncclCo NCCLCHECKGOTO(ncclCalloc(&proxySocket, 1), ret, fail); NCCLCHECKGOTO(getUDS(state->peerProxyAddressesUDS + rank), ret, fail); NCCLCHECKGOTO(createListenSocket(comm, comm->magic, proxySocket, state->peerProxyAddresses + rank, ncclSocketTypeProxy), ret, fail); - NCCLCHECKGOTO(ringAllInfo(comm, state, state->peerP2pAddresses, state->peerProxyAddresses, state->peerProxyAddressesUDS), ret, fail); + NCCLCHECKGOTO(ringAllInfo(comm, state, state->peerP2pAddresses, state->peerProxyAddresses, state->peerProxyAddressesUDS, NULL), ret, fail); NCCLCHECKGOTO(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses, state->peerProxyAddressesUDS), ret, fail); } @@ -816,7 +867,7 @@ static ncclResult_t socketConnect(void* commState, int peer, int tag, struct ncc struct bootstrapState* state = (struct bootstrapState*)commState; struct socketAckInfo ack = (struct socketAckInfo){.rank = state->rank, .tag = tag}; - NCCLCHECKGOTO(ncclSocketInit(sock, state->peerP2pAddresses + peer, state->magic, ncclSocketTypeBootstrap), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(sock, state->peerP2pAddresses + peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail); NCCLCHECKGOTO(ncclSocketConnect(sock), ret, fail); NCCLCHECKGOTO(socketSend(sock, &ack, sizeof(struct socketAckInfo)), ret, fail); return ncclSuccess; diff --git a/src/collectives.cc b/src/collectives.cc index 5207589476..f005488a2d 100644 --- a/src/collectives.cc +++ b/src/collectives.cc @@ -47,16 +47,12 @@ const char* ncclDatatypeToString(ncclDataType_t type) { case ncclUint32: return "ncclUint32"; case ncclInt64: return "ncclInt64"; case ncclUint64: return "ncclUint64"; -#if defined(RCCL_FLOAT8) - case ncclFp8E4M3: return "ncclFp8E4M3"; - case ncclFp8E5M2: return "ncclFp8E5M2"; -#endif case ncclFloat16: return "ncclFloat16"; case ncclFloat32: return "ncclFloat32"; case ncclFloat64: return "ncclFloat64"; -#if defined(RCCL_BFLOAT16) case ncclBfloat16: return "ncclBfloat16"; -#endif + case ncclFloat8e4m3: return "ncclFloat8e4m3"; + case ncclFloat8e5m2: return "ncclFloat8e5m2"; default: return "Unknown"; } } @@ -116,8 +112,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream); } - NCCLCHECK(ncclEnqueueCheck(&info)); - return ncclSuccess; + return ncclEnqueueCheck(&info); } NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count, @@ -156,8 +151,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream); } - NCCLCHECK(ncclEnqueueCheck(&info)); - return ncclSuccess; + return ncclEnqueueCheck(&info); } RCCL_PARAM(AllToAllPivotEnable, "ALL_TO_ALL_PIVOT_ENABLE", 0); @@ -309,8 +303,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream); } - NCCLCHECK(ncclEnqueueCheck(&info)); - return ncclSuccess; + return ncclEnqueueCheck(&info); } /* Deprecated original "in place" function, similar to MPI */ NCCL_API(ncclResult_t, ncclBcast, void* buff, size_t count, ncclDataType_t datatype, int root, @@ -318,8 +311,7 @@ NCCL_API(ncclResult_t, ncclBcast, void* buff, size_t count, ncclDataType_t datat ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) { NCCLCHECK(Recorder::instance().record(rrBcast, buff, buff, count, datatype, comm, stream, root)); - NCCLCHECK(ncclBroadcast(buff, buff, count, datatype, root, comm, stream)); - return ncclSuccess; + return ncclBroadcast(buff, buff, count, datatype, root, comm, stream); } NCCL_API(ncclResult_t, ncclGather, const void* sendbuff, void* recvbuff, size_t sendcount, @@ -405,8 +397,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count, count, datatype, root, 0, op, mscclFuncReduce, comm, stream); } - NCCLCHECK(ncclEnqueueCheck(&info)); - return ncclSuccess; + return ncclEnqueueCheck(&info); } NCCL_API(ncclResult_t, ncclReduceScatter, const void* sendbuff, void* recvbuff, size_t recvcount, @@ -445,8 +436,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream); } - NCCLCHECK(ncclEnqueueCheck(&info)); - return ncclSuccess; + return ncclEnqueueCheck(&info); } NCCL_API(ncclResult_t, ncclScatter, const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, int root, @@ -532,12 +522,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream); } - ncclResult_t ret; - NCCLCHECK(ncclGroupStart()); - NCCLCHECKGOTO(ncclEnqueueCheck(&info), ret, exit); -exit: - NCCLCHECK(ncclGroupEnd()); - return ret; + return ncclEnqueueCheck(&info); } NCCL_API(ncclResult_t, ncclRecv, void* recvbuff, size_t count, ncclDataType_t datatype, int peer, @@ -563,10 +548,5 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream); } - ncclResult_t ret; - NCCLCHECK(ncclGroupStart()); - NCCLCHECKGOTO(ncclEnqueueCheck(&info), ret, exit); -exit: - NCCLCHECK(ncclGroupEnd()); - return ret; + return ncclEnqueueCheck(&info); } diff --git a/src/debug.cc b/src/debug.cc index be2a284d41..ad1329b524 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -8,6 +8,7 @@ #include "nccl_net.h" #include #include +#include #include #include #include @@ -89,6 +90,8 @@ static void ncclDebugInit() { mask = NCCL_REG; } else if (strcasecmp(subsys, "PROFILE") == 0) { mask = NCCL_PROFILE; + } else if (strcasecmp(subsys, "RAS") == 0) { + mask = NCCL_RAS; } else if (strcasecmp(subsys, "VERBS") == 0) { mask = NCCL_VERBS; } else if (strcasecmp(subsys, "ALL") == 0) { @@ -226,6 +229,19 @@ void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *file } } +NCCL_API(void, ncclResetDebugInit); +void ncclResetDebugInit() { + // Cleans up from a previous ncclDebugInit() and reruns. + // Use this after changing NCCL_DEBUG and related parameters in the environment. + __atomic_load_n(&ncclDebugLevel, __ATOMIC_ACQUIRE); + if (ncclDebugFile != stdout) { + fclose(ncclDebugFile); + ncclDebugFile = stdout; + } + ncclDebugLevel = -1; + ncclDebugInit(); +} + NCCL_PARAM(SetThreadName, "SET_THREAD_NAME", 0); void ncclSetThreadName(pthread_t thread, const char *fmt, ...) { diff --git a/src/device/all_gather.h b/src/device/all_gather.h index fbb36b512b..c54c90d1f3 100644 --- a/src/device/all_gather.h +++ b/src/device/all_gather.h @@ -10,7 +10,7 @@ #include "primitives.h" namespace { - template + template #if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__) __device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { #else @@ -20,10 +20,10 @@ namespace { ncclRing *ring = &ncclShmem.channel.ring; const int *ringRanks = ring->userRanks; const int nranks = ncclShmem.comm.nRanks; - size_t count, partOffset, partCount, chunkCount; + ssize_t count, partOffset, partCount, chunkCount; ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &count, &partOffset, &partCount, &chunkCount); - size_t offset; - size_t dataOffset; + ssize_t offset; + ssize_t dataOffset; int nelem; int rankDest; @@ -51,103 +51,126 @@ namespace { ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif - + int workNthreads; T *inputBuf = (T*)work->sendbuff; T *outputBuf = (T*)work->recvbuff; - // Coverity reports that the callee treats &ring->next as an array. However, due to the use of - // FanSymmetric<1>, only the first element is ever accessed, so it's fine. - // coverity[callee_ptr_arith:FALSE] - Primitives, 0, Proto, 0> prims - (tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, work->connIndex, work->connIndex, work); + + // If isNetOffload == true, we only use 1 warp to drive Ring algo/network communication + // and the rest of warps proceed to copy src data into dst buffer in parallel when AG + // is not in-place. + if (isNetOffload) { + workNthreads = WARP_SIZE; + chunkCount = NCCL_MAX_NET_SIZE; + } else { + workNthreads = nthreads; + } + + if (tid < workNthreads) { + // Coverity reports that the callee treats &ring->next as an array. However, due to the use of + // FanSymmetric<1>, only the first element is ever accessed, so it's fine. + // coverity[callee_ptr_arith:FALSE] + Primitives, 0, Proto, 0, isNetOffload> prims + (tid, workNthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, work->connIndex, work->connIndex, work, NULL, isNetOffload ? NCCL_MAX_NET_SIZE : 0); #if defined(ENABLE_NPKIT) - if (tid == 0) { - prims.npKitCtxIdx = npKitCtxIdx; - } + if (tid == 0) { + prims.npKitCtxIdx = npKitCtxIdx; + } #endif + for (size_t elemOffset = 0; elemOffset < partCount; elemOffset += chunkCount) { + /////////////// begin AllGather steps /////////////// + nelem = min(chunkCount, partCount - elemOffset); + dataOffset = partOffset + elemOffset; - for (size_t elemOffset = 0; elemOffset < partCount; elemOffset += chunkCount) { - /////////////// begin AllGather steps /////////////// - nelem = min(chunkCount, partCount - elemOffset); - dataOffset = partOffset + elemOffset; - - // step 0: push data to next GPU - rankDest = ringRanks[0]; - offset = dataOffset + rankDest * count; + // step 0: push data to next GPU + rankDest = ringRanks[0]; + offset = dataOffset + rankDest * count; #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_SEND_ENTRY) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_SEND_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - prims.npKitDataProcessTotalTime = 0; - } + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_SEND_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + prims.npKitDataProcessTotalTime = 0; + } #endif - if (inputBuf + dataOffset == outputBuf + offset) { // In place - prims.directSend(dataOffset, offset, nelem); - } else { - prims.directCopySend(dataOffset, offset, nelem); - } + if ((inputBuf + dataOffset == outputBuf + offset) || isNetOffload) { // In place or onePPN + prims.directSend(dataOffset, offset, nelem); + } else { + prims.directCopySend(dataOffset, offset, nelem); + } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_SEND_EXIT) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_SEND_EXIT, nelem*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_SEND_EXIT, nelem*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_ENTRY) - if (tid == 0 && nranks > 2) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_ENTRY, nelem*(nranks-2)*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - prims.npKitDataProcessTotalTime = 0; - } + if (tid == 0 && nranks > 2) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_ENTRY, nelem*(nranks-2)*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + prims.npKitDataProcessTotalTime = 0; + } #endif - // k-2 steps: copy to next GPU - for (int j=1; j 2) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_EXIT, nelem*(nranks-2)*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } + if (tid == 0 && nranks > 2) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_RECV_COPY_SEND_EXIT, nelem*(nranks-2)*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } #endif - // Make final copy from buffer to dest. - rankDest = ringRanks[1]; - offset = dataOffset + rankDest * count; + // Make final copy from buffer to dest. + rankDest = ringRanks[1]; + offset = dataOffset + rankDest * count; #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_ENTRY) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - prims.npKitDataProcessTotalTime = 0; - } + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + prims.npKitDataProcessTotalTime = 0; + } #endif - // Final wait/copy. - prims.directRecv(offset, offset, nelem); - + // Final wait/copy. + prims.directRecv(offset, offset, nelem); + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT, nelem*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); - } + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT, nelem*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } #endif - } + } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_EXIT) - if (tid == 0) { - NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_EXIT, count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), - ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + if (tid == 0) { + NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_EXIT, count*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), + ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); + } +#endif + } else if (inputBuf != outputBuf + ringRanks[0] * count) { + inputBuf = inputBuf + partOffset; + outputBuf = outputBuf + partOffset + ringRanks[0] * count; + reduceCopy + (tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, partCount); } +#if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__) + // we have to wait for all warps before we can proceed to the next work; + // otherwise, we can have contention if next work will use the outputBuf + // in this work. We use bar 14 to avoid conflicts with prims barrier and + // __syncthread(). + if (isNetOffload) barrier_sync(14, nThreads); #endif } } @@ -155,8 +178,15 @@ namespace { template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { - using Proto = ProtoSimple; - runRing(tid, nthreads, work); +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + bool isNetOffload = false; +#else + bool isNetOffload = work->isOneRPN && work->netRegUsed; +#endif + if (isNetOffload) + runRing, true>(tid, nthreads, work); + else + runRing, false>(tid, nthreads, work); } }; @@ -186,7 +216,7 @@ struct RunWorkCollsendbuff; T *outputBuf = (T*)work->recvbuff; Primitives, 0, Proto, 0> prims - (tid, nthreads, NULL, NULL, inputBuf, outputBuf, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, nullptr, false, false, 0, primsModePatAg); + (tid, nthreads, NULL, NULL, inputBuf, outputBuf, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, nullptr, nullptr, 0, primsModePatAg); PatAGAlgorithm patAlgo(chunkCount*sizeof(T), NCCL_STEPS, channelOffset, channelOffset + channelCount, count, chunkCount, rank, nranks); int last = 0; @@ -227,6 +257,7 @@ struct RunWorkCollnHeads * count, nelem, count, -1, 0); } + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 } else if (tid < tidEndBcast) { // Bcast through NVLS using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>; @@ -238,6 +269,7 @@ struct RunWorkColl Coverity think prims.index can be greater than 1 } } else { /* direct allgather */ @@ -294,11 +326,11 @@ struct RunWorkCollchannelLo; char* inbuf = (char*)work->sendbuff; char* outbuf = (char*)work->recvbuff; - ssize_t sizePerRank = work->collnet.count*sizeof(T); - bool inPlace = (inbuf == outbuf + ncclShmem.comm.rank*sizePerRank); + ssize_t countPerRank = work->collnet.count*sizeof(T); + bool inPlace = (inbuf == outbuf + ncclShmem.comm.rank*countPerRank); - ssize_t railAllBeg = min(railGridOffset + part*chunkSize, nNodes*sizePerRank); - ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank); + ssize_t railAllBeg = min(railGridOffset + part*chunkSize, nNodes*countPerRank); + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*countPerRank); int railAllSize = railAllEnd - railAllBeg; if (tid < nDsts) dstSizes[tid] = railAllSize; @@ -311,15 +343,15 @@ struct RunWorkColl (tid, tn, 0, nullptr, false, - /*nSrcs=*/1, [=]__device__(int s/*==0*/) -> void* { - return work->regUsed && (recvDirectFlag & NCCL_DIRECT_READ) ? (char*)srcPtrs[src] + userOneBeg : (char*)srcPtrs[src] + railAllOffset; - }, - /*nDsts=*/outIsDst+nDsts, [=]__device__(int d) -> void* { - return d < outIsDst ? outbuf + userOneBeg - : work->regUsed && (sendDirectFlag & NCCL_DIRECT_WRITE) ? (char*)dstPtrs[d-outIsDst] + userOneBeg - : (char*)dstPtrs[d-outIsDst] + railAllOffset; - }, - delta); + /*nSrcs=*/1, [=]__device__(int s/*==0*/) -> void* { + return work->regUsed && (recvDirectFlag & NCCL_P2P_READ) ? (char*)srcPtrs[src] + userOneBeg : (char*)srcPtrs[src] + railAllOffset; + }, + /*nDsts=*/outIsDst+nDsts, [=]__device__(int d) -> void* { + return d < outIsDst ? outbuf + userOneBeg + : work->regUsed && (sendDirectFlag & NCCL_P2P_WRITE) ? (char*)dstPtrs[d-outIsDst] + userOneBeg + : (char*)dstPtrs[d-outIsDst] + railAllOffset; + }, + delta); } railAllOffset += delta; node += 1; @@ -352,8 +384,9 @@ struct RunWorkCollchannelHi - work->channelLo + 1; struct ncclDirect* direct = &ncclShmem.channel.collnetDirect; int const &nNodes = ncclShmem.comm.nNodes; - ssize_t sizePerRank = work->collnet.count*sizeof(T); + ssize_t countPerRank = work->collnet.count; size_t chunkSize = work->collnet.chunkCount; + const int hasDn = (direct->down[0] >= 0) ? 1 : 0; bool isMultiRail = (direct->nHeads > 1); int nWarps1 = 1; int nWarps2 = (isMultiRail ? 2 : 1); @@ -367,9 +400,12 @@ struct RunWorkCollregUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed) { if (tid == 0) { - int steps = (int)divUp(nNodes * sizePerRank * sizeof(T), NCCL_MAX_COLLNET_SIZE); + // If this rank has local peers (i.e, hasDn == true), we cannot offload all data to network. + // In this case, steps should be computed based on chunkSize and so on; otherwise, we just + // bump the step by 1 to kick off collnet progress. + int steps = hasDn ? (int)divUp(nNodes * countPerRank, nChannels * chunkSize) : 1; Primitives, /*Direct=*/0, Proto, 0>::sendPeerNotify(direct->out, 1, steps); } __syncwarp(); @@ -378,11 +414,11 @@ struct RunWorkColl, /*Direct=*/0, Proto, 0> prims(tid, tn, nullptr, &direct->out, work->sendbuff, nullptr, /*redOpArg=*/0, 0 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t railGridOffset = 0; railGridOffset < nNodes * sizePerRank; railGridOffset += nChannels * chunkSize) { + for (ssize_t railGridOffset = 0; railGridOffset < nNodes * countPerRank; railGridOffset += nChannels * chunkSize) { ssize_t railAllBeg = railGridOffset + part * chunkSize; - ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes * sizePerRank); - ssize_t railOneBeg = ncclShmem.comm.node * sizePerRank; - ssize_t railOneEnd = railOneBeg + sizePerRank; + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes * countPerRank); + ssize_t railOneBeg = ncclShmem.comm.node * countPerRank; + ssize_t railOneEnd = railOneBeg + countPerRank; ssize_t beg = max(railAllBeg, railOneBeg); ssize_t end = min(railAllEnd, railOneEnd); prims.send(beg - railOneBeg, max(ssize_t(0), end - beg)); @@ -394,10 +430,9 @@ struct RunWorkCollregUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed && !hasDn) { if (tid == 0) { - int steps = (int)divUp(nNodes * sizePerRank * sizeof(T), NCCL_MAX_COLLNET_SIZE); - Primitives, /*Direct=*/0, Proto, 0>::recvPeerNotify(direct->out, 0, steps); + Primitives, /*Direct=*/0, Proto, 0>::recvPeerNotify(direct->out, 0, 1); } __syncwarp(); } else { @@ -405,7 +440,7 @@ struct RunWorkColl, /*Direct=*/1, Proto, 0> prims(tid, tn, &direct->out, direct->heads + 1, nullptr, work->recvbuff, /*redOpArg=*/0, 1 * Proto::MaxGroupWidth, 0, 0, work); - for (ssize_t railGridOffset = 0; railGridOffset < nNodes * sizePerRank; railGridOffset += nChannels * chunkSize) { + for (ssize_t railGridOffset = 0; railGridOffset < nNodes * countPerRank; railGridOffset += nChannels * chunkSize) { Scatterer scat; scat.work = work; scat.chunkSize = chunkSize; @@ -423,7 +458,7 @@ struct RunWorkColl, /*Direct=*/1, Proto, 0> prims(tid, tn, direct->heads+1, nullptr, nullptr, work->recvbuff, /*redOpArg=*/0, 2*Proto::MaxGroupWidth, 0, 0, work); - for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + for (ssize_t railGridOffset=0; railGridOffset < nNodes*countPerRank; railGridOffset += nChannels*chunkSize) { Scatterer scat; scat.work = work; scat.chunkSize = chunkSize; diff --git a/src/device/all_reduce.h b/src/device/all_reduce.h index 7144e427d1..835ab708d4 100644 --- a/src/device/all_reduce.h +++ b/src/device/all_reduce.h @@ -166,7 +166,7 @@ namespace { chunkOffset = chunk * chunkCount; offset = gridOffset + elemOffset + chunkOffset; nelem = (int)min(chunkCount, remCount - chunkOffset); - prims.directRecvCopyDirectSend(offset, nelem); + prims.directRecvCopyDirectSend(offset, offset, nelem); } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_REDUCE_RING_DIRECT_RECV_COPY_SEND_EXIT) @@ -336,7 +336,7 @@ namespace { for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { offset = gridOffset + elemOffset; nelem = min(chunkCount, channelCount - elemOffset); - prims.directRecvCopyDirectSend(offset, nelem); + prims.directRecvCopyDirectSend(offset, offset, nelem); } } @@ -535,7 +535,7 @@ namespace { for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { offset = gridOffset + elemOffset; nelem = min(chunkCount, channelCount - elemOffset); - prims.directRecvCopyDirectSend(offset, nelem); + prims.directRecvCopyDirectSend(offset, offset, nelem); } } @@ -599,22 +599,30 @@ struct RunWorkColl; if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) { // Scatter - Primitives, /*Direct=*/0, Proto, 0> + Primitives, /*Direct=*/1, Proto, 0> prims(tid-tidStartScatter, nThreadsScatter, NULL, direct->up, work->sendbuff, work->recvbuff, - work->redOpArg, 2*Proto::MaxGroupWidth, 1, 1); + work->redOpArg, 2*Proto::MaxGroupWidth, 1, 1, work); + ssize_t offsetBase, peerOffset; + ssize_t maxNelems; + if (work->netRegUsed) { + offsetBase = bid * chunkSize; + maxNelems = size; // never be the min + peerOffset = nChannels * chunkSize; + } else { + offsetBase = bid * direct->nHeads * chunkSize; + maxNelems = direct->nHeads * chunkSize; + peerOffset = chunkSize; + } + // For collnet UB case, we need to organize buffers differently for contiguous buffer access + // across channels. This access pattern should be consistent with code in coll_net.cc for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize; - int nelem = min(direct->nHeads*chunkSize, size-offset); - if (work->regUsed) { - prims.directScatter(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); - } else { - prims.scatter(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); - } + ssize_t offset = gridOffset + offsetBase; + ssize_t nelem = min(maxNelems, size - offset); + prims.scatter(offset, nelem, chunkSize, peerOffset, direct->headRank, direct->shift); } // Coverity complains about a possible overrun inside the destructor of "prims", but that's actually // a false positive. @@ -622,24 +630,20 @@ struct RunWorkColl= tidStartReduce && direct->out != -1) { if (hasDn) { // Reduce, send to network - Primitives, /*Direct=*/0, Proto, 0> + Primitives, /*Direct=*/1, Proto, 0> prims(tid-tidStartReduce, nThreadsReduce, direct->down, &direct->out, work->sendbuff, work->recvbuff, - work->redOpArg, 3*Proto::MaxGroupWidth, 1, 1); + work->redOpArg, 3*Proto::MaxGroupWidth, 1, 1, work); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; - int nelem = min(chunkSize, size-offset); - if (work->regUsed) { - prims.directRecvReduceSend(offset, nelem); - } else { - prims.recvReduceSend(offset, nelem); - } + ssize_t offset = work->netRegUsed ? gridOffset + (bid + direct->headRank * nChannels) * chunkSize + : gridOffset + (bid * direct->nHeads + direct->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); + prims.recvReduceDirectSend(offset, offset, nelem); } } else { // Directly send to network - if (work->regUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed) { if (tid == tidStartReduce) { - int steps = (int)divUp(size * sizeof(T), NCCL_MAX_COLLNET_SIZE); - Primitives, /*Direct=*/0, Proto, 0>::sendPeerNotify(direct->out, 1, steps); + Primitives, /*Direct=*/0, Proto, 0>::sendPeerNotify(direct->out, 1, 1); } __syncwarp(); } else { @@ -647,8 +651,8 @@ struct RunWorkCollout, work->sendbuff, work->recvbuff, work->redOpArg, 3*Proto::MaxGroupWidth, 1, 1); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; - int nelem = min(chunkSize, size-offset); + ssize_t offset = gridOffset + (bid * direct->nHeads + direct->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); prims.send(offset, nelem); } } @@ -657,11 +661,22 @@ struct RunWorkColl, /*Direct=*/0, Proto, 0> prims(tid, nThreadsGather, direct->up, NULL, work->sendbuff, work->recvbuff, - work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, work); + work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, work); + ssize_t offsetBase, peerOffset; + ssize_t maxNelems; + if (work->netRegUsed) { + offsetBase = bid * chunkSize; + maxNelems = size; // never be the min + peerOffset = nChannels * chunkSize; + } else { + offsetBase = bid * direct->nHeads * chunkSize; + maxNelems = direct->nHeads * chunkSize; + peerOffset = chunkSize; + } for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*direct->nHeads*chunkSize; - int nelem = min(direct->nHeads*chunkSize, size-offset); - prims.directGather(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); + ssize_t offset = gridOffset + offsetBase; + ssize_t nelem = min(maxNelems, size - offset); + prims.directGather(offset, nelem, chunkSize, peerOffset, direct->headRank, direct->shift); } } else if (tid >= tidStartBcast && tid < tidStartScatter && direct->out != -1) { if (hasDn) { @@ -673,15 +688,15 @@ struct RunWorkCollout, direct->down, work->sendbuff, work->recvbuff, work->redOpArg, 1*Proto::MaxGroupWidth, 0, 0, work); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid*direct->nHeads+direct->headRank)*chunkSize; - int nelem = min(chunkSize, size-offset); - prims.recvCopyDirectSend(offset, nelem, /*postOp=*/true); + ssize_t offset = work->netRegUsed ? gridOffset + (bid + direct->headRank * nChannels) * chunkSize + : gridOffset + (bid * direct->nHeads + direct->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); + prims.directRecvCopyDirectSend(offset, offset, nelem, /*postOp=*/true); } } else { - if (work->regUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed) { if (tid == tidStartBcast) { - int steps = (int)divUp(size * sizeof(T), NCCL_MAX_COLLNET_SIZE); - Primitives, /*Direct=*/0, Proto, 0>::recvPeerNotify(direct->out, 0, steps); + Primitives, /*Direct=*/0, Proto, 0>::recvPeerNotify(direct->out, 0, 1); } __syncwarp(); } else { @@ -725,8 +740,6 @@ struct RunWorkCollnHeads * chunkSize; - ssize_t offset; - int nelem; int remCount = channelCount%(nvls->nHeads*chunkSize); int lastChunkSize = alignUp(divUp(remCount, nvls->nHeads), 16384/sizeof(T)); @@ -738,8 +751,8 @@ struct RunWorkCollredOpArg, 0 * Proto::MaxGroupWidth, 1, 1); for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { if (channelCount - elemOffset < loopCount) chunkSize = lastChunkSize; - offset = gridOffset + elemOffset; - nelem = work->regUsed ? 0 : min(loopCount, channelCount - elemOffset); + ssize_t offset = gridOffset + elemOffset; + int nelem = work->regUsed ? 0 : min(loopCount, channelCount - elemOffset); prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); } } else if (tid < tidEndGather) { @@ -750,8 +763,8 @@ struct RunWorkCollredOpArg, 1 * Proto::MaxGroupWidth, 1, 1); for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { if (channelCount - elemOffset < loopCount) chunkSize = lastChunkSize; - offset = gridOffset + elemOffset; - nelem = work->regUsed ? 0 : min(loopCount, channelCount - elemOffset); + ssize_t offset = gridOffset + elemOffset; + int nelem = work->regUsed ? 0 : min(loopCount, channelCount - elemOffset); prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); } } else if (tid < tidEndReduce) { @@ -761,7 +774,8 @@ struct RunWorkColldown, &nvls->down, NULL, NULL, work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, work); for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { - ssize_t chunkOffset; + ssize_t chunkOffset, offset; + int nelem; if (channelCount - elemOffset < loopCount) chunkSize = lastChunkSize; chunkOffset = elemOffset + nvls->headRank * chunkSize; offset = gridOffset + chunkOffset; @@ -787,6 +801,7 @@ struct RunWorkCollregUsed ? 0 : min(nvls->nHeads * chunkSize, size - offset); prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); } + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 } else if (tid < tidEndGather) { // Gather using Proto = ProtoSimple<1, 1, COLL_UNROLL>; @@ -795,38 +810,23 @@ struct RunWorkCollredOpArg, 1 * Proto::MaxGroupWidth, 1, 1); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; - int nelem = work->regUsed ? 0 :min(nvls->nHeads * chunkSize, size - offset); + int nelem = work->regUsed ? 0 : min(nvls->nHeads * chunkSize, size - offset); prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); } } else if (tid < tidEndReduce && nvls->headRank != -1) { - if (!hasOut) { - // Reduce, broadcast through NVLS - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>; - // Coverity complains about a possible overrun inside the class below, but that's actually - // a false positive. - // coverity[identity_transfer:FALSE] - Primitives, /*Direct=*/1, Proto, 0> - prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, NULL, NULL, - work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, work); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); - prims.directRecvDirectSend(offset, offset, nelem); - } - } else { - // Reduce, send to network - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; - // Coverity complains about a possible overrun inside the class below, but that's actually - // a false positive. - // coverity[identity_transfer:FALSE] - Primitives, /*Direct=*/1, Proto, 0> - prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->out, NULL, NULL, - work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 1, work); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); - prims.directRecvDirectSend(offset, offset, nelem); - } + // Reduce, send to network + using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; + // Coverity complains about a possible overrun inside the class below, but that's actually + // a false positive. + // coverity[identity_transfer:FALSE] + Primitives, /*Direct=*/1, Proto, 0> + prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->out, NULL, work->recvbuff, + work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 1, work); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = work->regUsed && work->netRegUsed ? gridOffset + (nvls->headRank * nChannels + bid) * chunkSize + : gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); + prims.directRecvDirectSend(offset, offset, nelem); } } else if (tid < tidEndBcast && nvls->headRank != -1) { // Recv from network, broadcast @@ -835,10 +835,11 @@ struct RunWorkColl, /*Direct=*/1, Proto, 0> - prims(tid - tidEndReduce, nThreadsBcast, &nvls->out, &nvls->down, NULL, NULL, + prims(tid - tidEndReduce, nThreadsBcast, &nvls->out, &nvls->down, NULL, work->recvbuff, work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, work); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; + ssize_t offset = work->regUsed && work->netRegUsed ? gridOffset + (nvls->headRank * nChannels + bid) * chunkSize + : gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; int nelem = min(chunkSize, size - offset); prims.directRecvDirectSend(offset, offset, nelem); } @@ -991,10 +992,9 @@ struct RunWorkCollregUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed) { if (groupTid == 0) { - int steps = (int)divUp(size * sizeof(T), NCCL_MAX_COLLNET_SIZE); - Primitives, /*Direct=*/1, Proto, 0>::sendPeerNotify(send, connIndex, steps); + Primitives, /*Direct=*/1, Proto, 0>::sendPeerNotify(send, connIndex, 1); } __syncwarp(); } else { @@ -1004,8 +1004,10 @@ struct RunWorkColl Coverity think prims.index can be greater than 1 prims.directSend(offset, offset, nelem); } + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 } } else { Primitives, /*Direct=*/1, Proto, 0> @@ -1014,18 +1016,19 @@ struct RunWorkColl Coverity think prims.index can be greater than 1 prims.directRecvReduceDirectSend(offset, offset, nelem); } + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 } } else { if (recv == nranks) { // I'm the first in the broadcast chain, I need to perform the division (postOp) if (send == -1) { - if (work->regUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed) { if (groupTid == 0) { - int steps = (int)divUp(size * sizeof(T), NCCL_MAX_COLLNET_SIZE); - Primitives, /*Direct=*/1, Proto, 0>::recvPeerNotify(recv, connIndex, steps); + Primitives, /*Direct=*/1, Proto, 0>::recvPeerNotify(recv, connIndex, 1); } __syncwarp(); } else { @@ -1051,7 +1054,7 @@ struct RunWorkColluserRanks[0]; const int nextRank = ring->userRanks[1]; const int root = work->root; - size_t size; - size_t chunkCount; - size_t channelCount; - size_t gridOffset; + ssize_t size; + ssize_t chunkCount; + ssize_t channelCount; + ssize_t gridOffset; ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount); size_t offset; int nelem; + int workNthreads; + bool isNetOffload = work->isOneRPN && work->netRegUsed; #if defined(ENABLE_NPKIT) int npKitCtxIdx = bid; @@ -55,39 +57,51 @@ namespace { T *inputBuf = (T*)work->sendbuff; T *outputBuf = (T*)work->recvbuff; - // Coverity reports that the callee treats &ring->next as an array. However, due to the use of - // FanSymmetric<1>, only the first element is ever accessed, so it's fine. - // coverity[callee_ptr_arith:FALSE] - Primitives, 0, Proto, 0> - prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, work->connIndex, work->connIndex, work); + workNthreads = isNetOffload ? WARP_SIZE : nthreads; + + if (tid < workNthreads) { + // Coverity reports that the callee treats &ring->next as an array. However, due to the use of + // FanSymmetric<1>, only the first element is ever accessed, so it's fine. + // coverity[callee_ptr_arith:FALSE] + Primitives, 0, Proto, 0> + prims(tid, workNthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, work->connIndex, work->connIndex, work); #if defined(ENABLE_NPKIT) - if (tid == 0) { - prims.npKitCtxIdx = npKitCtxIdx; - } + if (tid == 0) { + prims.npKitCtxIdx = npKitCtxIdx; + } #endif - for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { - offset = gridOffset + elemOffset; - nelem = min(chunkCount, channelCount - elemOffset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); - if (rank == root) { - if (inputBuf == outputBuf) { - prims.directSend(offset, offset, nelem); + if (rank == root) { + if (inputBuf == outputBuf || isNetOffload) { + prims.directSend(offset, offset, nelem); + } else { + prims.directCopySend(offset, offset, nelem); + } + } else if (nextRank == root) { + prims.directRecv(offset, offset, nelem); } else { - prims.directCopySend(offset, offset, nelem); + prims.directRecvCopyDirectSend(offset, offset, nelem); } - } else if (nextRank == root) { - prims.directRecv(offset, offset, nelem); - } else { - prims.directRecvCopyDirectSend(offset, nelem); } + } else if (inputBuf != outputBuf && rank == root) { + inputBuf = inputBuf + gridOffset; + outputBuf = outputBuf + gridOffset; + reduceCopy + (tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, channelCount); } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_EXIT) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_BROADCAST_RING_EXIT, size*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } +#endif +#if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__) + if (isNetOffload) barrier_sync(14, nThreads); #endif } } diff --git a/src/device/common.h b/src/device/common.h index 010c0c9cc9..97eb1a04d3 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -59,8 +59,8 @@ collTrace->p2p.recvConnIndex = p2pWork->recvConnIndex; \ collTrace->p2p.sendProtoLL = p2pWork->sendProtoLL; \ collTrace->p2p.recvProtoLL = p2pWork->recvProtoLL; \ - collTrace->p2p.sendRegistered = p2pWork->sendRegistered; \ - collTrace->p2p.recvRegistered = p2pWork->recvRegistered; \ + collTrace->p2p.sendRegistered = p2pWork->sendNetReg; \ + collTrace->p2p.recvRegistered = p2pWork->recvNetReg; \ collTrace->p2pOpCount[0] = p2pWork->sendOpCount; \ collTrace->p2pOpCount[1] = p2pWork->recvOpCount; \ collTrace->type = (launch_type) | ncclCollTraceP2pElemType; \ @@ -628,6 +628,9 @@ __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONST __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); #endif +#define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \ + __global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {} + #ifdef USE_INDIRECT_FUNCTION_CALL #define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \ __device__ void ncclDevFunc_##suffix() { \ diff --git a/src/device/common_kernel.h b/src/device/common_kernel.h index c7d6d44d28..a601618aef 100644 --- a/src/device/common_kernel.h +++ b/src/device/common_kernel.h @@ -67,19 +67,23 @@ __device__ __forceinline__ void reduceCopyPacks( uintptr_t minSrcs[MinSrcs + !MinSrcs]; uintptr_t minDsts[MinDsts + !MinDsts]; #pragma unroll - for (int s=0; s < MinSrcs; s++) + for (int s=0; s < MinSrcs; s++) { minSrcs[s] = cvta_to_global(srcPtrFn(s)) + threadBytesBehind; + } + #pragma unroll - for (int d=0; d < MinDsts; d++) + for (int d=0; d < MinDsts; d++) { // Yes, for some template arguments this code will be unreachable. That's fine. // coverity[dead_error_line] minDsts[d] = cvta_to_global(dstPtrFn(d)) + threadBytesBehind; + } // We dictate loop termination condition according to whether partial hunks // can be handled or not. while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) { BytePack acc[Unroll]; + // minSrcs[0] cannot be nullptr so we always process it { RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0); #pragma unroll Unroll for (int u=0; u < Unroll; u++) { @@ -165,7 +169,8 @@ __device__ __forceinline__ void reduceCopyPacks( } } for (int d=MinDsts; (MinDsts < MaxDsts) && (d < MaxDsts) && (d < nDsts); d++) { - uintptr_t dst = cvta_to_global(dstPtrFn(d)) + threadBytesBehind; + uintptr_t dstPtr = cvta_to_global(dstPtrFn(d)); + uintptr_t dst = dstPtr + threadBytesBehind; #pragma unroll Unroll for (int u=0; u < Unroll; u++) { st_global(dst, acc[u]); @@ -175,11 +180,15 @@ __device__ __forceinline__ void reduceCopyPacks( nWarps = nThreads/WARP_SIZE; #pragma unroll - for (int s=0; s < MinSrcs; s++) minSrcs[s] += (nWarps-1)*BytePerHunk; + for (int s=0; s < MinSrcs; s++) { + minSrcs[s] += (nWarps-1)*BytePerHunk; + } #pragma unroll // Yes, for some template arguments this code will be unreachable. That's fine. // coverity[dead_error_line] - for (int d=0; d < MinDsts; d++) minDsts[d] += (nWarps-1)*BytePerHunk; + for (int d=0; d < MinDsts; d++) { + minDsts[d] += (nWarps-1)*BytePerHunk; + } threadBytesBehind += nWarps*BytePerHunk; threadBytesAhead -= nWarps*BytePerHunk; nHunksAhead -= nWarps; diff --git a/src/device/generate.py b/src/device/generate.py index b97292da63..bcc15aefb6 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -6,7 +6,7 @@ import subprocess # Order of redops, tys, protos, algos must match src/include/device.h all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"] all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] -all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16", "f8", "bf8"] +all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"] all_protos = ["LL","LL128","SIMPLE"] all_algos = ["TREE","RING"] all_unroll = ["1", "2", "4"] @@ -253,7 +253,7 @@ def equivalent_primary(coll, algo, proto, redop, ty, unroll): unroll = str(coll_unroll) if coll in ("AllReduce", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned - if redop in ("Sum","Prod","PreMulSum") and ty[0]=="i": + if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i": ty = "u"+ty[1:] # map signed integer min/max to unsigned for non-NVLS elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo): @@ -510,8 +510,8 @@ ty_to_cxx = { "f32": "float", "f64": "double", "bf16": "hip_bfloat16", - "f8": "rccl_float8", - "bf8": "rccl_bfloat8", + "f8e4m3": "rccl_float8", + "f8e5m2": "rccl_bfloat8" } # Generate each /.cpp: diff --git a/src/device/network/unpack/unpack.h b/src/device/network/unpack/unpack.h index 2f8fc13930..44098977d3 100644 --- a/src/device/network/unpack/unpack.h +++ b/src/device/network/unpack/unpack.h @@ -33,17 +33,21 @@ inline __device__ void load64gpu(const uint64_t* ptr, uint64_t &v) { // Map internal association of handle with group and peer index (called once at init time) inline __device__ void ncclNetDeviceUnpackSetup(void* ohandle, const int group, const int index) { struct unpackNetDeviceHandle* handle = (struct unpackNetDeviceHandle*) ohandle; + // coverity[index_parm:FALSE] ncclShmem.groups[group].devicePlugin.unpack.g_meta[index] = handle->meta; ncclShmem.devicePlugin.unpack.bounce_buf = handle->bounce_buf; + // coverity[index_parm:FALSE] ncclShmem.groups[group].devicePlugin.unpack.head[index] = handle->head; } inline __device__ void ncclNetDeviceIncrementHead(const int group, const int index) { + // coverity[index_parm:FALSE] ncclShmem.groups[group].devicePlugin.unpack.head[index]++; } inline __device__ void ncclNetDeviceSaveHead(void* ohandle, const int group, const int index) { struct unpackNetDeviceHandle* handle = (struct unpackNetDeviceHandle*) ohandle; + // coverity[index_parm:FALSE] handle->head = ncclShmem.groups[group].devicePlugin.unpack.head[index]; } diff --git a/src/device/onerank.cu b/src/device/onerank.cu index f59de72428..25bb2ea442 100644 --- a/src/device/onerank.cu +++ b/src/device/onerank.cu @@ -71,13 +71,13 @@ ncclResult_t ncclLaunchOneRank(void* dst, void const* src, size_t nElts, struct case ncclUint32: kernel = (void const*)&oneRankReduce>; break; case ncclInt64: kernel = (void const*)&oneRankReduce>; break; case ncclUint64: kernel = (void const*)&oneRankReduce>; break; +#if defined(RCCL_FLOAT8) + case ncclFloat8e4m3: kernel = (void const*)&oneRankReduce>; break; + case ncclFloat8e5m2: kernel = (void const*)&oneRankReduce>; break; +#endif case ncclFloat16: kernel = (void const*)&oneRankReduce>; break; #if defined(RCCL_BFLOAT16) case ncclBfloat16: kernel = (void const*)&oneRankReduce>; break; -#endif -#if defined(RCCL_FLOAT8) - case ncclFp8E4M3: kernel = (void const*)&oneRankReduce>; break; - case ncclFp8E5M2: kernel = (void const*)&oneRankReduce>; break; #endif case ncclFloat32: kernel = (void const*)&oneRankReduce>; break; case ncclFloat64: kernel = (void const*)&oneRankReduce>; break; diff --git a/src/device/primitives.h b/src/device/primitives.h index 195a873436..f0c6986884 100644 --- a/src/device/primitives.h +++ b/src/device/primitives.h @@ -142,7 +142,7 @@ struct FanSymmetric { }; // The primitives class. Specialized per protocol in the other headers. -template +template class Primitives; // Used by LL & LL128 to implement direct members in the naive way. @@ -160,9 +160,12 @@ struct PrimitivesWithoutDirect { __device__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { static_cast(this)->copySend(inpIx, outIx, eltN, postOp); } - __device__ void directRecvCopyDirectSend(intptr_t outIx, int eltN, bool postOp=false) { + __device__ void directRecvCopyDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { static_cast(this)->recvCopySend(outIx, eltN, /*postOp=*/false); } + __device__ void directRecvDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + return; + } __device__ void recvReduceCopyDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { // Direct is only for the send part static_cast(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp); diff --git a/src/device/prims_ll.h b/src/device/prims_ll.h index 073da7ee50..387348942c 100644 --- a/src/device/prims_ll.h +++ b/src/device/prims_ll.h @@ -10,9 +10,9 @@ #include "npkit/npkit.h" #endif -template -class Primitives: - public PrimitivesWithoutDirect> { +template +class Primitives: + public PrimitivesWithoutDirect> { // In the case of Fan::MaxRecv == 0, we need to force MaxRecv to 1 for this to compile // This is because of a recv buffer which is allocated to MaxRecv length in send-only cases diff --git a/src/device/prims_ll128.h b/src/device/prims_ll128.h index f1ded84b57..21804f5526 100644 --- a/src/device/prims_ll128.h +++ b/src/device/prims_ll128.h @@ -20,9 +20,9 @@ #endif #endif -template -class Primitives: - public PrimitivesWithoutDirect> { +template +class Primitives: + public PrimitivesWithoutDirect> { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; diff --git a/src/device/prims_simple.h b/src/device/prims_simple.h index 44637ef434..e0d89a9620 100644 --- a/src/device/prims_simple.h +++ b/src/device/prims_simple.h @@ -21,31 +21,27 @@ enum primsMode { }; template + int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload> class Primitives< - T, RedOp, Fan, Direct, ProtoSimple, P2p + T, RedOp, Fan, Direct, ProtoSimple, P2p, isNetOffload > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; static constexpr int RoleInput = 0x01, - RoleOutput = 0x02, - RoleWaitRecv = 0x04, - RoleWaitSend = 0x08, - RolePostSend = 0x10, - RolePostRecv = 0x20, - Aborted = 0x40, - NetRegMode = 0x80, - ConnFifoEnabled = 0x100, - DirectWrite = 0x200, - DirectRead = 0x400, - PatMode = 0x800, - NvlsMinPolling = 0x1000, - NetDeviceUnpack = 0x2000, - AnyNetDeviceUnpack = 0x4000, - NvlsDirectRead = 0x8000, - NvlsDirectWrite = 0x10000, - IpcWrite = 0x20000, - IpcRead = 0x40000; + RoleOutput = 0x02, + RoleWaitRecv = 0x04, + RoleWaitSend = 0x08, + RolePostSend = 0x10, + RolePostRecv = 0x20, + Aborted = 0x40, + NetRegMode = 0x80, + ConnFifoEnabled = 0x100, + DirectWrite = 0x200, + DirectRead = 0x400, + PatMode = 0x800, + NvlsMinPolling = 0x1000, + NetDeviceUnpack = 0x2000, + AnyNetDeviceUnpack = 0x4000; const int tid, tidInBlock; const int nthreads; int nworkers; @@ -121,12 +117,9 @@ private: template __device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) { const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; - const bool noRecvWait = DirectRecv && Src && (flags & (DirectRead | IpcRead)); // no wait when directly reading from remote input - const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write // Yes, for some template arguments this code will be unreachable. That's fine. // coverity[dead_error_line] - if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) || - ((flags & (Send*RoleWaitSend)) && !noSendWait)) { + if ((flags & (Recv * RoleWaitRecv)) || (flags & (Send * RoleWaitSend))) { int spins = 0; repeat = 50; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { @@ -143,27 +136,38 @@ private: } if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { - if (flags & ConnFifoEnabled) + if ((flags & ConnFifoEnabled) && (flags & (Send * RoleWaitSend))) connFifo[step%NCCL_STEPS].size = nelts*sizeof(T); void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst) : (ncclShmem.groups[group].srcs + Src); if (flags & NetRegMode) { - // Do nothing + if (P2p) { + ptrs[index] = NULL; + } else { + if (isSendNotRecv) { + if (!Recv) + ptrs[index] = NULL; + else + ptrs[index] = (T*)ncclShmem.groups[group].userOutput + dstIx + offset; + } else { + ptrs[index] = (T*)ncclShmem.groups[group].userOutput + srcIx + offset; + } + } } else if ((flags & ConnFifoEnabled) && connFifo[step%NCCL_STEPS].mode == NCCL_MODE_OFFSET) { ptrs[index] = connEltsFifo + loadInt(&connFifo[step%NCCL_STEPS].offset)/sizeof(T); } else if (isSendNotRecv && DirectSend) { - if (flags & (DirectWrite | NvlsDirectWrite | IpcWrite)) { + if (flags & DirectWrite) { ptrs[index] = directBuff + dstIx + offset; - } else if ((flags & DirectRead) || (flags & IpcRead)) { // empty send + } else if (flags & DirectRead) { // empty send ptrs[index] = nullptr; } else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*connStepSize; } } else if (!isSendNotRecv && DirectRecv) { - if (flags & (DirectRead | NvlsDirectRead | IpcRead)) { + if (flags & DirectRead) { ptrs[index] = directBuff + srcIx + offset; - } else if ((flags & DirectWrite) || (flags & IpcWrite)) { + } else if (flags & DirectWrite) { ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer } else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*connStepSize; @@ -214,7 +218,7 @@ private: int slice = 0; int offset = 0; - if (tid < nworkers && offset < nelem && ((flags & NetRegMode) == 0)) { + if (tid < nworkers && offset < nelem && !isNetOffload) { // Worker-only loop for non-empty slices. Non-workers and empty slices are // processed in the loop following this if block. The benefit of splitting // the loop like this is we pull two branches out of the critical path. @@ -263,7 +267,7 @@ private: * so we need to check whether MultimemSrcs and MultimemDsts are 0. */ && MultimemSrcs == 0 && MultimemDsts == 0 && !Src) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy - if (Send) { + if (Send && Dst && ncclShmem.groups[group].srcs[0] != ncclShmem.groups[group].dsts[1]) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY) if (tid == 0) { @@ -350,13 +354,24 @@ private: constexpr int PreOpSrcs = SrcBuf != Input ? 0 : DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; - reduceCopy - (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, - Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, - Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts, - workSize); + if (Send && Dst && ncclShmem.groups[group].dsts[1] == nullptr) { + // this case should only be directCopySend() with registered buffers and send to net peer + reduceCopy + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, + Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs, + 1, ncclShmem.groups[group].dsts, + workSize); + } else { + reduceCopy + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, + Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs, + Send * fan.nsend() + Dst, ncclShmem.groups[group].dsts, + workSize); + } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { @@ -372,9 +387,14 @@ private: } #endif + } else { + // we will come here when calling prims.directSend with net peer, + // in this case, ncclShmem.groups[group].dsts[0] == NULL, so we + // skip data flush. + workSize = 0; } barrier(); // This barrier has a counterpart in following loop - postPeer(0 < sliceSize); + postPeer(0 < workSize); offset += sliceSize; slice += 1; // Yes, for some template arguments this code will be unreachable. That's fine. @@ -391,10 +411,11 @@ private: sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; { // Only workers could have Wait roles so we know the slice must be empty // since we've exited the loop above. - waitPeer(0, 0, 0, 0); + waitPeer(0, 0, 0, sliceSize); } barrier(); // Has couterpart in preceding worker-only loop. - postPeer(0 < sliceSize); + int workSize = ncclShmem.aborted ? 0 : sliceSize; + postPeer(0 < workSize); offset += sliceSize; slice += 1; } @@ -487,17 +508,17 @@ public: ptrs[index] = connEltsFifo + offset/sizeof(T); } else if (Direct && fn.work->regUsed) { if (isSendNotRecv) { - if (flags & (DirectWrite | IpcWrite)) { + if (flags & DirectWrite) { ptrs[index] = directBuff; - } else if (flags & (DirectRead | IpcRead)) { // empty send + } else if (flags & DirectRead) { // empty send ptrs[index] = nullptr; } else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; } } else { - if (flags & (DirectRead | IpcRead)) { + if (flags & DirectRead) { ptrs[index] = directBuff; - } else if (flags & (DirectWrite | IpcWrite)) { + } else if (flags & DirectWrite) { if (Send) ptrs[index] = directBuff; // send to next from my output buffer else @@ -580,7 +601,7 @@ private: int i = (j+shift)%fan.nsend(); ssize_t pOffset = i*peerOffset; // Skip the data I am responsible of reducing myself - if (skip >= 0 && i >= skip) pOffset += peerElem; + if (skip >= 0 && i >= skip) pOffset += peerOffset; void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset; ssize_t realPeerSize = min(realSize, totalElem-pOffset); if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) { @@ -592,7 +613,7 @@ private: } else if (Recv) { if (tid==0) ncclShmem.groups[group].dsts[0] = (T*)ncclShmem.groups[group].userOutput + outIx + offset; ssize_t pOffset = index*peerOffset; - if (skip >= 0 && index >= skip) pOffset += peerElem; + if (skip >= 0 && index >= skip) pOffset += peerOffset; // Adjust remote index with peer offset in case we are directly pulling from peer's output buffer waitPeer(outIx+pOffset, outIx+pOffset, offset, realSize); subBarrier(); @@ -600,7 +621,7 @@ private: for (int j=0; j= 0 && i >= skip) pOffset += peerElem; + if (skip >= 0 && i >= skip) pOffset += peerOffset; void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset; ssize_t realPeerSize = min(realSize, totalElem-pOffset); if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0; @@ -614,7 +635,7 @@ private: } } - __device__ __forceinline__ void loadRecvConn(ncclDevChannelPeer *peer, int connIndex, uint32_t direct, int regFlag) { + __device__ __forceinline__ void loadRecvConn(ncclDevChannelPeer *peer, int connIndex, uint32_t direct, int ipcRegFlag, int netRegFlag) { conn = &peer->recv[connIndex]; if (conn->netDeviceHandle.netDeviceType == NCCL_NET_DEVICE_UNPACK) { // handle must be a device ptr @@ -639,33 +660,34 @@ private: if (conn->connFifo != nullptr) { flags |= ConnFifoEnabled; connFifo = conn->connFifo; - } else if (Direct && regFlag) { - // User buffers have been registered - if (conn->flags & (NCCL_IPC_READ | NCCL_IPC_WRITE)) { - if (P2p) { - flags |= conn->flags & NCCL_IPC_WRITE ? IpcWrite : IpcRead; - } else if (connIndex == 1 && direct) { - flags |= IpcRead; - } else { - flags |= direct & NCCL_DIRECT_READ ? IpcRead : IpcWrite; + } + if (Direct) { + if (ipcRegFlag) { + // User buffers have been registered + if (conn->flags & (NCCL_P2P_READ | NCCL_P2P_WRITE)) { + if (P2p) { + flags |= conn->flags & NCCL_P2P_WRITE ? DirectWrite : DirectRead; + } else if (connIndex == 1 && direct) { + flags |= DirectRead; + } else { + flags |= direct & NCCL_P2P_READ ? DirectRead : DirectWrite; + } + } else if ((conn->flags & NCCL_NVLS_MIN_POLL)) { + /* NVLS direct */ + flags |= DirectRead; } - } else if (conn->flags & (NCCL_DIRECT_WRITE | NCCL_DIRECT_READ)) { - if (P2p) { - flags |= conn->flags & NCCL_DIRECT_WRITE ? DirectWrite : DirectRead; - } else if (connIndex == 1 && direct) { - flags |= DirectRead; // scatter-reduce use direct pull - } else { - flags |= direct & NCCL_DIRECT_READ ? DirectRead : DirectWrite; + } + if (netRegFlag) { + if (conn->flags & NCCL_DIRECT_NIC) { + flags |= NetRegMode; + connFifo[step % NCCL_STEPS].size = 0; } - } else if ((conn->flags & NCCL_NVLS_MIN_POLL)) { - /* NVLS direct */ - flags |= NvlsDirectRead; } } } } - __device__ __forceinline__ void loadSendConn(ncclDevChannelPeer *peer, int connIndex, uint32_t direct, int regFlag) { + __device__ __forceinline__ void loadSendConn(ncclDevChannelPeer *peer, int connIndex, uint32_t direct, int ipcRegFlag, int netRegFlag) { conn = &peer->send[connIndex]; step = conn->step; step = roundUp(step, SlicePerChunk*StepPerSlice); @@ -685,27 +707,26 @@ private: connStepCache = loadStepValue(connStepPtr); connStepSize = conn->stepSize/sizeof(T); connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; - if (connFifo == nullptr && Direct && regFlag) { - // User buffers have been registered - if (conn->flags & (NCCL_IPC_READ | NCCL_IPC_WRITE)) { - if (P2p) { - flags |= conn->flags & NCCL_IPC_WRITE ? IpcWrite : IpcRead; - } else if (connIndex == 1 && direct) { - flags |= IpcRead; - } else { - flags |= direct & NCCL_DIRECT_READ ? IpcRead : IpcWrite; + if (Direct) { + if (ipcRegFlag) { + // User buffers have been registered + if (conn->flags & (NCCL_P2P_WRITE | NCCL_P2P_READ)) { + if (P2p) { + flags |= conn->flags & NCCL_P2P_WRITE ? DirectWrite : DirectRead; + } else if (connIndex == 1 && direct) { + flags |= DirectRead; // scatter-reduce use direct pull + } else { + flags |= direct & NCCL_P2P_READ ? DirectRead : DirectWrite; + } + } else if ((conn->flags & NCCL_NVLS_MIN_POLL)) { + /* NVLS direct */ + flags |= DirectWrite; } - } else if (conn->flags & (NCCL_DIRECT_WRITE | NCCL_DIRECT_READ)) { - if (P2p) { - flags |= conn->flags & NCCL_DIRECT_WRITE ? DirectWrite : DirectRead; - } else if (connIndex == 1 && direct) { - flags |= DirectRead; // scatter-reduce use direct pull - } else { - flags |= direct & NCCL_DIRECT_READ ? DirectRead : DirectWrite; + } + if (netRegFlag) { + if (conn->flags & NCCL_DIRECT_NIC) { + flags |= NetRegMode; } - } else if ((conn->flags & NCCL_NVLS_MIN_POLL)) { - /* NVLS direct */ - flags |= NvlsDirectWrite; } } } @@ -715,8 +736,8 @@ public: __forceinline__ __device__ Primitives( int tid, int nthreads, int const *recvPeers, int const *sendPeers, void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0, - uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclDevWorkColl* e = nullptr, - bool ipcReg = false, bool netReg = false, int stepSize_ = 0, int mode = primsModeDefault + uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclDevWorkColl* collWork = nullptr, + struct ncclDevWorkP2p* p2pWork = nullptr, int stepSize_ = 0, int mode = primsModeDefault ): tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group), stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_) { @@ -785,11 +806,23 @@ public: // Coverity thinks that index could be -1 here but that's not actually the case. // coverity[negative_returns:FALSE] - if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, e ? e->direct : 0, e ? e->regUsed : ipcReg); - // coverity[negative_returns:FALSE] - if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, e ? e->direct : 0, e ? e->regUsed : ipcReg); - - if (netReg) flags |= NetRegMode; + int sendIpcReg; + int recvIpcReg; + int sendNetReg; + int recvNetReg; + if (P2p) { + sendIpcReg = p2pWork ? p2pWork->sendIpcReg : 0; + recvIpcReg = p2pWork ? p2pWork->recvIpcReg : 0; + sendNetReg = p2pWork ? p2pWork->sendNetReg : 0; + recvNetReg = p2pWork ? p2pWork->recvNetReg : 0; + } else { + recvIpcReg = sendIpcReg = collWork ? collWork->regUsed : 0; + recvNetReg = sendNetReg = collWork ? collWork->netRegUsed : 0; + } + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 + if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg); + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 + if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg); // if (barrierAny(flags & NetDeviceUnpack)) { // flags |= AnyNetDeviceUnpack; @@ -801,8 +834,10 @@ public: // } // } - // coverity[negative_returns:FALSE] - setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)e, (uint8_t)(e ? e->regUsed : ipcReg), peer); + // coverity[negative_returns:FALSE] => coverity thinks that index could be -1 but that's not actually the case + // coverity[var_deref_model] => coverity thinks work can dereferenced if NULL but this is not the case + setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer); + // coverity[uninit_member] => coverity thinks fan.n is not initialized } __forceinline__ __device__ ~Primitives() { @@ -825,6 +860,16 @@ public: // Make sure all threads are done writing back conn->step and done using // ncclShmem.groups[group] barrier(); + + if ((flags & DirectRead) && (flags & RoleWaitSend) && P2p) { + // For sendrecv DirectRead, sender needs to wait for receiver reading data from src. + // This has to be done after barrier() since post thread might have contention with + // this check. + int spins = 0; + volatile uint64_t* tail = conn->tail; + volatile uint64_t* head = conn->head; + while (*tail > *head) if (checkAbort(spins)) break; + } } __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer) { @@ -835,10 +880,10 @@ public: } if (Direct && ipcReg) { - bool recvProvider = (flags & RoleWaitRecv) && (flags & DirectWrite || flags & IpcWrite); - bool sendAcceptor = (flags & RoleWaitSend) && (flags & DirectWrite || flags & IpcWrite || flags & NvlsDirectWrite); - bool sendProvider = (flags & RoleWaitSend) && (flags & DirectRead || flags & IpcRead); // sender provides direct buffer (to be fetched) - bool recvAcceptor = (flags & RoleWaitRecv) && (flags & DirectRead || flags & IpcRead || flags & NvlsDirectRead); // receiver accepts direct buffer + bool recvProvider = (flags & RoleWaitRecv) && (flags & DirectWrite); + bool sendAcceptor = (flags & RoleWaitSend) && (flags & DirectWrite); + bool sendProvider = (flags & RoleWaitSend) && (flags & DirectRead); // sender provides direct buffer (to be fetched) + bool recvAcceptor = (flags & RoleWaitRecv) && (flags & DirectRead); // receiver accepts direct buffer if (recvProvider) { int spins = 0; void* volatile* slot = ncclShmem.groups[group].recvConns[index]->ptrExchange; @@ -851,6 +896,7 @@ public: exchgPtr = (T*)outputBuf; } else { int localPeer = ncclShmem.comm.rankToLocalRank[peer]; + // coverity[deref_parm:FALSE] => work cannot be NULL if ipcReg != NULL exchgPtr = (T*)(work->coll.recvbuffOffset + work->coll.recvbuffRmtAddrs[localPeer]); } *slot = reinterpret_cast(exchgPtr); @@ -869,6 +915,7 @@ public: directBuff = reinterpret_cast(ptr); *slot = nullptr; } else { + // coverity[var_deref_op] directBuff = (T*)work->dnOutputs[index]; } } @@ -889,8 +936,10 @@ public: } else { int localPeer = ncclShmem.comm.rankToLocalRank[peer]; if (MaxRecv == 0) + // coverity[var_deref_op] exchgPtr = (T*)(work->coll.sendbuffOffset + work->coll.sendbuffRmtAddrs[localPeer]); else + // coverity[var_deref_op] exchgPtr = (T*)(work->coll.recvbuffOffset + work->coll.recvbuffRmtAddrs[localPeer]); } @@ -987,11 +1036,11 @@ public: __device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, eltN, postOp); } - __device__ __forceinline__ void directRecvCopyDirectSend(intptr_t outIx, int eltN, bool postOp=false) { - genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, eltN, postOp); + __device__ __forceinline__ void directRecvCopyDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + genericOp<1, 1, 1, 1, -1, Output>(inpIx, outIx, eltN, postOp); } - __device__ __forceinline__ void directRecvDirectSend(intptr_t inpIx, intptr_t outIx, int eltN) { - genericOp<1, 1, 1, 1, -1, -1>(inpIx, outIx, eltN, false); + __device__ __forceinline__ void directRecvDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + genericOp<1, 1, 1, 1, -1, -1>(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, eltN, postOp); @@ -1010,6 +1059,9 @@ public: __device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) { genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp); } + __device__ __forceinline__ void recvReduceDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { + genericOp<0, 1, 1, 1, Input, -1>(inpIx, outIx, eltN, postOp); + } __device__ __forceinline__ void directRecvReduceDirectSend(intptr_t inpIx, intptr_t outIx, ssize_t eltN, bool postOp=false) { genericOp<1, 1, 1, 1, Input, -1>(inpIx, outIx, eltN, postOp); } diff --git a/src/device/reduce_kernel.h b/src/device/reduce_kernel.h index 9d7b773780..a4df12c25e 100644 --- a/src/device/reduce_kernel.h +++ b/src/device/reduce_kernel.h @@ -289,13 +289,30 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h #endif #if defined(RCCL_FLOAT8) +#if __CUDA_ARCH__ >= 900 + SPECIALIZE_REDUCE(FuncSum, __nv_fp8_e4m3, 1, __nv_fp8_e4m3, __nv_fp8_e4m3(__hadd(__half(x),__half(y)))) + SPECIALIZE_REDUCE(FuncSum, __nv_fp8_e4m3, 2, __nv_fp8x2_e4m3, __nv_fp8x2_e4m3(__hadd2(__half2(x),__half2(y)))) + SPECIALIZE_REDUCE(FuncProd, __nv_fp8_e4m3, 1, __nv_fp8_e4m3, __nv_fp8_e4m3(__hmul(__half(x),__half(y)))) + SPECIALIZE_REDUCE(FuncProd, __nv_fp8_e4m3, 2, __nv_fp8x2_e4m3, __nv_fp8x2_e4m3(__hmul2(__half2(x),__half2(y)))) + SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e4m3, 1, __nv_fp8_e4m3, __nv_fp8_e4m3(fn.isMinNotMax ? __hmin(__half(x),__half(y)) : __hmax(__half(x),__half(y)))) + SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e4m3, 2, __nv_fp8x2_e4m3, __nv_fp8x2_e4m3(fn.isMinNotMax ? __hmin2(__half2(x),__half2(y)) : __hmax2(__half2(x),__half2(y)))) + + SPECIALIZE_REDUCE(FuncSum, __nv_fp8_e5m2, 1, __nv_fp8_e5m2, __nv_fp8_e5m2(__hadd(__half(x),__half(y)))) + SPECIALIZE_REDUCE(FuncSum, __nv_fp8_e5m2, 2, __nv_fp8x2_e5m2, __nv_fp8x2_e5m2(__hadd2(__half2(x),__half2(y)))) + SPECIALIZE_REDUCE(FuncProd, __nv_fp8_e5m2, 1, __nv_fp8_e5m2, __nv_fp8_e5m2(__hmul(__half(x),__half(y)))) + SPECIALIZE_REDUCE(FuncProd, __nv_fp8_e5m2, 2, __nv_fp8x2_e5m2, __nv_fp8x2_e5m2(__hmul2(__half2(x),__half2(y)))) + SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 1, __nv_fp8_e5m2, __nv_fp8_e5m2(fn.isMinNotMax ? __hmin(__half(x), __half(y)) : __hmax(__half(x), __half(y)))) + SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 2, __nv_fp8x2_e5m2, __nv_fp8x2_e5m2(fn.isMinNotMax ? __hmin2(__half2(x), __half2(y)) : __hmax2(__half2(x), __half2(y)))) +#else SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, rccl_float8(float(x) + float(y))) SPECIALIZE_REDUCE(FuncProd, rccl_float8, 1, rccl_float8, rccl_float8(float(x) * float(y))) SPECIALIZE_REDUCE(FuncMinMax, rccl_float8, 1, rccl_float8, rccl_float8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y)))) + SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) + float(y))) SPECIALIZE_REDUCE(FuncProd, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) * float(y))) SPECIALIZE_REDUCE(FuncMinMax, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y)))) #endif +#endif #undef SPECIALIZE_REDUCE @@ -416,7 +433,7 @@ struct FuncPreMulSum { using EltType = half; half2 scalar; __device__ FuncPreMulSum(uint64_t opArg=0) { - union { uint64_t u64; half val; }; + union { uint64_t u64; __half val; }; u64 = opArg; scalar.x = val; scalar.y = val; @@ -450,6 +467,31 @@ struct FuncPreMulSum { #endif #if defined(RCCL_FLOAT8) +#if __CUDA_ARCH__ >= 900 + template<> + struct FuncPreMulSum<__nv_fp8_e4m3> { + using EltType = __nv_fp8_e4m3; + __half2 scalar2; + __device__ FuncPreMulSum(uint64_t opArg) { + union { uint64_t u64; __nv_fp8_storage_t val; }; + u64 = opArg; + scalar2.x = __half(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3)); + scalar2.y = scalar2.x; + } + }; + + template<> + struct FuncPreMulSum<__nv_fp8_e5m2> { + using EltType = __nv_fp8_e5m2; + __half2 scalar2; + __device__ FuncPreMulSum(uint64_t opArg) { + union { uint64_t u64; __nv_fp8_storage_t val; }; + u64 = opArg; + scalar2.x = __half(__nv_cvt_fp8_to_halfraw(val, __NV_E5M2)); + scalar2.y = scalar2.x; + } + }; +#else template<> struct FuncPreMulSum { // Change these to switch between all prescale, all postscale, or both by sqrt(N). @@ -480,12 +522,13 @@ struct FuncPreMulSum { } }; #endif +#endif -template -struct Apply_Reduce, /*EltPerPack=*/1> { - __device__ static BytePack reduce(FuncPreMulSum fn, BytePack a, BytePack b) { +template +struct Apply_Reduce, EltPerPack> { + __device__ static BytePack reduce(FuncPreMulSum fn, BytePack a, BytePack b) { // FuncPreMulSum reduce dispatches to FuncSum. - return Apply_Reduce, 1>::reduce(FuncSum(), a, b); + return Apply_Reduce, EltPerPack>::reduce(FuncSum(), a, b); } }; @@ -548,7 +591,49 @@ struct Apply_PreOp, /*EltPerPack=*/1> { #endif #endif +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp of FuncPreMulSum for fp8. + #if defined(RCCL_FLOAT8) +#if __CUDA_ARCH__ >= 900 + template<> + struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_fp8_e4m3> fn, BytePack a + ) { + return toPack<__nv_fp8_e4m3>(__nv_fp8_e4m3(__hmul(__half(fromPack<__nv_fp8_e4m3>(a)), fn.scalar2.x))); + } + }; + template<> + struct Apply_PreOp, /*EltPerPack=*/2> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_fp8_e4m3> fn, BytePack a + ) { + return toPack<__nv_fp8x2_e4m3>(__nv_fp8x2_e4m3(__hmul2(__half2(fromPack<__nv_fp8x2_e4m3>(a)), fn.scalar2))); + } + }; + + template<> + struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_fp8_e5m2> fn, BytePack a + ) { + return toPack<__nv_fp8_e5m2>(__nv_fp8_e5m2(__hmul(__half(fromPack<__nv_fp8_e5m2>(a)), fn.scalar2.x))); + } + }; + template<> + struct Apply_PreOp, /*EltPerPack=*/2> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_fp8_e5m2> fn, BytePack a + ) { + return toPack<__nv_fp8x2_e5m2>(__nv_fp8x2_e5m2(__hmul2(__half2(fromPack<__nv_fp8x2_e5m2>(a)), fn.scalar2))); + } + }; +#else template<> struct Apply_PreOp, /*EltPerPack=*/1> { static constexpr bool IsIdentity = false; @@ -571,6 +656,7 @@ struct Apply_PreOp, /*EltPerPack=*/1> { } }; #endif +#endif //////////////////////////////////////////////////////////////////////////////// // FuncSumPostDiv @@ -583,34 +669,74 @@ struct RedOpArg> { } }; -template::value> -struct FuncSumPostDiv_IntOnly; - template -struct FuncSumPostDiv: FuncSumPostDiv_IntOnly { - __device__ FuncSumPostDiv(uint64_t opArg=0): - FuncSumPostDiv_IntOnly(opArg) { +struct Divider { + __device__ __forceinline__ static T divide(T dividend, T divisor) { + return dividend / divisor; } }; +template<> +struct Divider { + __device__ __forceinline__ static uint64_t divide(uint64_t dividend, uint64_t divisor) { + if (divisor == 0) { + return UINT64_MAX; + } + + uint64_t quotient = 0; + uint64_t remainder = 0; + + #pragma unroll 64 + for (int i = 63; i >= 0; --i) { + remainder = (remainder << 1) | ((dividend >> i) & 1); + if (remainder >= divisor) { + remainder -= divisor; + quotient |= (1ULL << i); + } + } + + return quotient; + } +}; + template -struct FuncSumPostDiv_IntOnly: FuncSum { +struct FuncSumPostDiv { + static_assert(T(0) < T(-1), "FuncSumPostDiv is only for implementing ncclAvg on uint types."); using EltType = T; - int divisor; - __device__ FuncSumPostDiv_IntOnly(uint64_t opArg=0): divisor(opArg) {} + using UintType = typename std::conditional::type; + uint32_t divisor:31, isSigned:1; + UintType recip; + + __device__ FuncSumPostDiv(uint64_t opArg=0) { + isSigned = opArg & 1; + divisor = opArg >> 1; + recip = Divider::divide(UintType(-1), divisor); + } + __device__ T divide(T x) { + // x is negative iff we are in signed mode and the top bit is set + bool xneg = isSigned && (x & ~(T(-1)>>1)); + // Compute abs(x): + // T(-x) vs -T(x) is critical. We have to negate then truncate the bits. Consider + // if we are doing signed 8-bit types, thus T=uint8_t. The value -1 is encoded + // as 0xff. -T(0xff) when promoted to 32-bit (which is implicit by compiler) + // gives 0xffffff01, but T(-0xff) is 0x1, and that is the abs value we want. + UintType xabs = xneg ? T(-x) : x; + // Compute quotient by multiplying by reciprical. + UintType q = sizeof(T)==8 ? __umul64hi(xabs, recip) : __umulhi(xabs, recip); + // Quotient may be off by one so do a fixup. + if (xabs - q*divisor >= divisor) q += 1; + // If original x was negative then we have to negate it back since we were + // working with its abs val. + return xneg ? -T(q) : T(q); + } }; -template -struct FuncSumPostDiv_IntOnly { - static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types."); -}; - -template -struct Apply_Reduce, /*EltPerPack=*/1>: - Apply_Reduce, 1> { - __device__ static BytePack reduce(FuncSumPostDiv fn, BytePack a, BytePack b) { +template +struct Apply_Reduce, EltPerPack>: + Apply_Reduce, EltPerPack> { + __device__ static BytePack reduce(FuncSumPostDiv fn, BytePack a, BytePack b) { // FuncSumPostDiv reduce dispatches to FuncSum. - return Apply_Reduce, 1>::reduce(FuncSum(), a, b); + return Apply_Reduce, EltPerPack>::reduce(FuncSum(), a, b); } }; @@ -618,7 +744,7 @@ template struct Apply_PostOp, /*EltPerPack=*/1> { static constexpr bool IsIdentity = false; __device__ static BytePack postOp(FuncSumPostDiv fn, BytePack a) { - return toPack(fromPack(a) / fn.divisor); + return toPack(fn.divide(fromPack(a))); } }; diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h index 71431464bd..5c2085bd6d 100644 --- a/src/device/reduce_scatter.h +++ b/src/device/reduce_scatter.h @@ -165,7 +165,7 @@ struct RunWorkCollsendbuff; T *outputBuf = (T*)work->recvbuff; Primitives, 0, Proto, 0> prims - (tid, nthreads, NULL, NULL, inputBuf, outputBuf, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, nullptr, false, false, 0, primsModePatRs); + (tid, nthreads, NULL, NULL, inputBuf, outputBuf, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, nullptr, nullptr, 0, primsModePatRs); PatRSAlgorithm patAlgo(chunkCount*sizeof(T), NCCL_STEPS, channelOffset, channelOffset + channelCount, count, chunkCount, rank, nranks); int last = 0; @@ -213,6 +213,7 @@ struct RunWorkCollnHeads * count, nelem, count, -1, 0); } + // coverity[overrun-call] => Coverity think prims.index can be greater than 1 } else if (tid < tidEndReduce) { // Reduce through NVLS using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; @@ -282,10 +283,10 @@ struct RunWorkCollnHeads; int part = ncclShmem.channelId - work->channelLo; void* inbuf = (void*)work->sendbuff; - ssize_t sizePerRank = work->collnet.count; + ssize_t countPerRank = work->collnet.count; - ssize_t railAllBeg = min(railGridOffset + part*chunkSize, nNodes*sizePerRank); - ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank); + ssize_t railAllBeg = min(railGridOffset + part*chunkSize, nNodes*countPerRank); + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*countPerRank); int railAllSize = railAllEnd - railAllBeg; if (tid < nDsts) dstSizes[tid] = railAllSize; @@ -298,15 +299,15 @@ struct RunWorkCollredOpArg, &work->redOpArg, false, /*nSrcs=*/1+nSrcs, [=]__device__(int s) { return s==0 ? (T*)inbuf + userOneBeg - : work->regUsed && (recvDirectFlag & NCCL_DIRECT_READ) + : work->regUsed && (recvDirectFlag & NCCL_P2P_READ) ? (T*)srcPtrs[s-1] + userOneBeg : (T*)srcPtrs[s-1] + railAllOffset; }, @@ -340,7 +341,8 @@ struct RunWorkCollcollnet.chunkCount); - ssize_t sizePerRank = work->collnet.count; + ssize_t countPerRank = work->collnet.count; + const int hasDn = (direct->down[0] >= 0) ? 1 : 0; if (direct->out == -1) __builtin_trap(); bool isMultiRail = (direct->nHeads > 1); @@ -357,15 +359,15 @@ struct RunWorkColl, /*Direct=*/1, Proto, 0> + Primitives, /*Direct=*/0, Proto, 0> prims(tid, tn, nullptr, direct->heads+1, work->sendbuff, nullptr, - work->redOpArg, 0*Proto::MaxGroupWidth, 1, 1, work); - for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + work->redOpArg, 0*Proto::MaxGroupWidth, 1, 1); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*countPerRank; railGridOffset += nChannels*chunkSize) { Scatterer scat; scat.work = work; scat.chunkSize = chunkSize; scat.railGridOffset = railGridOffset; - prims.template process(scat, NCCL_DIRECT_READ, 0); + prims.template process(scat, 0, 0); } return; } @@ -373,23 +375,22 @@ struct RunWorkCollregUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed && !hasDn) { if (tid == 0) { - int steps = (int)divUp(nNodes * sizePerRank * sizeof(T), NCCL_MAX_COLLNET_SIZE); - Primitives, /*Direct=*/0, Proto, 0>::sendPeerNotify(direct->out, 1, steps); + Primitives, /*Direct=*/0, Proto, 0>::sendPeerNotify(direct->out, 1, 1); } __syncwarp(); } else { // Phase 2: Reduce from peers + local input -> send to network - Primitives, /*Direct=*/1, Proto, 0> + Primitives, /*Direct=*/0, Proto, 0> prims(tid, tn, direct->heads + 1, &direct->out, nullptr, nullptr, - work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1, work); - for (ssize_t railGridOffset = 0; railGridOffset < nNodes * sizePerRank; railGridOffset += nChannels * chunkSize) { + work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); + for (ssize_t railGridOffset = 0; railGridOffset < nNodes * countPerRank; railGridOffset += nChannels * chunkSize) { Scatterer scat; scat.work = work; scat.chunkSize = chunkSize; scat.railGridOffset = railGridOffset; - prims.template process(scat, 0, NCCL_DIRECT_READ); + prims.template process(scat, 0, 0); } } return; @@ -398,9 +399,9 @@ struct RunWorkCollregUsed == NCCL_COLLNET_REG_BUFFER) { + if (work->netRegUsed) { if (tid == 0) { - int steps = (int)divUp(nNodes * sizePerRank * sizeof(T), NCCL_MAX_COLLNET_SIZE); + int steps = hasDn ? (int)divUp(nNodes * countPerRank, nChannels * chunkSize) : 1; Primitives, /*Direct=*/0, Proto, 0>::recvPeerNotify(direct->out, 0, steps); } __syncwarp(); @@ -409,11 +410,11 @@ struct RunWorkColl, /*Direct=*/0, Proto, 0> prims(tid, tn, &direct->out, nullptr, nullptr, work->recvbuff, work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0); - for (ssize_t railGridOffset = 0; railGridOffset < nNodes * sizePerRank; railGridOffset += nChannels * chunkSize) { + for (ssize_t railGridOffset = 0; railGridOffset < nNodes * countPerRank; railGridOffset += nChannels * chunkSize) { ssize_t railAllBeg = railGridOffset + part * chunkSize; - ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes * sizePerRank); - ssize_t railOneBeg = ncclShmem.comm.node * sizePerRank; - ssize_t railOneEnd = railOneBeg + sizePerRank; + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes * countPerRank); + ssize_t railOneBeg = ncclShmem.comm.node * countPerRank; + ssize_t railOneEnd = railOneBeg + countPerRank; ssize_t beg = max(railAllBeg, railOneBeg); ssize_t end = min(railAllEnd, railOneEnd); prims.recv(beg - railOneBeg, max(ssize_t(0), end - beg), /*postOp=*/true); diff --git a/src/device/sendrecv.h b/src/device/sendrecv.h index c016023cf6..349dff4f1c 100644 --- a/src/device/sendrecv.h +++ b/src/device/sendrecv.h @@ -19,7 +19,9 @@ struct RunWorkBatch __device__ void runSend(int tid, int tn, int group, struct ncclDevWorkP2p* work) { size_t bytes = work->sendBytes; - int chunkSize = work->sendIpcReg && ncclShmem.comm.isNvlink ? (1 << 30) : u32fp8Decode(work->sendChunkSize_u32fp8); + bool useLargeChunk = (work->sendIpcReg && ncclShmem.comm.isAllNvlink) || work->sendNetReg; + int chunkSize = useLargeChunk ? NCCL_MAX_NET_SIZE : u32fp8Decode(work->sendChunkSize_u32fp8); + int stepSize = useLargeChunk ? NCCL_MAX_NET_SIZE : ncclShmem.comm.p2pChunkSize; #if defined(ENABLE_NPKIT) bool isNpKitThread = (tid == 0); @@ -42,8 +44,7 @@ struct RunWorkBatch, 0, Proto, 1> prims(tid, tn, nullptr, &work->sendRank, work->sendAddr, nullptr, - /*redOpArg(ignored)=*/0, group, work->sendConnIndex, work->sendConnIndex, nullptr, - /*ipcReg=*/work->sendIpcReg, /*netReg=*/work->sendRegistered, ncclShmem.comm.p2pChunkSize); + /*redOpArg(ignored)=*/0, group, work->sendConnIndex, work->sendConnIndex, nullptr, work, stepSize); #if defined(ENABLE_NPKIT) if (isNpKitThread) { @@ -64,7 +65,7 @@ struct RunWorkBatchsendRegistered == 0); + } while (cursor < bytes); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_SEND_RECV_SEND_EXIT) if (isNpKitThread) { @@ -77,7 +78,9 @@ struct RunWorkBatch __device__ void runRecv(int tid, int tn, int group, struct ncclDevWorkP2p* work) { size_t bytes = work->recvBytes; - int chunkSize = work->recvIpcReg && ncclShmem.comm.isNvlink ? (1 << 30) : u32fp8Decode(work->recvChunkSize_u32fp8); + bool useLargeChunk = (work->recvIpcReg && ncclShmem.comm.isAllNvlink) || work->recvNetReg; + int chunkSize = useLargeChunk ? NCCL_MAX_NET_SIZE : u32fp8Decode(work->recvChunkSize_u32fp8); + int stepSize = useLargeChunk ? NCCL_MAX_NET_SIZE : ncclShmem.comm.p2pChunkSize; #if defined(ENABLE_NPKIT) bool isNpKitThread = (tid == 0); @@ -100,8 +103,7 @@ struct RunWorkBatch, 0, Proto, 1> prims(tid, tn, &work->recvRank, nullptr, nullptr, work->recvAddr, - /*redOpArg(ignored)=*/0, group, work->recvConnIndex, work->recvConnIndex, nullptr, - /*ipcReg=*/work->recvIpcReg, /*netReg=*/work->recvRegistered, ncclShmem.comm.p2pChunkSize); + /*redOpArg(ignored)=*/0, group, work->recvConnIndex, work->recvConnIndex, nullptr, work, stepSize); #if defined(ENABLE_NPKIT) if (isNpKitThread) { @@ -122,7 +124,7 @@ struct RunWorkBatchrecvRegistered == 0); + } while (cursor < bytes); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_SEND_RECV_RECV_EXIT) if (isNpKitThread) { diff --git a/src/enqueue.cc b/src/enqueue.cc index 93c040dd53..08ce05f82b 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -24,6 +24,7 @@ #include "api_trace.h" #include // std::memcpy #include // PRIx64 +#include using namespace rccl; @@ -105,15 +106,6 @@ static inline int ncclFuncTrafficPerByte(ncclFunc_t func, int nRanks) { default: return 1; } } -static inline size_t ncclFuncSendCount(ncclFunc_t func, int nRanks, size_t count) { - return func == ncclFuncReduceScatter ? nRanks*count : count; -} -static inline size_t ncclFuncRecvCount(ncclFunc_t func, int nRanks, size_t count) { - return func == ncclFuncAllGather ? nRanks*count : count; -} -rccl_static_inline size_t ncclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count) { - return func == ncclFuncAllGather || func == ncclFuncReduceScatter ? nRanks*count : count; -} /*****************************************************************************/ /* Launch system : synchronization and CUDA kernel launch */ @@ -273,301 +265,8 @@ static void finishPlan(struct ncclComm* comm, struct ncclKernelPlan* plan) { } } -int64_t ncclParamLocalRegister(); NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 1); -struct ncclIpcCleanupCallback { - struct ncclCommCallback base; - void* ptr; -}; -static ncclResult_t cleanupIpc(struct ncclComm* comm, struct ncclCommCallback* cb) { - struct ncclIpcCleanupCallback* me = (struct ncclIpcCleanupCallback*)cb; - CUDACHECKIGNORE(cudaIpcCloseMemHandle(me->ptr)); - free(me); - return ncclSuccess; -} - -static ncclResult_t registerCheckP2PConnection(struct ncclComm* comm, struct ncclConnector* conn, struct ncclTopoGraph* graph, int peer, bool* needReg) { - if (conn->connected) { - if (conn->conn.flags & (NCCL_IPC_READ | NCCL_IPC_WRITE | NCCL_DIRECT_READ | NCCL_DIRECT_WRITE)) { - *needReg = true; - } else { - // network connection - *needReg = false; - } - } else { - struct ncclPeerInfo* peerInfo = &comm->peerInfo[peer]; - struct ncclPeerInfo* myInfo = &comm->peerInfo[comm->rank]; - int canConnect = 0; - NCCLCHECK(ncclTransports[0]->canConnect(&canConnect, comm, graph, myInfo, peerInfo)); - if (canConnect) { - *needReg = true; - } else { - *needReg = false; - } - } - return ncclSuccess; -} - -static ncclResult_t registerCollBuffers( - struct ncclComm* comm, struct ncclTaskColl* info, - void* outRegBufSend[NCCL_MAX_LOCAL_RANKS], - void* outRegBufRecv[NCCL_MAX_LOCAL_RANKS], - struct ncclIntruQueue* cleanupQueue, - bool* regNeedConnect - ) { - ncclResult_t result = ncclSuccess; - - info->regBufType = NCCL_REGULAR_BUFFER; - *regNeedConnect = true; - if (!(ncclParamLocalRegister() || (comm->planner.persistent && ncclParamGraphRegister()))) goto exit; -#if CUDART_VERSION >= 11030 - if (info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_NVLS_TREE) { - if (!comm->nvlsRegSupport || info->opDev.op == ncclDevPreMulSum) goto exit; - bool regBufUsed = false; - const void *sendbuff = info->sendbuff; - void *recvbuff = info->recvbuff; - if (info->func == ncclFuncAllGather) sendbuff = NULL; - if (info->func == ncclFuncReduceScatter) recvbuff = NULL; - size_t elementSize = ncclTypeSize(info->datatype); - size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); - size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); - - /* first try local registration. */ - if (ncclParamLocalRegister()) { - ncclNvlsLocalRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, ®BufUsed, outRegBufSend, outRegBufRecv); - } - - if (regBufUsed == false && comm->planner.persistent && ncclParamGraphRegister()) { - ncclNvlsGraphRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, ®BufUsed, outRegBufSend, outRegBufRecv, cleanupQueue, &info->nCleanupQueueElts); - } - - if (regBufUsed) { - *regNeedConnect = false; - /* tweak NVLS channels usage; for registered NVLS buffer, we only need 4/5 channels to - * saturate bandwidth. */ - if (comm->nNodes == 1) { - if (info->func == ncclFuncReduceScatter) - info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 5)); - else - info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 4)); - } else { - info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 6)); - } - info->regBufType = NCCL_NVLS_REG_BUFFER; - } - } else if ((info->algorithm == NCCL_ALGO_COLLNET_DIRECT || info->algorithm == NCCL_ALGO_COLLNET_CHAIN) && comm->collNetRegSupport && info->opDev.op != ncclDevPreMulSum && info->opDev.op != ncclDevSumPostDiv) { - size_t elementSize = ncclTypeSize(info->datatype); - size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); - size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); - int sendRegBufFlag = 0; - int recvRegBufFlag = 0; - void *sendHandle, *recvHandle; - - if (ncclParamLocalRegister()) { - ncclCollnetLocalRegisterBuffer(comm, info->sendbuff, sendbuffSize, collNetSend, &sendRegBufFlag, &sendHandle); - info->sendMhandle = sendHandle; - if (sendRegBufFlag) { - ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &recvRegBufFlag, &recvHandle); - info->recvMhandle = recvHandle; - } - } - - if ((sendRegBufFlag == 0 || recvRegBufFlag == 0) && comm->planner.persistent && ncclParamGraphRegister()) { - if (!sendRegBufFlag) { - ncclCollnetGraphRegisterBuffer(comm, info->sendbuff, sendbuffSize, collNetSend, &sendRegBufFlag, &sendHandle, cleanupQueue, &info->nCleanupQueueElts); - info->sendMhandle = sendHandle; - } - if (sendRegBufFlag && !recvRegBufFlag) { - ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &recvRegBufFlag, &recvHandle, cleanupQueue, &info->nCleanupQueueElts); - info->recvMhandle = recvHandle; - } - } - - if (sendRegBufFlag && recvRegBufFlag) { - info->nMaxChannels = 1; - info->regBufType = NCCL_COLLNET_REG_BUFFER; - if (sendRegBufFlag == 1 && recvRegBufFlag == 1) { - INFO(NCCL_REG, "rank %d successfully registered collNet sendbuff %p (handle %p), sendbuff size %ld, recvbuff %p (handle %p), recvbuff size %ld", comm->rank, info->sendbuff, sendHandle, sendbuffSize, info->recvbuff, recvHandle, recvbuffSize); - } - } - } else if (comm->intraNodeP2pSupport && info->protocol == NCCL_PROTO_SIMPLE) { - // IPC buffer registration - if (info->func == ncclFuncReduceScatter) goto exit; - if (info->algorithm == NCCL_ALGO_RING && ((info->func == ncclFuncAllReduce && info->sendbuff == info->recvbuff) || info->func == ncclFuncReduce)) goto exit; - if ((info->algorithm == NCCL_ALGO_TREE || info->algorithm == NCCL_ALGO_COLLNET_CHAIN) && info->sendbuff == info->recvbuff) goto exit; - if (info->func == ncclFuncAllGather && info->algorithm == NCCL_ALGO_PAT) goto exit; - - int peerRanks[NCCL_MAX_LOCAL_RANKS]; - int nPeers = 0; - size_t elementSize = ncclTypeSize(info->datatype); - size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); - size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); - int regBufFlag = 0; - memset(peerRanks, 0xff, sizeof(int) * NCCL_MAX_LOCAL_RANKS); - - if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) { - struct ncclChannel* channel = comm->channels; - for (int r = 0; r < NCCL_MAX_DIRECT_ARITY; ++r) { - for (int updown = 0; updown < 2; ++updown) { - int peer; - if (updown == 0) - peer = channel->collnetDirect.up[r]; - else - peer = channel->collnetDirect.down[r]; - if (peer != -1) { - struct ncclConnector* peerConn = &channel->peers[peer]->recv[0]; - bool needReg = false; - - NCCLCHECK(registerCheckP2PConnection(comm, peerConn, &comm->graphs[NCCL_ALGO_COLLNET_DIRECT], peer, &needReg)); - if (needReg) { - bool found = false; - for (int p = 0; p < nPeers; ++p) { - if (peerRanks[p] == peer) { - found = true; - break; - } - } - if (!found) peerRanks[nPeers++] = peer; - } - } - } - } - - if (nPeers > 0) { - if (ncclParamLocalRegister()) - ncclIpcLocalRegisterBuffer(comm, info->sendbuff, sendbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->sendbuffOffset, &info->sendbuffRmtAddrs); - if (!regBufFlag && comm->planner.persistent && ncclParamGraphRegister()) { - ncclIpcGraphRegisterBuffer(comm, info->sendbuff, sendbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->sendbuffOffset, &info->sendbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); - } - if (regBufFlag) { - if (ncclParamLocalRegister()) - ncclIpcLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs); - if (!regBufFlag && comm->planner.persistent && ncclParamGraphRegister()) { - ncclIpcGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); - } - } - } - if (regBufFlag) { - info->regBufType = NCCL_IPC_REG_BUFFER; - } - } else if (info->algorithm == NCCL_ALGO_RING) { - struct ncclReg* recvRegRecord; - NCCLCHECK(ncclRegFind(comm, info->recvbuff, recvbuffSize, &recvRegRecord)); - if (recvRegRecord == NULL) goto exit; - for (int c = 0; c < comm->nChannels; ++c) { - struct ncclChannel* channel = comm->channels + c; - for (int r = 0; r < 2; ++r) { - bool needReg = false; - int peer; - struct ncclConnector* peerConn; - // P2P transport - if (r == 0) - peer = channel->ring.prev; - else - peer = channel->ring.next; - peerConn = &channel->peers[peer]->recv[0]; - NCCLCHECK(registerCheckP2PConnection(comm, peerConn, &comm->graphs[NCCL_ALGO_RING], peer, &needReg)); - - if (needReg) { - bool found = false; - for (int p = 0; p < nPeers; ++p) { - if (peerRanks[p] == peer) { - found = true; - break; - } - } - if (!found) peerRanks[nPeers++] = peer; - } - } - } - if (nPeers > 0) { - if (ncclParamLocalRegister()) { - ncclIpcLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs); - } - if (!regBufFlag && comm->planner.persistent && ncclParamGraphRegister()) { - ncclIpcGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); - } - } - if (regBufFlag) { - info->regBufType = NCCL_IPC_REG_BUFFER; - } - } else if (info->algorithm == NCCL_ALGO_TREE || info->algorithm == NCCL_ALGO_COLLNET_CHAIN) { - struct ncclReg* recvRegRecord; - NCCLCHECK(ncclRegFind(comm, info->recvbuff, recvbuffSize, &recvRegRecord)); - if (recvRegRecord == NULL) goto exit; - for (int c = 0; c < comm->nChannels; ++c) { - struct ncclChannel* channel = comm->channels + c; - struct ncclTree* tree = NULL; - int peers[NCCL_MAX_TREE_ARITY + 1]; - - if (info->algorithm == NCCL_ALGO_TREE) - tree = &channel->tree; - else - tree = &channel->collnetChain; - for (int p = 0; p < NCCL_MAX_TREE_ARITY; ++p) peers[p] = tree->down[p]; - peers[NCCL_MAX_TREE_ARITY] = tree->up; - for (int p = 0; p < NCCL_MAX_TREE_ARITY + 1; ++p) { - int peer = peers[p]; - bool peerNeedReg = false; - struct ncclConnector* recvConn = NULL; - // P2P transport - if (peer == -1 || peer == comm->nRanks) continue; - recvConn = &channel->peers[peer]->recv[0]; - NCCLCHECK(registerCheckP2PConnection(comm, recvConn, &comm->graphs[info->algorithm], peer, &peerNeedReg)); - - if (peerNeedReg) { - bool found = false; - for (int pindex = 0; pindex < nPeers; ++pindex) { - if (peerRanks[pindex] == peer) { - found = true; - break; - } - } - if (!found) peerRanks[nPeers++] = peer; - } - } - } - if (nPeers > 0) { - if (ncclParamLocalRegister()) { - ncclIpcLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs); - } - if (!regBufFlag && comm->planner.persistent && ncclParamGraphRegister()) { - ncclIpcGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); - } - } - if (regBufFlag) { - info->regBufType = NCCL_IPC_REG_BUFFER; - } - } - - if (info->regBufType == NCCL_IPC_REG_BUFFER && comm->nNodes == 1 && 16 < info->nMaxChannels && info->nMaxChannels <= 24) { - info->nMaxChannels = 16; - } - } -#endif -exit: - return result; -} - -static ncclResult_t registerP2pBuffer(struct ncclComm* comm, void* userbuff, int peerRank, size_t size, int* regFlag, void** regAddr, struct ncclIntruQueue* cleanupQueue) { - ncclResult_t ret = ncclSuccess; - uintptr_t offset = 0; - uintptr_t* peerRmtAddrs = NULL; - - *regFlag = 0; - if (ncclParamLocalRegister()) { - ncclIpcLocalRegisterBuffer(comm, userbuff, size, &peerRank, 1, NCCL_IPC_SENDRECV, regFlag, &offset, &peerRmtAddrs); - } - if (*regFlag == 0 && comm->planner.persistent && ncclParamGraphRegister()) { - ncclIpcGraphRegisterBuffer(comm, userbuff, size, &peerRank, 1, NCCL_IPC_SENDRECV, regFlag, &offset, &peerRmtAddrs, reinterpret_cast(cleanupQueue), NULL); - } - - if (*regFlag) - *regAddr = (void*)((uintptr_t)peerRmtAddrs + offset); - return ret; -} - static ncclResult_t getCollNetSupport(struct ncclComm* comm, struct ncclTaskColl* task, int* collNetSupport); rccl_static ncclResult_t getAlgoInfo( struct ncclComm* comm, struct ncclTaskColl* task, @@ -593,10 +292,72 @@ static bool testBudget( return ok; } +ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { + struct ncclKernelPlanner* planner = &comm->planner; + struct ncclTaskColl *task; + + task = ncclIntruQueueHead(&planner->collTaskQueue); + while (task != nullptr) { + // Build a ncclDevWorkColl[Reg?] struct for each task. + void* regBufSend[NCCL_MAX_LOCAL_RANKS]; + void* regBufRecv[NCCL_MAX_LOCAL_RANKS]; + bool regNeedConnect = true; + struct ncclWorkList* workNode = NULL; + struct ncclDevWorkColl devWork = {}; + + if (task->algorithm == NCCL_ALGO_NVLS_TREE || task->algorithm == NCCL_ALGO_NVLS) { + workNode = ncclIntruQueueDequeue(&planner->tmpCollWorkQueue); + goto next; + } + ncclRegisterCollBuffers(comm, task, regBufSend, regBufRecv, &planner->collCleanupQueue, ®NeedConnect); + + devWork.sendbuff = (void*)task->sendbuff; + devWork.recvbuff = (void*)task->recvbuff; + devWork.sendbuffOffset = task->sendbuffOffset; + devWork.recvbuffOffset = task->recvbuffOffset; + devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs; + devWork.recvbuffRmtAddrs = task->recvbuffRmtAddrs; + devWork.root = task->root; + devWork.nWarps = task->nWarps; + devWork.redOpArg = task->opDev.scalarArg; + devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr; + devWork.oneNode = (comm->nNodes == 1); + devWork.isOneRPN = comm->isOneRPN; + devWork.netRegUsed = devWork.regUsed = 0; + if (task->regBufType & NCCL_NET_REG_BUFFER) + devWork.netRegUsed = 1; + if (task->regBufType & (NCCL_IPC_REG_BUFFER | NCCL_NVLS_REG_BUFFER)) + devWork.regUsed = 1; + + if (task->regBufType & NCCL_NVLS_REG_BUFFER) { + struct ncclDevWorkCollReg workReg = {}; + workReg.coll = devWork; // C++ struct assignment + /* NVLS only has one send and recv buffer registered */ + workReg.dnInputs[0] = regBufSend[0]; + workReg.dnOutputs[0] = regBufRecv[0]; + workNode = ncclMemoryStackAllocInlineArray(&comm->memScoped, 1); + workNode->workType = ncclDevWorkTypeCollReg; + workNode->size = sizeof(struct ncclDevWorkCollReg); + memcpy((void*)(workNode+1), (void*)&workReg, workNode->size); + } else { + workNode = ncclMemoryStackAllocInlineArray(&comm->memScoped, 1); + workNode->workType = ncclDevWorkTypeColl; + workNode->size = sizeof(struct ncclDevWorkColl); + memcpy((void*)(workNode+1), (void*)&devWork, workNode->size); + } +next: + ncclIntruQueueEnqueue(&planner->collWorkQueue, workNode); + task = task->next; + } + assert(ncclIntruQueueEmpty(&planner->tmpCollWorkQueue)); + return ncclSuccess; +} + // Called once per ncclGroup to organize the user submitted tasks in // comm->planner so that they can be peeled off into plans. ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool* needConnect, ncclSimInfo_t* simInfo) { struct ncclKernelPlanner* planner = &comm->planner; + planner->persistent = ncclCudaGraphValid(planner->capturingGraph); // Tasks from the sorter come out ordered size descending. struct ncclTaskColl* task = ncclTaskCollSorterDequeueAll(&planner->collSorter); // Tasks are assembled by (fn,op,ty) size ascending. @@ -695,7 +456,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool void* regBufSend[NCCL_MAX_LOCAL_RANKS]; void* regBufRecv[NCCL_MAX_LOCAL_RANKS]; bool regNeedConnect = true; - registerCollBuffers(comm, task, regBufSend, regBufRecv, &planner->collCleanupQueue, ®NeedConnect); + ncclRegisterCollNvlsBuffers(comm, task, regBufSend, regBufRecv, &planner->collCleanupQueue, ®NeedConnect); if (comm->runtimeConn && comm->initAlgoChannels[task->algorithm] == false) { if (task->algorithm == NCCL_ALGO_NVLS_TREE && comm->initAlgoChannels[NCCL_ALGO_NVLS] == false && regNeedConnect == true) { @@ -709,34 +470,30 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool } } - struct ncclDevWorkColl devWork = {}; - devWork.sendbuff = (void*)task->sendbuff; - devWork.recvbuff = (void*)task->recvbuff; - devWork.sendbuffOffset = task->sendbuffOffset; - devWork.recvbuffOffset = task->recvbuffOffset; - devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs; - devWork.recvbuffRmtAddrs = task->recvbuffRmtAddrs; - devWork.root = task->root; - devWork.nWarps = task->nWarps; - devWork.redOpArg = task->opDev.scalarArg; - devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr; - devWork.oneNode = (comm->nNodes == 1); - devWork.regUsed = task->regBufType; - devWork.pivotA2ANumBiRings = comm->topo->pivotA2ANumBiRings; - devWork.opCount = task->opCount; + if (task->algorithm == NCCL_ALGO_NVLS_TREE || task->algorithm == NCCL_ALGO_NVLS) { + struct ncclDevWorkColl devWork = {}; + devWork.sendbuff = (void*)task->sendbuff; + devWork.recvbuff = (void*)task->recvbuff; + devWork.sendbuffOffset = task->sendbuffOffset; + devWork.recvbuffOffset = task->recvbuffOffset; + devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs; + devWork.recvbuffRmtAddrs = task->recvbuffRmtAddrs; + devWork.root = task->root; + devWork.nWarps = task->nWarps; + devWork.redOpArg = task->opDev.scalarArg; + devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr; + devWork.oneNode = (comm->nNodes == 1); + devWork.netRegUsed = devWork.regUsed = 0; + if (task->regBufType & NCCL_NET_REG_BUFFER) + devWork.netRegUsed = 1; + if (task->regBufType & (NCCL_IPC_REG_BUFFER | NCCL_NVLS_REG_BUFFER)) + devWork.regUsed = 1; + devWork.pivotA2ANumBiRings = comm->topo->pivotA2ANumBiRings; + devWork.opCount = task->opCount; - struct ncclWorkList* workNode; - switch (task->regBufType) { - case NCCL_REGULAR_BUFFER: - case NCCL_IPC_REG_BUFFER: - case NCCL_COLLNET_REG_BUFFER: - { workNode = ncclMemoryStackAllocInlineArray(&comm->memScoped, 1); - workNode->workType = ncclDevWorkTypeColl; - workNode->size = sizeof(struct ncclDevWorkColl); - memcpy((void*)(workNode+1), (void*)&devWork, workNode->size); - } break; - case NCCL_NVLS_REG_BUFFER: - { struct ncclDevWorkCollReg workReg = {}; + struct ncclWorkList* workNode; + if (task->regBufType & NCCL_NVLS_REG_BUFFER) { + struct ncclDevWorkCollReg workReg = {}; workReg.coll = devWork; // C++ struct assignment /* NVLS only has one send and recv buffer registered */ workReg.dnInputs[0] = regBufSend[0]; @@ -744,15 +501,16 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool workNode = ncclMemoryStackAllocInlineArray(&comm->memScoped, 1); workNode->workType = ncclDevWorkTypeCollReg; workNode->size = sizeof(struct ncclDevWorkCollReg); - memcpy((void*)(workNode+1), (void*)&workReg, workNode->size); - } break; - default: - /* impossible value */ - WARN("Invalid regBufType %d", task->regBufType); - return ncclInvalidArgument; - } + memcpy((void*)(workNode + 1), (void*)&workReg, workNode->size); + } else { + workNode = ncclMemoryStackAllocInlineArray(&comm->memScoped, 1); + workNode->workType = ncclDevWorkTypeColl; + workNode->size = sizeof(struct ncclDevWorkColl); + memcpy((void*)(workNode + 1), (void*)&devWork, workNode->size); + } - ncclIntruQueueEnqueue(&planner->collWorkQueue, workNode); + ncclIntruQueueEnqueue(&planner->tmpCollWorkQueue, workNode); + } task = task->next; } @@ -935,15 +693,32 @@ static ncclResult_t scheduleCollTasksToPlan( struct ncclProxyOp* proxyOp; if (c == (int)devWork->channelLo) { proxyOp = &proxyOpLo; + proxyOp->loopOffset = 0; + proxyOp->channelSize = countLo * elementSize; } else if (c == (int)devWork->channelHi) { proxyOp = &proxyOpHi; + proxyOp->loopOffset = (countLo + nMidChannels * countMid) * elementSize; + proxyOp->channelSize = countHi * elementSize; } else { proxyOp = &proxyOpMid; + proxyOp->loopOffset = (countLo + (c - devWork->channelLo - 1) * countMid) * elementSize; + proxyOp->channelSize = countMid * elementSize; } proxyOp->channelId = c; proxyOp->opCount = proxyOpId; proxyOp->task.coll = task; proxyOp->rank = comm->rank; + proxyOp->ringAlgo = NULL; + if (proxyOp->reg && task->algorithm == NCCL_ALGO_RING && (task->recvNetHandles[c] || task->sendNetHandles[c])) { + if (task->func == ncclFuncAllGather) { + proxyOp->ringAlgo = new RingAGAlgorithm(task->sendbuff, task->recvbuff, comm->nRanks, comm->channels[c].ring.userRanks, proxyOp->chunkSteps, proxyOp->sliceSteps, proxyOp->chunkSize, proxyOp->sliceSize, proxyOp->loopOffset, proxyOp->channelSize, elementSize, task->count * elementSize, task->sendNetHandles[c], task->recvNetHandles[c], task->srecvNetHandles[c]); + } else if (task->func == ncclFuncAllReduce) { + proxyOp->ringAlgo = new RingARAlgorithm(task->sendbuff, task->recvbuff, comm->nRanks, comm->channels[c].ring.index, proxyOp->chunkSteps, proxyOp->sliceSteps, proxyOp->chunkSize, proxyOp->sliceSize, proxyOp->loopOffset, proxyOp->channelSize, elementSize, task->sendNetHandles[c], task->recvNetHandles[c], task->srecvNetHandles[c]); + } else if (task->func == ncclFuncBroadcast) { + proxyOp->ringAlgo = new RingBCAlgorithm(task->sendbuff, task->recvbuff, comm->rank, task->root, comm->nRanks, comm->channels[c].ring.userRanks, proxyOp->chunkSteps, proxyOp->sliceSteps, proxyOp->chunkSize, proxyOp->sliceSize, proxyOp->loopOffset, proxyOp->channelSize, task->sendNetHandles[c], task->recvNetHandles[c], task->srecvNetHandles[c]); + } + proxyOp->ringAlgo->incRefCount(); + } proxyOp->connIndex = 0; if (task->protocol == NCCL_PROTO_SIMPLE && task->algorithm == NCCL_ALGO_RING) { if (comm->useIntraNet && nBytes > rcclParamIntraNetThreshold()) { @@ -971,6 +746,10 @@ static ncclResult_t scheduleCollTasksToPlan( } if (comm->rank == 0) { + INFO(NCCL_TUNING, "%s: %ld Bytes -> Algo %s proto %s channel{Lo..Hi}={%d..%d}", + ncclFuncToString(task->func), task->count * ncclTypeSize(task->datatype), ncclAlgoToString(task->algorithm), + ncclProtoToString(task->protocol), devWork->channelLo, devWork->channelHi); + if (task->isCollnet) { TRACE(NCCL_COLL, "Collective %s(%s, %s, %s, %s) count=%ld devFuncId=%d channel{Lo..Hi}={%d..%d} count=%ld chunkCount=%d", ncclFuncToString(task->func), ncclDevRedOpToString(task->opDev.op), @@ -1029,6 +808,7 @@ static ncclResult_t addP2pToPlan( bool protoLL[2] = {!selfSend, !selfSend}; bool network[2] = {false, false}; bool proxySameProcess[2] = {true, true}; + void** handles[2] = {NULL, NULL}; uint8_t base = ncclP2pChannelBaseForRound(comm, p2pRound); if (comm->p2pNet) { @@ -1062,7 +842,7 @@ static ncclResult_t addP2pToPlan( int chunkSize[2]; int chunkDataSize[2]; int chunkDataSize_u32fp8[2]; - bool registered[2] = {false, false}; + bool netRegistered[2] = {false, false}; bool ipcRegistered[2] = {false, false}; for (int dir=0; dir < 2; dir++) { // 0=recv, 1=send @@ -1088,10 +868,20 @@ static ncclResult_t addP2pToPlan( if (protocol[dir] == NCCL_PROTO_LL) chunkSize[dir] *= 2; if (network[dir]) { - if (bytes[dir] > 0 && proxySameProcess[dir] && protocol[dir] == NCCL_PROTO_SIMPLE) { - struct ncclReg* regRecord; - NCCLCHECK(ncclRegFind(comm, addrs[dir], bytes[dir], ®Record)); - registered[dir] = regRecord && regRecord->nDevs; + if (bytes[dir] > 0 && proxySameProcess[dir] && protocol[dir] == NCCL_PROTO_SIMPLE && (ncclPxnDisable(comm) || !comm->isAllNvlink)) { + int regFlag = 0; + NCCLCHECK(ncclCalloc(&handles[dir], nChannelsMax)); + for (int part = 0; part < nChannelsMax; part++) { + int channelId = ncclP2pChannelForPart(comm->p2pnChannels, base, part, nChannelsMax, comm->nNodes); + struct ncclChannelPeer** channelPeers = comm->channels[channelId].peers; + int peerRank = dir ? sendRank : recvRank; + struct ncclConnector* conn = dir ? &channelPeers[peerRank]->send[connIndex[dir]] + : &channelPeers[peerRank]->recv[connIndex[dir]]; + if (conn->conn.flags & NCCL_DIRECT_NIC) + ncclRegisterP2pNetBuffer(comm, addrs[dir], bytes[dir], conn, ®Flag, &handles[dir][part], &plan->cleanupQueue); + if (!regFlag) break; + } + netRegistered[dir] = regFlag ? true : false; } } else if (bytes[dir] > 0 && addrs[dir] && protocol[dir] == NCCL_PROTO_SIMPLE && !selfSend) { int peerRank = dir ? sendRank : recvRank; @@ -1101,12 +891,12 @@ static ncclResult_t addP2pToPlan( struct ncclConnector* conn = dir ? &channelPeers[peerRank]->send[connIndex[dir]] : &channelPeers[peerRank]->recv[connIndex[dir]]; void* regAddr = NULL; - if (conn->conn.flags & (NCCL_IPC_WRITE | NCCL_IPC_READ | NCCL_DIRECT_WRITE | NCCL_DIRECT_READ)) { + if (conn->conn.flags & (NCCL_P2P_WRITE | NCCL_P2P_READ)) { // We require users registering buffers on both sides - NCCLCHECK(registerP2pBuffer(comm, addrs[dir], peerRank, bytes[dir], ®Flag, ®Addr, &plan->cleanupQueue)); + NCCLCHECK(ncclRegisterP2pIpcBuffer(comm, addrs[dir], bytes[dir], peerRank, ®Flag, ®Addr, &plan->cleanupQueue)); if (regFlag) { - if (dir == 0 && conn->conn.flags & (NCCL_IPC_WRITE | NCCL_DIRECT_WRITE)) recvAddr = regAddr; - else if (dir == 1 && conn->conn.flags & (NCCL_IPC_READ | NCCL_DIRECT_READ)) sendAddr = regAddr; + if (dir == 0 && (conn->conn.flags & NCCL_P2P_WRITE)) recvAddr = regAddr; + else if (dir == 1 && (conn->conn.flags & NCCL_P2P_READ)) sendAddr = regAddr; } } ipcRegistered[dir] = regFlag ? true : false; @@ -1138,7 +928,7 @@ static ncclResult_t addP2pToPlan( work->channelBase = base; work->nSendChannels = nChannels[1]; work->sendProtoLL = protoLL[1]; - work->sendRegistered = registered[1]; + work->sendNetReg = netRegistered[1]; work->sendIpcReg = ipcRegistered[1]; work->sendChunkSize_u32fp8 = chunkDataSize_u32fp8[1]; work->sendRank = sendRank; @@ -1148,7 +938,7 @@ static ncclResult_t addP2pToPlan( work->sendOpCount = sendOpCount; work->nRecvChannels = nChannels[0]; work->recvProtoLL = protoLL[0]; - work->recvRegistered = registered[0]; + work->recvNetReg = netRegistered[0]; work->recvIpcReg = ipcRegistered[0]; work->recvChunkSize_u32fp8 = chunkDataSize_u32fp8[0]; work->recvRank = recvRank; @@ -1169,7 +959,7 @@ static ncclResult_t addP2pToPlan( op->protocol = protocol[dir]; op->pattern = dir ? ncclPatternSend : ncclPatternRecv; op->chunkSize = chunkSize[dir]; - op->reg = registered[dir]; + op->reg = netRegistered[dir]; op->coll = p2pTasks[dir] ? p2pTasks[dir]->func : 0; op->task.p2p = p2pTasks[dir]; op->rank = comm->rank; @@ -1207,9 +997,10 @@ static ncclResult_t addP2pToPlan( size_t partBeg, partEnd; ncclP2pPartBounds(nParts, part, bytes, &partBeg, &partEnd); if (proxyOps[dir].reg) { - proxyOps[dir].nsteps = 1; - proxyOps[dir].recvbuff = (uint8_t*)addr+partBeg; - proxyOps[dir].nbytes = partEnd-partBeg; + (dir ? proxyOps[dir].sendbuff : proxyOps[dir].recvbuff) = (uint8_t*)addr + partBeg; + (dir ? proxyOps[dir].sendMhandle : proxyOps[dir].recvMhandle) = handles[dir][part]; + proxyOps[dir].nbytes = partEnd - partBeg; + proxyOps[dir].nsteps = DIVUP(proxyOps[dir].nbytes, NCCL_MAX_NET_SIZE); } else { proxyOps[dir].nsteps = divUp(partEnd-partBeg, chunkDataSize); proxyOps[dir].nbytes = std::min(partEnd-partBeg, chunkDataSize); @@ -1289,6 +1080,8 @@ static ncclResult_t scheduleP2pTasksToPlan( // Skip send to self in-place (we don't need to support this). ncclIntruQueueDequeue(&peers[sendRank].sendQueue); ncclIntruQueueDequeue(&peers[recvRank].recvQueue); + ncclMemoryPoolFree(&comm->memPool_ncclTaskP2p, send); + ncclMemoryPoolFree(&comm->memPool_ncclTaskP2p, recv); comm->planner.nTasksP2p -= 2; } else { // Ensure room for worst case of one new batch per channel. @@ -1393,7 +1186,13 @@ static ncclResult_t uploadWork(struct ncclComm* comm, struct ncclKernelPlan* pla plan->kernelArgs->workBuf = comm->workFifoBufDev; break; case ncclDevWorkStorageTypePersistent: - fifoBufHost = aligned_alloc(16, workBytes); // We rely on 16-byte alignment + // We rely on 16-byte alignment + #if __cplusplus >= 201103L + fifoBufHost = aligned_alloc(16, ROUNDUP(workBytes, 16)); + #else + static_assert(16 <= alignof(max_align_t), "We rely on 16-byte alignment."); + fifoBufHost = malloc(workBytes); + #endif fifoCursor = 0; fifoMask = ~0u; break; @@ -1436,37 +1235,41 @@ static ncclResult_t uploadWork(struct ncclComm* comm, struct ncclKernelPlan* pla break; case ncclDevWorkStorageTypePersistent: { ncclResult_t result = ncclSuccess; + struct uploadWork_cleanup_t* cleanup = nullptr; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; void* fifoBufDev = nullptr; - CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + CUDACHECKGOTO(cudaThreadExchangeStreamCaptureMode(&mode), result, fail); // Acquire deviceStream to gain access to deviceStream.cudaStream. Since the // user's graph will be launched later, and it also acquires the deviceStream, // it will observe this upload. - NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->deviceStream), result, finish_scope); + NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->deviceStream), result, fail); - CUDACHECKGOTO(cudaMallocAsync(&fifoBufDev, workBytes, comm->memPool, comm->sharedRes->deviceStream.cudaStream), result, finish_scope); + CUDACHECKGOTO(cudaMallocAsync(&fifoBufDev, workBytes, comm->memPool, comm->sharedRes->deviceStream.cudaStream), result, fail); plan->workBufPersistent = fifoBufDev; plan->kernelArgs->workBuf = fifoBufDev; - CUDACHECKGOTO(cudaMemcpyAsync(fifoBufDev, fifoBufHost, workBytes, cudaMemcpyDefault, comm->sharedRes->deviceStream.cudaStream), result, finish_scope); + // coverity[uninit_use_in_call:FALSE] => fifoBufHost is never NULL + CUDACHECKGOTO(cudaMemcpyAsync(fifoBufDev, fifoBufHost, workBytes, cudaMemcpyDefault, comm->sharedRes->deviceStream.cudaStream), result, fail); cudaEvent_t memcpyDone; - CUDACHECKGOTO(cudaEventCreateWithFlags(&memcpyDone, cudaEventDisableTiming), result, finish_scope); - CUDACHECKGOTO(cudaEventRecord(memcpyDone, comm->sharedRes->deviceStream.cudaStream), result, finish_scope); + CUDACHECKGOTO(cudaEventCreateWithFlags(&memcpyDone, cudaEventDisableTiming), result, fail); + CUDACHECKGOTO(cudaEventRecord(memcpyDone, comm->sharedRes->deviceStream.cudaStream), result, fail); - struct uploadWork_cleanup_t* cleanup; - NCCLCHECK(ncclCalloc(&cleanup, 1)); + NCCLCHECKGOTO(ncclCalloc(&cleanup, 1), result, fail); cleanup->base.fn = uploadWork_cleanup_fn; cleanup->base.event = memcpyDone; cleanup->hostBuf = fifoBufHost; - ncclIntruQueueEnqueue(&comm->eventCallbackQueue, &cleanup->base); + ncclIntruQueueEnqueue(&comm->eventCallbackQueue, (struct ncclCommEventCallback *)cleanup); - NCCLCHECKGOTO(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->deviceStream), result, finish_scope); - NCCLCHECKGOTO(ncclCommPollEventCallbacks(comm), result, finish_scope); + NCCLCHECKGOTO(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->deviceStream), result, fail); + NCCLCHECKGOTO(ncclCommPollEventCallbacks(comm), result, fail); finish_scope: - CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - if (result != ncclSuccess) return result; + if (mode != cudaStreamCaptureModeRelaxed) (void)cudaThreadExchangeStreamCaptureMode(&mode); + return result; + fail: + if (!cleanup) free(fifoBufHost); + goto finish_scope; } break; default: break; } @@ -1478,6 +1281,7 @@ static ncclResult_t uploadProxyOps(struct ncclComm* comm, struct ncclKernelPlan* uint64_t p2pOpBump[MAXCHANNELS] = {/*0...*/}; // Advance comm's collOpCount by number of colls in this plan. comm->sharedRes->collOpCount += plan->collOpCount; + comm->collOpCount += plan->collOpCount; struct ncclProxyOp* op = ncclIntruQueueHead(&plan->proxyOpQueue); while (op != nullptr) { @@ -1500,18 +1304,9 @@ static ncclResult_t uploadProxyOps(struct ncclComm* comm, struct ncclKernelPlan* NCCLCHECK(ncclProxySaveOp(comm, op, nullptr)); op->opCount = oldId; // Restore for next uploadProxyOps() - - struct ncclProxyOp* opNext = op->enqNext; - if (!plan->persistent) { - // Non-persistent kernels upload ops only once so can be free'd here. - ncclMemoryPoolFree(&comm->memPool_ncclProxyOp, op); - } - op = opNext; + op = op->enqNext; } - // Erase proxyOpQueue since all ops were free'd back to mempool. - if (!plan->persistent) ncclIntruQueueConstruct(&plan->proxyOpQueue); - for (int c=0; c < MAXCHANNELS; c++) { // Advance channel's p2pOpCount by number of p2p's in this plan channel. comm->sharedRes->p2pOpCount[c] += p2pOpBump[c]; @@ -1540,6 +1335,8 @@ static void HIPRT_CB hostStreamPlanCallback(void *plan_) { if (result != ncclSuccess) { WARN("hostStreamPlanCallback() failed : %s", ncclGetErrorString(result)); } + if (!plan->persistent) ncclAtomicRefCountDecrement(&plan->comm->noncapturedRefs); + return; } static ncclResult_t reclaimPlan(struct ncclComm* comm, struct ncclCommCallback* me) { @@ -1547,20 +1344,24 @@ static ncclResult_t reclaimPlan(struct ncclComm* comm, struct ncclCommCallback* if (plan->persistent) { comm->persistentRefs -= 1; NCCLCHECK(ncclCudaFree(plan->workBufPersistent)); - struct ncclProxyOp* q = ncclIntruQueueHead(&plan->proxyOpQueue); - while (q != nullptr) { - struct ncclProxyOp* q1 = q->enqNext; - ncclMemoryPoolFree(&comm->memPool_ncclProxyOp, q); - q = q1; - } - ncclResult_t result = ncclSuccess; - while (!ncclIntruQueueEmpty(&plan->cleanupQueue)) { - struct ncclCommCallback* cb = ncclIntruQueueDequeue(&plan->cleanupQueue); - ncclResult_t res1 = cb->fn(comm, cb); // Expect to reclaim memory of cb - if (res1 != ncclSuccess) result = res1; - } - NCCLCHECK(result); } + // Free proxy ops + struct ncclProxyOp* q = ncclIntruQueueHead(&plan->proxyOpQueue); + while (q != nullptr) { + struct ncclProxyOp* q1 = q->enqNext; + if (q->ringAlgo && q->ringAlgo->decRefCount() == 0) delete q->ringAlgo; + ncclMemoryPoolFree(&comm->memPool_ncclProxyOp, q); + q = q1; + } + // Run other free callbacks + ncclResult_t result = ncclSuccess; + while (!ncclIntruQueueEmpty(&plan->cleanupQueue)) { + struct ncclCommCallback* cb = ncclIntruQueueDequeue(&plan->cleanupQueue); + ncclResult_t res1 = cb->fn(comm, cb); // Expect to reclaim memory of cb + if (res1 != ncclSuccess) result = res1; + } + NCCLCHECK(result); + // Free plan struct ncclMemoryPoolFree(&comm->memPool_ncclKernelPlan, plan); return ncclSuccess; } @@ -1582,10 +1383,6 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { planner->persistent = persistent; int nPlans = 0; - // Poll for callbacks sent to us from other threads. Typically these free - // resources from to our memory pools. - NCCLCHECK(ncclCommPollCallbacks(comm, /*waitSome=*/false)); - if (planner->nTasksColl + planner->nTasksP2p != 0) { do { memset(&planner->wipPlan, 0, sizeof(planner->wipPlan)); @@ -1655,7 +1452,7 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { CUDACHECK(hipStreamWaitEvent(planner->streams->stream, comm->doneEvent, 0)); } - if (persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking) { + if (persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking || __atomic_load_n(&comm->noncapturedRefs, __ATOMIC_ACQUIRE)) { // We have to launch host tasks to push proxy args. We are careful to only // do this if necessary since host tasks impose a high performance cost in CUDA. bool acquired = false; @@ -1665,6 +1462,8 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { acquired = true; NCCLCHECKGOTO(ncclStrongStreamAcquire(planner->capturingGraph, &comm->sharedRes->hostStream), result, failure); } + if (!persistent) ncclAtomicRefCountIncrement(&comm->noncapturedRefs); + plan->isHostCbEnq = true; NCCLCHECKGOTO(ncclStrongStreamLaunchHost(planner->capturingGraph, &comm->sharedRes->hostStream, hostStreamPlanCallback, plan), result, failure); } } @@ -1680,6 +1479,7 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { NCCLCHECKGOTO(ncclCudaGraphAddDestructor(planner->capturingGraph, persistentDestructor, (void*)planHead), result, failure); } } + failure: return result; } @@ -1776,7 +1576,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan } ncclResult_t ncclLaunchKernelAfter_NoCuda(struct ncclComm* comm, struct ncclKernelPlan* plan) { - if (!(plan->persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking)) { + if (!(plan->persistent || ncclCudaLaunchBlocking || plan->isHostCbEnq)) { // We are not using the host stream for proxy ops and reclaimation submission. NCCLCHECK(hostStreamPlanTask(comm, plan)); } else { @@ -1862,8 +1662,7 @@ static void initCollCostTable(float** collCostTable) { static ncclResult_t updateCollCostTable( struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes, int collNetSupport, int nvlsSupport, int numPipeOps, - float** collCostTable, int* backupAlgo, int* backupProto, float* backupTime - ) { + float** collCostTable) { float (*table)[NCCL_NUM_PROTOCOLS] = (float (*)[NCCL_NUM_PROTOCOLS])collCostTable; if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot) { @@ -1887,16 +1686,12 @@ static ncclResult_t updateCollCostTable( table[a][p] = NCCL_ALGO_PROTO_IGNORE; continue; } - bool backup; - float time; - NCCLCHECK(ncclTopoGetAlgoTime(comm, info->func, a, p, nBytes, numPipeOps, &time, &backup)); - if (!backup) { - table[a][p] = time; - } else { - if (time >= 0.0 && time < *backupTime) { - *backupAlgo = a; - *backupProto = p; - *backupTime = time; + NCCLCHECK(ncclTopoGetAlgoTime(comm, info->func, a, p, nBytes, numPipeOps, &table[a][p])); + // Relegate fp8 reduction trees of sufficient depth that they incur precision loss + // to be least preferred. + if (info->datatype == ncclFloat8e4m3 || info->datatype == ncclFloat8e5m2) { + if (a == NCCL_ALGO_RING && comm->nRanks > 8) { + table[a][p] *= 1024.0; // Any factor large enough to act as a partition between lossy and non-lossy algos. } } } @@ -1907,7 +1702,7 @@ static ncclResult_t updateCollCostTable( static ncclResult_t topoGetAlgoInfo( struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes, - float** collCostTable, int backupAlgo, int backupProto, float backupTime, ncclSimInfo_t* simInfo + float** collCostTable, ncclSimInfo_t* simInfo ) { float (*table)[NCCL_NUM_PROTOCOLS] = (float (*)[NCCL_NUM_PROTOCOLS])collCostTable; @@ -1932,18 +1727,23 @@ static ncclResult_t topoGetAlgoInfo( // Yes, we are first assigning and then testing if protocol is sane, but that's OK in this case. // coverity[check_after_sink] if (info->algorithm == NCCL_ALGO_UNDEF || info->protocol == NCCL_PROTO_UNDEF) { - if (backupAlgo == NCCL_ALGO_UNDEF || backupProto == NCCL_PROTO_UNDEF) { - WARN("Error : no algorithm/protocol available"); - return ncclInternalError; + char ncclAlgoEnvStr[1024] = ""; + char ncclProtoEnvStr[1024] = ""; + char* algoEnv = getenv("NCCL_ALGO"); + if (algoEnv) { + snprintf(ncclAlgoEnvStr, 1023, " NCCL_ALGO was set to %s.", algoEnv); } - info->algorithm = backupAlgo; - info->protocol = backupProto; - time = backupTime; + char* protoEnv = getenv("NCCL_PROTO"); + if (protoEnv) { + snprintf(ncclProtoEnvStr, 1023, " NCCL_PROTO was set to %s.", protoEnv); + } + WARN("Error : no algorithm/protocol available for function %s with datatype %s.%s%s", ncclFuncToString(info->func), ncclDatatypeToString(info->datatype), ncclAlgoEnvStr, ncclProtoEnvStr); + return (algoEnv || protoEnv) ? ncclInvalidUsage : ncclInternalError; } rcclUpdateCollectiveProtocol(comm, nBytes, info); - if (comm->rank == 0) INFO(NCCL_TUNING, "%s: %ld Bytes -> Algo %d proto %d time %f", ncclFuncToString(info->func), nBytes, info->algorithm, info->protocol, time); if (simInfo) simInfo->estimatedTime = time; TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", nBytes, info->algorithm, info->protocol, time); + int nc = comm->nChannels; int nt = comm->maxThreads[info->algorithm][info->protocol]; int threadThreshold = comm->threadThresholds[info->algorithm][info->protocol]; @@ -2030,19 +1830,24 @@ rccl_static ncclResult_t getAlgoInfo( info->algorithm = NCCL_ALGO_UNDEF; info->protocol = NCCL_PROTO_UNDEF; int nMaxChannels = 0; - int backupAlgo = NCCL_ALGO_UNDEF; - int backupProto = NCCL_PROTO_UNDEF; - float backupTime = 3600000000.0; float collCostTable[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; initCollCostTable((float **)collCostTable); - NCCLCHECK(updateCollCostTable(comm, info, nBytes, collNetSupport, nvlsSupport, numPipeOps, (float **)collCostTable, &backupAlgo, &backupProto, &backupTime)); + NCCLCHECK(updateCollCostTable(comm, info, nBytes, collNetSupport, nvlsSupport, numPipeOps, (float **)collCostTable)); if (comm->tuner != NULL) { + size_t elementSize = ncclTypeSize(info->datatype); + size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); + size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); + struct ncclReg* regSendBuf; + struct ncclReg* regRecvBuf; + NCCLCHECK(ncclRegFind(comm, info->sendbuff, sendbuffSize, ®SendBuf)); + NCCLCHECK(ncclRegFind(comm, info->recvbuff, recvbuffSize, ®RecvBuf)); + int regBuff = ((regSendBuf && regRecvBuf) || (ncclCudaGraphValid(comm->planner.capturingGraph) && ncclParamGraphRegister())); NCCLCHECK(comm->tuner->getCollInfo( comm->tunerContext, info->func, nBytes, numPipeOps, (float **)collCostTable, NCCL_NUM_ALGORITHMS, NCCL_NUM_PROTOCOLS, - &nMaxChannels)); + regBuff, &nMaxChannels)); } - NCCLCHECK(topoGetAlgoInfo(comm, info, nBytes, (float **)collCostTable, backupAlgo, backupProto, backupTime, simInfo)); + NCCLCHECK(topoGetAlgoInfo(comm, info, nBytes, (float **)collCostTable, simInfo)); info->nMaxChannels = nMaxChannels == 0 ? info->nMaxChannels : nMaxChannels; return ncclSuccess; } @@ -2095,37 +1900,7 @@ static ncclResult_t calcCollChunking( } int nstepsPerLoop, nchunksPerLoop; - switch (pattern) { - case ncclPatternTreeUp: - case ncclPatternTreeDown: - case ncclPatternTreeUpDown: - case ncclPatternPatUp: - case ncclPatternPatDown: - case ncclPatternPipelineFrom: - case ncclPatternPipelineTo: - case ncclPatternCollnetChain: - nstepsPerLoop = nchunksPerLoop = 1; - break; - case ncclPatternNvls: - nstepsPerLoop = 1; nchunksPerLoop = comm->channels[0].nvls.nHeads; - break; - case ncclPatternCollnetDirect: - nstepsPerLoop = 1; nchunksPerLoop = comm->channels[0].collnetDirect.nHeads; - break; - case ncclPatternRing: - nstepsPerLoop = comm->nRanks-1; nchunksPerLoop = comm->nRanks; - break; - case ncclPatternRingTwice: - nstepsPerLoop = 2*(comm->nRanks-1); nchunksPerLoop = comm->nRanks; - break; - case ncclPatternNvlsTree: - nstepsPerLoop = 1; nchunksPerLoop = comm->channels[0].nvls.nHeads; - break; - default: - WARN("Unknown pattern %d", pattern); - return ncclInternalError; - } - + size_t loopOffset = 0; int stepSize = comm->buffSizes[info->protocol]/NCCL_STEPS; int chunkSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->chunkSteps : 1; int sliceSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->sliceSteps : 1; @@ -2201,22 +1976,60 @@ static ncclResult_t calcCollChunking( // Compute directFlags of work struct. if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) { // Set direct direction for broadcast-gather (read or write) - *outDirectFlags = (nBytes/nChannels <= 1024 * 4) ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; + *outDirectFlags = (nBytes/nChannels <= 1024 * 4) ? NCCL_P2P_READ : NCCL_P2P_WRITE; } else { *outDirectFlags = 0; } // Compute nSteps for proxies - //if (comm->rank == 0) printf("Coll %d, size %ld -> %dx%d, chunkSize %d (algo %d proto%d)\n", info->func, info->nBytes, info->nChannels, info->nThreads, chunkSize, info->algorithm, info->protocol); chunkSize = chunkSize / grainSize * grainSize; // align chunkSize to multiple grainSize - int nLoops = (int)DIVUP(nBytes, size_t(nChannels)*nchunksPerLoop*chunkSize); + switch (pattern) { + case ncclPatternTreeUp: + case ncclPatternTreeDown: + case ncclPatternTreeUpDown: + case ncclPatternPatUp: + case ncclPatternPatDown: + case ncclPatternPipelineFrom: + case ncclPatternPipelineTo: + case ncclPatternCollnetChain: + nstepsPerLoop = nchunksPerLoop = 1; + break; + case ncclPatternNvls: + nstepsPerLoop = 1; nchunksPerLoop = comm->channels[0].nvls.nHeads; + loopOffset = nChannels * chunkSize * comm->channels[0].nvls.headRank; + break; + case ncclPatternCollnetDirect: + nstepsPerLoop = 1; nchunksPerLoop = comm->channels[0].collnetDirect.nHeads; + loopOffset = nChannels * chunkSize * comm->channels[0].collnetDirect.headRank; + break; + case ncclPatternRing: + nstepsPerLoop = comm->nRanks-1; nchunksPerLoop = comm->nRanks; + break; + case ncclPatternRingTwice: + nstepsPerLoop = 2*(comm->nRanks-1); nchunksPerLoop = comm->nRanks; + break; + case ncclPatternNvlsTree: + nstepsPerLoop = 1; nchunksPerLoop = comm->channels[0].nvls.nHeads; + break; + default: + WARN("Unknown pattern %d", pattern); + return ncclInternalError; + } + + // Compute nSteps for proxies + size_t loopSize = size_t(nChannels)*nchunksPerLoop*chunkSize; + int nLoops = (int)DIVUP(nBytes, loopSize); memset(proxyOp, 0, sizeof(*proxyOp)); proxyOp->nsteps = nstepsPerLoop * nLoops * chunkSteps; proxyOp->sliceSteps = sliceSteps; proxyOp->chunkSteps = chunkSteps; proxyOp->chunkSize = chunkSize; + proxyOp->sliceSize = chunkSize / chunkSteps * sliceSteps; + proxyOp->loopSize = loopSize; + proxyOp->loopOffset = loopOffset; proxyOp->protocol = info->protocol; proxyOp->dtype = info->datatype; + proxyOp->algorithm = info->algorithm; if (info->opDev.op == ncclDevPreMulSum || info->opDev.op == ncclDevSumPostDiv) { proxyOp->redOp = ncclSum; // Network sees avg as sum } else { @@ -2225,17 +2038,50 @@ static ncclResult_t calcCollChunking( proxyOp->pattern = pattern; proxyOp->coll = info->func; proxyOp->root = info->root; + proxyOp->isOneRPN = comm->isOneRPN; // This is used by P2P to reduce the receive buffer size. We don't use it in collectives // because some protocols need to transmit more than the total size, plus they sometimes // round up proxyOp->nbytes = stepSize*sliceSteps; - if (info->regBufType == NCCL_COLLNET_REG_BUFFER) { + if (info->regBufType & NCCL_NET_REG_BUFFER) { proxyOp->reg = 1; - proxyOp->nsteps = DIVUP(nBytes, NCCL_MAX_COLLNET_SIZE); - proxyOp->sendMhandle = info->sendMhandle; + if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT || info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_COLLNET_CHAIN) { + if (proxyOp->isOneRPN) { + proxyOp->nsteps = 1; + proxyOp->loopOffset = 0; + proxyOp->sendbuff = (uint8_t*)info->sendbuff; + proxyOp->sendMhandle = info->sendMhandle; + } else { + if (info->func == ncclFuncAllGather || info->func == ncclFuncReduceScatter) { + proxyOp->nbytes = nBytes / nchunksPerLoop; + proxyOp->loopSize = proxyOp->loopSize / nchunksPerLoop; + proxyOp->loopOffset = 0; + if (info->func == ncclFuncAllGather) { + proxyOp->sendbuff = (uint8_t*)info->sendbuff; + proxyOp->sendMhandle = info->sendMhandle; + } + } else { + proxyOp->sendbuff = (uint8_t*)info->recvbuff; + proxyOp->sendMhandle = info->recvMhandle; + } + } + } else if (info->algorithm == NCCL_ALGO_RING) { + if (proxyOp->isOneRPN && info->func == ncclFuncAllGather) { + proxyOp->chunkSize = NCCL_MAX_NET_SIZE; + proxyOp->sliceSize = NCCL_MAX_NET_SIZE; + proxyOp->chunkSteps = 1; + proxyOp->sliceSteps = 1; + proxyOp->loopSize = size_t(nChannels) * nchunksPerLoop * proxyOp->chunkSize; + proxyOp->nsteps = DIVUP(nBytes, proxyOp->loopSize) * nstepsPerLoop; + proxyOp->loopOffset = 0; + } + } else { + WARN("Net registration invalid algorithm %s", ncclAlgoToString(info->algorithm)); + return ncclInternalError; + } + proxyOp->recvMhandle = info->recvMhandle; - proxyOp->sendbuff = (uint8_t*)info->sendbuff; proxyOp->recvbuff = (uint8_t*)info->recvbuff; proxyOp->nbytes = nBytes; } else { @@ -2254,7 +2100,7 @@ static ncclResult_t calcCollChunking( proxyOp->nbytes = DIVUP(nBytes, nChannels); } - *outChunkSize = chunkSize; + *outChunkSize = proxyOp->chunkSize; return ncclSuccess; } @@ -2262,22 +2108,17 @@ static ncclResult_t hostToDevRedOp( ncclDevRedOpFull *opFull, ncclRedOp_t op, ncclDataType_t datatype, ncclComm *comm ) { union { - int8_t i8; - uint8_t u8; - int32_t i32; - uint32_t u32; - int64_t i64; - uint64_t u64; - half f16; - float f32; - double f64; -#if defined(RCCL_BFLOAT16) - hip_bfloat16 bf16; -#endif -#if defined(RCCL_FLOAT8) - rccl_float8 fp8_e4m3; - rccl_bfloat8 fp8_e5m2; -#endif + int8_t i8; uint8_t u8; + int32_t i32; uint32_t u32; + int64_t i64; uint64_t u64; + __half f16; float f32; double f64; + #if defined(RCCL_BFLOAT16) + hip_bfloat16 bf16; + #endif + #if defined(RCCL_FLOAT8) + rccl_float8 f8; + rccl_bfloat8 bf8; + #endif void *ptr; }; u64 = 0; @@ -2288,7 +2129,8 @@ static ncclResult_t hostToDevRedOp( if (nbits <= 0) return ncclInvalidArgument; uint64_t allBits = uint64_t(-1)>>(64-nbits); uint64_t signBit = allBits^(allBits>>1); - + bool datatype_signed = false; + switch (int(op)) { case ncclSum: opFull->op = ncclDevSum; break; case ncclProd: opFull->op = ncclDevProd; break; @@ -2306,30 +2148,32 @@ static ncclResult_t hostToDevRedOp( case ncclAvg: switch ((int)datatype) { case ncclInt8: case ncclInt32: case ncclInt64: + datatype_signed = true; + // no break, we want to fall through... case ncclUint8: case ncclUint32: case ncclUint64: opFull->op = ncclDevSumPostDiv; - u64 = comm->nRanks; + u64 = comm->nRanks<<1 | datatype_signed; break; + #if defined(RCCL_FLOAT8) + case ncclFloat8e4m3: + opFull->op = ncclDevPreMulSum; + f8 = static_cast(float(1.0/comm->nRanks)); + break; + case ncclFloat8e5m2: + opFull->op = ncclDevPreMulSum; + bf8 = static_cast(float(1.0/comm->nRanks)); + break; + #endif case ncclFloat16: opFull->op = ncclDevPreMulSum; f16 = __float2half(float(1.0/comm->nRanks)); // __double2half not supported pre CUDA 11.x break; -#if defined(RCCL_BFLOAT16) + #if defined(RCCL_BFLOAT16) case ncclBfloat16: opFull->op = ncclDevPreMulSum; bf16 = (hip_bfloat16)(float(1.0/comm->nRanks)); break; -#endif -#if defined(RCCL_FLOAT8) - case ncclFp8E4M3: - opFull->op = ncclDevPreMulSum; - fp8_e4m3 = static_cast(float(1.0/comm->nRanks)); - break; - case ncclFp8E5M2: - opFull->op = ncclDevPreMulSum; - fp8_e5m2 = static_cast(float(1.0/comm->nRanks)); - break; -#endif + #endif case ncclFloat32: opFull->op = ncclDevPreMulSum; f32 = float(1.0/comm->nRanks); @@ -2423,6 +2267,13 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) { // Empty collectives can be discarded. if (info->count == 0) return ncclSuccess; + if (info->datatype == ncclFloat8e4m3 || info->datatype == ncclFloat8e5m2) { + if (comm->minCompCap < 90) { + WARN("FP8 reduction support begins with sm90 capable devices."); + return ncclInvalidArgument; + } + } + // Copy reduction op state from op handle into info struct here since the // op handle may be destroyed before ncclGroupEnd(). struct ncclDevRedOpFull opDev; diff --git a/src/graph/paths.cc b/src/graph/paths.cc index ff96cdba39..73081ba555 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -253,11 +253,31 @@ ncclResult_t ncclGetLevel(int* level, const char* disableEnv, const char* levelE NCCL_PARAM(IgnoreDisabledP2p, "IGNORE_DISABLED_P2P", 0); int ncclTopoUserP2pLevel = -1; -ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int rank1, int rank2, int* p2p, int *read, int* intermediateRank) { +ncclResult_t ncclTopoCheckP2p(struct ncclComm* comm, struct ncclTopoSystem* system, int rank1, int rank2, + int* p2p, int *read, int* intermediateRank) { + int mnnvl = 0; + struct ncclPeerInfo* info1 = NULL; + struct ncclPeerInfo* info2 = NULL; *p2p = 0; if (read) *read = 0; if (intermediateRank) *intermediateRank = -1; + // Rule out different nodes / isolated containers + if (comm) { + info1 = comm->peerInfo+rank1; + info2 = comm->peerInfo+rank2; + if (info1->hostHash != info2->hostHash) { + if (comm->MNNVL) { + NCCLCHECK(ncclTopoCheckMNNVL(comm->topo, info1, info2, &mnnvl)); + if (!mnnvl) return ncclSuccess; + } else { + return ncclSuccess; + } + } else if (info1->shmDev != info2->shmDev) { + return ncclSuccess; + } + } + // Get GPUs from topology int g1, g2; NCCLCHECK(ncclTopoRankToIndex(system, rank1, &g1)); @@ -304,7 +324,8 @@ compare: if (*p2p == 1) { // NCCL_IGNORE_DISABLED_P2P=2 is used by unit tests that don't want to // validate against NVML at all since they are pretending to be on other hw. - if (g1 != g2 && ncclParamIgnoreDisabledP2p() != 2) { + if (g1 != g2 && (comm == NULL || (info1->hostHash == comm->peerInfo[comm->rank].hostHash && + info1->hostHash == info2->hostHash)) && ncclParamIgnoreDisabledP2p() != 2) { int indexes[3] = {-1,-1,-1}; int verticeN = 0; NCCLCHECK(ncclNvmlEnsureInitialized()); @@ -364,14 +385,14 @@ ncclResult_t ncclTopoCheckMNNVL(struct ncclTopoSystem* system, struct ncclPeerIn NCCL_PARAM(NetGdrRead, "NET_GDR_READ", -2); int ncclTopoUserGdrLevel = -1; -ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* system, int64_t busId, int64_t netId, int read, int* useGdr) { +ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* system, int rank, int64_t netId, int read, int* useGdr) { *useGdr = 0; // Get GPU and NET int n, g; NCCLCHECK(ncclTopoIdToIndex(system, NET, netId, &n)); struct ncclTopoNode* net = system->nodes[NET].nodes+n; - NCCLCHECK(ncclTopoIdToIndex(system, GPU, busId, &g)); + NCCLCHECK(ncclTopoRankToIndex(system, rank, &g)); struct ncclTopoNode* gpu = system->nodes[GPU].nodes+g; // Check that both the NIC and GPUs support it @@ -433,12 +454,32 @@ ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* system, int64_t busId, int6 distance = proxyGpu->paths[NET][n].type; } if (distance > netGdrLevel) { - INFO(NCCL_NET,"GPU Direct RDMA Disabled for GPU %lx / HCA %lx (distance %d > %d)", busId, netId, distance, netGdrLevel); + INFO(NCCL_NET,"GPU Direct RDMA Disabled for GPU %d / HCA %lx (distance %d > %d)", rank, netId, distance, netGdrLevel); return ncclSuccess; } *useGdr = 1; - INFO(NCCL_NET,"GPU Direct RDMA Enabled for GPU %lx / HCA %lx (distance %d <= %d), read %d", busId, netId, distance, netGdrLevel, read); + INFO(NCCL_NET,"GPU Direct RDMA Enabled for GPU %d / HCA %lx (distance %d <= %d), read %d", rank, netId, distance, netGdrLevel, read); + return ncclSuccess; +} + +ncclResult_t ncclTopoIsGdrAvail(struct ncclTopoSystem* system, int rank, bool *avail) { + int netNum = system->nodes[NET].count; + int useGdr = 0; + *avail = false; + for (int n = 0; n < netNum; n++) { + int64_t netId = system->nodes[NET].nodes[n].id; + NCCLCHECK(ncclTopoCheckGdr(system, rank, netId, 1, &useGdr)); + if (useGdr) { + *avail = true; + break; + } + NCCLCHECK(ncclTopoCheckGdr(system, rank, netId, 0, &useGdr)); + if (useGdr) { + *avail = true; + break; + } + } return ncclSuccess; } @@ -446,15 +487,20 @@ ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* system, int64_t busId, int6 NCCL_PARAM(NetForceFlush, "NET_FORCE_FLUSH", 0); // Determine whether we need to flush the GDR recv buffers -ncclResult_t ncclTopoNeedFlush(struct ncclTopoSystem* system, int64_t busId, int* flush) { +ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int netDev, int rank, int* flush) { + *flush = 1; + ncclNetProperties_t props; + NCCLCHECK(comm->ncclNet->getProperties(netDev, &props)); + if (props.forceFlush == 1 || ncclParamNetForceFlush()) return ncclSuccess; int g; - NCCLCHECK(ncclTopoIdToIndex(system, GPU, busId, &g)); + struct ncclTopoSystem* system = comm->topo; + NCCLCHECK(ncclTopoRankToIndex(system, rank, &g)); struct ncclTopoNode* gpu = system->nodes[GPU].nodes+g; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) *flush = 1; #else // Flush is required on Ampere and earlier - *flush = gpu->gpu.cudaCompCap < 90 ? 1 : ncclParamNetForceFlush(); + if (gpu->gpu.cudaCompCap >= 90) *flush = 0; #endif return ncclSuccess; } @@ -549,7 +595,7 @@ ncclResult_t ncclTopoGetPxnRanks(struct ncclComm* comm, int** intermediateRanks, NCCLCHECK(ncclTopoGetNetDev(comm, comm->rank, NULL, 0, rank, &netId, NULL, &proxyRank)); if (proxyRank == comm->rank) continue; int useGdr; - NCCLCHECK(ncclTopoCheckGdr(comm->topo, comm->busId, netId, 1, &useGdr)); + NCCLCHECK(ncclTopoCheckGdr(comm->topo, comm->rank, netId, 1, &useGdr)); if (useGdr == 0) continue; int found = 0; for (int r=0; rnodes[GPU].count; g++) { for (int p=0; pnodes[GPU].count; p++) { int p2p; - NCCLCHECK(ncclTopoCheckP2p(system, system->nodes[GPU].nodes[p].gpu.rank, system->nodes[GPU].nodes[g].gpu.rank, &p2p, NULL, NULL)); + NCCLCHECK(ncclTopoCheckP2p(comm, system, system->nodes[GPU].nodes[p].gpu.rank, + system->nodes[GPU].nodes[g].gpu.rank, &p2p, NULL, NULL)); if (p2p == 0) { // Divert all traffic through the CPU int cpu; @@ -695,7 +742,7 @@ ncclResult_t ncclTopoComputePaths(struct ncclTopoSystem* system, struct ncclComm if (gpu->paths[NET][n].type < PATH_PHB) { // Update path when we dont want to / can't use GPU Direct RDMA. int gdr; - NCCLCHECK(ncclTopoCheckGdr(system, system->nodes[GPU].nodes[g].id, netNode->id, 0, &gdr)); + NCCLCHECK(ncclTopoCheckGdr(system, system->nodes[GPU].nodes[g].gpu.rank, netNode->id, 0, &gdr)); if (gdr == 0) { // We cannot use GPU Direct RDMA, divert all traffic through the CPU local to the GPU int localCpu; @@ -787,7 +834,7 @@ ncclResult_t ncclTopoTrimSystem(struct ncclTopoSystem* system, struct ncclComm* for (int g = 0; g < system->nodes[GPU].count; g++) { int64_t netId; NCCLCHECKGOTO(ncclTopoGetLocalNet(system, system->nodes[GPU].nodes[g].gpu.rank, 0, &netId, nullptr), ret, fail); - NCCLCHECKGOTO(ncclTopoCheckGdr(system, system->nodes[GPU].nodes[g].id, netId, 1, &gdr), ret, fail); + NCCLCHECKGOTO(ncclTopoCheckGdr(system, system->nodes[GPU].nodes[g].gpu.rank, netId, 1, &gdr), ret, fail); if (!gdr) break; } if (gdr && !allXgmi) { diff --git a/src/graph/search.cc b/src/graph/search.cc index 73294066dd..5cd44e9f28 100644 --- a/src/graph/search.cc +++ b/src/graph/search.cc @@ -1260,7 +1260,7 @@ ncclResult_t ncclTopoPrintGraph(struct ncclTopoSystem* system, struct ncclTopoGr } } if (system->nodes[NET].count > 0 && system->nodes[GPU].count != system->nRanks && !graph->nIntraChannels) { - sprintf(line+offset, " %s/%lx-%lx", topoNodeTypeStr[NET], NCCL_TOPO_ID_SYSTEM_ID(graph->inter[2*c+1]), NCCL_TOPO_ID_LOCAL_ID(graph->inter[2*c])); + sprintf(line+offset, " %s/%lx-%lx", topoNodeTypeStr[NET], NCCL_TOPO_ID_SYSTEM_ID(graph->inter[2*c+1]), NCCL_TOPO_ID_LOCAL_ID(graph->inter[2*c+1])); offset = strlen(line); } INFO(NCCL_GRAPH, "%s", line); diff --git a/src/graph/topo.cc b/src/graph/topo.cc index 365e440509..54266dddf6 100644 --- a/src/graph/topo.cc +++ b/src/graph/topo.cc @@ -302,7 +302,7 @@ static ncclResult_t ncclTopoPrintRec(struct ncclTopoNode* node, struct ncclTopoN NCCLCHECK(ncclTopoPrintRec(link->remNode, node, line, nextOffset)); } else { if (link->remNode->type == NET) { - sprintf(line+nextOffset, "%s/%lx-%lx (%lx/%d/%f)", topoNodeTypeStr[link->remNode->type], NCCL_TOPO_ID_SYSTEM_ID(link->remNode->id), NCCL_TOPO_ID_LOCAL_ID(link->remNode->id), link->remNode->net.asic, link->remNode->net.port, link->remNode->net.bw); + sprintf(line+nextOffset, "%s/%lx-%lx (%d/%lx/%d/%f)", topoNodeTypeStr[link->remNode->type], NCCL_TOPO_ID_SYSTEM_ID(link->remNode->id), NCCL_TOPO_ID_LOCAL_ID(link->remNode->id), link->remNode->net.collSupport, link->remNode->net.asic, link->remNode->net.port, link->remNode->net.bw); } else { sprintf(line+nextOffset, "%s/%lx-%lx", topoNodeTypeStr[link->remNode->type], NCCL_TOPO_ID_SYSTEM_ID(link->remNode->id), NCCL_TOPO_ID_LOCAL_ID(link->remNode->id)); } @@ -390,6 +390,7 @@ ncclResult_t ncclTopoAddNic(struct ncclXmlNode* xmlNic, struct ncclTopoSystem* s if (strcmp(xmlNet->name, "net") != 0) continue; int index; NCCLCHECK(xmlGetAttrIndex(xmlNet, "dev", &index)); + // This means that the "dev" attribute wasn't set on this net xml node. That means it should not be added to the system topology graph if (index == -1) continue; NCCLCHECK(ncclTopoAddNet(xmlNet, system, nic, systemId, busId)); } @@ -426,7 +427,7 @@ struct kvDict kvDictPciGen[] = { { "2.5 GT/s", 15 }, { "5 GT/s", 30 }, { "8 GT/s", 60 }, { "16 GT/s", 120 }, { "32 GT/s", 240 }, /* Kernel 5.6 and earlier */ { "2.5 GT/s PCIe", 15 }, { "5.0 GT/s PCIe", 30 }, { "8.0 GT/s PCIe", 60 }, { "16.0 GT/s PCIe", 120 }, { "32.0 GT/s PCIe", 240 }, { "64.0 GT/s PCIe", 480 }, { NULL, 60 /* Default fallback */ } }; // x100 Mbps per lane -ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* system, struct ncclTopoNode* parent, int systemId) { +ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* system, struct ncclTopoNode* parent, int systemId, int numaId) { const char* str; int type; @@ -453,9 +454,9 @@ ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* s if (xmlNic != NULL) { type = NIC; // Ignore sub device ID and merge multi-port NICs into one PCI device. - busId &= 0xfffffffffffffff0; struct ncclTopoNode* nicNode = NULL; - int64_t id = NCCL_TOPO_ID(systemId, busId); + int64_t localNicId = NCCL_TOPO_LOCAL_NIC_ID(numaId, busId); + int64_t id = NCCL_TOPO_ID(systemId, localNicId); NCCLCHECK(ncclTopoGetNode(system, &nicNode, type, id)); if (nicNode == NULL) { NCCLCHECK(ncclTopoCreateNode(system, &nicNode, type, id)); @@ -476,7 +477,7 @@ ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* s for (int s=0; snSubs; s++) { struct ncclXmlNode* xmlSubPci = xmlPci->subs[s]; if (strcmp(xmlSubPci->name, "pcilink") != 0) { // PCI links will be added later - NCCLCHECK(ncclTopoAddPci(xmlSubPci, system, node, systemId)); + NCCLCHECK(ncclTopoAddPci(xmlSubPci, system, node, systemId, numaId)); } } } @@ -550,12 +551,14 @@ ncclResult_t ncclTopoAddCpu(struct ncclXmlNode* xmlCpu, struct ncclTopoSystem* s } for (int s=0; snSubs; s++) { struct ncclXmlNode* node = xmlCpu->subs[s]; - if (strcmp(node->name, "pci") == 0) NCCLCHECK(ncclTopoAddPci(node, system, cpu, systemId)); + if (strcmp(node->name, "pci") == 0) NCCLCHECK(ncclTopoAddPci(node, system, cpu, systemId, numaId)); if (strcmp(node->name, "nic") == 0) { struct ncclTopoNode* nic = NULL; - NCCLCHECK(ncclTopoGetNode(system, &nic, NIC, 0)); + int64_t localNicId = NCCL_TOPO_LOCAL_NIC_ID(numaId, 0); + int64_t id = NCCL_TOPO_ID(systemId, localNicId); + NCCLCHECK(ncclTopoGetNode(system, &nic, NIC, id)); if (nic == NULL) { - NCCLCHECK(ncclTopoCreateNode(system, &nic, NIC, NCCL_TOPO_ID(systemId, 0))); + NCCLCHECK(ncclTopoCreateNode(system, &nic, NIC, id)); NCCLCHECK(ncclTopoConnectNodes(cpu, nic, LINK_PCI, LOC_BW)); NCCLCHECK(ncclTopoConnectNodes(nic, cpu, LINK_PCI, LOC_BW)); } @@ -812,14 +815,528 @@ ncclResult_t ncclTopoRefreshBcmP2pLinks(void) { return ncclSuccess; } -ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** system) { +// This is just checking for direct descendence +int ncclTopoCheckPix(ncclXmlNode* common, ncclXmlNode** nodes, int nNodes) { + const char* tempBusId; + // If the common parent isn't a pci switch, then this isn't PIX + NCCLCHECK(xmlGetAttrStr(common, "busid", &tempBusId)); + if (tempBusId == NULL) return 0; + TRACE(NCCL_GRAPH, "Checking pix for busid=%s", tempBusId); + + // All the nodes must have a "nic" which is a parent, and then a pci node (busid) which must be a child of the "common" + for (int i = 0; i < nNodes; i++) { + ncclXmlNode* node = nodes[i]; + if (strcmp(node->name, "net") == 0) { + node = node->parent; + if (node == NULL) return 0; + if (strcmp(node->name, "nic") == 0) { + node = node->parent; + if (node == NULL) return 0; + // All nodes must descend from the same first level pci switch + if (strcmp(node->name, "pci") == 0) { + TRACE(NCCL_GRAPH, "Comparing parent of node=%p to common=%p", node->parent, common); + if (node->parent != common) return 0; + } + } + } + } + + return 1; +} + +#define NCCL_TOPO_XML_DEPTH_MAX 256 +typedef struct xmlNodeStack { + ncclXmlNode* elems[NCCL_TOPO_XML_DEPTH_MAX]; + int tail; + + ncclXmlNode* top() { + if (!empty()) { + return elems[tail - 1]; + } else { + return NULL; + } + } + + ncclXmlNode* pop() { + ncclXmlNode* node = top(); + if (node) { + tail--; + } + return node; + } + + void push(ncclXmlNode* node) { + if (tail < NCCL_TOPO_XML_DEPTH_MAX) { + elems[tail++] = node; + } + } + + bool empty() { + return tail == 0; + } + +} xmlNodeStack; + +// 1. Find the common parent xmlNode between the given set of nodes +ncclResult_t ncclTopoGetPath(ncclXmlNode** nodes, int nNodes, int* path, ncclXmlNode** parent) { + // Track a stack of parents per-net node being merged + xmlNodeStack* parents; + NCCLCHECK(ncclCalloc(&parents, nNodes)); + // Find the common parent + ncclXmlNode* common = NULL; + + if (nNodes == 1) { + common = nodes[0]; + *path = PATH_LOC; + goto out; + } + + for (int i = 0; i < nNodes; i++) { + ncclXmlNode* temp; + temp = nodes[i]; + while (temp) { + parents[i].push(temp); + temp = strcmp(temp->name, "system") == 0 ? NULL : temp->parent; + } + } + + common = NULL; + int c; + c = 1; + while (c && !parents[0].empty()) { + ncclXmlNode* temp = parents[0].top(); + for (int i = 1; i < nNodes; i++) { + if (!parents[i].empty()) { + c &= (temp == parents[i].top()); + } else { + c = 0; + break; + } + } + + if (c) { + common = temp; + if (common == NULL) TRACE(NCCL_GRAPH, "COMMON IS NULL"); + for (int i = 0; i < nNodes; i++) { + parents[i].pop(); + } + // Check multi-port while we still have the mismatched parents + // For multi-port to be true, all parents (peers) must have the busId attribute with all but the last character matching + } else { + int multiPort = 1; + const char* tempBusId; + + NCCLCHECK(xmlGetAttr(temp, "busid", &tempBusId)); + if (tempBusId) { + for (int i = 1; i < nNodes; i++) { + if (!parents[i].empty()) { + const char* busId; + NCCLCHECK(xmlGetAttr(parents[i].top(), "busid", &busId)); + if (busId) { + if (strlen(busId) != strlen(tempBusId)) { + multiPort = 0; + break; + } + if (strncmp(busId, tempBusId, strlen(busId)-1) != 0) { + multiPort = 0; + break; + } + } else { + multiPort = 0; + break; + } + } + } + } else { + multiPort = 0; + } + + if (multiPort) { + *path = PATH_PORT; + goto out; + } + } + } + + if (common == NULL) { + *path = PATH_DIS; + } else if (strcmp(common->name,"system") == 0) { + *path = PATH_SYS; + } else if (strcmp(common->name, "cpu") == 0) { + *path = PATH_PHB; + } else if (strcmp(common->name, "nic") == 0) { + *path = PATH_PORT; + } else if (strcmp(common->name, "net") == 0) { + *path = PATH_PORT; + } else if (ncclTopoCheckPix(common, nodes, nNodes)) { + *path = PATH_PIX; + } else { + *path = PATH_PXB; + } + +out: + *parent = common; + free(parents); + return ncclSuccess; +} + +ncclResult_t ncclTopoMakeUniqueBusId(struct ncclXml* xml, char* busId, struct ncclXmlNode** pciNode, struct ncclXmlNode* parent) { + int i = 0; + int64_t rBusId; + NCCLCHECK(busIdToInt64(busId, &rBusId)); + // Try to find an unused busid - NCCL expects leaf busid to be unique + while (i < 100) { + rBusId++; + TRACE(NCCL_GRAPH, "Trying to make new busId %lx", rBusId); + int64ToBusId(rBusId, busId); + struct ncclXmlNode* temp = NULL; + NCCLCHECK(xmlFindTagKv(xml, "pci", &temp, "busid", busId)); + if (temp == NULL) { + NCCLCHECK(xmlAddNode(xml, parent, "pci", pciNode)); + NCCLCHECK(xmlSetAttr(*pciNode, "busid", busId)); + TRACE(NCCL_GRAPH, "Made new busId %lx", rBusId); + return ncclSuccess; + } + TRACE(NCCL_GRAPH, "Conflicting busId %lx", rBusId); + i++; + } + + WARN("TOPO/NET : Couldn't generate unique busId after %d tries", i); + return ncclInternalError; +} + +ncclResult_t ncclTopoMakePciParent(struct ncclXml* xml, struct ncclXmlNode** parent, struct ncclXmlNode* physNetNode) { + struct ncclXmlNode* newBusId = NULL; + struct ncclXmlNode* pci = physNetNode->parent; + if (pci) { + pci = pci->parent; + if (pci) { + if (strcmp(pci->name, "pci") == 0) { + char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; + memset(busId, 0, sizeof(busId)); + const char* originalBusId; + // Seed busId with the current NIC 0's busId to make discovering a unique hash quicker + NCCLCHECK(xmlGetAttrStr(pci, "busid", &originalBusId)); + snprintf(busId, sizeof(busId), "%s", originalBusId); + NCCLCHECK(ncclTopoMakeUniqueBusId(xml, busId, &newBusId, *parent)); + for (int i = 0; i < pci->nAttrs; i++) { + NCCLCHECK(xmlSetAttr(newBusId, pci->attrs[i].key, pci->attrs[i].value)); + } + NCCLCHECK(xmlSetAttr(newBusId, "busid", busId)); + *parent = newBusId; + } + } + } + + if (newBusId == NULL) { + const char* name; + NCCLCHECK(xmlGetAttr(physNetNode, "name", &name)); + WARN("TOPO/NET : Can't find busId of child 0 %s", name); + return ncclInternalError; + } + + return ncclSuccess; +} + +ncclResult_t ncclTopoMakeVnic(ncclComm_t comm, struct ncclXml* xml, ncclNetVDeviceProps_t* vProps, +struct ncclXmlNode** physNetNodes, struct ncclXmlNode** netNode, ncclResult_t (*makeVDevice)(int*, ncclNetVDeviceProps_t*)) { + if (vProps->ndevs > NCCL_NET_MAX_DEVS_PER_NIC) { + WARN("TOPO/NET : Tried to merge too many NICs. %d > %d", vProps->ndevs, NCCL_NET_MAX_DEVS_PER_NIC); + return ncclInternalError; + } + + // Trigger the merge, then get the new device's properties + int vDevIndex = 0; + ncclResult_t ret = makeVDevice(&vDevIndex, vProps); + if (ret == ncclInvalidUsage) { + WARN("TOPO/NET : Tried merging multiple devices together and failed. Try setting NCCL_NET_MERGE_LEVEL=LOC"); + NCCLCHECK(ret); + } + + INFO(NCCL_GRAPH, "TOPO/NET : Made vNic %d", vDevIndex); + return ncclSuccess; +} + +ncclResult_t ncclTopoForceMerge(ncclComm_t comm, struct ncclXml* xml, char* str, int* placedDevs, ncclNetProperties_t* propsList, struct ncclXmlNode** physNetNodes, int nPhysDevs, ncclResult_t (*makeVDevice)(int*, ncclNetVDeviceProps_t*)) { + INFO(NCCL_ENV|NCCL_NET, "TOPO/NET : Force-fusing NICs using NCCL_NET_FORCE_MERGE=%s", str); + char* semi_token; + char* semi = strtok_r(str, ";", &semi_token); + while (semi) { + TRACE(NCCL_NET, "Fusing %s", semi); + struct netIf userIfs[NCCL_NET_MAX_DEVS_PER_NIC]; + int nUserIfs = parseStringList(semi, userIfs, NCCL_NET_MAX_DEVS_PER_NIC); + if (nUserIfs == 0) { + INFO(NCCL_NET, "NET/IB : Invalid NCCL_NET_FORCE_MERGE specified %s. Couldn't parse substring %s. Please provide a semicolon-delimited list of comma-delimited NIC groups.", + str, semi); + continue; + } + + ncclNetVDeviceProps_t vProps = {0}; + for (int d = 0; d < nPhysDevs; d++) { + if (matchIfList(propsList[d].name, propsList[d].port, userIfs, nUserIfs, 1)) { + vProps.devs[vProps.ndevs++] = d; + } + } + + if (vProps.ndevs != nUserIfs) { + WARN("TOPO/NET : Only matched %d devices, %d requested from %s", + vProps.ndevs, nUserIfs, semi); + return ncclInvalidUsage; + } + + if (vProps.ndevs > NCCL_NET_MAX_DEVS_PER_NIC) { + WARN("Specified fused NIC %s which has too many devices (%d). Max %d", semi, vProps.ndevs, NCCL_NET_MAX_DEVS_PER_NIC); + return ncclInvalidUsage; + } + + struct ncclXmlNode* netNode; + NCCLCHECK(ncclTopoMakeVnic(comm, xml, &vProps, physNetNodes, &netNode, makeVDevice)); + + // Only set that a device is "placed" after successfully making a vNic (it's possible to exit before this) + for (int i = 0; i < vProps.ndevs; i++) { + placedDevs[vProps.devs[i]] = 1; + } + + semi = strtok_r(NULL, ";", &semi_token);; + } + + return ncclSuccess; +} + +ncclResult_t ncclTopoAutoMerge(ncclComm_t comm, struct ncclXml* xml, int mergeLevel, int* placedDevs, ncclNetProperties_t* propsList, struct ncclXmlNode** physNetNodes, int nPhysDevs, ncclResult_t (*makeVDevice)(int*, ncclNetVDeviceProps_t*)) { + // Compute the path type between each device + int* paths = NULL; + ncclResult_t res = ncclSuccess; + ncclCalloc(&paths, nPhysDevs*nPhysDevs); + TRACE(NCCL_GRAPH, "Allocated %d paths", nPhysDevs*nPhysDevs); + for (int i = 0; i < nPhysDevs; i++) { + for (int j = 0; j < nPhysDevs; j++) { + struct ncclXmlNode* nodes[2]; + nodes[0] = physNetNodes[i]; + nodes[1] = physNetNodes[j]; + struct ncclXmlNode* parent; + NCCLCHECKGOTO(ncclTopoGetPath(nodes, 2, &paths[i*nPhysDevs + j], &parent), res, out); + } + } + + // Place all remaining physical devices into a virtual device given the mergeLevel criteria + for (int i = 0; i < nPhysDevs; i++) { + // Select the first unplaced device "i" as the root + if (placedDevs[i] == 0) { + // Init a new vDevice + ncclNetVDeviceProps_t vProps; + vProps = {0}; + vProps.devs[vProps.ndevs++] = i; + placedDevs[i] = 1; + TRACE(NCCL_GRAPH, "Placed dev %d", i); + + // Select each unplaced device "j" which is at most "mergeLevel" distance from "i", but not equal to "i" + // (Don't merge the same device with itself) + for (int j = 0; j < nPhysDevs; j++) { + if (paths[i*nPhysDevs + j] <= mergeLevel && + placedDevs[j] == 0 && j != i) { + vProps.devs[vProps.ndevs++] = j; + placedDevs[j] = 1; + TRACE(NCCL_GRAPH, "Placed dev %d path=%d", j, paths[i*nPhysDevs + j] ); + } + if (vProps.ndevs == NCCL_NET_MAX_DEVS_PER_NIC) break; + } + + if (vProps.ndevs > NCCL_NET_MAX_DEVS_PER_NIC) { + WARN("TOPO/NET : Tried to merge too many NICs. %d > %d", vProps.ndevs, NCCL_NET_MAX_DEVS_PER_NIC); + return ncclInternalError; + } + + struct ncclXmlNode* netNode; + NCCLCHECKGOTO(ncclTopoMakeVnic(comm, xml, &vProps, physNetNodes, &netNode, makeVDevice), res, out); + } + } + +out: + free(paths); + return res; +} + +struct kvDict nicPathKvList[] = { + { "LOC", PATH_LOC }, + { "PORT", PATH_PORT }, + { "PIX", PATH_PIX }, + { "PXB", PATH_PXB }, + { "PXN", PATH_PXN }, + { "PHB", PATH_PHB }, + { "SYS", PATH_SYS }, + { NULL, 0 } +}; + +ncclResult_t ncclTopoGetVNicParent(struct ncclXml* xml, ncclResult_t (*getProperties)(int, ncclNetProperties_t*), ncclNetVDeviceProps_t* vProps, ncclXmlNode** parent) { + ncclNetProperties_t props[NCCL_NET_MAX_DEVS_PER_NIC]; + ncclXmlNode* physNetNodes[NCCL_NET_MAX_DEVS_PER_NIC]; + for (int i = 0; i < vProps->ndevs; i++) { + NCCLCHECK(getProperties(vProps->devs[i], props + i)); + struct ncclXmlNode* physNetNode; + NCCLCHECK(xmlFindTagKv(xml, "net", &physNetNode, "name", props[i].name)); + physNetNodes[i] = physNetNode; + TRACE(NCCL_GRAPH, "Re-found physical ncclNet node %d %s", i, props[i].name); + } + + int path = PATH_LOC; + NCCLCHECK(ncclTopoGetPath(physNetNodes, vProps->ndevs, &path, parent)); + if (path == PATH_LOC) { + *parent = NULL; + } else if (parent && strcmp((*parent)->name, "pci") == 0) { + // If the common parent is PCI, we must reparent the new NIC under a made up busId + NCCLCHECK(ncclTopoMakePciParent(xml, parent, physNetNodes[0])); + } + TRACE(NCCL_GRAPH, "Selected parent %s with path %d", (*parent)->name, path); + return ncclSuccess; +} + +ncclResult_t ncclTopoMakeVNics(ncclComm_t comm, struct ncclXml* xml, ncclResult_t (*makeVDevice)(int*, ncclNetVDeviceProps_t*), ncclResult_t (*getProperties)(int, ncclNetProperties_t*), int physicalDevs) { + int* placedDevs = NULL; + struct ncclXmlNode** physNetNodes = NULL; + if (physicalDevs == 0) return ncclSuccess; + + ncclCalloc(&physNetNodes, physicalDevs); + ncclResult_t res = ncclSuccess; + + ncclNetProperties_t* props = NULL; + ncclCalloc(&props, physicalDevs); + for (int i = 0; i < physicalDevs; i++) { + NCCLCHECKGOTO(getProperties(i, props + i), res, out); + struct ncclXmlNode* physNetNode; + NCCLCHECKGOTO(xmlFindTagKv(xml, "net", &physNetNode, "name", props[i].name), res, out); + physNetNodes[i] = physNetNode; + TRACE(NCCL_GRAPH, "Found physical ncclNet node %d %s", i, props[i].name); + } + + // By default, don't merge any devices + int mergeLevel; + mergeLevel = PATH_PORT; + char* mergeLevelEnv; + mergeLevelEnv = getenv("NCCL_NET_MERGE_LEVEL"); + if (mergeLevelEnv) kvConvertToInt(mergeLevelEnv, &mergeLevel, nicPathKvList); + char* forceMerge; + forceMerge = getenv("NCCL_NET_FORCE_MERGE"); + NCCLCHECK(ncclCalloc(&placedDevs, physicalDevs)); + memset(placedDevs, 0, sizeof(int)*physicalDevs); + + if (forceMerge) { + NCCLCHECKGOTO(ncclTopoForceMerge(comm, xml, forceMerge, placedDevs, props, physNetNodes, physicalDevs, makeVDevice), res, out); + } + NCCLCHECKGOTO(ncclTopoAutoMerge(comm, xml, mergeLevel, placedDevs, props, physNetNodes, physicalDevs, makeVDevice), res, out); + +out: + free(physNetNodes); + free(props); + if (placedDevs) free(placedDevs); + return res; +} + +static ncclResult_t ncclTopoPopulateNics(ncclComm_t comm, ncclXml* xml, int startIndex, int endIndex, ncclResult_t (*getProperties)(int, ncclNetProperties_t*), const char* netName, int coll, int keep, int virtualNics) { + for (int n = startIndex; n < endIndex; n++) { + ncclNetProperties_t props; + NCCLCHECK(getProperties(n, &props)); + struct ncclXmlNode* netNode = NULL; + struct ncclXmlNode* parent = NULL; + if (virtualNics) { + struct ncclXmlNode* net = NULL; + NCCLCHECK(xmlFindTagKv(xml, "net", &net, "name", props.name)); + // In the event of multithreaded use case, we need to re-discover the shared parent of the given devices for this vNIC + // Only run this if the net doesn't exist locally - this may alter the XML state + if (net == NULL) NCCLCHECK(ncclTopoGetVNicParent(xml, getProperties, &props.vProps, &parent)); + } + + NCCLCHECK(ncclTopoFillNet(xml, props.pciPath, props.name, &netNode, parent)); + + const char* colAttr; + NCCLCHECK(xmlGetAttr(netNode, "coll", &colAttr)); + + // If coll == 0 but the netNode is tagged as coll, don't update the keep value + if (colAttr == NULL || coll != 0 || strcmp(colAttr,"1") != 0) NCCLCHECK(xmlSetAttrInt(netNode, "keep", keep)); + NCCLCHECK(xmlSetAttrInt(netNode, "dev", n)); + NCCLCHECK(xmlInitAttrInt(netNode, "latency", props.latency)); + NCCLCHECK(xmlInitAttrInt(netNode, "speed", props.speed)); + NCCLCHECK(xmlInitAttrInt(netNode, "port", props.port)); + NCCLCHECK(xmlInitAttrUint64(netNode, "guid", props.guid)); + NCCLCHECK(xmlInitAttrInt(netNode, "maxconn", props.maxComms)); + bool gdrSupport = (props.ptrSupport & NCCL_PTR_CUDA) || (comm->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF)); + INFO(NCCL_NET,"NET/%s : GPU Direct RDMA %s for HCA %d '%s'", netName, gdrSupport ? "Enabled" : "Disabled", n, props.name); + NCCLCHECK(xmlInitAttrInt(netNode, "gdr", gdrSupport)); + // Only set coll if it's not 0 + if (coll) NCCLCHECK(xmlInitAttrInt(netNode, "coll", coll)); + + const char* keepAttr; + NCCLCHECK(xmlGetAttr(netNode, "coll", &colAttr)); + NCCLCHECK(xmlGetAttr(netNode, "keep", &keepAttr)); + INFO(NCCL_GRAPH, "ncclTopoPopulateNics : Filled %s in topo with pciPath=%s keep=%s coll=%s", + props.name, props.pciPath, keepAttr, colAttr); + } + + return ncclSuccess; +} + +struct ncclTopoNetState { + int nVirtualNics; + int nPhysicalNics; + const char* name; +}; + +// Calls to network plugin APIs should be protected. This function should be called inside a per-process lock. +static ncclResult_t ncclTopoProcessNet(ncclComm_t comm, ncclXml* xml, int coll, const char* dumpXmlFile, ncclTopoNetState* state, ncclResult_t (*getProperties)(int, ncclNetProperties_t*), ncclResult_t (*makeVDevice)(int*, ncclNetVDeviceProps_t*), ncclResult_t (*devices)(int*), const char* netName) { + int usePhysicalDevices = (dumpXmlFile || makeVDevice == NULL); + if (state->nPhysicalNics == -1) NCCLCHECK(devices(&state->nPhysicalNics)); + // Enumerate physical devices + NCCLCHECK(ncclTopoPopulateNics(comm, xml, 0, state->nPhysicalNics, getProperties, netName, coll, 1, 0)); + if (!usePhysicalDevices) { + if (state->nVirtualNics == -1) { + NCCLCHECK(ncclTopoMakeVNics(comm, xml, makeVDevice, getProperties, state->nPhysicalNics)); + int nDevs; + NCCLCHECK(devices(&nDevs)); + state->nVirtualNics = nDevs - state->nPhysicalNics; + } + // Remove keep=1 for physical collnets + if (state->nVirtualNics > 0) { + NCCLCHECK(ncclTopoPopulateNics(comm, xml, 0, state->nPhysicalNics, getProperties, netName, coll, 0, 0)); + // Populate new devices + NCCLCHECK(ncclTopoPopulateNics(comm, xml, state->nPhysicalNics, state->nPhysicalNics+state->nVirtualNics, getProperties, netName, coll, 1, 1)); + } + } + + return ncclSuccess; +} + +static pthread_mutex_t netLock = PTHREAD_MUTEX_INITIALIZER; +ncclTopoNetState netStates[NCCL_NET_MAX_PLUGINS] = {}; +ncclTopoNetState collNetStates[NCCL_NET_MAX_PLUGINS] = {}; +ncclResult_t ncclTopoGetSharedState(ncclTopoNetState** state, const char* name, ncclTopoNetState* states) { + INFO(NCCL_GRAPH, "Retrieving state for %s", name); + for (int i = 0; i < NCCL_NET_MAX_PLUGINS; i++) { + // Empty slot + if (states[i].name == NULL) { + states[i].nVirtualNics = -1; + states[i].nPhysicalNics = -1; + states[i].name = strdup(name); + *state = states + i; + INFO(NCCL_GRAPH, "Initialized state %d for %s", i, name); + return ncclSuccess; + // Found my slot + } else if (strcmp(states[i].name, name) == 0) { + *state = states + i; + return ncclSuccess; + } + } + WARN("NET/TOPO : Couldn't find net with name %s", name); + return ncclInternalError; +} + +ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** system, const char* dumpXmlFile) { ncclResult_t ret = ncclSuccess; struct ncclXml* xml; char* mem = NULL; int* localRanks = NULL; - int netDevCount = 0; struct ncclXml* rankXml; int localRank = -1, nLocalRanks = 0; + int netLockHeld = 0; NCCLCHECK(xmlAlloc(&xml, NCCL_TOPO_XML_MAX_NODES)); const char* xmlTopoFile = ncclGetEnv("NCCL_TOPO_FILE"); if (xmlTopoFile) { @@ -848,47 +1365,25 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy NCCLCHECKGOTO(xmlSetAttrInt(node, "rank", comm->rank), ret, fail); NCCLCHECKGOTO(xmlInitAttrInt(node, "gdr", comm->peerInfo[comm->rank].gdrSupport), ret, fail); } + // Auto-detect NICs if needed. net/collnet share the same xml/graph nodes, // so we start with collnet so that it has precedence. + pthread_mutex_lock(&netLock); + netLockHeld = 1; + INFO(NCCL_GRAPH, "TOPO/NET : Importing network plugins to topology"); + ncclTopoNetState* state; + state = NULL; if (collNetSupport(comm)) { - NCCLCHECKGOTO(collNetDevices(comm, &netDevCount), ret, fail); - for (int n=0; ndmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF)); - INFO(NCCL_NET,"NET/%s : GPU Direct RDMA %s for HCA %d '%s'", comm->ncclNet->name, gdrSupport ? "Enabled" : "Disabled", n, props.name); - NCCLCHECKGOTO(xmlInitAttrInt(netNode, "gdr", gdrSupport), ret, fail); - NCCLCHECKGOTO(xmlInitAttrInt(netNode, "coll", 1), ret, fail); - } - } - if (netDevCount == 0) { - NCCLCHECKGOTO(comm->ncclNet->devices(&netDevCount), ret, fail); - } - for (int n=0; nncclNet->getProperties(n, &props), ret, fail); - comm->netDeviceType = props.netDeviceType; - struct ncclXmlNode* netNode; - NCCLCHECKGOTO(ncclTopoFillNet(xml, props.pciPath, props.name, &netNode), ret, fail); - NCCLCHECKGOTO(xmlSetAttrInt(netNode, "keep", 1), ret, fail); - NCCLCHECKGOTO(xmlSetAttrInt(netNode, "dev", n), ret, fail); - NCCLCHECKGOTO(xmlInitAttrInt(netNode, "speed", props.speed), ret, fail); - NCCLCHECKGOTO(xmlInitAttrInt(netNode, "port", props.port), ret, fail); - NCCLCHECKGOTO(xmlInitAttrFloat(netNode, "latency", props.latency), ret, fail); - NCCLCHECKGOTO(xmlInitAttrUint64(netNode, "guid", props.guid), ret, fail); - NCCLCHECKGOTO(xmlInitAttrInt(netNode, "maxconn", props.maxComms), ret, fail); - bool gdrSupport = (props.ptrSupport & NCCL_PTR_CUDA) || (comm->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF)); - INFO(NCCL_NET,"NET/%s : GPU Direct RDMA %s for HCA %d '%s'", comm->ncclNet->name, gdrSupport ? "Enabled" : "Disabled", n, props.name); - NCCLCHECKGOTO(xmlInitAttrInt(netNode, "gdr", gdrSupport), ret, fail); + NCCLCHECKGOTO(ncclTopoGetSharedState(&state, comm->ncclCollNet->name, collNetStates), ret, fail); + NCCLCHECKGOTO(ncclTopoProcessNet(comm, xml, 1, dumpXmlFile, state, + comm->ncclCollNet->getProperties, comm->ncclCollNet->makeVDevice, comm->ncclCollNet->devices, comm->ncclCollNet->name), ret, fail); } + NCCLCHECKGOTO(ncclTopoGetSharedState(&state, comm->ncclNet->name, netStates), ret, fail); + // [RCCL] Disabled virtual devices + NCCLCHECKGOTO(ncclTopoProcessNet(comm, xml, 0, dumpXmlFile, state, + comm->ncclNet->getProperties, nullptr /*comm->ncclNet->makeVDevice*/, comm->ncclNet->devices, comm->ncclNet->name), ret, fail); + pthread_mutex_unlock(&netLock); + netLockHeld = 0; // Remove XML branches which don't have a node with keep="1" (typically when importing a topology) NCCLCHECKGOTO(ncclTopoTrimXml(xml), ret, fail); @@ -932,19 +1427,21 @@ ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** sy NCCLCHECKGOTO(ncclTopoFuseXml(xml, peerXml), ret, fail); } - xmlTopoFile = ncclGetEnv("NCCL_TOPO_DUMP_FILE"); - if (xmlTopoFile && comm->rank == ncclParamTopoDumpFileRank()) { - INFO(NCCL_ENV, "NCCL_TOPO_DUMP_FILE set by environment to %s", xmlTopoFile); - NCCLCHECKGOTO(ncclTopoDumpXmlToFile(xmlTopoFile, xml), ret, fail); + if (dumpXmlFile && comm->rank == ncclParamTopoDumpFileRank()) { + INFO(NCCL_ENV, "NCCL_TOPO_DUMP_FILE set by environment to %s", dumpXmlFile); + NCCLCHECKGOTO(ncclTopoDumpXmlToFile(dumpXmlFile, xml), ret, fail); } - NCCLCHECKGOTO(ncclTopoGetSystemFromXml(xml, system, comm->peerInfo[comm->rank].hostHash), ret, fail); + // Only update our topo tracking structure if we aren't dumping (separate steps) + if (dumpXmlFile == NULL) NCCLCHECKGOTO(ncclTopoGetSystemFromXml(xml, system, comm->peerInfo[comm->rank].hostHash), ret, fail); + exit: if (!comm->MNNVL && localRanks) free(localRanks); if (mem) free(mem); free(xml); return ret; fail: + if (netLockHeld) pthread_mutex_unlock(&netLock); goto exit; } diff --git a/src/graph/topo.h b/src/graph/topo.h index 1fb6af0641..fdfca50d75 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -85,6 +85,9 @@ extern const char* topoLinkTypeStr[]; // Connection through the network #define PATH_NET 8 +// New type of path which should precede PATH_PIX +#define PATH_PORT PATH_NVL + // Disconnected #define PATH_DIS 9 extern const char* topoPathTypeStr[]; @@ -114,6 +117,7 @@ struct ncclTopoLinkList { #define NCCL_TOPO_ID_LOCAL_ID_MASK 0x00ffffffffffffff #define NCCL_TOPO_ID_SYSTEM_ID(id) (id >> 56) #define NCCL_TOPO_ID_LOCAL_ID(id) (id & NCCL_TOPO_ID_LOCAL_ID_MASK) +#define NCCL_TOPO_LOCAL_NIC_ID(numaid, busid) (((int64_t)numaid << 56) + busid) #define NCCL_TOPO_ID(systemid, localid) (((int64_t)systemid << 56) + (localid & NCCL_TOPO_ID_LOCAL_ID_MASK)) #define RCCL_TOPO_CR8G 1 diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index e1f3a1b1cf..c86fe56a3e 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -32,23 +32,87 @@ static int getNthreads(const char* name, int env, int min, int max, int def, int return nt; } -ncclResult_t parseList(const char* str, const char* elems[], int nelems, int* list) { - int def, set; - if (str[0] == '^') { - def = 1; set = 0; str++; - } else { - def = 0; set = 1; +// Parse a map of prefixes to a list of elements. The first prefix is +// optional and, if not present, the list of elements will be applied +// to all prefixes. Only the first list of elements can lack a +// prefix. Prefixes (if present) are followed by a colon. Lists of +// elements are comma delimited. Mappings of prefix to the lists of +// elements are semi-colon delimited. +// +// For example: +// +// NCCL_ALGO="ring,collnetdirect;allreduce:tree,collnetdirect;broadcast:ring" +// Enable ring and collnetdirect for all functions, then select tree +// and collnetdirect for allreduce and ring for broadcast. +// +// NCCL_PROTO="LL,Simple;allreduce:^LL" +// Enable LL and Simple for all functions, but everything except LL +// for allreduce. +// +// NCCL_PROTO="^LL128;allreduce:LL128" +// Enable everything but LL128, but only LL128 for allreduce. +ncclResult_t parseList(const char* str, const char* prefixElems[], int nprefixes, const char* elems[], int nelems, int* list) { + char* fullStr = strdup(str); + char* tmpFullStr; + char* fullToken = strtok_r(fullStr, ";", &tmpFullStr); + while (fullToken) { + char* subToken = strdup(fullToken); + char* tmpSubStr; + char* prefix = strtok_r(subToken, ":", &tmpSubStr); + char* elemList = strtok_r(NULL, ":", &tmpSubStr); + if (elemList == NULL) { + if (fullToken != fullStr) { + // It makes no sense for any entry other than the first to not have a prefix, + // because then all the prefixes before the prefix-less entry would be + // overwritten. + WARN("All entries except the first must have a prefix: \"%s\"", str); + return ncclInvalidUsage; + } + elemList = prefix; + prefix = NULL; + } + + int unset, set; + if (elemList[0] == '^') { + unset = 1; set = 0; elemList++; + } else { + unset = 0; set = 1; + } + + bool foundPrefix = false; + for (int p=0; p= 90 ? HOPPER_COMPCAP_IDX : minCompCap >= 80 ? AMPERE_COMPCAP_IDX : VOLTA_COMPCAP_IDX; - int cpuArch, cpuVendor, cpuModel; - NCCLCHECK(ncclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel)); int index2 = nNodes <= 2 ? nNodes-1 : 2; // LL: for single node, we look at GPU type; for multi-node, we look at CPU type - int index1 = nNodes == 1 ? compCapIndex : cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0; + int index1 = nNodes == 1 ? compCapIndex : + (comm->cpuVendor == NCCL_TOPO_CPU_VENDOR_AMD || comm->cpuVendor == NCCL_TOPO_CPU_VENDOR_MIXED) ? 1 : 0; double llMaxBw = llMaxBws[index1][index2]; double perChMaxTreeBw = perChMaxTreeBws[compCapIndex][index2]; double perChMaxRingLL128Bw = perChMaxRingLL128Bws[compCapIndex][index2]; double perChMaxTreeLL128Bw = perChMaxTreeLL128Bws[compCapIndex][index2]; // De-penalize Tree/Simple latency on Power systems to favor Tree than Ring - //if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]; + //if (comm->cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]; float ppn = (float)nRanks / nNodes; int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS]; @@ -419,7 +482,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw); if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), graphs[a]->nChannels*perChMaxTreeLL128Bw); if (a == NCCL_ALGO_TREE && graphs[a]->pattern == NCCL_TOPO_PATTERN_TREE) busBw *= .85; - if (a == NCCL_ALGO_PAT) busBw *= .85; + if (a == NCCL_ALGO_PAT) busBw *= .75; if (a == NCCL_ALGO_COLLNET_DIRECT && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used if (a == NCCL_ALGO_COLLNET_CHAIN && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used if (a == NCCL_ALGO_COLLNET_DIRECT && p == NCCL_PROTO_SIMPLE) { @@ -456,10 +519,6 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom busBw *= ratio; } comm->bandwidths[coll][a][p] = busBw; - /* Ring bandwidth backup */ - if (a == NCCL_ALGO_RING) - comm->ringbdw[coll][p] = comm->bandwidths[coll][NCCL_ALGO_RING][p]; - comm->latencies[coll][a][p] = baseLat[a][p]; float intraLat = rcclTuningModel[comm->topo->tuning].hwLat[intraHw[a]][a][p]; float interLat = ppn == 1 ? rcclTuningModel[comm->topo->tuning].hwLat[NCCL_HW_NET][NCCL_ALGO_TREE][p] : rcclTuningModel[comm->topo->tuning].hwLat[NCCL_HW_NET][a][p]; @@ -514,42 +573,79 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom // Protocols/Algorithms enable/disable, and user overrides. // All are enabled except ll128 which is enabled by default only in certain cases. - int protoEnable[NCCL_NUM_PROTOCOLS] = { 1, 2, 1 }; - int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1, 1, 1, 1 }; + int protoEnable[NCCL_NUM_FUNCTIONS*NCCL_NUM_PROTOCOLS]; + int algoEnable[NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS]; + for (int f=0; fnNodes == 1) algoEnable[NCCL_ALGO_NVLS_TREE] = 0; - - // Disable CollNet if it is not supported - if (comm->collNetSupport == 0) { - algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0; - algoEnable[NCCL_ALGO_COLLNET_CHAIN] = 0; - if (nNodes > 1) algoEnable[NCCL_ALGO_NVLS] = 0; - // If user has hard set NCCL_ALGO=COLLNET, ignore it - if (algoEnable[NCCL_ALGO_RING] == 0 && algoEnable[NCCL_ALGO_TREE] == 0 && - algoEnable[NCCL_ALGO_NVLS] == 0 && algoEnable[NCCL_ALGO_NVLS_TREE] == 0) { - algoEnable[NCCL_ALGO_RING] = algoEnable[NCCL_ALGO_TREE] = 1; + if (comm->rank == 0 && (algoStr||protoStr)) { + constexpr int strLength = 1024; + char funcAlgoProtoTuningStr[strLength]; + int offset = 0; + offset += snprintf(funcAlgoProtoTuningStr+offset, std::max(0, strLength-offset), "\n Function | "); + for (int p=0; ptopo, &nvsCount)); + + for (int f=0; fnNodes == 1 && a == NCCL_ALGO_NVLS_TREE) disable = 1; + // Disable Collnet+Direct, Collnet+Chain or Collnet+NVLS if collnet is not supported. + if (comm->collNetSupport == 0 && + (a == NCCL_ALGO_COLLNET_DIRECT || + a == NCCL_ALGO_COLLNET_CHAIN || + (a == NCCL_ALGO_NVLS && comm->nNodes > 1))) disable = 1; + // Disable CollNet+Direct if not on an NVSwitch system + if (nvsCount == 0 && a == NCCL_ALGO_COLLNET_DIRECT) disable = 1; + if (disable) algoEnable[f*NCCL_NUM_ALGORITHMS+a] = 0; } - } else { - // Disable CollNet+Direct if not on an NVSwitch system - int nvsCount = 0; - NCCLCHECK(ncclTopoGetNvsCount(comm->topo, &nvsCount)); - if (nvsCount == 0) algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0; } for (int c=0; ctopo->nodes[GPU].nodes[0].gpu.gcn, "gfx12")) ? 0 : protoEnable[p]; + int pEnable = (p == NCCL_PROTO_LL && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx12")) ? 0 : protoEnable[c*NCCL_NUM_PROTOCOLS+p]; if (p == NCCL_PROTO_LL128) { #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) #if defined(ENABLE_LL128) @@ -575,66 +671,51 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom #endif } if (pEnable == 0) comm->bandwidths[c][a][p] = 0; - if (algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0; - //if (a == NCCL_ALGO_RING && pEnable == 0) comm->ringbdw[c][p] = 0; - } - - for (int c = 0; c < NCCL_NUM_FUNCTIONS; c++) { - bool available = false; - for (int a = 0; a < NCCL_NUM_ALGORITHMS; a++) - for (int p = 0; p < NCCL_NUM_PROTOCOLS; p++) - if (comm->bandwidths[c][a][p] != 0) { - available = true; - goto check_avail; - } - check_avail: - if (available == false) { - /* at least set ring algo available */ - for (int p = 0; p < NCCL_NUM_PROTOCOLS; p++) - comm->bandwidths[c][NCCL_ALGO_RING][p] = comm->ringbdw[c][p]; - } + if (algoEnable[c*NCCL_NUM_ALGORITHMS+a] == 0) comm->bandwidths[c][a][p] = 0; } if (comm->rank == 0) { - char line[1024]; + constexpr int lineLen = 1024; + char line[lineLen]; + int offset = 0; for (int block=0; block= NCCL_NUM_ALGORITHMS) continue; - sprintf(line+strlen(line), " %14s %14s %14s |", "", ncclAlgoStr[a], ""); + offset += snprintf(line+offset, std::max(0, lineLen-offset), " %14s %14s %14s |", "", ncclAlgoStr[a], ""); } INFO(NCCL_TUNING, "%s", line); - sprintf(line, " Protocol |"); + offset = snprintf(line, lineLen, " Protocol |"); for (int ba=0; ba<3; ba++) { for (int p=0; p= NCCL_NUM_ALGORITHMS) continue; for (int p=0; pmaxThreads[a][p]); + offset += snprintf(line+offset, std::max(0, lineLen-offset), " %14d |", comm->maxThreads[a][p]); } } INFO(NCCL_TUNING, "%s", line); for (int c=0; c= NCCL_NUM_ALGORITHMS) continue; for (int p=0; platencies[c][a][p], comm->bandwidths[c][a][p]); + offset += snprintf(line+offset, std::max(0, lineLen-offset), "%8.1f/%6.1f |", comm->latencies[c][a][p], comm->bandwidths[c][a][p]); } } INFO(NCCL_TUNING, "%s", line); } } } - + // Set per-thread amount of work before we increase nThreads and nChannels for (int a=0; athreadThresholds[a][NCCL_PROTO_LL] = NCCL_LL_THREAD_THRESHOLD; @@ -678,19 +759,10 @@ static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][23] = { { .9, .9, .9, .9, .9, .9, .9, .8, .7, .6, .6, .5, .5, .5, .5, .6, .7, .8, .7, .7, .8, .9, .9 } }; -ncclResult_t ncclTopoGetAlgoTime(struct ncclComm* comm, int coll, int algorithm, int protocol, size_t nBytes, int numPipeOps, float* time, bool* backup) { +ncclResult_t ncclTopoGetAlgoTime(struct ncclComm* comm, int coll, int algorithm, int protocol, size_t nBytes, int numPipeOps, float* time) { float bw = comm->bandwidths[coll][algorithm][protocol]; float lat = comm->latencies[coll][algorithm][protocol]; - if (backup) { - *backup = false; - if (algorithm == NCCL_ALGO_RING && bw == 0.0f) { - /* try back up RING algorithm */ - bw = comm->ringbdw[coll][protocol]; - *backup = true; - } - } - if (bw == 0) { *time = -1.0; return ncclSuccess; } diff --git a/src/graph/xml.cc b/src/graph/xml.cc index f96d14ca0d..f81a854014 100644 --- a/src/graph/xml.cc +++ b/src/graph/xml.cc @@ -20,6 +20,9 @@ #include #endif +// Arbitrarily large number for constructing virtual topology string +#define NCCL_MAX_XML_DEPTH 1024 + /*******************/ /* XML File Parser */ /*******************/ @@ -442,7 +445,7 @@ static ncclResult_t getBcmLinks(const char* busId, int* nlinks, char** peers) { ncclResult_t ncclTopoGetStrFromSys(const char* path, const char* fileName, char* strValue) { char filePath[PATH_MAX]; - sprintf(filePath, "%s/%s", path, fileName); + snprintf(filePath, sizeof(filePath), "%s/%s", path, fileName); int offset = 0; FILE* file; if ((file = fopen(filePath, "r")) != NULL) { @@ -974,7 +977,7 @@ ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct nccl // where sysPath/subsystem points to. ncclResult_t ncclTopoGetSubsystem(const char* sysPath, char* subSys) { char subSysPath[PATH_MAX]; - sprintf(subSysPath, "%s/subsystem", sysPath); + snprintf(subSysPath, sizeof(subSysPath), "%s/subsystem", sysPath); char* path = realpath(subSysPath, NULL); if (path == NULL) { subSys[0] = '\0'; @@ -987,8 +990,9 @@ ncclResult_t ncclTopoGetSubsystem(const char* sysPath, char* subSys) { return ncclSuccess; } -ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode) { +ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode, struct ncclXmlNode* forceParent) { NCCLCHECK(xmlFindTagKv(xml, "net", netNode, "name", netName)); + if (*netNode != NULL) return ncclSuccess; const char* pciSysPath = pciPath; @@ -997,13 +1001,15 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha NCCLCHECK(ncclTopoGetSubsystem(pciSysPath, subSystem)); // This is not a PCI device (virtual, usb, ...). if (strcmp(subSystem, "pci") != 0) { - INFO(NCCL_GRAPH, "Topology detection: network path %s is not a PCI device (%s). Attaching to first CPU", pciSysPath, subSystem); + INFO(NCCL_NET|NCCL_GRAPH, "Topology detection: network path %s is not a PCI device (%s). Attaching to first CPU", pciSysPath, subSystem); pciSysPath = NULL; } } struct ncclXmlNode* parent = NULL; - if (pciSysPath) { + if (forceParent) { + parent = forceParent; + } else if (pciSysPath) { int offset; for (offset=strlen(pciSysPath)-1; pciSysPath[offset] != '/'; offset--); char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; diff --git a/src/graph/xml.h b/src/graph/xml.h index b91afc349e..c1885b9c5d 100644 --- a/src/graph/xml.h +++ b/src/graph/xml.h @@ -52,7 +52,7 @@ ncclResult_t ncclTopoGetXmlGraphFromFile(const char* xmlGraphFile, struct ncclXm /* Auto-detect functions */ ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode); -ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode); +ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode, struct ncclXmlNode* forceParent=NULL); /* Remove unneeded parts */ ncclResult_t ncclTopoTrimXml(struct ncclXml* xml); @@ -136,6 +136,13 @@ static ncclResult_t xmlGetAttrFloat(struct ncclXmlNode* node, const char* attrNa return ncclSuccess; } +static ncclResult_t xmlGetAttrFloatDefault(struct ncclXmlNode* node, const char* attrName, float* value, float defaultValue) { + const char* str; + NCCLCHECK(xmlGetAttr(node, attrName, &str)); + *value = str ? strtof(str, NULL) : defaultValue; + return ncclSuccess; +} + static ncclResult_t xmlFindTag(struct ncclXml* xml, const char* tagName, struct ncclXmlNode** node) { *node = NULL; for (int i=0; imaxIndex; i++) { @@ -212,6 +219,24 @@ static ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, c return ncclSuccess; } +static ncclResult_t xmlPrintNodeRecursive(struct ncclXmlNode* node, const char* name) { + while (node) { + char line[1024*8]; + int cursor = 0; + snprintf(line, sizeof(line), "name); + for (int i = 0; i < node->nAttrs; i++) { + cursor = strlen(line); + snprintf(line + cursor, sizeof(line) - cursor, " %s=%s", node->attrs[i].key, node->attrs[i].value); + } + cursor = strlen(line); + snprintf(line + cursor, sizeof(line) - cursor, ">"); + INFO(NCCL_GRAPH, "%s", line); + node = node->parent; + } + return ncclSuccess; +} + + static ncclResult_t xmlSetAttrIfUnset(struct ncclXmlNode* node, const char* attrName, const char* value) { int index; NCCLCHECK(xmlGetAttrIndex(node, attrName, &index)); diff --git a/src/group.cc b/src/group.cc index f5d2d12e7e..c18b5acb05 100644 --- a/src/group.cc +++ b/src/group.cc @@ -344,7 +344,7 @@ static void groupCleanup(struct ncclComm** groupCommHeadPtr, struct ncclComm** g /* reset everything */ while (!ncclIntruQueueEmpty(asyncJobsPtr)) { struct ncclAsyncJob* job = ncclIntruQueueDequeue(asyncJobsPtr); - if (job->comm && !job->comm->config.blocking) + if (!job->destroyFlag && job->comm && !job->comm->config.blocking) (void) ncclCommSetAsyncError(job->comm, error); if (job->undo) job->undo(job); if (job->destructor) job->destructor((void*)job); @@ -413,7 +413,6 @@ fail: } static ncclResult_t groupLaunch(struct ncclAsyncJob *job_, ncclSimInfo_t* simInfo = NULL) { - int savedDev; ncclResult_t ret = ncclSuccess; struct ncclGroupJob *gjob = (struct ncclGroupJob*) job_; struct ncclComm *groupCommHeadMain = *gjob->groupCommHeadPtr; @@ -422,8 +421,6 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_, ncclSimInfo_t* simInf bool *groupAbortFlag = gjob->abortFlagPtr; - CUDACHECKGOTO(cudaGetDevice(&savedDev), ret, fail); - if (!simInfo && groupCommPreconnectHeadMain != nullptr) { struct ncclComm* comm = groupCommPreconnectHeadMain; do { @@ -475,12 +472,19 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_, ncclSimInfo_t* simInf } comm = comm->groupNext; } while (comm); - NCCLCHECKGOTO(asyncJobLaunch(&asyncCollJobs, groupAbortFlag), ret, fail); while (!ncclIntruQueueEmpty(&asyncCollJobs)) { struct ncclAsyncJob* job = ncclIntruQueueDequeue(&asyncCollJobs); if (job->destructor) job->destructor((void*)job); } + + // done with all buffer allocation, start registration and enqueue + comm = groupCommHeadMain; + do { + CUDACHECKGOTO(cudaSetDevice(comm->cudaDev), ret, fail); + NCCLCHECKGOTO(ncclTasksRegAndEnqueue(comm), ret, fail); + comm = comm->groupNext; + } while (comm); } if ((!simInfo) && (groupCommHeadMain != nullptr)) { @@ -497,6 +501,9 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_, ncclSimInfo_t* simInf while (groupCommHeadMain != nullptr) { struct ncclComm* comm = groupCommHeadMain; struct ncclComm* next = comm->groupNext; + // Poll for callbacks sent to us from other threads. Typically these free + // resources from to our memory pools and UB + NCCLCHECKGOTO(ncclCommPollCallbacks(comm, /*waitSome=*/false), ret, fail); (void) ncclGroupCommLeave(comm); if (!comm->config.blocking) { (void) ncclCommSetAsyncError(comm, ret); @@ -504,8 +511,6 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_, ncclSimInfo_t* simInf groupCommHeadMain = next; } - CUDACHECK(cudaSetDevice(savedDev)); - exit: return ret; fail: @@ -588,7 +593,10 @@ ncclResult_t ncclGroupEndInternal(ncclSimInfo_t* simInfo) { ret = ncclInProgress; } else { /* blocking group */ + int savedDev; + CUDACHECKGOTO(cudaGetDevice(&savedDev), ret, fail); NCCLCHECKGOTO(groupLaunch(&ncclGroupJobMainPtr->base, internalSimInfoPtr), ret, fail); + CUDACHECKGOTO(cudaSetDevice(savedDev), ret, fail); if (simInfo) memcpy((void*)simInfo, (void*)internalSimInfoPtr, realSize); groupResetJobState(ncclGroupJobMainPtr); } diff --git a/src/include/collectives.h b/src/include/collectives.h index 93fd17c7c6..d9e653c760 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -12,6 +12,7 @@ #include "nccl.h" #include "nccl_common.h" #include "device.h" +#define NCCL_MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two. // CHUNKSIZE must be a multiple of SLICESIZE #define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4) @@ -25,6 +26,7 @@ #define REDUCE_SLICESTEPS 1 #define REDUCE_CHUNKSTEPS 1 #define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above +#define NCCL_MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two. #define ALLTOALL_PIVOT_SLICESTEPS 2 #define ALLTOALL_PIVOT_CHUNKSTEPS 4 @@ -38,15 +40,11 @@ inline int ncclTypeSize(ncclDataType_t type) { switch (type) { case ncclInt8: case ncclUint8: -#if defined(RCCL_FLOAT8) - case ncclFp8E4M3: - case ncclFp8E5M2: -#endif + case ncclFloat8e4m3: + case ncclFloat8e5m2: return 1; case ncclFloat16: -#if defined(RCCL_BFLOAT16) case ncclBfloat16: -#endif return 2; case ncclInt32: case ncclUint32: @@ -75,6 +73,319 @@ struct ncclConnFifo { #include +class RingAlgorithm { +protected: + int refCount; + int nRanks; + int nStepsPerLoop; + int chunkSteps; + int sliceSteps; + ssize_t sliceSize; + ssize_t loopSize; + ssize_t channelSize; + uint8_t *sendbuff; + uint8_t *recvbuff; + void *sendMhandle; + void *recvMhandle; + void *srecvMhandle; +public: + // this ring class is used by proxy thread to retrieve the send and recv buffer, size as well as corresponding + // mem handle based on the current step of the proxy args. The derived ring algo class is AR, AG, and BC which + // would be allocated during enqueue stage and copied to proxy side through shared memory. For each copy, we will + // increase the refCount by incRefCount() since the same ring algo object can be referenced multiple times for send + // and recv progress. After all steps are done, we decrease the refCount and only delete the ring object when + // refCount == 0. + virtual void getNextSendAddr(int curStep, uint8_t **sendbuffOut, size_t *sizeOut, void **mhandleOut) = 0; + virtual void getNextRecvAddr(int curStep, uint8_t **recvbuffOut, size_t *sizeOut, void **mhandleOut) = 0; + int incRefCount() { + return __atomic_add_fetch(&refCount, 1, __ATOMIC_RELAXED); + } + int decRefCount() { + return __atomic_sub_fetch(&refCount, 1, __ATOMIC_RELEASE); + } + RingAlgorithm() { refCount = 0; } + virtual ~RingAlgorithm() {}; +}; + +class RingARAlgorithm : public RingAlgorithm { +private: + int ringIndex; + int elemSize; + ssize_t chunkSize; + int slicePerChunk; +public: + void getNextSendAddr(int curStep, uint8_t **sendbuffOut, size_t *sizeOut, void **mhandleOut) { + int curLoop = curStep / nStepsPerLoop; + int curLoopStage = (curStep % nStepsPerLoop) / chunkSteps; + int chunkStage = curLoopStage % nRanks; + int sliceStage = (curStep % chunkSteps) / sliceSteps; + ssize_t elemOffset = curLoop * loopSize; + ssize_t remSize = channelSize - elemOffset; + ssize_t chunkOffset; + ssize_t sliceOffset; + ssize_t curSliceSize; + ssize_t curChunkSize; + ssize_t size; + ssize_t nelem; + int chunkId; + + if (remSize < loopSize) { + curChunkSize = alignUp(divUp(remSize / elemSize, nRanks), 16 / elemSize) * elemSize; + } else { + curChunkSize = chunkSize; + } + chunkId = (ringIndex + nRanks - 1 - chunkStage) % nRanks; + chunkOffset = chunkId * curChunkSize; + nelem = std::min(remSize - chunkOffset, curChunkSize); + curSliceSize = std::max(divUp(nelem / elemSize, 16 * slicePerChunk) * 16, sliceSize / elemSize / 32) * elemSize; + sliceOffset = sliceStage * curSliceSize; + + if (nelem <= sliceOffset) { + *sendbuffOut = sendbuff; + *mhandleOut = sendMhandle; + } else { + if (curLoopStage == 0) { + *sendbuffOut = sendbuff + elemOffset + chunkOffset + sliceOffset; + *mhandleOut = sendMhandle; + } else { + *sendbuffOut = recvbuff + elemOffset + chunkOffset + sliceOffset; + *mhandleOut = srecvMhandle; + } + } + size = std::min(curSliceSize, nelem - sliceOffset); + *sizeOut = size < 0 ? 0 : size; + return; + } + + void getNextRecvAddr(int curStep, uint8_t **recvbuffOut, size_t *sizeOut, void **mhandleOut) { + int curLoop = curStep / nStepsPerLoop; + int curLoopStage = ((curStep + chunkSteps) % nStepsPerLoop) / chunkSteps; + int chunkStage = curLoopStage % nRanks; + int sliceStage = (curStep % chunkSteps) / sliceSteps; + ssize_t elemOffset = curLoop * loopSize; + ssize_t remSize = channelSize - elemOffset; + ssize_t chunkOffset; + ssize_t sliceOffset; + ssize_t curSliceSize; + ssize_t curChunkSize; + ssize_t size; + ssize_t nelem; + int chunkId; + + if (remSize < loopSize) { + curChunkSize = alignUp(divUp(remSize / elemSize, nRanks), 16 / elemSize) * elemSize; + } else { + curChunkSize = chunkSize; + } + + if (curLoopStage == 0) { + chunkId = (ringIndex + 1) % nRanks; + } else { + chunkId = (ringIndex + nRanks - 1 - chunkStage) % nRanks; + } + + chunkOffset = chunkId * curChunkSize; + nelem = std::min(remSize - chunkOffset, curChunkSize); + curSliceSize = std::max(divUp(nelem / elemSize, 16 * slicePerChunk) * 16, sliceSize / elemSize / 32) * elemSize; + sliceOffset = sliceStage * curSliceSize; + if (nelem <= sliceOffset) { + *recvbuffOut = recvbuff; + } else { + *recvbuffOut = recvbuff + elemOffset + chunkOffset + sliceOffset; + } + if (sizeOut) { + size = std::min(curSliceSize, nelem - sliceOffset); + *sizeOut = size < 0 ? 0 : size; + } + *mhandleOut = recvMhandle; + return; + } + + RingARAlgorithm(const void *sendbuff, void *recvbuff, int nRanks, int ringIndex, int chunkSteps, int sliceSteps, size_t chunkSize, size_t sliceSize, size_t gridOffset, size_t channelSize, int elemSize, void *sendMhandle, void *recvMhandle, void *srecvMhandle) { + this->ringIndex = ringIndex; + this->nRanks = nRanks; + this->nStepsPerLoop = 2 * (nRanks - 1) * chunkSteps; + this->chunkSteps = chunkSteps; + this->sliceSteps = sliceSteps; + this->chunkSize = chunkSize; + this->sliceSize = sliceSize; + this->loopSize = nRanks * chunkSize; + this->sendbuff = (uint8_t*)sendbuff + gridOffset; + this->recvbuff = (uint8_t*)recvbuff + gridOffset; + this->channelSize = channelSize; + this->elemSize = elemSize; + this->sendMhandle = sendMhandle; + this->recvMhandle = recvMhandle; + this->srecvMhandle = srecvMhandle; + this->slicePerChunk = chunkSteps / sliceSteps; + } + ~RingARAlgorithm() {} +}; + +class RingAGAlgorithm : public RingAlgorithm { +private: + int *ringRanks; + int elemSize; + ssize_t sendSize; + int slicePerChunk; +public: + void getNextSendAddr(int curStep, uint8_t **sendbuffOut, size_t *sizeOut, void **mhandleOut) { + int curLoop = curStep / nStepsPerLoop; + int chunkStage = (curStep % nStepsPerLoop) / chunkSteps; + int sliceStage = (curStep % chunkSteps) / sliceSteps; + ssize_t sliceOffset; + ssize_t curSliceSize; + ssize_t offset; + ssize_t elemOffset = curLoop * loopSize; + ssize_t chunkSize = std::min(loopSize, channelSize - elemOffset); + ssize_t size; + int rankDest; + uint8_t *buff; + void *mhandle; + + curSliceSize = std::max(divUp(chunkSize / elemSize, 16 * slicePerChunk) * 16, sliceSize / elemSize / 32) * elemSize; + sliceOffset = sliceStage * curSliceSize; + if (chunkStage == 0) { + rankDest = ringRanks[0]; + offset = elemOffset + sliceOffset; + buff = sendbuff + offset; + mhandle = sendMhandle; + } else { + rankDest = ringRanks[nRanks - chunkStage]; + offset = elemOffset + rankDest * sendSize + sliceOffset; + buff = recvbuff + offset; + mhandle = srecvMhandle; + } + *sendbuffOut = buff; + size = std::min(curSliceSize, channelSize - elemOffset - sliceOffset); + *sizeOut = size < 0 ? 0 : size; + *mhandleOut = mhandle; + return; + } + + void getNextRecvAddr(int curStep, uint8_t **recvbuffOut, size_t *sizeOut, void **mhandleOut) { + int curLoop = curStep / nStepsPerLoop; + int chunkStage = ((curStep + chunkSteps) % nStepsPerLoop) / chunkSteps; + int sliceStage = (curStep % chunkSteps) / sliceSteps; + ssize_t sliceOffset; + ssize_t curSliceSize; + ssize_t offset; + ssize_t elemOffset = curLoop * loopSize; + ssize_t chunkSize = std::min(loopSize, channelSize - elemOffset); + ssize_t size; + int rankDest; + + curSliceSize = std::max(divUp(chunkSize / elemSize, 16 * slicePerChunk) * 16, sliceSize / elemSize / 32) * elemSize; + sliceOffset = sliceStage * curSliceSize; + if (chunkStage == 0) { + rankDest = ringRanks[1]; + } else { + rankDest = ringRanks[nRanks - chunkStage]; + } + offset = elemOffset + rankDest * sendSize + sliceOffset; + *recvbuffOut = recvbuff + offset; + if (sizeOut) { + size = std::min(sliceSize, channelSize - elemOffset - sliceOffset); + *sizeOut = size < 0 ? 0 : size; + } + *mhandleOut = recvMhandle; + } + + RingAGAlgorithm(const void *sendbuff, void *recvbuff, int nRanks, int *ringRanks, int chunkSteps, int sliceSteps, size_t chunkSize, size_t sliceSize, size_t gridOffset, size_t channelSize, int elemSize, size_t sendSize, void *sendMhandle, void *recvMhandle, void *srecvMhandle) { + this->ringRanks = ringRanks; + this->nRanks = nRanks; + this->nStepsPerLoop = (nRanks - 1) * chunkSteps; + this->chunkSteps = chunkSteps; + this->sliceSteps = sliceSteps; + this->elemSize = elemSize; + this->sliceSize = sliceSize; + this->loopSize = chunkSize; + this->sendSize = sendSize; + this->channelSize = channelSize; + this->sendbuff = (uint8_t*)sendbuff + gridOffset; + this->recvbuff = (uint8_t*)recvbuff + gridOffset; + this->sendMhandle = sendMhandle; + this->recvMhandle = recvMhandle; + this->srecvMhandle = srecvMhandle; + this->slicePerChunk = chunkSteps / sliceSteps; + } + ~RingAGAlgorithm() {} +}; + +class RingBCAlgorithm : public RingAlgorithm { +private: + int root; + int rank; + int nextRank; +public: + void getNextSendAddr(int curStep, uint8_t **sendbuffOut, size_t *sizeOut, void **mhandleOut) { + int curLoop = curStep / nStepsPerLoop; + int sliceStage = (curStep % chunkSteps) / sliceSteps; + ssize_t sliceOffset = sliceStage * sliceSize; + ssize_t offset; + ssize_t elemOffset = curLoop * loopSize; + ssize_t size; + uint8_t *buff; + void *mhandle; + + offset = elemOffset + sliceOffset; + if (offset >= channelSize) { + buff = sendbuff; + mhandle = sendMhandle; + } else if (rank == root) { + buff = sendbuff + offset; + mhandle = sendMhandle; + } else { + buff = recvbuff + offset; + mhandle = srecvMhandle; + } + *sendbuffOut = buff; + size = std::min(sliceSize, channelSize - offset); + *sizeOut = size < 0 ? 0 : size; + *mhandleOut = mhandle; + return; + } + + void getNextRecvAddr(int curStep, uint8_t **recvbuffOut, size_t *sizeOut, void **mhandleOut) { + int curLoop = curStep / nStepsPerLoop; + int sliceStage = (curStep % chunkSteps) / sliceSteps; + ssize_t sliceOffset = sliceStage * sliceSize; + ssize_t offset; + ssize_t elemOffset = curLoop * loopSize; + ssize_t size; + offset = elemOffset + sliceOffset; + if (offset >= channelSize) { + *recvbuffOut = recvbuff; + } else { + *recvbuffOut = recvbuff + offset; + } + if (sizeOut) { + size = std::min(sliceSize, channelSize - offset); + *sizeOut = size < 0 ? 0 : size; + } + *mhandleOut = recvMhandle; + return; + } + + RingBCAlgorithm(const void* sendbuff, void* recvbuff, int rank, int root, int nRanks, int *ringRanks, int chunkSteps, int sliceSteps, size_t chunkSize, size_t sliceSize, size_t gridOffset, size_t channelSize, void *sendMhandle, void *recvMhandle, void *srecvMhandle) { + this->root = root; + this->rank = rank; + this->nextRank = ringRanks[1]; + this->nStepsPerLoop = chunkSteps; + this->chunkSteps = chunkSteps; + this->sliceSteps = sliceSteps; + this->sliceSize = sliceSize; + this->loopSize = chunkSize; + this->channelSize = channelSize; + this->sendbuff = (uint8_t*)sendbuff + gridOffset; + this->recvbuff = (uint8_t*)recvbuff + gridOffset; + this->sendMhandle = sendMhandle; + this->recvMhandle = recvMhandle; + this->srecvMhandle = srecvMhandle; + } + ~RingBCAlgorithm() {} +}; + template class PatRSAlgorithm{ size_t offset; @@ -540,10 +851,10 @@ restart: int sendDataRank = (rank + nranks + s) % nranks; outIx = sendDataRank * count + offset; recvDim = s ? firstBitSet(s, nrPow2) : -1; - s -= (1<> (recvDim+1); recvOffset = (foffset%postFreq)*nelem; recvStepOffset = foffset / postFreq; diff --git a/src/include/comm.h b/src/include/comm.h index d22ca71e69..2bb73c7170 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -205,13 +205,16 @@ struct ncclTaskColl { int32_t algorithm:8, protocol:8; uint32_t isCollnet:1, isNvls:1; uint32_t devFuncId:30; - enum ncclRegBufferType regBufType; + int regBufType; uint64_t opCount; // number of elements in planner->ipcMemQueue associated with this collective int nCleanupQueueElts; void* sendMhandle; void* recvMhandle; + void** sendNetHandles; + void** recvNetHandles; + void** srecvNetHandles; // index for IPC record lookup uintptr_t sendbuffOffset; uintptr_t recvbuffOffset; @@ -246,6 +249,7 @@ struct ncclKernelPlan { struct ncclKernelPlan* next; bool persistent; // aka captured in a graph + bool isHostCbEnq; enum ncclDevWorkStorageType workStorageType; bool kernelSpecialized; void *kernelFn; @@ -377,6 +381,7 @@ struct ncclKernelPlanner { struct ncclIntruQueue collTaskQueue; struct ncclIntruQueue collWorkQueue; + struct ncclIntruQueue tmpCollWorkQueue; struct ncclIntruQueue collCleanupQueue; ////////////////////////////////////////////////////////////////////////////// @@ -494,6 +499,8 @@ struct ncclComm { // Counter for tracking CUDA launches (P2P and collectives included) uint64_t opCount; + // Collective operation counter + uint64_t collOpCount; // Channels for collectives int nChannels; // connection nChannels @@ -517,7 +524,6 @@ struct ncclComm { ssize_t threadThresholds[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; float latencies[NCCL_NUM_FUNCTIONS][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; float bandwidths[NCCL_NUM_FUNCTIONS][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; - float ringbdw[NCCL_NUM_FUNCTIONS][NCCL_NUM_PROTOCOLS]; int maxThreads[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; uint64_t minMaxLLRange[RCCL_TUNABLE_COLLS][NCCL_NUM_PROTOCOLS - 1][RCCL_PROTOCOL_ENTRY_SIZE]; @@ -569,7 +575,7 @@ struct ncclComm { int proxyRefCountOld; /* store proxy post-atomic-sub refcount */ // Whether this communicator uses collNet int collNetSupport; - bool collNetRegSupport; + bool isOneRPN; uint8_t collNetSupportMatrix[4/*sum,prod,max,min*/][ncclNumTypes]; bool intraNodeP2pSupport; int* collNetHeads; @@ -597,6 +603,7 @@ struct ncclComm { // Subset of those in groupNext list. Holds 0x1 if not needing preconnect. struct ncclComm* preconnectNext; int persistentRefs; // number of persistent plan-lists capturing this comm + int noncapturedRefs; // number of non-captured hostStreamPlanCallback on the stream struct P2pSchedulePair { int sendRank; int recvRank; } *p2pSchedule; struct ncclKernelPlanner planner; @@ -664,12 +671,20 @@ struct ncclComm { // buffer registration cache struct ncclRegCache regCache; - uint64_t endMagic; + int isAllNvlink; + bool useNetPXN; + bool useGdr; + int splitCount; // Unroll factor for comm [RCCL] int unroll; + + uint64_t endMagic; }; +static_assert(offsetof(struct ncclComm, startMagic) == 0, "startMagic must be the first field of ncclComm"); +static_assert(offsetof(struct ncclComm, endMagic) == sizeof(struct ncclComm) - sizeof(uint64_t), "endMagic must be the last field of ncclComm"); + enum ncclLaunchMode { ncclLaunchModeInvalid=0, ncclLaunchModeParallel, diff --git a/src/include/debug.h b/src/include/debug.h index 491ac3e123..4e50cbf5a7 100644 --- a/src/include/debug.h +++ b/src/include/debug.h @@ -38,4 +38,6 @@ extern char ncclLastError[]; void ncclSetThreadName(pthread_t thread, const char *fmt, ...); +void ncclResetDebugInit(); + #endif diff --git a/src/include/device.h b/src/include/device.h index 4c202be14a..5f745f1e82 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -98,24 +98,18 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 4 #define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS) -#define NCCL_DIRECT_WRITE 0x01 -#define NCCL_DIRECT_READ 0x02 +#define NCCL_P2P_WRITE 0x01 +#define NCCL_P2P_READ 0x02 #define NCCL_DIRECT_NIC 0x04 -#define NCCL_IPC_WRITE 0x08 -#define NCCL_IPC_READ 0x10 -#define NCCL_NVLS_MIN_POLL 0x20 +#define NCCL_NVLS_MIN_POLL 0x80 // Number of named barriers supported by CUDA #define NCCL_MAX_GROUPS (NCCL_MAX_NTHREADS/WARP_SIZE) -#define NCCL_MAX_COLLNET_SIZE (1L << 29) - -enum ncclRegBufferType { - NCCL_REGULAR_BUFFER = 0, - NCCL_IPC_REG_BUFFER = 1, - NCCL_NVLS_REG_BUFFER = 2, - NCCL_COLLNET_REG_BUFFER = 3 -}; +#define NCCL_REGULAR_BUFFER 0x00 +#define NCCL_IPC_REG_BUFFER 0x01 +#define NCCL_NVLS_REG_BUFFER 0x02 +#define NCCL_NET_REG_BUFFER 0x04 struct ncclConnInfo { // Regular comm mechanism @@ -159,8 +153,6 @@ struct ncclConnector { struct ncclTransportComm* transportComm; void* transportResources; struct ncclConnInfo conn; - int sendMemSameProcess; - int recvMemSameProcess; }; struct ncclRing { @@ -249,8 +241,7 @@ struct alignas(16) ncclDevWorkP2p { uint8_t sendChunkSize_u32fp8, recvChunkSize_u32fp8; uint8_t sendProtoLL:1, recvProtoLL:1; - uint8_t sendRegistered:1, recvRegistered:1; - + uint8_t sendNetReg:1, recvNetReg:1; uint8_t sendIpcReg:1, recvIpcReg:1; uint8_t sendConnIndex:2, recvConnIndex:2; @@ -299,7 +290,7 @@ struct alignas(16) ncclDevWorkColl { // nChannels == (channelHi - channelLo) + 1 uint32_t channelLo:8, channelHi:8; uint32_t nWarps:8; - uint32_t redOpArgIsPtr:1, regUsed:2, oneNode:1, direct:4; + uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1; uint32_t root:30, connIndex:2; uint16_t pivotA2ANumBiRings; void* recvbuff; @@ -504,7 +495,7 @@ struct ncclDevComm { int nNodes; int buffSizes[NCCL_NUM_PROTOCOLS]; int p2pChunkSize; - int isNvlink; + int isAllNvlink; int p2pnChannelsPerPeer; // Work fifo return credits @@ -665,13 +656,7 @@ inline bool ncclNvlsSupported(int devRedOp, int type) { case ncclInt64: case ncclUint64: case ncclFloat16: -#if defined(RCCL_BFLOAT16) case ncclBfloat16: -#endif -#if defined(RCCL_FLOAT8) - case ncclFp8E4M3: - case ncclFp8E5M2: -#endif return devRedOp == ncclDevSum || devRedOp == ncclDevMinMax; case ncclFloat: case ncclDouble: diff --git a/src/include/enqueue.h b/src/include/enqueue.h index 1bb5a604f5..a715a44c43 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -25,5 +25,16 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan ncclResult_t ncclLaunchKernelAfter_NoCuda(struct ncclComm* comm, struct ncclKernelPlan* plan); ncclResult_t ncclLaunchFinish(struct ncclComm* comm); ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool* needConnect, ncclSimInfo_t* simInfo); +ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm); + +static inline size_t ncclFuncSendCount(ncclFunc_t func, int nRanks, size_t count) { + return func == ncclFuncReduceScatter ? nRanks*count : count; +} +static inline size_t ncclFuncRecvCount(ncclFunc_t func, int nRanks, size_t count) { + return func == ncclFuncAllGather ? nRanks*count : count; +} +rccl_static_inline size_t ncclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count) { + return func == ncclFuncAllGather || func == ncclFuncReduceScatter ? nRanks*count : count; +} #endif // End include guard diff --git a/src/include/graph.h b/src/include/graph.h index 7d7aaa4f91..31b1221ba7 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -20,7 +20,7 @@ ncclResult_t ncclTopoCudaPath(int cudaDev, char** path); struct ncclTopoSystem; // Build the topology -ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** system); +ncclResult_t ncclTopoGetSystem(struct ncclComm* comm, struct ncclTopoSystem** system, const char* dumpXmlFile=NULL); ncclResult_t ncclTopoSortSystem(struct ncclTopoSystem* system); ncclResult_t ncclTopoPrint(struct ncclTopoSystem* system); @@ -34,13 +34,14 @@ ncclResult_t ncclTopoComputeCommCPU(struct ncclComm* comm); // Query topology ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoGraph* graph, int channelId, int peerRank, int64_t* id, int* dev, int* proxyRank); -ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int rank1, int rank2, int* p2p, int *read, int* intermediateRank); +ncclResult_t ncclTopoCheckP2p(struct ncclComm* comm, struct ncclTopoSystem* system, int rank1, int rank2, int* p2p, int *read, int* intermediateRank); ncclResult_t ncclTopoCheckMNNVL(struct ncclTopoSystem* system, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* ret); -ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int64_t busId, int64_t netId, int read, int* useGdr); +ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int rank, int64_t netId, int read, int* useGdr); #define MAX_XGMI_INTER_GPUS 4 ncclResult_t ncclTopoGetIntraNetDev(struct ncclTopoSystem* system, int rank, struct ncclTopoGraph* graph, int channelId, int type, int64_t* id, int* dev); ncclResult_t ncclTopoGetLinkType(struct ncclTopoSystem* system, int cudaDev1, int cudaDev2, bool* isXGMI, int maxInter=MAX_XGMI_INTER_GPUS, int nInter=0, int *inter=nullptr); -ncclResult_t ncclTopoNeedFlush(struct ncclTopoSystem* system, int64_t busId, int* flush); +ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int netDev, int rank, int* flush); +ncclResult_t ncclTopoIsGdrAvail(struct ncclTopoSystem* system, int rank, bool *avail); ncclResult_t ncclTopoCheckNet(struct ncclTopoSystem* system, int rank1, int rank2, int* net); int ncclPxnDisable(struct ncclComm* comm); ncclResult_t ncclTopoGetPxnRanks(struct ncclComm* comm, int** intermediateRanks, int* nranks); @@ -128,6 +129,6 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph** graphs); -ncclResult_t ncclTopoGetAlgoTime(struct ncclComm* comm, int coll, int algorithm, int protocol, size_t nBytes, int numPipeOps, float* time, bool* backup=nullptr); +ncclResult_t ncclTopoGetAlgoTime(struct ncclComm* comm, int coll, int algorithm, int protocol, size_t nBytes, int numPipeOps, float* time); #endif diff --git a/src/include/ibvwrap.h b/src/include/ibvwrap.h index c3709584c3..3a4c42bb21 100644 --- a/src/include/ibvwrap.h +++ b/src/include/ibvwrap.h @@ -12,6 +12,8 @@ #ifndef NCCL_IBVWRAP_H_ #define NCCL_IBVWRAP_H_ +#include +#include #ifdef NCCL_BUILD_RDMA_CORE #include #else @@ -89,4 +91,14 @@ static inline ncclResult_t wrap_ibv_post_recv(struct ibv_qp *qp, struct ibv_recv ncclResult_t wrap_ibv_event_type_str(char **ret, enum ibv_event_type event); +// converts a GID into a readable string. On success, returns a non-null pointer to gidStr. +// NULL is returned if there was an error, with errno set to indicate the error. +// errno = ENOSPC if the converted string would exceed strLen. +static inline const char* ibvGetGidStr(union ibv_gid* gid, char* gidStr, size_t strLen) { + // GID is a 16B handle, to convert it to a readable form, we use inet_ntop + // sizeof(ibv_gid) == sizeof(struct in6_addr), so using AF_INET6 + static_assert(sizeof(union ibv_gid) == sizeof(struct in6_addr), "the sizeof struct ibv_gid must be the size of struct in6_addr"); + return inet_ntop(AF_INET6, gid->raw, gidStr, strLen); +} + #endif //End include guard diff --git a/src/include/nccl_common.h b/src/include/nccl_common.h index f4d60a9b7d..6a9d18d3b5 100644 --- a/src/include/nccl_common.h +++ b/src/include/nccl_common.h @@ -32,7 +32,8 @@ typedef enum { NCCL_BOOTSTRAP = 0x1000, NCCL_REG = 0x2000, NCCL_PROFILE = 0x4000, - NCCL_VERBS = 0x8000, + NCCL_RAS = 0x8000, + NCCL_VERBS = 0x10000, NCCL_ALL = ~0 } ncclDebugLogSubSys; diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index 467d9fdb89..f165aa1bf0 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -13,6 +13,9 @@ #include #define NCCL_NET_HANDLE_MAXSIZE 128 +//Maximum value NCCL can accept for maxP2pBytes and maxCollBytes net properties +#define NCCL_MAX_NET_SIZE_BYTES (1*1024*1024*1024*1024L) +#define NCCL_NET_OPTIONAL_RECV_COMPLETION 0x1 #define NCCL_PTR_HOST 0x1 #define NCCL_PTR_CUDA 0x2 @@ -21,6 +24,161 @@ // Maximum number of requests per comm object #define NCCL_NET_MAX_REQUESTS 32 +// Max number of ncclNet objects which can live in the same process +#define NCCL_NET_MAX_PLUGINS 3 + +#define NCCL_NET_MAX_DEVS_PER_NIC_V9 4 +#define NCCL_NET_MAX_DEVS_PER_NIC NCCL_NET_MAX_DEVS_PER_NIC_V9 + +typedef struct { + int ndevs; + int devs[NCCL_NET_MAX_DEVS_PER_NIC_V9]; +} ncclNetVDeviceProps_v9_t; +typedef ncclNetVDeviceProps_v9_t ncclNetVDeviceProps_t; + +typedef struct { + char* name; // Used mostly for logging. + char* pciPath; // Path to the PCI device in /sys. + uint64_t guid; // Unique identifier for the NIC chip. Important for + // cards with multiple PCI functions (Physical or virtual). + int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] + int regIsGlobal; // regMr is not tied to a particular comm + int forceFlush; // Force a flush on receives + int speed; // Port speed in Mbps. + int port; // Port number. + float latency; // Network latency + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives. + ncclNetDeviceType netDeviceType; // Network offload type + int netDeviceVersion; // Version number for network offload + ncclNetVDeviceProps_v9_t vProps; + size_t maxP2pBytes; // Max transfer size for point-to-point operations + size_t maxCollBytes; // Max transfer size for collective operations +} ncclNetProperties_v9_t; +typedef ncclNetProperties_v9_t ncclNetProperties_t; + +typedef struct { + // Name of the network (mainly for logs) + const char* name; + // Initialize the network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v9_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create a connection. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Connect to a handle and return a sending comm object for that peer. + // This call must not block for the connection to be established, and instead + // should return successfully with sendComm == NULL with the expectation that + // it will be called again until sendComm != NULL. + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm); + // Finalize connection establishment after remote peer has called connect. + // This call must not block for the connection to be established, and instead + // should return successfully with recvComm == NULL with the expectation that + // it will be called again until recvComm != NULL. + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm); + // Register/Deregister memory. Comm can be either a sendComm or a recvComm. + // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + // Asynchronous send to a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*isend)(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request); + // Asynchronous recv from a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*irecv)(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* sizes); + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); + + // Create a virtual NIC given the specified properties, which can be accessed at device index d + ncclResult_t (*makeVDevice)(int* d, ncclNetVDeviceProps_t* props); +} ncclNet_v9_t; + +typedef ncclNet_v9_t ncclNet_t; + +#define NCCL_NET_PLUGIN_SYMBOL ncclNetPlugin_v9 + +typedef struct { + void* mhandle; + void* address; + size_t size; +} ncclNetSGE_v9_t; + +typedef struct { + // Name of the collective network (mainly for logs) + const char* name; + // Initialize the collective network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters capable of doing collective operations. + // If ndev returns 0, all other functions might be set to NULL. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v9_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create connections. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Create a group for collective operations. handles have been created + // using listen() above. rank indicates caller's rank in the collective network. + ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); + // Returns whether a reduction operation on a data type is supported. + // 1 for supported, 0 otherwise. + ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); + // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* collComm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* collComm, void* mhandle); + // Performs an asynchronous allreduce operation on the collective group. + // May return request == NULL if the call cannot be performed (or would block). + ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, size_t count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); + ncclResult_t (*iallgather)(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v9_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request); + ncclResult_t (*ireducescatter)(void* collComm, int nSendParts, ncclNetSGE_v9_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* size); + // Close and free collective comm objects + ncclResult_t (*closeColl)(void* collComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Create a virtual NIC given the specified properties, which can be accessed at device index d + ncclResult_t (*makeVDevice)(int* d, ncclNetVDeviceProps_t* props); +} ncclCollNet_v9_t; + +typedef ncclCollNet_v9_t ncclCollNet_t; + +#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v9 + typedef struct { char* name; // Used mostly for logging. char* pciPath; // Path to the PCI device in /sys. @@ -37,8 +195,6 @@ typedef struct { int netDeviceVersion; // Version number for network offload } ncclNetProperties_v8_t; -typedef ncclNetProperties_v8_t ncclNetProperties_t; - typedef struct { // Name of the network (mainly for logs) const char* name; @@ -94,10 +250,6 @@ typedef struct { ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); } ncclNet_v8_t; -typedef ncclNet_v8_t ncclNet_t; - -#define NCCL_NET_PLUGIN_SYMBOL ncclNetPlugin_v8 - typedef struct { void* mhandle; void* address; @@ -151,10 +303,6 @@ typedef struct { ncclResult_t (*closeListen)(void* listenComm); } ncclCollNet_v8_t; -typedef ncclCollNet_v8_t ncclCollNet_t; - -#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v8 - typedef struct { char* name; // Used mostly for logging. char* pciPath; // Path to the PCI device in /sys. diff --git a/src/include/nccl_profiler.h b/src/include/nccl_profiler.h index 556a0f6e45..a8164d075e 100644 --- a/src/include/nccl_profiler.h +++ b/src/include/nccl_profiler.h @@ -16,9 +16,133 @@ enum { ncclProfileProxyOp = (1 << 3), // proxy operation event type ncclProfileProxyStep = (1 << 4), // proxy step event type ncclProfileProxyCtrl = (1 << 5), // proxy control event type - ncclProfileNumEvents = ( 6), }; +typedef struct { + uint8_t type; // event type descriptor: ncclProfileColl, ... + void* parentObj; // pointer to the profiler parent object (for coll is the group) + int rank; // originating rank + union { + struct { + const char* name; + uint64_t commHash; + uint64_t seqNumber; + const char* func; + void const* sendBuff; + void* recvBuff; + size_t count; + int root; + const char* datatype; + size_t trafficBytes; + uint8_t nMaxChannels; + uint8_t nWarps; + const char* algo; + const char* proto; + } coll; + + struct { + const char* name; + uint64_t commHash; + const char* func; + void* buff; + const char* datatype; + size_t count; + int peer; + } p2p; + + struct { + pid_t pid; // pid of the originating process + uint8_t channelId; // channel id for this proxy operation + int peer; // remote rank for send/recv + int nSteps; // number of steps for this proxy operation + int chunkSize; // amount of data transferred by this proxy operation + int isSend; + } proxyOp; + + struct { + int step; + } proxyStep; + }; +} ncclProfilerEventDescr_v2_t; + +typedef enum { + ncclProfilerProxyOpSendPosted, + ncclProfilerProxyOpSendRemFifoWait, + ncclProfilerProxyOpSendTransmitted, + ncclProfilerProxyOpSendDone, + ncclProfilerProxyOpRecvPosted, + ncclProfilerProxyOpRecvReceived, + ncclProfilerProxyOpRecvTransmitted, + ncclProfilerProxyOpRecvDone, + + /* Legacy proxy profiler states */ + ncclProfilerProxyStepSendGPUWait, + ncclProfilerProxyStepSendWait, + ncclProfilerProxyStepRecvWait, + ncclProfilerProxyStepRecvFlushWait, + ncclProfilerProxyStepRecvGPUWait, + + /* Legacy proxy control states */ + ncclProfilerProxyCtrlIdle, + ncclProfilerProxyCtrlActive, + ncclProfilerProxyCtrlSleep, + ncclProfilerProxyCtrlWakeup, + ncclProfilerProxyCtrlAppend, + ncclProfilerProxyCtrlAppendEnd, +} ncclProfilerEventState_v2_t; + +typedef union { + struct { + size_t transSize; + int steps; + } proxyOp; + + struct { + int appendedProxyOps; + } proxyCtrl; +} ncclProfilerEventStateArgs_v2_t; + +typedef struct { + const char* name; + + // init - initialize the profiler plugin + // Input + // - context : opaque profiler context object for separating profiler behavior across comms + // Output + // - eActivationMask: bitmask of active events set by the plugin + ncclResult_t (*init)(void** context, int* eActivationMask); + + // startEvent - initialize and start a new event for the supplied event descriptor inside the eventset + // Input + // - context: opaque profiler context object + // - eDescr : pointer to ncclProfilerEventDescr_t object + // Output + // - eHandle: return event handle for supplied event descriptor object + ncclResult_t (*startEvent)(void* context, void** eHandle, ncclProfilerEventDescr_v2_t* eDescr); + + // stopEvent - stop/finalize an event inside and event set + // Input + // - eHandle: handle to event object + ncclResult_t (*stopEvent)(void* eHandle); + + // recordEventState - record event state transitions and event attribute updates + // Input + // - eHandle : handle to event object created through startEvent + // - eStateArgs: optional argument used to capture event attribute updates associated with the state transition + // - eState : event state transition + ncclResult_t (*recordEventState)(void* eHandle, ncclProfilerEventState_v2_t eState, ncclProfilerEventStateArgs_v2_t* eStateArgs); + + // finalize - finalize the profiler plugin + // Input + // - context: opaque profiler context object + ncclResult_t (*finalize)(void* context); +} ncclProfiler_v2_t; + +typedef ncclProfilerEventDescr_v2_t ncclProfilerEventDescr_t; +typedef ncclProfilerEventState_v2_t ncclProfilerEventState_t; +typedef ncclProfilerEventStateArgs_v2_t ncclProfilerEventStateArgs_t; +typedef ncclProfiler_v2_t ncclProfiler_t; + typedef struct { uint8_t type; // event type descriptor: ncclProfileColl, ... void* parentObj; // pointer to the profiler parent object (for coll is the group) @@ -69,42 +193,8 @@ typedef struct { }; } ncclProfilerEventDescr_v1_t; -typedef enum { - ncclProfilerProxyOpSendPosted, - ncclProfilerProxyOpSendRemFifoWait, - ncclProfilerProxyOpSendTransmitted, - ncclProfilerProxyOpSendDone, - ncclProfilerProxyOpRecvPosted, - ncclProfilerProxyOpRecvReceived, - ncclProfilerProxyOpRecvTransmitted, - ncclProfilerProxyOpRecvDone, - - /* Legacy proxy profiler states */ - ncclProfilerProxyStepSendGPUWait, - ncclProfilerProxyStepSendWait, - ncclProfilerProxyStepRecvWait, - ncclProfilerProxyStepRecvFlushWait, - ncclProfilerProxyStepRecvGPUWait, - - /* Legacy proxy control states */ - ncclProfilerProxyCtrlIdle, - ncclProfilerProxyCtrlActive, - ncclProfilerProxyCtrlSleep, - ncclProfilerProxyCtrlWakeup, - ncclProfilerProxyCtrlAppend, - ncclProfilerProxyCtrlAppendEnd, -} ncclProfilerEventState_v1_t; - -typedef union { - struct { - size_t transSize; - int steps; - } proxyOp; - - struct { - int appendedProxyOps; - } proxyCtrl; -} ncclProfilerEventStateArgs_v1_t; +typedef ncclProfilerEventState_v2_t ncclProfilerEventState_v1_t; +typedef ncclProfilerEventStateArgs_v2_t ncclProfilerEventStateArgs_v1_t; typedef struct { const char* name; @@ -142,9 +232,4 @@ typedef struct { ncclResult_t (*finalize)(void* context); } ncclProfiler_v1_t; -typedef ncclProfilerEventDescr_v1_t ncclProfilerEventDescr_t; -typedef ncclProfilerEventState_v1_t ncclProfilerEventState_t; -typedef ncclProfilerEventStateArgs_v1_t ncclProfilerEventStateArgs_t; -typedef ncclProfiler_v1_t ncclProfiler_t; - #endif diff --git a/src/include/nccl_tuner.h b/src/include/nccl_tuner.h index 5cd02149f9..6e61118b9c 100644 --- a/src/include/nccl_tuner.h +++ b/src/include/nccl_tuner.h @@ -11,6 +11,55 @@ #include "nccl.h" #include "nccl_common.h" +// API to be implemented by external tuner +typedef struct { + // Name of the tuner + const char* name; + + // Initializes tuner states. + // Inputs: + // - nRanks: number of ranks in current communicator. Each communicator initialize its own tuner. + // - nNodes: number of nodes in current communicator. + // - logFunction: a logFunction can be useful to integrate logging together with NCCL core. + // Outputs: + // - context: tuner context object + ncclResult_t (*init)(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction, void **context); + + // Gets info (algo, protocol, number of ctas and threads) for a given collective. + // Inputs: + // - context: tuner context object + // - collType: collective type , e.g., allreduce, allgather… + // - nBytes: collective size in bytes + // - numPipeOps: number of operations in the group + // - numAlgo: number of algorithms in collCostTable + // - numProto: number of protocols in collCostTable + // - regBuff: can register user buffer + // + // Outputs: + // - nChannels: number of channels (hence SMs) to be used. + // + // InOut: + // - collCostTable: collective cost table, generated by NCCL core, containing algo|proto|time entries for collType. + // NCCL core sets ignored algo/proto cost table entries to -1.0 (NCCL_ALGO_PROTO_IGNORE). + // + // If getCollInfo() does not return ncclSuccess, NCCL will fall back to the + // default tuning for the given collective. + // Also, the plugin is allowed to not set any output, or set only the + // algorithm and protocol, but not only the algorithm or only the protocol. + // Unset fields will be set automatically by NCCL. + ncclResult_t (*getCollInfo)(void* context, ncclFunc_t collType, size_t nBytes, + int numPipeOps, float** collCostTable, int numAlgo, int numProto, + int regBuff, int* nChannels); + + // Terminates the plugin and cleans up any resources that the plugin allocated. + // context: tuner context object + ncclResult_t (*destroy)(void* context); +} ncclTuner_v4_t; + +typedef ncclTuner_v4_t ncclTuner_t; + +#define NCCL_TUNER_PLUGIN_SYMBOL "ncclTunerPlugin_v4" + // API to be implemented by external tuner typedef struct { // Name of the tuner @@ -55,10 +104,6 @@ typedef struct { ncclResult_t (*destroy)(void* context); } ncclTuner_v3_t; -typedef ncclTuner_v3_t ncclTuner_t; - -#define NCCL_TUNER_PLUGIN_SYMBOL "ncclTunerPlugin_v3" - // API to be implemented by external tuner typedef struct { // Name of the tuner diff --git a/src/include/net_device.h b/src/include/net_device.h index 7bb2968c05..5fae9b5424 100644 --- a/src/include/net_device.h +++ b/src/include/net_device.h @@ -25,6 +25,7 @@ typedef struct { } ncclNetDeviceHandle_v7_t; typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t; -typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_t; +typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_v9_t; +typedef ncclNetDeviceHandle_v9_t ncclNetDeviceHandle_t; #endif diff --git a/src/include/nvmlwrap.h b/src/include/nvmlwrap.h index 7dee7d4aef..72fbf9ce2a 100644 --- a/src/include/nvmlwrap.h +++ b/src/include/nvmlwrap.h @@ -302,7 +302,7 @@ extern ncclNvmlDevicePairInfo ncclNvmlDevicePairs[ncclNvmlMaxDevices][ncclNvmlMa struct ncclNvmlCCStatus { bool CCEnabled; - bool multiGpuCCEnabled; + bool multiGpuProtectedPCIE; }; // All ncclNvmlFoo() functions call ncclNvmlEnsureInitialized() implicitly. diff --git a/src/include/profiler.h b/src/include/profiler.h index 36774dc848..2b7efe0f69 100644 --- a/src/include/profiler.h +++ b/src/include/profiler.h @@ -36,9 +36,9 @@ ncclResult_t ncclProfilerStartRecvProxyOpEvent(int sub, struct ncclProxyArgs* ar ncclResult_t ncclProfilerStopProxyOpEvent(int sub, struct ncclProxyArgs* args); // Proxy Step Start/Stop Event Wrappers -ncclResult_t ncclProfilerStartSendProxyStepEvents(int sub, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi); -ncclResult_t ncclProfilerStartRecvProxyStepEvents(int sub, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi); -ncclResult_t ncclProfilerStopProxyStepEvents(int sub, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi); +ncclResult_t ncclProfilerStartSendProxyStepEvent(int sub, struct ncclProxyArgs* args, int stepId); +ncclResult_t ncclProfilerStartRecvProxyStepEvent(int sub, struct ncclProxyArgs* args, int stepId); +ncclResult_t ncclProfilerStopProxyStepEvent(int sub, struct ncclProxyArgs* args, int stepId); // Proxy Control Start/Stop Events Wrappers ncclResult_t ncclProfilerStartProxyCtrlEvent(void* profilerContext, void** eHandle); @@ -46,7 +46,7 @@ ncclResult_t ncclProfilerStopProxyCtrlEvent(void* eHandle); // Record Event Wrappers ncclResult_t ncclProfilerRecordProxyOpEventState(int sub, struct ncclProxyArgs* args, int steps, size_t transSize, ncclProfilerEventState_t eState); -ncclResult_t ncclProfilerRecordProxyStepEventStates(int sub, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi, ncclProfilerEventState_t eState); +ncclResult_t ncclProfilerRecordProxyStepEventState(int sub, struct ncclProxyArgs* args, int stepId, ncclProfilerEventState_t eState); ncclResult_t ncclProfilerRecordProxyCtrlEventState(void*eHandle, int appended, ncclProfilerEventState_t eState); // Profiler utility functions diff --git a/src/include/proxy.h b/src/include/proxy.h index 43f7b55947..52b46b8557 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -17,6 +17,7 @@ #include #include "shmutils.h" #include "p2p.h" +#include "collectives.h" typedef enum : uint8_t { ncclPatternRing, @@ -60,7 +61,11 @@ struct ncclProxyOp { uint32_t connIndex:2; int next; int nsteps; - int chunkSize; + size_t chunkSize; + size_t sliceSize; + size_t loopSize; + size_t loopOffset; + size_t channelSize; uint8_t sliceSteps; uint8_t chunkSteps; uint8_t channelId; @@ -69,16 +74,17 @@ struct ncclProxyOp { uint8_t /*ncclFunc_t*/ coll; uint8_t /*ncclPattern_t*/ pattern; uint8_t protocol; + uint8_t algorithm; uint8_t reg; - // collnet buffer reg handles + // collnet/p2p/coll buffer reg handles void* sendMhandle; void* recvMhandle; uint8_t* sendbuff; uint8_t* recvbuff; - + int isOneRPN; + RingAlgorithm *ringAlgo; int nextRank; int prevRank; - union ncclProxyOpSpecifics specifics; // Profiler plugin @@ -100,19 +106,21 @@ struct ncclProxyOp { struct ncclProxySubArgs { struct ncclProxyConnection* connection; int reg; - // p2p mhandle - void* mhandle; // collnet handles void* sendMhandle; void* recvMhandle; uint8_t* sendbuff; uint8_t* recvbuff; size_t offset; + ssize_t loopSize; + ssize_t loopOffset; int channelId; int nsteps; ssize_t nbytes; + ssize_t chunkSize; int peer; - + int isOneRPN; + RingAlgorithm *ringAlgo; int groupSize; // Number of consecutive sub operations sharing the same recvComm uint64_t base; uint64_t posted; @@ -121,11 +129,14 @@ struct ncclProxySubArgs { uint64_t transmitted; uint64_t done; uint64_t end; + int regBufferReady; void* requests[NCCL_STEPS]; // Profiler plugin int eActivationMask; int rank; + pid_t pid; + void* profilerContext; void* taskEventHandle; void* opEventHandle; void* stepEventHandles[NCCL_STEPS]; @@ -145,10 +156,11 @@ struct ncclProxyArgs { proxyProgressFunc_t progress; int nsubs; int done; + int onePPN; uint64_t opCount; int sliceSteps; int chunkSteps; - int chunkSize; + size_t chunkSize; size_t totalSendSize; size_t totalRecvSize; size_t sendSizePerRound; @@ -158,6 +170,7 @@ struct ncclProxyArgs { uint8_t /*ncclPattern_t*/ pattern; uint8_t /*ncclFunc_t*/ coll; uint8_t protocol; + uint8_t algorithm; int state; char* sharedBuff[NCCL_STEPS]; int sharedSize[NCCL_STEPS]; @@ -165,10 +178,6 @@ struct ncclProxyArgs { int idle; uint64_t hdp_flushed; - // Profiler plugin - pid_t pid; - void* profilerContext; - // Element linking struct ncclProxyArgs* next; struct ncclProxyArgs* nextPeer; diff --git a/src/include/ras.h b/src/include/ras.h new file mode 100644 index 0000000000..7909b3dc89 --- /dev/null +++ b/src/include/ras.h @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_RAS_H_ +#define NCCL_RAS_H_ + +#include "socket.h" + +// Structure used to communicate data about NCCL ranks from NCCL threads to RAS. +struct rasRankInit { + union ncclSocketAddress addr; + pid_t pid; + int cudaDev; + int nvmlDev; +}; + +ncclResult_t ncclRasCommInit(struct ncclComm* comm, struct rasRankInit* myRank); +ncclResult_t ncclRasCommFini(const struct ncclComm* comm); +ncclResult_t ncclRasAddRanks(struct rasRankInit* ranks, int nranks); + +#endif // !NCCL_RAS_H_ diff --git a/src/include/register.h b/src/include/register.h index 7c60535d9a..740a645f43 100644 --- a/src/include/register.h +++ b/src/include/register.h @@ -6,6 +6,9 @@ #include #include +int64_t ncclParamLocalRegister(); +int64_t ncclParamGraphRegister(); + enum { NET_REG_COMPLETE = 0x01, NVLS_REG_COMPLETE = 0x02, @@ -20,16 +23,21 @@ struct ncclPeerRegIpcAddr { uintptr_t* hostPeerRmtAddrs; }; +struct ncclRegNetHandles { + void* handle; + struct ncclProxyConnector* proxyConn; + struct ncclRegNetHandles* next; +}; + struct ncclReg { // common attributes size_t pages; - int refs; + int localRefs; + int graphRefs; uintptr_t addr; uint32_t state; // net reg - int nDevs; - int devs[MAXCHANNELS]; - void** handles; + struct ncclRegNetHandles* netHandleHead; // nvls reg uintptr_t baseAddr; size_t baseSize; @@ -50,11 +58,12 @@ struct ncclRegCache { struct ncclReg **slots; int capacity, population; uintptr_t pageSize; - void* sComms[MAXCHANNELS]; - void* rComms[MAXCHANNELS]; }; ncclResult_t ncclRegCleanup(struct ncclComm* comm); ncclResult_t ncclRegFind(struct ncclComm* comm, const void* data, size_t size, struct ncclReg** reg); +ncclResult_t ncclCommGraphRegister(const ncclComm_t comm, void* buff, size_t size, void** handle); +ncclResult_t ncclCommGraphDeregister(const ncclComm_t comm, struct ncclReg *handle); +ncclResult_t ncclRegLocalIsValid(struct ncclReg *reg, bool *isValid); #endif diff --git a/src/include/shmutils.h b/src/include/shmutils.h index 43e8afb79a..097b4c6577 100644 --- a/src/include/shmutils.h +++ b/src/include/shmutils.h @@ -10,7 +10,7 @@ #include "nccl.h" typedef void* ncclShmHandle_t; -ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle); +ncclResult_t ncclShmOpen(char* shmPath, size_t shmPathSize, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle); ncclResult_t ncclShmClose(ncclShmHandle_t handle); ncclResult_t ncclShmUnlink(ncclShmHandle_t handle); diff --git a/src/include/socket.h b/src/include/socket.h index 60a4138752..f0a3237cee 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -17,9 +17,6 @@ #define MAX_IFS 16 #define MAX_IF_NAME_SIZE 16 -#define SLEEP_INT 1000 // connection retry sleep interval in usec -#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) -#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) #define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV) #define NCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL @@ -39,9 +36,10 @@ enum ncclSocketState { ncclSocketStateConnectPolling = 5, ncclSocketStateConnected = 6, ncclSocketStateReady = 7, - ncclSocketStateClosed = 8, - ncclSocketStateError = 9, - ncclSocketStateNum = 10 + ncclSocketStateTerminating = 8, + ncclSocketStateClosed = 9, + ncclSocketStateError = 10, + ncclSocketStateNum = 11 }; enum ncclSocketType { @@ -49,14 +47,14 @@ enum ncclSocketType { ncclSocketTypeBootstrap = 1, ncclSocketTypeProxy = 2, ncclSocketTypeNetSocket = 3, - ncclSocketTypeNetIb = 4 + ncclSocketTypeNetIb = 4, + ncclSocketTypeRasNetwork = 5 }; struct ncclSocket { int fd; int acceptFd; - int timedOutRetries; - int refusedRetries; + int errorRetries; union ncclSocketAddress addr; volatile uint32_t* abortFlag; int asyncFlag; @@ -64,15 +62,18 @@ struct ncclSocket { int salen; uint64_t magic; enum ncclSocketType type; + int customRetry; + int finalizeCounter; // Used to keep track of initial handshake for async sockets. + char finalizeBuffer[sizeof(uint64_t)]; // Used to keep track of initial handshake for async sockets. }; -const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm = 1); +const char *ncclSocketToString(const union ncclSocketAddress *addr, char *buf, const int numericHostForm = 1); ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair); int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAddrs, union ncclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs); int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNameMaxSize, int maxIfs); // Initialize a socket -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr = NULL, uint64_t magic = NCCL_SOCKET_MAGIC, enum ncclSocketType type = ncclSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0); +ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr = NULL, uint64_t magic = NCCL_SOCKET_MAGIC, enum ncclSocketType type = ncclSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0, int customRetry = 0); // Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call ncclResult_t ncclSocketListen(struct ncclSocket* sock); ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr); @@ -88,11 +89,12 @@ ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock); #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 -ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset); +ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* closed = NULL); ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset); ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketSendRecv(struct ncclSocket* sendSock, void* sendPtr, int sendSize, struct ncclSocket* recvSock, void* recvPtr, int recvSize); ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking); +ncclResult_t ncclSocketShutdown(struct ncclSocket* sock, int how); ncclResult_t ncclSocketClose(struct ncclSocket* sock); #endif diff --git a/src/include/transport.h b/src/include/transport.h index 7cd139290e..90b591d35c 100644 --- a/src/include/transport.h +++ b/src/include/transport.h @@ -29,7 +29,6 @@ extern struct ncclTransport netTransport; extern struct ncclTransport collNetTransport; extern struct ncclTransport* ncclTransports[]; - // Forward declarations struct ncclRing; struct ncclConnector; @@ -100,16 +99,16 @@ struct ncclTransport { }; ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, int channelId, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex); -ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL, bool* needsProxy=NULL); +ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, bool* needsProxy=NULL); ncclResult_t ncclTransportCheckP2pType(struct ncclComm* comm, bool* intraNodeP2pSupport, bool* directMode); ncclResult_t ncclNvlsInit(struct ncclComm* comm); ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent); ncclResult_t ncclNvlsBufferSetup(struct ncclComm* comm); ncclResult_t ncclNvlsTreeConnect(struct ncclComm* comm); -ncclResult_t ncclNvlsGraphRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, bool *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueElts); -ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, bool *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv); -ncclResult_t ncclNvlsDeregBuffer(CUmemGenericAllocationHandle *mcHandler, CUdeviceptr ptr, int dev, size_t size); +ncclResult_t ncclNvlsGraphRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueElts); +ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv); +ncclResult_t ncclNvlsDeregBuffer(struct ncclComm* comm, CUmemGenericAllocationHandle *mcHandler, CUdeviceptr ptr, int dev, size_t size); ncclResult_t ncclNvlsFree(struct ncclComm* comm); enum { collNetRecv=0, collNetSend=1 }; @@ -128,4 +127,13 @@ ncclResult_t ncclCollNetSetup(ncclComm_t comm, ncclComm_t parent, struct ncclTop ncclResult_t ncclCollNetChainBufferSetup(ncclComm_t comm); ncclResult_t ncclCollNetDirectBufferSetup(ncclComm_t comm); +ncclResult_t ncclNetDeregBuffer(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* handle); +ncclResult_t ncclNetLocalRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, struct ncclConnector** peerConns, int nPeers, int* outRegBufFlag, void** outHandle); +ncclResult_t ncclNetGraphRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, struct ncclConnector** peerConns, int nPeers, int* outRegBufFlag, void** outHandle, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueElts); + +ncclResult_t ncclRegisterP2pIpcBuffer(struct ncclComm* comm, void* userbuff, size_t size, int peerRank, int* regFlag, void** regAddr, struct ncclIntruQueue* cleanupQueue); +ncclResult_t ncclRegisterP2pNetBuffer(struct ncclComm* comm, void* userbuff, size_t size, struct ncclConnector* conn, int* regFlag, void** handle, struct ncclIntruQueue* cleanupQueue); +ncclResult_t ncclRegisterCollBuffers(struct ncclComm* comm, struct ncclTaskColl* info, void* outRegBufSend[NCCL_MAX_LOCAL_RANKS], void* outRegBufRecv[NCCL_MAX_LOCAL_RANKS], struct ncclIntruQueue* cleanupQueue, bool* regNeedConnect); +ncclResult_t ncclRegisterCollNvlsBuffers(struct ncclComm* comm, struct ncclTaskColl* info, void* outRegBufSend[NCCL_MAX_LOCAL_RANKS], void* outRegBufRecv[NCCL_MAX_LOCAL_RANKS], struct ncclIntruQueue* cleanupQueue, bool* regNeedConnect); + #endif diff --git a/src/include/utils.h b/src/include/utils.h index 5a1b749a76..383f678c87 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -49,8 +49,7 @@ inline uint64_t clockNano() { return uint64_t(ts.tv_sec)*1000*1000*1000 + ts.tv_nsec; } -/* get any bytes of random data from /dev/urandom, return 0 if it succeeds; else - * return -1 */ +/* get any bytes of random data from /dev/urandom, return ncclSuccess (0) if it succeeds. */ inline ncclResult_t getRandomData(void* buffer, size_t bytes) { ncclResult_t ret = ncclSuccess; if (bytes > 0) { diff --git a/src/init.cc b/src/init.cc index 7a1aa312b6..e149d0ce9f 100644 --- a/src/init.cc +++ b/src/init.cc @@ -24,6 +24,7 @@ #include "npkit/npkit.h" #endif #include "tuner.h" +#include "ras.h" #include #include #include @@ -377,6 +378,8 @@ static ncclResult_t commFree(ncclComm_t comm) { if (comm == NULL) return ncclSuccess; + NCCLCHECK(ncclRasCommFini(comm)); + /* in commReclaim, we have guaranteed only last rank which calls ncclCommDestroy() will * free all intra-process communicators; therefore, we only need to focus on local * resource cleanup in commFree(). */ @@ -388,7 +391,7 @@ static ncclResult_t commFree(ncclComm_t comm) { } } - CUDACHECK(cudaMemPoolDestroy(comm->memPool)); + if (comm->memPool) CUDACHECK(cudaMemPoolDestroy(comm->memPool)); delete[] comm->userRedOps; @@ -690,11 +693,6 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in ncclIntruQueueConstruct(&comm->eventCallbackQueue); - // setup intraComm0 and intraRanks 0 to default values to ensure proper cleanup of the communicator - comm->intraComm0 = comm; - comm->intraRank = 0; - comm->intraRanks = 1; - return ncclSuccess; } @@ -704,6 +702,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { struct ncclDevCommAndChannels tmpCommAndChans; struct ncclDevCommAndChannels *devCommAndChans = NULL; struct ncclNvmlCCStatus ccStatus; + bool ccEnable = false; NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->deviceStream), ret, fail); NCCLCHECKGOTO(ncclCudaCallocAsync(&devCommAndChans, 1, comm->sharedRes->deviceStream.cudaStream), ret, fail); @@ -717,7 +716,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.comm.node = comm->node; tmpCommAndChans.comm.nNodes = comm->nNodes; tmpCommAndChans.comm.abortFlag = comm->abortFlagDev; - tmpCommAndChans.comm.isNvlink = ncclTopoPathAllNVLink(comm->topo); + tmpCommAndChans.comm.isAllNvlink = comm->isAllNvlink; tmpCommAndChans.comm.p2pnChannelsPerPeer = comm->p2pnChannelsPerPeer; for (int p=0; p < NCCL_NUM_PROTOCOLS; p++) { tmpCommAndChans.comm.buffSizes[p] = comm->buffSizes[p]; @@ -729,11 +728,9 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { #if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__) memset(&ccStatus, 0, sizeof(ccStatus)); - if (ncclNvmlGetCCStatus(&ccStatus) == ncclSuccess && ccStatus.CCEnabled) { + ccEnable = (ncclSuccess == ncclNvmlGetCCStatus(&ccStatus)) && (ccStatus.CCEnabled || ccStatus.multiGpuProtectedPCIE); + if (ccEnable) { comm->workFifoBytes = 0; - if (ccStatus.multiGpuCCEnabled == false && comm->rank == 0) { - WARN("CC On, Multi-GPU CC Off (No inter-GPU communication protection)"); - } } else { comm->workFifoBytes = ncclParamWorkFifoBytes(); if (0 != (comm->workFifoBytes & (comm->workFifoBytes-1))) { @@ -752,7 +749,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { #endif if (comm->rank == 0) { - INFO(NCCL_INIT, "CC %s, Multi-GPU CC %s, workFifoBytes %d", ccStatus.CCEnabled ? "On" : "Off", ccStatus.multiGpuCCEnabled ? "On" : "Off", comm->workFifoBytes); + INFO(NCCL_INIT, "CC %s, workFifoBytes %d", ccEnable ? "On" : "Off", comm->workFifoBytes); } if (ncclGdrCopy != NULL && ncclParamGdrCopyFifoEnable() == 1) { @@ -961,9 +958,6 @@ NCCL_PARAM(P2pPciChunkSize, "P2P_PCI_CHUNKSIZE", (1 << 17)); /* 128 kB */ NCCL_PARAM(P2pNvlChunkSize, "P2P_NVL_CHUNKSIZE", (1 << 19)); /* 512 kB */ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { - int cpuArch, cpuVendor, cpuModel; - NCCLCHECK(ncclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel)); - int64_t envs[NCCL_NUM_PROTOCOLS] = { ncclParamLlBuffSize(), ncclParamLl128BuffSize(), ncclParamBuffSize() }; int defaults[NCCL_NUM_PROTOCOLS] = { DEFAULT_LL_BUFFSIZE, DEFAULT_LL128_BUFFSIZE, DEFAULT_BUFFSIZE }; @@ -972,7 +966,7 @@ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { } if (comm->nNodes > 1) comm->p2pChunkSize = ncclParamP2pNetChunkSize(); - else if (ncclTopoPathAllNVLink(comm->topo)) comm->p2pChunkSize = ncclParamP2pNvlChunkSize(); + else if (comm->isAllNvlink) comm->p2pChunkSize = ncclParamP2pNvlChunkSize(); else comm->p2pChunkSize = ncclParamP2pPciChunkSize(); // Make sure P2P chunksize is not larger than coll chunksize. @@ -1218,6 +1212,14 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p } while(0); timers[TIMER_INIT_TOPO] = clockNano(); + + // Dump XML if requested by user + const char* dumpXmlFile; + dumpXmlFile = ncclGetEnv("NCCL_TOPO_DUMP_FILE"); + if (dumpXmlFile) { + NCCLCHECKGOTO(ncclTopoGetSystem(comm, NULL, dumpXmlFile), ret, fail); + } + // Topo detection / System graph creation NCCLCHECKGOTO(ncclTopoGetSystem(comm, &comm->topo), ret, fail); // save nRanks to ncclTopoSystem as indicator of multi-node @@ -1556,9 +1558,9 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p INFO(NCCL_INIT, "Communicator has %d nodes which is less than CollNet node threshold %d, disabling CollNet", comm->nNodes, collNetNodeThreshold); comm->collNetSupport = 0; } - // As long as there is more than 1 rank on any node, we need to disable collnet reg - comm->collNetRegSupport = (comm->maxLocalRanks == 1); } + comm->isAllNvlink = ncclTopoPathAllNVLink(comm->topo); + comm->isOneRPN = (comm->maxLocalRanks == 1); NCCLCHECKGOTO(ncclCalloc(&rings, nranks*MAXCHANNELS), ret, fail); @@ -1845,6 +1847,7 @@ struct ncclCommInitRankAsyncJob { // for ncclCommSplit struct ncclComm* parent; int color, key; + int splitCount; // name of the function calling char funcName[NCCL_COMMINIT_FUNCNAME_LEN]; }; @@ -1958,13 +1961,14 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { timers[TIMER_INIT_ALLOC] = clockNano(); NCCLCHECKGOTO(commAlloc(comm, job->parent, job->nranks, job->myrank), res, fail); timers[TIMER_INIT_ALLOC] = clockNano() - timers[TIMER_INIT_ALLOC]; - // obtain a unique hash for the comm, re-using part of the parent's hash, commHash is a 64bit struct (=16 hex), add the color + // obtain a unique hash for the comm, re-using part of the parent's hash, commHash is a 64bit struct (=16 hex), + // add unique split counter and the color ncclUniqueId tmpId; memset(&tmpId,0,sizeof(ncclUniqueId));// must set 0 here to avoid undefined bits - snprintf((char*)&tmpId, NCCL_UNIQUE_ID_BYTES, "%016lx-%d", job->parent->commHash, job->color); + snprintf((char*)&tmpId, NCCL_UNIQUE_ID_BYTES, "%016lx-%d-%d", job->parent->commHash, job->splitCount, job->color); comm->commHash = getHash(tmpId.internal, NCCL_UNIQUE_ID_BYTES); - INFO(NCCL_INIT, "%s comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx parent %p color %d key %d- Init START", job->funcName, - comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, job->parent, job->color, job->key); + INFO(NCCL_INIT, "%s comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx parent %p splitCount %d color %d key %d- Init START", job->funcName, + comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, job->parent, job->splitCount, job->color, job->key); timers[TIMER_INIT_BOOTSTRAP] = clockNano(); NCCLCHECKGOTO(bootstrapSplit(comm->commHash, comm, job->parent, job->color, job->key, parentRanks), res, fail); timers[TIMER_INIT_BOOTSTRAP] = clockNano() - timers[TIMER_INIT_BOOTSTRAP]; @@ -2059,8 +2063,8 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { /* unlink child abort flag. */ __atomic_store_n(&job->parent->childAbortFlag, NULL, __ATOMIC_RELEASE); TRACE_CALL("ncclCommSplit(%p, %d, %d, %p, %d, %d)", job->parent, job->color, job->key, comm, comm->rank, comm->nRanks); - INFO(NCCL_INIT, "%s comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx parent %p color %d key %d - Init COMPLETE", job->funcName, - comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, job->parent, job->color, job->key); + INFO(NCCL_INIT, "%s comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx parent %p splitCount %d color %d key %d - Init COMPLETE", job->funcName, + comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, job->parent, job->splitCount, job->color, job->key); } else { // the name for the replay tool is ncclCommInitRank for all the variations TRACE_CALL("ncclCommInitRank(%p, %d, 0x%llx, %d, %d)", comm, comm->nRanks, commIdHash, comm->rank, comm->cudaDev); @@ -2301,8 +2305,8 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, int nId comm->startMagic = comm->endMagic = NCCL_MAGIC; // Used to detect comm corruption. *comm->abortFlagRefCount = 1; NCCLCHECKGOTO(parseCommConfig(comm, config), res, fail); - /* start with ncclInternalError and will be changed to ncclSuccess if init succeeds. */ - comm->initState = ncclInternalError; + /* start with ncclInProgress and will be changed to ncclSuccess if init succeeds. */ + comm->initState = ncclInProgress; *newcomm = comm; NCCLCHECKGOTO(ncclCalloc(&job, 1), res, fail); @@ -2337,6 +2341,7 @@ exit: NCCLCHECK(Recorder::instance().record(rrCommInitDev, nranks, myrank, commId, comm, cudaDev)); return ncclGroupErrCheck(res); fail: + if (job) ncclCommInitJobFree(job); if (comm) { free(comm->abortFlag); if (comm->abortFlagDev) (void)ncclCudaHostFree((void*)comm->abortFlagDev); @@ -2436,7 +2441,7 @@ ncclResult_t ncclCommInitAll_impl(ncclComm_t* comms, int ndev, const int* devlis NCCLCHECKGOTO(ncclGroupEnd(), ret, fail); exit: - cudaSetDevice(oldDev); + (void)cudaSetDevice(oldDev); free(gpuFlags); return ret; fail: @@ -2517,14 +2522,9 @@ fail: static ncclResult_t commDestroySync(struct ncclAsyncJob* job_) { struct ncclCommFinalizeAsyncJob* job = (struct ncclCommFinalizeAsyncJob*) job_; ncclComm_t comm = job->comm; - int savedDevice; - int commDevice = comm->cudaDev; ncclResult_t ret = ncclSuccess; - CUDACHECKGOTO(cudaGetDevice(&savedDevice), ret, fail); - if (savedDevice != commDevice) { - CUDACHECKGOTO(cudaSetDevice(commDevice), ret, fail); - } + CUDACHECKGOTO(cudaSetDevice(comm->cudaDev), ret, fail); TRACE(NCCL_INIT, "Destroying comm %p rank %d abortFlag %d asyncResult %d", comm, comm->rank, *comm->abortFlag, comm->asyncResult); @@ -2554,10 +2554,6 @@ static ncclResult_t commDestroySync(struct ncclAsyncJob* job_) { WARN("ncclProxyStop: comm %p (rank = %d) destroys proxy resource error %d", comm, comm->rank, ret); } - if (savedDevice != commDevice) { - CUDACHECKGOTO(cudaSetDevice(savedDevice), ret, fail); - } - exit: return ret; fail: @@ -2565,30 +2561,18 @@ fail: } static ncclResult_t commCleanup(ncclComm_t comm) { - int savedDevice; - int commDevice = comm->cudaDev; bool mscclEnabledForTopo = comm->topo->mscclEnabled; - CUDACHECK(cudaGetDevice(&savedDevice)); - if (savedDevice != commDevice) { - CUDACHECK(cudaSetDevice(commDevice)); - } - + CUDACHECK(cudaSetDevice(comm->cudaDev)); if (comm->tuner != NULL) { NCCLCHECK(comm->tuner->destroy(comm->tunerContext)); NCCLCHECK(ncclTunerPluginUnload(comm)); } - if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) { NCCLCHECK(mscclTeardown(comm->rank)); } - NCCLCHECK(commFree(comm)); - if (savedDevice != commDevice) { - CUDACHECK(cudaSetDevice(savedDevice)); - } - #if defined(ENABLE_NPKIT) // Dump NPKit events and shutdown const char* npkitDumpDir = getenv("NPKIT_DUMP_DIR"); @@ -2731,6 +2715,7 @@ ncclResult_t ncclCommDestroy_impl(ncclComm_t comm) { NVTX3_FUNC_WITH_PARAMS(CommDestroy, CommInitRankSchema, payload) TRACE(NCCL_INIT, "comm %p rank %d nRanks %d cudaDev %d busId %lx", comm, rank, nranks, cudaDev, comm->busId); + NCCLCHECK(ncclGroupStartInternal()); // Try and prevent a double free of the comm struct (user error) if (comm->rank == -1 || comm->nRanks == -1 || comm->cudaDev == -1 || comm->busId == -1) { WARN("comm %p has already been destroyed", comm); @@ -2745,6 +2730,8 @@ ncclResult_t ncclCommDestroy_impl(ncclComm_t comm) { NCCLCHECKGOTO(ncclAsyncLaunch((struct ncclAsyncJob*)job, commReclaim, NULL, free, comm), res, fail); exit: + ncclGroupErrCheck(res); + NCCLCHECK(ncclGroupEndInternal()); return res; fail: goto exit; @@ -2757,7 +2744,7 @@ ncclResult_t ncclCommAbort_impl(ncclComm_t comm) { NVTX3_FUNC_RANGE_IN(nccl_domain); return ncclSuccess; } - + NCCLCHECK(ncclGroupStartInternal()); // Ask anything that might still be running on the device to quit if (comm->childAbortFlag != nullptr) { __atomic_store_n(comm->childAbortFlag, 1, __ATOMIC_RELEASE); @@ -2785,6 +2772,8 @@ ncclResult_t ncclCommAbort_impl(ncclComm_t comm) { NCCLCHECKGOTO(ncclAsyncLaunch((struct ncclAsyncJob*)job, commReclaim, NULL, free, comm), res, fail); exit: + ncclGroupErrCheck(res); + NCCLCHECK(ncclGroupEndInternal()); return ncclSuccess; fail: goto exit; @@ -2851,14 +2840,15 @@ ncclResult_t ncclCommSplit_impl(ncclComm_t comm, int color, int key, ncclComm_t NCCLCHECKGOTO(parseCommConfig(childComm, config), res, fail); } - /* start with ncclInternalError and will be changed to ncclSuccess if init succeeds. */ - childComm->initState = ncclInternalError; + /* start with ncclInProgress and will be changed to ncclSuccess if init succeeds. */ + childComm->initState = ncclInProgress; } NCCLCHECKGOTO(ncclCalloc(&job, 1), res, fail); job->comm = childComm; job->newcomm = newcomm; job->parent = comm; + job->splitCount = ++comm->splitCount; job->color = color; job->key = key; job->cudaDev = comm->cudaDev; @@ -2870,13 +2860,13 @@ exit: // TODO: further integrate overloaded record header // !recording at sink Recorder::instance().record(rrCommSplit, color, key, (ncclUniqueId*)comm, config, *newcomm); - cudaSetDevice(oldDev); + (void)cudaSetDevice(oldDev); (void)ncclGroupErrCheck(res); NCCLCHECK(ncclGroupEndInternal()); return res; fail: if (childComm) { - if (comm && !comm->config.splitShare) { + if (!comm->config.splitShare) { free(childComm->abortFlag); if (childComm->abortFlagDev) ncclCudaHostFree(childComm->abortFlagDev); free(childComm->abortFlagRefCount); @@ -2990,14 +2980,12 @@ ncclResult_t ncclMemAlloc_impl(void **ptr, size_t size) { CUDACHECK(cudaGetDevice(&cudaDev)); CUCHECK(cuDeviceGet(¤tDev, cudaDev)); - if (CUPFN(cuMulticastCreate) != NULL) - CUCHECK(cuDeviceGetAttribute(&mcSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, currentDev)); - if (mcSupport) { + if (ncclCuMemEnable()) { int requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // Query device to see if FABRIC handle support is available flag = 0; - (void) CUPFN(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, currentDev));; + (void) CUPFN(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, currentDev)); if (flag) requestedHandleTypes |= CU_MEM_HANDLE_TYPE_FABRIC; memprop.type = CU_MEM_ALLOCATION_TYPE_PINNED; memprop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; @@ -3008,18 +2996,24 @@ ncclResult_t ncclMemAlloc_impl(void **ptr, size_t size) { CUCHECK(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, currentDev)); if (flag) memprop.allocFlags.gpuDirectRDMACapable = 1; CUCHECK(cuMemGetAllocationGranularity(&memGran, &memprop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); - - /* mc property */ CUDACHECK(cudaGetDeviceCount(&dcnt)); - mcprop.size = size; - /* device cnt is a dummy value right now, it might affect mc granularity in the future. */ - mcprop.numDevices = dcnt; - mcprop.handleTypes = requestedHandleTypes; - mcprop.flags = 0; - CUCHECK(cuMulticastGetGranularity(&mcGran, &mcprop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); - /* only size needs to be aligned to mcGran */ - ALIGN_SIZE(size, mcGran); + if (CUPFN(cuMulticastCreate) != NULL) CUCHECK(cuDeviceGetAttribute(&mcSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, currentDev)); + if (mcSupport) { + /* mc property */ + mcprop.size = size; + /* device cnt is a dummy value right now, it might affect mc granularity in the future. */ + mcprop.numDevices = dcnt; + mcprop.handleTypes = requestedHandleTypes; + mcprop.flags = 0; + CUCHECK(cuMulticastGetGranularity(&mcGran, &mcprop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + + /* only size needs to be aligned to mcGran */ + ALIGN_SIZE(size, mcGran); + } else { + ALIGN_SIZE(size, memGran); + } + if (requestedHandleTypes & CU_MEM_HANDLE_TYPE_FABRIC) { /* First try cuMemCreate() with FABRIC handle support and then remove if it fails */ CUresult err = CUPFN(cuMemCreate(&handle, size, &memprop, 0)); @@ -3046,6 +3040,7 @@ ncclResult_t ncclMemAlloc_impl(void **ptr, size_t size) { accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; CUCHECK(cuMemSetAccess((CUdeviceptr)*ptr, size, &accessDesc, 1)); } + if (0 == p2p && i != cudaDev) INFO(NCCL_ALLOC, "P2P not supported between GPU%d and GPU%d", cudaDev, i); } goto exit; } @@ -3074,18 +3069,13 @@ ncclResult_t ncclMemFree_impl(void *ptr) { CUDACHECK(cudaGetDevice(&saveDevice)); #if CUDART_VERSION >= 12010 CUdevice ptrDev = 0; - int mcSupport = 0; if (ptr == NULL) goto fallback; - if (ncclCudaLibraryInit() != ncclSuccess) goto fallback; CUCHECKGOTO(cuPointerGetAttribute((void*)&ptrDev, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)ptr), ret, fail); - if (CUPFN(cuMulticastCreate) != NULL) - CUCHECKGOTO(cuDeviceGetAttribute(&mcSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, ptrDev), ret, fail); - CUDACHECKGOTO(cudaSetDevice((int)ptrDev), ret, fail); - if (mcSupport) { + if (ncclCuMemEnable()) { NCCLCHECKGOTO(ncclCuMemFree(ptr), ret, fail); goto exit; } diff --git a/src/misc/cudawrap.cc b/src/misc/cudawrap.cc index 03e3bde992..e5fec1e46c 100644 --- a/src/misc/cudawrap.cc +++ b/src/misc/cudawrap.cc @@ -11,7 +11,7 @@ // This env var (NCCL_CUMEM_ENABLE) toggles cuMem API usage NCCL_PARAM(CuMemEnable, "CUMEM_ENABLE", -2); -NCCL_PARAM(CuMemHostEnable, "CUMEM_HOST_ENABLE", 0); +NCCL_PARAM(CuMemHostEnable, "CUMEM_HOST_ENABLE", -1); // Handle type used for cuMemCreate() CUmemAllocationHandleType ncclCuMemHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; @@ -35,9 +35,6 @@ int ncclIsCuMemSupported() { // Query device to see if CUMEM VMM support is available CUCHECKGOTO(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, currentDev), ret, error); if (!flag) return 0; - // Query device to see if CUMEM RDMA support is available - CUCHECKGOTO(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, currentDev), ret, error); - if (!flag) return 0; error: return (ret == ncclSuccess); #endif @@ -49,11 +46,31 @@ int ncclCuMemEnable() { return param >= 0 ? param : (param == -2 && ncclCuMemSupported); } +static int ncclCumemHostEnable = -1; int ncclCuMemHostEnable() { + if (ncclCumemHostEnable != -1) + return ncclCumemHostEnable; #if CUDART_VERSION < 12020 - return 0; + ncclCumemHostEnable = 0; + return ncclCumemHostEnable; #else - return ncclParamCuMemHostEnable(); + ncclResult_t ret = ncclSuccess; + int cudaDriverVersion; + int paramValue = -1; + CUDACHECKGOTO(cudaDriverGetVersion(&cudaDriverVersion), ret, error); + if (cudaDriverVersion < 12020) { + ncclCumemHostEnable = 0; + } + else { + paramValue = ncclParamCuMemHostEnable(); + if (paramValue != -1) + ncclCumemHostEnable = paramValue; + else + ncclCumemHostEnable = (cudaDriverVersion >= 12060) ? 1 : 0; + } + return ncclCumemHostEnable; +error: + return (ret == ncclSuccess); #endif } @@ -218,10 +235,9 @@ static void initOnceFunc() { // Determine whether we support the cuMem APIs or not ncclCuMemSupported = ncclIsCuMemSupported(); -#if 12020 <= CUDART_VERSION && CUDART_VERSION <= 12030 - /* To use cuMem* for host memory allocation, we need to create context on each - * visible device. This is workaround needed in CUDA 12.3 which is fixed in 12.4. */ - if (ncclCuMemSupported && ncclCuMemHostEnable()) { + /* To use cuMem* for host memory allocation, we need to create context on each visible device. + * This is a workaround needed in CUDA 12.2 and CUDA 12.3 which is fixed in 12.4. */ + if (ncclCuMemSupported && ncclCuMemHostEnable() && 12020 <= driverVersion && driverVersion <= 12030) { int deviceCnt, saveDevice; cudaGetDevice(&saveDevice); cudaGetDeviceCount(&deviceCnt); @@ -231,7 +247,6 @@ static void initOnceFunc() { } cudaSetDevice(saveDevice); } -#endif initResult = ret; return; error: diff --git a/src/misc/ibvwrap.cc b/src/misc/ibvwrap.cc index eb4e52b606..698465ca48 100644 --- a/src/misc/ibvwrap.cc +++ b/src/misc/ibvwrap.cc @@ -8,6 +8,7 @@ #include #include +#include "ibvcore.h" #include "ibvsymbols.h" static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; @@ -53,7 +54,7 @@ ncclResult_t wrap_ibv_symbols(void) { } \ int ret = container.call; \ if (ret == ENOTSUP || ret == EOPNOTSUPP) { \ - INFO(NCCL_NET, "Call to " name " failed with error %s errno %d", strerror(ret), ret); \ + INFO(NCCL_NET, "Call to " name " not supported"); \ *supported = 0; \ return ncclSuccess; \ } else if (ret != success_retval) { \ @@ -87,6 +88,14 @@ ncclResult_t wrap_ibv_symbols(void) { container.call; \ return ncclSuccess; +NCCL_PARAM(IbMQpRetryAll, "IB_MQP_RETRY_ALL", 0); +NCCL_PARAM(IbMQpRetryCnt, "IB_MQP_RETRY_CNT", 34); +NCCL_PARAM(IbMQpRetryTimeout, "IB_MQP_RETRY_SLEEP_MSEC", 100); // in milliseconds + +#define IBV_ERR_EQ(e, code) (e == code || e == (-code)) +#define IBV_MQP_RETRY_ERRNO(e) (IBV_ERR_EQ(e, ETIMEDOUT)) +#define IBV_MQP_RETRY_ERRNO_ALL(e) (ncclParamIbMQpRetryAll() ? (e != 0) : IBV_MQP_RETRY_ERRNO(e)) + ncclResult_t wrap_ibv_fork_init() { IBV_INT_CHECK(ibvSymbols, ibv_internal_fork_init, ibv_internal_fork_init(), -1, "ibv_fork_init"); } @@ -202,8 +211,87 @@ ncclResult_t wrap_ibv_create_qp(struct ibv_qp **ret, struct ibv_pd *pd, struct i IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_create_qp, ibv_internal_create_qp(pd, qp_init_attr), *ret, NULL, "ibv_create_qp"); } -ncclResult_t wrap_ibv_modify_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ - IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_modify_qp, ibv_internal_modify_qp(qp, attr, attr_mask), 0, "ibv_modify_qp"); +static void ibvQpStateName(enum ibv_qp_state state, char* msg, const size_t len) { + switch (state) { + case (IBV_QPS_RESET): snprintf(msg, len, "RESET"); break; + case (IBV_QPS_INIT): snprintf(msg, len, "INIT"); break; + case (IBV_QPS_RTR): snprintf(msg, len, "RTR"); break; + case (IBV_QPS_RTS): snprintf(msg, len, "RTS"); break; + case (IBV_QPS_SQD): snprintf(msg, len, "SQD"); break; + case (IBV_QPS_SQE): snprintf(msg, len, "SQE"); break; + case (IBV_QPS_ERR): snprintf(msg, len, "ERR"); break; + case (IBV_QPS_UNKNOWN): snprintf(msg, len, "UNKNOWN"); break; + default: snprintf(msg, len, "NOT RECOGNIZED (%d)", state); break; + } +} + +#define QP_ATTR(attr, userAttr, userFlag, mask) ((userFlag & mask) ? (userAttr) : (attr)) + +static void ibvModifyQpLog(struct ibv_qp* qp, enum ibv_qp_state qpState, struct ibv_qp_attr* userAttr, int userFlag, char* msg, size_t msgLen) { + ncclResult_t res; + int portNum = -1, gidIndex = -1; + char localGidName[INET6_ADDRSTRLEN], remoteGidName[INET6_ADDRSTRLEN]; + const char *localGidRes = NULL, *remoteGidRes = NULL; + + char nextState[32], currState[32]; + ibvQpStateName(qp->state, currState, sizeof(currState)); + ibvQpStateName(qpState, nextState, sizeof(nextState)); + char devName[IBV_SYSFS_NAME_MAX] = ""; + snprintf(devName, sizeof(devName), "%s", (qp->pd->context) ? wrap_ibv_get_device_name(qp->pd->context->device) : "N/A"); + + struct ibv_qp_attr attr; + struct ibv_qp_init_attr init_attr; + int attr_mask = IBV_QP_PORT | IBV_QP_AV; + res = wrap_ibv_query_qp(qp, &attr, attr_mask, &init_attr); + struct ibv_qp_attr *qpAttr = (res == ncclSuccess) ? &attr : NULL; + + // port info, portAttr can be NULL if not given by the user and query_qp failed + struct ibv_qp_attr *portAttr = QP_ATTR(qpAttr, userAttr, userFlag, IBV_QP_PORT); + portNum = portAttr ? portAttr->port_num : -1; + + // address info, avAttr can be NULL if not given by the user and query_qp failed + struct ibv_qp_attr *avAttr = QP_ATTR(qpAttr, userAttr, userFlag, IBV_QP_AV); + if (avAttr && avAttr->ah_attr.is_global) { + union ibv_gid *remoteGid = &avAttr->ah_attr.grh.dgid; + remoteGidRes = ibvGetGidStr(remoteGid, remoteGidName, sizeof(remoteGidName)); + // we need pd->context to retrieve local GID, skip if not there + if (!qp->pd->context) goto print; + gidIndex = avAttr->ah_attr.grh.sgid_index; + union ibv_gid localGid; + NCCLCHECKGOTO(wrap_ibv_query_gid(qp->pd->context, portNum, gidIndex, &localGid), res, print); + localGidRes = ibvGetGidStr(&localGid, localGidName, sizeof(localGidName)); + } + +print: + snprintf(msg, msgLen, "on dev %s:%d, curr state %s, next state %s, local GID index %d, local GID %s, remote GID %s", + devName, portNum, currState, nextState, gidIndex, localGidRes ? localGidName : "N/A", remoteGidRes ? remoteGidName : "N/A"); + return; +} + +ncclResult_t wrap_ibv_modify_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask) { + char qpMsg[1024]; + int ret = 0, attempts = 0; + int maxCnt = (int)ncclParamIbMQpRetryCnt() + 1; // number of attempts = number of retry + 1 + int timeOut = (int)ncclParamIbMQpRetryTimeout(); + CHECK_NOT_NULL(ibvSymbols, ibv_internal_modify_qp); + do { + if (attempts > 0) { + unsigned int sleepTime = timeOut * attempts; + ibvModifyQpLog(qp, attr->qp_state, attr, attr_mask, qpMsg, sizeof(qpMsg)); + INFO(NCCL_NET, "Call to ibv_modify_qp failed with %d %s, %s, retrying %d/%d after %u msec of sleep", ret, strerror(ret), qpMsg, attempts, maxCnt, sleepTime); + // sleep before retrying + struct timespec tv = {.tv_sec = sleepTime / 1000, .tv_nsec = (sleepTime % 1000) * ((long)1e6)}; + nanosleep(&tv, NULL); + } + ret = ibvSymbols.ibv_internal_modify_qp(qp, attr, attr_mask); + attempts++; + } while (IBV_MQP_RETRY_ERRNO_ALL(ret) && attempts < maxCnt); + if (ret != 0) { + ibvModifyQpLog(qp, attr->qp_state, attr, attr_mask, qpMsg, sizeof(qpMsg)); + WARN("Call to ibv_modify_qp failed with %d %s, %s", ret, strerror(ret), qpMsg); + return ncclSystemError; + } + return ncclSuccess; } ncclResult_t wrap_ibv_query_ece(struct ibv_qp *qp, struct ibv_ece *ece, int* supported) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ diff --git a/src/misc/ipcsocket.cc b/src/misc/ipcsocket.cc index 2d17f47e69..23746b3c5c 100644 --- a/src/misc/ipcsocket.cc +++ b/src/misc/ipcsocket.cc @@ -189,14 +189,16 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d fd %d to UDS socket %s", hdr, hdrLen, sendFd, temp); - msg.msg_control = control_un.control; - msg.msg_controllen = sizeof(control_un.control); + if (sendFd != -1) { + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof(control_un.control); - cmptr = CMSG_FIRSTHDR(&msg); - cmptr->cmsg_len = CMSG_LEN(sizeof(int)); - cmptr->cmsg_level = SOL_SOCKET; - cmptr->cmsg_type = SCM_RIGHTS; - memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd)); + cmptr = CMSG_FIRSTHDR(&msg); + cmptr->cmsg_len = CMSG_LEN(sizeof(int)); + cmptr->cmsg_level = SOL_SOCKET; + cmptr->cmsg_type = SCM_RIGHTS; + memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd)); + } msg.msg_name = (void *)&cliaddr; msg.msg_namelen = sizeof(struct sockaddr_un); diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index 7299820455..bd5fd07aa7 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -112,9 +112,8 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) // Connect MSCCL connections mscclSetIsCallerFlag(); - int highestTransportType = TRANSPORT_P2P; bool needsProxy = false; - NCCLCHECK(ncclTransportP2pSetup(comm, NULL, 0, &highestTransportType, &needsProxy)); + NCCLCHECK(ncclTransportP2pSetup(comm, NULL, 0, &needsProxy)); status.needsProxy |= needsProxy; mscclClearIsCallerFlag(); @@ -273,11 +272,11 @@ static ncclResult_t hostToDevRedOp( break; #endif #if defined(RCCL_FLOAT8) - case ncclFp8E4M3: + case ncclFloat8e4m3: opFull->op = ncclDevPreMulSum; fp8_e4m3 = (rccl_float8)(float(1.0/comm->nRanks)); break; - case ncclFp8E5M2: + case ncclFloat8e5m2: opFull->op = ncclDevPreMulSum; fp8_e5m2 = (rccl_bfloat8)(float(1.0/comm->nRanks)); break; diff --git a/src/misc/nvmlwrap.cc b/src/misc/nvmlwrap.cc index f441af80b1..66ba2d4c85 100644 --- a/src/misc/nvmlwrap.cc +++ b/src/misc/nvmlwrap.cc @@ -311,19 +311,19 @@ ncclResult_t ncclNvmlGetCCStatus(struct ncclNvmlCCStatus *status) { status->CCEnabled = false; if (ccInfo.settingV12040.multiGpuMode == NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE) - status->multiGpuCCEnabled = true; + status->multiGpuProtectedPCIE = true; else - status->multiGpuCCEnabled = false; + status->multiGpuProtectedPCIE = false; } else if (pfn_nvmlSystemGetConfComputeState != NULL) { NVMLTRY(nvmlSystemGetConfComputeState, &ccInfo.settingV12020); if (ccInfo.settingV12020.ccFeature == NVML_CC_SYSTEM_FEATURE_ENABLED) status->CCEnabled = true; else status->CCEnabled = false; - status->multiGpuCCEnabled = false; + status->multiGpuProtectedPCIE = false; } else { status->CCEnabled = false; - status->multiGpuCCEnabled = false; + status->multiGpuProtectedPCIE = false; } return ncclSuccess; } diff --git a/src/misc/profiler.cc b/src/misc/profiler.cc index 2fa5cbaa57..aba1234d84 100644 --- a/src/misc/profiler.cc +++ b/src/misc/profiler.cc @@ -17,9 +17,110 @@ static pthread_mutex_t profilerLock = PTHREAD_MUTEX_INITIALIZER; static int profilerPluginRefCount; static void* profilerPluginLib; static ncclProfiler_t* ncclProfiler; +static ncclProfiler_v2_t ncclProfiler_v1_as_v2; +static ncclProfiler_v1_t* ncclProfiler_v1; + +static uint8_t ncclStringToFunc(const char* func) { + if (0 == strcmp(func, "AllGather")) return ncclFuncAllGather; + if (0 == strcmp(func, "AllReduce")) return ncclFuncAllReduce; + if (0 == strcmp(func, "Broadcast")) return ncclFuncBroadcast; + if (0 == strcmp(func, "Recv")) return ncclFuncRecv; + if (0 == strcmp(func, "Reduce")) return ncclFuncReduce; + if (0 == strcmp(func, "ReduceScatter")) return ncclFuncReduceScatter; + if (0 == strcmp(func, "SendRecv")) return ncclFuncSendRecv; + return ncclFuncSend; +} + +static uint8_t ncclStringToAlgo(const char* algo) { + if (0 == strcmp(algo, "TREE")) return NCCL_ALGO_TREE; + if (0 == strcmp(algo, "RING")) return NCCL_ALGO_RING; + if (0 == strcmp(algo, "COLLNET_DIRECT")) return NCCL_ALGO_COLLNET_DIRECT; + if (0 == strcmp(algo, "COLLNET_CHAIN")) return NCCL_ALGO_COLLNET_CHAIN; + if (0 == strcmp(algo, "NVLS")) return NCCL_ALGO_NVLS; + if (0 == strcmp(algo, "NVLS_TREE")) return NCCL_ALGO_NVLS_TREE; + return NCCL_ALGO_PAT; +} + +static uint8_t ncclStringToProto(const char* proto) { + if (0 == strcmp(proto, "LL")) return NCCL_PROTO_LL; + if (0 == strcmp(proto, "LL128")) return NCCL_PROTO_LL128; + return NCCL_PROTO_SIMPLE; +} + +static uint8_t ncclStringToDatatype(const char* dt) { + if (0 == strcmp(dt, "ncclInt8")) return ncclInt8; + if (0 == strcmp(dt, "ncclInt32")) return ncclInt32; + if (0 == strcmp(dt, "ncclUint32")) return ncclUint32; + if (0 == strcmp(dt, "ncclInt64")) return ncclInt64; + if (0 == strcmp(dt, "ncclUint64")) return ncclUint64; + if (0 == strcmp(dt, "ncclFloat16")) return ncclFloat16; + if (0 == strcmp(dt, "ncclFloat32")) return ncclFloat32; +#if defined(__CUDA_BF16_TYPES_EXIST__) + if (0 == strcmp(dt, "ncclBfloat16")) return ncclBfloat16; +#endif + return ncclFloat64; +} + +static ncclResult_t ncclProfiler_v1_as_v2_startEvent(void* context, void** eHandle, ncclProfilerEventDescr_v2_t* eDescr) { + ncclProfilerEventDescr_v1_t eDescr_v1 = { 0 }; + eDescr_v1.type = eDescr->type; + eDescr_v1.parentObj = eDescr->parentObj; + eDescr_v1.rank = eDescr->rank; + switch(eDescr->type) { + case ncclProfileGroup: break; + case ncclProfileColl: { + eDescr_v1.coll.name = eDescr->coll.name; + eDescr_v1.coll.commHash = eDescr->coll.commHash; + eDescr_v1.coll.seqNumber = eDescr->coll.seqNumber; + eDescr_v1.coll.func = ncclStringToFunc(eDescr->coll.func); + eDescr_v1.coll.sendBuff = eDescr->coll.sendBuff; + eDescr_v1.coll.recvBuff = eDescr->coll.recvBuff; + eDescr_v1.coll.count = eDescr->coll.count; + eDescr_v1.coll.root = eDescr->coll.root; + eDescr_v1.coll.datatype = ncclStringToDatatype(eDescr->coll.datatype); + eDescr_v1.coll.op = 0; // removed in v2 + eDescr_v1.coll.trafficBytes = eDescr->coll.trafficBytes; + eDescr_v1.coll.nMaxChannels = eDescr->coll.nMaxChannels; + eDescr_v1.coll.nWarps = eDescr->coll.nWarps; + eDescr_v1.coll.algo = ncclStringToAlgo(eDescr->coll.algo); + eDescr_v1.coll.proto = ncclStringToProto(eDescr->coll.proto); + } break; + case ncclProfileP2p: { + eDescr_v1.p2p.name = eDescr->p2p.name; + eDescr_v1.p2p.commHash = eDescr->p2p.commHash; + eDescr_v1.p2p.func = ncclStringToFunc(eDescr->p2p.func); + eDescr_v1.p2p.buff = eDescr->p2p.buff; + eDescr_v1.p2p.count = eDescr->p2p.count; + eDescr_v1.p2p.datatype = ncclStringToDatatype(eDescr->p2p.datatype); + eDescr_v1.p2p.peer = eDescr->p2p.peer; + } break; + case ncclProfileProxyOp: { + eDescr_v1.proxyOp.pid = eDescr->proxyOp.pid; + eDescr_v1.proxyOp.channelId = eDescr->proxyOp.channelId; + eDescr_v1.proxyOp.peer = eDescr->proxyOp.peer; + eDescr_v1.proxyOp.nSteps = eDescr->proxyOp.nSteps; + eDescr_v1.proxyOp.chunkSize = eDescr->proxyOp.chunkSize; + eDescr_v1.proxyOp.isSend = eDescr->proxyOp.isSend; + } break; + case ncclProfileProxyStep: { + eDescr_v1.proxyStep.step = eDescr->proxyStep.step; + } break; + case ncclProfileProxyCtrl: break; + default:; + } + return ncclProfiler_v1->startEvent(context, eHandle, &eDescr_v1); +} + +static ncclResult_t ncclProfiler_v1_as_v2_init(void** context, int* eActivationMask) { + ncclProfiler_v1->init(context, eActivationMask); + ncclProfiler_v1_as_v2.startEvent = ncclProfiler_v1_as_v2_startEvent; + ncclProfiler_v1_as_v2.stopEvent = ncclProfiler_v1->stopEvent; + ncclProfiler_v1_as_v2.recordEventState = ncclProfiler_v1->recordEventState; + ncclProfiler_v1_as_v2.finalize = ncclProfiler_v1->finalize; + return ncclSuccess; +} #define MAX_STR_LEN 256 -#define NCCL_PROFILER_PLUGIN_SYMBOL "ncclProfiler_v1" static void* tryOpenLib(char* name, int *err, char* errStr) { if (nullptr == name || strlen(name) == 0) { @@ -34,7 +135,7 @@ static void* tryOpenLib(char* name, int *err, char* errStr) { if (nullptr == handle) { strncpy(errStr, dlerror(), MAX_STR_LEN); errStr[MAX_STR_LEN] = 0; - if (strstr(errStr, name) && strstr(errStr, "No such file or directory")) { + if (name && strstr(errStr, name) && strstr(errStr, "No such file or directory")) { *err = ENOENT; } } @@ -117,10 +218,21 @@ static ncclResult_t ncclProfilerPluginLoad(void) { goto fail; } - ncclProfiler = (ncclProfiler_t*)dlsym(profilerPluginLib, NCCL_PROFILER_PLUGIN_SYMBOL); + ncclProfiler = (ncclProfiler_v2_t*)dlsym(profilerPluginLib, "ncclProfiler_v2"); if (ncclProfiler == nullptr) { - INFO(NCCL_INIT|NCCL_ENV, "PROFILER/Plugin: failed to find " NCCL_PROFILER_PLUGIN_SYMBOL "."); - goto fail; + INFO(NCCL_INIT|NCCL_ENV, "PROFILER/Plugin: failed to find ncclProfiler_v2."); + ncclProfiler_v1 = (ncclProfiler_v1_t*)dlsym(profilerPluginLib, "ncclProfiler_v1"); + if (ncclProfiler_v1 == nullptr) { + INFO(NCCL_INIT|NCCL_ENV, "PROFILER/Plugin: failed to find ncclProfiler_v1."); + goto fail; + } else { + ncclProfiler = &ncclProfiler_v1_as_v2; + ncclProfiler_v1_as_v2.name = ncclProfiler_v1->name; + ncclProfiler_v1_as_v2.init = ncclProfiler_v1_as_v2_init; + INFO(NCCL_INIT|NCCL_ENV, "PROFILER/Plugin: loaded ncclProfiler_v1."); + } + } else { + INFO(NCCL_INIT|NCCL_ENV, "PROFILER/Plugin: loaded ncclProfiler_v2."); } ++profilerPluginRefCount; @@ -248,7 +360,7 @@ ncclResult_t ncclProfilerStartGroupEvent(struct ncclKernelPlan* plan) { eActivationMaskGroup = __atomic_load_n(&eActivationMask, __ATOMIC_RELAXED); if (__builtin_expect(ncclProfiler != NULL, 0)) { if (eActivationMaskGroup & (ncclProfileColl | ncclProfileP2p | ncclProfileProxyOp | ncclProfileProxyStep)) { - ncclProfilerEventDescr_v1_t eDescr = { 0 }; + ncclProfilerEventDescr_t eDescr = { 0 }; eDescr.type = ncclProfileGroup; ncclProfiler->startEvent(plan->comm->profilerContext, &plan->groupEventHandle, &eDescr); } @@ -280,20 +392,17 @@ ncclResult_t ncclProfilerStartTaskEvents(struct ncclKernelPlan* plan) { eDescr.coll.name = plan->comm->commName; eDescr.coll.commHash = plan->comm->commHash; eDescr.coll.seqNumber = plan->comm->seqNumber[ct->func]++; - eDescr.coll.func = ct->func; + eDescr.coll.func = ncclFuncToString(ct->func); eDescr.coll.sendBuff = ct->sendbuff; eDescr.coll.recvBuff = ct->recvbuff; eDescr.coll.count = ct->count; eDescr.coll.root = ct->root; - eDescr.coll.datatype = ct->datatype; - eDescr.coll.op = ct->opHost; + eDescr.coll.datatype = ncclDatatypeToString(ct->datatype); eDescr.coll.trafficBytes = ct->trafficBytes; eDescr.coll.nMaxChannels = ct->nMaxChannels; eDescr.coll.nWarps = ct->nWarps; - eDescr.coll.algo = ct->algorithm; - eDescr.coll.proto = ct->protocol; - eDescr.coll.isCollnet = ct->isCollnet; - eDescr.coll.isNvls = ct->isNvls; + eDescr.coll.algo = ncclAlgoToString(ct->algorithm); + eDescr.coll.proto = ncclProtoToString(ct->protocol); ncclProfiler->startEvent(plan->comm->profilerContext, &ct->eventHandle, &eDescr); // update collective task with group event activation mask @@ -308,10 +417,10 @@ ncclResult_t ncclProfilerStartTaskEvents(struct ncclKernelPlan* plan) { eDescr.rank = plan->comm->rank; eDescr.p2p.name = plan->comm->commName; eDescr.p2p.commHash = plan->comm->commHash; - eDescr.p2p.func = pt->func; + eDescr.p2p.func = ncclFuncToString(pt->func); eDescr.p2p.buff = pt->buff; eDescr.p2p.count = pt->count; - eDescr.p2p.datatype = pt->datatype; + eDescr.p2p.datatype = ncclDatatypeToString(pt->datatype); eDescr.p2p.peer = pt->root; ncclProfiler->startEvent(plan->comm->profilerContext, &pt->eventHandle, &eDescr); @@ -346,6 +455,11 @@ ncclResult_t ncclProfilerStopTaskEvents(struct ncclKernelPlan* plan) { return ncclSuccess; } +// Bellow we set the proxy descriptor step number to DIVUP(step, args->sliceSteps). +// The reason is that for some ncclOp (e.g. AllReduce) one network transfer is +// made of sliceSteps steps rather than one step. In the profiler we are still +// interested in whole network transfers though, so we account for this when +// computing the actual network step number. ncclResult_t ncclProfilerStartSendProxyOpEvent(int s, struct ncclProxyArgs* args) { TIME_START_EVENT(proxyOpStart); struct ncclProxySubArgs* sub = &args->subs[s]; @@ -355,13 +469,13 @@ ncclResult_t ncclProfilerStartSendProxyOpEvent(int s, struct ncclProxyArgs* args eDescr.type = ncclProfileProxyOp; eDescr.parentObj = sub->taskEventHandle; eDescr.rank = sub->rank; - eDescr.proxyOp.pid = args->pid; + eDescr.proxyOp.pid = sub->pid; eDescr.proxyOp.channelId = sub->channelId; eDescr.proxyOp.peer = sub->peer; - eDescr.proxyOp.nSteps = sub->nsteps; - eDescr.proxyOp.chunkSize = args->chunkSize; + eDescr.proxyOp.nSteps = DIVUP(sub->nsteps, args->sliceSteps); + eDescr.proxyOp.chunkSize = args->chunkSize * args->sliceSteps; eDescr.proxyOp.isSend = 1; - ncclProfiler->startEvent(args->profilerContext, &sub->opEventHandle, &eDescr); + ncclProfiler->startEvent(sub->profilerContext, &sub->opEventHandle, &eDescr); } } TIME_STOP_EVENT(proxyOpStart); @@ -377,13 +491,13 @@ ncclResult_t ncclProfilerStartRecvProxyOpEvent(int s, struct ncclProxyArgs* args eDescr.type = ncclProfileProxyOp; eDescr.parentObj = sub->taskEventHandle; eDescr.rank = sub->rank; - eDescr.proxyOp.pid = args->pid; + eDescr.proxyOp.pid = sub->pid; eDescr.proxyOp.channelId = sub->channelId; eDescr.proxyOp.peer = sub->peer; - eDescr.proxyOp.nSteps = sub->nsteps; - eDescr.proxyOp.chunkSize = args->chunkSize; + eDescr.proxyOp.nSteps = DIVUP(sub->nsteps, args->sliceSteps); + eDescr.proxyOp.chunkSize = args->chunkSize * args->sliceSteps; eDescr.proxyOp.isSend = 0; - ncclProfiler->startEvent(args->profilerContext, &sub->opEventHandle, &eDescr); + ncclProfiler->startEvent(sub->profilerContext, &sub->opEventHandle, &eDescr); } } TIME_STOP_EVENT(proxyOpStart); @@ -401,53 +515,50 @@ ncclResult_t ncclProfilerStopProxyOpEvent(int s, struct ncclProxyArgs* args) { return ncclSuccess; } -ncclResult_t ncclProfilerStartSendProxyStepEvents(int s, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi) { +ncclResult_t ncclProfilerStartSendProxyStepEvent(int s, struct ncclProxyArgs* args, int stepId) { TIME_START_EVENT(proxyStepStart); struct ncclProxySubArgs* sub = &args->subs[s]; if (__builtin_expect(ncclProfiler != NULL, 0)) { if (sub->opEventHandle && (sub->eActivationMask & ncclProfileProxyStep)) { - for (uint64_t step = stepLo; step < stepHi; step++) { - ncclProfilerEventDescr_t eDescr = { 0 }; - eDescr.type = ncclProfileProxyStep; - eDescr.parentObj = sub->opEventHandle; - eDescr.rank = sub->rank; - eDescr.proxyStep.step = step; - ncclProfiler->startEvent(args->profilerContext, &sub->stepEventHandles[step%NCCL_STEPS], &eDescr); - } + int step_ = DIVUP(stepId, args->sliceSteps); + ncclProfilerEventDescr_t eDescr = { 0 }; + eDescr.type = ncclProfileProxyStep; + eDescr.parentObj = sub->opEventHandle; + eDescr.rank = sub->rank; + eDescr.proxyStep.step = step_; + ncclProfiler->startEvent(sub->profilerContext, &sub->stepEventHandles[step_%NCCL_STEPS], &eDescr); } } TIME_STOP_EVENT(proxyStepStart); return ncclSuccess; } -ncclResult_t ncclProfilerStartRecvProxyStepEvents(int s, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi) { +ncclResult_t ncclProfilerStartRecvProxyStepEvent(int s, struct ncclProxyArgs* args, int stepId) { TIME_START_EVENT(proxyStepStart); struct ncclProxySubArgs* sub = &args->subs[s]; if (__builtin_expect(ncclProfiler != NULL, 0)) { if (sub->opEventHandle && (sub->eActivationMask & ncclProfileProxyStep)) { - for (uint64_t step = stepLo; step < stepHi; step++) { - ncclProfilerEventDescr_t eDescr = { 0 }; - eDescr.type = ncclProfileProxyStep; - eDescr.parentObj = sub->opEventHandle; - eDescr.rank = sub->rank; - eDescr.proxyStep.step = step; - ncclProfiler->startEvent(args->profilerContext, &sub->stepEventHandles[step%NCCL_STEPS], &eDescr); - } + int step_ = DIVUP(stepId, args->sliceSteps); + ncclProfilerEventDescr_t eDescr = { 0 }; + eDescr.type = ncclProfileProxyStep; + eDescr.parentObj = sub->opEventHandle; + eDescr.rank = sub->rank; + eDescr.proxyStep.step = step_; + ncclProfiler->startEvent(sub->profilerContext, &sub->stepEventHandles[step_%NCCL_STEPS], &eDescr); } } TIME_STOP_EVENT(proxyStepStart); return ncclSuccess; } -ncclResult_t ncclProfilerStopProxyStepEvents(int s, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi) { +ncclResult_t ncclProfilerStopProxyStepEvent(int s, struct ncclProxyArgs* args, int stepId) { TIME_START_EVENT(proxyStepStop); struct ncclProxySubArgs* sub = &args->subs[s]; if (__builtin_expect(ncclProfiler != NULL, 0)) { - for (uint64_t step = stepLo; step < stepHi; step++) { - if (sub->stepEventHandles[step%NCCL_STEPS]) { - ncclProfiler->stopEvent(sub->stepEventHandles[step%NCCL_STEPS]); - sub->stepEventHandles[step%NCCL_STEPS] = NULL; - } + int step_ = DIVUP(stepId, args->sliceSteps); + if (sub->stepEventHandles[step_%NCCL_STEPS]) { + ncclProfiler->stopEvent(sub->stepEventHandles[step_%NCCL_STEPS]); + sub->stepEventHandles[step_%NCCL_STEPS] = NULL; } } TIME_STOP_EVENT(proxyStepStop); @@ -485,8 +596,8 @@ ncclResult_t ncclProfilerRecordProxyOpEventState(int s, struct ncclProxyArgs* ar TIME_START_EVENT(proxyOpRecord); struct ncclProxySubArgs* sub = &args->subs[s]; if (__builtin_expect(ncclProfiler != NULL, 0) && sub->opEventHandle) { - ncclProfilerEventStateArgs_t a = { 0 }; - a.proxyOp.steps = steps; + ncclProfilerEventStateArgs_t a = { }; + a.proxyOp.steps = DIVUP(steps, args->sliceSteps); a.proxyOp.transSize = transSize; ncclProfiler->recordEventState(sub->opEventHandle, eState, &a); } @@ -494,14 +605,13 @@ ncclResult_t ncclProfilerRecordProxyOpEventState(int s, struct ncclProxyArgs* ar return ncclSuccess; } -ncclResult_t ncclProfilerRecordProxyStepEventStates(int s, struct ncclProxyArgs* args, uint64_t stepLo, uint64_t stepHi, ncclProfilerEventState_t eState) { +ncclResult_t ncclProfilerRecordProxyStepEventState(int s, struct ncclProxyArgs* args, int stepId, ncclProfilerEventState_t eState) { TIME_START_EVENT(proxyStepRecord); struct ncclProxySubArgs* sub = &args->subs[s]; if (__builtin_expect(ncclProfiler != NULL, 0) && sub->opEventHandle) { - for (uint64_t step = stepLo; step < stepHi; step++) { - if (sub->stepEventHandles[step%NCCL_STEPS]) { - ncclProfiler->recordEventState(sub->stepEventHandles[step%NCCL_STEPS], eState, 0); - } + int step_ = DIVUP(stepId, args->sliceSteps); + if (sub->stepEventHandles[step_%NCCL_STEPS]) { + ncclProfiler->recordEventState(sub->stepEventHandles[step_%NCCL_STEPS], eState, 0); } } TIME_STOP_EVENT(proxyStepRecord); @@ -511,7 +621,7 @@ ncclResult_t ncclProfilerRecordProxyStepEventStates(int s, struct ncclProxyArgs* ncclResult_t ncclProfilerRecordProxyCtrlEventState(void* eHandle, int appended, ncclProfilerEventState_t eState) { TIME_START_EVENT(proxyCtrlRecord); if (__builtin_expect(ncclProfiler != NULL, 0) && eHandle && __atomic_load_n(&eActivationMask, __ATOMIC_RELAXED) & ncclProfileProxyCtrl) { - ncclProfilerEventStateArgs_t args = { 0 }; + ncclProfilerEventStateArgs_t args = { }; args.proxyCtrl.appendedProxyOps = appended; ncclProfiler->recordEventState(eHandle, eState, &args); } diff --git a/src/misc/shmutils.cc b/src/misc/shmutils.cc index daf3b338db..eb9cd10156 100644 --- a/src/misc/shmutils.cc +++ b/src/misc/shmutils.cc @@ -45,7 +45,7 @@ static void shmHandleInit(int fd, char* shmPath, size_t shmSize, size_t realShmS return; } -ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle) { +ncclResult_t ncclShmOpen(char* shmPath, size_t shmPathSize, size_t shmSize, void** shmPtr, void** devShmPtr, int refcount, ncclShmHandle_t* handle) { int fd = -1; char* hptr = NULL; void* dptr = NULL; @@ -62,7 +62,7 @@ ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** de * refcount references; when the peer attaches, it should pass -1 to reduce one reference count. When it * goes down to 0, unlink should be called in order to delete shared memory file. */ if (shmPath[0] == '\0') { - sprintf(shmPath, "/dev/shm/nccl-XXXXXX"); + snprintf(shmPath, shmPathSize, "/dev/shm/nccl-XXXXXX"); retry_mkstemp: fd = mkstemp(shmPath); if (fd < 0) { @@ -70,7 +70,7 @@ ncclResult_t ncclShmOpen(char* shmPath, size_t shmSize, void** shmPtr, void** de INFO(NCCL_ALL, "mkstemp: Failed to create %s, error: %s (%d) - retrying", shmPath, strerror(errno), errno); goto retry_mkstemp; } - WARN("Error: failed to create shared memory file %p, error %s (%d)", shmPath, strerror(errno), errno); + WARN("Error: failed to create shared memory file %s, error %s (%d)", shmPath, strerror(errno), errno); ret = ncclSystemError; goto fail; } diff --git a/src/misc/socket.cc b/src/misc/socket.cc index dd65598d6f..41246e0913 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -14,6 +14,18 @@ #include #include #include "param.h" +#include + +NCCL_PARAM(RetryCnt, "SOCKET_RETRY_CNT", 34); +NCCL_PARAM(RetryTimeOut, "SOCKET_RETRY_SLEEP_MSEC", 100); +static void msleep(unsigned int time_msec) { + const long c_1e6 = 1e6; + struct timespec tv = (struct timespec){ + .tv_sec = time_msec / 1000, + .tv_nsec = (time_msec % 1000) * c_1e6, + }; + nanosleep(&tv, NULL); +} RCCL_PARAM(SocketReuseAddr, "SOCKET_REUSEADDR", 0); RCCL_PARAM(SocketLinger, "SOCKET_LINGER", -1); @@ -31,8 +43,13 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr return ncclSuccess; } if (bytes == -1) { + if ((op == NCCL_SOCKET_SEND && errno == EPIPE) || (op == NCCL_SOCKET_RECV && errno == ECONNRESET)) { + *closed = 1; + return ncclSuccess; + } if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { - WARN("socketProgressOpt: Call to recv from %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); + WARN("socketProgressOpt: Call to %s %s failed : %s", (op == NCCL_SOCKET_RECV ? "recv from" : "send to"), + ncclSocketToString(&sock->addr, line), strerror(errno)); return ncclRemoteError; } else { bytes = 0; @@ -43,17 +60,22 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr INFO(NCCL_NET, "socketProgressOpt: abort called"); return ncclInternalError; } - } while (bytes > 0 && (*offset) < size); + } while (sock->asyncFlag == 0 && bytes > 0 && (*offset) < size); return ncclSuccess; } -static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { +static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* pclosed = NULL) { int closed; NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed)); if (closed) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); - return ncclRemoteError; + if (pclosed) { + *pclosed = closed; + return ncclSuccess; + } else { + char line[SOCKET_NAME_MAXLEN+1]; + WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); + return ncclRemoteError; + } } return ncclSuccess; } @@ -68,9 +90,9 @@ static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int s * * Output: "IPv4/IPv6 address" */ -const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm /*= 1*/) { +const char *ncclSocketToString(const union ncclSocketAddress *addr, char *buf, const int numericHostForm /*= 1*/) { if (buf == NULL || addr == NULL) return NULL; - struct sockaddr *saddr = &addr->sa; + const struct sockaddr *saddr = &addr->sa; if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; } char host[NI_MAXHOST], service[NI_MAXSERV]; /* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned. @@ -375,10 +397,9 @@ ncclResult_t ncclSocketListen(struct ncclSocket* sock) { if (socketToPort(&sock->addr)) { // Port is forced by env. Make sure we get the port. int opt = 1; -#if defined(SO_REUSEPORT) - SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); -#else SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); +#if defined(SO_REUSEPORT) + SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); #endif } @@ -417,6 +438,15 @@ static ncclResult_t socketTryAccept(struct ncclSocket* sock) { sock->fd = accept(sock->acceptFd, (struct sockaddr*)&sock->addr, &socklen); if (sock->fd != -1) { sock->state = ncclSocketStateAccepted; + } else if (errno == ENETDOWN || errno == EPROTO || errno == ENOPROTOOPT || errno == EHOSTDOWN || + errno == ENONET || errno == EHOSTUNREACH || errno == EOPNOTSUPP || errno == ENETUNREACH) { + /* per accept's man page, for linux sockets, the following errors might be already pending errors + * and should be considered as EAGAIN. To avoid infinite loop in case of errors, we use the retry count*/ + if (++sock->errorRetries == ncclParamRetryCnt()) { + WARN("socketTryAccept: exceeded error retry count (%d), %s", sock->errorRetries, strerror(errno)); + return ncclSystemError; + } + INFO(NCCL_ALL, "Call to accept returned %s, retrying", strerror(errno)); } else if (errno != EAGAIN && errno != EWOULDBLOCK) { WARN("socketTryAccept: Accept failed: %s", strerror(errno)); return ncclSystemError; @@ -424,72 +454,118 @@ static ncclResult_t socketTryAccept(struct ncclSocket* sock) { return ncclSuccess; } +static ncclResult_t socketSetFlags(struct ncclSocket* sock) { + const int one = 1; + /* Set socket as non-blocking if async or if we need to be able to abort */ + if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) { + int flags; + SYSCHECK(flags = fcntl(sock->fd, F_GETFL), "fcntl"); + SYSCHECK(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); + } + SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); + return ncclSuccess; +} + static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { uint64_t magic; enum ncclSocketType type; - int received = 0; - const int one = 1; - SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); + int received; + // once accepted, linux sockets do NOT inherit file status flags such as O_NONBLOCK (BSD ones do) + NCCLCHECK(socketSetFlags(sock)); - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); - if (received == 0) return ncclSuccess; - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); - if (magic != sock->magic) { - WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); - close(sock->fd); - sock->fd = -1; - // Ignore spurious connection and accept again - sock->state = ncclSocketStateAccepting; - return ncclSuccess; - } else { - received = 0; - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received)); - if (type != sock->type) { - WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type); - sock->state = ncclSocketStateError; + if (sock->asyncFlag == 0 || sock->finalizeCounter < sizeof(magic)) { + if (sock->asyncFlag == 0) { + received = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); + } else { + received = sock->finalizeCounter; + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, sock, sock->finalizeBuffer, sizeof(magic), &received)); + sock->finalizeCounter = received; + if (received < sizeof(magic)) return ncclSuccess; + memcpy(&magic, sock->finalizeBuffer, sizeof(magic)); + } + if (magic != sock->magic) { + WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); close(sock->fd); sock->fd = -1; - return ncclInternalError; - } else { - sock->state = ncclSocketStateReady; + // Ignore spurious connection and accept again + sock->state = ncclSocketStateAccepting; + return ncclSuccess; } } + if (sock->asyncFlag == 0) { + received = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received)); + } else { + received = sock->finalizeCounter - sizeof(magic); + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, sock, sock->finalizeBuffer, sizeof(type), &received)); + sock->finalizeCounter = received + sizeof(magic); + if (received < sizeof(type)) return ncclSuccess; + memcpy(&type, sock->finalizeBuffer, sizeof(type)); + } + if (type != sock->type) { + WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type); + sock->state = ncclSocketStateError; + close(sock->fd); + sock->fd = -1; + return ncclInternalError; + } else { + sock->state = ncclSocketStateReady; + } return ncclSuccess; } -static ncclResult_t socketStartConnect(struct ncclSocket* sock) { - /* blocking/non-blocking connect() is determined by asyncFlag. */ - int ret = connect(sock->fd, &sock->addr.sa, sock->salen); - - if (ret == 0) { +static ncclResult_t socketResetFd(struct ncclSocket* sock) { + ncclResult_t ret = ncclSuccess; + int fd = -1; + SYSCHECKGOTO(fd = socket(sock->addr.sa.sa_family, SOCK_STREAM, 0), "socket", ret, cleanup); + // if sock->fd is valid, close it and reuse its number + if (sock->fd != -1) { + SYSCHECKGOTO(dup2(fd, sock->fd), "dup2", ret, cleanup); + SYSCHECKGOTO(close(fd), "close", ret, cleanup); + } else { + sock->fd = fd; + } + NCCLCHECKGOTO(socketSetFlags(sock), ret, exit); +exit: + return ret; +cleanup: + // cleanup fd, leave sock->fd untouched + if (fd != -1) { + (void)close(fd); + } + goto exit; +} +static ncclResult_t socketConnectCheck(struct ncclSocket* sock, int errCode, const char funcName[]) { + if (errCode == 0) { sock->state = ncclSocketStateConnected; - return ncclSuccess; - } else if (errno == EINPROGRESS) { + } else if (errCode == EINPROGRESS) { sock->state = ncclSocketStateConnectPolling; - return ncclSuccess; - } else if (errno == ECONNREFUSED) { - if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries); - return ncclRemoteError; + } else if (errCode == ETIMEDOUT || errCode == EHOSTUNREACH || errCode == ECONNREFUSED) { + if (sock->customRetry == 0) { + if (sock->errorRetries++ == ncclParamRetryCnt()) { + sock->state = ncclSocketStateError; + WARN("%s: connect returned %s, exceeded error retry count (%d)", funcName, strerror(errCode), sock->errorRetries); + return ncclRemoteError; + } + unsigned int sleepTime = sock->errorRetries * ncclParamRetryTimeOut(); + INFO(NCCL_ALL, "%s: connect returned %s, retrying (%d/%ld) after sleep for %u msec", funcName, strerror(errCode), sock->errorRetries, ncclParamRetryCnt(), sleepTime); + msleep(sleepTime); } - usleep(SLEEP_INT); - if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); - return ncclSuccess; - } else if (errno == ETIMEDOUT) { - if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries); - return ncclRemoteError; - } - usleep(SLEEP_INT); - return ncclSuccess; + NCCLCHECK(socketResetFd(sock)); /* in case of failure in connect, socket state is unspecified */ + sock->state = ncclSocketStateConnecting; } else { char line[SOCKET_NAME_MAXLEN+1]; sock->state = ncclSocketStateError; - WARN("socketStartConnect: Connect to %s failed : %s", ncclSocketToString(&sock->addr, line), strerror(errno)); + WARN("%s: Connect to %s failed : %s", funcName, ncclSocketToString(&sock->addr, line), strerror(errCode)); return ncclSystemError; } + return ncclSuccess; +} +static ncclResult_t socketStartConnect(struct ncclSocket* sock) { + /* blocking/non-blocking connect() is determined by asyncFlag. */ + int ret = connect(sock->fd, &sock->addr.sa, sock->salen); + return socketConnectCheck(sock, (ret == -1) ? errno : 0, __func__); } static ncclResult_t socketPollConnect(struct ncclSocket* sock) { @@ -514,33 +590,7 @@ static ncclResult_t socketPollConnect(struct ncclSocket* sock) { /* check socket status */ SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt"); - - if (ret == 0) { - sock->state = ncclSocketStateConnected; - } else if (ret == ECONNREFUSED) { - if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries); - return ncclRemoteError; - } - if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); - usleep(SLEEP_INT); - sock->state = ncclSocketStateConnecting; - } else if (ret == ETIMEDOUT) { - if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries); - return ncclRemoteError; - } - usleep(SLEEP_INT); - sock->state = ncclSocketStateConnecting; - } else if (ret != EINPROGRESS) { - sock->state = ncclSocketStateError; - char line[SOCKET_NAME_MAXLEN+1]; - WARN("socketPollConnect: Connect to %s returned %d(%s) errno %d(%s)", ncclSocketToString(&sock->addr, line), ret, strerror(ret), errno, strerror(errno)); - return ncclSystemError; - } - return ncclSuccess; + return socketConnectCheck(sock, ret, __func__); } ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) { @@ -553,12 +603,24 @@ ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) { } static ncclResult_t socketFinalizeConnect(struct ncclSocket* sock) { - int sent = 0; - NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); - if (sent == 0) return ncclSuccess; - NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); - sent = 0; - NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + int sent; + if (sock->asyncFlag == 0) { + sent = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + sent = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + } else { + if (sock->finalizeCounter < sizeof(sock->magic)) { + sent = sock->finalizeCounter; + NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + sock->finalizeCounter = sent; + if (sent < sizeof(sock->magic)) return ncclSuccess; + } + sent = sock->finalizeCounter - sizeof(sock->magic); + NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + sock->finalizeCounter = sent + sizeof(sock->magic); + if (sent < sizeof(sock->type)) return ncclSuccess; + } sock->state = ncclSocketStateReady; return ncclSuccess; } @@ -601,7 +663,6 @@ ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running) { ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { char line[SOCKET_NAME_MAXLEN+1]; - const int one = 1; if (sock == NULL) { WARN("ncclSocketConnect: pass NULL socket"); @@ -619,9 +680,8 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { } TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", ncclSocketToString(&sock->addr, line)); - SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); - sock->state = ncclSocketStateConnecting; + sock->finalizeCounter = 0; do { NCCLCHECK(socketProgressState(sock)); } while (sock->asyncFlag == 0 && @@ -667,6 +727,7 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen memcpy(sock, listenSock, sizeof(struct ncclSocket)); sock->acceptFd = listenSock->fd; sock->state = ncclSocketStateAccepting; + sock->finalizeCounter = 0; } do { @@ -697,12 +758,11 @@ exit: return ret; } -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) { +ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag, int customRetry) { ncclResult_t ret = ncclSuccess; if (sock == NULL) goto exit; - sock->timedOutRetries = 0; - sock->refusedRetries = 0; + sock->errorRetries = 0; sock->abortFlag = abortFlag; sock->asyncFlag = asyncFlag; sock->state = ncclSocketStateInitialized; @@ -710,6 +770,7 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* ad sock->type = type; sock->fd = -1; sock->acceptFd = -1; + sock->customRetry = customRetry; if (addr) { /* IPv4/IPv6 support */ @@ -721,17 +782,11 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* ad WARN("ncclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6); ret = ncclInternalError; - goto fail; + goto exit; } sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); - - /* Connect to a hostname / port */ - sock->fd = socket(family, SOCK_STREAM, 0); - if (sock->fd == -1) { - WARN("ncclSocketInit: Socket creation failed : %s", strerror(errno)); - ret = ncclSystemError; - goto fail; - } + // in case of error, we close the fd before returning as it's unclear if the caller has to use ncclSocketClose for cleanup + NCCLCHECKGOTO(socketResetFd(sock), ret, fail); // [RCCL] Runtime socket options if (rcclParamSocketReuseAddr()) { @@ -746,14 +801,6 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* ad } else { memset(&sock->addr, 0, sizeof(union ncclSocketAddress)); } - - /* Set socket as non-blocking if async or if we need to be able to abort */ - if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) { - int flags; - SYSCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), "fcntl", ret, fail); - SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), "fcntl", ret, fail); - } - exit: return ret; fail: @@ -764,12 +811,12 @@ fail: goto exit; } -ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { +ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* closed) { if (sock == NULL) { WARN("ncclSocketProgress: pass NULL socket"); return ncclInvalidArgument; } - NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); + NCCLCHECK(socketProgress(op, sock, ptr, size, offset, closed)); return ncclSuccess; } @@ -802,7 +849,7 @@ ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { WARN("ncclSocketRecv: pass NULL socket"); return ncclInvalidArgument; } - if (sock->state != ncclSocketStateReady) { + if (sock->state != ncclSocketStateReady && sock->state != ncclSocketStateTerminating) { WARN("ncclSocketRecv: socket state (%d) is not ready", sock->state); return ncclInternalError; } @@ -816,7 +863,8 @@ ncclResult_t ncclSocketSendRecv(struct ncclSocket* sendSock, void* sendPtr, int WARN("ncclSocketSendRecv: invalid socket %p/%p", sendSock, recvSock); return ncclInternalError; } - if (sendSock->state != ncclSocketStateReady || recvSock->state != ncclSocketStateReady) { + if (sendSock->state != ncclSocketStateReady || + (recvSock->state != ncclSocketStateReady && recvSock->state != ncclSocketStateTerminating)) { WARN("ncclSocketSendRecv: socket state (%d/%d) is not ready", sendSock->state, recvSock->state); return ncclInternalError; } @@ -860,9 +908,20 @@ ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int return ncclSuccess; } -ncclResult_t ncclSocketClose(struct ncclSocket* sock) { +// Make it possible to close just one part of a socket. +ncclResult_t ncclSocketShutdown(struct ncclSocket* sock, int how) { if (sock != NULL) { if (sock->fd >= 0) { + shutdown(sock->fd, how); + } + sock->state = ncclSocketStateTerminating; + } + return ncclSuccess; +} + +ncclResult_t ncclSocketClose(struct ncclSocket* sock) { + if (sock != NULL) { + if (sock->state > ncclSocketStateNone && sock->state < ncclSocketStateNum && sock->fd >= 0) { /* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected * by refcount of fd, but close() is. close() won't close a fd and send FIN packet if * the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful diff --git a/src/misc/tuner.cc b/src/misc/tuner.cc index 830470ded4..61f57c78b9 100644 --- a/src/misc/tuner.cc +++ b/src/misc/tuner.cc @@ -16,9 +16,11 @@ pthread_mutex_t tunerPluginLock = PTHREAD_MUTEX_INITIALIZER; static int tunerPluginRefCount; static void* tunerPluginLib = nullptr; -static ncclTuner_v3_t* tunerSymbol = nullptr; +static ncclTuner_v4_t* tunerSymbol = nullptr; +static ncclTuner_v3_t* ncclTuner_v3 = nullptr; static ncclTuner_v2_t* ncclTuner_v2 = nullptr; -static ncclTuner_v3_t ncclTuner_v2_as_v3; +static ncclTuner_v4_t ncclTuner_v2_as_v4; +static ncclTuner_v4_t ncclTuner_v3_as_v4; static int hasNvlsSupport(float** collCostTable) { // Requirements for support of different algorithms: @@ -39,7 +41,20 @@ static int hasCollNetSupport(float** collCostTable) { return (table[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] == NCCL_ALGO_PROTO_IGNORE) ? 0 : 1; } -static ncclResult_t ncclTuner_v2_as_v3_getCollInfo(void* context, ncclFunc_t collType, size_t nBytes, int numPipeOps, float** collCostTable, int numAlgo __attribute__((unused)), int numProto __attribute__((unused)), int* nChannels) { +static ncclResult_t ncclTuner_v3_as_v4_getCollInfo(void* context, ncclFunc_t collType, size_t nBytes, int numPipeOps, float** collCostTable, int numAlgo, int numProto, int regBuff __attribute__((unused)), int* nChannels) { + NCCLCHECK(ncclTuner_v3->getCollInfo(context, collType, nBytes, numPipeOps, collCostTable, numAlgo, numProto, nChannels)); + return ncclSuccess; +} + +static ncclResult_t ncclTuner_v3_as_v4_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction, void** context) { + NCCLCHECK(ncclTuner_v3->init(nRanks, nNodes, logFunction, context)); + ncclTuner_v3_as_v4.name = ncclTuner_v3->name; + ncclTuner_v3_as_v4.getCollInfo = ncclTuner_v3_as_v4_getCollInfo; + ncclTuner_v3_as_v4.destroy = ncclTuner_v3->destroy; + return ncclSuccess; +} + +static ncclResult_t ncclTuner_v2_as_v4_getCollInfo(void* context, ncclFunc_t collType, size_t nBytes, int numPipeOps, float** collCostTable, int numAlgo __attribute__((unused)), int numProto __attribute__((unused)), int regBuff __attribute__((unused)), int* nChannels) { int algorithm = NCCL_ALGO_UNDEF; int protocol = NCCL_PROTO_UNDEF; int nvlsSupport = hasNvlsSupport(collCostTable); @@ -53,11 +68,11 @@ static ncclResult_t ncclTuner_v2_as_v3_getCollInfo(void* context, ncclFunc_t col return ncclSuccess; } -static ncclResult_t ncclTuner_v2_as_v3_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction, void** context) { +static ncclResult_t ncclTuner_v2_as_v4_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction, void** context) { NCCLCHECK(ncclTuner_v2->init(nRanks, nNodes, logFunction, context)); - ncclTuner_v2_as_v3.name = ncclTuner_v2->name; - ncclTuner_v2_as_v3.getCollInfo = ncclTuner_v2_as_v3_getCollInfo; - ncclTuner_v2_as_v3.destroy = ncclTuner_v2->destroy; + ncclTuner_v2_as_v4.name = ncclTuner_v2->name; + ncclTuner_v2_as_v4.getCollInfo = ncclTuner_v2_as_v4_getCollInfo; + ncclTuner_v2_as_v4.destroy = ncclTuner_v2->destroy; return ncclSuccess; } @@ -198,18 +213,26 @@ ncclResult_t ncclTunerPluginLoad(struct ncclComm* comm) { goto fail; } - tunerSymbol = (ncclTuner_v3_t*)dlsym(tunerPluginLib, "ncclTunerPlugin_v3"); + tunerSymbol = (ncclTuner_v4_t*)dlsym(tunerPluginLib, "ncclTunerPlugin_v4"); if (tunerSymbol == nullptr) { - INFO(NCCL_ENV|NCCL_TUNING, "TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol."); - ncclTuner_v2 = (ncclTuner_v2_t*)dlsym(tunerPluginLib, "ncclTunerPlugin_v2"); - if (ncclTuner_v2 == nullptr) { - INFO(NCCL_ENV|NCCL_TUNING, "TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead."); - dlclose(tunerPluginLib); - goto fail; + INFO(NCCL_ENV|NCCL_TUNING, "TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol."); + ncclTuner_v3 = (ncclTuner_v3_t*)dlsym(tunerPluginLib, "ncclTunerPlugin_v3"); + if (ncclTuner_v3 == nullptr) { + INFO(NCCL_ENV|NCCL_TUNING, "TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol."); + ncclTuner_v2 = (ncclTuner_v2_t*)dlsym(tunerPluginLib, "ncclTunerPlugin_v2"); + if (ncclTuner_v2 == nullptr) { + INFO(NCCL_ENV|NCCL_TUNING, "TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead."); + dlclose(tunerPluginLib); + goto fail; + } else { + ncclTuner_v2_as_v4.init = ncclTuner_v2_as_v4_init; + ncclTuner_v2_as_v4.name = ncclTuner_v2->name; + tunerSymbol = &ncclTuner_v2_as_v4; + } } else { - ncclTuner_v2_as_v3.init = ncclTuner_v2_as_v3_init; - ncclTuner_v2_as_v3.name = ncclTuner_v2->name; - tunerSymbol = &ncclTuner_v2_as_v3; + ncclTuner_v3_as_v4.init = ncclTuner_v3_as_v4_init; + ncclTuner_v3_as_v4.name = ncclTuner_v3->name; + tunerSymbol = &ncclTuner_v3_as_v4; } } diff --git a/src/nccl.h.in b/src/nccl.h.in index f658e9c9b2..e6c51e05db 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -306,6 +306,12 @@ const char* ncclGetLastError(ncclComm_t comm); const char* pncclGetLastError(ncclComm_t comm); /*! @endcond */ +/* Reload environment variables that determine logging. */ +void ncclResetDebugInit(); +/*! @cond include_hidden */ +void pncclResetDebugInit(); +/*! @endcond */ + /*! @brief Checks whether the comm has encountered any asynchronous errors @details Query whether the provided communicator has encountered any asynchronous errors @return Result code. See @ref rccl_result_code for more details. @@ -407,13 +413,10 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, ncclBfloat16 = 9, -#if defined(RCCL_FLOAT8) - ncclFp8E4M3 = 10, - ncclFp8E5M2 = 11, - ncclNumTypes = 12 } ncclDataType_t; -#else - ncclNumTypes = 10 } ncclDataType_t; -#endif + ncclFloat8e4m3 = 10, + ncclFloat8e5m2 = 11, + ncclNumTypes = 12 +} ncclDataType_t; /*! @} */ /*! @defgroup rccl_api_custom_redop Custom Reduction Operator diff --git a/src/net.cc b/src/net.cc index a06e0d1d53..815f8c4ba6 100644 --- a/src/net.cc +++ b/src/net.cc @@ -16,20 +16,95 @@ //#include //#include -static ncclNet_v8_t ncclNet_v5_as_v8; -static ncclNet_v8_t ncclNet_v6_as_v8; -static ncclNet_v8_t ncclNet_v7_as_v8; +static ncclNet_v9_t ncclNet_v5_as_v9; +static ncclNet_v9_t ncclNet_v6_as_v9; +static ncclNet_v9_t ncclNet_v7_as_v9; +static ncclNet_v9_t ncclNet_v8_as_v9; static ncclNet_v5_t *ncclNet_v5; static ncclNet_v6_t *ncclNet_v6; static ncclNet_v7_t *ncclNet_v7; -static ncclCollNet_v8_t ncclCollNet_v5_as_v8; -static ncclCollNet_v8_t ncclCollNet_v6_as_v8; -static ncclCollNet_v8_t ncclCollNet_v7_as_v8; +static ncclNet_v8_t *ncclNet_v8; +static ncclCollNet_v9_t ncclCollNet_v5_as_v9; +static ncclCollNet_v9_t ncclCollNet_v6_as_v9; +static ncclCollNet_v9_t ncclCollNet_v7_as_v9; +static ncclCollNet_v9_t ncclCollNet_v8_as_v9; static ncclCollNet_v5_t *ncclCollNet_v5; static ncclCollNet_v6_t *ncclCollNet_v6; static ncclCollNet_v7_t *ncclCollNet_v7; +static ncclCollNet_v8_t *ncclCollNet_v8; -static ncclResult_t ncclNet_v7_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { +#define MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two. +#define MAX_COLLNET_SIZE (512*1024*1024L) //Set for initial collent plugins when size was not dynamically queried + +static ncclResult_t ncclNet_v8_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { + ncclNetProperties_v8_t p8; + ncclResult_t ans = ncclNet_v8->getProperties(dev, &p8); + if (ans != ncclSuccess) return ans; + props->name = p8.name; + props->pciPath = p8.pciPath; + props->guid = p8.guid; + props->ptrSupport = p8.ptrSupport; + props->regIsGlobal = p8.regIsGlobal; + props->forceFlush = 0; + props->speed = p8.speed; + props->port = p8.port; + props->maxComms = p8.maxComms; + props->maxRecvs = p8.maxRecvs; + props->latency = p8.latency; + props->netDeviceType = p8.netDeviceType; + props->netDeviceVersion = p8.netDeviceVersion; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; + return ncclSuccess; +} + +static ncclResult_t ncclNet_v8_as_v9_isend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { + int sizeInt; + if (size > MAX_NET_SIZE) return ncclInternalError; + sizeInt = (int)size; + ncclResult_t ans = ncclNet_v8->isend(sendComm, data, sizeInt, tag, mhandle, request); + return ans; +} + +static ncclResult_t ncclNet_v8_as_v9_irecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { + int sizesInt[NCCL_PROXY_MAX_SUBS]; + //reset to NULL if optional receive completion is set + if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) *request = NULL; + for (int i=0; i MAX_NET_SIZE) return ncclInternalError; + sizesInt[i] = (int) sizes[i]; + } + ncclResult_t ans = ncclNet_v8->irecv(recvComm, n, data, sizesInt, tags, mhandles, request); + return ans; +} + +static ncclResult_t ncclNet_v8_as_v9_init(ncclDebugLogger_t logfn) { + NCCLCHECK(ncclNet_v8->init(logfn)); + ncclNet_v8_as_v9.name = ncclNet_v8->name; + ncclNet_v8_as_v9.devices = ncclNet_v8->devices; + ncclNet_v8_as_v9.getProperties = ncclNet_v8_as_v9_getProperties; + ncclNet_v8_as_v9.listen = ncclNet_v8->listen; + ncclNet_v8_as_v9.connect = ncclNet_v8->connect; + ncclNet_v8_as_v9.accept = ncclNet_v8->accept; + ncclNet_v8_as_v9.regMr = ncclNet_v8->regMr; + ncclNet_v8_as_v9.regMrDmaBuf = ncclNet_v8->regMrDmaBuf; + ncclNet_v8_as_v9.deregMr = ncclNet_v8->deregMr; + ncclNet_v8_as_v9.isend = ncclNet_v8_as_v9_isend; + ncclNet_v8_as_v9.irecv = ncclNet_v8_as_v9_irecv; + ncclNet_v8_as_v9.iflush = ncclNet_v8->iflush; + ncclNet_v8_as_v9.test = ncclNet_v8->test; + ncclNet_v8_as_v9.closeSend = ncclNet_v8->closeSend; + ncclNet_v8_as_v9.closeRecv = ncclNet_v8->closeRecv; + ncclNet_v8_as_v9.closeListen = ncclNet_v8->closeListen; + ncclNet_v8_as_v9.getDeviceMr = ncclNet_v8->getDeviceMr; + ncclNet_v8_as_v9.irecvConsumed = ncclNet_v8->irecvConsumed; + ncclNet_v8_as_v9.makeVDevice = NULL; + return ncclSuccess; +} + +static ncclResult_t ncclNet_v7_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { ncclNetProperties_v7_t p7; ncclResult_t ans = ncclNet_v7->getProperties(dev, &p7); if (ans != ncclSuccess) return ans; @@ -38,6 +113,7 @@ static ncclResult_t ncclNet_v7_as_v8_getProperties(int dev, ncclNetProperties_v8 props->guid = p7.guid; props->ptrSupport = p7.ptrSupport; props->regIsGlobal = 0; + props->forceFlush = 0; props->speed = p7.speed; props->port = p7.port; props->maxComms = p7.maxComms; @@ -45,38 +121,63 @@ static ncclResult_t ncclNet_v7_as_v8_getProperties(int dev, ncclNetProperties_v8 props->latency = p7.latency; props->netDeviceType = p7.netDeviceType; props->netDeviceVersion = p7.netDeviceVersion; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; return ncclSuccess; } -static ncclResult_t ncclNet_v7_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { +static ncclResult_t ncclNet_v7_as_v9_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { if (size >= 1UL<<31) return ncclInternalError; return ncclNet_v7->regMr(comm, data, (int) size, type, mhandle); } -static ncclResult_t ncclNet_v7_as_v8_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclNet_v7_as_v9_isend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { + int sizeInt; + if (size > MAX_NET_SIZE) return ncclInternalError; + sizeInt = (int)size; + ncclResult_t ans = ncclNet_v7->isend(sendComm, data, sizeInt, tag, mhandle, request); + return ans; +} + +static ncclResult_t ncclNet_v7_as_v9_irecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { + int sizesInt[NCCL_PROXY_MAX_SUBS]; + //reset to NULL if optional receive completion is set + if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) *request = NULL; + for (int i=0; i MAX_NET_SIZE) return ncclInternalError; + sizesInt[i] = (int) sizes[i]; + } + ncclResult_t ans = ncclNet_v7->irecv(recvComm, n, data, sizesInt, tags, mhandles, request); + return ans; +} + +static ncclResult_t ncclNet_v7_as_v9_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclNet_v7->init(logfn)); - ncclNet_v7_as_v8.name = ncclNet_v7->name; - ncclNet_v7_as_v8.devices = ncclNet_v7->devices; - ncclNet_v7_as_v8.getProperties = ncclNet_v7_as_v8_getProperties; // ncclNet_v5->getProperties; - ncclNet_v7_as_v8.listen = ncclNet_v7->listen; - ncclNet_v7_as_v8.connect = ncclNet_v7->connect; - ncclNet_v7_as_v8.accept = ncclNet_v7->accept; - ncclNet_v7_as_v8.regMr = ncclNet_v7_as_v8_regMr; - ncclNet_v7_as_v8.regMrDmaBuf = ncclNet_v7->regMrDmaBuf; - ncclNet_v7_as_v8.deregMr = ncclNet_v7->deregMr; - ncclNet_v7_as_v8.isend = ncclNet_v7->isend; - ncclNet_v7_as_v8.irecv = ncclNet_v7->irecv; - ncclNet_v7_as_v8.iflush = ncclNet_v7->iflush; - ncclNet_v7_as_v8.test = ncclNet_v7->test; - ncclNet_v7_as_v8.closeSend = ncclNet_v7->closeSend; - ncclNet_v7_as_v8.closeRecv = ncclNet_v7->closeRecv; - ncclNet_v7_as_v8.closeListen = ncclNet_v7->closeListen; - ncclNet_v7_as_v8.getDeviceMr = ncclNet_v7->getDeviceMr; - ncclNet_v7_as_v8.irecvConsumed = ncclNet_v7->irecvConsumed; + ncclNet_v7_as_v9.name = ncclNet_v7->name; + ncclNet_v7_as_v9.devices = ncclNet_v7->devices; + ncclNet_v7_as_v9.getProperties = ncclNet_v7_as_v9_getProperties; // ncclNet_v5->getProperties; + ncclNet_v7_as_v9.listen = ncclNet_v7->listen; + ncclNet_v7_as_v9.connect = ncclNet_v7->connect; + ncclNet_v7_as_v9.accept = ncclNet_v7->accept; + ncclNet_v7_as_v9.regMr = ncclNet_v7_as_v9_regMr; + ncclNet_v7_as_v9.regMrDmaBuf = ncclNet_v7->regMrDmaBuf; + ncclNet_v7_as_v9.deregMr = ncclNet_v7->deregMr; + ncclNet_v7_as_v9.isend = ncclNet_v7_as_v9_isend; + ncclNet_v7_as_v9.irecv = ncclNet_v7_as_v9_irecv; + ncclNet_v7_as_v9.iflush = ncclNet_v7->iflush; + ncclNet_v7_as_v9.test = ncclNet_v7->test; + ncclNet_v7_as_v9.closeSend = ncclNet_v7->closeSend; + ncclNet_v7_as_v9.closeRecv = ncclNet_v7->closeRecv; + ncclNet_v7_as_v9.closeListen = ncclNet_v7->closeListen; + ncclNet_v7_as_v9.getDeviceMr = ncclNet_v7->getDeviceMr; + ncclNet_v7_as_v9.irecvConsumed = ncclNet_v7->irecvConsumed; + ncclNet_v7_as_v9.makeVDevice = NULL; return ncclSuccess; } -static ncclResult_t ncclNet_v6_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { +static ncclResult_t ncclNet_v6_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclNet_v6->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -85,6 +186,7 @@ static ncclResult_t ncclNet_v6_as_v8_getProperties(int dev, ncclNetProperties_v8 props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; props->regIsGlobal = 0; + props->forceFlush = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -92,46 +194,71 @@ static ncclResult_t ncclNet_v6_as_v8_getProperties(int dev, ncclNetProperties_v8 props->latency = p6.latency; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; return ncclSuccess; } -static ncclResult_t ncclNet_v6_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { +static ncclResult_t ncclNet_v6_as_v9_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { if (size >= 1UL<<31) return ncclInternalError; return ncclNet_v6->regMr(comm, data, (int) size, type, mhandle); } -static ncclResult_t ncclNet_v6_as_v8_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { +static ncclResult_t ncclNet_v6_as_v9_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { return ncclNet_v6->connect(dev, handle, sendComm); } -static ncclResult_t ncclNet_v6_as_v8_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { +static ncclResult_t ncclNet_v6_as_v9_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { return ncclNet_v6->accept(listenComm, recvComm); } -static ncclResult_t ncclNet_v6_as_v8_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclNet_v6_as_v9_isend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { + int sizeInt; + if (size > MAX_NET_SIZE) return ncclInternalError; + sizeInt = (int)size; + ncclResult_t ans = ncclNet_v6->isend(sendComm, data, sizeInt, tag, mhandle, request); + return ans; +} + +static ncclResult_t ncclNet_v6_as_v9_irecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { + int sizesInt[NCCL_PROXY_MAX_SUBS]; + //reset to NULL if optional receive completion is set + if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) *request = NULL; + for (int i=0; i MAX_NET_SIZE) return ncclInternalError; + sizesInt[i] = (int) sizes[i]; + } + ncclResult_t ans = ncclNet_v6->irecv(recvComm, n, data, sizesInt, tags, mhandles, request); + return ans; +} + +static ncclResult_t ncclNet_v6_as_v9_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclNet_v6->init(logfn)); - ncclNet_v6_as_v8.name = ncclNet_v6->name; - ncclNet_v6_as_v8.devices = ncclNet_v6->devices; - ncclNet_v6_as_v8.getProperties = ncclNet_v6_as_v8_getProperties; // ncclNet_v5->getProperties; - ncclNet_v6_as_v8.listen = ncclNet_v6->listen; - ncclNet_v6_as_v8.connect = ncclNet_v6_as_v8_connect; - ncclNet_v6_as_v8.accept = ncclNet_v6_as_v8_accept; - ncclNet_v6_as_v8.regMr = ncclNet_v6_as_v8_regMr; - ncclNet_v6_as_v8.regMrDmaBuf = ncclNet_v6->regMrDmaBuf; - ncclNet_v6_as_v8.deregMr = ncclNet_v6->deregMr; - ncclNet_v6_as_v8.isend = ncclNet_v6->isend; - ncclNet_v6_as_v8.irecv = ncclNet_v6->irecv; - ncclNet_v6_as_v8.iflush = ncclNet_v6->iflush; - ncclNet_v6_as_v8.test = ncclNet_v6->test; - ncclNet_v6_as_v8.closeSend = ncclNet_v6->closeSend; - ncclNet_v6_as_v8.closeRecv = ncclNet_v6->closeRecv; - ncclNet_v6_as_v8.closeListen = ncclNet_v6->closeListen; - ncclNet_v6_as_v8.getDeviceMr = NULL; - ncclNet_v6_as_v8.irecvConsumed = NULL; + ncclNet_v6_as_v9.name = ncclNet_v6->name; + ncclNet_v6_as_v9.devices = ncclNet_v6->devices; + ncclNet_v6_as_v9.getProperties = ncclNet_v6_as_v9_getProperties; + ncclNet_v6_as_v9.listen = ncclNet_v6->listen; + ncclNet_v6_as_v9.connect = ncclNet_v6_as_v9_connect; + ncclNet_v6_as_v9.accept = ncclNet_v6_as_v9_accept; + ncclNet_v6_as_v9.regMr = ncclNet_v6_as_v9_regMr; + ncclNet_v6_as_v9.regMrDmaBuf = ncclNet_v6->regMrDmaBuf; + ncclNet_v6_as_v9.deregMr = ncclNet_v6->deregMr; + ncclNet_v6_as_v9.isend = ncclNet_v6_as_v9_isend; + ncclNet_v6_as_v9.irecv = ncclNet_v6_as_v9_irecv; + ncclNet_v6_as_v9.iflush = ncclNet_v6->iflush; + ncclNet_v6_as_v9.test = ncclNet_v6->test; + ncclNet_v6_as_v9.closeSend = ncclNet_v6->closeSend; + ncclNet_v6_as_v9.closeRecv = ncclNet_v6->closeRecv; + ncclNet_v6_as_v9.closeListen = ncclNet_v6->closeListen; + ncclNet_v6_as_v9.getDeviceMr = NULL; + ncclNet_v6_as_v9.irecvConsumed = NULL; + ncclNet_v6_as_v9.makeVDevice = NULL; return ncclSuccess; } -static ncclResult_t ncclNet_v5_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { +static ncclResult_t ncclNet_v5_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclNet_v5->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -140,6 +267,7 @@ static ncclResult_t ncclNet_v5_as_v8_getProperties(int dev, ncclNetProperties_v8 props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; props->regIsGlobal = 0; + props->forceFlush = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -147,48 +275,73 @@ static ncclResult_t ncclNet_v5_as_v8_getProperties(int dev, ncclNetProperties_v8 props->latency = p6.latency; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; return ncclSuccess; } -static ncclResult_t ncclNet_v5_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { +static ncclResult_t ncclNet_v5_as_v9_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { if (size >= 1UL<<31) return ncclInternalError; return ncclNet_v5->regMr(comm, data, (int) size, type, mhandle); } -static ncclResult_t ncclNet_v5_as_v8_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { +static ncclResult_t ncclNet_v5_as_v9_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { return ncclNet_v5->connect(dev, handle, sendComm); } -static ncclResult_t ncclNet_v5_as_v8_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { +static ncclResult_t ncclNet_v5_as_v9_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { return ncclNet_v5->accept(listenComm, recvComm); } +static ncclResult_t ncclNet_v5_as_v9_isend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { + int sizeInt; + if (size > MAX_NET_SIZE) return ncclInternalError; + sizeInt = (int)size; + ncclResult_t ans = ncclNet_v5->isend(sendComm, data, sizeInt, tag, mhandle, request); + return ans; +} + +static ncclResult_t ncclNet_v5_as_v9_irecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { + int sizesInt[NCCL_PROXY_MAX_SUBS]; + //reset to NULL if optional receive completion is set + if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) *request = NULL; + for (int i=0; i MAX_NET_SIZE) return ncclInternalError; + sizesInt[i] = (int) sizes[i]; + } + ncclResult_t ans = ncclNet_v5->irecv(recvComm, n, data, sizesInt, tags, mhandles, request); + return ans; +} + // We use a wrapper around the v5 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclNet_v5_as_v8_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclNet_v5_as_v9_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclNet_v5->init(logfn)); - ncclNet_v5_as_v8.name = ncclNet_v5->name; - ncclNet_v5_as_v8.devices = ncclNet_v5->devices; - ncclNet_v5_as_v8.getProperties = ncclNet_v5_as_v8_getProperties; - ncclNet_v5_as_v8.listen = ncclNet_v5->listen; - ncclNet_v5_as_v8.connect = ncclNet_v5_as_v8_connect; - ncclNet_v5_as_v8.accept = ncclNet_v5_as_v8_accept; - ncclNet_v5_as_v8.regMr = ncclNet_v5_as_v8_regMr; - ncclNet_v5_as_v8.regMrDmaBuf = NULL; - ncclNet_v5_as_v8.deregMr = ncclNet_v5->deregMr; - ncclNet_v5_as_v8.isend = ncclNet_v5->isend; - ncclNet_v5_as_v8.irecv = ncclNet_v5->irecv; - ncclNet_v5_as_v8.iflush = ncclNet_v5->iflush; - ncclNet_v5_as_v8.test = ncclNet_v5->test; - ncclNet_v5_as_v8.closeSend = ncclNet_v5->closeSend; - ncclNet_v5_as_v8.closeRecv = ncclNet_v5->closeRecv; - ncclNet_v5_as_v8.closeListen = ncclNet_v5->closeListen; - ncclNet_v5_as_v8.getDeviceMr = NULL; - ncclNet_v5_as_v8.irecvConsumed = NULL; + ncclNet_v5_as_v9.name = ncclNet_v5->name; + ncclNet_v5_as_v9.devices = ncclNet_v5->devices; + ncclNet_v5_as_v9.getProperties = ncclNet_v5_as_v9_getProperties; + ncclNet_v5_as_v9.listen = ncclNet_v5->listen; + ncclNet_v5_as_v9.connect = ncclNet_v5_as_v9_connect; + ncclNet_v5_as_v9.accept = ncclNet_v5_as_v9_accept; + ncclNet_v5_as_v9.regMr = ncclNet_v5_as_v9_regMr; + ncclNet_v5_as_v9.regMrDmaBuf = NULL; + ncclNet_v5_as_v9.deregMr = ncclNet_v5->deregMr; + ncclNet_v5_as_v9.isend = ncclNet_v5_as_v9_isend; + ncclNet_v5_as_v9.irecv = ncclNet_v5_as_v9_irecv; + ncclNet_v5_as_v9.iflush = ncclNet_v5->iflush; + ncclNet_v5_as_v9.test = ncclNet_v5->test; + ncclNet_v5_as_v9.closeSend = ncclNet_v5->closeSend; + ncclNet_v5_as_v9.closeRecv = ncclNet_v5->closeRecv; + ncclNet_v5_as_v9.closeListen = ncclNet_v5->closeListen; + ncclNet_v5_as_v9.getDeviceMr = NULL; + ncclNet_v5_as_v9.irecvConsumed = NULL; + ncclNet_v5_as_v9.makeVDevice = NULL; return ncclSuccess; } -static ncclResult_t ncclCollNet_v5_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { +static ncclResult_t ncclCollNet_v5_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclCollNet_v5->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -197,6 +350,7 @@ static ncclResult_t ncclCollNet_v5_as_v8_getProperties(int dev, ncclNetPropertie props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; props->regIsGlobal = 0; + props->forceFlush = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -204,38 +358,52 @@ static ncclResult_t ncclCollNet_v5_as_v8_getProperties(int dev, ncclNetPropertie props->latency = p6.latency; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; return ncclSuccess; } -static ncclResult_t ncclCollNet_v5_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { +static ncclResult_t ncclCollNet_v5_as_v9_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { if (size >= 1UL<<31) return ncclInternalError; return ncclCollNet_v5->regMr(comm, data, (int) size, type, mhandle); } +static ncclResult_t ncclCollNet_v5_as_v9_iallreduce(void* collComm, void* sendData, void* recvData, size_t count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { + int countInt; + if (count > MAX_NET_SIZE) return ncclInternalError; + countInt = (int)count; + ncclResult_t ans = ncclCollNet_v5->iallreduce(collComm, sendData, recvData, countInt, dataType, redOp, + sendMhandle, recvMhandle, request); + return ans; +} + // We use a wrapper around the v5 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclCollNet_v5_as_v8_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclCollNet_v5_as_v9_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclCollNet_v5->init(logfn)); - ncclCollNet_v5_as_v8.name = ncclCollNet_v5->name; - ncclCollNet_v5_as_v8.devices = ncclCollNet_v5->devices; - ncclCollNet_v5_as_v8.getProperties = ncclCollNet_v5_as_v8_getProperties; - ncclCollNet_v5_as_v8.listen = ncclCollNet_v5->listen; - ncclCollNet_v5_as_v8.connect = ncclCollNet_v5->connect; - ncclCollNet_v5_as_v8.reduceSupport = ncclCollNet_v5->reduceSupport; - ncclCollNet_v5_as_v8.regMr = ncclCollNet_v5_as_v8_regMr; - ncclCollNet_v5_as_v8.regMrDmaBuf = NULL; - ncclCollNet_v5_as_v8.deregMr = ncclCollNet_v5->deregMr; - ncclCollNet_v5_as_v8.iallreduce = ncclCollNet_v5->iallreduce; - ncclCollNet_v5_as_v8.iallgather = nullptr; - ncclCollNet_v5_as_v8.ireducescatter = nullptr; - ncclCollNet_v5_as_v8.iflush = ncclCollNet_v5->iflush; - ncclCollNet_v5_as_v8.test = ncclCollNet_v5->test; - ncclCollNet_v5_as_v8.closeColl = ncclCollNet_v5->closeColl; - ncclCollNet_v5_as_v8.closeListen = ncclCollNet_v5->closeListen; + ncclCollNet_v5_as_v9.name = ncclCollNet_v5->name; + ncclCollNet_v5_as_v9.devices = ncclCollNet_v5->devices; + ncclCollNet_v5_as_v9.getProperties = ncclCollNet_v5_as_v9_getProperties; + ncclCollNet_v5_as_v9.listen = ncclCollNet_v5->listen; + ncclCollNet_v5_as_v9.connect = ncclCollNet_v5->connect; + ncclCollNet_v5_as_v9.reduceSupport = ncclCollNet_v5->reduceSupport; + ncclCollNet_v5_as_v9.regMr = ncclCollNet_v5_as_v9_regMr; + ncclCollNet_v5_as_v9.regMrDmaBuf = NULL; + ncclCollNet_v5_as_v9.deregMr = ncclCollNet_v5->deregMr; + ncclCollNet_v5_as_v9.iallreduce = ncclCollNet_v5_as_v9_iallreduce; + ncclCollNet_v5_as_v9.iallgather = nullptr; + ncclCollNet_v5_as_v9.ireducescatter = nullptr; + ncclCollNet_v5_as_v9.iflush = ncclCollNet_v5->iflush; + ncclCollNet_v5_as_v9.test = ncclCollNet_v5->test; + ncclCollNet_v5_as_v9.closeColl = ncclCollNet_v5->closeColl; + ncclCollNet_v5_as_v9.closeListen = ncclCollNet_v5->closeListen; return ncclSuccess; } -static ncclResult_t ncclCollNet_v6_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { +static ncclResult_t ncclCollNet_v6_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclCollNet_v6->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -244,6 +412,7 @@ static ncclResult_t ncclCollNet_v6_as_v8_getProperties(int dev, ncclNetPropertie props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; props->regIsGlobal = 0; + props->forceFlush = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -251,38 +420,52 @@ static ncclResult_t ncclCollNet_v6_as_v8_getProperties(int dev, ncclNetPropertie props->latency = p6.latency; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; return ncclSuccess; } -static ncclResult_t ncclCollNet_v6_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { +static ncclResult_t ncclCollNet_v6_as_v9_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { if (size >= 1UL<<31) return ncclInternalError; return ncclCollNet_v6->regMr(comm, data, (int) size, type, mhandle); } +static ncclResult_t ncclCollNet_v6_as_v9_iallreduce(void* collComm, void* sendData, void* recvData, size_t count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { + int countInt; + if (count > MAX_NET_SIZE) return ncclInternalError; + countInt = (int)count; + ncclResult_t ans = ncclCollNet_v6->iallreduce(collComm, sendData, recvData, countInt, dataType, redOp, + sendMhandle, recvMhandle, request); + return ans; +} + // We use a wrapper around the v6 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclCollNet_v6_as_v8_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclCollNet_v6_as_v9_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclCollNet_v6->init(logfn)); - ncclCollNet_v6_as_v8.name = ncclCollNet_v6->name; - ncclCollNet_v6_as_v8.devices = ncclCollNet_v6->devices; - ncclCollNet_v6_as_v8.getProperties = ncclCollNet_v6_as_v8_getProperties; - ncclCollNet_v6_as_v8.listen = ncclCollNet_v6->listen; - ncclCollNet_v6_as_v8.connect = ncclCollNet_v6->connect; - ncclCollNet_v6_as_v8.reduceSupport = ncclCollNet_v6->reduceSupport; - ncclCollNet_v6_as_v8.regMr = ncclCollNet_v6_as_v8_regMr; - ncclCollNet_v6_as_v8.regMrDmaBuf = ncclCollNet_v6->regMrDmaBuf; - ncclCollNet_v6_as_v8.deregMr = ncclCollNet_v6->deregMr; - ncclCollNet_v6_as_v8.iallreduce = ncclCollNet_v6->iallreduce; - ncclCollNet_v6_as_v8.iallgather = nullptr; - ncclCollNet_v6_as_v8.ireducescatter = nullptr; - ncclCollNet_v6_as_v8.iflush = ncclCollNet_v6->iflush; - ncclCollNet_v6_as_v8.test = ncclCollNet_v6->test; - ncclCollNet_v6_as_v8.closeColl = ncclCollNet_v6->closeColl; - ncclCollNet_v6_as_v8.closeListen = ncclCollNet_v6->closeListen; + ncclCollNet_v6_as_v9.name = ncclCollNet_v6->name; + ncclCollNet_v6_as_v9.devices = ncclCollNet_v6->devices; + ncclCollNet_v6_as_v9.getProperties = ncclCollNet_v6_as_v9_getProperties; + ncclCollNet_v6_as_v9.listen = ncclCollNet_v6->listen; + ncclCollNet_v6_as_v9.connect = ncclCollNet_v6->connect; + ncclCollNet_v6_as_v9.reduceSupport = ncclCollNet_v6->reduceSupport; + ncclCollNet_v6_as_v9.regMr = ncclCollNet_v6_as_v9_regMr; + ncclCollNet_v6_as_v9.regMrDmaBuf = ncclCollNet_v6->regMrDmaBuf; + ncclCollNet_v6_as_v9.deregMr = ncclCollNet_v6->deregMr; + ncclCollNet_v6_as_v9.iallreduce = ncclCollNet_v6_as_v9_iallreduce; + ncclCollNet_v6_as_v9.iallgather = nullptr; + ncclCollNet_v6_as_v9.ireducescatter = nullptr; + ncclCollNet_v6_as_v9.iflush = ncclCollNet_v6->iflush; + ncclCollNet_v6_as_v9.test = ncclCollNet_v6->test; + ncclCollNet_v6_as_v9.closeColl = ncclCollNet_v6->closeColl; + ncclCollNet_v6_as_v9.closeListen = ncclCollNet_v6->closeListen; return ncclSuccess; } -static ncclResult_t ncclCollNet_v7_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { +static ncclResult_t ncclCollNet_v7_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { ncclNetProperties_v7_t p7; ncclResult_t ans = ncclCollNet_v7->getProperties(dev, &p7); if (ans != ncclSuccess) return ans; @@ -291,6 +474,7 @@ static ncclResult_t ncclCollNet_v7_as_v8_getProperties(int dev, ncclNetPropertie props->guid = p7.guid; props->ptrSupport = p7.ptrSupport; props->regIsGlobal = 0; + props->forceFlush = 0; props->speed = p7.speed; props->port = p7.port; props->maxComms = p7.maxComms; @@ -298,47 +482,150 @@ static ncclResult_t ncclCollNet_v7_as_v8_getProperties(int dev, ncclNetPropertie props->latency = p7.latency; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; return ncclSuccess; } -static ncclResult_t ncclCollNet_v7_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { +static ncclResult_t ncclCollNet_v7_as_v9_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { if (size >= 1UL<<31) return ncclInternalError; return ncclCollNet_v7->regMr(comm, data, (int) size, type, mhandle); } +static ncclResult_t ncclCollNet_v7_as_v9_iallreduce(void* collComm, void* sendData, void* recvData, size_t count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { + int countInt; + if (count > MAX_NET_SIZE) return ncclInternalError; + countInt = (int)count; + ncclResult_t ans = ncclCollNet_v7->iallreduce(collComm, sendData, recvData, countInt, dataType, redOp, + sendMhandle, recvMhandle, request); + return ans; +} + // We use a wrapper around the v7 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclCollNet_v7_as_v8_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclCollNet_v7_as_v9_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclCollNet_v7->init(logfn)); - ncclCollNet_v7_as_v8.name = ncclCollNet_v7->name; - ncclCollNet_v7_as_v8.devices = ncclCollNet_v7->devices; - ncclCollNet_v7_as_v8.getProperties = ncclCollNet_v7_as_v8_getProperties; - ncclCollNet_v7_as_v8.listen = ncclCollNet_v7->listen; - ncclCollNet_v7_as_v8.connect = ncclCollNet_v7->connect; - ncclCollNet_v7_as_v8.reduceSupport = ncclCollNet_v7->reduceSupport; - ncclCollNet_v7_as_v8.regMr = ncclCollNet_v7_as_v8_regMr; - ncclCollNet_v7_as_v8.regMrDmaBuf = ncclCollNet_v7->regMrDmaBuf; - ncclCollNet_v7_as_v8.deregMr = ncclCollNet_v7->deregMr; - ncclCollNet_v7_as_v8.iallreduce = ncclCollNet_v7->iallreduce; - ncclCollNet_v7_as_v8.iallgather = nullptr; - ncclCollNet_v7_as_v8.ireducescatter = nullptr; - ncclCollNet_v7_as_v8.iflush = ncclCollNet_v7->iflush; - ncclCollNet_v7_as_v8.test = ncclCollNet_v7->test; - ncclCollNet_v7_as_v8.closeColl = ncclCollNet_v7->closeColl; - ncclCollNet_v7_as_v8.closeListen = ncclCollNet_v7->closeListen; + ncclCollNet_v7_as_v9.name = ncclCollNet_v7->name; + ncclCollNet_v7_as_v9.devices = ncclCollNet_v7->devices; + ncclCollNet_v7_as_v9.getProperties = ncclCollNet_v7_as_v9_getProperties; + ncclCollNet_v7_as_v9.listen = ncclCollNet_v7->listen; + ncclCollNet_v7_as_v9.connect = ncclCollNet_v7->connect; + ncclCollNet_v7_as_v9.reduceSupport = ncclCollNet_v7->reduceSupport; + ncclCollNet_v7_as_v9.regMr = ncclCollNet_v7_as_v9_regMr; + ncclCollNet_v7_as_v9.regMrDmaBuf = ncclCollNet_v7->regMrDmaBuf; + ncclCollNet_v7_as_v9.deregMr = ncclCollNet_v7->deregMr; + ncclCollNet_v7_as_v9.iallreduce = ncclCollNet_v7_as_v9_iallreduce; + ncclCollNet_v7_as_v9.iallgather = nullptr; + ncclCollNet_v7_as_v9.ireducescatter = nullptr; + ncclCollNet_v7_as_v9.iflush = ncclCollNet_v7->iflush; + ncclCollNet_v7_as_v9.test = ncclCollNet_v7->test; + ncclCollNet_v7_as_v9.closeColl = ncclCollNet_v7->closeColl; + ncclCollNet_v7_as_v9.closeListen = ncclCollNet_v7->closeListen; + return ncclSuccess; +} + +static ncclResult_t ncclCollNet_v8_as_v9_getProperties(int dev, ncclNetProperties_v9_t* props) { + ncclNetProperties_v8_t p8; + ncclResult_t ans = ncclCollNet_v8->getProperties(dev, &p8); + if (ans != ncclSuccess) return ans; + props->name = p8.name; + props->pciPath = p8.pciPath; + props->guid = p8.guid; + props->ptrSupport = p8.ptrSupport; + props->regIsGlobal = p8.regIsGlobal; + props->forceFlush = 0; + props->speed = p8.speed; + props->port = p8.port; + props->maxComms = p8.maxComms; + props->maxRecvs = p8.maxRecvs; + props->latency = p8.latency; + props->netDeviceType = NCCL_NET_DEVICE_HOST; + props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->vProps.ndevs = 1; + props->vProps.devs[0] = dev; + props->maxP2pBytes = MAX_NET_SIZE; + props->maxCollBytes = MAX_COLLNET_SIZE; + return ncclSuccess; +} + +static ncclResult_t ncclCollNet_v8_as_v9_iallreduce(void* collComm, void* sendData, void* recvData, size_t count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { + int countInt; + if (count > MAX_NET_SIZE) return ncclInternalError; + countInt = (int)count; + ncclResult_t ans = ncclCollNet_v8->iallreduce(collComm, sendData, recvData, countInt, dataType, redOp, + sendMhandle, recvMhandle, request); + return ans; +} + +static ncclResult_t ncclCollNet_v8_as_v9_iallgather (void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v9_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request) { + ncclNetSGE_v8_t recvPartsInt; + if (nRecvParts > 1) return ncclInternalError; + if (recvParts->size > MAX_COLLNET_SIZE) return ncclInternalError; + recvPartsInt.mhandle = recvParts->mhandle; + recvPartsInt.address = recvParts->address; + recvPartsInt.size = (int)recvParts->size; + ncclResult_t ans = ncclCollNet_v8->iallgather(collComm, sendData, nRecvParts, &recvPartsInt, + bytesPerRank, windowOffset, windowBytes, + sendMhandle, request); + return ans; +} + +static ncclResult_t ncclCollNet_v8_as_v9_ireducescatter(void* collComm, int nSendParts, ncclNetSGE_v9_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request) { + ncclNetSGE_v8_t sendPartsInt; + if (nSendParts > 1) return ncclInternalError; + if (sendParts->size > MAX_COLLNET_SIZE) return ncclInternalError; + sendPartsInt.mhandle = sendParts->mhandle; + sendPartsInt.address = sendParts->address; + sendPartsInt.size = (int)sendParts->size; + ncclResult_t ans = ncclCollNet_v8->ireducescatter(collComm, nSendParts, &sendPartsInt, + recvData, bytesPerRank, windowOffset, windowBytes, + dataType, redOp, + recvMhandle, request); + return ans; +} + +// We use a wrapper around the v8 init to copy over the struct contents +// post-init since they may not be initialized before hand. +static ncclResult_t ncclCollNet_v8_as_v9_init(ncclDebugLogger_t logfn) { + NCCLCHECK(ncclCollNet_v8->init(logfn)); + ncclCollNet_v8_as_v9.name = ncclCollNet_v8->name; + ncclCollNet_v8_as_v9.devices = ncclCollNet_v8->devices; + ncclCollNet_v8_as_v9.getProperties = ncclCollNet_v8_as_v9_getProperties; + ncclCollNet_v8_as_v9.listen = ncclCollNet_v8->listen; + ncclCollNet_v8_as_v9.connect = ncclCollNet_v8->connect; + ncclCollNet_v8_as_v9.reduceSupport = ncclCollNet_v8->reduceSupport; + ncclCollNet_v8_as_v9.regMr = ncclCollNet_v8->regMr; + ncclCollNet_v8_as_v9.regMrDmaBuf = ncclCollNet_v8->regMrDmaBuf; + ncclCollNet_v8_as_v9.deregMr = ncclCollNet_v8->deregMr; + ncclCollNet_v8_as_v9.iallreduce = ncclCollNet_v8_as_v9_iallreduce; + ncclCollNet_v8_as_v9.iallgather = ncclCollNet_v8_as_v9_iallgather; + ncclCollNet_v8_as_v9.ireducescatter = ncclCollNet_v8_as_v9_ireducescatter; + ncclCollNet_v8_as_v9.iflush = ncclCollNet_v8->iflush; + ncclCollNet_v8_as_v9.test = ncclCollNet_v8->test; + ncclCollNet_v8_as_v9.closeColl = ncclCollNet_v8->closeColl; + ncclCollNet_v8_as_v9.closeListen = ncclCollNet_v8->closeListen; return ncclSuccess; } static pthread_mutex_t netLock = PTHREAD_MUTEX_INITIALIZER; -ncclNet_t* ncclNets[3] = { nullptr, &ncclNetIb, &ncclNetSocket }; -ncclCollNet_t* ncclCollNets[3] = { nullptr, nullptr, nullptr }; +ncclNet_t* ncclNets[NCCL_NET_MAX_PLUGINS] = { nullptr, &ncclNetIb, &ncclNetSocket }; +ncclCollNet_t* ncclCollNets[NCCL_NET_MAX_PLUGINS] = { nullptr, nullptr, nullptr }; enum ncclNetState { ncclNetStateInit = 0, ncclNetStateEnabled = 1, ncclNetStateDisabled = 2 }; -enum ncclNetState ncclNetStates[3] = { ncclNetStateInit, ncclNetStateInit, ncclNetStateInit }; -enum ncclNetState ncclCollNetStates[3] = { ncclNetStateInit, ncclNetStateInit, ncclNetStateInit }; +enum ncclNetState ncclNetStates[NCCL_NET_MAX_PLUGINS] = { ncclNetStateInit, ncclNetStateInit, ncclNetStateInit }; +enum ncclNetState ncclCollNetStates[NCCL_NET_MAX_PLUGINS] = { ncclNetStateInit, ncclNetStateInit, ncclNetStateInit }; #define MAX_STR_LEN 255 @@ -444,72 +731,93 @@ ncclResult_t ncclNetPluginLoad(struct ncclComm* comm) { goto fail; } - ncclNets[0] = (ncclNet_v8_t*)dlsym(netPluginLib, "ncclNetPlugin_v8"); + ncclNets[0] = (ncclNet_v9_t*)dlsym(netPluginLib, "ncclNetPlugin_v9"); if (ncclNets[0] == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin_v8 symbol."); - // Try v7 plugin - ncclNet_v7 = (ncclNet_v7_t*)dlsym(netPluginLib, "ncclNetPlugin_v7"); - if (ncclNet_v7 == nullptr) { - // Try v6 plugin - ncclNet_v6 = (ncclNet_v6_t*)dlsym(netPluginLib, "ncclNetPlugin_v6"); - if (ncclNet_v6 == nullptr) { - // Try v5 plugin - ncclNet_v5 = (ncclNet_v5_t*)dlsym(netPluginLib, "ncclNetPlugin_v5"); - if (ncclNet_v5 == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin symbol (>= v5). ncclNetPlugin symbols v4 and lower are not supported."); - goto fail; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin_v9 symbol."); + ncclNet_v8 = (ncclNet_v8_t*)dlsym(netPluginLib, "ncclNetPlugin_v8"); + if (ncclNet_v8 == nullptr) { + // Try v7 plugin + ncclNet_v7 = (ncclNet_v7_t*)dlsym(netPluginLib, "ncclNetPlugin_v7"); + if (ncclNet_v7 == nullptr) { + // Try v6 plugin + ncclNet_v6 = (ncclNet_v6_t*)dlsym(netPluginLib, "ncclNetPlugin_v6"); + if (ncclNet_v6 == nullptr) { + // Try v5 plugin + ncclNet_v5 = (ncclNet_v5_t*)dlsym(netPluginLib, "ncclNetPlugin_v5"); + if (ncclNet_v5 == nullptr) { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin symbol (>= v5). ncclNetPlugin symbols v4 and lower are not supported."); + goto fail; + } else { + ncclNets[0] = &ncclNet_v5_as_v9; + ncclNet_v5_as_v9.init = ncclNet_v5_as_v9_init; + // Set the name right away to allow for NCCL_NET=... to work + ncclNet_v5_as_v9.name = ncclNet_v5->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v5)", ncclNets[0]->name); + } } else { - ncclNets[0] = &ncclNet_v5_as_v8; - ncclNet_v5_as_v8.init = ncclNet_v5_as_v8_init; + ncclNets[0] = &ncclNet_v6_as_v9; + ncclNet_v6_as_v9.init = ncclNet_v6_as_v9_init; // Set the name right away to allow for NCCL_NET=... to work - ncclNet_v5_as_v8.name = ncclNet_v5->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v5)", ncclNets[0]->name); + ncclNet_v6_as_v9.name = ncclNet_v6->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v6)", ncclNets[0]->name); } } else { - ncclNets[0] = &ncclNet_v6_as_v8; - ncclNet_v6_as_v8.init = ncclNet_v6_as_v8_init; + ncclNets[0] = &ncclNet_v7_as_v9; + ncclNet_v7_as_v9.init = ncclNet_v7_as_v9_init; // Set the name right away to allow for NCCL_NET=... to work - ncclNet_v6_as_v8.name = ncclNet_v6->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v6)", ncclNets[0]->name); + ncclNet_v7_as_v9.name = ncclNet_v7->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v7)", ncclNets[0]->name); } } else { - ncclNets[0] = &ncclNet_v7_as_v8; - ncclNet_v7_as_v8.init = ncclNet_v7_as_v8_init; + ncclNets[0] = &ncclNet_v8_as_v9; + ncclNet_v8_as_v9.init = ncclNet_v8_as_v9_init; // Set the name right away to allow for NCCL_NET=... to work - ncclNet_v7_as_v8.name = ncclNet_v7->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v7)", ncclNets[0]->name); + ncclNet_v8_as_v9.name = ncclNet_v8->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v8)", ncclNets[0]->name); } + } else { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v9)", ncclNets[0]->name); } // Check for CollNet - ncclCollNets[0] = (ncclCollNet_v8_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v8"); + ncclCollNets[0] = (ncclCollNet_v9_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v9"); if (ncclCollNets[0] == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin_v8 symbol."); - ncclCollNet_v7 = (ncclCollNet_v7_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v7"); - if (ncclCollNet_v7 == nullptr) { - ncclCollNet_v6 = (ncclCollNet_v6_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v6"); - if (ncclCollNet_v6 == nullptr) { - ncclCollNet_v5 = (ncclCollNet_v5_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v5"); - if (ncclCollNet_v5 == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin symbol (>= v5). ncclCollNetPlugin symbols v4 and lower are not supported."); + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin_v9 symbol."); + ncclCollNet_v8 = (ncclCollNet_v8_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v8"); + if (ncclCollNet_v8 == nullptr) { + ncclCollNet_v7 = (ncclCollNet_v7_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v7"); + if (ncclCollNet_v7 == nullptr) { + ncclCollNet_v6 = (ncclCollNet_v6_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v6"); + if (ncclCollNet_v6 == nullptr) { + ncclCollNet_v5 = (ncclCollNet_v5_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v5"); + if (ncclCollNet_v5 == nullptr) { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin symbol (>= v5). ncclCollNetPlugin symbols v4 and lower are not supported."); + } else { + ncclCollNets[0] = &ncclCollNet_v5_as_v9; + ncclCollNet_v5_as_v9.init = ncclCollNet_v5_as_v9_init; + ncclCollNet_v5_as_v9.name = ncclCollNet_v5->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v5)", ncclCollNets[0]->name); + } } else { - ncclCollNets[0] = &ncclCollNet_v5_as_v8; - ncclCollNet_v5_as_v8.init = ncclCollNet_v5_as_v8_init; - ncclCollNet_v5_as_v8.name = ncclCollNet_v5->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v5)", ncclCollNets[0]->name); + ncclCollNets[0] = &ncclCollNet_v6_as_v9; + ncclCollNet_v6_as_v9.init = ncclCollNet_v6_as_v9_init; + ncclCollNet_v6_as_v9.name = ncclCollNet_v6->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v6)", ncclCollNets[0]->name); } } else { - ncclCollNets[0] = &ncclCollNet_v6_as_v8; - ncclCollNet_v6_as_v8.init = ncclCollNet_v6_as_v8_init; - ncclCollNet_v6_as_v8.name = ncclCollNet_v6->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v6)", ncclCollNets[0]->name); + ncclCollNets[0] = &ncclCollNet_v7_as_v9; + ncclCollNet_v7_as_v9.init = ncclCollNet_v7_as_v9_init; + ncclCollNet_v7_as_v9.name = ncclCollNet_v7->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v7)", ncclCollNets[0]->name); } } else { - ncclCollNets[0] = &ncclCollNet_v7_as_v8; - ncclCollNet_v7_as_v8.init = ncclCollNet_v7_as_v8_init; - ncclCollNet_v7_as_v8.name = ncclCollNet_v7->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v7)", ncclCollNets[0]->name); + ncclCollNets[0] = &ncclCollNet_v8_as_v9; + ncclCollNet_v8_as_v9.init = ncclCollNet_v8_as_v9_init; + ncclCollNet_v8_as_v9.name = ncclCollNet_v8->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v8)", ncclCollNets[0]->name); } + } else { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded collnet plugin %s (v9)", ncclCollNets[0]->name); } ++netPluginRefCount; @@ -540,6 +848,8 @@ ncclResult_t ncclNetPluginUnload(struct ncclComm* comm) { ncclCollNets[0] = nullptr; netPluginStatus = netPluginLoadReady; comm->netPluginLoaded = 0; + for (int i = 0; i < NCCL_NET_MAX_PLUGINS; ++i) + ncclCollNetStates[i] = ncclNetStates[i] = ncclNetStateInit; } pthread_mutex_unlock(&netPluginLock); return ncclSuccess; @@ -562,7 +872,7 @@ ncclResult_t ncclNetCheckDeviceVersion(struct ncclComm* comm, ncclNet_t* net, in return ncclInternalError; } default: - WARN("Unknown device code index"); + WARN("Unknown device code index %d \n", type); return ncclInternalError; } @@ -720,8 +1030,9 @@ cleanup1: int ncclNetVersion(struct ncclComm* comm) { return - (comm->ncclNet == &ncclNet_v5_as_v8) ? 5 : - (comm->ncclNet == &ncclNet_v6_as_v8) ? 6 : - (comm->ncclNet == &ncclNet_v7_as_v8) ? 7 : - 8; + (comm->ncclNet == &ncclNet_v5_as_v9) ? 5 : + (comm->ncclNet == &ncclNet_v6_as_v9) ? 6 : + (comm->ncclNet == &ncclNet_v7_as_v9) ? 7 : + (comm->ncclNet == &ncclNet_v8_as_v9) ? 8 : + 9; } diff --git a/src/proxy.cc b/src/proxy.cc index a0b86889ca..67cf6cfd1b 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -378,7 +378,11 @@ static ncclResult_t ncclProxyOpToArgs(struct ncclProxyOp* op, struct ncclProxyAr sub->channelId = op->channelId; sub->nsteps = op->nsteps; sub->nbytes = op->nbytes; + sub->chunkSize = op->chunkSize; sub->offset = 0; + sub->loopSize = op->loopSize; + sub->loopOffset = op->loopOffset; + sub->isOneRPN = op->isOneRPN; sub->peer = op->peer; sub->reg = op->reg; sub->sendMhandle = op->sendMhandle; @@ -388,8 +392,9 @@ static ncclResult_t ncclProxyOpToArgs(struct ncclProxyOp* op, struct ncclProxyAr sub->eActivationMask = op->eActivationMask; sub->taskEventHandle = op->taskEventHandle; sub->rank = op->rank; - args->pid = op->pid; - args->profilerContext = op->profilerContext; + sub->pid = op->pid; + sub->profilerContext = op->profilerContext; + sub->ringAlgo = op->ringAlgo; args->nsubs = subIndex+1; if (subIndex) { if ((args->sliceSteps != op->sliceSteps) || @@ -418,6 +423,7 @@ static ncclResult_t ncclProxyOpToArgs(struct ncclProxyOp* op, struct ncclProxyAr args->pattern = op->pattern; args->protocol = op->protocol; args->coll = op->coll; + args->algorithm = op->algorithm; args->specifics = op->specifics; args->state = ncclProxyOpReady; args->progress = op->connection->tcomm->proxyProgress; @@ -503,6 +509,7 @@ static ncclResult_t ncclLocalOpAppend(struct ncclComm* comm, struct ncclProxyCon } if (op->next != -1) __builtin_prefetch(pool->ops+op->next); // Prefetch next free op memcpy(op, proxyOp, sizeof(struct ncclProxyOp)); + if (proxyOp->ringAlgo) proxyOp->ringAlgo->incRefCount(); op->next = -1; op->connection = proxyConn->connection; if (proxyOps->nextOps == -1) { @@ -628,13 +635,15 @@ ncclResult_t ncclProxySaveOp(struct ncclComm* comm, struct ncclProxyOp* op, bool } break; case ncclPatternPatUp: { // Run full algorithm to count the number of steps for each peer. - int *nstepsSend, *nstepsRecv; - const int rank = comm->rank, nranks = comm->nRanks; - NCCLCHECK(ncclCalloc(&nstepsSend, log2Up(nranks))); - NCCLCHECK(ncclCalloc(&nstepsRecv, log2Up(nranks))); + ncclResult_t result = ncclSuccess; const ssize_t size = op->nbytes/comm->nRanks; - PatRSAlgorithm algo(op->chunkSize, NCCL_STEPS, 0, size, size, op->chunkSize, rank, nranks); int last = 0; + int *nstepsSend = NULL, *nstepsRecv = NULL; + const int rank = comm->rank, nranks = comm->nRanks; + PatRSAlgorithm algo(op->chunkSize, NCCL_STEPS, 0, size, size, op->chunkSize, rank, nranks); + NCCLCHECKGOTO(ncclCalloc(&nstepsSend, log2Up(nranks)), result, exit_pat_up); + NCCLCHECKGOTO(ncclCalloc(&nstepsRecv, log2Up(nranks)), result, exit_pat_up); + while (last == 0) { int recvDim, sendDim, recvOffset, sendOffset, sendStepOffset, postRecv, postSend, nelem; size_t inpIx, outIx; @@ -646,24 +655,30 @@ ncclResult_t ncclProxySaveOp(struct ncclComm* comm, struct ncclProxyOp* op, bool if (nstepsSend[i]) { int sendPeer = (rank + (1<nsteps = nstepsSend[i]; - NCCLCHECK(SaveProxy(comm, channel, proxySend, sendPeer, op, 0, justInquire)); + NCCLCHECKGOTO(SaveProxy(comm, channel, proxySend, sendPeer, op, 0, justInquire), result, exit_pat_up); } if (nstepsRecv[i]) { int recvPeer = (rank - (1<nsteps = nstepsRecv[i]; - NCCLCHECK(SaveProxy(comm, channel, proxyRecv, recvPeer, op, 0, justInquire)); + NCCLCHECKGOTO(SaveProxy(comm, channel, proxyRecv, recvPeer, op, 0, justInquire), result, exit_pat_up); } } + exit_pat_up: + free(nstepsSend); + free(nstepsRecv); + NCCLCHECK(result); } break; case ncclPatternPatDown: { // Run full algorithm to count the number of steps for each peer. - int *nstepsSend, *nstepsRecv; - const int rank = comm->rank, nranks = comm->nRanks; - NCCLCHECK(ncclCalloc(&nstepsSend, log2Up(nranks))); - NCCLCHECK(ncclCalloc(&nstepsRecv, log2Up(nranks))); + ncclResult_t result = ncclSuccess; const ssize_t size = op->nbytes/comm->nRanks; - PatAGAlgorithm algo(op->chunkSize, NCCL_STEPS, 0, size, size, op->chunkSize, rank, nranks); int last = 0; + int *nstepsSend = NULL, *nstepsRecv = NULL; + const int rank = comm->rank, nranks = comm->nRanks; + PatAGAlgorithm algo(op->chunkSize, NCCL_STEPS, 0, size, size, op->chunkSize, rank, nranks); + NCCLCHECKGOTO(ncclCalloc(&nstepsSend, log2Up(nranks)), result, exit_pat_down); + NCCLCHECKGOTO(ncclCalloc(&nstepsRecv, log2Up(nranks)), result, exit_pat_down); + while (last == 0) { int recvDim, sendDim, recvOffset, sendOffset, recvStepOffset, postRecv, postSend, nelem; size_t inpIx, outIx; @@ -675,14 +690,18 @@ ncclResult_t ncclProxySaveOp(struct ncclComm* comm, struct ncclProxyOp* op, bool if (nstepsSend[i]) { int sendPeer = (rank - (1<nsteps = nstepsSend[i]; - NCCLCHECK(SaveProxy(comm, channel, proxySend, sendPeer, op, 0, justInquire)); + NCCLCHECKGOTO(SaveProxy(comm, channel, proxySend, sendPeer, op, 0, justInquire), result, exit_pat_down); } if (nstepsRecv[i]) { int recvPeer = (rank + (1<nsteps = nstepsRecv[i]; - NCCLCHECK(SaveProxy(comm, channel, proxyRecv, recvPeer, op, 0, justInquire)); + NCCLCHECKGOTO(SaveProxy(comm, channel, proxyRecv, recvPeer, op, 0, justInquire), result, exit_pat_down); } } + exit_pat_down: + free(nstepsSend); + free(nstepsRecv); + NCCLCHECK(result); } break; case ncclPatternSend: case ncclPatternRecv: { @@ -764,23 +783,17 @@ static ncclResult_t ncclProxyGetPostedOps(struct ncclProxyState* proxyState, int if (state->active == NULL) { pthread_mutex_lock(&pool->mutex); - while (pool->nextOps == -1 && !state->stop) { + if (pool->nextOps == -1 && !state->stop) { ncclProfilerStartProxyCtrlEvent(proxyState->profilerContext, &eHandle); ncclProfilerRecordProxyCtrlEventState(eHandle, 0, ncclProfilerProxyCtrlSleep); pthread_cond_wait(&pool->cond, &pool->mutex); ncclProfilerRecordProxyCtrlEventState(eHandle, 0, ncclProfilerProxyCtrlWakeup); ncclProfilerStopProxyCtrlEvent(eHandle); } - if (state->stop) { // We might have been woken up to stop. - pthread_mutex_unlock(&pool->mutex); - return ncclSuccess; - } } - state->nextOps = pool->nextOps; pool->nextOps = pool->nextOpsEnd = -1; pthread_mutex_unlock(&pool->mutex); - if (state->nextOps == -1) return ncclInternalError; process_nextops: ncclProfilerStartProxyCtrlEvent(proxyState->profilerContext, &eHandle); @@ -919,7 +932,7 @@ void* ncclProxyProgress(void *proxyState_) { * ncclParamProgressAppendOpFreq(). If they are equal, we will append proxy ops. This will decrease the * frequency of calling ncclProxyGetPostedOps() and reduce the perf impact. */ int proxyOpAppendCounter = 0; - while (state->stop == 0 || (state->stop == 1 && state->active)) { + do { int idle = 1; ncclResult_t ret = progressOps(proxyState, state, state->active, &idle); if (ret != ncclSuccess) { @@ -932,12 +945,11 @@ void* ncclProxyProgress(void *proxyState_) { if (lastIdle == 0 && idle == 1) ncclProfilerRecordProxyCtrlEventState(eHandle, 0, ncclProfilerProxyCtrlIdle); if (lastIdle == 1 && idle == 0) ncclProfilerRecordProxyCtrlEventState(eHandle, 0, ncclProfilerProxyCtrlActive); ncclProfilerStopProxyCtrlEvent(eHandle); - if (idle || (++proxyOpAppendCounter == ncclParamProgressAppendOpFreq())) { + if (idle || !state->active || (++proxyOpAppendCounter == ncclParamProgressAppendOpFreq())) { int added = 0; proxyOpAppendCounter = 0; TIME_START(3); - if (state->stop == 0) - ret = ncclProxyGetPostedOps(proxyState, &added); + ret = ncclProxyGetPostedOps(proxyState, &added); if (added) { TIME_STOP(3); } else { TIME_CANCEL(3); } if (ret != ncclSuccess) { __atomic_store_n(&proxyState->asyncResult, ret, __ATOMIC_RELEASE); @@ -948,7 +960,7 @@ void* ncclProxyProgress(void *proxyState_) { } } lastIdle = idle; - } + } while (state->stop == 0 || (state->stop == 1 && state->active)); return NULL; } @@ -1120,7 +1132,7 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in strncpy(poolPath+sizeof("/dev/shm/nccl-")-1, resp.devShmPath, sizeof("XXXXXX")-1); struct ncclProxyOps* proxyOps = sharedProxyState->proxyOps + proxyConn->tpLocalRank; if (proxyOps->pool == NULL) { - NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, -1, &proxyOps->handle)); + NCCLCHECK(ncclShmOpen(poolPath, sizeof(poolPath), sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, -1, &proxyOps->handle)); proxyOps->nextOps = proxyOps->nextOpsEnd = proxyOps->freeOp = -1; } } @@ -1323,7 +1335,7 @@ static ncclResult_t proxyProgressInit(struct ncclProxyState* proxyState) { char shmPath[sizeof("/dev/shm/nccl-XXXXXX")]; shmPath[0] = '\0'; - NCCLCHECK(ncclShmOpen(shmPath, size, (void**)&pool, NULL, proxyState->tpLocalnRanks, &state->handle)); + NCCLCHECK(ncclShmOpen(shmPath, sizeof(shmPath), size, (void**)&pool, NULL, proxyState->tpLocalnRanks, &state->handle)); // Init pool pool->nextOps = -1; @@ -1402,7 +1414,7 @@ static ncclResult_t proxyQueryFd(struct ncclProxyState* proxyState, int rank, vo ncclResult_t ret = ncclSuccess; NCCLCHECKGOTO(ncclIpcSocketInit(&ipcSock, proxyState->tpRank, hash^1, proxyState->abortFlag), ret, exit); - NCCLCHECKGOTO(ncclIpcSocketSendMsg(&ipcSock, &rmtFd, sizeof(int), rmtFd, rank, hash), ret, exit); + NCCLCHECKGOTO(ncclIpcSocketSendMsg(&ipcSock, &rmtFd, sizeof(int), -1, rank, hash), ret, exit); exit: NCCLCHECK(ncclIpcSocketClose(&ipcSock)); return ncclSuccess; @@ -1634,7 +1646,7 @@ void* ncclProxyService(void* _args) { if (pollfds[s].fd == -1) continue; // Progress all ops for this ncclProxyLocalPeer - if (stop == PROXY_ABORT && ncclCuMemEnable() && ncclCuMemHostEnable() && !proxyState->directMode) closeConn = 1; + if (stop == PROXY_ABORT && ncclCuMemEnable() && ncclCuMemHostEnable() && !proxyState->directMode && __atomic_load_n(&proxyState->stop, __ATOMIC_ACQUIRE)) closeConn = 1; ncclProxyAsyncOp* op = peer->asyncOps; while (op != nullptr) { ncclProxyAsyncOp* opnext = op->next; /* in case op is freed in proxyProgressAsync */ @@ -1724,11 +1736,17 @@ static ncclResult_t proxyUDSRecvReq(struct ncclProxyState* proxyState, int reqFd NCCLCHECK(ncclIpcSocketRecvMsg(&proxyState->ipcSock, &hdr, sizeof(hdr), &rmtFd)); if (hdr.type == ncclProxyMsgGetFd) { - // cuMem API support + // cuMem API support for non-UB case, and rmtFd is not used since UDS proxy thread need to export + // fd from handle and send it back to the main thread to import the buffer. We just need to close + // this dummy rmtFd. uint64_t handle = *(uint64_t*)hdr.data; INFO(NCCL_PROXY, "proxyUDSRecvReq::ncclProxyMsgGetFd rank %d opId %p handle=0x%lx", hdr.rank, hdr.opId, handle); + close(rmtFd); return proxyGetFd(proxyState, hdr.rank, hdr.opId, handle); } else if (hdr.type == ncclProxyMsgQueryFd) { + // remote main thread registers buffer into this rank, it querys rmtFd of this rank through UDS + // and the rmtFd is returned unchanged back to remote main thread which will use rmtFd to call into + // proxy service thread for buffer registration. INFO(NCCL_PROXY, "proxyUDSRecvReq::proxyQueryFd rank %d opId %p rmtFd %d", hdr.rank, hdr.opId, rmtFd); return proxyQueryFd(proxyState, hdr.rank, hdr.opId, rmtFd); } @@ -1775,7 +1793,7 @@ void* ncclProxyServiceUDS(void* _args) { } } - ncclIpcSocketClose(&proxyState->ipcSock); + (void)ncclIpcSocketClose(&proxyState->ipcSock); INFO(NCCL_PROXY, "[Proxy Service UDS] exit: stop %d abortFlag %d", proxyState->stop, *proxyState->abortFlag); return NULL; } @@ -1832,15 +1850,10 @@ ncclResult_t ncclProxyStop(struct ncclComm* comm) { struct ncclProxyState* sharedProxyState = comm->proxyState; if ((comm->proxyRefCountOld = ncclAtomicRefCountDecrement(&sharedProxyState->refCount)) == 0) { - if (comm->proxyState->threadUDS) { - // UDS support - __atomic_store_n(&comm->proxyState->stop, 1, __ATOMIC_RELEASE); - } - if (*comm->abortFlag == 0 && sharedProxyState->peerAddresses) { struct ncclSocket sock; int type = ncclProxyMsgStop; - ncclSocketInit(&sock, sharedProxyState->peerAddresses + comm->topParentRanks[comm->rank], comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag); + NCCLCHECK(ncclSocketInit(&sock, sharedProxyState->peerAddresses + comm->topParentRanks[comm->rank], comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag)); if (ncclSocketConnect(&sock) == ncclSuccess) { (void)ncclSocketSend(&sock, &type, sizeof(int)); } @@ -1867,6 +1880,8 @@ ncclResult_t ncclProxyStop(struct ncclComm* comm) { } } } + // Now we notify proxy service and UDS thread to exit. + __atomic_store_n(&comm->proxyState->stop, 1, __ATOMIC_RELEASE); } } diff --git a/src/ras/client.cc b/src/ras/client.cc new file mode 100644 index 0000000000..8061cef4e6 --- /dev/null +++ b/src/ras/client.cc @@ -0,0 +1,318 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nccl.h" +#define NCCL_RAS_CLIENT // Only pull client-specific definitions from the header file below. +#include "ras_internal.h" + +#define STR2(v) #v +#define STR(v) STR2(v) + +// Local timeout increment compared to the '-t' argument, in seconds. +#define TIMEOUT_INCREMENT 1 + +static const char* hostName = "localhost"; +static const char* port = STR(NCCL_RAS_CLIENT_PORT); +static int timeout = -1; +static bool verbose = false; +static int sock = -1; + +static void printUsage(const char* argv0) { + fprintf(stderr, + "Usage: %s [OPTION]...\n" + "Query the state of a running NCCL job.\n" + "\nOptions:\n" + " -h, --host=HOST Host name or IP address of the RAS client socket of the\n" + " NCCL job to connect to (localhost by default)\n" + " -p, --port=PORT TCP port of the RAS client socket of the NCCL job\n" + " (" STR(NCCL_RAS_CLIENT_PORT) " by default)\n" + " -t, --timeout=SECS Maximum time for the local NCCL process to wait for\n" + " responses from other NCCL processes\n" + " (" STR(RAS_COLLECTIVE_LEG_TIMEOUT_SEC) " secs by default; 0 disables the timeout)\n" + " -v, --verbose Increase the verbosity level of the RAS output\n" + " --help Print this help and exit\n" + " --version Print the version number and exit\n", argv0); +} + +static void parseArgs(int argc, char** argv) { + int c; + int optIdx = 0; + struct option longOpts[] = { + {"host", required_argument, NULL, 'h'}, + {"port", required_argument, NULL, 'p'}, + {"timeout", required_argument, NULL, 't'}, + {"verbose", no_argument, NULL, 'v'}, + {"help", no_argument, NULL, 'e'}, + {"version", no_argument, NULL, 'r'}, + {0} + }; + + while ((c = getopt_long(argc, argv, "h:p:t:v", longOpts, &optIdx)) != -1) { + switch (c) { + case 'h': + hostName = optarg; + break; + case 'p': + port = optarg; + break; + case 't': { + char* endPtr = nullptr; + timeout = strtol(optarg, &endPtr, 10); + if (timeout < 0 || !endPtr || *endPtr != '\0') { + fprintf(stderr, "Invalid timeout: %s\n", optarg); + exit(1); + } + break; + } + case 'v': + verbose = true; + break; + case 'e': + printUsage(argv[0]); + exit(0); + case 'r': + fprintf(stderr, "NCCL RAS client version " STR(NCCL_MAJOR) "." STR(NCCL_MINOR) "." + STR(NCCL_PATCH) NCCL_SUFFIX "\n"); + exit(0); + default: + printUsage(argv[0]); + exit(1); + } + } +} + +static ssize_t socketWrite(int fd, const void* buf, size_t count) { + size_t done = 0; + do { + ssize_t ret; + ret = write(fd, ((const char*)buf)+done, count-done); + if (ret == -1) { + if (errno != EINTR) + return -1; + continue; + } + done += ret; + } while (done < count); + + return done; +} + +// Reads a message from RAS. Assumes that the message ends with '\n' (will continue reading until the terminating +// newline, unless false is passed as untilNewLine). +// Terminates the buffer with '\0'. Returns the number of bytes read (excluding the added terminating '\0'). +static ssize_t rasRead(int fd, void* buf, size_t count, bool untilNewline = true) { + char* bufChar = (char*)buf; + size_t done = 0; + do { + ssize_t ret; + ret = read(fd, bufChar+done, count-1-done); + if (ret == -1) { + if (errno != EINTR) + return -1; + continue; + } + if (ret == 0) + break; // EOF + done += ret; + } while (untilNewline && (done == 0 || bufChar[done-1] != '\n')); + bufChar[done] = '\0'; + + return done; +} + +static int connectToNCCL() { + struct addrinfo hints = {0}; + struct addrinfo* addrInfo = nullptr; + int ret; + char msgBuf[1024]; + int bytes; + struct timeval tv = {TIMEOUT_INCREMENT, 0}; + +retry: + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + if ((ret = getaddrinfo(hostName, port, &hints, &addrInfo)) != 0) { + fprintf(stderr, "Resolving %s:%s: %s\n", hostName, port, gai_strerror(ret)); + goto fail; + } + for (struct addrinfo* ai = addrInfo; ai; ai = ai->ai_next) { + char hostBuf[NI_MAXHOST], portBuf[NI_MAXSERV]; + int err; + sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (sock == -1) { + perror("socket"); + continue; + } + // Initially start with a small, 1-sec timeout to quickly eliminate non-responsive processes... + if (timeout && (setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof tv) != 0 || + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof tv) != 0)) { + perror("setsockopt"); + // Non-fatal; fall through. + } + if (connect(sock, ai->ai_addr, ai->ai_addrlen) == 0) + break; + err = errno; + if (getnameinfo(ai->ai_addr, ai->ai_addrlen, hostBuf, sizeof(hostBuf), portBuf, sizeof(portBuf), + NI_NUMERICHOST | NI_NUMERICSERV) != 0) { + strcpy(hostBuf, hostName); + strcpy(portBuf, port); + } + fprintf(stderr, "Connecting to %s:%s: %s\n", hostBuf, portBuf, strerror(err)); + close(sock); + sock = -1; + } + freeaddrinfo(addrInfo); + addrInfo = nullptr; + + if (sock == -1) { + fprintf(stderr, "Failed to connect to the NCCL RAS service!\n" + "Please make sure that the NCCL job has the RAS service enabled and that\n" + "%s.\n", + (strcmp(hostName, "localhost") || strcmp(port, STR(NCCL_RAS_CLIENT_PORT)) ? + "the host/port arguments are correct and match NCCL_RAS_ADDR" : + "the RAS client was started on a node where the NCCL job is running")); + goto fail; + } + + // Exchange the RAS client handshake. + strcpy(msgBuf, "CLIENT PROTOCOL " STR(NCCL_RAS_CLIENT_PROTOCOL) "\n"); + if (socketWrite(sock, msgBuf, strlen(msgBuf)) != strlen(msgBuf)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + goto timeout; + } + perror("write to socket"); + goto fail; + } + bytes = rasRead(sock, msgBuf, sizeof(msgBuf)); + if (bytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + goto timeout; + } + perror("read socket"); + goto fail; + } + if (bytes == 0) { + fprintf(stderr, "NCCL unexpectedly closed the connection\n"); + goto fail; + } + if (strncasecmp(msgBuf, "SERVER PROTOCOL ", strlen("SERVER PROTOCOL "))) { + fprintf(stderr, "Unexpected response from NCCL: %s\n", msgBuf); + goto fail; + } + if (strtol(msgBuf+strlen("SERVER PROTOCOL "), nullptr, 10) != NCCL_RAS_CLIENT_PROTOCOL) { + fprintf(stderr, "NCCL RAS protocol version mismatch (NCCL: %s; RAS client: %d)!\n" + "Will try to continue in spite of that...\n", msgBuf+strlen("SERVER PROTOCOL "), NCCL_RAS_CLIENT_PROTOCOL); + } + + if (timeout >= 0) { + snprintf(msgBuf, sizeof(msgBuf), "TIMEOUT %d\n", timeout); + if (socketWrite(sock, msgBuf, strlen(msgBuf)) != strlen(msgBuf)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + goto timeout; + } + perror("write to socket"); + goto fail; + } + bytes = rasRead(sock, msgBuf, sizeof(msgBuf)); + if (bytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + goto timeout; + } + perror("read socket"); + goto fail; + } + if (bytes == 0) { + fprintf(stderr, "NCCL unexpectedly closed the connection\n"); + goto fail; + } + if (strcasecmp(msgBuf, "OK\n")) { + fprintf(stderr, "Unexpected response from NCCL: %s\n", msgBuf); + goto fail; + } + } + if (timeout) { + // Increase the socket timeout to accommodate NCCL timeout. + tv.tv_sec += (timeout > 0 ? timeout : RAS_COLLECTIVE_LEG_TIMEOUT_SEC) + RAS_COLLECTIVE_EXTRA_TIMEOUT_SEC; + if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof tv) != 0) { + perror("setsockopt"); + // Non-fatal; fall through. + } + } + + return 0; +fail: + if (addrInfo) + freeaddrinfo(addrInfo); + if (sock != -1) + (void)close(sock); + return 1; +timeout: + fprintf(stderr, "Connection timed out; retrying...\n"); + (void)close(sock); + goto retry; +} + +int getNCCLStatus() { + char msgBuf[4096]; + int bytes; + snprintf(msgBuf, sizeof(msgBuf), "%sSTATUS\n", (verbose ? "VERBOSE " : "")); + if (socketWrite(sock, msgBuf, strlen(msgBuf)) != strlen(msgBuf)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) + fprintf(stderr, "Connection timed out\n"); + else + perror("write to socket"); + return 1; + } + for (;;) { + bytes = rasRead(sock, msgBuf, sizeof(msgBuf), /*untileNewLine*/false); + if (bytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) + fprintf(stderr, "Connection timed out\n"); + else + perror("read socket"); + return 1; + } + if (bytes == 0) // EOF + break; + if (fwrite(msgBuf, 1, bytes, stdout) != bytes) { + fprintf(stderr, "fwrite to stdout failed!\n"); + return 1; + } + if (fflush(stdout) != 0) { + perror("fflush stdout"); + return 1; + } + } + return 0; +} + +int main(int argc, char** argv) { + parseArgs(argc, argv); + + if (connectToNCCL()) + return 1; + + if (getNCCLStatus()) { + (void)close(sock); + return 1; + } + + if (close(sock) == -1) { + perror("close socket"); + return 1; + } + return 0; +} diff --git a/src/ras/client_support.cc b/src/ras/client_support.cc new file mode 100644 index 0000000000..414a1ed94f --- /dev/null +++ b/src/ras/client_support.cc @@ -0,0 +1,1755 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#define NDEBUG // Comment out duriyng development only! +#include +#include +#include + +#include "alloc.h" +#include "checks.h" +#include "comm.h" +#include "nccl.h" +#include "utils.h" +#include "ras_internal.h" + +// Outlier count above which we don't print individual details about each of them. +#define RAS_CLIENT_DETAIL_THRESHOLD 10 +// Fraction of the count of the total above which we don't consider another set to be an outlier. +#define RAS_CLIENT_OUTLIER_FRACTION 0.25 +// Fraction of the count of the total below which a set is considered to be an outlier. +#define RAS_CLIENT_VERBOSE_OUTLIER_FRACTION 0.5 + +#define STR2(v) #v +#define STR(v) STR2(v) + +// The RAS client listening socket of this RAS thread (normally port 28028). +int rasClientListeningSocket = -1; + +// Auxiliary structure used when processing the results. Helps with statistics gathering and sorting. +struct rasValCount { + uint64_t value; // The observed value. + int count; // The number of occurences of this value in the results. + int firstIdx; // The index of the first occurence of this value in the results. +}; + +// Used in rasAuxComm below. The values are bitmasks so that they can be combined. +typedef enum { + RAS_ACS_UNKNOWN = 1, // Set if a peer did not provide info about a given communicator. + RAS_ACS_INIT = 2, + RAS_ACS_RUNNING = 4, + RAS_ACS_FINALIZE = 8, + RAS_ACS_ABORT = 16 +} rasACStatus; + +// Used in rasAuxComm below. The values are bitmasks so that they can be combined (with the exception of RAS_ACE_OK). +typedef enum { + RAS_ACE_OK = 0, + RAS_ACE_MISMATCH = 1, + RAS_ACE_ERROR = 2, + RAS_ACE_INCOMPLETE = 4 +} rasACError; + +// Auxiliary structure used when processing the results. Helps with sorting and includes additional statistics +// on the number of peers and nodes for a communicator. +struct rasAuxComm { + struct rasCollComms::comm* comm; + int nPeers; + int nNodes; + int ranksPerNodeMin; + int ranksPerNodeMax; + unsigned int status; // Bitmask of rasACStatus values. + unsigned int errors; // Bitmask of rasACError values. + uint64_t firstCollOpCount; // collOpCount of the first rank, to compare against. +}; + +// Connected RAS clients. +struct rasClient* rasClients; +int nRasClients; + +// Minimum byte count to increment the output buffer size by if it's too small. +#define RAS_OUT_INCREMENT 4096 + +// Internal buffer for storing the formatted results. +static char* rasOutBuffer = nullptr; +static int nRasOutBuffer = 0; // Does _not_ include the terminating '\0' (which _is_ present in the buffer). +static int rasOutBufferSize = 0; + +// We use them all over the place; no point in wasting the stack... +static char lineBuf[1024]; // Temporary buffer used for printing at most 10 (RAS_CLIENT_DETAIL_THRESHOLD) rank numbers + // or for printing the local GPU devices, which can't be more than 64 (NCCL_MAX_LOCAL_RANKS) + // small numbers (times two if the NVML mask is different than the CUDA mask). + // Still, 1024 should normally be plenty (verbose output may make things more difficult, + // but we do check for overflows, so it will just be trimmed). + +static ncclResult_t getNewClientEntry(struct rasClient** pClient); +static void rasClientEnqueueMsg(struct rasClient* client, char* msg, size_t msgLen); +static void rasClientTerminate(struct rasClient* client); + +static ncclResult_t rasClientRun(struct rasClient* client); +static ncclResult_t rasClientRunInit(struct rasClient* client); +static ncclResult_t rasClientRunConns(struct rasClient* client); +static ncclResult_t rasClientRunComms(struct rasClient* client); +static void rasClientBreakDownErrors(struct rasClient* client, struct rasCollComms::comm* comm, + const int* peerIdxConv, int ncclErrors[ncclNumResults], bool isAsync = false); + +static void rasOutAppend(const char* format, ...) __attribute__ ((format(printf, 1, 2))); +static void rasOutExtract(char* buffer); +static int rasOutLength(); +static void rasOutReset(); + +static int rasPeersNGpuCompare(const void* e1, const void* e2); +static int rasPeersNProcsCompare(const void* e1, const void* e2); +static int rasPeersHostPidCompare(const void* e1, const void* e2); +static int ncclSocketsHostCompare(const void* p1, const void* p2); +static int rasValCountsCompareRev(const void* p1, const void* p2); +static int rasAuxCommsCompareRev(const void* p1, const void* p2); +static int rasCommRanksPeerCompare(const void* p1, const void* p2); +static int rasCommRanksCollOpCompare(const void* p1, const void* p2); + +static const char* rasCommRankGpuToString(const struct rasCollComms::comm::rank* rank, char* buf, size_t size); +static const char* ncclErrorToString(ncclResult_t err); +static const char* ncclSocketToHost(const union ncclSocketAddress* addr, char* buf, size_t size); +static bool rasCountIsOutlier(int count, bool verbose, int totalCount = -1); + + +/////////////////////////////////// +// General rasClients functions. // +/////////////////////////////////// + +// Creates a listening socket for clients to connect to. +ncclResult_t rasClientInitSocket() { + ncclResult_t ret = ncclSuccess; + const char* clientAddr = "localhost:" STR(NCCL_RAS_CLIENT_PORT); + union ncclSocketAddress addr; + const int opt = 1; + if (const char* env = ncclGetEnv("NCCL_RAS_ADDR")) + clientAddr = env; + NCCLCHECKGOTO(ncclSocketGetAddrFromString(&addr, clientAddr), ret, fail); + SYSCHECKGOTO(rasClientListeningSocket = socket(addr.sa.sa_family, SOCK_STREAM, 0), "socket", ret, fail); + SYSCHECKGOTO(setsockopt(rasClientListeningSocket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), + "setsockopt", ret, fail); +#if defined(SO_REUSEPORT) + SYSCHECKGOTO(setsockopt(rasClientListeningSocket, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)), + "setsockopt", ret, fail); +#endif + SYSCHECKGOTO(bind(rasClientListeningSocket, &addr.sa, (addr.sa.sa_family == AF_INET ? sizeof(struct sockaddr_in) : + sizeof(struct sockaddr_in6))), "bind", ret, fail); + SYSCHECKGOTO(listen(rasClientListeningSocket, 16384), "listen", ret, fail); + INFO(NCCL_INIT|NCCL_RAS, "RAS client listening socket at %s", ncclSocketToString(&addr, rasLine)); +exit: + return ret; +fail: + INFO(NCCL_INIT|NCCL_RAS, "RAS failed to establish a client listening socket at %s", clientAddr); + if (rasClientListeningSocket != -1) { + (void)close(rasClientListeningSocket); + rasClientListeningSocket = -1; + } + goto exit; +} + +// Accepts a new RAS client connection. The acceptance process may need to continue in the main event loop. +ncclResult_t rasClientAcceptNewSocket() { + ncclResult_t ret = ncclSuccess; + struct rasClient* client = nullptr; + union ncclSocketAddress addr; + socklen_t addrlen = sizeof(addr); + int flags; + + NCCLCHECKGOTO(getNewClientEntry(&client), ret, fail); + + SYSCHECKGOTO(client->sock = accept(rasClientListeningSocket, (struct sockaddr*)&addr, &addrlen), "accept", ret, fail); + + SYSCHECKGOTO(flags = fcntl(client->sock, F_GETFL), "fcntl", ret, fail); + SYSCHECKGOTO(fcntl(client->sock, F_SETFL, flags | O_NONBLOCK), "fcntl", ret, fail); + + NCCLCHECKGOTO(rasGetNewPollEntry(&client->pfd), ret, fail); + rasPfds[client->pfd].fd = client->sock; + rasPfds[client->pfd].events = POLLIN; + client->status = RAS_CLIENT_CONNECTED; +exit: + return ret; +fail: + if (client && client->sock != -1) + (void)close(client->sock); + goto exit; +} + +// Returns the index of the first available entry in the rasClients array, enlarging the array if necessary. +static ncclResult_t getNewClientEntry(struct rasClient** pClient) { + struct rasClient* client; + int i; + for (i = 0; i < nRasClients; i++) + if (rasClients[i].status == RAS_CLIENT_CLOSED) + break; + if (i == nRasClients) { + NCCLCHECK(ncclRealloc(&rasClients, nRasClients, nRasClients+RAS_INCREMENT)); + nRasClients += RAS_INCREMENT; + } + + client = rasClients+i; + memset(client, '\0', sizeof(*client)); + client->sock = client->pfd = -1; + ncclIntruQueueConstruct(&client->sendQ); + client->timeout = RAS_COLLECTIVE_LEG_TIMEOUT; + client->collIdx = -1; + + *pClient = client; + return ncclSuccess; +} + +// Allocates a message of the desired length for sending. +// Behind the scenes uses rasMsgAlloc. +// Must use rasClientFreeMsg to free. +static ncclResult_t rasClientAllocMsg(char** msg, size_t msgLen) { + return rasMsgAlloc((struct rasMsg**)msg, msgLen); +} + +// To be used only with messages allocated with rasClientAllocMsg, i.e., for messages meant for sending. +static void rasClientFreeMsg(char* msg) { + rasMsgFree((struct rasMsg*)msg); +} + +// Enqueues a message for sending to a RAS client. The message *must* have been allocated using rasClientAllocMsg. +static void rasClientEnqueueMsg(struct rasClient* client, char* msg, size_t msgLen) { + // Get to the metadata of this message. + struct rasMsgMeta* meta = (struct rasMsgMeta*)((char*)msg - offsetof(struct rasMsgMeta, msg)); + meta->offset = 0; + meta->length = msgLen; + ncclIntruQueueEnqueue(&client->sendQ, meta); + assert(client->status != RAS_CLIENT_CLOSED && client->status < RAS_CLIENT_FINISHED); + rasPfds[client->pfd].events |= POLLOUT; +} + +// Terminates a connection with a RAS client. +static void rasClientTerminate(struct rasClient* client) { + (void)close(client->sock); + client->sock = -1; + client->status = RAS_CLIENT_CLOSED; + rasPfds[client->pfd].fd = -1; + rasPfds[client->pfd].events = rasPfds[client->pfd].revents = 0; + client->pfd = -1; + while (struct rasMsgMeta* meta = ncclIntruQueueTryDequeue(&client->sendQ)) { + free(meta); + } +} + + +////////////////////////////////////////////////////////////////////// +// Functions related to the asynchronous operations of RAS clients. // +////////////////////////////////////////////////////////////////////// + +// Invoked when an asynchronous operation that a client was waiting on completes. Finds the right client and +// reinvokes rasClientRun. +ncclResult_t rasClientResume(struct rasCollective* coll) { + int collIdx = coll-rasCollectives; + int i; + struct rasClient* client = nullptr; + for (i = 0; i < nRasClients; i++) { + client = rasClients+i; + if (client->status != RAS_CLIENT_CLOSED && client->collIdx == collIdx) { + break; + } + } + if (i == nRasClients) { + INFO(NCCL_RAS, "RAS failed to find a matching client!"); + rasCollFree(coll); + goto exit; + } + + NCCLCHECK(rasClientRun(client)); +exit: + return ncclSuccess; +} + +// Handles a ready client FD from the main event loop. +void rasClientEventLoop(int clientIdx, int pollIdx) { + struct rasClient* client = rasClients+clientIdx; + bool closed = false; + + if (client->status == RAS_CLIENT_CONNECTED) { + char* cmd; + char* cmdEnd; + if (rasPfds[pollIdx].revents & POLLIN) { + if (client->recvOffset < sizeof(client->recvBuffer)) { + ssize_t nRecv; + nRecv = recv(client->sock, client->recvBuffer+client->recvOffset, + sizeof(client->recvBuffer) - client->recvOffset, MSG_DONTWAIT); + if (nRecv == 0) { + closed = true; + } else if (nRecv == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + if (errno == ECONNRESET) + INFO(NCCL_RAS, "RAS socket closed by the client on receive; terminating it"); + else + INFO(NCCL_RAS, "RAS unexpected error from recv; terminating the client socket"); + closed = true; + } + } else { // nRecv > 0 + client->recvOffset += nRecv; + } + } else { // client->recvOffset == sizeof(client->recvBuffer) + rasPfds[client->pfd].events &= ~POLLIN; // No room to receive for now. + } + } // if (rasPfds[pollIdx].revents & POLLIN) + if (closed) { + rasClientTerminate(client); + return; + } + cmd = client->recvBuffer; + while ((cmdEnd = (char*)memchr(cmd, '\n', client->recvOffset - (cmd-client->recvBuffer))) != nullptr) { + char* msg; + int msgLen; + *cmdEnd = '\0'; // Replaces '\n'. + if (cmdEnd > cmd && cmdEnd[-1] == '\r') + cmdEnd[-1] = '\0'; // Replaces '\r' (e.g., in case of a telnet connection). + + if (strncasecmp(cmd, "client protocol ", strlen("client protocol ")) == 0) { + // We ignore the protocol version for now; we just send our version back. + snprintf(rasLine, sizeof(rasLine), "SERVER PROTOCOL " STR(NCCL_RAS_CLIENT_PROTOCOL) "\n"); + msgLen = strlen(rasLine); + if (rasClientAllocMsg(&msg, msgLen) != ncclSuccess) { + rasClientTerminate(client); + return; + } + // We don't copy the terminating '\0', hence memcpy rather than strcpy. + memcpy(msg, rasLine, msgLen); + rasClientEnqueueMsg(client, msg, msgLen); + } else if (strncasecmp(cmd, "timeout ", strlen("timeout ")) == 0) { + char* endPtr = nullptr; + int timeout = strtol(cmd+strlen("timeout "), &endPtr, 10); + if (timeout < 0 || !endPtr || *endPtr != '\0') { + snprintf(rasLine, sizeof(rasLine), "ERROR: Invalid timeout value %s\n", cmd+strlen("timeout ")); + } else { + client->timeout = timeout * CLOCK_UNITS_PER_SEC; + strcpy(rasLine, "OK\n"); + } + msgLen = strlen(rasLine); + if (rasClientAllocMsg(&msg, msgLen) != ncclSuccess) { + rasClientTerminate(client); + return; + } + // We don't copy the terminating '\0', hence memcpy rather than strcpy. + memcpy(msg, rasLine, msgLen); + rasClientEnqueueMsg(client, msg, msgLen); + } else if (strcasecmp(cmd, "status") == 0) { + client->status = RAS_CLIENT_INIT; + (void)rasClientRun(client); + } else if (strcasecmp(cmd, "verbose status") == 0) { + client->status = RAS_CLIENT_INIT; + client->verbose = 1; + (void)rasClientRun(client); + } else { + snprintf(rasLine, sizeof(rasLine), "ERROR: Unknown command %s\n", cmd); + msgLen = strlen(rasLine); + if (rasClientAllocMsg(&msg, msgLen) != ncclSuccess) + return; // It should be non-fatal if we don't return a response... + // We don't copy the terminating '\0', hence memcpy rather than strcpy. + memcpy(msg, rasLine, msgLen); + rasClientEnqueueMsg(client, msg, msgLen); + } + + cmd = cmdEnd+1; + } // while newline found + + if (cmd == client->recvBuffer) { + if (client->recvOffset == sizeof(client->recvBuffer)) { + // We didn't find any newlines and the buffer is full. + INFO(NCCL_RAS, "RAS excessively long input line; terminating the client socket"); + rasClientTerminate(client); + return; + } + // Otherwise it's an incomplete command; we need to wait for the rest of it. + } else { // cmd > client->recvBuffer + // Shift whatever remains (if anything) to the beginning of the buffer. + memmove(client->recvBuffer, cmd, client->recvOffset - (cmd-client->recvBuffer)); + client->recvOffset -= cmd-client->recvBuffer; + } + } // if (client->status == RAS_CLIENT_CONNECTED) + + if (rasPfds[pollIdx].revents & POLLOUT) { + struct rasMsgMeta* meta; + while ((meta = ncclIntruQueueHead(&client->sendQ)) != nullptr) { + ssize_t nSend; + nSend = send(client->sock, ((char*)&meta->msg)+meta->offset, meta->length-meta->offset, + MSG_DONTWAIT | MSG_NOSIGNAL); + if (nSend < 1) { + if (nSend == -1 && errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + if (errno == EPIPE) + INFO(NCCL_RAS, "RAS socket closed by the client on send; terminating it"); + else + INFO(NCCL_RAS, "RAS unexpected error from send; terminating the client socket"); + closed = true; + } + break; + } + + meta->offset += nSend; + if (meta->offset < meta->length) + break; + + ncclIntruQueueDequeue(&client->sendQ); + free(meta); + } // while (meta) + + if (closed) { + rasClientTerminate(client); + return; + } + + if (!meta) { + rasPfds[client->pfd].events &= ~POLLOUT; // Nothing more to send for now. + if (client->status == RAS_CLIENT_FINISHED) + rasClientTerminate(client); + } + } // if (rasPfds[pollIdx].revents & POLLOUT) +} + + +////////////////////////////////////////////////////////// +// Functions driving data gathering for the RAS client. // +////////////////////////////////////////////////////////// + +// Main function that drives the whole data gathering process and sends it back to the client. +// There are multiple asynchronous aspects of it (getting the data on connections and on communicators), so the +// function may exit early and needs to be reinvoked when the asynchronous responses arrive or the timeout expires. +// The state tracking the progress of such operations is kept in the rasClient. +static ncclResult_t rasClientRun(struct rasClient* client) { + ncclResult_t ret = ncclSuccess; + + switch (client->status) { + case RAS_CLIENT_INIT: + NCCLCHECKGOTO(rasClientRunInit(client), ret, exit); +#if 0 // Commented out for now to focus the summary status report on the information most relevant to the users. + // To be revisited with future extensions to RAS. + client->status = RAS_CLIENT_CONNS; + if (ret == ncclInProgress) { + ret = ncclSuccess; + break; + } + case RAS_CLIENT_CONNS: + assert(client->collIdx != -1); + NCCLCHECKGOTO(rasClientRunConns(client), ret, exit); +#endif + client->status = RAS_CLIENT_COMMS; + if (ret == ncclInProgress) { + ret = ncclSuccess; + break; + } + case RAS_CLIENT_COMMS: + assert(client->collIdx != -1); + NCCLCHECKGOTO(rasClientRunComms(client), ret, exit); + client->status = RAS_CLIENT_FINISHED; + break; + default: + WARN("Invalid client status %d", client->status); + ret = ncclInternalError; + goto exit; + } +exit: + return ret; +} + +// Sends to the client the initial data that can be obtained locally -- version info, stats on rasPeers, +// dump of rasDeadPeers. Initiates the RAS_COLL_CONNS collective operation. +static ncclResult_t rasClientRunInit(struct rasClient* client) { + ncclResult_t ret = ncclSuccess; + char* msg = nullptr; + int msgLen; + struct rasPeerInfo* peersReSorted = nullptr; + int totalGpus, totalNodes, firstNGpusNode, firstNGpusGlobal, firstNPeersGlobal; + bool consistentNGpusNode, consistentNGpusGlobal, consistentNPeersGlobal; + int firstIdx, nPeers; + struct rasValCount valCounts[NCCL_MAX_LOCAL_RANKS]; + int nValCounts; + static int cudaDriver = -1, cudaRuntime = -1; + + rasOutReset(); + rasOutAppend("NCCL version " STR(NCCL_MAJOR) "." STR(NCCL_MINOR) "." STR(NCCL_PATCH) NCCL_SUFFIX + " compiled with CUDA " STR(CUDA_MAJOR) "." STR(CUDA_MINOR) "\n"); + if (cudaRuntime == -1) + cudaRuntimeGetVersion(&cudaRuntime); + if (cudaDriver == -1) + cudaDriverGetVersion(&cudaDriver); + rasOutAppend("CUDA runtime version %d, driver version %d\n\n", cudaRuntime, cudaDriver); + msgLen = rasOutLength(); + NCCLCHECKGOTO(rasClientAllocMsg(&msg, msgLen), ret, fail); + rasOutExtract(msg); + rasClientEnqueueMsg(client, msg, msgLen); + msg = nullptr; + + rasOutReset(); + totalGpus = totalNodes = 0; + firstNGpusNode = 0; // #GPUs on the first peer of a node. + firstNGpusGlobal = 0; // #GPUs on peerIdx 0. + consistentNGpusNode = true; // Whether #GPUs/peer is consistent between the peers *on any one node*. + consistentNGpusGlobal = true; // Whether #GPUs/peer is consistent between the peers *on all nodes*. + consistentNPeersGlobal = true; // Whether #peers/node is consistent between all nodes. + nPeers = 0; // #peers on a node. + firstNPeersGlobal = 0; + for (int peerIdx = 0; peerIdx < nRasPeers; peerIdx++) { + int nGpus = __builtin_popcountll(rasPeers[peerIdx].cudaDevs); + totalGpus += nGpus; + if (peerIdx == 0) { + totalNodes = 1; + nPeers = 1; + firstNGpusGlobal = firstNGpusNode = nGpus; + } else { // peerIdx > 0 + if (nGpus != firstNGpusGlobal) + consistentNGpusGlobal = false; + if (!ncclSocketsSameNode(&rasPeers[peerIdx].addr, &rasPeers[peerIdx-1].addr)) { + totalNodes++; + if (firstNPeersGlobal == 0) + firstNPeersGlobal = nPeers; + else if (nPeers != firstNPeersGlobal) + consistentNPeersGlobal = false; + nPeers = 1; + firstNGpusNode = nGpus; + } else { // Same node. + if (nGpus != firstNGpusNode) + consistentNGpusNode = false; + nPeers++; + } // Same node + } // peerIdx > 0 + if (peerIdx == nRasPeers-1) { + if (firstNPeersGlobal == 0) + firstNPeersGlobal = nPeers; + else if (nPeers != firstNPeersGlobal) + consistentNPeersGlobal = false; + } + } // for (peerIdx) + + rasOutAppend("Job summary\n" + "===========\n\n"); + + if (consistentNGpusNode && consistentNGpusGlobal && consistentNPeersGlobal) { + rasOutAppend(" Nodes Processes GPUs Processes GPUs\n" + "(total) per node per process (total) (total)\n" + "%7d" " %9d" " %11d" " %9d" " %7d\n", + totalNodes, firstNPeersGlobal, firstNGpusGlobal, nRasPeers, totalGpus); + } else { + // Gather the stats on the number of processes per node. However, that number is not a property of a peer, + // but of a group of peers, so calculating it is more involved. We make a copy of rasPeers and creatively + // misuse it: cudaDevs of each element will be repurposed to store the number of processes on the node. + NCCLCHECKGOTO(ncclCalloc(&peersReSorted, nRasPeers), ret, fail); + memcpy(peersReSorted, rasPeers, nRasPeers * sizeof(*peersReSorted)); + + firstIdx = 0; + nPeers = 0; + for (int peerIdx = 0; peerIdx < nRasPeers; peerIdx++) { + if (peerIdx == 0) { + nPeers = 1; + firstIdx = 0; + } else { // peerIdx > 0 + if (!ncclSocketsSameNode(&peersReSorted[peerIdx].addr, &peersReSorted[peerIdx-1].addr)) { + for (int i = firstIdx; i < peerIdx; i++) { + // Go back and update the number of processes of all the elements of that node. + peersReSorted[i].cudaDevs = nPeers; + } + nPeers = 1; + firstIdx = peerIdx; + } else { + nPeers++; + } + } // peerIdx > 0 + if (peerIdx == nRasPeers-1) { + // Last iteration of the loop. + for (int i = firstIdx; i < nRasPeers; i++) { + peersReSorted[i].cudaDevs = nPeers; + } + } + } // for (peerIdx) + + // Re-sort it now using the number of processes on the node (cudaDevs) as the primary key, host IP as the + // secondary, and process id as the tertiary. + qsort(peersReSorted, nRasPeers, sizeof(*peersReSorted), rasPeersNProcsCompare); + + // Calculate the distribution of different numbers of peers per node. + nValCounts = 0; + for (int peerIdx = 0; peerIdx < nRasPeers;) { + if (peerIdx == 0 || peersReSorted[peerIdx].cudaDevs != peersReSorted[peerIdx-1].cudaDevs) { + valCounts[nValCounts].value = peersReSorted[peerIdx].cudaDevs; + valCounts[nValCounts].count = 1; + valCounts[nValCounts].firstIdx = peerIdx; + nValCounts++; + } else { + valCounts[nValCounts-1].count++; + } + // Advance peerIdx to the next node. + peerIdx += peersReSorted[peerIdx].cudaDevs; + } + // valCounts is currently sorted by value (the number of peers per node). Sort it by the count (most frequent + // number of peers first). + qsort(valCounts, nValCounts, sizeof(*valCounts), rasValCountsCompareRev); + + // Print it out, the most frequent peer counts first. + if (consistentNGpusNode && consistentNGpusGlobal) { + rasOutAppend(" Nodes Processes GPUs\n" + " per node per process\n"); + for (int i = 0; i < nValCounts; i++) { + struct rasValCount* vc = valCounts+i; + rasOutAppend("%7d %9ld %11d\n", + vc->count, vc->value, firstNGpusGlobal); + } + } else { + rasOutAppend(" Nodes Processes\n" + " per node\n"); + for (int i = 0; i < nValCounts; i++) { + struct rasValCount* vc = valCounts+i; + rasOutAppend("%7d %9ld\n", + vc->count, vc->value); + } + + // We calculate and print the GPUs/process separately. This is required for !consistentNGpusNode and + // it also makes our life easier above for !consistentNGpusGlobal (which could require a larger valCounts). + + // Sort peers by the GPU count, to simplify data extraction. + memcpy(peersReSorted, rasPeers, nRasPeers * sizeof(*peersReSorted)); + // GPU count is the primary key, host IP is the secondary, and process id is the tertiary. + qsort(peersReSorted, nRasPeers, sizeof(*peersReSorted), rasPeersNGpuCompare); + + // Calculate the distribution of different numbers of GPUs per peer. + nValCounts = 0; + for (int peerIdx = 0; peerIdx < nRasPeers; peerIdx++) { + if (peerIdx == 0 || __builtin_popcountll(peersReSorted[peerIdx].cudaDevs) != + __builtin_popcountll(peersReSorted[peerIdx-1].cudaDevs)) { + valCounts[nValCounts].value = __builtin_popcountll(peersReSorted[peerIdx].cudaDevs); + valCounts[nValCounts].count = 1; + valCounts[nValCounts].firstIdx = peerIdx; + nValCounts++; + } else { + valCounts[nValCounts-1].count++; + } + } + // valCounts is currently sorted by value (number of GPUs per peer). Sort it by the count (most frequent + // GPU counts first). + qsort(valCounts, nValCounts, sizeof(*valCounts), rasValCountsCompareRev); + + // Print it out, the most frequent GPU counts first. + rasOutAppend("\n" + " Processes GPUs\n" + " per process\n"); + for (int i = 0; i < nValCounts; i++) { + struct rasValCount* vc = valCounts+i; + rasOutAppend(" %9d %11ld\n", + vc->count, vc->value); + } + } + rasOutAppend("\n" + " Nodes Processes GPUs\n" + "(total) (total) (total)\n" + "%7d" " %9d" " %11d\n", + totalNodes, nRasPeers, totalGpus); + + if (consistentNGpusNode && consistentNGpusGlobal) { + // In this simpler case, also print the node outliers. + for (int i = 1; i < nValCounts; i++) { + struct rasValCount* vc = valCounts+i; + // We assume that the most frequent group is correct; for the remaining ones, we try to provide more info, + // provided that they meet our definition of an outlier. + if (rasCountIsOutlier(vc->count, client->verbose, totalNodes)) { + rasOutAppend("\nThe outlier node%s:\n", (vc->count > 1 ? "s" : "")); + // peersReSorted is sorted by the node IP address (not port!) as the secondary key and the pid as + // the tertiary, which comes in handy when printing... + for (int peerIdx = vc->firstIdx; peerIdx < vc->count*vc->value + vc->firstIdx; peerIdx += vc->value) { + lineBuf[0] = '\0'; + for (int j = 0; j < vc->value; j++) { + snprintf(lineBuf+strlen(lineBuf), sizeof(lineBuf)-strlen(lineBuf), "%s%d", + (j > 0 ? "," : ""), peersReSorted[j].pid); + } + rasOutAppend(" Node %s running process%s %s\n", + ncclSocketToHost(&peersReSorted[peerIdx].addr, rasLine, sizeof(rasLine)), + (vc->value > 1 ? "es" : ""), lineBuf); + } // for (peerIdx) + } // if (rasCountIsOutlier(vc->count)) + } // for (i) + } // !consistentNPeersGlobal + } // !consistentNGpusNode || !consistentNGpusGlobal || !consistentNPeersGlobal + +#if 0 // Commented out for now to focus the summary status report on the information most relevant to the users. + // To be revisited with future extensions to RAS. + rasOutAppend("\nGathering data about the RAS network (timeout %lds)...", client->timeout / CLOCK_UNITS_PER_SEC); + msgLen = rasOutLength(); + NCCLCHECKGOTO(rasClientAllocMsg(&msg, msgLen), ret, fail); + rasOutExtract(msg); + rasClientEnqueueMsg(client, msg, msgLen); + msg = nullptr; + { + struct rasCollRequest collReq; + bool allDone = false; + rasCollReqInit(&collReq); + collReq.timeout = client->timeout; + collReq.type = RAS_COLL_CONNS; + NCCLCHECKGOTO(rasNetSendCollReq(&collReq, rasCollDataLength(RAS_COLL_CONNS), &allDone, &client->collIdx), + ret, fail); + if (!allDone) + ret = ncclInProgress; // We need to wait for async. responses. + } +#endif + rasOutAppend("\nCommunicators..."); + msgLen = rasOutLength(); + NCCLCHECKGOTO(rasClientAllocMsg(&msg, msgLen), ret, fail); + rasOutExtract(msg); + rasClientEnqueueMsg(client, msg, msgLen); + msg = nullptr; + { + struct rasCollRequest collReq; + bool allDone = false; + rasCollReqInit(&collReq); + collReq.timeout = client->timeout; + collReq.type = RAS_COLL_COMMS; + NCCLCHECKGOTO(rasNetSendCollReq(&collReq, rasCollDataLength(RAS_COLL_COMMS), &allDone, &client->collIdx), + ret, fail); + if (!allDone) + ret = ncclInProgress; + } +exit: + free(peersReSorted); + return ret; +fail: + goto exit; +} + +#if 0 // Commented out for now to focus the summary status report on the information most relevant to the users. + // To be revisited with future extensions to RAS. +// Processes the response from the RAS_COLL_CONNS collective operation and sends the data to the client (for now +// primarily the list of missing processes). Initiates the RAS_COLL_COMMS collective operation. +static ncclResult_t rasClientRunConns(struct rasClient* client) { + ncclResult_t ret = ncclSuccess; + char* msg = nullptr; + int msgLen; + struct rasCollective* coll = rasCollectives+client->collIdx; + struct rasCollConns* connsData = (struct rasCollConns*)coll->data; + int expected; + struct rasPeerInfo* peersBuf = nullptr; + + assert(coll->nFwdSent == coll->nFwdRecv); + client->collIdx = -1; + + rasOutReset(); + rasOutAppend(" obtained a result in %.2fs\n", (clockNano()-coll->startTime)/1e9); + if (coll->nLegTimeouts > 0) { + rasOutAppend(" Warning: encountered %d communication timeout%s while gathering data\n", coll->nLegTimeouts, + (coll->nLegTimeouts > 1 ? "s" : "")); + } + + expected = nRasPeers - nRasDeadPeers; + if (coll->nPeers != expected) { + int missing = expected - coll->nPeers; + rasOutAppend(" Warning: missing data from %d process%s (received from %d, expected %d)\n", + missing, (missing > 1 ? "es" : ""), coll->nPeers, expected); + if (missing <= RAS_CLIENT_DETAIL_THRESHOLD) { + // Extract a list of missing peers. We don't want to print it right away because it would be sorted + // by address (including port, which isn't meaningful to end users). + int nPeersBuf = 0; + NCCLCHECKGOTO(ncclCalloc(&peersBuf, missing), ret, fail); + // Ensure both arrays are sorted (rasPeers already is, by addr); makes finding missing records a breeze. + qsort(coll->peers, coll->nPeers, sizeof(*coll->peers), &ncclSocketsCompare); + for (int rasPeerIdx = 0, collPeerIdx = 0; rasPeerIdx < nRasPeers || collPeerIdx < coll->nPeers;) { + int cmp; + if (rasPeerIdx < nRasPeers && collPeerIdx < coll->nPeers) + cmp = ncclSocketsCompare(&rasPeers[rasPeerIdx].addr, coll->peers+collPeerIdx); + else + cmp = (rasPeerIdx < nRasPeers ? -1 : 1); + + if (cmp == 0) { + rasPeerIdx++; + collPeerIdx++; + } else if (cmp < 0) { + memcpy(peersBuf+(nPeersBuf++), rasPeers+rasPeerIdx, sizeof(*peersBuf)); + rasPeerIdx++; + } else { // cmp > 0 + // Process not found in rasPeers -- shouldn't happen. + collPeerIdx++; + } // cmp > 0 + } // for (rasPeerIdx, collPeerIdx) + + // Sort the output by host and pid, not host and port. + qsort(peersBuf, nPeersBuf, sizeof(*peersBuf), rasPeersHostPidCompare); + rasOutAppend(" The missing process%s:\n", (missing > 1 ? "es" : "")); + for (int peerIdx = 0; peerIdx < nPeersBuf; peerIdx++) { + rasOutAppend(" Process %d on node %s managing GPU%s %s\n", peersBuf[peerIdx].pid, + ncclSocketToHost(&peersBuf[peerIdx].addr, rasLine, sizeof(rasLine)), + (__builtin_popcountll(peersBuf[peerIdx].cudaDevs) > 1 ? "s" : ""), + rasGpuDevsToString(peersBuf[peerIdx].cudaDevs, peersBuf[peerIdx].nvmlDevs, lineBuf, + sizeof(lineBuf))); + } + if (nPeersBuf != missing) + rasOutAppend(" [could not find information on %d process%s]\n", + missing-nPeersBuf, (missing-nPeersBuf > 1 ? "es" : "")); + } // if (expected - coll->nPeers <= RAS_CLIENT_DETAIL_THRESHOLD) + } // if (coll->nPeers != expected) + + if (connsData->nConns > 0) { + rasOutAppend(" Collected data about %d unidirectional connection%s\n", + connsData->nConns, (connsData->nConns > 1 ? "s" : "")); + rasOutAppend(" Travel times (valid only if system clocks are synchronized between nodes):\n" + " Minimum %fs, maximum %fs, average %fs\n", + connsData->travelTimeMin/1e9, connsData->travelTimeMax/1e9, + connsData->travelTimeSum/(1e9*connsData->travelTimeCount)); + } else { + rasOutAppend(" No connection data collected!\n"); + } + if (connsData->nNegativeMins > 0) { + rasOutAppend(" Warning: negative travel times were observed across %d connection%s,\n" + " indicating that the system clocks are *not* synchronized.\n" + " Ordering of events based on local timestamps should be considered unreliable\n", + connsData->nNegativeMins, (connsData->nNegativeMins > 1 ? "s" : "")); + if (connsData->nNegativeMins <= RAS_CLIENT_DETAIL_THRESHOLD) { + rasOutAppend(" The affected connection%s:\n", (connsData->nNegativeMins > 1 ? "s" : "")); + for (int i = 0; i < connsData->nNegativeMins; i++) { + struct rasCollConns::negativeMin* negativeMin = connsData->negativeMins+i; + int sourcePeerIdx = rasPeerFind(&negativeMin->source); + int destPeerIdx = rasPeerFind(&negativeMin->dest); + if (sourcePeerIdx != -1 && destPeerIdx != -1) + rasOutAppend(" From node %s process %d to node %s process %d: observed travel time of %fs\n", + ncclSocketToHost(&negativeMin->source, rasLine, sizeof(rasLine)), rasPeers[sourcePeerIdx].pid, + ncclSocketToHost(&negativeMin->dest, lineBuf, sizeof(lineBuf)), rasPeers[destPeerIdx].pid, + negativeMin->travelTimeMin/1e9); + } + } + } + rasCollFree(coll); + + rasOutAppend("\nGathering data about the NCCL communicators (timeout %lds)...", + client->timeout / CLOCK_UNITS_PER_SEC); + msgLen = rasOutLength(); + NCCLCHECKGOTO(rasClientAllocMsg(&msg, msgLen), ret, fail); + rasOutExtract(msg); + rasClientEnqueueMsg(client, msg, msgLen); + msg = nullptr; + { + struct rasCollRequest collReq; + bool allDone = false; + rasCollReqInit(&collReq); + collReq.timeout = client->timeout; + collReq.type = RAS_COLL_COMMS; + NCCLCHECKGOTO(rasNetSendCollReq(&collReq, rasCollDataLength(RAS_COLL_COMMS), &allDone, &client->collIdx), + ret, fail); + if (!allDone) + ret = ncclInProgress; + } +exit: + free(peersBuf); + return ret; +fail: + goto exit; +} +#endif + +// Processes the response from the RAS_COLL_COMMS collective operation and sends the data to the client: +// statistics on the communicators, missing data from ranks, inconsistent collective operation counts, +// initialization and asynchronous errors, and inconsistent initialization/termination status. +static ncclResult_t rasClientRunComms(struct rasClient* client) { + ncclResult_t ret = ncclSuccess; + char* msg = nullptr; + int msgLen; + struct rasCollective* coll = rasCollectives+client->collIdx; + struct rasCollComms* commsData = (struct rasCollComms*)coll->data; + struct rasCollComms::comm* comm; + struct rasCollComms::comm::rank* ranksReSorted = nullptr; + struct rasValCount* valCounts = nullptr; + int nValCounts; + struct rasValCount* collOpCounts = nullptr; + struct rasAuxComm* auxComms = nullptr; + int maxCommSize; + int* peerIdxConv = nullptr; + int vcIdx; + int nPeersMissing; + uint64_t* peerNvmlDevs = nullptr; + const char*const statusStr[] = { "UNKNOWN", "INIT", "RUNNING", "FINALIZE", "ABORT" }; + const char*const errorStr[] = { + // Listing them all like this, while a bit of a hassle, is less effort than formatting in a temporary buffer. + "OK", + "MISMATCH", + "ERROR", + "ERROR,MISMATCH", + "INCOMPLETE", + "INCOMPLETE,MISMATCH", + "INCOMPLETE,ERROR", + "INCOMPLETE,ERROR,MISMATCH" + }; + + assert(coll->nFwdSent == coll->nFwdRecv); + client->collIdx = -1; + + rasOutReset(); + rasOutAppend(" (%.2fs)\n=============\n\n", (clockNano()-coll->startTime)/1e9); + + // Calculate the number of missing peers early as we rely on it for other things. + nPeersMissing = nRasPeers - nRasDeadPeers - coll->nPeers; + + // Sort the communicators by size. As the structure is inconvenient to move around due to the elements being + // of variable length, we create an auxiliary array that includes pointers to individual elements and simply sort + // that array while keeping the data intact. + NCCLCHECKGOTO(ncclCalloc(&auxComms, commsData->nComms), ret, fail); + // While initializing the just allocated array, also find out the size of the largest communicator so that we know + // how much memory to allocate for another temporary array. + maxCommSize = 0; + comm = commsData->comms; + for (int commIdx = 0; commIdx < commsData->nComms; commIdx++) { + if (maxCommSize < comm->commNRanks) + maxCommSize = comm->commNRanks; + auxComms[commIdx].comm = comm; + comm = (struct rasCollComms::comm*)(((char*)(comm+1)) + comm->nRanks * sizeof(*comm->ranks)); + } + NCCLCHECKGOTO(ncclCalloc(&ranksReSorted, maxCommSize), ret, fail); + + // For convenience, create a translation table from rasCollective's peerIdx to rasPeers peerIdx. + NCCLCHECKGOTO(ncclCalloc(&peerIdxConv, coll->nPeers), ret, fail); + for (int peerIdx = 0; peerIdx < coll->nPeers; peerIdx++) + peerIdxConv[peerIdx] = rasPeerFind(coll->peers+peerIdx); + // Sort coll->peers to match the ordering of rasPeers -- we may need it later... + qsort(coll->peers, coll->nPeers, sizeof(*coll->peers), &ncclSocketsCompare); + + // Fill in the remaining fields of auxComm's. + for (int commIdx = 0; commIdx < commsData->nComms; commIdx++) { + struct rasAuxComm* auxComm = auxComms+commIdx; + int nRanks = 0; + comm = auxComm->comm; + + if (comm->commNRanks > comm->nRanks) { + // There are two possibilities here. Either we are missing the data on some ranks because the processes are + // unreachable, or the processes _are_ reachable but didn't report to be part of this communicator (which + // could definitely happen if some processes have already called ncclCommDestroy or ncclCommAbort). Because we + // currently don't collect data about missing ranks, we can't reliably distinguish these two cases. + // For now we rely on an approximation: if we _know_ that some peers failed to respond, we mark this + // as an INCOMPLETE error; otherwise as a MISMATCH warning. + if (nPeersMissing > 0 || nRasDeadPeers > 0) + auxComm->errors |= RAS_ACE_INCOMPLETE; + else { + auxComm->errors |= RAS_ACE_MISMATCH; + auxComm->status |= RAS_ACS_UNKNOWN; + } + } + + memcpy(ranksReSorted, comm->ranks, comm->nRanks * sizeof(*ranksReSorted)); + // Convert ranksReSorted' peerIdx to rasPeers and sort by it -- that way we will have the ranks sorted + // by process _and_ node, which makes counting easy. + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) + ranksReSorted[rankIdx].peerIdx = peerIdxConv[ranksReSorted[rankIdx].peerIdx]; + qsort(ranksReSorted, comm->nRanks, sizeof(*ranksReSorted), rasCommRanksPeerCompare); + + // Count the peers and nodes, get the status/error indicators. + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) { + struct rasCollComms::comm::rank* rank = ranksReSorted+rankIdx; + if (rankIdx == 0) { + auxComm->nPeers = auxComm->nNodes = 1; + auxComm->ranksPerNodeMin = NCCL_MAX_LOCAL_RANKS; + auxComm->ranksPerNodeMax = 0; + auxComm->firstCollOpCount = rank->collOpCount; + nRanks = 1; + } else { // rankIdx > 0 + if (rank->peerIdx != rank[-1].peerIdx) { + auxComm->nPeers++; + if (!ncclSocketsSameNode(&rasPeers[rank->peerIdx].addr, &rasPeers[rank[-1].peerIdx].addr)) { + auxComm->nNodes++; + if (auxComm->ranksPerNodeMin > nRanks) + auxComm->ranksPerNodeMin = nRanks; + if (auxComm->ranksPerNodeMax < nRanks) + auxComm->ranksPerNodeMax = nRanks; + nRanks = 0; + } + } // if (rank->peerIdx != rank[-1].peerIdx) + nRanks++; + } // rankIdx > 0 + if (rankIdx == comm->nRanks-1) { + // Last iteration of the loop. + if (auxComm->ranksPerNodeMin > nRanks) + auxComm->ranksPerNodeMin = nRanks; + if (auxComm->ranksPerNodeMax < nRanks) + auxComm->ranksPerNodeMax = nRanks; + } + + if (rank->status.abortFlag) + auxComm->status |= RAS_ACS_ABORT; + else if (rank->status.finalizeCalled || rank->status.destroyFlag) { + // destroyFlag is set by ncclCommDestroy and ncclCommAbort. finalizeCalled appears to be set by + // ncclCommFinalize only. According to the docs, ncclCommDestroy *can* be called without calling + // ncclCommFinalize first. The code structure here ensures that we attribute destroyFlag properly + // as a finalize state indicator (and ignore it in case of ncclCommAbort). + auxComm->status |= RAS_ACS_FINALIZE; + } + else if (rank->status.initState == ncclSuccess) + auxComm->status |= RAS_ACS_RUNNING; + else // rank->initState != ncclSuccess + auxComm->status |= RAS_ACS_INIT; + + if (rank->collOpCount != auxComm->firstCollOpCount) + auxComm->errors |= RAS_ACE_MISMATCH; + if (rank->status.initState != ncclSuccess && rank->status.initState != ncclInProgress) + auxComm->errors |= RAS_ACE_ERROR; + if (rank->status.asyncError != ncclSuccess && rank->status.asyncError != ncclInProgress) + auxComm->errors |= RAS_ACE_ERROR; + } // for (rankIdx) + + if (__builtin_popcount(auxComm->status) > 1) { + // We've got a status mismatch between ranks. + auxComm->errors |= RAS_ACE_MISMATCH; + } + } // for (commIdx) + // Sort it by size/nNodes/status/errors/missing ranks. + qsort(auxComms, commsData->nComms, sizeof(*auxComms), &rasAuxCommsCompareRev); + + // Calculate the distribution of different communicator sizes. + NCCLCHECKGOTO(ncclCalloc(&valCounts, commsData->nComms), ret, fail); + nValCounts = 0; + for (int commIdx = 0; commIdx < commsData->nComms; commIdx++) { + if (commIdx == 0 || + auxComms[commIdx].comm->commNRanks != auxComms[commIdx-1].comm->commNRanks || + auxComms[commIdx].nNodes != auxComms[commIdx-1].nNodes || + // __builtin_clz returns the number of leading 0-bits, which is a proxy for the index of the highest 1-bit. + __builtin_clz(auxComms[commIdx].status) != __builtin_clz(auxComms[commIdx-1].status) || + auxComms[commIdx].errors != auxComms[commIdx-1].errors) { + valCounts[nValCounts].value = 0; // We have many distinguishing values but only one field to store them. + // It doesn't really matter, given that we can extract them via firstIdx. + valCounts[nValCounts].count = 1; + valCounts[nValCounts].firstIdx = commIdx; + nValCounts++; + } else { + valCounts[nValCounts-1].count++; + } + } + + rasOutAppend("Group Comms Nodes Ranks Ranks Ranks Status Errors\n" + " # in group per comm per node per comm in group\n"); + if (commsData->nComms == 0) + rasOutAppend("No communicator data collected!\n"); + + // Allocate an auxiliary structure used for counting the number of ranks (unique GPUs) in a group. + NCCLCHECKGOTO(ncclCalloc(&peerNvmlDevs, coll->nPeers), ret, fail); + + // Print it out, the largest communicators first. + for (int vcIdx = 0; vcIdx < nValCounts; vcIdx++) { + struct rasValCount* vc = valCounts+vcIdx; + struct rasAuxComm* auxComm = auxComms+vc->firstIdx; + int ranksPerNodeMin, ranksPerNodeMax; + int ranksTotal; + + ranksPerNodeMin = NCCL_MAX_LOCAL_RANKS; + ranksPerNodeMax = 0; + memset(peerNvmlDevs, '\0', coll->nPeers * sizeof(*peerNvmlDevs)); + // We don't group comms by ranksPerNodeMin/Max, so the values may differ between comms in one group. + // Calculate the group's min/max. + // Also calculate the number of unique ranks in the group. + for (int commIdx = 0; commIdx < vc->count; commIdx++) { + if (ranksPerNodeMin > auxComm[commIdx].ranksPerNodeMin) + ranksPerNodeMin = auxComm[commIdx].ranksPerNodeMin; + if (ranksPerNodeMax < auxComm[commIdx].ranksPerNodeMax) + ranksPerNodeMax = auxComm[commIdx].ranksPerNodeMax; + for (int rankIdx = 0; rankIdx < auxComm[commIdx].comm->nRanks; rankIdx++) { + struct rasCollComms::comm::rank* rank = auxComm[commIdx].comm->ranks+rankIdx; + peerNvmlDevs[rank->peerIdx] |= (1UL << rank->nvmlDev); + } + } + ranksTotal = 0; + for (int peerIdx = 0; peerIdx < coll->nPeers; peerIdx++) + ranksTotal += __builtin_popcountll(peerNvmlDevs[peerIdx]); + if (ranksPerNodeMin == ranksPerNodeMax) + snprintf(rasLine, sizeof(rasLine), "%d", ranksPerNodeMin); + else + snprintf(rasLine, sizeof(rasLine), "%d-%d", ranksPerNodeMin, ranksPerNodeMax); + rasOutAppend("%5d %8d %8d %8s %8d %8d %8s %6s\n", + vcIdx, vc->count, auxComm->nNodes, rasLine, auxComm->comm->commNRanks, ranksTotal, + // __builtin_clz returns the number of leading 0-bits. This makes it possible to translate the + // status (which is a bitmask) into an array index. + statusStr[(sizeof(unsigned int)*8-1)-__builtin_clz(auxComm->status)], errorStr[auxComm->errors]); + } + + rasOutAppend("\nErrors\n" + "======\n\n"); + + if (nPeersMissing > 0) { + rasOutAppend("INCOMPLETE\n" + " Missing communicator data from %d job process%s\n", nPeersMissing, (nPeersMissing > 1 ? "es" : "")); + if (rasCountIsOutlier(nPeersMissing, client->verbose)) { + // Extract a list of missing peers. We don't want to print it right away because it would be sorted + // by address (including port, which isn't meaningful to end users). + struct rasPeerInfo* peersBuf = nullptr; + int nPeersBuf; + + // Both rasPeers and coll->peers are sorted by address (the latter we sorted above) which makes comparing + // them much easier. + NCCLCHECKGOTO(ncclCalloc(&peersBuf, nPeersMissing), ret, fail); + nPeersBuf = 0; + for (int rasPeerIdx = 0, collPeerIdx = 0; rasPeerIdx < nRasPeers || collPeerIdx < coll->nPeers;) { + int cmp; + if (rasPeerIdx < nRasPeers && collPeerIdx < coll->nPeers) + cmp = ncclSocketsCompare(&rasPeers[rasPeerIdx].addr, coll->peers+collPeerIdx); + else + cmp = (rasPeerIdx < nRasPeers ? -1 : 1); + + if (cmp == 0) { + rasPeerIdx++; + collPeerIdx++; + } else if (cmp < 0) { + // Process missing from coll->peers. Don't report dead ones though, as they are not included + // in nPeersMissing and are reported separately below. + if (!rasPeerIsDead(&rasPeers[rasPeerIdx].addr)) { + assert(nPeersBuf < nPeersMissing); + memcpy(peersBuf+(nPeersBuf++), rasPeers+rasPeerIdx, sizeof(*peersBuf)); + } + rasPeerIdx++; + } else { // cmp > 0 + // Process not found in rasPeers -- shouldn't happen, unless during a race? + collPeerIdx++; + } // cmp > 0 + } // for (rasPeerIdx, collPeerIdx) + + // Sort the output by host and pid. + qsort(peersBuf, nPeersBuf, sizeof(*peersBuf), rasPeersHostPidCompare); + for (int peerIdx = 0; peerIdx < nPeersBuf; peerIdx++) { + rasOutAppend(" Process %d on node %s managing GPU%s %s\n", peersBuf[peerIdx].pid, + ncclSocketToHost(&peersBuf[peerIdx].addr, rasLine, sizeof(rasLine)), + (__builtin_popcountll(peersBuf[peerIdx].cudaDevs) > 1 ? "s" : ""), + rasGpuDevsToString(peersBuf[peerIdx].cudaDevs, peersBuf[peerIdx].nvmlDevs, lineBuf, + sizeof(lineBuf))); + } + if (nPeersBuf != nPeersMissing) + rasOutAppend(" [could not find information on %d process%s]\n", + nPeersMissing-nPeersBuf, (nPeersMissing-nPeersBuf > 1 ? "es" : "")); + free(peersBuf); + } // if (rasCountIsOutlier(nPeersMissing)) + rasOutAppend("\n"); + } + + if (nRasDeadPeers > 0) { + rasOutAppend("DEAD\n" + " %d job process%s considered dead (unreachable via the RAS network)\n", nRasDeadPeers, + (nRasDeadPeers > 1 ? "es are" : " is")); + if (rasCountIsOutlier(nRasDeadPeers, client->verbose)) { + struct rasPeerInfo* peersReSorted = nullptr; + int nPeersReSorted = 0; + NCCLCHECKGOTO(ncclCalloc(&peersReSorted, nRasDeadPeers), ret, fail); + for (int i = 0; i < nRasDeadPeers; i++) { + int peerIdx = rasPeerFind(rasDeadPeers+i); + if (peerIdx != -1) + memcpy(peersReSorted+(nPeersReSorted++), rasPeers+peerIdx, sizeof(*peersReSorted)); + } + // Sort the output by host and pid, not host and port. + qsort(peersReSorted, nPeersReSorted, sizeof(*peersReSorted), rasPeersHostPidCompare); + for (int peerIdx = 0; peerIdx < nPeersReSorted; peerIdx++) { + rasOutAppend(" Process %d on node %s managing GPU%s %s\n", peersReSorted[peerIdx].pid, + ncclSocketToHost(&peersReSorted[peerIdx].addr, rasLine, sizeof(rasLine)), + (__builtin_popcountll(peersReSorted[peerIdx].cudaDevs) > 1 ? "s" : ""), + rasGpuDevsToString(peersReSorted[peerIdx].cudaDevs, peersReSorted[peerIdx].nvmlDevs, lineBuf, + sizeof(lineBuf))); + } + if (nPeersReSorted != nRasDeadPeers) + rasOutAppend(" [could not find information on %d process%s]\n", + nRasDeadPeers-nPeersReSorted, (nRasDeadPeers-nPeersReSorted > 1 ? "es" : "")); + free(peersReSorted); + } // if (rasCountIsOutlier(nRasDeadPeers) + rasOutAppend("\n"); + } + + for (vcIdx = 0; vcIdx < nValCounts; vcIdx++) { + struct rasValCount* vc; + vc = valCounts+vcIdx; + for (int commIdx = vc->firstIdx; commIdx < vc->count + vc->firstIdx; commIdx++) { + struct rasAuxComm* auxComm = auxComms+commIdx; + comm = auxComm->comm; + + if (auxComm->errors & RAS_ACE_INCOMPLETE) { + int nRanksMissing = comm->commNRanks - comm->nRanks; + rasOutAppend("#%d-%d (%016lx) INCOMPLETE\n" + " Missing communicator data from %d rank%s\n", vcIdx, commIdx - vc->firstIdx, + comm->commHash, nRanksMissing, (nRanksMissing > 1 ? "s" : "")); + if (rasCountIsOutlier(nRanksMissing, client->verbose)) { + lineBuf[0] = '\0'; + // rankIdx indexes the comm->ranks array; in principle it should be the same as commRank, with the + // exception of the missing ranks... + for (int commRank = 0, rankIdx = 0; commRank < comm->commNRanks; commRank++) { + if (rankIdx < comm->nRanks && comm->ranks[rankIdx].commRank == commRank) { + rankIdx++; + } else { + snprintf(lineBuf+strlen(lineBuf), sizeof(lineBuf)-strlen(lineBuf), "%s%d", + (rankIdx == commRank ? "" : ","), commRank); + } + } // for (commRank) + rasOutAppend(" The missing rank%s: %s\n", (nRanksMissing > 1 ? "s" : ""), lineBuf); + } // if (rasCountIsOutlier(nRanksMissing)) + rasOutAppend("\n"); + } // if (auxComm->errors & RAS_ACE_INCOMPLETE) + + if (auxComm->errors & RAS_ACE_ERROR) { + int ncclErrors[ncclNumResults]; + int nErrors; + rasOutAppend("#%d-%d (%016lx) ERROR\n", vcIdx, commIdx - vc->firstIdx, comm->commHash); + + memset(ncclErrors, '\0', sizeof(ncclErrors)); + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) + ncclErrors[comm->ranks[rankIdx].status.initState]++; + nErrors = comm->nRanks - (ncclErrors[ncclSuccess] + ncclErrors[ncclInProgress]); + if (nErrors > 0) { + rasOutAppend(" Initialization error%s on %d rank%s\n", + (nErrors > 1 ? "s" : ""), nErrors, (nErrors > 1 ? "s" : "")); + rasClientBreakDownErrors(client, comm, peerIdxConv, ncclErrors); + } + + memset(ncclErrors, '\0', sizeof(ncclErrors)); + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) + ncclErrors[comm->ranks[rankIdx].status.asyncError]++; + nErrors = comm->nRanks - (ncclErrors[ncclSuccess] + ncclErrors[ncclInProgress]); + if (nErrors > 0) { + rasOutAppend(" Asynchronous error%s on %d rank%s\n", + (nErrors > 1 ? "s" : ""), nErrors, (nErrors > 1 ? "s" : "")); + rasClientBreakDownErrors(client, comm, peerIdxConv, ncclErrors, /*isAsync*/true); + } + rasOutAppend("\n"); + } // if (auxComm->errors & RAS_ACE_ERROR) + } // for (commIdx) + } // for (vcIdx) + + rasOutAppend("Warnings\n" + "========\n\n"); + + if (coll->nLegTimeouts > 0) { + rasOutAppend("TIMEOUT\n" + " Encountered %d communication timeout%s while gathering communicator data\n\n", + coll->nLegTimeouts, (coll->nLegTimeouts > 1 ? "s" : "")); + } + + for (int vcIdx = 0; vcIdx < nValCounts; vcIdx++) { + struct rasValCount* vc = valCounts+vcIdx; + for (int commIdx = vc->firstIdx; commIdx < vc->count + vc->firstIdx; commIdx++) { + bool inconsistent; + struct rasAuxComm* auxComm = auxComms+commIdx; + comm = auxComm->comm; + + if (auxComm->errors & RAS_ACE_MISMATCH) { + rasOutAppend("#%d-%d (%016lx) MISMATCH\n", vcIdx, commIdx - vc->firstIdx, comm->commHash); + + if (collOpCounts == nullptr) { + // Allocating comm->commNRanks elements ensures that we won't need to reallocate, because the valCounts + // array is reverse-sorted by commNRanks. On the other hand, for this purpose allocating commNRanks + // elements may be massively overpessimistic... + NCCLCHECKGOTO(ncclCalloc(&collOpCounts, comm->commNRanks), ret, fail); + } + + if (__builtin_popcount(auxComm->status) > 1) { + rasOutAppend(" Communicator ranks have different status\n"); + + // We need to sort the ranks by status. However, status is normally calculated from other fields. + // We will copy the ranks and reuse collOpCount to store it. + memcpy(ranksReSorted, comm->ranks, comm->nRanks * sizeof(*ranksReSorted)); + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) { + struct rasCollComms::comm::rank* rank = ranksReSorted+rankIdx; + + if (rank->status.abortFlag) + rank->collOpCount = RAS_ACS_ABORT; + else if (rank->status.finalizeCalled || rank->status.destroyFlag) + rank->collOpCount = RAS_ACS_FINALIZE; + else if (rank->status.initState == ncclSuccess) + rank->collOpCount = RAS_ACS_RUNNING; + else + rank->collOpCount = RAS_ACS_INIT; + } + qsort(ranksReSorted, comm->nRanks, sizeof(*ranksReSorted), rasCommRanksCollOpCompare); + // Calculate the frequency of different status values. + int nCollOpCounts = 0; + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) { + if (rankIdx == 0 || ranksReSorted[rankIdx].collOpCount != ranksReSorted[rankIdx-1].collOpCount) { + // __builtin_clz returns the number of leading 0-bits. This makes it possible to translate the + // status (which is a bitmask) into an array index. + collOpCounts[nCollOpCounts].value = (sizeof(unsigned int)*8-1) - __builtin_clz(ranksReSorted[rankIdx].collOpCount); + collOpCounts[nCollOpCounts].count = 1; + collOpCounts[nCollOpCounts].firstIdx = rankIdx; + nCollOpCounts++; + } else { + collOpCounts[nCollOpCounts-1].count++; + } + } + if (comm->nRanks < comm->commNRanks) { + // Add a "fake" element corresponding to the missing entries. The statusStr array contains the "UNKNOWN" + // string at index 0. + collOpCounts[nCollOpCounts].value = 0; + collOpCounts[nCollOpCounts].count = comm->commNRanks - comm->nRanks; + collOpCounts[nCollOpCounts].firstIdx = -1; // "Fake" entry identifier. + nCollOpCounts++; + } + // Sort by that frequency (most frequent first). + qsort(collOpCounts, nCollOpCounts, sizeof(*collOpCounts), rasValCountsCompareRev); + + for (int coc = 0; coc < nCollOpCounts; coc++) { + struct rasValCount* vcc = collOpCounts+coc; + if (vcc->count > 1) + rasOutAppend(" %d ranks have status %s\n", vcc->count, statusStr[vcc->value]); + if (rasCountIsOutlier(vcc->count, client->verbose, comm->commNRanks)) { + if (vcc->firstIdx != -1) { + // ranksReSorted is sorted by rank as the secondary key, which comes in handy when printing... + for (int rankIdx = vcc->firstIdx; rankIdx < vcc->count+vcc->firstIdx; rankIdx++) { + int peerIdx = peerIdxConv[ranksReSorted[rankIdx].peerIdx]; + if (peerIdx != -1) { + if (vcc->count > 1) + rasOutAppend(" Rank %d -- GPU %s managed by process %d on node %s\n", + ranksReSorted[rankIdx].commRank, + rasCommRankGpuToString(ranksReSorted+rankIdx, lineBuf, sizeof(lineBuf)), + rasPeers[peerIdx].pid, + ncclSocketToHost(&rasPeers[peerIdx].addr, rasLine, sizeof(rasLine))); + else + rasOutAppend(" Rank %d has status %s -- GPU %s managed by process %d on node %s\n", + ranksReSorted[rankIdx].commRank, statusStr[vcc->value], + rasCommRankGpuToString(ranksReSorted+rankIdx, lineBuf, sizeof(lineBuf)), + rasPeers[peerIdx].pid, + ncclSocketToHost(&rasPeers[peerIdx].addr, rasLine, sizeof(rasLine))); + } else { // peerIdx == -1 + if (vcc->count > 1) + rasOutAppend(" Rank %d -- [process information not found]\n", ranksReSorted[rankIdx].commRank); + else + rasOutAppend(" Rank %d has status %s -- [process information not found]\n", + ranksReSorted[rankIdx].commRank, statusStr[vcc->value]); + } // peerIdx == -1 + } // for (rankIdx) + } else { + // UNKNOWN ranks. Format a string with their rank numbers (we don't know anything more). + lineBuf[0] = '\0'; + // rankIdx indexes the comm->ranks array; in principle it should be the same as commRank, with the + // exception of the missing ranks... + for (int commRank = 0, rankIdx = 0; commRank < comm->commNRanks; commRank++) { + if (rankIdx < comm->nRanks && comm->ranks[rankIdx].commRank == commRank) { + rankIdx++; + } else { + snprintf(lineBuf+strlen(lineBuf), sizeof(lineBuf)-strlen(lineBuf), "%s%d", + (rankIdx == commRank ? "" : ","), commRank); + } + } // for (commRank) + if (vcc->count > 1) { + rasOutAppend(" The unknown ranks: %s\n", lineBuf); + } else { + rasOutAppend(" Rank %s has status %s\n", lineBuf, statusStr[vcc->value]); + } + } + } // if (rasCountIsOutlier(vcc->count)) + } // for (coc) + } // if (__builtin_popcount(auxComm->status) > 1) + + inconsistent = false; + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) { + if (comm->ranks[rankIdx].collOpCount != auxComm->firstCollOpCount) { + inconsistent = true; + break; + } + } + if (inconsistent) { + rasOutAppend(" Communicator ranks have different collective operation counts\n"); + + // Sort the ranks by collOpCount and rank for easy counting. + memcpy(ranksReSorted, comm->ranks, comm->nRanks * sizeof(*ranksReSorted)); + qsort(ranksReSorted, comm->nRanks, sizeof(*ranksReSorted), rasCommRanksCollOpCompare); + // Calculate the frequency of different collOpCount values. + int nCollOpCounts = 0; + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) { + if (rankIdx == 0 || ranksReSorted[rankIdx].collOpCount != ranksReSorted[rankIdx-1].collOpCount) { + collOpCounts[nCollOpCounts].value = ranksReSorted[rankIdx].collOpCount; + collOpCounts[nCollOpCounts].count = 1; + collOpCounts[nCollOpCounts].firstIdx = rankIdx; + nCollOpCounts++; + } else { + collOpCounts[nCollOpCounts-1].count++; + } + } + // Sort by that frequency (most frequent first). + qsort(collOpCounts, nCollOpCounts, sizeof(*collOpCounts), rasValCountsCompareRev); + + for (int coc = 0; coc < nCollOpCounts; coc++) { + struct rasValCount* vcc = collOpCounts+coc; + if (vcc->count > 1) + rasOutAppend(" %d ranks have launched up to operation %ld\n", vcc->count, vcc->value); + if (rasCountIsOutlier(vcc->count, client->verbose, comm->commNRanks)) { + // ranksReSorted is sorted by rank as the secondary key, which comes in handy when printing... + for (int rankIdx = vcc->firstIdx; rankIdx < vcc->count+vcc->firstIdx; rankIdx++) { + int peerIdx = peerIdxConv[ranksReSorted[rankIdx].peerIdx]; + if (peerIdx != -1) { + if (vcc->count > 1) + rasOutAppend(" Rank %d -- GPU %s managed by process %d on node %s\n", + ranksReSorted[rankIdx].commRank, + rasCommRankGpuToString(ranksReSorted+rankIdx, lineBuf, sizeof(lineBuf)), + rasPeers[peerIdx].pid, + ncclSocketToHost(&rasPeers[peerIdx].addr, rasLine, sizeof(rasLine))); + else + rasOutAppend(" Rank %d has launched up to operation %ld -- GPU %s managed by process %d on node %s\n", + ranksReSorted[rankIdx].commRank, vcc->value, + rasCommRankGpuToString(ranksReSorted+rankIdx, lineBuf, sizeof(lineBuf)), + rasPeers[peerIdx].pid, + ncclSocketToHost(&rasPeers[peerIdx].addr, rasLine, sizeof(rasLine))); + } else { // peerIdx == -1 + if (vcc->count > 1) + rasOutAppend(" Rank %d -- [process information not found]\n", ranksReSorted[rankIdx].commRank); + else + rasOutAppend(" Rank %d has launched up to operation %ld -- [process information not found]\n", + ranksReSorted[rankIdx].commRank, vcc->value); + } // peerIdx == -1 + } // for (rankIdx) + } // if (rasCountIsOutlier(vcc->count)) + } // for (coc) + } // if (inconsistent) + rasOutAppend("\n"); + } // if (auxComm->errors & RAS_ACE_MISMATCH) + } // for (commIdx) + } // for (vcIdx) + rasCollFree(coll); + + msgLen = rasOutLength(); + NCCLCHECKGOTO(rasClientAllocMsg(&msg, msgLen), ret, fail); + rasOutExtract(msg); + rasClientEnqueueMsg(client, msg, msgLen); + msg = nullptr; +exit: + free(peerNvmlDevs); + free(collOpCounts); + free(valCounts); + free(peerIdxConv); + free(ranksReSorted); + free(auxComms); + return ret; +fail: + goto exit; +} + +static void rasClientBreakDownErrors(struct rasClient* client, struct rasCollComms::comm* comm, + const int* peerIdxConv, int ncclErrors[ncclNumResults], bool isAsync) { + for (;;) { + int maxCount = 0; + ncclResult_t maxCountIdx = ncclSuccess; + for (int i = ncclUnhandledCudaError; i < ncclInProgress; i++) { + if (maxCount < ncclErrors[i]) { + maxCount = ncclErrors[i]; + maxCountIdx = (ncclResult_t)i; + } + } // for (i) + if (maxCountIdx == ncclSuccess) + break; + if (maxCount > 1) + rasOutAppend(" %d ranks reported %s\n", maxCount, ncclErrorToString(maxCountIdx)); + if (rasCountIsOutlier(maxCount, client->verbose)) { + for (int rankIdx = 0; rankIdx < comm->nRanks; rankIdx++) { + if ((isAsync ? comm->ranks[rankIdx].status.asyncError : comm->ranks[rankIdx].status.initState) == maxCountIdx) { + int peerIdx = peerIdxConv[comm->ranks[rankIdx].peerIdx]; + if (peerIdx != -1) { + if (maxCount > 1) + rasOutAppend(" Rank %d -- GPU %s managed by process %d on node %s\n", + comm->ranks[rankIdx].commRank, + rasCommRankGpuToString(comm->ranks+rankIdx, lineBuf, sizeof(lineBuf)), + rasPeers[peerIdx].pid, + ncclSocketToHost(&rasPeers[peerIdx].addr, rasLine, sizeof(rasLine))); + else + rasOutAppend(" Rank %d reported %s -- GPU %s managed by process %d on node %s\n", + comm->ranks[rankIdx].commRank, ncclErrorToString(maxCountIdx), + rasCommRankGpuToString(comm->ranks+rankIdx, lineBuf, sizeof(lineBuf)), + rasPeers[peerIdx].pid, + ncclSocketToHost(&rasPeers[peerIdx].addr, rasLine, sizeof(rasLine))); + } else { // peerIdx == -1 + if (maxCount > 1) + rasOutAppend(" Rank %d -- [process information not found]\n", comm->ranks[rankIdx].commRank); + else + rasOutAppend(" Rank %d reported %s -- [process information not found]\n", + comm->ranks[rankIdx].commRank, ncclErrorToString(maxCountIdx)); + } // peerIdx == -1 + } // if rank's error matches + } // for (rankIdx) + } // if (rasCountIsOutlier(maxCount)) + ncclErrors[maxCountIdx] = 0; + } // for (;;) +} + + +////////////////////////////////////////////////////////////////////// +// Functions related to the handling of the internal output buffer. // +////////////////////////////////////////////////////////////////////// + +// Appends a printf-formatted string to the output buffer. +// Unlike with INFO or WARN messages, the caller should terminate lines with '\n' as appropriate. +static void rasOutAppend(const char* format, ...) { + ncclResult_t ret; // Ignored. + va_list vargs; + int needed; + va_start(vargs, format); + needed = vsnprintf(rasOutBuffer+nRasOutBuffer, rasOutBufferSize-nRasOutBuffer, format, vargs); + va_end(vargs); + + if (needed < 0) // Output error (whatever that might be...) + return; + + // The +1 below accounts for the terminating '\0'. + if (needed + 1 > rasOutBufferSize-nRasOutBuffer) { + int newBufferSize = ROUNDUP(nRasOutBuffer+needed+1, RAS_OUT_INCREMENT); + NCCLCHECKGOTO(ncclRealloc(&rasOutBuffer, rasOutBufferSize, newBufferSize), ret, exit); + rasOutBufferSize = newBufferSize; + + va_start(vargs, format); + needed = vsnprintf(rasOutBuffer+nRasOutBuffer, rasOutBufferSize-nRasOutBuffer, format, vargs); + va_end(vargs); + + if (needed < 0) // Output error (whatever that might be...) + return; + } + + nRasOutBuffer += needed; + assert(nRasOutBuffer <= rasOutBufferSize); +exit: + ; +} + +// Copies the output data from an internal buffer to a user-supplied one, including the terminating '\0'. +// The user buffer must already be allocated and be at least rasOutLength() bytes long (which includes +// the terminating '\0'). +static void rasOutExtract(char* buffer) { + if (rasOutBuffer) + memcpy(buffer, rasOutBuffer, rasOutLength()); +} + +// Returns the current length of the used portion of the output buffer, *not* including the terminating '\0'. +static int rasOutLength() { + return nRasOutBuffer; +} + +// Resets the output buffer position to the beginning (effectively clearing the buffer). +static void rasOutReset() { + ncclResult_t ret; // Ignored. + nRasOutBuffer = 0; + if (rasOutBuffer == nullptr) { + NCCLCHECKGOTO(ncclCalloc(&rasOutBuffer, RAS_OUT_INCREMENT), ret, exit); + rasOutBufferSize = RAS_OUT_INCREMENT; + } +exit: + ; +} + + +/////////////////////////////////////////////////////////////////// +// Various sorting callbacks used when grouping/formatting data. // +/////////////////////////////////////////////////////////////////// + +// Sorting callback for rasPeerInfo elements. Sorts by the number of bits set in cudaDevs. Uses the host IP as the +// secondary key and the process id as the tertiary key. +static int rasPeersNGpuCompare(const void* e1, const void* e2) { + const struct rasPeerInfo* p1 = (const struct rasPeerInfo*)e1; + const struct rasPeerInfo* p2 = (const struct rasPeerInfo*)e2; + int c1 = __builtin_popcountll(p1->cudaDevs); + int c2 = __builtin_popcountll(p2->cudaDevs); + + if (c1 == c2) { + // Host IP address is the secondary key. + int cmp = ncclSocketsHostCompare(&p1->addr, &p2->addr); + if (cmp == 0) { + // Process ID is the tertiary key. + cmp = (p1->pid < p2->pid ? -1 : (p1->pid > p2->pid ? 1 : 0)); + } + return cmp; + } else { + return (c1 < c2 ? -1 : 1); + } +} + +// Sorting callback for rasPeerInfo elements. Sorts by the number of peers per node, which we store in cudaDevs. +// Uses the host IP as the secondary key and the process id as the tertiary key. +static int rasPeersNProcsCompare(const void* e1, const void* e2) { + const struct rasPeerInfo* p1 = (const struct rasPeerInfo*)e1; + const struct rasPeerInfo* p2 = (const struct rasPeerInfo*)e2; + + if (p1->cudaDevs == p2->cudaDevs) { + // Host IP address is the secondary key. + int cmp = ncclSocketsHostCompare(&p1->addr, &p2->addr); + if (cmp == 0) { + // Process ID is the tertiary key. + cmp = (p1->pid < p2->pid ? -1 : (p1->pid > p2->pid ? 1 : 0)); + } + return cmp; + } else { + return (p1->cudaDevs < p2->cudaDevs ? -1 : 1); + } +} + +// Sorting callback for rasPeerInfo elements. Sorts by the host IP and the process id as the secondary key (rather +// than the port). +static int rasPeersHostPidCompare(const void* e1, const void* e2) { + const struct rasPeerInfo* p1 = (const struct rasPeerInfo*)e1; + const struct rasPeerInfo* p2 = (const struct rasPeerInfo*)e2; + + int cmp = ncclSocketsHostCompare(&p1->addr, &p2->addr); + if (cmp == 0) { + // Process ID is the secondary key. + cmp = (p1->pid < p2->pid ? -1 : (p1->pid > p2->pid ? 1 : 0)); + } + return cmp; +} + +// Sorting callback for ncclSocketAddress. Unlike the ncclSocketsCompare, it ignores the port. +static int ncclSocketsHostCompare(const void* p1, const void* p2) { + const union ncclSocketAddress* a1 = (const union ncclSocketAddress*)p1; + const union ncclSocketAddress* a2 = (const union ncclSocketAddress*)p2; + // AF_INET (2) is less than AF_INET6 (10). + int family = a1->sa.sa_family; + if (family != a2->sa.sa_family) { + if (family > 0 && a2->sa.sa_family > 0) + return (family < a2->sa.sa_family ? -1 : 1); + else // Put empty addresses at the end (not that it matters...). + return (family > 0 ? -1 : 1); + } + + int cmp; + if (family == AF_INET) { + cmp = memcmp(&a1->sin.sin_addr, &a2->sin.sin_addr, sizeof(a1->sin.sin_addr)); + } + else if (family == AF_INET6) { + cmp = memcmp(&a1->sin6.sin6_addr, &a2->sin6.sin6_addr, sizeof(a1->sin6.sin6_addr)); + } else { + // The only remaining valid case are empty addresses. + assert(family == 0); + cmp = 0; // Two empty addresses are equal... + } + + return cmp; +} + +// Sorting callback for rasValCount elements. Sorts by the count, largest first. Value is the secondary key. +static int rasValCountsCompareRev(const void* p1, const void* p2) { + const struct rasValCount* r1 = (const struct rasValCount*)p1; + const struct rasValCount* r2 = (const struct rasValCount*)p2; + + if (r1->count == r2->count) { + return (r1->value > r2->value ? -1 : (r1->value < r2->value ? 1: 0)); + } else { + return (r1->count > r2->count ? -1 : 1); + } +} + +// Sorting callback for rasAuxComm elements. +// Sorts the comms by the rank count (commNRanks), nNodes as secondary key, status as the tertiary, and errors as +// the quaternary. Sorts in reverse (largest first). +// The final key is the comm's nRanks, sorted in reverse to the other keys, so comms with the largest number +// of ranks *missing* will be first. +static int rasAuxCommsCompareRev(const void* p1, const void* p2) { + const struct rasAuxComm* c1 = (const struct rasAuxComm*)p1; + const struct rasAuxComm* c2 = (const struct rasAuxComm*)p2; + + if (c1->comm->commNRanks == c2->comm->commNRanks) { + if (c1->nNodes == c2->nNodes) { + // We don't want to compare the status values directly because they could be bitmasks and we are only + // interested in the highest bit set. + // __builtin_clz returns the number of leading 0-bits, so in our case the value will be the *smallest* + // if RAS_ACS_ABORT (8) is set and the *largest* if only RAS_ACS_INIT (1) is set, so we reverse the + // comparison to get the desired sorting order. + int s1 = __builtin_clz(c1->status); + int s2 = __builtin_clz(c2->status); + if (s1 == s2) { + if (c1->errors == c2->errors) { + if (c1->comm->nRanks == c2->comm->nRanks) { + return 0; + } else { + return (c1->comm->nRanks < c2->comm->nRanks ? -1 : 1); + } + } else { + return (c1->errors > c2->errors ? -1 : 1); + } + } else { + return (s1 < s2 ? -1 : 1); + } + } else { + return (c1->nNodes > c2->nNodes ? -1 : 1); + } + } else { + return (c1->comm->commNRanks > c2->comm->commNRanks ? -1 : 1); + } +} + +// Sorting callback for rasCollComms::comm::rank elements. Sorts by the peerIdx. +static int rasCommRanksPeerCompare(const void* p1, const void* p2) { + const struct rasCollComms::comm::rank* r1 = (const struct rasCollComms::comm::rank*)p1; + const struct rasCollComms::comm::rank* r2 = (const struct rasCollComms::comm::rank*)p2; + + return (r1->peerIdx < r2->peerIdx ? -1 : (r1->peerIdx > r2->peerIdx ? 1 : 0)); +} + +// Sorting callback for rasCollComms::comm::rank elements. Sorts by the collOpCount, with rank as the secondary key. +static int rasCommRanksCollOpCompare(const void* p1, const void* p2) { + const struct rasCollComms::comm::rank* r1 = (const struct rasCollComms::comm::rank*)p1; + const struct rasCollComms::comm::rank* r2 = (const struct rasCollComms::comm::rank*)p2; + + if (r1->collOpCount == r2->collOpCount) { + // Use the rank as the secondary key. + return (r1->commRank < r2->commRank ? -1 : (r1->commRank > r2->commRank ? 1 : 0)); + } else { + return (r1->collOpCount < r2->collOpCount ? -1 : 1); + } +} + + +//////////////////////////////////////////////////////////// +// String formatting functions for various types of data. // +//////////////////////////////////////////////////////////// + +// Coverts a GPU mask(s) to a string. If the CUDA mask is different from the NVML mask, both are printed. +const char* rasGpuDevsToString(uint64_t cudaDevs, uint64_t nvmlDevs, char* buf, size_t size) { + bool first = true; + buf[0] = '\0'; + for (int i = 0; i < NCCL_MAX_LOCAL_RANKS; i++) + if (cudaDevs & (1UL << i)) { + snprintf(buf+strlen(buf), size-strlen(buf), "%s%d", (first ? "" : ","), i); + first = false; + } + if (cudaDevs != nvmlDevs) { + snprintf(buf+strlen(buf), size-strlen(buf), " (NVML "); + first = true; + for (int i = 0; i < NCCL_MAX_LOCAL_RANKS; i++) + if (nvmlDevs & (1UL << i)) { + snprintf(buf+strlen(buf), size-strlen(buf), "%s%d", (first ? "" : ","), i); + first = false; + } + snprintf(buf+strlen(buf), size-strlen(buf), ")"); + } + return buf; +} + +// Formats a GPU string based on the rasCollComms's rank. If the CUDA id is different from the NVML id, both are +// printed. +static const char* rasCommRankGpuToString(const struct rasCollComms::comm::rank* rank, char* buf, size_t size) { + snprintf(buf, size, "%d", rank->cudaDev); + if (rank->cudaDev != rank->nvmlDev) { + snprintf(buf+strlen(buf), size-strlen(buf), " (NVML %d)", rank->nvmlDev); + } + return buf; +} + +// Converts a NCCL error result to a string. +static const char* ncclErrorToString(ncclResult_t err) { + switch (err) { + case ncclUnhandledCudaError : return "Unhandled CUDA error"; + case ncclSystemError : return "System error"; + case ncclInternalError : return "Internal error"; + case ncclInvalidArgument : return "Invalid argument"; + case ncclInvalidUsage : return "Invalid usage"; + case ncclRemoteError : return "Remote process error"; + case ncclInProgress : return "NCCL operation in progress"; + default : return "Unexpected error"; + } +} + +// Converts the IP number of a NCCL address to a string (the port part is ignored and no DNS resolution is attempted). +static const char* ncclSocketToHost(const union ncclSocketAddress* addr, char* buf, size_t size) { + if (addr->sa.sa_family > 0) + return inet_ntop(addr->sa.sa_family, + (addr->sa.sa_family == AF_INET ? (void*)&addr->sin.sin_addr : (void*)&addr->sin6.sin6_addr), + buf, size); + else { + if (size > 0) + buf[0] = '\0'; + return buf; + } +} + +// Determines if the given count constitutes an outlier. +static bool rasCountIsOutlier(int count, bool verbose, int totalCount) { + if (count == 1) + return true; // A single rank is always considered an outlier... + if (verbose) { + return (totalCount != -1 ? count < totalCount * RAS_CLIENT_VERBOSE_OUTLIER_FRACTION : true); + } else { + return count <= RAS_CLIENT_DETAIL_THRESHOLD && + (totalCount == -1 || count <= totalCount * RAS_CLIENT_OUTLIER_FRACTION); + } +} diff --git a/src/ras/collectives.cc b/src/ras/collectives.cc new file mode 100644 index 0000000000..201144f1a0 --- /dev/null +++ b/src/ras/collectives.cc @@ -0,0 +1,762 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#define NDEBUG // Comment out duriyng development only! +#include +#include + +#include "alloc.h" +#include "checks.h" +#include "comm.h" +#include "nccl.h" +#include "utils.h" +#include "ras_internal.h" + +// The number of recent collectives to keep track of. Completely arbitrary. +#define COLL_HISTORY_SIZE 64 + +// An entry in the rasCollHistory array keeping track of recently completed collectives (to make it possible to +// identify and drop duplicates arriving over different links). +struct rasCollHistoryEntry { + union ncclSocketAddress rootAddr; + uint64_t rootId; +}; + +// Array keeping track of recently completed collectives (to avoid infinite loops). LRU-based replacement. +static struct rasCollHistoryEntry rasCollHistory[COLL_HISTORY_SIZE]; +static int nRasCollHistory, rasCollHistNextIdx; + +// Monotonically increased to ensure that each collective originating locally has a unique Id. +static uint64_t rasCollLastId; + +// Array keeping track of ongoing collective operations (apart from broadcasts, which have no response so require +// no such tracking). +struct rasCollective* rasCollectives; +static int nRasCollectives; + +static ncclResult_t getNewCollEntry(struct rasCollective** pColl); +static ncclResult_t rasLinkSendCollReq(struct rasLink* link, struct rasCollective* coll, + const struct rasCollRequest* req, size_t reqLen, int fromConnIdx); +static ncclResult_t rasConnSendCollReq(struct rasConnection* conn, const struct rasCollRequest* req, size_t reqLen); +static ncclResult_t rasCollReadyResp(struct rasCollective* coll); +static ncclResult_t rasConnSendCollResp(struct rasConnection* conn, + const union ncclSocketAddress* rootAddr, uint64_t rootId, + const union ncclSocketAddress* peers, int nPeers, + const char* data, int nData, int nLegTimeouts); + +static ncclResult_t rasCollConnsInit(char** pData, int* pNData); +static ncclResult_t rasCollConnsMerge(struct rasCollective* coll, struct rasMsg* msg); + +static ncclResult_t rasCollCommsInit(char** pData, int* pNData); +static ncclResult_t rasCollCommsMerge(struct rasCollective* coll, struct rasMsg* msg); +static int ncclCommsCompare(const void* p1, const void* p2); + + +/////////////////////////////////////////////////////////////////////////////////////// +// Functions related to the initialization of collectives and the message exchanges. // +/////////////////////////////////////////////////////////////////////////////////////// + +// Returns the index of the first available entry in the rasCollectives array, enlarging the array if necessary. +static ncclResult_t getNewCollEntry(struct rasCollective** pColl) { + struct rasCollective* coll; + int i; + for (i = 0; i < nRasCollectives; i++) + if (rasCollectives[i].type == RAS_MSG_NONE) + break; + if (i == nRasCollectives) { + NCCLCHECK(ncclRealloc(&rasCollectives, nRasCollectives, nRasCollectives+RAS_INCREMENT)); + nRasCollectives += RAS_INCREMENT; + } + + coll = rasCollectives+i; + memset(coll, '\0', sizeof(*coll)); + coll->startTime = clockNano(); + coll->fromConnIdx = -1; + // We are unlikely to use the whole array, but at least we won't need to realloc. + NCCLCHECK(ncclCalloc(&coll->fwdConns, nRasConns)); + + *pColl = coll; + return ncclSuccess; +} + +// Initializes a collective request by giving it a unique ID. +void rasCollReqInit(struct rasCollRequest* req) { + memcpy(&req->rootAddr, &rasNetListeningSocket.addr, sizeof(req->rootAddr)); + req->rootId = ++rasCollLastId; +} + +// Sends a collective request message through all regular RAS network connections (effectively, broadcasts it). +// Also used for re-broadcasts (on peers receiving the request over the network). +// Checking for duplicates is the responsibility of the caller. +// For collectives other than broadcasts, initializes a rasCollective structure and fills it with local data, +// in preparation for collective response messages. +// pAllDone indicates on return if the collective operation is already finished, which is unusual, but possible +// in scenarios such as a total of two peers. +// pCollIdx provides on return an index of the allocated rasCollective structure to track this collective (unless +// it's a broadcast, which require no such tracking). +ncclResult_t rasNetSendCollReq(const struct rasCollRequest* req, size_t reqLen, bool* pAllDone, int* pCollIdx, + int fromConnIdx) { + struct rasCollective* coll = nullptr; + if (req->type >= RAS_COLL_CONNS) { + // Keep track of this collective operation so that we can handle the responses appropriately. + NCCLCHECK(getNewCollEntry(&coll)); + if (pCollIdx) + *pCollIdx = coll-rasCollectives; + memcpy(&coll->rootAddr, &req->rootAddr, sizeof(coll->rootAddr)); + coll->rootId = req->rootId; + coll->type = req->type; + coll->timeout = req->timeout; + coll->fromConnIdx = fromConnIdx; + if (ncclCalloc(&coll->peers, 1) == ncclSuccess) { + memcpy(coll->peers, &rasNetListeningSocket.addr, sizeof(*coll->peers)); + coll->nPeers = 1; + } + + // Collective-specific initialization of accumulated data (using local data for now). + if (req->type == RAS_COLL_CONNS) + (void)rasCollConnsInit(&coll->data, &coll->nData); + else if (req->type == RAS_COLL_COMMS) + (void)rasCollCommsInit(&coll->data, &coll->nData); + } else { // req->type < RAS_COLL_CONNS + // Add the info to the collective message history. + nRasCollHistory = std::min(nRasCollHistory+1, COLL_HISTORY_SIZE); + memcpy(&rasCollHistory[rasCollHistNextIdx].rootAddr, &req->rootAddr, + sizeof(rasCollHistory[rasCollHistNextIdx].rootAddr)); + rasCollHistory[rasCollHistNextIdx].rootId = req->rootId; + rasCollHistNextIdx = (rasCollHistNextIdx + 1) % COLL_HISTORY_SIZE; + + // Collective-specific message handling. + if (req->type == RAS_BC_DEADPEER) { + bool done = false; + rasMsgHandleBCDeadPeer(req, &done); + if (done) + goto exit; + } + } // req->type < RAS_COLL_CONNS + + for (int connIdx = 0; connIdx < nRasConns; connIdx++) + rasConns[connIdx].linkFlag = false; + + (void)rasLinkSendCollReq(&rasNextLink, coll, req, reqLen, fromConnIdx); + (void)rasLinkSendCollReq(&rasPrevLink, coll, req, reqLen, fromConnIdx); + + if (coll && pAllDone) + *pAllDone = (coll->nFwdSent == coll->nFwdRecv); +exit: + return ncclSuccess; +} + +// Sends the collective message through all connections associated with this link (with the exception of the one +// the message came from, if any). +static ncclResult_t rasLinkSendCollReq(struct rasLink* link, struct rasCollective* coll, + const struct rasCollRequest* req, size_t reqLen, int fromConnIdx) { + for (int i = 0; i < link->nConns; i++) { + struct rasLinkConn* linkConn = link->conns+i; + if (linkConn->connIdx != -1 && linkConn->connIdx != fromConnIdx) { + struct rasConnection* conn = rasConns+linkConn->connIdx; + if (!conn->linkFlag) { + // We send collective messages through fully established and operational connections only. + if (conn->sockIdx != -1 && rasSockets[conn->sockIdx].status == RAS_SOCK_READY && !conn->experiencingDelays) { + if (rasConnSendCollReq(conn, req, reqLen) == ncclSuccess && coll != nullptr) + coll->fwdConns[coll->nFwdSent++] = linkConn->connIdx; + } // if (conn->sockIdx != -1 && RAS_SOCK_READY) + conn->linkFlag = true; + } // if (!conn->linkFlag) + } // if (linkConn->connIdx != -1 && linkConn->connIdx != fromConnIdx) + } // for (i) + + return ncclSuccess; +} + +// Sends a collective message down a particular connection. +static ncclResult_t rasConnSendCollReq(struct rasConnection* conn, const struct rasCollRequest* req, size_t reqLen) { + struct rasMsg* msg = nullptr; + int msgLen = rasMsgLength(RAS_MSG_COLLREQ) + reqLen; + + NCCLCHECK(rasMsgAlloc(&msg, msgLen)); + msg->type = RAS_MSG_COLLREQ; + memcpy(&msg->collReq, req, reqLen); + + rasConnEnqueueMsg(conn, msg, msgLen); + + return ncclSuccess; +} + +// Handles the RAS_MSG_COLLREQ collective message request on the receiver side. Primarily deals with duplicates and +// re-broadcasts the message to local peers, though in case of a very limited RAS network it might be done right away, +// in which case it can immediately send the response. +ncclResult_t rasMsgHandleCollReq(struct rasMsg* msg, struct rasSocket* sock) { + bool allDone = false; + int collIdx = -1; + assert(sock->connIdx != -1); + + // First check if we've already handled this request (through another connection). + for (int i = 0; i < nRasCollHistory; i++) { + // In principle we can use i to index the array but we convert it so that we check the most recent entries first. + int collHistIdx = (rasCollHistNextIdx + COLL_HISTORY_SIZE - 1 - i) % COLL_HISTORY_SIZE; + if (memcmp(&msg->collReq.rootAddr, &rasCollHistory[collHistIdx].rootAddr, sizeof(msg->collReq.rootAddr)) == 0 && + msg->collReq.rootId == rasCollHistory[collHistIdx].rootId) { + if (msg->collReq.type >= RAS_COLL_CONNS) { + // Send an empty response so that the sender can account for it. The non-empty response has already been + // sent through the connection that we received the request through first. + NCCLCHECK(rasConnSendCollResp(rasConns+sock->connIdx, &msg->collReq.rootAddr, msg->collReq.rootId, + /*peers*/nullptr, /*nPeers*/0, /*data*/nullptr, /*nData*/0, /*nLegTimeouts*/0)); + } + goto exit; + } + } // for (i) + + if (msg->collReq.type >= RAS_COLL_CONNS) { + // Check if we're currently handling this collective request. + for (int i = 0; i < nRasCollectives; i++) { + struct rasCollective* coll = rasCollectives+i; + if (coll->type != RAS_MSG_NONE && + memcmp(&msg->collReq.rootAddr, &coll->rootAddr, sizeof(msg->collReq.rootAddr)) == 0 && + msg->collReq.rootId == coll->rootId) { + assert(msg->collReq.type == coll->type); + + // Send an empty response so that the sender can account for it. The non-empty response will be + // sent through the connection that we received the request through first. + NCCLCHECK(rasConnSendCollResp(rasConns+sock->connIdx, &msg->collReq.rootAddr, msg->collReq.rootId, + /*peers*/nullptr, /*nPeers*/0, /*data*/nullptr, /*nData*/0, /*nLegTimeouts*/0)); + goto exit; + } // if match + } // for (i) + } // if (msg->collReq.type >= RAS_COLL_CONNS) + + // Re-broadcast the message to my peers (minus the one it came from) and handle it locally. + NCCLCHECK(rasNetSendCollReq(&msg->collReq, rasCollDataLength(msg->collReq.type), &allDone, &collIdx, sock->connIdx)); + + if (msg->collReq.type >= RAS_COLL_CONNS && allDone) { + assert(collIdx != -1); + // We are a leaf process -- send the response right away. This can probably trigger only for the case of a total + // of two peers, and hence just one RAS connection, or during communication issues, because normally every peer + // has more than one connection so there should always be _some_ other peer to forward the request to. + NCCLCHECK(rasCollReadyResp(rasCollectives+collIdx)); + } +exit: + return ncclSuccess; +} + +// Sends a collective response back to the process we received the collective request from. +// Invoked when we are finished waiting for the collective responses from other peers (i.e., either there weren't +// any peers (unlikely), the peers sent their responses (likely), or we timed out. +static ncclResult_t rasCollReadyResp(struct rasCollective* coll) { + if (coll->fromConnIdx != -1) { + // For remotely-initiated collectives, send the response back. + NCCLCHECK(rasConnSendCollResp(rasConns+coll->fromConnIdx, &coll->rootAddr, coll->rootId, + coll->peers, coll->nPeers, coll->data, coll->nData, coll->nLegTimeouts)); + + // Add the identifying info to the collective message history. + nRasCollHistory = std::min(nRasCollHistory+1, COLL_HISTORY_SIZE); + memcpy(&rasCollHistory[rasCollHistNextIdx].rootAddr, &coll->rootAddr, + sizeof(rasCollHistory[rasCollHistNextIdx].rootAddr)); + rasCollHistory[rasCollHistNextIdx].rootId = coll->rootId; + rasCollHistNextIdx = (rasCollHistNextIdx + 1) % COLL_HISTORY_SIZE; + + rasCollFree(coll); + } else { + // For locally-initiated collectives, invoke the client code again (which will release it, once finished). + NCCLCHECK(rasClientResume(coll)); + } + return ncclSuccess; +} + +// Sends a collective response via the connection we originally received the request from. The message should be +// a cumulative response from this process and all the processes that we forwarded the request to. +static ncclResult_t rasConnSendCollResp(struct rasConnection* conn, + const union ncclSocketAddress* rootAddr, uint64_t rootId, + const union ncclSocketAddress* peers, int nPeers, + const char* data, int nData, int nLegTimeouts) { + struct rasMsg* msg = nullptr; + int msgLen = rasMsgLength(RAS_MSG_COLLRESP) + nPeers*sizeof(*peers); + int dataOffset = 0; + + if (nData > 0) { + ALIGN_SIZE(msgLen, alignof(int64_t)); + dataOffset = msgLen; + msgLen += nData; + } + + NCCLCHECK(rasMsgAlloc(&msg, msgLen)); + msg->type = RAS_MSG_COLLRESP; + memcpy(&msg->collResp.rootAddr, rootAddr, sizeof(msg->collResp.rootAddr)); + msg->collResp.rootId = rootId; + msg->collResp.nLegTimeouts = nLegTimeouts; + msg->collResp.nPeers = nPeers; + msg->collResp.nData = nData; + if (nPeers) + memcpy(msg->collResp.peers, peers, nPeers*sizeof(*msg->collResp.peers)); + if (nData) + memcpy(((char*)msg)+dataOffset, data, nData); + + rasConnEnqueueMsg(conn, msg, msgLen); + + return ncclSuccess; +} + +// Handles the collective response on the receiver side. Finds the corresponding rasCollective structure, merges +// the data from the response into the accumulated data. If all the responses have been accounted for, sends the +// accumulated response back. +ncclResult_t rasMsgHandleCollResp(struct rasMsg* msg, struct rasSocket* sock) { + int collIdx; + struct rasCollective* coll = nullptr; + char line[SOCKET_NAME_MAXLEN+1]; + + for (collIdx = 0; collIdx < nRasCollectives; collIdx++) { + coll = rasCollectives+collIdx; + if (coll->type != RAS_MSG_NONE && + memcmp(&msg->collResp.rootAddr, &coll->rootAddr, sizeof(msg->collResp.rootAddr)) == 0 && + msg->collResp.rootId == coll->rootId) + break; + } + if (collIdx == nRasCollectives) { + INFO(NCCL_RAS, "RAS failed to find a matching ongoing collective for response %s:%ld from %s!", + ncclSocketToString(&msg->collResp.rootAddr, line), msg->collResp.rootId, + ncclSocketToString(&sock->sock.addr, rasLine)); + goto exit; + } + + coll->nLegTimeouts += msg->collResp.nLegTimeouts; + assert(sock->connIdx != -1); + // Account for the received response in our collective operation tracking. + for (int i = 0; i < coll->nFwdSent; i++) { + if (coll->fwdConns[i] == sock->connIdx) { + coll->fwdConns[i] = -1; + break; + } + } + coll->nFwdRecv++; + if (msg->collResp.nData > 0) { + // Collective-specific merging of the response into locally accumulated data. + if (coll->type == RAS_COLL_CONNS) + NCCLCHECK(rasCollConnsMerge(coll, msg)); + else if (coll->type == RAS_COLL_COMMS) + NCCLCHECK(rasCollCommsMerge(coll, msg)); + } + // We merge the peers after merging the data, so that the data merge function can rely on peers being unchanged. + if (msg->collResp.nPeers > 0) { + NCCLCHECK(ncclRealloc(&coll->peers, coll->nPeers, coll->nPeers + msg->collResp.nPeers)); + memcpy(coll->peers+coll->nPeers, msg->collResp.peers, msg->collResp.nPeers * sizeof(*coll->peers)); + coll->nPeers += msg->collResp.nPeers; + } + + // If we received all the data we were waiting for, send our response back. + if (coll->nFwdSent == coll->nFwdRecv) + NCCLCHECK(rasCollReadyResp(coll)); +exit: + return ncclSuccess; +} + +// Removes a connection from all ongoing collectives. Called when a connection is experiencing a delay or is being +// terminated. +void rasCollsPurgeConn(int connIdx) { + for (int i = 0; i < nRasCollectives; i++) { + struct rasCollective* coll = rasCollectives+i; + if (coll->type != RAS_MSG_NONE) { + char line[SOCKET_NAME_MAXLEN+1]; + if (coll->fromConnIdx == connIdx) { + INFO(NCCL_RAS, "RAS purging collective %s:%ld because it comes from %s", + ncclSocketToString(&coll->rootAddr, line), coll->rootId, + ncclSocketToString(&rasConns[connIdx].addr, rasLine)); + rasCollFree(coll); + } else { + for (int j = 0; j < coll->nFwdSent; j++) { + if (coll->fwdConns[j] == connIdx) { + coll->fwdConns[j] = -1; + coll->nFwdRecv++; + coll->nLegTimeouts++; + INFO(NCCL_RAS, "RAS not waiting for response from %s to collective %s:%ld " + "(nFwdSent %d, nFwdRecv %d, nLegTimeouts %d)", + ncclSocketToString(&rasConns[connIdx].addr, rasLine), ncclSocketToString(&coll->rootAddr, line), + coll->rootId, coll->nFwdSent, coll->nFwdRecv, coll->nLegTimeouts); + if (coll->nFwdSent == coll->nFwdRecv) + (void)rasCollReadyResp(coll); + break; + } + } // for (j) + } // coll->fromConnIdx != connIdx + } // !RAS_MSG_NONE + } // for (i) +} + +// Frees a rasCollective entry and any memory associated with it. +void rasCollFree(struct rasCollective* coll) { + free(coll->fwdConns); + coll->fwdConns = nullptr; + free(coll->peers); + coll->peers = nullptr; + free(coll->data); + coll->data = nullptr; + coll->fromConnIdx = -1; + coll->type = RAS_MSG_NONE; +} + +// Invoked from the main RAS thread loop to handle timeouts of the collectives. +// We obviously want to have a reasonable *total* timeout that the RAS client can rely on, but we don't have strict +// global coordination. So we have, in effect, two timeouts: soft (5s) and hard (10s). Soft equals the keep-alive +// timeout. +// When sending collective requests, we skip any connections that are experiencing delays. After the 5s timeout, we +// check again the status of all outstanding connections and if any is now delayed, we give up on it. +// That works fine for directly observable delays, but if the problematic connection is further away from us, all +// we can do is trust that the other peers will "do the right thing soon". However, if there is a cascade of +// problematic connections, they could still exceed the 5s total. So after 10s we give up waiting no matter what +// and send back whatever we have. Unfortunately, the peer that the RAS client is connected to will in all likelihood +// time out first, so at that point any delayed responses that eventually arrive are likely to be too late... +void rasCollsHandleTimeouts(int64_t now, int64_t* nextWakeup) { + for (int collIdx = 0; collIdx < nRasCollectives; collIdx++) { + struct rasCollective* coll = rasCollectives+collIdx; + if (coll->type == RAS_MSG_NONE || coll->timeout == 0) + continue; + + if (now - coll->startTime > coll->timeout) { + // We've exceeded the leg timeout. For all outstanding responses, check their connections. + if (!coll->timeoutWarned) { + INFO(NCCL_RAS, "RAS collective %s:%ld timeout warning (%lds) -- %d responses missing", + ncclSocketToString(&coll->rootAddr, rasLine), coll->rootId, + (now - coll->startTime) / CLOCK_UNITS_PER_SEC, coll->nFwdSent - coll->nFwdRecv); + coll->timeoutWarned = true; + } + for (int i = 0; i < coll->nFwdSent; i++) { + if (coll->fwdConns[i] != -1) { + struct rasConnection* conn = rasConns+coll->fwdConns[i]; + char line[SOCKET_NAME_MAXLEN+1]; + if (!conn->experiencingDelays && conn->sockIdx != -1) { + struct rasSocket* sock = rasSockets+conn->sockIdx; + // Ensure that the connection is fully established and operational, and that the socket hasn't been + // re-created during the handling of the collective (which would suggest that the request may have been + // lost). + if (sock->status == RAS_SOCK_READY && sock->createTime < coll->startTime) + continue; + } + // In all other cases we declare a timeout so that we can (hopefully) recover. + INFO(NCCL_RAS, "RAS not waiting for response from %s to collective %s:%ld " + "(nFwdSent %d, nFwdRecv %d, nLegTimeouts %d)", + ncclSocketToString(&conn->addr, rasLine), ncclSocketToString(&coll->rootAddr, line), + coll->rootId, coll->nFwdSent, coll->nFwdRecv, coll->nLegTimeouts); + coll->fwdConns[i] = -1; + coll->nFwdRecv++; + coll->nLegTimeouts++; + } // if (coll->fwdConns[i] != -1) + } // for (i) + if (coll->nFwdSent == coll->nFwdRecv) { + (void)rasCollReadyResp(coll); + } else { + // At least some of the delays are *not* due to this process' connections experiencing delays, i.e., they + // must be due to delays at other processes. Presumably those processes will give up waiting soon and the + // (incomplete) responses will arrive shortly, so we should wait a little longer. + if (now - coll->startTime > coll->timeout + RAS_COLLECTIVE_EXTRA_TIMEOUT) { + // We've exceeded even the longer timeout, which is unexpected. Try to return whatever we have (though + // the originator of the collective, if it's not us, may have timed out already anyway). + INFO(NCCL_RAS, "RAS collective %s:%ld timeout error (%lds) -- giving up on %d missing responses", + ncclSocketToString(&coll->rootAddr, rasLine), coll->rootId, + (now - coll->startTime) / CLOCK_UNITS_PER_SEC, coll->nFwdSent - coll->nFwdRecv); + coll->nLegTimeouts += coll->nFwdSent - coll->nFwdRecv; + coll->nFwdRecv = coll->nFwdSent; + (void)rasCollReadyResp(coll); + } else { + *nextWakeup = std::min(*nextWakeup, coll->startTime+coll->timeout+RAS_COLLECTIVE_EXTRA_TIMEOUT); + } + } // conn->nFwdRecv < conn->nFwdSent + } else { + *nextWakeup = std::min(*nextWakeup, coll->startTime+coll->timeout); + } + } // for (collIdx) +} + + +///////////////////////////////////////////////////////////////////////// +// Functions related to the handling of the RAS_COLL_CONNS collective. // +///////////////////////////////////////////////////////////////////////// + +// Initializes the accumulated data with just the local data for now. +// For this particular collective, we keep some reduced statistical data (min/max/avg travel time) as well +// as connection-specific info in case we observed a negative min travel time (which, ideally, shouldn't happen, +// but the system clocks may not be perfectly in sync). +static ncclResult_t rasCollConnsInit(char** pData, int* pNData) { + struct rasCollConns connsData = {.travelTimeMin = INT64_MAX, .travelTimeMax = INT64_MIN}; + struct rasCollConns* pConnsData; + + // Update the statistical data first and in the process also calculate how much connection-specific space we + // will need. + for (int i = 0; i < nRasConns; i++) { + struct rasConnection* conn = rasConns+i; + if (conn->inUse && conn->travelTimeCount > 0) { + if (connsData.travelTimeMin > conn->travelTimeMin) + connsData.travelTimeMin = conn->travelTimeMin; + if (connsData.travelTimeMax < conn->travelTimeMax) + connsData.travelTimeMax = conn->travelTimeMax; + connsData.travelTimeSum += conn->travelTimeSum; + connsData.travelTimeCount += conn->travelTimeCount; + connsData.nConns++; + if (conn->travelTimeMin < 0) + connsData.nNegativeMins++; + } + } + + *pNData = sizeof(connsData) + connsData.nNegativeMins*sizeof(*connsData.negativeMins); + NCCLCHECK(ncclCalloc(pData, *pNData)); + pConnsData = (struct rasCollConns*)*pData; + memcpy(pConnsData, &connsData, sizeof(*pConnsData)); + if (connsData.nNegativeMins > 0) { + for (int i = 0, negMinsIdx = 0; i < nRasConns; i++) { + struct rasConnection* conn = rasConns+i; + if (conn->inUse && conn->travelTimeMin < 0) { + struct rasCollConns::negativeMin* negativeMin = pConnsData->negativeMins+negMinsIdx; + memcpy(&negativeMin->source, &rasNetListeningSocket.addr, sizeof(negativeMin->source)); + memcpy(&negativeMin->dest, &conn->addr, sizeof(negativeMin->dest)); + negativeMin->travelTimeMin = conn->travelTimeMin; + negMinsIdx++; + } + assert(negMinsIdx <= connsData.nNegativeMins); + } + } + + return ncclSuccess; +} + +// Merges incoming collective RAS_COLL_CONNS response message into the local accumulated data. +static ncclResult_t rasCollConnsMerge(struct rasCollective* coll, struct rasMsg* msg) { + struct rasCollConns* collData; + struct rasCollConns* msgData; + int dataOffset = rasMsgLength(RAS_MSG_COLLRESP) + msg->collResp.nPeers*sizeof(*msg->collResp.peers); + ALIGN_SIZE(dataOffset, alignof(int64_t)); + + msgData = (struct rasCollConns*)(((char*)msg) + dataOffset); + collData = (struct rasCollConns*)coll->data; + + // Merge the stats. + if (collData->travelTimeMin > msgData->travelTimeMin) + collData->travelTimeMin = msgData->travelTimeMin; + if (collData->travelTimeMax < msgData->travelTimeMax) + collData->travelTimeMax = msgData->travelTimeMax; + collData->travelTimeSum += msgData->travelTimeSum; + collData->travelTimeCount += msgData->travelTimeCount; + collData->nConns += msgData->nConns; + + // Append the info about negative minimums. + if (msgData->nNegativeMins > 0) { + int nData = sizeof(*collData) + + (collData->nNegativeMins+msgData->nNegativeMins) * sizeof(*collData->negativeMins); + NCCLCHECK(ncclRealloc(&coll->data, coll->nData, nData)); + collData = (struct rasCollConns*)coll->data; + memcpy(coll->data+coll->nData, msgData->negativeMins, + msgData->nNegativeMins * sizeof(*collData->negativeMins)); + coll->nData = nData; + collData->nNegativeMins += msgData->nNegativeMins; + } + + return ncclSuccess; +} + + +///////////////////////////////////////////////////////////////////////// +// Functions related to the handling of the RAS_COLL_COMMS collective. // +///////////////////////////////////////////////////////////////////////// + +// Initializes the accumulated data with just the local data for now. +// For this particular collective, we keep for every communicator information about every rank, to help identify +// the missing ones and the discrepancies between the ones that did respond. +static ncclResult_t rasCollCommsInit(char** pData, int* pNData) { + struct rasCollComms* commsData; + int nComms = 0, nRanks = 0; + std::lock_guard lock(ncclCommsMutex); + + // Start by counting the communicators so that we know how much space to allocate. + // We also need to sort the comms array, to make the subsequent merging easier, both between the ranks (in case + // of multiple GPUs per process) and between the peers. + if (!ncclCommsSorted) { + qsort(ncclComms, nNcclComms, sizeof(*ncclComms), &ncclCommsCompare); + ncclCommsSorted = true; + } + for (int i = 0; i < nNcclComms; i++) { + if (ncclComms[i] == nullptr) // nullptr's are always at the end after sorting. + break; + if (i == 0) { + nComms = 1; + } else if (ncclComms[i]->commHash != ncclComms[i-1]->commHash) { + nComms++; + } + nRanks++; + } + + // rasNetCollCommsData has nested variable-length arrays, which makes the size calculation and subsequent + // pointer manipulations somewhat unwieldy... + *pNData = sizeof(*commsData) + nComms * sizeof(*commsData->comms) + nRanks * sizeof(*commsData->comms[0].ranks); + NCCLCHECK(ncclCalloc(pData, *pNData)); + commsData = (struct rasCollComms*)*pData; + commsData->nComms = nComms; + + // comm points at the space in the accumulated data where the info about the current communicator is to be stored. + struct rasCollComms::comm* comm = commsData->comms; + for (int i = 0; i < nNcclComms; i++) { + struct rasCollComms::comm::rank* rank; + ncclResult_t asyncError; + if (ncclComms[i] == nullptr) + break; + if (i == 0 || ncclComms[i]->commHash != ncclComms[i-1]->commHash) { + if (i > 0) + comm = (struct rasCollComms::comm*)(((char*)(comm+1)) + comm->nRanks * sizeof(*comm->ranks)); + comm->commHash = ncclComms[i]->commHash; + comm->commNRanks = ncclComms[i]->nRanks; + comm->nRanks = 0; + } else if (ncclComms[i]->nRanks != ncclComms[i-1]->nRanks) { + INFO(NCCL_RAS, "RAS encountered inconsistent communicator data: size %d != %d -- " + "possible commHash collision (0x%lx)", ncclComms[i-1]->nRanks, ncclComms[i]->nRanks, comm->commHash); + continue; // Short of failing, the best we can do is skip... + } else if (ncclComms[i]->rank == ncclComms[i-1]->rank) { + INFO(NCCL_RAS, "RAS encountered duplicate data for rank %d -- possible commHash collision (0x%lx)", + ncclComms[i]->rank, comm->commHash); + continue; // Short of failing, the best we can do is skip... + } + if (comm->nRanks == comm->commNRanks) { + INFO(NCCL_RAS, + "RAS encountered more ranks than the communicator size (%d) -- possible commHash collision (0x%lx)", + comm->commNRanks, comm->commHash); + continue; // Short of failing, the best we can do is skip... + } + rank = comm->ranks+comm->nRanks; + rank->commRank = ncclComms[i]->rank; + // rasNetSendCollReq initializes coll->peers[0] to our rasNetListeningSocket.addr, so peerIdx is initially + // always 0. It will increase after we send this response back to the peer we got the request from. + rank->peerIdx = 0; + rank->collOpCount = ncclComms[i]->collOpCount; + rank->status.initState = ncclComms[i]->initState; + if (ncclCommGetAsyncError(ncclComms[i], &asyncError) == ncclSuccess) + rank->status.asyncError = asyncError; + rank->status.finalizeCalled = (ncclComms[i]->finalizeCalled != 0); + rank->status.destroyFlag = (ncclComms[i]->destroyFlag != 0); + rank->status.abortFlag = (__atomic_load_n(ncclComms[i]->abortFlag, __ATOMIC_ACQUIRE) != 0); + rank->cudaDev = ncclComms[i]->cudaDev; + rank->nvmlDev = ncclComms[i]->nvmlDev; + comm->nRanks++; + } + assert(nComms == 0 || ((char*)(comm->ranks+comm->nRanks)) - (char*)commsData <= *pNData); + + return ncclSuccess; +} + +// Merges incoming collective RAS_COLL_COMMS response message into the local accumulated data. +static ncclResult_t rasCollCommsMerge(struct rasCollective* coll, struct rasMsg* msg) { + struct rasCollComms* collData; + struct rasCollComms* msgData; + int dataOffset = rasMsgLength(RAS_MSG_COLLRESP) + msg->collResp.nPeers*sizeof(*msg->collResp.peers); + ALIGN_SIZE(dataOffset, alignof(int64_t)); + + msgData = (struct rasCollComms*)(((char*)msg) + dataOffset); + collData = (struct rasCollComms*)coll->data; + + if (msgData->nComms > 0) { + struct rasCollComms* newData = nullptr; + + // Allocate the new buffer pessimistically (sized as the sum of the two old ones). + NCCLCHECK(ncclCalloc((char**)&newData, coll->nData + msg->collResp.nData)); + struct rasCollComms::comm* collComm = collData->comms; + struct rasCollComms::comm* msgComm = msgData->comms; + struct rasCollComms::comm* newComm = newData->comms; + + for (int collIdx = 0, msgIdx = 0; collIdx < collData->nComms || msgIdx < msgData->nComms; newData->nComms++) { + int cmp; + if (collIdx < collData->nComms && msgIdx < msgData->nComms) + cmp = (collComm->commHash < msgComm->commHash ? -1 : (collComm->commHash > msgComm->commHash ? 1 : 0)); + else + cmp = (collIdx < collData->nComms ? -1 : 1); + + if (cmp == 0 && collComm->commNRanks != msgComm->commNRanks) { + INFO(NCCL_RAS, "RAS encountered inconsistent communicator data: size %d != %d -- " + "possible commHash collision (0x%lx)", collComm->commNRanks, msgComm->commNRanks, collComm->commHash); + cmp = (collComm->commNRanks < msgComm->commNRanks ? -1 : 1); + // We try to preserve both separately, although the input data might already be messed up anyway... + } + + if (cmp == 0) { + // Merge the comms. + newComm->commHash = collComm->commHash; + newComm->commNRanks = collComm->commNRanks; + if (collComm->nRanks + msgComm->nRanks > collComm->commNRanks) { + INFO(NCCL_RAS, + "RAS encountered more ranks (%d) than the communicator size (%d) -- possible commHash collision (0x%lx)", + collComm->nRanks + msgComm->nRanks, newComm->commNRanks, newComm->commHash); + // We'll skip the extras in the loop below. + } else { + newComm->nRanks = collComm->nRanks + msgComm->nRanks; + } + // Merge the ranks. + for (int newRankIdx = 0, collRankIdx = 0, msgRankIdx = 0; + collRankIdx < collComm->nRanks || msgRankIdx < msgComm->nRanks; + newRankIdx++) { + int cmpRank; + if (newRankIdx == newComm->commNRanks) + break; // Short of failing, the best we can do is skip... + if (collRankIdx < collComm->nRanks && msgRankIdx < msgComm->nRanks) + cmpRank = (collComm->ranks[collRankIdx].commRank < msgComm->ranks[msgRankIdx].commRank ? -1 : + (collComm->ranks[collRankIdx].commRank > msgComm->ranks[msgRankIdx].commRank ? 1 : 0)); + else + cmpRank = (collRankIdx < collComm->nRanks ? -1 : 1); + + // There shouldn't be any overlaps in ranks between different sources. + if (cmpRank == 0) { + INFO(NCCL_RAS, "RAS encountered duplicate data for rank %d -- possible commHash collision (0x%lx)", + collComm->ranks[collRankIdx].commRank, newComm->commHash); + msgRankIdx++; // Short of failing, the best we can do is skip... + } + memcpy(newComm->ranks+newRankIdx, (cmpRank <= 0 ? collComm->ranks+collRankIdx++ : + msgComm->ranks+msgRankIdx++), sizeof(*newComm->ranks)); + if (cmpRank > 0) { + // peerIdx values from msgComm need to shift after merge. + newComm->ranks[newRankIdx].peerIdx += coll->nPeers; + } + } // for (newRankIdx) + newComm = (struct rasCollComms::comm*)(((char*)(newComm+1)) + newComm->nRanks * sizeof(*newComm->ranks)); + collComm = (struct rasCollComms::comm*)(((char*)(collComm+1)) + collComm->nRanks * sizeof(*collComm->ranks)); + collIdx++; + msgComm = (struct rasCollComms::comm*)(((char*)(msgComm+1)) + msgComm->nRanks * sizeof(*msgComm->ranks)); + msgIdx++; + } else if (cmp < 0) { + // Copy from collComm. + int commSize = sizeof(*collComm) + collComm->nRanks * sizeof(*collComm->ranks); + memcpy(newComm, collComm, commSize); + newComm = (struct rasCollComms::comm*)(((char*)(newComm)) + commSize); + collComm = (struct rasCollComms::comm*)(((char*)(collComm)) + commSize); + collIdx++; + } else { // cmp > 0 + // Copy from msgComm. + int commSize = sizeof(*msgComm) + msgComm->nRanks * sizeof(*msgComm->ranks); + memcpy(newComm, msgComm, commSize); + for (int i = 0; i < newComm->nRanks; i++) { + // peerIdx values from msgComm need to shift after merge. + newComm->ranks[i].peerIdx += coll->nPeers; + } + newComm = (struct rasCollComms::comm*)(((char*)(newComm)) + commSize); + msgComm = (struct rasCollComms::comm*)(((char*)(msgComm)) + commSize); + msgIdx++; + } // cmp > 0 + } // for (collIdx and msgIdx) + + free(coll->data); + coll->data = (char*)newData; + // newComm points at the next element beyond the last one -- exactly what we need. + coll->nData = ((char*)newComm) - (char*)newData; + } // if (msgData->nComms > 0) + + return ncclSuccess; +} + +// Sorting callback for the ncclComms array. +static int ncclCommsCompare(const void* p1, const void* p2) { + const ncclComm** pc1 = (const ncclComm**)p1; + const ncclComm** pc2 = (const ncclComm**)p2; + + // Put nullptr's at the end. + if (*pc1 == nullptr || *pc2 == nullptr) + return (*pc1 != nullptr ? -1 : (*pc2 != nullptr ? 1 : 0)); + + if ((*pc1)->commHash == (*pc2)->commHash) { + return ((*pc1)->rank < (*pc2)->rank ? -1 : ((*pc1)->rank > (*pc2)->rank ? 1 : 0)); + } else { + return ((*pc1)->commHash < (*pc2)->commHash ? -1 : 1); + } +} diff --git a/src/ras/peers.cc b/src/ras/peers.cc new file mode 100644 index 0000000000..f2692d3e17 --- /dev/null +++ b/src/ras/peers.cc @@ -0,0 +1,960 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#define NDEBUG // Comment out during development only! +#include + +#include "alloc.h" +#include "checks.h" +#include "comm.h" +#include "nccl.h" +#include "ras_internal.h" + + +// All the known peer NCCL processes. The array is sorted by addr to ensure locality (within a node and hopefully +// also within a DC). The array may grow over time and it *includes* dead peers. +struct rasPeerInfo* rasPeers; +int nRasPeers; +// Hash of the rasPeers array, for figuring out when to sync with a remote peer. +uint64_t rasPeersHash; +// Index of this process within the rasPeers array (may change over time as the array grows). +static int myPeerIdx = -1; + +// Addresses of all the dead peers, sorted. In principle we could instead have a flag in rasPeerInfo for this, +// but we expect rasPeers to be largely static (and large at scale!) and rasDeadPeers to be fairly dynamic and +// much smaller, so we prefer to keep the dead info separately so that we don't end up sending the possibly large +// rasPeerInfo array around all the time. +union ncclSocketAddress* rasDeadPeers; +// The number of dead peers. +int nRasDeadPeers; +// The array size (may be larger than nRasDeadPeers). +static int rasDeadPeersSize; +// Hash of the rasDeadPeers array, for figuring out when to sync with a remote peer. +uint64_t rasDeadPeersHash; + +static ncclResult_t rasRanksConvertToPeers(struct rasRankInit* ranks, int nranks, + struct rasPeerInfo** rankPeers, int *nRankPeers, int* newNRasPeers); +static ncclResult_t rasPeersUpdate(struct rasPeerInfo* rankPeers, int* nRankPeers, int newNRasPeers = -1); + +static ncclResult_t rasNetUpdatePeers(const struct rasPeerInfo* newPeers, int nNewPeers, bool updateDeadPeers, + struct rasRankInit* ranks = nullptr, int nranks = 0, int fromConnIdx = -1); +static ncclResult_t rasLinkPropagateUpdate(struct rasLink* link, const struct rasPeerInfo* newPeers, int nNewPeers, + bool updateDeadPeers, struct rasRankInit* ranks, int nranks, + int fromConnIdx); +static ncclResult_t rasConnPropagateUpdate(struct rasConnection* conn, const struct rasPeerInfo* newPeers, + int nNewPeers, bool updateDeadPeers, struct rasRankInit* ranks, int nranks); +ncclResult_t rasMsgHandlePeersUpdate(struct rasMsg* msg, struct rasSocket* sock); + +static ncclResult_t rasLinkReinitConns(struct rasLink* link); + +static ncclResult_t rasDeadPeersUpdate(union ncclSocketAddress* updatePeers, int* nUpdatePeers); +static ncclResult_t getNewDeadEntry(union ncclSocketAddress** pAddr); + +static int rasAddrRankInitCompare(const void* k, const void* e); +static int rasAddrPeerInfoCompare(const void* k, const void* e); +static int rasRanksCompare(const void* e1, const void* e2); + +static void rasPeersDump(); +static void rasDeadPeersDump(); +static char* rasPeerDump(const struct rasPeerInfo* peer, char* result, size_t nres); + + +///////////////////////////////////////////////////////////////////////////// +// Functions related to the handling of local RAS_ADD_RANKS notifications. // +///////////////////////////////////////////////////////////////////////////// + +// Handles RAS_ADD_RANKS notification -- adds new ranks to the internal list of all RAS peers, reconfigures RAS +// network connections, and notifies the peers. +ncclResult_t rasLocalHandleAddRanks(struct rasRankInit* ranks, int nranks) { + ncclResult_t ret = ncclSuccess; + + INFO(NCCL_RAS, "RAS handling local addRanks request (old nRasPeers %d)", nRasPeers); + + // Convert the input rasRankInit structures into our internal rasPeerInfo. + struct rasPeerInfo* rankPeers = nullptr; + int nRankPeers; + int newNRasPeers; + NCCLCHECKGOTO(rasRanksConvertToPeers(ranks, nranks, &rankPeers, &nRankPeers, &newNRasPeers), ret, fail); + + // Update local rasPeers. + NCCLCHECKGOTO(rasPeersUpdate(rankPeers, &nRankPeers, newNRasPeers), ret, fail); + + INFO(NCCL_RAS, "RAS finished local processing of addRanks request (new nRasPeers %d, nRankPeers %d)", + nRasPeers, nRankPeers); + // Print peers only if something changed and we're the "root". + if (nRankPeers > 0 && memcmp(&ranks[0].addr, &rasNetListeningSocket.addr, sizeof(ranks[0].addr)) == 0) + rasPeersDump(); + + // Propagate the changes through our RAS network links. + NCCLCHECKGOTO(rasNetUpdatePeers(rankPeers, nRankPeers, /*updateDeadPeers*/false, ranks, nranks), ret, fail); + +exit: + if (rankPeers) + free(rankPeers); + free(ranks); + return ret; +fail: + goto exit; +} + +// Converts the rasRankInit structure into rasPeerInfo. This skips empty elements (in case of errors), orders +// elements by the address/cudaDev, and merges elements with duplicate addresses (in case of multiple CUDA devices per +// process). In the process we also calculate how large the merged rasPeers array will need to be. +static ncclResult_t rasRanksConvertToPeers(struct rasRankInit* ranks, int nranks, + struct rasPeerInfo** rankPeers, int *nRankPeers, int* newNRasPeers) { + ncclResult_t ret = ncclSuccess; + int peerIdx, rankPeerIdx; + + // Handy when checking for empty (in case of errors) addresses. + union ncclSocketAddress emptyAddr; + memset(&emptyAddr, '\0', sizeof(emptyAddr)); + + // Begin by sorting the array by address and cudaDev (to match the rasPeers order). + qsort(ranks, nranks, sizeof(*ranks), &rasRanksCompare); + + // We over-allocate peers here because to get an accurate count we would need to loop over the ranks first... + // nRankPeers will hold the actual count of used elements. + *rankPeers = nullptr; + NCCLCHECKGOTO(ncclCalloc(rankPeers, nranks), ret, fail); + + peerIdx = rankPeerIdx = 0; + *newNRasPeers = nRasPeers; + for (int rankIdx = 0; rankIdx < nranks; rankIdx++) { + const struct rasRankInit* rank = ranks+rankIdx; + struct rasPeerInfo* rankPeer = *rankPeers+rankPeerIdx; + + if (memcmp(&emptyAddr, &rank->addr, sizeof(emptyAddr)) == 0) { + // Skip empty rank entries. + continue; + } + + // First check if the rank doesn't need to be merged into the previous entry in rankPeers + // (possible if there are multiple ranks with the same address). + if (rankPeerIdx > 0 && memcmp(&rank->addr, &rankPeer[-1].addr, sizeof(rank->addr)) == 0) { + // Merge into the previous entry in peers. + rankPeer[-1].cudaDevs |= (1UL << rank->cudaDev); + rankPeer[-1].nvmlDevs |= (1UL << rank->nvmlDev); + continue; + } + + // Add a new entry to rankPeers. + assert(rankPeerIdx < nranks); + memcpy(&rankPeer->addr, &rank->addr, sizeof(rankPeer->addr)); + rankPeer->pid = rank->pid; + rankPeer->cudaDevs = (1UL << rank->cudaDev); + rankPeer->nvmlDevs = (1UL << rank->nvmlDev); + rankPeerIdx++; + + // Also check if there is already an entry with that address in the global rasPeers so that the caller can know how + // many more entries will be needed. + const struct rasPeerInfo* rasPeer = rasPeers+peerIdx; + int cmp = 0; + while (peerIdx < nRasPeers) { + cmp = ncclSocketsCompare(&rank->addr, &rasPeer->addr); + if (cmp <= 0) + break; + peerIdx++; + rasPeer++; + } + if (peerIdx == nRasPeers) { + // The current rank is "greater than" all existing peers, so it will need a new entry. We stay in the loop so + // that we don't need to handle the remaining ranks separately. + (*newNRasPeers)++; + continue; + } + if (cmp < 0) { + (*newNRasPeers)++; + } else { + // Duplicates (cmp == 0) between the rank array and the peers array will be merged. + assert(rank->pid == rasPeer->pid); + } + } + assert(peerIdx <= nRasPeers); + *nRankPeers = rankPeerIdx; + +exit: + return ret; +fail: + if (*rankPeers) { + free(*rankPeers); + *rankPeers = nullptr; + } + goto exit; +} + +// Updates the rasPeers array with the new data. The new data gets updated in the process as well: any data that +// wasn't actually new is purged, so as to minimize the amount of data we forward to our peers. +// On a successful return, nRankPeers contains the number of entries that were updated. +static ncclResult_t rasPeersUpdate(struct rasPeerInfo* rankPeers, int* nRankPeers, int newNRasPeers) { + ncclResult_t ret = ncclSuccess; + int rankPeerIdxDst; + int rankPeerIdx, peerIdx; + + if (newNRasPeers == -1) { + // First calculate the new size of rasPeers. + newNRasPeers = nRasPeers; + for (rankPeerIdx = peerIdx = 0; rankPeerIdx < *nRankPeers; rankPeerIdx++) { + struct rasPeerInfo* rankPeer = rankPeers+rankPeerIdx; + struct rasPeerInfo* rasPeer = rasPeers+peerIdx; + int cmp = 1; + + while (peerIdx < nRasPeers) { + cmp = ncclSocketsCompare(&rankPeer->addr, &rasPeer->addr); + + if (cmp < 0) { + // rankPeer will go in front of rasPeer. + newNRasPeers++; + break; + } + + peerIdx++; + rasPeer++; + + if (cmp == 0) + break; + } + if (cmp > 0) // No more rasPeer entries -- rankPeer will go at the end. + newNRasPeers++; + } + } + + // If needed, allocate a new, larger rasPeers array. + struct rasPeerInfo* newRasPeers; + int myNewPeerIdx; + if (newNRasPeers > nRasPeers) { + NCCLCHECKGOTO(ncclCalloc(&newRasPeers, newNRasPeers), ret, fail); + } else { + newRasPeers = rasPeers; + } + + // Now merge the rankPeers into newRasPeers. In the process, modify rankPeers to become a "diff" between + // the old rasPeers and newRasPeers -- this will be the data structure to broadcast on the RAS network. + myNewPeerIdx = -1; + int newPeerIdx; + for (newPeerIdx = rankPeerIdx = peerIdx = 0; rankPeerIdx < *nRankPeers || peerIdx < nRasPeers;) { + struct rasPeerInfo* rankPeer = rankPeers+rankPeerIdx; + struct rasPeerInfo* rasPeer = rasPeers+peerIdx; + struct rasPeerInfo* newRasPeer = newRasPeers+newPeerIdx; + + if (rankPeerIdx < *nRankPeers) { + if (peerIdx < nRasPeers) { + int cmp = ncclSocketsCompare(&rankPeer->addr, &rasPeer->addr); + + if (cmp < 0) { + // rankPeer needs to occur before rasPeer -- that's possible only if we are adding new entries. + assert(newRasPeers != rasPeers); + // Add new entry to newRasPeers. + assert(newPeerIdx < newNRasPeers); + memcpy(newRasPeer, rankPeer, sizeof(*newRasPeer)); + newPeerIdx++; + rankPeerIdx++; + } + else { + // cmp >= 0 -- Start by copying peer to newRasPeer, if needed. + if (newRasPeers != rasPeers) { + assert(newPeerIdx < newNRasPeers); + memcpy(newRasPeer, rasPeer, sizeof(*newRasPeer)); + } + else { // in-place + assert(newRasPeer == rasPeer); + } + + if (cmp == 0) { + // The address of rankPeer is the same as that of newRasPeer -- merge into it. + // First though calculate what GPUs from rankPeer are actually new (if any). + uint64_t newDevs = rankPeer->cudaDevs & ~newRasPeer->cudaDevs; + newRasPeer->cudaDevs |= rankPeer->cudaDevs; + // Update rankPeer->devs with the newly added devs only -- we'll clean it up at the end. + rankPeer->cudaDevs = newDevs; + // Repeat for nvmlDevs... + newDevs = rankPeer->nvmlDevs & ~newRasPeer->nvmlDevs; + newRasPeer->nvmlDevs |= rankPeer->nvmlDevs; + rankPeer->nvmlDevs = newDevs; + rankPeerIdx++; + } + // Given that we might've added new entries, we need to update myPeerIdx as well. + if (myPeerIdx == peerIdx) + myNewPeerIdx = newPeerIdx; + peerIdx++; + newPeerIdx++; + } + } else { // peerIdx == nRasPeers + // No more rasPeers -- add a new entry based on rank. + assert(newPeerIdx < newNRasPeers); + memcpy(newRasPeer, rankPeer, sizeof(*newRasPeer)); + // If this is the first time this function is run, myPeerIdx will need to be set. It's more work in that + // case as we need to compare the addresses of each peer until we find one. + if (myPeerIdx == -1 && memcmp(&newRasPeer->addr, &rasNetListeningSocket.addr, sizeof(newRasPeer->addr)) == 0) + myNewPeerIdx = newPeerIdx; + newPeerIdx++; + rankPeerIdx++; + } + } else { // rankPeerIdx == *nRankPeers + // No more rankPeers -- copy the rasPeer over if needed. + if (newRasPeers != rasPeers) { + assert(newPeerIdx < newNRasPeers); + memcpy(newRasPeer, rasPeer, sizeof(*newRasPeer)); + } + else { // in-place at the end. + assert(newRasPeer == rasPeer); + } + if (myPeerIdx == peerIdx) + myNewPeerIdx = newPeerIdx; + peerIdx++; + newPeerIdx++; + } + } + assert(newPeerIdx == newNRasPeers); + + if (newRasPeers != rasPeers) { + if (rasPeers) + free(rasPeers); + rasPeers = newRasPeers; + nRasPeers = newNRasPeers; + assert(myNewPeerIdx != -1); + myPeerIdx = myNewPeerIdx; + } else { + assert(myNewPeerIdx == myPeerIdx); + } + rasPeersHash = getHash((const char*)rasPeers, nRasPeers*sizeof(*rasPeers)); + + // Purge from rankPeers all entries that didn't actually contribute any new GPUs. + for (rankPeerIdx = rankPeerIdxDst = 0; rankPeerIdx < *nRankPeers; rankPeerIdx++) { + struct rasPeerInfo* rankPeer = rankPeers+rankPeerIdx; + if (rankPeer->cudaDevs != 0) { + if (rankPeerIdxDst != rankPeerIdx) { + memcpy(rankPeers+rankPeerIdxDst, rankPeer, sizeof(*rankPeers)); + } + rankPeerIdxDst++; + } + } + assert(rankPeerIdxDst <= *nRankPeers); + *nRankPeers = rankPeerIdxDst; + +exit: + return ret; +fail: + goto exit; +} + +// Searches through rasPeers given the peer address. Returns the index of the found entry in the rasPeers +// array or -1 if not found. +int rasPeerFind(const union ncclSocketAddress* addr) { + struct rasPeerInfo* peer = (struct rasPeerInfo*)bsearch(addr, rasPeers, nRasPeers, sizeof(*rasPeers), + rasAddrPeerInfoCompare); + return (peer ? peer-rasPeers : -1); +} + + +///////////////////////////////////////////////////////////////////////////////// +// Functions related to the propagation of peers updates over the RAS network. // +///////////////////////////////////////////////////////////////////////////////// + +// Propagates information about new peers through the RAS network links. +// ranks -- if provided -- lists all the peers who are already aware of this update (because they are the members +// of the new communicator being established), and who thus don't need to be notified. updatedDeadPeers can +// be used, however, to request at least the propagation of rasDeadPeers to such peers. +// fromConnIdx -- if provided -- identified the connection used to receive this update; there's no need to +// propagate the update back through it. +// Reconfigures the RAS network to accommodate the newly added peers, by modifying the links and establishing new +// connections as needed. +static ncclResult_t rasNetUpdatePeers(const struct rasPeerInfo* newPeers, int nNewPeers, bool updateDeadPeers, + struct rasRankInit* ranks, int nranks, int fromConnIdx) { + ncclResult_t ret = ncclSuccess; + + // Do we actually have anything to do? + if (nNewPeers == 0 && !updateDeadPeers) + goto exit; + + // Start by propagating the update through the RAS network links. We consider any errors during this process + // to be non-fatal (we can re-sync later around a keep-alive exchange). + (void)rasLinkPropagateUpdate(&rasNextLink, newPeers, nNewPeers, updateDeadPeers, ranks, nranks, fromConnIdx); + (void)rasLinkPropagateUpdate(&rasPrevLink, newPeers, nNewPeers, updateDeadPeers, ranks, nranks, fromConnIdx); + + // Calculate new link peers and open new connections if needed. + NCCLCHECKGOTO(rasLinkReinitConns(&rasNextLink), ret, fail); + NCCLCHECKGOTO(rasLinkReinitConns(&rasPrevLink), ret, fail); + +exit: + return ret; +fail: + goto exit; +} + +// Sends a peers update through all the connections associated with a particular link. See rasNetUpdatePeers +// for the explanation of the function arguments. +static ncclResult_t rasLinkPropagateUpdate(struct rasLink* link, const struct rasPeerInfo* newPeers, int nNewPeers, + bool updateDeadPeers, struct rasRankInit* ranks, int nranks, + int fromConnIdx) { + for (int i = 0; i < link->nConns; i++) { + struct rasLinkConn* linkConn = link->conns+i; + // Note that we don't send the update via the connection that we received this notification from in the first + // place (while it wouldn't loop indefinitely, it would add a needless extra exchange). + if (linkConn->connIdx != -1 && linkConn->connIdx != fromConnIdx) { + struct rasConnection* conn = rasConns+linkConn->connIdx; + // Failed propagations are not considered fatal (we will retry after a keep-alive). + (void)rasConnPropagateUpdate(conn, newPeers, nNewPeers, updateDeadPeers, ranks, nranks); + } + } + + return ncclSuccess; +} + +// Sends a peers update down a particular connection. See rasNetUpdatePeers for the explanation of the function +// arguments. +static ncclResult_t rasConnPropagateUpdate(struct rasConnection* conn, const struct rasPeerInfo* newPeers, + int nNewPeers, bool updateDeadPeers, struct rasRankInit* ranks, int nranks) { + if (conn->sockIdx != -1 && rasSockets[conn->sockIdx].status == RAS_SOCK_READY) { + // If we have the rank info, check if the peer on the other side of this connection has participated in the new + // communicator. + int connRank = -1; + if (ranks && !updateDeadPeers) { + struct rasRankInit* rank = (struct rasRankInit*)bsearch(&conn->addr, ranks, nranks, sizeof(*ranks), + rasAddrRankInitCompare); + if (rank) + connRank = rank-ranks; + } + if (connRank < 0) { + // It did not participate or we don't know -- we should send an update to that peer then. + NCCLCHECK(rasConnSendPeersUpdate(conn, newPeers, nNewPeers)); + } + } + + return ncclSuccess; +} + +// Sends a RAS_MSG_PEERSUPDATE message, which can include both the rasPeers (preferably only the newly added peers +// rather than the complete rasPeers array, to save on the network bandwidth) and rasDeadPeers (sent in its entirety +// if at all, as it's assumed to be a lot smaller than rasPeers). +ncclResult_t rasConnSendPeersUpdate(struct rasConnection* conn, const struct rasPeerInfo* peers, int nPeers) { + struct rasMsg* msg = nullptr; + int msgLen; + int deadPeersOffset = 0; + int nDeadPeers; + + if (conn->lastSentPeersHash == rasPeersHash || conn->lastRecvPeersHash == rasPeersHash) { + nPeers = 0; + } + if (conn->lastSentDeadPeersHash == rasDeadPeersHash || conn->lastRecvDeadPeersHash == rasDeadPeersHash) { + nDeadPeers = 0; + } else { + // We expect the rasDeadPeers array to be much smaller than rasPeers so if we send it, we send it in full. + nDeadPeers = nRasDeadPeers; + } + + if (nPeers == 0 && nDeadPeers == 0) + goto exit; + + msgLen = rasMsgLength(RAS_MSG_PEERSUPDATE) + nPeers*sizeof(*peers); + if (nDeadPeers > 0) { + ALIGN_SIZE(msgLen, alignof(union ncclSocketAddress)); + deadPeersOffset = msgLen; + msgLen += nDeadPeers*sizeof(*rasDeadPeers); + } + + NCCLCHECK(rasMsgAlloc(&msg, msgLen)); + msg->type = RAS_MSG_PEERSUPDATE; + msg->peersUpdate.peersHash = rasPeersHash; + msg->peersUpdate.nPeers = nPeers; + msg->peersUpdate.deadPeersHash = rasDeadPeersHash; + msg->peersUpdate.nDeadPeers = nDeadPeers; + memcpy(msg->peersUpdate.peers, peers, nPeers * sizeof(msg->peersUpdate.peers[0])); + memcpy(((char*)msg)+deadPeersOffset, rasDeadPeers, nDeadPeers * sizeof(*rasDeadPeers)); + + if (nPeers > 0) + conn->lastSentPeersHash = rasPeersHash; + if (nDeadPeers > 0) + conn->lastSentDeadPeersHash = rasDeadPeersHash; + + INFO(NCCL_RAS, "RAS sending a peersUpdate to %s (nPeers %d, nDeadPeers %d)", + ncclSocketToString(&conn->addr, rasLine), nPeers, nDeadPeers); + + rasConnEnqueueMsg(conn, msg, msgLen); +exit: + return ncclSuccess; +} + +// Handles the RAS_MSG_PEERSUPDATE message on the receiver side. The received data is merged into the local +// rasPeers and rasDeadPeers arrays. If the checksums of the resulting arrays don't match those from the message, +// sends its own RAS_MSG_PEERSUPDATE back to the source, to ensure a sync. +// Subsequently propagates the update to its own peers. +ncclResult_t rasMsgHandlePeersUpdate(struct rasMsg* msg, struct rasSocket* sock) { + ncclResult_t ret = ncclSuccess; + struct rasMsg* newMsg = nullptr; + int newMsgLen = 0; + assert(sock->connIdx != -1); + struct rasConnection* conn = rasConns+sock->connIdx; + int nPeers, nDeadPeers; + int deadPeersOffset = 0; + bool updatePeers, updateDeadPeers; + + INFO(NCCL_RAS, "RAS handling peersUpdate from %s (peersHash 0x%lx, deadPeersHash 0x%lx, nPeers %d, nDeadPeers %d)", + ncclSocketToString(&sock->sock.addr, rasLine), msg->peersUpdate.peersHash, msg->peersUpdate.deadPeersHash, + msg->peersUpdate.nPeers, msg->peersUpdate.nDeadPeers); + INFO(NCCL_RAS, "RAS my old rasPeersHash 0x%lx, rasDeadPeersHash 0x%lx, nRasPeers %d, nRasDeadPeers %d", + rasPeersHash, rasDeadPeersHash, nRasPeers, nRasDeadPeers); + conn->lastRecvPeersHash = msg->peersUpdate.peersHash; + conn->lastRecvDeadPeersHash = msg->peersUpdate.deadPeersHash; + + // Prepare ours to send back. We don't enqueue it right away because we want to make sure first that we need + // to send it. We'll find out by comparing the hash values after the merge. + // We want to prepare the message pre-merge though because post-merge it will include the just received new peers, + // and it's pointless to send those back to where they just came from. + // nPeers and nDeadPeers are used primarily for message length calculations, so they have to assume the worst-case + // scenario (e.g., no overlap in case of nDeadPeers). + nPeers = (msg->peersUpdate.peersHash != rasPeersHash ? nRasPeers : 0); + nDeadPeers = (msg->peersUpdate.deadPeersHash != rasDeadPeersHash ? nRasDeadPeers+msg->peersUpdate.nDeadPeers : 0); + if (nPeers > 0 || nDeadPeers > 0) { + newMsgLen = rasMsgLength(RAS_MSG_PEERSUPDATE) + nPeers*sizeof(*rasPeers); + if (nDeadPeers > 0) { + ALIGN_SIZE(newMsgLen, alignof(union ncclSocketAddress)); + newMsgLen += nDeadPeers*sizeof(*rasDeadPeers); + } + NCCLCHECKGOTO(rasMsgAlloc(&newMsg, newMsgLen), ret, fail); + newMsg->type = RAS_MSG_PEERSUPDATE; + // Note that after rasPeersUpdate below we may still decide not to send the peers. + memcpy(newMsg->peersUpdate.peers, rasPeers, nPeers * sizeof(newMsg->peersUpdate.peers[0])); + newMsg->peersUpdate.nPeers = nPeers; + + if (nDeadPeers > 0) { + // Calculate the offset where dead peers are stored in the received message. We do it before the peers + // update because it could modify msg->peersUpdate.nPeers... + deadPeersOffset = rasMsgLength(RAS_MSG_PEERSUPDATE) + msg->peersUpdate.nPeers * sizeof(msg->peersUpdate.peers[0]); + ALIGN_SIZE(deadPeersOffset, alignof(union ncclSocketAddress)); + } + + if (nPeers > 0) + NCCLCHECKGOTO(rasPeersUpdate(msg->peersUpdate.peers, &msg->peersUpdate.nPeers), ret, fail); + else + msg->peersUpdate.nPeers = 0; + if (nDeadPeers > 0) + NCCLCHECKGOTO(rasDeadPeersUpdate((union ncclSocketAddress*)(((char*)msg)+deadPeersOffset), + &msg->peersUpdate.nDeadPeers), ret, fail); + else + msg->peersUpdate.nDeadPeers = 0; + + INFO(NCCL_RAS, "RAS finished local processing of peersUpdate " + "(new nRasPeers %d, nRasDeadPeers %d, nPeers %d, nDeadPeers %d)", + nRasPeers, nRasDeadPeers, msg->peersUpdate.nPeers, msg->peersUpdate.nDeadPeers); + if (msg->peersUpdate.nPeers > 0) + rasPeersDump(); + if (msg->peersUpdate.nDeadPeers > 0) + rasDeadPeersDump(); + + // If post-merge the hashes are still different, send our (dead) peers back. + updatePeers = (conn->lastSentPeersHash != rasPeersHash && conn->lastRecvPeersHash != rasPeersHash); + updateDeadPeers = (conn->lastSentDeadPeersHash != rasDeadPeersHash && + conn->lastRecvDeadPeersHash != rasDeadPeersHash); + if (updatePeers || updateDeadPeers) { + newMsg->peersUpdate.peersHash = rasPeersHash; + newMsg->peersUpdate.deadPeersHash = rasDeadPeersHash; + if (updatePeers) { + assert(nPeers > 0); + conn->lastSentPeersHash = rasPeersHash; + } else { + // If hashes match, make sure that we don't send the rasPeers back. + newMsg->peersUpdate.nPeers = 0; + } + + // We need to recalculate the message size from scratch now that both rasPeers and rasDeadPeers may have changed. + newMsgLen = rasMsgLength(RAS_MSG_PEERSUPDATE) + newMsg->peersUpdate.nPeers * sizeof(*rasPeers); + + if (updateDeadPeers) { + assert(nRasDeadPeers > 0); + conn->lastSentDeadPeersHash = rasDeadPeersHash; + + ALIGN_SIZE(newMsgLen, alignof(union ncclSocketAddress)); + deadPeersOffset = newMsgLen; + newMsgLen += nRasDeadPeers*sizeof(*rasDeadPeers); + + memcpy(((char*)newMsg)+deadPeersOffset, rasDeadPeers, nDeadPeers * sizeof(*rasDeadPeers)); + conn->lastSentDeadPeersHash = rasDeadPeersHash; + newMsg->peersUpdate.nDeadPeers = nRasDeadPeers; + } else { + newMsg->peersUpdate.nDeadPeers = 0; + } + + INFO(NCCL_RAS, "RAS sending back a peersUpdate (nPeers %d, nDeadPeers %d)", + newMsg->peersUpdate.nPeers, newMsg->peersUpdate.nDeadPeers); + + rasConnEnqueueMsg(conn, newMsg, newMsgLen); + newMsg = nullptr; + } // if (updatePeers || updateDeadPeers) + + // Propagate the changes through our RAS network links. + NCCLCHECKGOTO(rasNetUpdatePeers(msg->peersUpdate.peers, msg->peersUpdate.nPeers, updateDeadPeers, nullptr, 0, + sock->connIdx), ret, fail); + } + +exit: + rasMsgFree(newMsg); + return ret; +fail: + goto exit; +} + + +////////////////////////////////////////////////////////////////////////////////////////// +// Functions related to the (re-)configuration of RAS connections after a peers update. // +////////////////////////////////////////////////////////////////////////////////////////// + +// Reinitializes the connection(s) of a particular link, following a peers update. +// Adding new peers can affect the calculation of the link's primary connection and also the fallbacks. +// The newly added peers could also shift all the existing peerIdx values, invalidating the values in RasLinkConn +// structures, so it's better to drop it all and recalculate from scratch. +// We recalculate the primary peer; if an active connection to it already exists, then we're done. If there +// is no connection, we create one. If a connection exists but is experiencing delays then we add a fallback and +// the process repeats. +// External conns are dropped from the links as well (they will be re-created via keepAlive messages as needed). +static ncclResult_t rasLinkReinitConns(struct rasLink* link) { + struct rasLinkConn* linkConn; + struct rasConnection* conn = nullptr; + int newPeerIdx = myPeerIdx; + + if (link->connsSize == 0) { + link->connsSize = RAS_INCREMENT; + NCCLCHECK(ncclCalloc(&link->conns, link->connsSize)); + } + link->nConns = 0; + + // Establish a connection for this link. We iterate as long as the connections we find are experiencing delays. + while (newPeerIdx != -1) { + if (link->nConns == link->connsSize) { + NCCLCHECK(ncclRealloc(&link->conns, link->connsSize, link->connsSize+RAS_INCREMENT)); + link->connsSize += RAS_INCREMENT; + } + + newPeerIdx = rasLinkCalculatePeer(link, newPeerIdx, /*isFallback*/link->nConns > 1); + if (newPeerIdx == -1) { + INFO(NCCL_RAS, "RAS link %d: no more fallbacks to add (nConns %d)", link->direction, link->nConns); + if (link->nConns > 0) + break; + } + linkConn = link->conns+link->nConns; + linkConn->peerIdx = newPeerIdx; + linkConn->connIdx = (newPeerIdx != -1 ? rasConnFind(&rasPeers[newPeerIdx].addr) : -1); + linkConn->external = false; + + // If the calculated connection does not exist, then we are at the end of the chain and this is the last iteration. + // Depending on the circumstances, we may first need to create that connection. + if (linkConn->connIdx == - 1) { + if (link->nConns == 0) { + if (linkConn->peerIdx != -1) { + INFO(NCCL_RAS, "RAS link %d: %s primary connection with %s", + link->direction, (myPeerIdx < linkConn->peerIdx ? "opening new" : "calculated deferred"), + ncclSocketToString(&rasPeers[linkConn->peerIdx].addr, rasLine)); + // We try to initiate primary connections from the side with a lower address (and thus an earlier peer index) + // to avoid races and the creation of duplicate connections. + if (myPeerIdx < linkConn->peerIdx) { + NCCLCHECK(rasConnCreate(&rasPeers[linkConn->peerIdx].addr, &linkConn->connIdx)); + } + else { // If we didn't initiate the connection, start the timeout. + link->lastUpdatePeersTime = clockNano(); + } + } // if (linkConn->peerIdx != -1) + } else { // link->nConns > 0 + INFO(NCCL_RAS, "RAS link %d: opening new fallback connection %d with %s", + link->direction, link->nConns, ncclSocketToString(&rasPeers[linkConn->peerIdx].addr, rasLine)); + NCCLCHECK(rasConnCreate(&rasPeers[newPeerIdx].addr, &linkConn->connIdx)); + } // link->nConns > 0 + } else { // linkConn->connIdx != -1 + if (link->nConns == 0) { + INFO(NCCL_RAS, "RAS link %d: calculated existing primary connection with %s", + link->direction, ncclSocketToString(&rasPeers[linkConn->peerIdx].addr, rasLine)); + } else { + INFO(NCCL_RAS, "RAS link %d: calculated existing fallback connection %d with %s", + link->direction, link->nConns, ncclSocketToString(&rasPeers[linkConn->peerIdx].addr, rasLine)); + } + } + link->nConns++; + if (linkConn->connIdx == -1) + break; + conn = rasConns+linkConn->connIdx; + + // We check if the connection already went through the fallback calculation; if so, we'll need to create a new + // fallback in the next iteration, to ensure that RAS will keep retrying. + if (!conn->experiencingDelays) + break; + + INFO(NCCL_RAS, "RAS connection experiencingDelays %d, startRetryTime %.2fs, socket status %d", + conn->experiencingDelays, (clockNano()-conn->startRetryTime)/1e9, + (conn->sockIdx == -1 ? -1 : rasSockets[conn->sockIdx].status)); + } + + return ncclSuccess; +} + +// Calculates the index of the peer on the RAS network. Can also be used to calculate the index of the next fallback +// peer. +// In the simplest case we want to try the "next closest" fallback, although we still need to check for and skip +// any dead peers. +// For fallbacks to fallbacks, we also apply a more pessimistic policy. We skip all the remaining RAS threads that +// are on the same node as the previous fallback (unless it's the same node that we're running on or we have strong +// indications that the node is up). We do that to avoid having to excessively wait iterating through, say, 8 +// processes when a whole node might be down. +int rasLinkCalculatePeer(const struct rasLink* link, int peerIdx, bool isFallback) { + int newPeerIdx = (peerIdx + link->direction + nRasPeers) % nRasPeers; + do { + if (isFallback && !ncclSocketsSameNode(&rasPeers[peerIdx].addr, &rasNetListeningSocket.addr)) { + // peerIdx is a fallback and it is not running on the same node as us. + int tryPeerIdx = newPeerIdx; + int tryConnIdx = -1; + + // Try to skip the remaining peers on the same node as peerIdx. We may end up skipping over some peers that + // are alive, which is fine -- they will still have connectivity with the rest of the RAS network, just a + // little suboptimal one. + while (ncclSocketsSameNode(&rasPeers[tryPeerIdx].addr, &rasPeers[peerIdx].addr)) { + if (!rasPeerIsDead(&rasPeers[tryPeerIdx].addr)) { + tryConnIdx = rasConnFind(&rasPeers[tryPeerIdx].addr); + if (tryConnIdx != -1) { + struct rasConnection* tryConn = rasConns+tryConnIdx; + // Check if the connection is fully established and operational, i.e., if the underlying socket + // is ready and there's been recent communication on it. + if (tryConn->sockIdx != -1 && rasSockets[tryConn->sockIdx].status == RAS_SOCK_READY && + !tryConn->experiencingDelays) { + // We convinced ourselves that the node is not down. We don't adjust newPeerIdx in + // this case. This is the only case when tryConnIdx != -1 after this loop. + break; + } + } // if (tryConnIdx != -1) + } // if (!rasPeerIsDead(&rasPeers[tryPeerIdx].addr)) + + tryConnIdx = -1; + tryPeerIdx = (tryPeerIdx + nRasPeers + link->direction) % nRasPeers; + if (tryPeerIdx == myPeerIdx) + break; + } + + if (tryConnIdx == -1) + newPeerIdx = tryPeerIdx; + if (tryPeerIdx == myPeerIdx) + break; + } // if (isFallback && !ncclSocketsSameNode(&rasPeers[peerIdx].addr, &rasNetListeningSocket.addr)) + + if (rasPeerIsDead(&rasPeers[newPeerIdx].addr)) { + newPeerIdx = (newPeerIdx + nRasPeers + link->direction) % nRasPeers; + } + else + break; + } while (newPeerIdx != myPeerIdx); + + return (newPeerIdx != myPeerIdx ? newPeerIdx : -1); +} + + +////////////////////////////////////////////////////// +// Functions related to the handling of dead peers. // +////////////////////////////////////////////////////// + +// Marks a peer as dead in the local rasDeadPeers array. Any propagation, reconfiguration, etc., needs to be +// handled outside of this function. +ncclResult_t rasPeerDeclareDead(const union ncclSocketAddress* addr) { + union ncclSocketAddress* deadAddr; + + if (!rasPeerIsDead(addr)) { + NCCLCHECK(getNewDeadEntry(&deadAddr)); + memcpy(deadAddr, addr, sizeof(*deadAddr)); + qsort(rasDeadPeers, nRasDeadPeers, sizeof(*rasDeadPeers), &ncclSocketsCompare); + + rasDeadPeersHash = getHash((const char*)rasDeadPeers, nRasDeadPeers*sizeof(*rasDeadPeers)); + + INFO(NCCL_RAS, "RAS declaring peer %s as DEAD; rasDeadPeersHash 0x%lx", + ncclSocketToString(addr, rasLine), rasDeadPeersHash); + } + return ncclSuccess; +} + +// Invoked when an incoming RAS_MSG_PEERSUPDATE includes info on dead peers. Updates the rasDeadPeers array. +// Any propagation needs to be handled outside of this function, though it *does* disconnect any connections +// with the newly dead peers. +// On return, nUpdatePeers contains the number of newly added dead entries. +static ncclResult_t rasDeadPeersUpdate(union ncclSocketAddress* updatePeers, int* nUpdatePeers) { + static union ncclSocketAddress* newPeers = nullptr; + static union ncclSocketAddress* oldPeers; + + if (*nUpdatePeers == 0) + return ncclSuccess; + + // Pessimistically estimate the new size of rasDeadPeers. + int nNewPeers = nRasDeadPeers + *nUpdatePeers; + if (nNewPeers > rasDeadPeersSize) { + nNewPeers = ROUNDUP(nNewPeers, RAS_INCREMENT); + + NCCLCHECK(ncclCalloc(&newPeers, nNewPeers)); + oldPeers = rasDeadPeers; + } else { + // We don't need to allocate a new array in this case. We just shift the existing content to the end of the + // array to make room in the front for merging. + oldPeers = rasDeadPeers+(rasDeadPeersSize-nRasDeadPeers); + memmove(oldPeers, rasDeadPeers, nRasDeadPeers*sizeof(*rasDeadPeers)); + newPeers = rasDeadPeers; + } + + // Merge updatePeers with oldPeers into newPeers. + int oldPeersIdx, updatePeersIdx, newPeersIdx; + for (oldPeersIdx = updatePeersIdx = newPeersIdx = 0; oldPeersIdx < nRasDeadPeers || updatePeersIdx < *nUpdatePeers;) { + int cmp; + if (oldPeersIdx < nRasDeadPeers && updatePeersIdx < *nUpdatePeers) { + cmp = ncclSocketsCompare(oldPeers+oldPeersIdx, updatePeers+updatePeersIdx); + } else { + cmp = (oldPeersIdx < nRasDeadPeers ? -1 : 1); + } + + memmove(newPeers+newPeersIdx++, (cmp <= 0 ? oldPeers+oldPeersIdx : updatePeers+updatePeersIdx), sizeof(*newPeers)); + if (cmp <= 0) + oldPeersIdx++; + if (cmp > 0) { + rasConnDisconnect(updatePeers+updatePeersIdx); + } + if (cmp >= 0) + updatePeersIdx++; + } + *nUpdatePeers = newPeersIdx - nRasDeadPeers; + nRasDeadPeers = newPeersIdx; + + if (newPeers != rasDeadPeers) { + free(rasDeadPeers); + rasDeadPeers = newPeers; + rasDeadPeersSize = nNewPeers; + } + + rasDeadPeersHash = getHash((const char*)rasDeadPeers, nRasDeadPeers*sizeof(*rasDeadPeers)); + + return ncclSuccess; +} + +// Returns the index of the first available entry in the rasDeadPeers array, enlarging the array if necessary. +static ncclResult_t getNewDeadEntry(union ncclSocketAddress** pAddr) { + if (nRasDeadPeers == rasDeadPeersSize) { + NCCLCHECK(ncclRealloc(&rasDeadPeers, rasDeadPeersSize, rasDeadPeersSize+RAS_INCREMENT)); + rasDeadPeersSize += RAS_INCREMENT; + } + + *pAddr = rasDeadPeers+(nRasDeadPeers++); + return ncclSuccess; +} + +// Checks whether a peer is dead by looking it up in the rasDeadPeers array. +bool rasPeerIsDead(const union ncclSocketAddress* addr) { + return (rasDeadPeers != nullptr && + bsearch(addr, rasDeadPeers, nRasDeadPeers, sizeof(*rasDeadPeers), ncclSocketsCompare) != nullptr); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Auxiliary functions -- primarily sorting/searching callbacks, plus some debug output support. // +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Searching callback for struct rasRankInit. Compares the ncclSocketAddress key against a rasRankInit element. +static int rasAddrRankInitCompare(const void* k, const void* e) { + const union ncclSocketAddress* key = (const union ncclSocketAddress*)k; + const struct rasRankInit* elem = (const struct rasRankInit*)e; + + return ncclSocketsCompare(key, &elem->addr); +} + +// Searching callback for struct rasPeerInfo. Compares the ncclSocketAddress key against a rasPeerInfo element. +static int rasAddrPeerInfoCompare(const void* k, const void* e) { + const union ncclSocketAddress* key = (const union ncclSocketAddress*)k; + const struct rasPeerInfo* elem = (const struct rasPeerInfo*)e; + + return ncclSocketsCompare(key, &elem->addr); +} + +// Sorting callback for struct rasRankInit. addr is the primary key; cudaDev is secondary. +static int rasRanksCompare(const void* e1, const void* e2) { + const struct rasRankInit* r1 = (const struct rasRankInit*)e1; + const struct rasRankInit* r2 = (const struct rasRankInit*)e2; + int cmp = ncclSocketsCompare(&r1->addr, &r2->addr); + if (cmp == 0) { + if (r1->addr.sa.sa_family == 0) // Bail out in case of empty addresses... + return 0; + assert(r1->pid == r2->pid); + cmp = (r1->cudaDev < r2->cudaDev ? -1 : (r1->cudaDev > r2->cudaDev ? 1 : 0)); + assert(cmp != 0); // There should be no complete duplicates within the rank array. + } + return cmp; +} + +// Sorting callback for ncclSocketAddress. We want to sort by the address family (IPv4 first), then the address, +// then port. Unfortunately, that's not the order of how they are laid out in memory, so one big memcmp won't do. +// memcmp is still useful though for individual elements in the network byte order. +int ncclSocketsCompare(const void* p1, const void* p2) { + const union ncclSocketAddress* a1 = (const union ncclSocketAddress*)p1; + const union ncclSocketAddress* a2 = (const union ncclSocketAddress*)p2; + // AF_INET (2) is less than AF_INET6 (10). + int family = a1->sa.sa_family; + if (family != a2->sa.sa_family) { + if (family > 0 && a2->sa.sa_family > 0) + return (family < a2->sa.sa_family ? -1 : 1); + else // Put empty addresses at the end (not that it matters...). + return (family > 0 ? -1 : 1); + } + + int cmp; + if (family == AF_INET) { + if ((cmp = memcmp(&a1->sin.sin_addr, &a2->sin.sin_addr, sizeof(a1->sin.sin_addr))) == 0) { + cmp = memcmp(&a1->sin.sin_port, &a2->sin.sin_port, sizeof(a1->sin.sin_port)); + } + } + else if (family == AF_INET6) { + if ((cmp = memcmp(&a1->sin6.sin6_addr, &a2->sin6.sin6_addr, sizeof(a1->sin6.sin6_addr))) == 0) { + cmp = memcmp(&a1->sin6.sin6_port, &a2->sin6.sin6_port, sizeof(a1->sin6.sin6_port)); + } + } else { + // The only remaining valid case are empty addresses. + assert(family == 0); + cmp = 0; // Two empty addresses are equal... + } + + return cmp; +} + +// Returns true if two socket addresses are from the same node (actually, the same network interface on one node). +bool ncclSocketsSameNode(const union ncclSocketAddress* a1, const union ncclSocketAddress* a2) { + // AF_INET (2) is less than AF_INET6 (10). + int family = a1->sa.sa_family; + if (family != a2->sa.sa_family) + return false; + + if (family == AF_INET) + return (memcmp(&a1->sin.sin_addr, &a2->sin.sin_addr, sizeof(a1->sin.sin_addr)) == 0); + else if (family == AF_INET6) + return (memcmp(&a1->sin6.sin6_addr, &a2->sin6.sin6_addr, sizeof(a1->sin6.sin6_addr)) == 0); + else + return true; // Two empty addresses are equal... +} + +// Debug output routine: dumps the rasPeers array. +static void rasPeersDump() { + for (int p = 0; p < nRasPeers; p++) { + const struct rasPeerInfo* peer = rasPeers+p; + INFO(NCCL_RAS, "RAS peer %d: %s%s", p, rasPeerDump(peer, rasLine, sizeof(rasLine)), (p == myPeerIdx ? " [this process]" : "")); + } + if (nRasPeers > 0) + INFO(NCCL_RAS, "RAS peersHash 0x%lx", rasPeersHash); +} + +// Debug output routine: dumps the rasDeadPeers array. +static void rasDeadPeersDump() { + for (int p = 0; p < nRasDeadPeers; p++) { + int deadPeerIdx = rasPeerFind(rasDeadPeers+p); + INFO(NCCL_RAS, "RAS dead peer %d: %s", p, + (deadPeerIdx >= 0 ? rasPeerDump(rasPeers+deadPeerIdx, rasLine, sizeof(rasLine)) : + ncclSocketToString(rasDeadPeers+p, rasLine))); + } + if (nRasDeadPeers > 0) + INFO(NCCL_RAS, "RAS deadPeersHash 0x%lx", rasDeadPeersHash); +} + +// Debug output routine: dumps part of an individual element from the rasPeers array. +static char* rasPeerDump(const struct rasPeerInfo* peer, char* result, size_t nres) { + char line[SOCKET_NAME_MAXLEN+1], line2[1024]; + snprintf(result, nres, "socket %s, pid %d, GPU%s %s", ncclSocketToString(&peer->addr, line), peer->pid, + (__builtin_popcountll(peer->cudaDevs) > 1 ? "s" : ""), + rasGpuDevsToString(peer->cudaDevs, peer->nvmlDevs, line2, sizeof(line2))); + return result; +} diff --git a/src/ras/ras.cc b/src/ras/ras.cc new file mode 100644 index 0000000000..4905d7a69c --- /dev/null +++ b/src/ras/ras.cc @@ -0,0 +1,668 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#define NDEBUG // Comment out during development only! +#include +#include +#include +#include +#include + +#include "alloc.h" +#include "checks.h" +#include "comm.h" +#include "nccl.h" +#include "utils.h" +#include "ras_internal.h" + +// Type of a notification from a local NCCL thread. +typedef enum { + RAS_ADD_RANKS = 0, + RAS_TERMINATE = 1 +} rasNotificationType; + +// Used for communication from local NCCL threads to the RAS thread. +struct rasNotification { + rasNotificationType type; + union { + struct { + struct rasRankInit* ranks; + int nranks; + } addRanks; + }; +}; +static_assert(sizeof(struct rasNotification) <= PIPE_BUF, "The rasNotification structure is too large"); + +// These ensure that we get only one RAS port/thread per process. +static std::mutex rasInitMutex; +static bool rasInitialized = false; +static int rasInitRefCount = 0; + +// The RAS network listening socket of this RAS thread (random port). +struct ncclSocket rasNetListeningSocket; + +static pthread_t rasThread; + +// Used for communication from regular NCCL threads to the RAS thread. +static std::mutex rasNotificationMutex; +static int rasNotificationPipe[2] = {-1, -1}; + +// Data for the main poll() in the RAS thread. +struct pollfd* rasPfds; +static int nRasPfds; + +// We use it all over the place; no point in wasting the stack... +char rasLine[SOCKET_NAME_MAXLEN+1]; + +// An array holding the addresses of all NCCL communicators. Modified by the NCCL threads (hence the mutex), read by +// the RAS thread. +std::mutex ncclCommsMutex; +struct ncclComm** ncclComms = nullptr; +int nNcclComms = 0; +bool ncclCommsSorted = false; // Whether the array is currently sorted. We sort by the comms' commHash and rank. + +static ncclResult_t rasLocalNotify(const struct rasNotification* msg); +static ncclResult_t rasLocalHandle(); +static void rasLocalHandleTerminate(); + +static ncclResult_t rasMsgHandleConnInit(const struct rasMsg* msg, struct rasSocket* sock); +static ncclResult_t rasMsgHandleConnInitAck(const struct rasMsg* msg, struct rasSocket* sock); +static ncclResult_t rasNetSendNack(struct rasSocket* sock); + +static void* rasThreadMain(void*); + +NCCL_PARAM(RasTimeoutFactor, "RAS_TIMEOUT_FACTOR", 1); + +////////////////////////////////////////////////// +// Functions invoked from regular NCCL threads. // +////////////////////////////////////////////////// + +// Invoked by regular NCCL threads on every comm initialization. This is the first function to call. +// The myRank structure should be passed with the addr element initialized to the IP address of the bootstrap +// network interface to use. On a successful return, the address will be updated with the port number of the +// RAS network listening socket. +ncclResult_t ncclRasCommInit(struct ncclComm* comm, struct rasRankInit* myRank) { + ncclResult_t ret = ncclSuccess; + if (!rasInitialized) { + std::lock_guard lock(rasInitMutex); + if (!rasInitialized) { + union ncclSocketAddress addr; + + memcpy(&addr, &myRank->addr, sizeof(addr)); + (addr.sa.sa_family == AF_INET ? addr.sin.sin_port : addr.sin6.sin6_port) = htons(0); + NCCLCHECKGOTO(ncclSocketInit(&rasNetListeningSocket, &addr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, + /*abortFlag*/nullptr, /*asyncFlag*/1), ret, fail); + NCCLCHECKGOTO(ncclSocketListen(&rasNetListeningSocket), ret, fail); + INFO(NCCL_RAS, "RAS network listening socket at %s", + ncclSocketToString(&rasNetListeningSocket.addr, rasLine)); + + (void)rasClientInitSocket(); + + SYSCHECKGOTO(pipe(rasNotificationPipe), "pipe", ret, fail); + + PTHREADCHECKGOTO(pthread_create(&rasThread, nullptr, &rasThreadMain, nullptr), "pthread_create", ret, fail); + ncclSetThreadName(rasThread, "NCCL RAS"); + (void)pthread_detach(rasThread); + + rasInitialized = true; + } + } + ncclAtomicRefCountIncrement(&rasInitRefCount); + + { + std::lock_guard lock(ncclCommsMutex); + + int i; + for (i = 0; i < nNcclComms; i++) { + if (ncclComms[i] == nullptr) + break; + } + if (i == nNcclComms) { + NCCLCHECK(ncclRealloc(&ncclComms, nNcclComms, nNcclComms+RAS_INCREMENT*8)); + nNcclComms += RAS_INCREMENT*8; + } + ncclComms[i] = comm; + ncclCommsSorted = false; + } + + if (myRank != nullptr) + memcpy(&myRank->addr, &rasNetListeningSocket.addr, sizeof(myRank->addr)); + +exit: + return ret; +fail: + if (rasNotificationPipe[1] != 0) + (void)close(rasNotificationPipe[1]); + if (rasNotificationPipe[0] != 0) + (void)close(rasNotificationPipe[0]); + (void)close(rasClientListeningSocket); + (void)ncclSocketClose(&rasNetListeningSocket); + goto exit; +} + +// Invoked by regular NCCL threads on every comm termination. +ncclResult_t ncclRasCommFini(const struct ncclComm* comm) { + if (!rasInitialized) + return ncclSuccess; + { + std::lock_guard lock(ncclCommsMutex); + for (int i = 0; i < nNcclComms; i++) { + if (ncclComms[i] == comm) { + ncclComms[i] = nullptr; + ncclCommsSorted = false; + break; + } + } + } + if (ncclAtomicRefCountDecrement(&rasInitRefCount) == 0) { + struct rasNotification msg; + msg.type = RAS_TERMINATE; + NCCLCHECK(rasLocalNotify(&msg)); + } + return ncclSuccess; +} + +// Invoked by regular NCCL threads on every (non-split) comm initialization. Provides info on all the ranks within +// the communicator. +ncclResult_t ncclRasAddRanks(struct rasRankInit* ranks, int nranks) { + struct rasNotification msg; + msg.type = RAS_ADD_RANKS; + msg.addRanks.ranks = ranks; + msg.addRanks.nranks = nranks; + NCCLCHECK(rasLocalNotify(&msg)); + return ncclSuccess; +} + +// Internal function running on regular NCCL threads -- asynchronously notifies the RAS thread. +static ncclResult_t rasLocalNotify(const struct rasNotification* msg) { + if (!rasInitialized) + return ncclSuccess; + + // Take an exclusive lock here to avoid multiplexing between multiple user threads (not sure if it's + // strictly required, but it won't hurt)... + std::lock_guard lock(rasNotificationMutex); + size_t done = 0; + while (done < sizeof(*msg)) { + ssize_t written; + SYSCHECK(written = write(rasNotificationPipe[1], (char*)msg + done, sizeof(*msg) - done), "write"); + done += written; + } + return ncclSuccess; +} + + +///////////////////////////////////////////////////////////////////////////////// +// Functions related to the handling of local notifications from NCCL threads. // +///////////////////////////////////////////////////////////////////////////////// + +// Handles asynchronous local notifications arriving from regular NCCL threads. +static ncclResult_t rasLocalHandle() { + struct rasNotification msg; + + size_t done = 0; + while (done < sizeof(msg)) { + ssize_t nread; + SYSCHECK(nread = read(rasNotificationPipe[0], (char*)&msg + done, sizeof(msg) - done), "read"); + if (nread == 0) // EOF + return ncclSystemError; + done += nread; + } + + if (msg.type == RAS_ADD_RANKS) { + NCCLCHECK(rasLocalHandleAddRanks(msg.addRanks.ranks, msg.addRanks.nranks)); + } else if (msg.type == RAS_TERMINATE) { + rasLocalHandleTerminate(); + } else { + WARN("RAS received unknown notification type %d", msg.type); + return ncclInternalError; + } + + return ncclSuccess; +} + +// Handles local RAS_TERMINATE notification. +static void rasLocalHandleTerminate() { + INFO(NCCL_RAS, "RAS handling local termination request"); + // For now we don't do anything. +} + + +//////////////////////////////////////////////// +// Generic functions related to RAS messages. // +//////////////////////////////////////////////// + +// Allocates a RAS message of the desired length for sending. +// Behind the scenes allocates encapsulating rasMsgMeta structure, which includes local metadata stored in front +// of the message. +// Must use rasMsgFree to free. +ncclResult_t rasMsgAlloc(struct rasMsg** msg, size_t msgLen) { + struct rasMsgMeta* meta = nullptr; + NCCLCHECK(ncclCalloc((char**)&meta, offsetof(struct rasMsgMeta, msg) + msgLen)); + *msg = &meta->msg; + // coverity[leaked_storage:FALSE] => rasMsgFree is used to free it + return ncclSuccess; +} + +// To be used only with messages allocated with rasMsgAlloc. I.e., it should be used for sent messages, not +// for received ones. +void rasMsgFree(struct rasMsg* msg) { + if (msg) { + struct rasMsgMeta* meta = (struct rasMsgMeta*)((char*)msg - offsetof(struct rasMsgMeta, msg)); + free(meta); + } +} + +// Enqueues a message for sending down a RAS connection. +void rasConnEnqueueMsg(struct rasConnection* conn, struct rasMsg* msg, size_t msgLen, bool front) { + // Get to the metadata of this message. + struct rasMsgMeta* meta = (struct rasMsgMeta*)((char*)msg - offsetof(struct rasMsgMeta, msg)); + bool ready = false; + + meta->enqueueTime = clockNano(); + meta->offset = 0; + meta->length = msgLen; + + if (front) + ncclIntruQueueEnqueueFront(&conn->sendQ, meta); + else + ncclIntruQueueEnqueue(&conn->sendQ, meta); + + if (conn->sockIdx != -1) { + struct rasSocket* sock = rasSockets+conn->sockIdx; + if (sock->status == RAS_SOCK_READY || (sock->status == RAS_SOCK_HANDSHAKE && msg->type == RAS_MSG_CONNINIT)) { + rasPfds[sock->pfd].events |= POLLOUT; + ready = true; + } + } + if (!ready) { + // It's not a bug, unless it's for things like keep-alive messages... + INFO(NCCL_RAS, "RAS enqueued message type %d on a non-ready connection with %s " + "(experiencingDelays %d, startRetryTime %.2fs, socket status %d)", + msg->type, ncclSocketToString(&conn->addr, rasLine), + conn->experiencingDelays, (conn->startRetryTime ? (clockNano()-conn->startRetryTime)/1e9 : 0.0), + (conn->sockIdx == -1 ? -1 : rasSockets[conn->sockIdx].status)); + } +} + +// Attempts to send the queued RAS messages to another RAS thread. +ncclResult_t rasConnSendMsg(struct rasConnection* conn, int* closed, bool* allSent) { + struct ncclSocket* sock = &rasSockets[conn->sockIdx].sock; + struct rasMsgMeta* meta; + *closed = 0; + while ((meta = ncclIntruQueueHead(&conn->sendQ)) != nullptr) { + if (rasSockets[conn->sockIdx].status == RAS_SOCK_HANDSHAKE && meta->msg.type != RAS_MSG_CONNINIT) { + // We don't send anything beyond the handshake at this point. + meta = nullptr; + break; + } + if (meta->offset < sizeof(meta->length)) { + // Send the length of the message. + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, sock, &meta->length, sizeof(meta->length), &meta->offset, closed)); + if (*closed) + return ncclSuccess; + if (meta->offset < sizeof(meta->length)) + break; + } + // Send the body of the message. + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, sock, ((char*)&meta->msg)-sizeof(meta->length), + meta->length+sizeof(meta->length), &meta->offset, closed)); + if (*closed) + return ncclSuccess; + if (meta->offset < meta->length+sizeof(meta->length)) + break; + ncclIntruQueueDequeue(&conn->sendQ); + free(meta); + } + + *allSent = !meta; + + return ncclSuccess; +} + +// Attempts to receive a message through a RAS socket. +ncclResult_t rasMsgRecv(struct rasSocket* sock, struct rasMsg** msg, int* closed) { + *closed = 0; + if (sock->recvOffset < sizeof(sock->recvLength)) { + // Receive the length of the message. + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &sock->sock, &sock->recvLength, sizeof(sock->recvLength), + &sock->recvOffset, closed)); + if (*closed || sock->recvOffset < sizeof(sock->recvLength)) + return ncclSuccess; + NCCLCHECK(ncclCalloc((char**)&sock->recvMsg, sock->recvLength)); + } + // Receive the body of the message. + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &sock->sock, ((char*)sock->recvMsg)-sizeof(sock->recvLength), + sock->recvLength+sizeof(sock->recvLength), &sock->recvOffset, closed)); + if (*closed || sock->recvOffset < sock->recvLength+sizeof(sock->recvLength)) + return ncclSuccess; + + *msg = sock->recvMsg; + sock->recvMsg = nullptr; + sock->recvOffset = sock->recvLength = 0; + + return ncclSuccess; +} + + +////////////////////////////////////////////////////////////////// +// Functions related to the handling of specific message types. // +////////////////////////////////////////////////////////////////// + +// Invoked from the main RAS thread to dispatch incoming messages to the appropriate handler. +ncclResult_t rasMsgHandle(struct rasMsg* msg, struct rasSocket* sock) { + if (msg->type == RAS_MSG_CONNINIT) { + NCCLCHECK(rasMsgHandleConnInit(msg, sock)); + } else if (msg->type == RAS_MSG_CONNINITACK) { + NCCLCHECK(rasMsgHandleConnInitAck(msg, sock)); + } else if (msg->type == RAS_MSG_KEEPALIVE) { + NCCLCHECK(rasMsgHandleKeepAlive(msg, sock)); + } else if (msg->type == RAS_MSG_PEERSUPDATE) { + NCCLCHECK(rasMsgHandlePeersUpdate(msg, sock)); + } else if (msg->type == RAS_MSG_COLLREQ) { + NCCLCHECK(rasMsgHandleCollReq(msg, sock)); + } else if (msg->type == RAS_MSG_COLLRESP) { + NCCLCHECK(rasMsgHandleCollResp(msg, sock)); + } else { + WARN("RAS received unknown message type (%d) from %s", msg->type, ncclSocketToString(&sock->sock.addr, rasLine)); + return ncclInternalError; + } + + return ncclSuccess; +} + +// Handles the first message sent over a RAS socket as part of the handshake. +static ncclResult_t rasMsgHandleConnInit(const struct rasMsg* msg, struct rasSocket* sock) { + ncclResult_t ret = ncclSuccess; + struct rasConnection* conn = nullptr; + int connIdx, peerIdx; + struct rasMsg* newMsg = nullptr; + int newMsgLen; + char line[SOCKET_NAME_MAXLEN+1]; + + INFO(NCCL_RAS, "RAS handling connInit from %s (version %d, listeningAddr %s, peersHash 0x%lx, deadPeersHash 0x%lx)", + ncclSocketToString(&sock->sock.addr, rasLine), msg->connInit.ncclVersion, + ncclSocketToString(&msg->connInit.listeningAddr, line), msg->connInit.peersHash, msg->connInit.deadPeersHash); + + if (msg->connInit.ncclVersion != NCCL_VERSION_CODE) { + // Close any such sockets immediately! This is basically unrecoverable... + WARN("NCCL version mismatch with remote peer %s (local: %d, remote %d)", + ncclSocketToString(&sock->sock.addr, rasLine), NCCL_VERSION_CODE, msg->connInit.ncclVersion); + rasNetSendNack(sock); + rasSocketTerminate(sock, /*finalize*/true); + ret = ncclInvalidUsage; + goto exit; + } + + if (rasPeerIsDead(&msg->connInit.listeningAddr)) { + // A peer long declared dead is suddenly alive again?! + INFO(NCCL_RAS, "RAS connection from peer %s that is considered dead!", + ncclSocketToString(&msg->connInit.listeningAddr, rasLine)); + rasNetSendNack(sock); + rasSocketTerminate(sock, /*finalize*/true); + goto exit; + } + + // Check for any existing connection with that RAS thread (could happen due to a network issue, or possibly a race). + connIdx = rasConnFind(&msg->connInit.listeningAddr); + if (connIdx != -1) { + conn = rasConns+connIdx; + + INFO(NCCL_RAS, + "RAS found a matching existing connection (sendQ %sempty, experiencingDelays %d, startRetryTime %.2fs)", + (ncclIntruQueueEmpty(&conn->sendQ) ? "" : "not "), + conn->experiencingDelays, (conn->startRetryTime ? (clockNano()-conn->startRetryTime)/1e9 : 0.0)); + + if (conn->sockIdx != -1) { + struct rasSocket* connSock = rasSockets+conn->sockIdx; + INFO(NCCL_RAS, "RAS found an alternative existing socket (status %d, createTime %.2fs)", + connSock->status, (clockNano()-connSock->createTime)/1e9); + // In general we prefer to keep the newer connection, but "newer" can be a relative term: we may have + // a race where both sides attempt to establish a connection at roughly the same time, so the other side's + // incoming connection ends up looking newer than the locally-initiated one -- for *both* of them. + // If each side closed the "old" one, both would end up being closed. + // As we normally try to initiate connections from the side with a lower address (precisely to avoid such + // situations), we'll follow the same logic here: the "lower" side will reject the new connection (as it + // came from the "wrong" side), whereas the "higher" side will keep the new one (as it came from the correct + // side) and terminate the old one (that it presumably just opened). + if (ncclSocketsCompare(&rasNetListeningSocket.addr, &conn->addr) < 0) { + INFO(NCCL_RAS, "RAS terminating the new socket"); + rasSocketTerminate(sock, /*finalize*/true); + goto exit; + } else { + INFO(NCCL_RAS, "RAS keeping the new socket and terminating the existing one"); + rasSocketTerminate(connSock); + } + } + } + if (!conn) { + NCCLCHECK(getNewConnEntry(&conn)); + memcpy(&conn->addr, &msg->connInit.listeningAddr, sizeof(conn->addr)); + connIdx = conn - rasConns; + } + + sock->status = RAS_SOCK_READY; + // rasConnResume will reset any experiencingDelays, startRetryTime, etc. + + conn->sockIdx = sock-rasSockets; + sock->connIdx = connIdx; + memcpy(&sock->sock.addr, &msg->connInit.listeningAddr, sizeof(sock->sock.addr)); + + // Make sure that the connection is part of the right links forming the RAS network. At this point we only + // update the expected (non-external) connections; external ones will be added during keep-alive handling. + peerIdx = rasPeerFind(&conn->addr); + // Note: it's possible for peerIdx to be -1 at this point if, due to races, the connInit arrives before + // the peers update. + if (peerIdx != -1) { + (void)rasLinkUpdateConn(&rasNextLink, connIdx, peerIdx); + (void)rasLinkUpdateConn(&rasPrevLink, connIdx, peerIdx); + } + + // Send a confirmation to the server that requested the connection (so that the resilience code can mark + // the connection as live). + newMsgLen = rasMsgLength(RAS_MSG_CONNINITACK); + NCCLCHECK(rasMsgAlloc(&newMsg, newMsgLen)); + newMsg->type = RAS_MSG_CONNINITACK; + newMsg->connInitAck.nack = 0; + rasConnEnqueueMsg(conn, newMsg, newMsgLen, /*front*/true); + + conn->lastRecvPeersHash = msg->connInit.peersHash; + conn->lastRecvDeadPeersHash = msg->connInit.deadPeersHash; + + if (msg->connInit.peersHash != rasPeersHash || msg->connInit.deadPeersHash != rasDeadPeersHash) { + // Send my rasPeers and request the same in return. + INFO(NCCL_RAS, "RAS connInit hash mismatch (my peersHash 0x%lx, deadPeersHash 0x%lx); sending my (dead) peers", + rasPeersHash, rasDeadPeersHash); + NCCLCHECK(rasConnSendPeersUpdate(conn, rasPeers, nRasPeers)); + } +exit: + return ret; +} + +// Handles the second message sent over a RAS socket as part of the handshake. +static ncclResult_t rasMsgHandleConnInitAck(const struct rasMsg* msg, struct rasSocket* sock) { + INFO(NCCL_RAS, "RAS handling connInitAck from %s (nack %d)", + ncclSocketToString(&sock->sock.addr, rasLine), msg->connInitAck.nack); + + if (msg->connInitAck.nack) { + // The remote peer doesn't want to talk to us. The easiest way to prevent it is by declaring it dead. + // We make a copy of the address because rasConnDisconnect will terminate the rasSocket. + union ncclSocketAddress addr; + memcpy(&addr, &sock->sock.addr, sizeof(addr)); + rasConnDisconnect(&addr); + (void)rasPeerDeclareDead(&addr); + + return ncclSuccess; + } + + sock->status = RAS_SOCK_READY; + // rasConnResume will reset any experiencingDelays, startRetryTime, etc. + + return ncclSuccess; +} + +// Handles the deadPeer broadcast. +void rasMsgHandleBCDeadPeer(const struct rasCollRequest* req, bool* pDone) { + INFO(NCCL_RAS, "RAS handling deadPeer (addr %s)", ncclSocketToString(&req->deadPeer.addr, rasLine)); + + if (!rasPeerIsDead(&req->deadPeer.addr)) { + rasConnDisconnect(&req->deadPeer.addr); + (void)rasPeerDeclareDead(&req->deadPeer.addr); + *pDone = false; + } else { + INFO(NCCL_RAS, "RAS already knew it was dead"); + // No point in re-broadcasting what's already known. + *pDone = true; + } +} + +// Attempts to immediately send a fatal NACK connInitAck response to a socket. A bit of a hack (as it doesn't +// follow our usual message queuing and polling convention) but, since this can be invoked only for newly opened +// connections, and the message is tiny, it should be OK. We can't use the regular path because the socket is +// about to be terminated. +static ncclResult_t rasNetSendNack(struct rasSocket* sock) { + struct rasMsg msg; + int length = rasMsgLength(RAS_MSG_CONNINITACK); + int closed = 0; + int offset; + + INFO(NCCL_RAS, "RAS sending NACK to %s", ncclSocketToString(&sock->sock.addr, rasLine)); + + msg.type = RAS_MSG_CONNINITACK; + msg.connInitAck.nack = 1; + offset = 0; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &sock->sock, &length, sizeof(length), &offset, &closed)); + if (closed || offset < sizeof(length)) + return ncclSuccess; + offset = 0; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &sock->sock, &msg, length, &offset, &closed)); + // We are closing this socket anyway -- it doesn't matter to us if we succeeded or not. + + return ncclSuccess; +} + + +///////////////////////////////////////////////////////////////// +// Functions related to the main event loop of the RAS thread. // +///////////////////////////////////////////////////////////////// + +// Main function of the RAS thread. +static void* rasThreadMain(void*) { + ncclResult_t ret = ncclSuccess; // Unused. + int pfd; + int rasNetListeningSocketFd; + + INFO(NCCL_RAS, "RAS thread started"); + + // Initialize the global pollfd with the file descriptors we already have (the pipe and the listening socket). + NCCLCHECKGOTO(rasGetNewPollEntry(&pfd), ret, fail); + rasPfds[pfd].fd = rasNotificationPipe[0]; + rasPfds[pfd].events = POLLIN; + + NCCLCHECKGOTO(rasGetNewPollEntry(&pfd), ret, fail); + NCCLCHECKGOTO(ncclSocketGetFd(&rasNetListeningSocket, &rasNetListeningSocketFd), ret, fail); + rasPfds[pfd].fd = rasNetListeningSocketFd; + rasPfds[pfd].events = POLLIN; + + NCCLCHECKGOTO(rasGetNewPollEntry(&pfd), ret, fail); + rasPfds[pfd].fd = rasClientListeningSocket; + rasPfds[pfd].events = POLLIN; + + // Main event loop of the RAS thread. + for (int64_t nextWakeup=0;;) { + int timeout, nEvents; + int64_t now = clockNano(); + if (nextWakeup > 0) { + // The "1" below helps avoid round-downs and especially zeroes. + if (nextWakeup > now) + timeout = (nextWakeup - now) / (CLOCK_UNITS_PER_SEC / 1000) + 1; + else + timeout = 1; + } else { + timeout = 1000; // 1 second. + } + + nEvents = poll(rasPfds, nRasPfds, timeout); + + nextWakeup = clockNano()+CLOCK_UNITS_PER_SEC; + if (nEvents == -1 && errno != EINTR) + INFO(NCCL_RAS, "RAS continuing in spite of an unexpected error from poll: %s", strerror(errno)); + + // Handle any poll-related events. + for (int pollIdx = 0; pollIdx < nRasPfds && nEvents > 0; pollIdx++) { + if (rasPfds[pollIdx].revents) { + nEvents--; + if (rasPfds[pollIdx].fd == rasNotificationPipe[0]) { + (void)rasLocalHandle(); + } else if (rasPfds[pollIdx].fd == rasNetListeningSocketFd) { + (void)rasNetAcceptNewSocket(); + } else if (rasPfds[pollIdx].fd == rasClientListeningSocket) { + (void)rasClientAcceptNewSocket(); + } else { + // Check if it's one of the RAS sockets. + int sockIdx; + for (sockIdx = 0; sockIdx < nRasSockets; sockIdx++) { + struct rasSocket* sock = rasSockets+sockIdx; + if (sock->status != RAS_SOCK_CLOSED && rasPfds[pollIdx].fd == sock->sock.fd) { + rasSockEventLoop(sockIdx, pollIdx); + break; + } + } // for (sockIdx) + + if (sockIdx == nRasSockets) { + // Try a client socket instead. + for (int clientIdx = 0; clientIdx < nRasClients; clientIdx++) { + struct rasClient* client = rasClients+clientIdx; + if (client->status != RAS_CLIENT_CLOSED && rasPfds[pollIdx].fd == client->sock) { + rasClientEventLoop(clientIdx, pollIdx); + break; + } + } // for (clientIdx) + } // if (sockIdx == nRasSockets) + } // dynamic fds + } // if (revents) + } // for (pollIdx) + + now = clockNano(); + + rasSocksHandleTimeouts(now, &nextWakeup); + + rasConnsHandleTimeouts(now, &nextWakeup); + + rasNetHandleTimeouts(now, &nextWakeup); + + rasCollsHandleTimeouts(now, &nextWakeup); + } // for (;;) + +fail: + WARN("fatal error - RAS thread terminating"); + std::lock_guard lock(rasInitMutex); + (void)close(rasNotificationPipe[1]); + (void)close(rasNotificationPipe[0]); + (void)close(rasClientListeningSocket); + (void)ncclSocketClose(&rasNetListeningSocket); + rasInitialized = false; + return nullptr; +} + +// Returns the index of the first available entry in the rasPfds array, enlarging the array if necessary. +ncclResult_t rasGetNewPollEntry(int* index) { + int i; + for (i = 0; i < nRasPfds; i++) + if (rasPfds[i].fd == -1) + break; + if (i == nRasPfds) { + NCCLCHECK(ncclRealloc(&rasPfds, nRasPfds, nRasPfds+RAS_INCREMENT)); + nRasPfds += RAS_INCREMENT; + for (int j = i; j < nRasPfds; j++) + rasPfds[j].fd = -1; + } + + memset(rasPfds+i, '\0', sizeof(*rasPfds)); + rasPfds[i].fd = -1; + + *index = i; + return ncclSuccess; +} diff --git a/src/ras/ras_internal.h b/src/ras/ras_internal.h new file mode 100644 index 0000000000..68cac0b44b --- /dev/null +++ b/src/ras/ras_internal.h @@ -0,0 +1,512 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_RAS_INTERNAL_H_ +#define NCCL_RAS_INTERNAL_H_ + +#define NCCL_RAS_CLIENT_PORT 28028 +#define NCCL_RAS_CLIENT_PROTOCOL 2 + +#define RAS_COLLECTIVE_LEG_TIMEOUT_SEC 5 +#define RAS_COLLECTIVE_EXTRA_TIMEOUT_SEC RAS_COLLECTIVE_LEG_TIMEOUT_SEC + +// End of the client section; everything below is meant for the NCCL threads only. +#ifndef NCCL_RAS_CLIENT + +#include + +#include "nccl.h" +#include "ras.h" +#include "socket.h" +#include "utils.h" + +// Type of a RAS network or client message. +typedef enum { + RAS_MSG_CONNINIT = 1, + RAS_MSG_CONNINITACK = 2, + RAS_MSG_KEEPALIVE = 3, + RAS_MSG_PEERSUPDATE = 4, + RAS_MSG_COLLREQ = 5, + RAS_MSG_COLLRESP = 6, +} rasMsgType; + +// Type of a RAS network collective message. +typedef enum { + RAS_MSG_NONE = 0, + RAS_BC_DEADPEER = 1, + // Broadcast operations above this line; collective operations below (1000 is the demarcation line). + RAS_COLL_CONNS = 1001, // Collect data about all RAS connections. + RAS_COLL_COMMS = 1002, // Collect data about all communicators. +} rasCollectiveType; + +// Payload of a collective request message (RAS_MSG_COLLREQ). +struct rasCollRequest { + union ncclSocketAddress rootAddr; + uint64_t rootId; + + int64_t timeout; + rasCollectiveType type; + union { + struct { + union ncclSocketAddress addr; + } deadPeer; + struct { + } conns; + struct { + } comms; + }; +}; + +// Payload of a collective response message (RAS_MSG_COLLRESP). +struct rasCollResponse { + union ncclSocketAddress rootAddr; + uint64_t rootId; + + int nLegTimeouts; // If >0, indicates incomplete data. + int nPeers; + int nData; // Size of data in bytes. + union ncclSocketAddress peers[0]; // Variable length. + // The peersAddrs array is followed by: + //alignas(int64_t) char data[0]; // Variable length, collective-dependent. +}; + +// Describes a peer NCCL process. Every RAS thread keeps an (identical) array of them, one entry for each +// NCCL process. +struct rasPeerInfo { + union ncclSocketAddress addr; + pid_t pid; + uint64_t cudaDevs; // Bitmask. Conveniently, NCCL_MAX_LOCAL_RANKS == 64. + uint64_t nvmlDevs; // Same, but not affected by CUDA_VISIBLE_DEVICES. +}; + +// Describes a RAS message. Every message is preceded by a (32-bit) message length. All data in the host +// byte order. Depending on the message type, the length of the message will vary. +struct rasMsg { + rasMsgType type; + union { + struct { + int ncclVersion; + union ncclSocketAddress listeningAddr; + uint64_t peersHash; + uint64_t deadPeersHash; + } connInit; // Sent by the connecting side as the first message. + struct { + int nack; // If non-0, we should stop trying to reconnect. + } connInitAck; // Response from the accepting side to the above. + struct { + uint64_t peersHash; + uint64_t deadPeersHash; + int linkMask; // What links at the destination peer should the connection be part of + // (bit 0: nextLink; bit 1: prevLink). + struct timespec realTime; // Wallclock time at the source, for statistical purposes (in principle there's + // no guarantee that the nodes have synchronized clocks so we can't really rely + // on it for anything important).. + int nack; // If non-0, it means that this message is a response to an unexpected keepAlive message. + } keepAlive; + struct { + uint64_t peersHash; + uint64_t deadPeersHash; + int nPeers; + int nDeadPeers; + struct rasPeerInfo peers[0]; // Variable length. + // The peers array is followed by the following: + //union ncclSocketAddress deadPeers[0]; // Variable length. + } peersUpdate; + struct { + int protocol; // Protocol version, sent to the client. + } clientInit; + struct { + int nData; + char data[0]; // Variable length. + } clientDump; + struct rasCollRequest collReq; // Variable length. + struct rasCollResponse collResp; // Variable length. + }; +}; + +// Returns the size of the collective portion of a collective request message. +static inline size_t rasCollDataLength(rasCollectiveType type) { + struct rasCollRequest* data; + switch (type) { + case RAS_BC_DEADPEER: + return offsetof(struct rasCollRequest, deadPeer) + sizeof(data->deadPeer); + case RAS_COLL_CONNS: + return offsetof(struct rasCollRequest, conns) + sizeof(data->conns); + case RAS_COLL_COMMS: + return offsetof(struct rasCollRequest, comms) + sizeof(data->comms); + case RAS_MSG_NONE: + return 0; + }; + return 0; +} + +// Returns the size for a message of a particular type. +static inline size_t rasMsgLength(rasMsgType type, rasCollectiveType collType = RAS_MSG_NONE) { + struct rasMsg* msg; + switch (type) { + case RAS_MSG_CONNINIT: + return offsetof(struct rasMsg, connInit) + sizeof(msg->connInit); + case RAS_MSG_CONNINITACK: + return offsetof(struct rasMsg, connInitAck) + sizeof(msg->connInitAck); + case RAS_MSG_KEEPALIVE: + return offsetof(struct rasMsg, keepAlive) + sizeof(msg->keepAlive); + case RAS_MSG_PEERSUPDATE: + return offsetof(struct rasMsg, peersUpdate) + sizeof(msg->peersUpdate); + case RAS_MSG_COLLREQ: + return offsetof(struct rasMsg, collReq) + rasCollDataLength(collType); + case RAS_MSG_COLLRESP: + return offsetof(struct rasMsg, collResp) + sizeof(msg->collResp); + }; + return 0; +} + +// How much to enlarge any RAS array by if we run out of space. +#define RAS_INCREMENT 4 + +// Our clock has nanosecond resolution. +#define CLOCK_UNITS_PER_SEC 1000000000L + +// Keep-alive messages are sent no sooner than a second after the last message was sent down a particular connection. +#define RAS_KEEPALIVE_INTERVAL (1*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// If no message arrives in 5 seconds via a particular connection that uses keep-alive messages, generate a warning +// and try alternative connections. +#define RAS_KEEPALIVE_TIMEOUT_WARN (5*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// Abort a socket that uses keep-alive messages if no message arrives in 20 seconds. +// We will try to re-establish communication via that connection (until RAS_PEER_DEAD_TIMEOUT). +#define RAS_KEEPALIVE_TIMEOUT_ERROR RAS_STUCK_TIMEOUT + +// Retry connecting on failing sockets (ECONNREFUSED, etc.) once a second. +#define RAS_CONNECT_RETRY (1*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// If we can't connect in 5 seconds, we generate a warning and try alternative connections. +#define RAS_CONNECT_WARN RAS_KEEPALIVE_TIMEOUT_WARN + +// Abort a busy socket (one we are trying to send on, or one that was being established) if there's been +// no sign of progress in 20 second. We will try to re-establish communication (up to RAS_PEER_DEAD_TIMEOUT). +#define RAS_STUCK_TIMEOUT (20*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// Terminate ad-hoc connections that have not been used in 60 seconds. +#define RAS_IDLE_TIMEOUT (60*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// If the socket is closed by peer within 5 seconds from the idle timeout, do not attempt to re-establish. +#define RAS_IDLE_GRACE_PERIOD (5*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// Declare a peer as dead and don't retry communicating with it if we couldn't reach it for 60 seconds. +#define RAS_PEER_DEAD_TIMEOUT (60*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// Abort a leg of a collective operation if the response takes more than 5 seconds to arrive *and* one of the +// connections experiences delays. +#define RAS_COLLECTIVE_LEG_TIMEOUT (RAS_COLLECTIVE_LEG_TIMEOUT_SEC*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// Abort a whole collective operation after at most RAS_COLLECTIVE_LEG_TIMEOUT+RAS_COLLECTIVE_EXTRA_TIMEOUT (10s). +#define RAS_COLLECTIVE_EXTRA_TIMEOUT (RAS_COLLECTIVE_EXTRA_TIMEOUT_SEC*CLOCK_UNITS_PER_SEC*ncclParamRasTimeoutFactor()) + +// Structure used for tracking the progress of sending a RAS message. +struct rasMsgMeta { + struct rasMsgMeta* next; + int64_t enqueueTime; + int offset; // Progress sending the message (including the message size itself (an int, which is sent first)). + int length; // Length of the message (*excluding* the message size). + struct rasMsg msg; // Variable length. +}; + +// Describes an ongoing collective RAS operation (apart from broadcasts, which don't need a response). +// For every collective operation, each participating RAS thread will create its own. +struct rasCollective { + union ncclSocketAddress rootAddr; + uint64_t rootId; + + rasCollectiveType type; + + int64_t timeout; + bool timeoutWarned; + + int64_t startTime; // For timeout calculations. + int fromConnIdx; // The connection we received the request from. + + int* fwdConns; // Indices of the connections we forwarded the request to; replaced by -1 as the responses arrive. + int nFwdSent; // Count of the above (local process only). + int nFwdRecv; // Count of the responses received or timeouts (local process only). + + int nLegTimeouts; // Collective (from this process and the responses we received). + + union ncclSocketAddress* peers; // Collective (from this process and the responses we received). + int nPeers; + + char* data; // Collective (from this process and the responses we received). + int nData; +}; + +// Collective data in RAS_COLL_CONNS responses. +struct rasCollConns { + int64_t travelTimeMin; + int64_t travelTimeMax; + int64_t travelTimeSum; + int64_t travelTimeCount; + int nConns; + int nNegativeMins; + struct negativeMin { + union ncclSocketAddress source; + union ncclSocketAddress dest; + int64_t travelTimeMin; + } negativeMins[0]; // Variable length. +}; + +// Collective data in RAS_COLL_COMMS responses. +struct rasCollComms { + int nComms; + struct comm { + uint64_t commHash; + int commNRanks; + int nRanks; // number of elements in the array below, *not* in the communicator. + struct rank { + int commRank; + int peerIdx; // Index within rasCollective->peers, *not* rasPeers. + uint64_t collOpCount; + struct { + ncclResult_t initState:4; + ncclResult_t asyncError:4; + bool finalizeCalled:1; + bool destroyFlag:1; + bool abortFlag:1; + } status; + char cudaDev; + char nvmlDev; + } ranks[0]; // Variable length. Sorted by commRank. Optimized for 1 GPU/process. + } comms[0]; // Variable length. Sorted by commHash. +}; + +// Holds data needed to keep track of a connection belonging to a RAS network link (either the primary one +// or one of the fallbacks). +struct rasLinkConn { + int peerIdx; // Index in the rasPeers array of the peer this entry describes. Could be -1 (an entry initiated + // by an as of yet unknown peer -- should be a temporary situation that resolves via peer updates). + int connIdx; // Index in the rasConns array of the connection to the above peer. Could be -1 (a placeholder + // for a connection to be started by the remote peer). + bool external; // true if the entry exists only due to an external request (requested by a remote peer, most + // likely as part of fault recovery). Such connections are kept as fallbacks even if there's a + // valid primary connection, in order to ensure that keep-alive messages are sent. +}; + +// Describes a link that forms the backbone of the RAS network. Links focus on direction (previous/next in +// case of 1-D topology) rather than a particular destination. The are implemented using rasConnections, but +// they are persistent through the life of the RAS threads, whereas rasConnections can be terminated if the RAS +// network is reconfigured or a peer dies. +struct rasLink { + int direction; // 1 for nextLink, -1 for prevLink. + + // Index 0 is the primary connection; any additional ones are fallbacks (that get created if we are having + // problems with the primary connection). The elements are de-facto ordered (highest-preference ones have + // the lowest indices). + struct rasLinkConn* conns; + int nConns; + int connsSize; // Array size; could be larger than nConns. + + // Keep track of a timeout in case we did not create a connection during the last peers update (because we expect + // the peer on the other side to do so) but that peer failed to initiate. + int64_t lastUpdatePeersTime; +}; + +// Describes a connection to another peer on the RAS network. It is meant to be more persistent than a volatile +// socket (described by the rasSocket structure), which can be affected by transient network issues. +struct rasConnection { + bool inUse; + + union ncclSocketAddress addr; + + // Index of the current rasSocket in the rasSockets array. Note that multiple rasSocket entries may point back + // to a single entry here, for sockets that are in the process of being terminated and re-established. + // We use indices, not pointers, because the arrays holding these structures can be re-alloced at run time. + // -1 if there is no such socket. + int sockIdx; + + // We keep the rasPeersHash of remote connections to minimize the number of needless exchanges. + // There is a subtle difference in the meaning of lastSentPeersHash and lastRecvPeersHash. + // lastSentPeersHash stores *our* rasPeersHash from the time we last sent a peers *update* through this connection + // (which is different than sending just the hash, like we do in KEEPALIVE, etc.). + // lastRecvPeersHash stores the latest known rasPeersHash of the peer (received via KEEPALIVE, etc.). + uint64_t lastSentPeersHash; + uint64_t lastRecvPeersHash; + + // Same but for rasDeadPeersHash. + uint64_t lastSentDeadPeersHash; + uint64_t lastRecvDeadPeersHash; + + // Queue of messages to send. + struct ncclIntruQueue sendQ; + + // Used for keeping track of timeouts that may extend beyond the lifetime of a socket. + // The timeout starts when the connection is being created (and is turned off when the initialization is completed + // successfully) or when we detect a problem, such as a socket timeout (in the latter case, we may need to + // retroactively calculate the start time). + // A value of 0 indicates that they are not currently in use. + int64_t startRetryTime; + int64_t lastRetryTime; + + bool experiencingDelays; // A flag indicating that the connection is currently subject to RAS_KEEPALIVE_TIMEOUT_WARN + // or RAS_CONNECT_WARN timeout. If set, the warnings have been issued and the fallbacks + // have been initiated if needed. + bool linkFlag; // Used within rasNet* calls to mark whether this connection was already handled when iterating over + // multiple links (since a connection can belong to more than one link). + // The below four fields are for statistical purposes only. + int64_t travelTimeMin; + int64_t travelTimeMax; + int64_t travelTimeSum; + int64_t travelTimeCount; +}; + +// Status of a RAS socket. +typedef enum { + RAS_SOCK_CLOSED = 0, + RAS_SOCK_CONNECTING = 1, + RAS_SOCK_HANDSHAKE = 2, + RAS_SOCK_READY = 3, + RAS_SOCK_TERMINATING = 4 +} rasSocketStatus; + +// Describes a socket implementing communication between two peers. +struct rasSocket { + struct ncclSocket sock; + + rasSocketStatus status; + + int pfd; // Index in the rasPfds array. + + // Index of the corresponding entry in the rasConns array. + // We use indices, not pointers, because the arrays holding these structures can be re-alloced at run time. + // -1 if there is no connection (normal condition on the accept side before the connInit message). + int connIdx; + + int64_t createTime; + int64_t lastSendTime; + int64_t lastRecvTime; + + // Data on the message currently being received. + int recvOffset; + int recvLength; + struct rasMsg* recvMsg; +}; + +// Status of a RAS client. +typedef enum { + RAS_CLIENT_CLOSED = 0, + RAS_CLIENT_CONNECTED = 1, + RAS_CLIENT_INIT = 2, + RAS_CLIENT_CONNS = 3, + RAS_CLIENT_COMMS = 4, + RAS_CLIENT_FINISHED = 99 +} rasClientStatus; + +// Describes a RAS client. +struct rasClient { + int sock; + + rasClientStatus status; + + int pfd; // Index in the rasPfds array. + + char recvBuffer[1024]; + int recvOffset; + + // Queue of messages to send. + struct ncclIntruQueue sendQ; + + int verbose; + int64_t timeout; + + // State stored during asynchronous operations such as collectives. + int collIdx; // Index to the onging rasCollective. +}; + + +// ras.cc +extern struct pollfd* rasPfds; +extern struct ncclSocket rasNetListeningSocket; +extern std::mutex ncclCommsMutex; +extern struct ncclComm** ncclComms; +extern int nNcclComms; +extern bool ncclCommsSorted; +extern char rasLine[SOCKET_NAME_MAXLEN+1]; + +int64_t ncclParamRasTimeoutFactor(); +ncclResult_t rasMsgAlloc(struct rasMsg** msg, size_t msgLen); +void rasMsgFree(struct rasMsg* msg); +void rasConnEnqueueMsg(struct rasConnection* conn, struct rasMsg* msg, size_t msgLen, bool front = false); +ncclResult_t rasConnSendMsg(struct rasConnection* conn, int* closed, bool* allSent); +ncclResult_t rasMsgRecv(struct rasSocket* sock, struct rasMsg** msg, int* closed); +ncclResult_t rasMsgHandle(struct rasMsg* msg, struct rasSocket* sock); +void rasMsgHandleBCDeadPeer(const struct rasCollRequest* req, bool* pDone); +ncclResult_t rasGetNewPollEntry(int* index); + + +// rasnet.cc +extern struct rasLink rasNextLink, rasPrevLink; +extern struct rasConnection* rasConns; +extern int nRasConns; +extern struct rasSocket *rasSockets; +extern int nRasSockets; + +ncclResult_t getNewConnEntry(struct rasConnection** pConn); +ncclResult_t rasConnCreate(const union ncclSocketAddress* addr, int* pConnIdx); +int rasConnFind(const union ncclSocketAddress* addr); +void rasConnsHandleTimeouts(int64_t now, int64_t* nextWakeup); +void rasConnDisconnect(const union ncclSocketAddress* addr); +ncclResult_t rasNetAcceptNewSocket(); +void rasSocksHandleTimeouts(int64_t now, int64_t* nextWakeup); +void rasSocketTerminate(struct rasSocket* sock, bool finalize = false, uint64_t startRetryOffset = 0, + bool retry = true); +void rasSockEventLoop(int sockIdx, int pollIdx); +void rasNetHandleTimeouts(int64_t now, int64_t* nextWakeup); +ncclResult_t rasMsgHandleKeepAlive(const struct rasMsg* msg, struct rasSocket* sock); +ncclResult_t rasLinkUpdateConn(struct rasLink* link, int connIdx, int peerIdx, bool external = false, + bool insert = false, bool pretend = false, int* pLinkIdx = nullptr); + +// peers.cc +extern struct rasPeerInfo* rasPeers; +extern int nRasPeers; +extern uint64_t rasPeersHash; +extern union ncclSocketAddress* rasDeadPeers; +extern int nRasDeadPeers; +extern uint64_t rasDeadPeersHash; + +ncclResult_t rasLocalHandleAddRanks(struct rasRankInit* ranks, int nranks); +int rasPeerFind(const union ncclSocketAddress* addr); +ncclResult_t rasConnSendPeersUpdate(struct rasConnection* conn, const struct rasPeerInfo* peers, int nPeers); +ncclResult_t rasMsgHandlePeersUpdate(struct rasMsg* msg, struct rasSocket* sock); +int rasLinkCalculatePeer(const struct rasLink* link, int peerIdx, bool isFallback = false); +ncclResult_t rasPeerDeclareDead(const union ncclSocketAddress* addr); +bool rasPeerIsDead(const union ncclSocketAddress* addr); +int ncclSocketsCompare(const void* p1, const void* p2); +bool ncclSocketsSameNode(const union ncclSocketAddress* a1, const union ncclSocketAddress* a2); + + +// collectives.cc +extern struct rasCollective* rasCollectives; + +void rasCollReqInit(struct rasCollRequest* req); +ncclResult_t rasNetSendCollReq(const struct rasCollRequest* req, size_t reqLen, bool* pAllDone = nullptr, + int* pCollIdx = nullptr, int fromConnIdx = -1); +ncclResult_t rasMsgHandleCollReq(struct rasMsg* msg, struct rasSocket* sock); +ncclResult_t rasMsgHandleCollResp(struct rasMsg* msg, struct rasSocket* sock); +void rasCollsPurgeConn(int connIdx); +void rasCollFree(struct rasCollective* coll); +void rasCollsHandleTimeouts(int64_t now, int64_t* nextWakeup); + +// client_support.cc +extern int rasClientListeningSocket; +extern struct rasClient* rasClients; +extern int nRasClients; +ncclResult_t rasClientInitSocket(); +ncclResult_t rasClientAcceptNewSocket(); +ncclResult_t rasClientResume(struct rasCollective* coll); +void rasClientEventLoop(int clientIdx, int pollIdx); +const char* rasGpuDevsToString(uint64_t cudaDevs, uint64_t nvmlDevs, char* buf, size_t size); + +#endif // !NCCL_RAS_CLIENT + +#endif // !NCCL_RAS_INTERNAL_H_ diff --git a/src/ras/rasnet.cc b/src/ras/rasnet.cc new file mode 100644 index 0000000000..441ad192c0 --- /dev/null +++ b/src/ras/rasnet.cc @@ -0,0 +1,1189 @@ +/************************************************************************* + * Copyright (c) 2016-2024, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#define NDEBUG // Comment out during development only! +#include + +#include "ras_internal.h" + +// Links forming the backbone of the RAS network (currently a ring). +struct rasLink rasNextLink = {1}, rasPrevLink = {-1}; + +// Connections on the RAS network. +struct rasConnection* rasConns; +int nRasConns; + +// Sockets implementing the RAS network. +struct rasSocket *rasSockets; +int nRasSockets; + +// Magic file descriptor number when we want poll() to ignore an entry. Anything negative would do, but +// I didn't want to use -1 because it has a special meaning for us. +#define POLL_FD_IGNORE -2 + +static void rasConnOpen(struct rasConnection* conn); +static ncclResult_t rasConnPrepare(struct rasConnection* conn); +static void rasConnTerminate(struct rasConnection* conn); + +static ncclResult_t getNewSockEntry(struct rasSocket** pSock); + +static ncclResult_t rasLinkHandleNetTimeouts(struct rasLink* link, int64_t now, int64_t* nextWakeup); +static void rasConnHandleNetTimeouts(int connIdx, int64_t now, int64_t* nextWakeup); +static void rasConnSendKeepAlive(struct rasConnection* conn, bool nack = false); + +static ncclResult_t rasLinkAddFallback(struct rasLink* link, int connIdx); +static void rasConnResume(struct rasConnection* conn); +static void rasLinkSanitizeFallbacks(struct rasLink* link); +static void rasLinkDropConn(struct rasLink* link, int connIdx, int linkIdx = -1); +static int rasLinkFindConn(const struct rasLink* link, int connIdx); + + +/////////////////////////////////////////////// +// Functions related to the RAS connections. // +/////////////////////////////////////////////// + +// Allocates an entry in the rasConns array, enlarging the array if necessary. +ncclResult_t getNewConnEntry(struct rasConnection** pConn) { + struct rasConnection* conn; + int i; + for (i = 0; i < nRasConns; i++) + if (!rasConns[i].inUse) + break; + if (i == nRasConns) { + NCCLCHECK(ncclRealloc(&rasConns, nRasConns, nRasConns+RAS_INCREMENT)); + nRasConns += RAS_INCREMENT; + } + + conn = rasConns+i; + memset(conn, '\0', sizeof(*conn)); + conn->inUse = true; + conn->sockIdx = -1; + ncclIntruQueueConstruct(&conn->sendQ); + conn->travelTimeMin = INT64_MAX; + conn->travelTimeMax = INT64_MIN; + + *pConn = conn; + return ncclSuccess; +} + +// Creates a new RAS network connection to a remote peer address. +ncclResult_t rasConnCreate(const union ncclSocketAddress* addr, int* pConnIdx) { + ncclResult_t ret = ncclSuccess; + struct rasConnection* conn = nullptr; + + // First check if a connection entry for this peer already exists. + int connIdx = rasConnFind(addr); + if (connIdx != -1) { + conn = rasConns+connIdx; + } + + if (conn && conn->sockIdx != -1) { + // An entry exists and has a socket associated with it -- nothing left for us to do. + if (pConnIdx) + *pConnIdx = connIdx; + goto exit; + } + + if (!conn) { + NCCLCHECKGOTO(getNewConnEntry(&conn), ret, exit); + memcpy(&conn->addr, addr, sizeof(conn->addr)); + // We are establishing a new connection -- start the timeout. + conn->startRetryTime = clockNano(); + connIdx = conn - rasConns; + } + + if (pConnIdx) + *pConnIdx = connIdx; + + rasConnOpen(conn); + +exit: + return ret; +} + +// Opens a connection to a remote peer. +static void rasConnOpen(struct rasConnection* conn) { + ncclResult_t ret; // Not used. + struct rasSocket* sock; + bool closeSocketOnFail = false; + int ready; + + NCCLCHECKGOTO(getNewSockEntry(&sock), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&sock->sock, &conn->addr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, + /*asyncFlag*/1, /*customRetry*/1), ret, fail); + closeSocketOnFail = true; + NCCLCHECKGOTO(ncclSocketConnect(&sock->sock), ret, fail); + NCCLCHECKGOTO(ncclSocketReady(&sock->sock, &ready), ret, fail); + + NCCLCHECKGOTO(rasGetNewPollEntry(&sock->pfd), ret, fail); + + // We delay the initialization of sockIdx, connIdx and status until this point so that in case of failures + // we don't need to clean them up. + conn->sockIdx = sock-rasSockets; + sock->connIdx = conn-rasConns; + rasPfds[sock->pfd].fd = sock->sock.fd; + + // We ignore the possibly ready status of the socket at this point and consider it CONNECTING because + // there are other things we want to do before sending the CONNINIT, such as adding the connection to + // the network links, etc. + sock->status = RAS_SOCK_CONNECTING; + rasPfds[sock->pfd].events = (POLLIN | POLLOUT); + if (sock->sock.state == ncclSocketStateConnecting) + rasPfds[sock->pfd].fd = POLL_FD_IGNORE; // Don't poll on this socket before connect(). + +exit: + conn->lastRetryTime = clockNano(); + // We deliberately ignore ret as this function will be retried later if needed. + return; +fail: + if (closeSocketOnFail) + (void)ncclSocketClose(&sock->sock); + goto exit; +} + +// Sends an initial RAS message to the peer after connecting to it. +static ncclResult_t rasConnPrepare(struct rasConnection* conn) { + struct rasMsg* msg = nullptr; + int msgLen = rasMsgLength(RAS_MSG_CONNINIT); + + // The first message the RAS threads exchange provides the listening address of the connecting thread + // and the NCCL version to ensure that users aren't mixing things up. + NCCLCHECK(rasMsgAlloc(&msg, msgLen)); + msg->type = RAS_MSG_CONNINIT; + msg->connInit.ncclVersion = NCCL_VERSION_CODE; + memcpy(&msg->connInit.listeningAddr, &rasNetListeningSocket.addr, sizeof(msg->connInit.listeningAddr)); + msg->connInit.peersHash = rasPeersHash; + msg->connInit.deadPeersHash = rasDeadPeersHash; + // We don't update lastSent[Dead]PeersHash because we aren't actually sending the peers themselves here. + + rasConnEnqueueMsg(conn, msg, msgLen, /*front*/true); + + // We'll finish the initialization in rasMsgHandleConnInitAck, after the other side responds. + return ncclSuccess; +} + +// Searches through rasConns for a connection with a provided address. +int rasConnFind(const union ncclSocketAddress* addr) { + // rasConns is not sorted (given the number of indices, it would be a massive hassle to keep it that way) + // so binary search won't do... + for (int i = 0; i < nRasConns; i++) { + struct rasConnection* conn = rasConns+i; + if (conn->inUse && memcmp(&conn->addr, addr, sizeof(conn->addr)) == 0) + return i; + } + + return -1; +} + +// Handles any connection-related timeouts. Many timeouts affect the underlying sockets and thus have been handled +// in the socket timeout handler earlier by terminating the problematic sockets. If a socket connection doesn't +// exist or needs to be re-established (due to having just been terminated), we handle that here. +// This is also where we declare peers as dead, etc. +// Invoked from the main RAS event loop. +void rasConnsHandleTimeouts(int64_t now, int64_t* nextWakeup) { + for (int connIdx = 0; connIdx < nRasConns; connIdx++) { + struct rasConnection* conn = rasConns+connIdx; + + if (!conn->inUse) + continue; + + if (conn->sockIdx != -1) { + struct rasSocket* sock = rasSockets+conn->sockIdx; + bool sockTerminated = false; + + // Retry the socket connections that have been refused. + if (sock->status == RAS_SOCK_CONNECTING && sock->sock.state == ncclSocketStateConnecting) { + if (now - sock->lastSendTime > RAS_CONNECT_RETRY) { + int ready; + if (ncclSocketReady(&sock->sock, &ready) != ncclSuccess) { + INFO(NCCL_RAS, "Unexpected error from ncclSocketReady; terminating the socket connection with %s", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/true); + // We will retry below in the same loop. + sockTerminated = true; + } else { + // We update lastSendTime even if !ready because we need it up-to-date for timeout calculations. + sock->lastSendTime = clockNano(); + if (!ready && sock->sock.state == ncclSocketStateConnecting) + *nextWakeup = std::min(*nextWakeup, sock->lastSendTime+RAS_CONNECT_RETRY); + else + rasPfds[sock->pfd].fd = sock->sock.fd; // Enable the handling via the main loop. + } // if (ncclSocketReady) + } else { + *nextWakeup = std::min(*nextWakeup, sock->lastSendTime+RAS_CONNECT_RETRY); + } + } // if (sock->status == RAS_SOCK_CONNECTING && sock->sock.state == ncclSocketStateConnecting) + + // For connections that have data to send but that we've been unable to send a message on for a while, + // consider their sockets lost and terminate them. + if (!sockTerminated && !ncclIntruQueueEmpty(&conn->sendQ) && sock->status == RAS_SOCK_READY) { + if (now - std::max(sock->lastSendTime, ncclIntruQueueHead(&conn->sendQ)->enqueueTime) > RAS_STUCK_TIMEOUT) { + INFO(NCCL_RAS, "RAS send stuck timeout error (%lds) on socket connection with %s", + (now - std::max(sock->lastSendTime, ncclIntruQueueHead(&conn->sendQ)->enqueueTime)) / + CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/false, RAS_STUCK_TIMEOUT); + // We will retry below in the same loop. + } else { + *nextWakeup = std::min(*nextWakeup, std::max(sock->lastSendTime, + ncclIntruQueueHead(&conn->sendQ)->enqueueTime)+RAS_STUCK_TIMEOUT); + } + } // if (!ncclIntruQueueEmpty(&conn->sendQ) && sock->status == RAS_SOCK_READY) + } // if (conn->sockIdx != -1) + + // For connections that are being (re-)established, irrespective of whether there's a valid socket associated + // with them (conn->startIdx != -1), we need to check if any connection-level timeout has expired. + if (conn->startRetryTime) { + // If we've been trying to open a connection for too long (60s), give up and mark the peer as dead + // so that we don't try again. + if (now - conn->startRetryTime > RAS_PEER_DEAD_TIMEOUT) { + struct rasCollRequest bCast; + INFO(NCCL_RAS, "RAS connect retry timeout (%lds) on socket connection with %s", + (now-conn->startRetryTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&conn->addr, rasLine)); + + // Broadcast the info about a dead peer to everybody. This will handle it locally as well, including + // declaring the peer dead and terminating the connection. + rasCollReqInit(&bCast); + bCast.type = RAS_BC_DEADPEER; + memcpy(&bCast.deadPeer.addr, &conn->addr, sizeof(bCast.deadPeer.addr)); + (void)rasNetSendCollReq(&bCast, rasCollDataLength(RAS_BC_DEADPEER)); + + continue; + } else { + *nextWakeup = std::min(*nextWakeup, conn->startRetryTime+RAS_PEER_DEAD_TIMEOUT); + } + + // RAS_STUCK_TIMEOUT has already been handled in the socket function (we'll pick it up later via + // the conn->sockIdx == -1 test). + + // We print warnings after the same time as with keep-alive (5s), and we pessimistically immediately try + // to establish fallback connections. + if (now - conn->startRetryTime > RAS_CONNECT_WARN) { + if (!conn->experiencingDelays) { + INFO(NCCL_RAS, "RAS connect timeout warning (%lds) on socket connection with %s", + (now-conn->startRetryTime) / CLOCK_UNITS_PER_SEC, ncclSocketToString(&conn->addr, rasLine)); + + // See if the connection was meant to be a part of a RAS link and if so, try to initiate fallback + // connection(s). At this point, it's mostly just a precaution; we will continue trying to establish + // the primary connection until RAS_PEER_DEAD_TIMEOUT expires. + conn->experiencingDelays = true; + (void)rasLinkAddFallback(&rasNextLink, connIdx); + (void)rasLinkAddFallback(&rasPrevLink, connIdx); + // rasConns may have been reallocated by the above calls. + conn = rasConns+connIdx; + + // Stop collectives from waiting for a response over it. + rasCollsPurgeConn(connIdx); + } // if (!conn->experiencingDelays) + } else { + *nextWakeup = std::min(*nextWakeup, conn->startRetryTime+RAS_CONNECT_WARN); + } + + // If a socket was terminated (or never opened, due to some error), try to open it now. + // We retry once a second. + if (conn->sockIdx == -1) { + if (now - conn->lastRetryTime > RAS_CONNECT_RETRY) { + INFO(NCCL_RAS, "RAS trying to reconnect with %s (experiencingDelays %d, startRetryTime %.2fs)", + ncclSocketToString(&conn->addr, rasLine), conn->experiencingDelays, + (conn->startRetryTime ? (now-conn->startRetryTime)/1e9 : 0.0)); + rasConnOpen(conn); + } + if (conn->sockIdx == -1) + *nextWakeup = std::min(*nextWakeup, conn->lastRetryTime+RAS_CONNECT_RETRY); + } + } // if (conn->startRetryTime) + } // for (connIdx) +} + +// Checks if we have a connection to a given peer and if so, terminates it. The connection is removed from the +// RAS links, though fallbacks are initiated if necessary. Typically called just before declaring a peer dead. +void rasConnDisconnect(const union ncclSocketAddress* addr) { + int connIdx = rasConnFind(addr); + if (connIdx != -1) { + (void)rasLinkAddFallback(&rasNextLink, connIdx); + (void)rasLinkAddFallback(&rasPrevLink, connIdx); + rasLinkDropConn(&rasNextLink, connIdx); + rasLinkDropConn(&rasPrevLink, connIdx); + + rasConnTerminate(rasConns+connIdx); + } +} + +// Terminates a connection and frees the rasConns entry. +static void rasConnTerminate(struct rasConnection* conn) { + int connIdx = conn - rasConns; + + // Make sure there are no lingering rasSockets pointing to it. + for (int i = 0; i < nRasSockets; i++) { + struct rasSocket* sock = rasSockets+i; + if (sock->status != RAS_SOCK_CLOSED && sock->connIdx == connIdx) + rasSocketTerminate(sock, /*finalize*/true); + } + + // Also check any ongoing collectives. + rasCollsPurgeConn(connIdx); + + while (struct rasMsgMeta* meta = ncclIntruQueueTryDequeue(&conn->sendQ)) { + free(meta); + } + + INFO(NCCL_RAS, "RAS terminating a connection with %s", ncclSocketToString(&conn->addr, rasLine)); + + conn->inUse = false; + conn->sockIdx = -1; // Should be that way already, but just to be extra sure... +} + + +/////////////////////////////////////////// +// Functions related to the RAS sockets. // +/////////////////////////////////////////// + +// Accepts a new RAS network socket connection. The socket is not usable until after the handshake, as a +// corresponding rasConnection can't be established without knowing the peer's address. +ncclResult_t rasNetAcceptNewSocket() { + ncclResult_t ret = ncclSuccess; + struct rasSocket* sock; + int ready; + bool socketInitialized = false; + NCCLCHECKGOTO(getNewSockEntry(&sock), ret, fail); + + NCCLCHECKGOTO(ncclSocketInit(&sock->sock, nullptr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, + /*asyncFlag*/1), ret, fail); + socketInitialized = true; + NCCLCHECKGOTO(ncclSocketAccept(&sock->sock, &rasNetListeningSocket), ret, fail); + NCCLCHECKGOTO(ncclSocketReady(&sock->sock, &ready), ret, fail); + + if (sock->sock.fd != -1) { + NCCLCHECKGOTO(rasGetNewPollEntry(&sock->pfd), ret, fail); + rasPfds[sock->pfd].fd = sock->sock.fd; + rasPfds[sock->pfd].events = POLLIN; // Initially we'll just wait for a handshake from the other side. This also + // helps the code tell the sides apart. + sock->status = RAS_SOCK_CONNECTING; + + INFO(NCCL_RAS, "RAS new incoming socket connection from %s", ncclSocketToString(&sock->sock.addr, rasLine)); + } + +exit: + return ret; +fail: + if (socketInitialized) + NCCLCHECK(ncclSocketClose(&sock->sock)); + goto exit; +} + +// Returns the index of the first available entry in the rasConns array, enlarging the array if necessary. +static ncclResult_t getNewSockEntry(struct rasSocket** pSock) { + struct rasSocket* sock; + int i; + for (i = 0; i < nRasSockets; i++) + if (rasSockets[i].status == RAS_SOCK_CLOSED) + break; + if (i == nRasSockets) { + NCCLCHECK(ncclRealloc(&rasSockets, nRasSockets, nRasSockets+RAS_INCREMENT)); + nRasSockets += RAS_INCREMENT; + } + + sock = rasSockets+i; + memset(sock, '\0', sizeof(*sock)); + sock->pfd = -1; + sock->connIdx = -1; + sock->createTime = sock->lastSendTime = sock->lastRecvTime = clockNano(); + + *pSock = sock; + return ncclSuccess; +} + +// Invoked from the main RAS event loop to handle RAS socket timeouts. +void rasSocksHandleTimeouts(int64_t now, int64_t* nextWakeup) { + for (int sockIdx = 0; sockIdx < nRasSockets; sockIdx++) { + struct rasSocket* sock = rasSockets+sockIdx; + + if (sock->status == RAS_SOCK_CLOSED) + continue; + + // For socket connections that are still being established, give up on the ones that take too long to initialize. + if (sock->status == RAS_SOCK_CONNECTING || sock->status == RAS_SOCK_HANDSHAKE) { + if (now - sock->createTime > RAS_STUCK_TIMEOUT) { + if (sock->connIdx == -1) { + INFO(NCCL_RAS, "RAS init timeout error (%lds) on incoming socket connection from %s", + (now-sock->createTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine)); + } else { + struct rasConnection* conn = rasConns+sock->connIdx; + INFO(NCCL_RAS, "RAS init timeout error (%lds) on socket connection with %s " + "(experiencingDelays %d, startRetryTime %.2fs, socket status %d)", + (now-sock->createTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine), + conn->experiencingDelays, (conn->startRetryTime ? (now-conn->startRetryTime)/1e9 : 0.0), + sock->status); + } + rasSocketTerminate(sock, /*finalize*/true); + // We may retry later. + continue; + } else { + *nextWakeup = std::min(*nextWakeup, sock->createTime+RAS_STUCK_TIMEOUT); + } + } // if (sock->status == RAS_SOCK_CONNECTING || sock->status == RAS_SOCK_HANDSHAKE) + + // For sockets that are being terminated, force finalization of the ones that haven't made progress in too long. + if (sock->status == RAS_SOCK_TERMINATING) { + if (now - std::max(sock->lastSendTime, sock->lastRecvTime) > RAS_STUCK_TIMEOUT) { + INFO(NCCL_RAS, "RAS termination stuck timeout error (%lds) on socket connection with %s", + (now-std::max(sock->lastSendTime, sock->lastRecvTime)) / CLOCK_UNITS_PER_SEC, + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/true); + // This socket is presumably already being re-established, if needed. + continue; + } else { + *nextWakeup = std::min(*nextWakeup, std::max(sock->lastSendTime, sock->lastRecvTime)+RAS_STUCK_TIMEOUT); + } + } // if (sock->status == RAS_SOCK_TERMINATING) + + // Terminate sockets that haven't been used in a good while. In principle this shouldn't trigger for anything + // important due to shorter timeouts on RAS network connections, but in case of weird situations like process + // suspend, rasSocketTerminate will do additional checking. + if (sock->status == RAS_SOCK_READY) { + if (now - std::max(sock->lastSendTime, sock->lastRecvTime) > RAS_IDLE_TIMEOUT) { + INFO(NCCL_RAS, "RAS idle timeout (%lds) on socket connection with %s", + (now - std::max(sock->lastSendTime, sock->lastRecvTime)) / CLOCK_UNITS_PER_SEC, + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/false, /*startRetryOffset*/0, /*retry*/false); + continue; + // The RAS network timeout handler will terminate the conn it was associated with, if any. + } else { + *nextWakeup = std::min(*nextWakeup, std::max(sock->lastSendTime, sock->lastRecvTime)+RAS_IDLE_TIMEOUT); + } + } // if (sock->status == RAS_SOCK_READY) + } // for (sockIdx) +} + +// Handles the termination of a RAS socket. +// We try to do it in stages for established sockets (in READY state). We shut down just the sending side +// for them and change their state to TERMINATING, so that we can still receive data that may be in the buffers. +// Once we get an EOF when receiving data, we finalize the termination. +// For not fully established sockets, we can terminate immediately as there's no useful data to extract. +void rasSocketTerminate(struct rasSocket* sock, bool finalize, uint64_t startRetryOffset, bool retry) { + assert(sock->status != RAS_SOCK_CLOSED); + if (sock->connIdx != -1) { + struct rasConnection* conn = rasConns+sock->connIdx; + // If the sockIdx of the connection points back to us, it means that we are the current socket of this + // connection, so we have additional work to do before we can terminate it. + if (conn->sockIdx == sock-rasSockets) { + // Reset it to indicate there's no valid socket associated with that connection anymore. + conn->sockIdx = -1; + + // Don't attempt to retry on sockets that have been unused for so long that the remote peer probably + // deliberately closed them. Make an exception for sockets that are part of the RAS network links. + if ((retry && + clockNano() - std::max(sock->lastSendTime, sock->lastRecvTime) < RAS_IDLE_TIMEOUT - RAS_IDLE_GRACE_PERIOD) || + rasLinkFindConn(&rasNextLink, sock->connIdx) != -1 || rasLinkFindConn(&rasPrevLink, sock->connIdx) != -1) { + // For connections that were fine until now, the connection-level timeout starts at termination, and possibly + // even earlier, depending on what event trigerred the termination -- if it was another timeout expiring, then + // we need to include that timeout as well. + if (conn->startRetryTime == 0) { + conn->startRetryTime = conn->lastRetryTime = clockNano() - startRetryOffset; + } + + // We also filter through the sendQ, eliminating any messages that won't need to be sent when the socket + // connection is re-established (that's essentially the server init and keep-alives). + // As ncclIntruQueue can't be iterated, we transfer the content in bulk to a temporary and then filter the + // messages as we move them back one-by-one. + struct ncclIntruQueue sendQTmp; + ncclIntruQueueConstruct(&sendQTmp); + ncclIntruQueueTransfer(&sendQTmp, &conn->sendQ); + while (struct rasMsgMeta* meta = ncclIntruQueueTryDequeue(&sendQTmp)) { + if (meta->msg.type != RAS_MSG_CONNINIT && meta->msg.type != RAS_MSG_CONNINITACK && + meta->msg.type != RAS_MSG_KEEPALIVE) { + if (meta->offset != 0) { + // Reset the progress of any partially-sent messages (they will need to be resent from the beginning; + // in principle that could apply to the first message only). + meta->offset = 0; + } + ncclIntruQueueEnqueue(&conn->sendQ, meta); + } else { // RAS_MSG_CONNINIT || RAS_MSG_CONNINITACK || RAS_MSG_KEEPALIVE + free(meta); + } + } // while (meta) + } // if (retry) + + // Stop collectives from waiting for a response over this connection. + rasCollsPurgeConn(sock->connIdx); + } // if (conn->sockIdx == sock-rasSockets) + } // if (sock->connIdx != -1) + + if (sock->status != RAS_SOCK_CONNECTING && sock->connIdx != -1 && !finalize && (rasPfds[sock->pfd].events & POLLIN)) { + if (sock->status != RAS_SOCK_TERMINATING) { + // The receiving side is still open -- close just the sending side. + (void)ncclSocketShutdown(&sock->sock, SHUT_WR); + rasPfds[sock->pfd].events &= ~POLLOUT; // Nothing more to send. + // The timeout for this socket starts ticking now... + sock->lastSendTime = clockNano(); + sock->status = RAS_SOCK_TERMINATING; + } + // Else it must be in RAS_SOCK_TERMINATING state already -- in that case we do nothing here and instead + // we wait for an EOF on the receiving side or for a timeout. + } else { + // Either the caller requested finalization or we cannot receive on it. + (void)ncclSocketClose(&sock->sock); + sock->status = RAS_SOCK_CLOSED; + rasPfds[sock->pfd].fd = -1; + rasPfds[sock->pfd].events = rasPfds[sock->pfd].revents = 0; + sock->pfd = sock->connIdx = -1; + sock->recvOffset = sock->recvLength = 0; + free(sock->recvMsg); + sock->recvMsg = nullptr; + } +} + +// Handles a ready socket FD from the main event loop. +void rasSockEventLoop(int sockIdx, int pollIdx) { + struct rasSocket* sock = rasSockets+sockIdx; + + if (sock->status == RAS_SOCK_CONNECTING) { + int ready; + // Socket is not yet fully established. Continue the OS or NCCL-level handshake. + if (ncclSocketReady(&sock->sock, &ready) != ncclSuccess) { + INFO(NCCL_RAS, "RAS unexpected error from ncclSocketReady; terminating the socket connection with %s", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock); + // We may retry further down. + } else { + if (ready) { + // We can tell the connect-side based on what events is set to. + bool connectSide = (rasPfds[pollIdx].events & POLLOUT); + (connectSide ? sock->lastSendTime : sock->lastRecvTime) = clockNano(); + sock->status = RAS_SOCK_HANDSHAKE; + if (connectSide) { + assert(sock->connIdx != -1); + if (rasConns[sock->connIdx].sockIdx == sockIdx) { + if (rasConnPrepare(rasConns+sock->connIdx) != ncclSuccess) { + INFO(NCCL_RAS, "RAS unexpected error from rasConnPrepare; terminating the socket connection with %s", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock); + // We may retry further down. + } + } else { + // The connection this socket is associated with no longer considers it to be the current one. + // This could possibly happen due to a race condition. Simply terminate it. + INFO(NCCL_RAS, "RAS connected with %s via a socket that's no longer current!", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock); + } + } // if (connectSide) + } else { // !ready + if (sock->sock.state == ncclSocketStateConnecting) + rasPfds[sock->pfd].fd = POLL_FD_IGNORE; // Don't poll on this socket before connect(). + } + } // if (ncclSocketReady) + } else { // RAS_SOCK_HANDSHAKE || RAS_SOCK_READY || RAS_SOCK_TERMINATING. + // The extra test for TERMINATING is there to take care of a race when the handling of one socket + // results in another socket being terminated, but one that already has revents waiting from poll. + if (sock->status != RAS_SOCK_TERMINATING && (rasPfds[pollIdx].revents & POLLOUT)) { + int closed = 0; + bool allSent = false; + assert(sock->connIdx != -1); + struct rasConnection* conn = rasConns+sock->connIdx; + assert(conn->sockIdx == sockIdx); + if (rasConnSendMsg(conn, &closed, &allSent) != ncclSuccess) { + INFO(NCCL_RAS, "RAS unexpected error from rasConnSendMsg; terminating the socket connection with %s", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock); + // We may retry further down. + } else if (closed) { + INFO(NCCL_RAS, "RAS socket connection with %s closed by peer on send; terminating it", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock); + // We may retry further down. + } else { + sock->lastSendTime = clockNano(); + if (allSent) + rasPfds[sock->pfd].events &= ~POLLOUT; // Nothing more to send for now. + } + } + if (rasPfds[pollIdx].revents & POLLIN) { + struct rasMsg* msg; + do { + int closed = 0; + msg = nullptr; + if (rasMsgRecv(sock, &msg, &closed) != ncclSuccess) { + INFO(NCCL_RAS, "RAS unexpected error from rasMsgRecv; terminating the socket connection with %s", + ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/true); + // We may retry further down. + } else if (closed) { + const char* socketType; + if (sock->connIdx == -1) + socketType = "incoming"; + else if (rasConns[sock->connIdx].sockIdx != sockIdx) + socketType = "old"; + else if (sock->status == RAS_SOCK_HANDSHAKE) + socketType = "new"; + else + socketType = "current"; + INFO(NCCL_RAS, "RAS %s socket connection with %s closed by peer on receive; terminating it", + socketType, ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/true); + // We may retry further down. + } else { + sock->lastRecvTime = clockNano(); + if (msg) { + (void)rasMsgHandle(msg, sock); + free(msg); + // Message handlers can terminate a socket in certain cases; we need to check for + // that here so that we don't try to receive from a closed socket. + // No handlers are currently believed to create new sockets but better to be safe than sorry + // and re-init the sock variable. + sock = rasSockets+sockIdx; + if (sock->status == RAS_SOCK_CLOSED) + break; + } + if (sock->connIdx != -1) { + struct rasConnection* conn = rasConns+sock->connIdx; + if (conn->sockIdx == sockIdx && (conn->startRetryTime || conn->experiencingDelays)) + rasConnResume(conn); + } + } + } while (msg); + } // if (POLLIN) + } // RAS_SOCK_HANDSHAKE || RAS_SOCK_READY || RAS_SOCK_TERMINATING +} + + +//////////////////////////////////////////////////////////////// +// Functions related to the handling of RAS network timeouts. // +//////////////////////////////////////////////////////////////// + +// Invoked from the main RAS event loop to handle RAS network timeouts. +void rasNetHandleTimeouts(int64_t now, int64_t* nextWakeup) { + // A connection can belong to multiple links but, when it comes to various timeouts, we want to handle each + // connection just once. We solve that with a simple flag within a connection. This also allows us to distinguish + // connections that are part of a link from those that are not. + for (int connIdx = 0; connIdx < nRasConns; connIdx++) + rasConns[connIdx].linkFlag = false; + + (void)rasLinkHandleNetTimeouts(&rasNextLink, now, nextWakeup); + (void)rasLinkHandleNetTimeouts(&rasPrevLink, now, nextWakeup); + + for (int connIdx = 0; connIdx < nRasConns; connIdx++) { + struct rasConnection* conn = rasConns+connIdx; + if (conn->inUse && !conn->linkFlag) { + // The connection is not part of any link. Check if it should be terminated. + if (conn->sockIdx == -1 && ncclIntruQueueEmpty(&conn->sendQ)) { + rasConnTerminate(conn); + continue; + } + } + } +} + +// Checks for and handles timeouts at the link level; primarily the keep-alives for link connections. +static ncclResult_t rasLinkHandleNetTimeouts(struct rasLink* link, int64_t now, int64_t* nextWakeup) { + for (int i = 0; i < link->nConns; i++) { + struct rasLinkConn* linkConn = link->conns+i; + if (linkConn->connIdx != -1) { + if (!rasConns[linkConn->connIdx].linkFlag) { + rasConnHandleNetTimeouts(linkConn->connIdx, now, nextWakeup); + // rasConns may have been reallocated by the above call, which is why we don't have a conn variable here. + // For the same reason we re-init linkConn. + linkConn = link->conns+i; + rasConns[linkConn->connIdx].linkFlag = true; + } + } else if (i == 0 && link->lastUpdatePeersTime != 0) { + // This triggers when rasLinkReinitConns didn't create the primary connection because we have a higher address + // than the peer. If that peer fails to initiate within RAS_CONNECT_WARN, we need to take action. + if (now - link->lastUpdatePeersTime > RAS_CONNECT_WARN) { + INFO(NCCL_RAS, "RAS peer connect timeout warning (%lds) on socket connection from %s", + (now-link->lastUpdatePeersTime) / CLOCK_UNITS_PER_SEC, + ncclSocketToString(&rasPeers[linkConn->peerIdx].addr, rasLine)); + NCCLCHECK(rasConnCreate(&rasPeers[linkConn->peerIdx].addr, &linkConn->connIdx)); + if (linkConn->connIdx != -1) { + rasConns[linkConn->connIdx].linkFlag = true; + } + // We used to connect to the first fallback but I think trying to connect to the calculated primary first + // in this case is more intuitive. + //(void)rasLinkTryFallback(link, -1); + link->lastUpdatePeersTime = 0; + } else { + *nextWakeup = std::min(*nextWakeup, link->lastUpdatePeersTime+RAS_CONNECT_WARN); + } + } // if (i == 0 && link->lastUpdatePeerTime != 0) + } // for (i) + + return ncclSuccess; +} + +// Handles the sending of keep-alive messages and related timeouts for connections that are part of the RAS links. +static void rasConnHandleNetTimeouts(int connIdx, int64_t now, int64_t* nextWakeup) { + struct rasConnection* conn = rasConns+connIdx; + if (conn->sockIdx != -1) { + struct rasSocket* sock = rasSockets+conn->sockIdx; + + if (sock->status == RAS_SOCK_READY) { + // Send a regular keep-alive message if we haven't sent anything in a while and we don't have anything queued. + if (ncclIntruQueueEmpty(&conn->sendQ)) { + if (now - sock->lastSendTime > RAS_KEEPALIVE_INTERVAL) { + rasConnSendKeepAlive(conn); + } else { + *nextWakeup = std::min(*nextWakeup, sock->lastSendTime+RAS_KEEPALIVE_INTERVAL); + } + } + + // For short timeouts print a warning but also pessimistically immediately try to establish fallback connections. + if (now - sock->lastRecvTime > RAS_KEEPALIVE_TIMEOUT_WARN) { + if (!conn->experiencingDelays) { + INFO(NCCL_RAS, "RAS keep-alive timeout warning (%lds) on socket connection with %s", + (now-sock->lastRecvTime) / CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine)); + + // At this point, it's mostly just a precaution; we will continue with the primary connection until + // RAS_PEER_DEAD_TIMEOUT expires. + conn->experiencingDelays = true; + (void)rasLinkAddFallback(&rasNextLink, connIdx); + (void)rasLinkAddFallback(&rasPrevLink, connIdx); + // rasConns and rasSockets may have been reallocated by the above calls. + conn = rasConns+connIdx; + sock = rasSockets+conn->sockIdx; + + // Stop collectives from waiting for a response over it. + rasCollsPurgeConn(connIdx); + } + } else { + *nextWakeup = std::min(*nextWakeup, sock->lastRecvTime+RAS_KEEPALIVE_TIMEOUT_WARN); + } + + // For long timeouts we need to act. + if (now - sock->lastRecvTime > RAS_KEEPALIVE_TIMEOUT_ERROR) { + INFO(NCCL_RAS, "RAS keep-alive timeout error (%lds) on socket connection with %s", + (now-sock->lastRecvTime) / CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine)); + rasSocketTerminate(sock, /*finalize*/true, RAS_KEEPALIVE_TIMEOUT_ERROR); + *nextWakeup = now; // Retry will be in the next iteration of the main loop so ensure we don't wait. + } else { + *nextWakeup = std::min(*nextWakeup, sock->lastRecvTime+RAS_KEEPALIVE_TIMEOUT_ERROR); + } + } // if (sock->status == RAS_SOCK_READY) + } // if (conn->sockIdx != -1) +} + +// Sends a keep-alive message to a peer on the RAS network. +static void rasConnSendKeepAlive(struct rasConnection* conn, bool nack) { + struct rasMsg* msg = nullptr; + int msgLen = rasMsgLength(RAS_MSG_KEEPALIVE); + if (rasMsgAlloc(&msg, msgLen) == ncclSuccess) { + int linkIdx; + msg->type = RAS_MSG_KEEPALIVE; + msg->keepAlive.peersHash = rasPeersHash; + msg->keepAlive.deadPeersHash = rasDeadPeersHash; + msg->keepAlive.nack = (nack ? 1 : 0); + + linkIdx = rasLinkFindConn(&rasNextLink, conn-rasConns); + if (linkIdx != -1 && !rasNextLink.conns[linkIdx].external) + msg->keepAlive.linkMask |= 2; // Our rasNextLink should be the peer's rasPrevLink. + linkIdx = rasLinkFindConn(&rasPrevLink, conn-rasConns); + if (linkIdx != -1 && !rasPrevLink.conns[linkIdx].external) + msg->keepAlive.linkMask |= 1; // Our rasPrevLink should be the peer's rasNextLink. + + (void)clock_gettime(CLOCK_REALTIME, &msg->keepAlive.realTime); + + rasConnEnqueueMsg(conn, msg, msgLen); + } +} + +// Handles incoming keep-alive messages. +ncclResult_t rasMsgHandleKeepAlive(const struct rasMsg* msg, struct rasSocket* sock) { + struct timespec currentTime; + int64_t travelTime; + int peerIdx; + + assert(sock->connIdx != -1); + struct rasConnection* conn = rasConns+sock->connIdx; + SYSCHECK(clock_gettime(CLOCK_REALTIME, ¤tTime), "clock_gettime"); + travelTime = (currentTime.tv_sec-msg->keepAlive.realTime.tv_sec)*1000*1000*1000 + + (currentTime.tv_nsec-msg->keepAlive.realTime.tv_nsec); + + if (msg->keepAlive.peersHash != conn->lastRecvPeersHash) { + conn->lastRecvPeersHash = msg->keepAlive.peersHash; + } + if (msg->keepAlive.deadPeersHash != conn->lastRecvDeadPeersHash) { + conn->lastRecvDeadPeersHash = msg->keepAlive.deadPeersHash; + } + + // Make sure that the connection is part of the appropriate links forming the RAS network. In particular, this + // will add any externally-requested connections to the appropriate links (or remove existing ones, if no longer + // needed). + peerIdx = rasPeerFind(&conn->addr); + // Note: it's possible for peerIdx to be -1 at this point if, due to races, the keepAlive arrives before + // the peers update. + (void)rasLinkUpdateConn(&rasNextLink, (msg->keepAlive.linkMask & 1) ? sock->connIdx : -1, peerIdx, /*external*/true); + (void)rasLinkUpdateConn(&rasPrevLink, (msg->keepAlive.linkMask & 2) ? sock->connIdx : -1, peerIdx, /*external*/true); + + // If the keep-alive message is from a peer that doesn't actually need this connection (i.e., for that peer the + // connection is just an external fallback), we should check if *we* still need it. It might be that we don't, + // and because we stopped sending the keep-alives, our peer doesn't know about it. rasLinkUpdateConn calls above + // will have wiped any external fallbacks, so anything that remains must be needed. + if (!msg->keepAlive.nack && msg->keepAlive.linkMask == 0) { + if (rasLinkFindConn(&rasNextLink, sock->connIdx) == -1 && rasLinkFindConn(&rasPrevLink, sock->connIdx) == -1) { + // We don't need this connection either. Notify the peer about it. To avoid an infinite loop, we set the + // special nack flag in the message to distinguish it from regular keep-alives. + rasConnSendKeepAlive(conn, /*nack*/true); + } + } + + if (conn->travelTimeMin > travelTime) + conn->travelTimeMin = travelTime; + if (conn->travelTimeMax < travelTime) + conn->travelTimeMax = travelTime; + conn->travelTimeSum += travelTime; + conn->travelTimeCount++; + + if (msg->keepAlive.peersHash != rasPeersHash || msg->keepAlive.deadPeersHash != rasDeadPeersHash) { + // This could happen due to a short-lived race condition between the peers propagation + // process and the periodic keep-alive messages (perhaps we'll see it regularly at scale?). + // Just in case there's some unforeseen problem with the peers propagation though, exchange with the + // remote to get everybody in sync. + INFO(NCCL_RAS, "RAS keepAlive hash mismatch from %s (peersHash 0x%lx, deadPeersHash 0x%lx)", + ncclSocketToString(&sock->sock.addr, rasLine), msg->keepAlive.peersHash, msg->keepAlive.deadPeersHash); + INFO(NCCL_RAS, "RAS my peersHash 0x%lx, deadPeersHash 0x%lx", rasPeersHash, rasDeadPeersHash); + NCCLCHECK(rasConnSendPeersUpdate(conn, rasPeers, nRasPeers)); + } + return ncclSuccess; +} + + +/////////////////////////////////////////////////////////////////////////////// +// Functions related to the RAS links and recovery from connection failures. // +/////////////////////////////////////////////////////////////////////////////// + +// Checks if the connection (that we just detected some problem with) is part of the RAS link and if so, +// tries to initiate a(nother) fallback connection if needed. +// External connections are generally ignored by this whole process: in particular, we don't add fallbacks for +// timing out external connections. However, we will use an active external connection if it would be a better +// option than whatever we can come up with. +static ncclResult_t rasLinkAddFallback(struct rasLink* link, int connIdx) { + int peerIdx = -1; + int linkIdx = -1; + int firstExtLinkIdx = -1; + int newPeerIdx; + + // First check if the connection is part of this link. In the process also check if any of the link's connections + // might be active -- if so, there's no need to initiate any more fallbacks and we can bail out. + for (int i = 0; i < link->nConns; i++) { + struct rasLinkConn* linkConn = link->conns+i; + + if (linkConn->peerIdx == -1) { + // Such elements are always at the very end of the array and we can't use them so we can just as well break. + break; + } + + // Check for any other connection that might be a viable fallback (basically, anything that is not experiencing + // delays). + if (linkConn->connIdx != -1 && linkConn->connIdx != connIdx) { + struct rasConnection* conn = rasConns+linkConn->connIdx; + if (!conn->experiencingDelays) { + if (!linkConn->external) + goto exit; // We don't need to do anything if there's a non-external connection. + else if (linkConn->peerIdx != -1) { + // Record the location of the first potentially viable external connection in the chain; we may prefer it + // over anything we can come up with. + if (firstExtLinkIdx == -1) + firstExtLinkIdx = i; + if (linkIdx != -1) + break; // Break out of the loop if we already have all the data we might need. + } // linkConn->external && linkConn->peerIdx != -1 + } // if (!conn->experiencingDelays) + } // if (linkConn->connIdx != -1) + + if (linkConn->connIdx == connIdx) { + if (linkConn->external) + goto exit; // We don't add fallbacks for external connections... + peerIdx = linkConn->peerIdx; + linkIdx = i; + // We are not breaking out of the loop here because we want to check for active connections on *all* potentially + // viable elements (in particular, there could be some external ones beyond this one). + } + } + + if (linkIdx == -1) + goto exit; + + // We found an existing element so the connection is part of the link. No existing non-external connections of this + // link are active, so a fallback is needed. + assert(peerIdx != -1); + newPeerIdx = rasLinkCalculatePeer(link, peerIdx, /*isFallback*/linkIdx > 0); + // In principle we want to add (at most) one fallback. However, if the found fallback connection already exists + // and is also experiencing delays, we need to keep iterating. + while (newPeerIdx != -1) { + int newConnIdx = rasConnFind(&rasPeers[newPeerIdx].addr); + // If we previously found a potential external fallback connection, check if it's better than what we just found. + if (firstExtLinkIdx != -1) { + linkIdx = -1; + // Calculate the index that the newly found fallback would have (pretend mode). + NCCLCHECK(rasLinkUpdateConn(link, newConnIdx, newPeerIdx, /*external*/false, /*insert*/true, /*pretend*/true, + &linkIdx)); + assert(linkIdx != -1); + if (firstExtLinkIdx < linkIdx) { + // The external connection *is* better -- use it as a fallback instead and be done. + link->conns[firstExtLinkIdx].external = false; + goto exit; + } + } + NCCLCHECK(rasLinkUpdateConn(link, newConnIdx, newPeerIdx, /*external*/false, /*insert*/true, /*pretend*/false, + &linkIdx)); + if (firstExtLinkIdx != -1 && linkIdx <= firstExtLinkIdx) + firstExtLinkIdx++; // Adjust if we inserted a new conn at a lower index. + + INFO(NCCL_RAS, "RAS link %d: %s fallback connection %d with %s", + link->direction, (newConnIdx == -1 ? "opening new" : "calculated existing"), + linkIdx, ncclSocketToString(&rasPeers[newPeerIdx].addr, rasLine)); + // Note that we don't follow here our convention of "lower address is the one establishing connections" -- + // that convention is for optimizing regular operations, but we don't want to take chances during fault + // recovery. It may temporarily result in duplicate connections, but we have a mechanism to deal with those. + if (newConnIdx == -1) + NCCLCHECK(rasConnCreate(&rasPeers[newPeerIdx].addr, &link->conns[linkIdx].connIdx)); + + struct rasConnection* conn = rasConns+link->conns[linkIdx].connIdx; + // If the fallback connection is also experiencing delays, we need to keep trying. + if (!conn->experiencingDelays) + break; + INFO(NCCL_RAS, "RAS connection experiencingDelays %d, startRetryTime %.2fs, socket status %d", + conn->experiencingDelays, (conn->startRetryTime ? (clockNano()-conn->startRetryTime)/1e9 : 0.0), + (conn->sockIdx == -1 ? -1 : rasSockets[conn->sockIdx].status)); + + newPeerIdx = rasLinkCalculatePeer(link, newPeerIdx, /*isFallback*/true); + } + if (newPeerIdx == -1) + INFO(NCCL_RAS, "RAS link %d: no more fallbacks to add (nConns %d)", link->direction, link->nConns); +exit: + return ncclSuccess; +} + +// Invoked when we receive a message over a connection that was just activated or was experiencing delays. +// Cleans up the fallbacks, timers, etc, as appropriate. +static void rasConnResume(struct rasConnection* conn) { + if (conn->sockIdx != -1 && rasSockets[conn->sockIdx].status == RAS_SOCK_READY) { + INFO(NCCL_RAS, "RAS %s connection with %s (sendQ %sempty, experiencingDelays %d, startRetryTime %.2fs)", + (conn->experiencingDelays && conn->startRetryTime == 0 ? "recovered" : "established"), + ncclSocketToString(&conn->addr, rasLine), (ncclIntruQueueEmpty(&conn->sendQ) ? "" : "not "), + conn->experiencingDelays, (conn->startRetryTime ? (clockNano()-conn->startRetryTime)/1e9 : 0.0)); + + conn->experiencingDelays = false; + + conn->startRetryTime = conn->lastRetryTime = 0; + + rasLinkSanitizeFallbacks(&rasNextLink); + rasLinkSanitizeFallbacks(&rasPrevLink); + + if (!ncclIntruQueueEmpty(&conn->sendQ)) + rasPfds[rasSockets[conn->sockIdx].pfd].events |= POLLOUT; + } +} + +// Checks if the primary connection is fully established and if so, purges the fallbacks (as they are no longer needed). +static void rasLinkSanitizeFallbacks(struct rasLink* link) { + if (link->nConns > 0 && link->conns[0].connIdx != -1) { + struct rasConnection* conn = rasConns+link->conns[0].connIdx; + if (conn->sockIdx != -1 && rasSockets[conn->sockIdx].status == RAS_SOCK_READY && !conn->experiencingDelays) { + // We have a good primary. Simply drop all the fallbacks (the external ones will get recreated via the + // keepAlive messages). + for (int i = 1; i < link->nConns; i++) { + INFO(NCCL_RAS, "RAS link %d: dropping %sfallback connection %d with %s", + link->direction, (link->conns[i].external ? "external " : ""), i, + ncclSocketToString(&rasConns[link->conns[i].connIdx].addr, rasLine)); + } + link->nConns = 1; + link->lastUpdatePeersTime = 0; + } + } +} + +// Attempt to drop a connection from a link. +static void rasLinkDropConn(struct rasLink* link, int connIdx, int linkIdx) { + if (linkIdx == -1) + linkIdx = rasLinkFindConn(link, connIdx); + if (linkIdx != -1) { + if (linkIdx == 0) { + INFO(NCCL_RAS, "RAS link %d: dropping primary connection with %s", + link->direction, ncclSocketToString(&rasConns[connIdx].addr, rasLine)); + } else { + INFO(NCCL_RAS, "RAS link %d: dropping %sfallback connection %d with %s", + link->direction, (link->conns[linkIdx].external ? "external " : ""), linkIdx, + ncclSocketToString(&rasConns[connIdx].addr, rasLine)); + } + memmove(link->conns+linkIdx, link->conns+linkIdx+1, (link->nConns-(linkIdx+1))*sizeof(*link->conns)); + if (link->nConns > 1) + link->nConns--; + else { + link->conns[0].peerIdx = link->conns[0].connIdx = -1; + } + + if (linkIdx == 0) { + // First ensure that the conn becoming the primary is not marked as external (we don't want to lose it if + // the remote peer loses interest in it). + link->conns[0].external = false; + if (link->conns[0].connIdx != -1) { + INFO(NCCL_RAS, "RAS link %d: former fallback connection 1 with %s is the new primary", + link->direction, ncclSocketToString(&rasConns[link->conns[0].connIdx].addr, rasLine)); + } + rasLinkSanitizeFallbacks(link); + } + } +} + +// Checks if a given connection is a member of this link and if so, returns its entry index. +// Returns -1 if connection not found. +static int rasLinkFindConn(const struct rasLink* link, int connIdx) { + for (int i = 0; i < link->nConns; i++) { + if (link->conns[i].connIdx == connIdx) + return i; + } + return -1; +} + +// Note: the behavior of this function has become super-complex and so it should be considered for refactoring. +// Searches for and updates an entry in a RAS network link. The conns array is de-facto sorted by peerIdx: it is +// ordered by preference, though peerIdx values can wrap around (given the ring/torus topology) and they can also +// be -1 (the latter are stored at the end). +// external provides an updated value for the entry's external field. A false value, if requested, is always set; +// a true value, however, is only set if a new entry is added (external == true implies insert), i.e., if an entry +// already exists and the function is invoked with external == true, the new value will be ignored. +// If insert is set, it will, if necessary, insert a new entry if one is not already there. +// If pretend is set, it will not modify the array and will just set *pLinkIdx as appropriate. +// pLinkIdx is a pointer to an (optional) result where the index of the added/updated entry is stored. +// -1 can be passed as peerIdx if unknown (possible in case of race conditions, and only if external). +// -1 can be passed as connIdx if unknown or, if insert is *not* set, to indicate that the entry is to be removed +// (the entry's external must match the argument external for it to be removed). +ncclResult_t rasLinkUpdateConn(struct rasLink* link, int connIdx, int peerIdx, bool external, bool insert, + bool pretend, int* pLinkIdx) { + int i, oldLinkIdx = -1; + + if (external && connIdx != -1) + insert = true; + + if (connIdx != -1) { + // Start by checking if we already have an element with this connIdx. + oldLinkIdx = rasLinkFindConn(link, connIdx); + if (oldLinkIdx != -1) { + struct rasLinkConn* linkConn = link->conns+oldLinkIdx; + if (linkConn->peerIdx != -1) + assert(linkConn->peerIdx == peerIdx); + + if (linkConn->peerIdx == peerIdx) { + if (!external && !pretend) + linkConn->external = false; // Ensure that external is cleared if so requested. + if (pLinkIdx) + *pLinkIdx = oldLinkIdx; + goto exit; // Nothing more to do if both connIdx and peerIdx are up to date. + } + + // Otherwise (linkConn->peerIdx == -1 && peerIdx != -1) we have a conn that, due to -1 peerIdx, is in a wrong + // place in the array -- we need to find the right spot. linkConn->peerIdx == -1 can only happen for external + // connections. + assert(external); + } + } + + if (peerIdx != -1) { + // Search for the right spot in the conns array. + for (i = 0; i < link->nConns; i++) { + struct rasLinkConn* linkConn = link->conns+i; + if (peerIdx != -1 && linkConn->peerIdx == peerIdx) { + // The exact conn element already exists. + if (connIdx == -1 && !insert) { + // Drop the connection from the link. + if (linkConn->external == external) { + if (!pretend) + rasLinkDropConn(link, linkConn->connIdx, i); + else if (pLinkIdx) + *pLinkIdx = i; + } + } else { // connIdx != -1 || insert + if (!pretend) { + if (linkConn->connIdx != -1) + assert(linkConn->connIdx == connIdx); + else + linkConn->connIdx = connIdx; + if (!external) + linkConn->external = false; // Ensure that external is cleared if so requested. + if (i == 0) { + // We received a connection from the remote peer that matches the primary connection we've been + // waiting for. + rasLinkSanitizeFallbacks(link); + } + } // if (!pretend) + if (pLinkIdx) + *pLinkIdx = i; + } // connIdx != -1 || insert + + goto exit; + } // if (peerIdx != -1 && linkConn->peerIdx == peerIdx) + if (!insert) + continue; + // Ensure that the i-1 index is also valid. + if (i == 0) + continue; + // Conns with peerIdx == -1 are stored at the end, so anything else needs to go before them. + if (peerIdx != -1 && linkConn->peerIdx == -1) + break; + // Detect a roll-over and handle it specially. + if (link->direction * (link->conns[i-1].peerIdx - linkConn->peerIdx) > 0) { + if (link->direction * (peerIdx - link->conns[i-1].peerIdx) > 0 || + link->direction * (peerIdx - linkConn->peerIdx) < 0) + break; + } else { // Regular, monotonic case with the peerIdx value between two existing elements. + if (link->direction * (peerIdx - link->conns[i-1].peerIdx) > 0 && + link->direction * (peerIdx - linkConn->peerIdx) < 0) + break; + } + } // for (i) + } else { + // If peerIdx == -1, insert the new element at the very end. This can only happen for external connections. + assert(external && oldLinkIdx == -1); + i = link->nConns; + } + if (!insert) + goto exit; + + // i holds the index at which to insert a new element. + if (pretend) { + if (pLinkIdx) + *pLinkIdx = i; + goto exit; + } + + if (oldLinkIdx == -1) { + struct rasLinkConn* linkConn; + if (link->nConns == link->connsSize) { + NCCLCHECK(ncclRealloc(&link->conns, link->connsSize, link->connsSize+RAS_INCREMENT)); + link->connsSize += RAS_INCREMENT; + } + linkConn = link->conns+i; + // Shift existing conns with indices >= i to make room for the new one. + memmove(linkConn+1, linkConn, (link->nConns-i)*sizeof(*link->conns)); + linkConn->peerIdx = peerIdx; + linkConn->connIdx = connIdx; + linkConn->external = external; + if (external) { + INFO(NCCL_RAS, "RAS link %d: adding external fallback connection %d with %s", link->direction, i, + ncclSocketToString((connIdx != -1 ? &rasConns[connIdx].addr : &rasPeers[peerIdx].addr), rasLine)); + } + link->nConns++; + } + else { // oldLinkIdx > -1 + // We already have the conn, we just need to move it to a new spot. + struct rasLinkConn* linkConn = link->conns+i; + assert(i <= oldLinkIdx); // We can only get here if linkConn->peerIdx == -1 && peerIdx != -1. + if (i != oldLinkIdx) { + struct rasLinkConn tmp; + struct rasLinkConn* linkConnNext = link->conns+i+1; // Just to silence the compiler. + // Move the existing conn from index oldLinkIdx to a (lower) index i, shifting the existing conns + // with indices in the range [i, oldLinkIdx). + memcpy(&tmp, link->conns+oldLinkIdx, sizeof(tmp)); + memmove(linkConnNext, linkConn, (oldLinkIdx-i)*sizeof(*linkConn)); + memcpy(linkConn, &tmp, sizeof(*linkConn)); + } + if (!external) + linkConn->external = false; // Ensure that external is cleared if so requested. + } // oldLinkIdx > -1 + if (pLinkIdx) + *pLinkIdx = i; +exit: + return ncclSuccess; +} diff --git a/src/register.cc b/src/register.cc deleted file mode 100644 index 5661c47511..0000000000 --- a/src/register.cc +++ /dev/null @@ -1,244 +0,0 @@ -/************************************************************************* - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "argcheck.h" // Need some checks here since we access comm -#include "nccl.h" -#include "comm.h" -#include "net.h" -#include "register.h" -#include "transport.h" -#include "api_trace.h" -#ifdef ENABLE_MSCCLPP -#include "mscclpp/mscclpp_nccl.h" -#endif - -using namespace rccl; - -ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) { - struct ncclRegCache* cache = &comm->regCache; - ncclDebugNoWarn = NCCL_NET; - for (int d=0; dnDevs; d++) { - if (reg->handles[d] != NULL) NCCLCHECK(comm->ncclNet->deregMr(cache->sComms[reg->devs[d]], reg->handles[d])); - } - reg->nDevs = 0; - free(reg->handles); - reg->handles = NULL; - ncclDebugNoWarn = 0; - return ncclSuccess; -} - -ncclResult_t ncclNetRegister(struct ncclComm* comm, void* addr, size_t size, struct ncclReg* reg) { - struct ncclRegCache* cache = &comm->regCache; - int netCount = 0; - if (comm->topo != NULL) NCCLCHECK(ncclTopoGetNetCount(comm->topo, &netCount)); - if (netCount == 0) return ncclSuccess; - - ncclResult_t ret = ncclSuccess; - - // Find local devices for p2p operations - for (int c=0; cp2pnChannels; c++) { - int dev; - if (ncclTopoGetLocalNet(comm->topo, comm->rank, c, NULL, &dev) != ncclSuccess) goto end; // No local net - ncclNetProperties_t props; - NCCLCHECKGOTO(comm->ncclNet->getProperties(dev, &props), ret, end); - if (props.regIsGlobal == 0) { // We need to be sure all NICs support global registration. - reg->nDevs = 0; - break; - } - int found = 0; - for (int d=0; dnDevs; d++) if (reg->devs[d] == dev) found = 1; - if (!found) reg->devs[reg->nDevs++] = dev; - } - - NCCLCHECKGOTO(ncclCalloc(®->handles, reg->nDevs), ret, end); - - ncclDebugNoWarn = NCCL_NET; - for (int d=0; dnDevs; d++) { - int dev = reg->devs[d]; - reg->handles[d] = NULL; - - if (cache->sComms[dev] == NULL) { - // Create a loopback network comm object for that device to register the buffers. - void *lComm = NULL; - ncclNetHandle_t netHandle; - bool connected = false; - NCCLCHECKGOTO(comm->ncclNet->listen(dev, &netHandle, &lComm), ret, end); - while (!connected) { - if (*comm->abortFlag) { - goto end; - } - if (cache->sComms[dev] == NULL) - NCCLCHECKGOTO(comm->ncclNet->connect(dev, &netHandle, cache->sComms+dev, NULL), ret, end); - if (cache->rComms[dev] == NULL) - NCCLCHECKGOTO(comm->ncclNet->accept(lComm, cache->rComms+dev, NULL), ret, end); - connected = (cache->rComms[dev] != NULL) && (cache->sComms[dev] != NULL); - } - NCCLCHECK(comm->ncclNet->closeListen(lComm)); - } - if (comm->ncclNet->regMr(cache->sComms[dev], addr, size, NCCL_PTR_CUDA, reg->handles+d) != ncclSuccess) { - reg->handles[d] = NULL; - NCCLCHECK(ncclNetDeregister(comm, reg)); - reg->nDevs = 0; - goto end; - } - } -end: - INFO(NCCL_INIT, "Register ptr %p size %ld on %d net devices", addr, size, reg->nDevs); - ncclDebugNoWarn = 0; - if (ret != ncclSuccess) NCCLCHECK(ncclNetDeregister(comm, reg)); - return ret; -} - -ncclResult_t ncclRegFind(struct ncclComm* comm, const void* data, size_t size, struct ncclReg** reg) { - struct ncclRegCache* cache = &comm->regCache; - uintptr_t pageSize = cache->pageSize; - uintptr_t addr = (uintptr_t)data & -pageSize; - size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; - - *reg = NULL; - for (int slot=0; /*true*/; slot++) { - if (slot == cache->population || addr < cache->slots[slot]->addr) return ncclSuccess; - if ((addr >= cache->slots[slot]->addr) && - ((addr-cache->slots[slot]->addr)/pageSize+pages) <= cache->slots[slot]->pages) { - *reg = cache->slots[slot]; - return ncclSuccess; - } - } -} -NCCL_PARAM(LocalRegister, "LOCAL_REGISTER", 1); - -ncclResult_t ncclRegister(struct ncclComm* comm, void* data, size_t size, void** handle) { - if (!ncclParamLocalRegister()) { - *handle = NULL; - return ncclSuccess; - } - INFO(NCCL_REG, "register comm %p buffer %p size %zi", comm, data, size); - struct ncclRegCache* cache = &comm->regCache; - uintptr_t pageSize = cache->pageSize; - uintptr_t addr = (uintptr_t)data & -pageSize; - size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; - for (int slot=0; /*true*/; slot++) { - if ((slot == cache->population) || (addr < cache->slots[slot]->addr)) { - if (cache->population == cache->capacity) { // must grow cache - cache->capacity = cache->capacity < 32 ? 32 : 2*cache->capacity; - NCCLCHECK(ncclRealloc(&cache->slots, cache->population, cache->capacity)); - } - memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclReg*)); - NCCLCHECK(ncclCalloc(cache->slots+slot, 1)); - struct ncclReg* regSlot = cache->slots[slot]; - regSlot->addr = addr; - regSlot->pages = pages; - regSlot->refs = 1; - NCCLCHECK(ncclNetRegister(comm, (void*)addr, pages*pageSize, regSlot)); - regSlot->state |= NET_REG_COMPLETE; - cache->population += 1; - *handle = regSlot; - return ncclSuccess; - } else if ((addr >= cache->slots[slot]->addr) && - ((addr-cache->slots[slot]->addr)/pageSize+pages) <= cache->slots[slot]->pages) { - cache->slots[slot]->refs++; - *handle = cache->slots[slot]; - return ncclSuccess; - } - } -} - -ncclResult_t ncclRegCleanup(struct ncclComm* comm) { - struct ncclRegCache* cache = &comm->regCache; - for (int i=0; ipopulation; i++) { - INFO(NCCL_INIT, "Cleanup buffer %p pages %lx", (void*)cache->slots[i]->addr, cache->slots[i]->pages); - NCCLCHECK(ncclNetDeregister(comm, cache->slots[i])); - if (cache->slots[i]->state & NVLS_REG_COMPLETE) NCCLCHECK(ncclNvlsDeregBuffer(&cache->slots[i]->mcHandle, cache->slots[i]->regAddr, cache->slots[i]->dev, cache->slots[i]->regSize)); - free(cache->slots[i]); - } - free(cache->slots); - for (int d=0; dsComms[d]) NCCLCHECK(comm->ncclNet->closeSend(cache->sComms[d])); - if (cache->rComms[d]) NCCLCHECK(comm->ncclNet->closeRecv(cache->rComms[d])); - } - return ncclSuccess; -} - -NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size_t size, void** handle); -ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) { - ncclResult_t ret = ncclSuccess; - - NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); - if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister")); - #ifdef ENABLE_MSCCLPP - if (comm->mscclppCompatible) { - if (comm->mscclCompatible && size > 0){ - bool isManagedBuffer = false; - CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(buff))); - if(!isManagedBuffer){ - INFO(NCCL_INIT, "MSCCL++: ncclCommRegister"); - NCCLCHECKGOTO(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle), ret, end); - } - else{ - WARN("MSCCL++: Cannot register user-buffers on managed memory. RCCL user-buffer registration will occur."); - } - } - } - #endif - INFO(NCCL_INIT, "RCCL: ncclCommRegister"); - NCCLCHECKGOTO(ncclRegister(comm, buff, size, handle), ret, end); - -end: - // !recording at sink - NCCLCHECK(Recorder::instance().record(rrCommRegister, comm, *handle, buff, size)); - return ret; -} - -NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); -ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) { - NCCLCHECK(Recorder::instance().record(rrCommDeregister, comm, handle)); - - #ifdef ENABLE_MSCCLPP - if (comm->mscclppCompatible) { - const size_t size = mscclpp_BufferSize(comm->mscclpp_comm, handle); - if (comm->mscclCompatible && size > 0) { - NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle)); - return ncclSuccess; - } - } - #endif - - NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); - struct ncclReg* reg = (struct ncclReg*)handle; - struct ncclRegCache* cache = &comm->regCache; - int slot; - int saveDev; - if (handle == NULL) goto exit; - CUDACHECK(cudaGetDevice(&saveDev)); - CUDACHECK(cudaSetDevice(comm->cudaDev)); - for (slot=0; slotpopulation && cache->slots[slot] != reg; slot++); - if (slot == cache->population) { - WARN("Deregister: Could not find handle"); - return ncclInvalidUsage; - } - if (--reg->refs) return ncclSuccess; - NCCLCHECK(ncclNetDeregister(comm, reg)); - if (reg->state & NVLS_REG_COMPLETE) { - NCCLCHECK(ncclNvlsDeregBuffer(®->mcHandle, reg->regAddr, reg->dev, reg->regSize)); - reg->regAddr = (CUdeviceptr)NULL; - } - if (reg->state & COLLNET_REG_COMPLETE) { - NCCLCHECK(ncclCollnetDeregBuffer(comm, reg->collnetProxyconn, reg->collnetHandle)); - } - if (reg->state & IPC_REG_COMPLETE) { - for (int i = 0; i < NCCL_MAX_LOCAL_RANKS; ++i) - if (reg->ipcInfos[i]) - NCCLCHECK(ncclIpcDeregBuffer(comm, reg->ipcInfos[i])); - if (reg->regIpcAddrs.hostPeerRmtAddrs) free(reg->regIpcAddrs.hostPeerRmtAddrs); - if (reg->regIpcAddrs.devPeerRmtAddrs) NCCLCHECK(ncclCudaFree(reg->regIpcAddrs.devPeerRmtAddrs)); - } - free(reg); - memmove(cache->slots+slot, cache->slots+slot+1, (cache->population-slot-1)*sizeof(struct ncclReg*)); - cache->population -= 1; - CUDACHECK(cudaSetDevice(saveDev)); -exit: - return ncclSuccess; -} diff --git a/src/register/coll_reg.cc b/src/register/coll_reg.cc new file mode 100644 index 0000000000..ef59514287 --- /dev/null +++ b/src/register/coll_reg.cc @@ -0,0 +1,446 @@ +#include "register.h" +#include "transport.h" +#include "enqueue.h" + +static ncclResult_t registerCheckP2PConnection(struct ncclComm* comm, struct ncclConnector* conn, struct ncclTopoGraph* graph, int peer, bool* needReg) { + if (conn->connected) { + if (conn->conn.flags & (NCCL_P2P_READ | NCCL_P2P_WRITE)) { + *needReg = true; + } else { + // network connection + *needReg = false; + } + } else { + struct ncclPeerInfo* peerInfo = &comm->peerInfo[peer]; + struct ncclPeerInfo* myInfo = &comm->peerInfo[comm->rank]; + int canConnect = 0; + NCCLCHECK(ncclTransports[0]->canConnect(&canConnect, comm, graph, myInfo, peerInfo)); + if (canConnect) { + *needReg = true; + } else { + *needReg = false; + } + } + return ncclSuccess; +} + +ncclResult_t ncclRegisterCollNvlsBuffers( + struct ncclComm* comm, struct ncclTaskColl* info, + void* outRegBufSend[NCCL_MAX_LOCAL_RANKS], + void* outRegBufRecv[NCCL_MAX_LOCAL_RANKS], + struct ncclIntruQueue* cleanupQueue, + bool* regNeedConnect + ) { + ncclResult_t result = ncclSuccess; + + info->regBufType = NCCL_REGULAR_BUFFER; + *regNeedConnect = true; + if (!(ncclParamLocalRegister() || (comm->planner.persistent && ncclParamGraphRegister()))) goto exit; +#if CUDART_VERSION >= 11030 + if (info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_NVLS_TREE) { + if (!comm->nvlsRegSupport || info->opDev.op == ncclDevPreMulSum) goto exit; + int nvlsReged = 0; + int collnetReged = 0; + const void *sendbuff = info->sendbuff; + void *recvbuff = info->recvbuff; + void *recvHandle = NULL, *sendHandle = NULL; + if (info->func == ncclFuncAllGather) sendbuff = NULL; + if (info->func == ncclFuncReduceScatter) recvbuff = NULL; + size_t elementSize = ncclTypeSize(info->datatype); + size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); + size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); + + /* first try graph registration. */ + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclNvlsGraphRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, &nvlsReged, outRegBufSend, outRegBufRecv, cleanupQueue, &info->nCleanupQueueElts); + } + + if (nvlsReged == 0 && ncclParamLocalRegister()) { + ncclNvlsLocalRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, &nvlsReged, outRegBufSend, outRegBufRecv); + } + + if (nvlsReged && comm->nNodes > 1 && info->algorithm == NCCL_ALGO_NVLS) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetSend, &collnetReged, &sendHandle, cleanupQueue, &info->nCleanupQueueElts); + if (collnetReged) ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &collnetReged, &recvHandle, cleanupQueue, &info->nCleanupQueueElts); + } + + if (collnetReged == 0 && ncclParamLocalRegister()) { + ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetSend, &collnetReged, &sendHandle); + if (collnetReged) ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &collnetReged, &recvHandle); + } + } + + if (nvlsReged) { + *regNeedConnect = 0; + /* tweak NVLS channels usage; for registered NVLS buffer, we only need 4/5 channels to + * saturate bandwidth. */ + if (comm->nNodes == 1) { + if (info->func == ncclFuncReduceScatter) + info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 5)); + else + info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 4)); + } else { + info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 6)); + } + info->regBufType |= NCCL_NVLS_REG_BUFFER; + } + + if (collnetReged) { + info->regBufType |= NCCL_NET_REG_BUFFER; + info->sendMhandle = sendHandle; + info->recvMhandle = recvHandle; + } + } + #endif +exit: + return result; +} + +ncclResult_t ncclRegisterCollBuffers( + struct ncclComm* comm, struct ncclTaskColl* info, + void* outRegBufSend[NCCL_MAX_LOCAL_RANKS], + void* outRegBufRecv[NCCL_MAX_LOCAL_RANKS], + struct ncclIntruQueue* cleanupQueue, + bool* regNeedConnect + ) { + ncclResult_t result = ncclSuccess; + + info->regBufType = NCCL_REGULAR_BUFFER; + *regNeedConnect = true; + if (!(ncclParamLocalRegister() || (comm->planner.persistent && ncclParamGraphRegister()))) goto exit; +#if CUDART_VERSION >= 11030 + if (info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_NVLS_TREE) { + /* this part of nvls reg code is temporarily not used and obsolete. */ + if (!comm->nvlsRegSupport || info->opDev.op == ncclDevPreMulSum) goto exit; + int nvlsReged = 0; + int collnetReged = 0; + const void *sendbuff = info->sendbuff; + void *recvbuff = info->recvbuff; + void *recvHandle = NULL, *sendHandle = NULL; + if (info->func == ncclFuncAllGather) sendbuff = NULL; + if (info->func == ncclFuncReduceScatter) recvbuff = NULL; + size_t elementSize = ncclTypeSize(info->datatype); + size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); + size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); + + /* first try local registration. */ + if (ncclParamLocalRegister()) { + ncclNvlsLocalRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, &nvlsReged, outRegBufSend, outRegBufRecv); + } + + if (nvlsReged == 0 && comm->planner.persistent && ncclParamGraphRegister()) { + ncclNvlsGraphRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, &nvlsReged, outRegBufSend, outRegBufRecv, cleanupQueue, &info->nCleanupQueueElts); + } + + if (comm->nNodes > 1 && info->algorithm == NCCL_ALGO_NVLS) { + if (ncclParamLocalRegister()) { + ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetSend, &collnetReged, &sendHandle); + if (collnetReged) ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &collnetReged, &recvHandle); + } + + if (collnetReged == 0 && comm->planner.persistent && ncclParamGraphRegister()) { + ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetSend, &collnetReged, &sendHandle, cleanupQueue, &info->nCleanupQueueElts); + if (collnetReged) ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &collnetReged, &recvHandle, cleanupQueue, &info->nCleanupQueueElts); + } + } + + if (nvlsReged) { + *regNeedConnect = 0; + /* tweak NVLS channels usage; for registered NVLS buffer, we only need 4/5 channels to + * saturate bandwidth. */ + if (comm->nNodes == 1) { + if (info->func == ncclFuncReduceScatter) + info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 5)); + else + info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 4)); + } else { + info->nMaxChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 6)); + } + info->regBufType |= NCCL_NVLS_REG_BUFFER; + } + + if (collnetReged) { + info->regBufType |= NCCL_NET_REG_BUFFER; + info->sendMhandle = sendHandle; + info->recvMhandle = recvHandle; + } + } else if (info->protocol == NCCL_PROTO_SIMPLE) { + // IPC buffer registration + if (info->func == ncclFuncReduceScatter && info->algorithm != NCCL_ALGO_COLLNET_DIRECT) goto exit; + if (info->algorithm == NCCL_ALGO_RING && ((info->func == ncclFuncAllReduce && info->sendbuff == info->recvbuff) || info->func == ncclFuncReduce)) goto exit; + if ((info->algorithm == NCCL_ALGO_TREE || info->algorithm == NCCL_ALGO_COLLNET_CHAIN) && info->sendbuff == info->recvbuff) goto exit; + if (info->func == ncclFuncAllGather && info->algorithm == NCCL_ALGO_PAT) goto exit; + + int peerRanks[NCCL_MAX_LOCAL_RANKS]; + int nPeers = 0; + size_t elementSize = ncclTypeSize(info->datatype); + size_t sendbuffSize = elementSize*ncclFuncSendCount(info->func, comm->nRanks, info->count); + size_t recvbuffSize = elementSize*ncclFuncRecvCount(info->func, comm->nRanks, info->count); + int regBufFlag = 0; + memset(peerRanks, 0xff, sizeof(int) * NCCL_MAX_LOCAL_RANKS); + + if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) { + struct ncclChannel* channel = comm->channels; + int ipcRegFlag = 0, netSendRegFlag = 0, netRecvRegFlag = 0; + void *sendHandle, *recvHandle; + if (info->func != ncclFuncReduceScatter && comm->intraNodeP2pSupport) { + for (int r = 0; r < NCCL_MAX_DIRECT_ARITY; ++r) { + for (int down = 0; down < 2; ++down) { + int peer = down ? channel->collnetDirect.down[r] : channel->collnetDirect.up[r]; + if (peer != -1) { + struct ncclConnector* peerConn = &channel->peers[peer]->recv[0]; + bool needReg = false; + + NCCLCHECK(registerCheckP2PConnection(comm, peerConn, &comm->graphs[NCCL_ALGO_COLLNET_DIRECT], peer, &needReg)); + if (needReg) { + bool found = false; + for (int p = 0; p < nPeers; ++p) { + if (peerRanks[p] == peer) { + found = true; + break; + } + } + if (!found) peerRanks[nPeers++] = peer; + } + } + } + } + + if (nPeers > 0) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclIpcGraphRegisterBuffer(comm, info->sendbuff, sendbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, &ipcRegFlag, &info->sendbuffOffset, &info->sendbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); + if (ipcRegFlag) ncclIpcGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, &ipcRegFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); + } + if (!ipcRegFlag && ncclParamLocalRegister()) { + ncclIpcLocalRegisterBuffer(comm, info->sendbuff, sendbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, &ipcRegFlag, &info->sendbuffOffset, &info->sendbuffRmtAddrs); + if (ipcRegFlag) ncclIpcLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, &ipcRegFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs); + } + } + if (ipcRegFlag) { + info->regBufType |= NCCL_IPC_REG_BUFFER; + } + } + + // register collnet buffer + if (info->opDev.op != ncclDevPreMulSum && info->opDev.op != ncclDevSumPostDiv && !(info->func == ncclFuncAllReduce && !comm->isOneRPN)) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclCollnetGraphRegisterBuffer(comm, info->sendbuff, sendbuffSize, collNetSend, &netSendRegFlag, &sendHandle, cleanupQueue, &info->nCleanupQueueElts); + info->sendMhandle = sendHandle; + if (netSendRegFlag) { + ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &netRecvRegFlag, &recvHandle, cleanupQueue, &info->nCleanupQueueElts); + info->recvMhandle = recvHandle; + } + } + + if ((netSendRegFlag == 0 || netRecvRegFlag == 0) && ncclParamLocalRegister()) { + if (!netSendRegFlag) { + ncclCollnetLocalRegisterBuffer(comm, info->sendbuff, sendbuffSize, collNetSend, &netSendRegFlag, &sendHandle); + info->sendMhandle = sendHandle; + } + if (netSendRegFlag && !netRecvRegFlag) { + ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &netRecvRegFlag, &recvHandle); + info->recvMhandle = recvHandle; + } + } + } + + if (netSendRegFlag && netRecvRegFlag) { + if (comm->isOneRPN) info->nMaxChannels = 1; + info->regBufType |= NCCL_NET_REG_BUFFER; + } + } else if (info->algorithm == NCCL_ALGO_RING) { + struct ncclReg* recvRegRecord = NULL; + struct ncclReg* sendRegRecord = NULL; + int sendNetPeers = comm->nChannels; + int recvNetPeers = comm->nChannels; + struct ncclConnector** sendNetConns = NULL; + struct ncclConnector** recvNetConns = NULL; + void** sendNetHandles = NULL; + void** recvNetHandles = NULL; + void** srecvNetHandles = NULL; + bool hasRecvNetPeer = false; + bool hasSendNetPeer = false; + + NCCLCHECK(ncclRegFind(comm, info->recvbuff, recvbuffSize, &recvRegRecord)); + if (recvRegRecord == NULL && !(comm->planner.persistent && ncclParamGraphRegister())) goto exit; + NCCLCHECK(ncclRegFind(comm, info->sendbuff, sendbuffSize, &sendRegRecord)); + if (sendRegRecord == NULL && !(comm->planner.persistent && ncclParamGraphRegister())) goto exit; + NCCLCHECK(ncclCalloc(&sendNetConns, comm->nChannels)); + NCCLCHECK(ncclCalloc(&sendNetHandles, comm->nChannels)); + NCCLCHECK(ncclCalloc(&recvNetConns, comm->nChannels)); + NCCLCHECK(ncclCalloc(&recvNetHandles, comm->nChannels)); + NCCLCHECK(ncclCalloc(&srecvNetHandles, comm->nChannels)); + + for (int c = 0; c < comm->nChannels; ++c) { + struct ncclChannel* channel = comm->channels + c; + for (int r = 0; r < 2; ++r) { + int peer; + struct ncclConnector* peerConn; + if (r == 0) { + peer = channel->ring.prev; + peerConn = &channel->peers[peer]->recv[0]; + if (peerConn->conn.flags & NCCL_DIRECT_NIC) { + recvNetConns[c] = peerConn; + hasRecvNetPeer = true; + } + } else { + peer = channel->ring.next; + peerConn = &channel->peers[peer]->send[0]; + if (peerConn->conn.flags & NCCL_DIRECT_NIC) { + sendNetConns[c] = peerConn; + hasSendNetPeer = true; + } + } + if (peerConn->conn.flags & (NCCL_P2P_READ | NCCL_P2P_WRITE)) { + bool found = false; + for (int p = 0; p < nPeers; ++p) { + if (peerRanks[p] == peer) { + found = true; + break; + } + } + if (!found) peerRanks[nPeers++] = peer; + } + } + } + if (nPeers > 0 && comm->intraNodeP2pSupport) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclIpcGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); + } + if (!regBufFlag && ncclParamLocalRegister()) { + ncclIpcLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs); + } + } + if (regBufFlag) { + info->regBufType = NCCL_IPC_REG_BUFFER; + } + + // start net registration + regBufFlag = 0; + if (!comm->useNetPXN && comm->useGdr && comm->netDeviceType != NCCL_NET_DEVICE_UNPACK) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + if (hasSendNetPeer) { + ncclNetGraphRegisterBuffer(comm, info->sendbuff, sendbuffSize, sendNetConns, sendNetPeers, ®BufFlag, sendNetHandles, cleanupQueue, &info->nCleanupQueueElts); + if (regBufFlag) + ncclNetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, sendNetConns, sendNetPeers, ®BufFlag, srecvNetHandles, cleanupQueue, &info->nCleanupQueueElts); + } + if ((regBufFlag || !hasSendNetPeer) && hasRecvNetPeer) + ncclNetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, recvNetConns, recvNetPeers, ®BufFlag, recvNetHandles, cleanupQueue, &info->nCleanupQueueElts); + } + if (!regBufFlag && ncclParamLocalRegister()) { + if (hasSendNetPeer) { + ncclNetLocalRegisterBuffer(comm, info->sendbuff, sendbuffSize, sendNetConns, sendNetPeers, ®BufFlag, sendNetHandles); + if (regBufFlag) + ncclNetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, sendNetConns, sendNetPeers, ®BufFlag, srecvNetHandles); + } + if ((regBufFlag || !hasSendNetPeer) && hasRecvNetPeer) + ncclNetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, recvNetConns, recvNetPeers, ®BufFlag, recvNetHandles); + } + } + + if (regBufFlag) { + info->regBufType |= NCCL_NET_REG_BUFFER; + info->sendNetHandles = sendNetHandles; + info->recvNetHandles = recvNetHandles; + info->srecvNetHandles = srecvNetHandles; + if (comm->isOneRPN && (info->func == ncclFuncAllGather || info->func == ncclFuncBroadcast)) { + info->nMaxChannels = 1; + } + } else { + free(sendNetHandles); + free(recvNetHandles); + free(srecvNetHandles); + } + + free(sendNetConns); + free(recvNetConns); + } else if (info->algorithm == NCCL_ALGO_TREE || info->algorithm == NCCL_ALGO_COLLNET_CHAIN) { + struct ncclReg* recvRegRecord; + int netSendRegFlag = 0, netRecvRegFlag = 0; + void *sendHandle, *recvHandle; + NCCLCHECK(ncclRegFind(comm, info->recvbuff, recvbuffSize, &recvRegRecord)); + if (recvRegRecord == NULL && !(comm->planner.persistent && ncclParamGraphRegister())) goto exit; + if (comm->intraNodeP2pSupport) { + for (int c = 0; c < comm->nChannels; ++c) { + struct ncclChannel* channel = comm->channels + c; + struct ncclTree* tree = NULL; + int peers[NCCL_MAX_TREE_ARITY + 1]; + + if (info->algorithm == NCCL_ALGO_TREE) + tree = &channel->tree; + else + tree = &channel->collnetChain; + for (int p = 0; p < NCCL_MAX_TREE_ARITY; ++p) peers[p] = tree->down[p]; + peers[NCCL_MAX_TREE_ARITY] = tree->up; + for (int p = 0; p < NCCL_MAX_TREE_ARITY + 1; ++p) { + int peer = peers[p]; + bool peerNeedReg = false; + struct ncclConnector* recvConn = NULL; + // P2P transport + if (peer == -1 || peer == comm->nRanks) continue; + recvConn = &channel->peers[peer]->recv[0]; + NCCLCHECK(registerCheckP2PConnection(comm, recvConn, &comm->graphs[info->algorithm], peer, &peerNeedReg)); + + if (peerNeedReg) { + bool found = false; + for (int pindex = 0; pindex < nPeers; ++pindex) { + if (peerRanks[pindex] == peer) { + found = true; + break; + } + } + if (!found) peerRanks[nPeers++] = peer; + } + } + } + if (nPeers > 0) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclIpcGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs, cleanupQueue, &info->nCleanupQueueElts); + } + if (!regBufFlag && ncclParamLocalRegister()) { + ncclIpcLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, peerRanks, nPeers, NCCL_IPC_COLLECTIVE, ®BufFlag, &info->recvbuffOffset, &info->recvbuffRmtAddrs); + } + } + if (regBufFlag) { + info->regBufType = NCCL_IPC_REG_BUFFER; + } + } + + // register collnet chain 1RPN buffer + if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN && info->opDev.op != ncclDevPreMulSum && info->opDev.op != ncclDevSumPostDiv && comm->isOneRPN) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclCollnetGraphRegisterBuffer(comm, info->sendbuff, sendbuffSize, collNetSend, &netSendRegFlag, &sendHandle, cleanupQueue, &info->nCleanupQueueElts); + info->sendMhandle = sendHandle; + if (netSendRegFlag) { + ncclCollnetGraphRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &netRecvRegFlag, &recvHandle, cleanupQueue, &info->nCleanupQueueElts); + info->recvMhandle = recvHandle; + } + } + + if ((netSendRegFlag == 0 || netRecvRegFlag == 0) && ncclParamLocalRegister()) { + if (!netSendRegFlag) { + ncclCollnetLocalRegisterBuffer(comm, info->sendbuff, sendbuffSize, collNetSend, &netSendRegFlag, &sendHandle); + info->sendMhandle = sendHandle; + } + if (netSendRegFlag && !netRecvRegFlag) { + ncclCollnetLocalRegisterBuffer(comm, info->recvbuff, recvbuffSize, collNetRecv, &netRecvRegFlag, &recvHandle); + info->recvMhandle = recvHandle; + } + } + } + + if (netSendRegFlag && netRecvRegFlag) { + if (comm->isOneRPN) info->nMaxChannels = 1; + info->regBufType |= NCCL_NET_REG_BUFFER; + } + } + + if (info->regBufType == NCCL_IPC_REG_BUFFER && comm->nNodes == 1 && 16 < info->nMaxChannels && info->nMaxChannels <= 24) { + info->nMaxChannels = 16; + } + } +#endif +exit: + return result; +} diff --git a/src/register/register.cc b/src/register/register.cc new file mode 100644 index 0000000000..5bf56d8d34 --- /dev/null +++ b/src/register/register.cc @@ -0,0 +1,219 @@ +/************************************************************************* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "argcheck.h" // Need some checks here since we access comm +#include "nccl.h" +#include "comm.h" +#include "net.h" +#include "register.h" +#include "transport.h" +#include "api_trace.h" +#ifdef ENABLE_MSCCLPP +#include "mscclpp/mscclpp_nccl.h" +#endif + +using namespace rccl; + +ncclResult_t ncclRegFind(struct ncclComm* comm, const void* data, size_t size, struct ncclReg** reg) { + struct ncclRegCache* cache = &comm->regCache; + uintptr_t pageSize = cache->pageSize; + uintptr_t addr = (uintptr_t)data & -pageSize; + size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; + + *reg = NULL; + for (int slot=0; /*true*/; slot++) { + if (slot == cache->population || addr < cache->slots[slot]->addr) return ncclSuccess; + if ((addr >= cache->slots[slot]->addr) && + ((addr-cache->slots[slot]->addr)/pageSize+pages) <= cache->slots[slot]->pages) { + *reg = cache->slots[slot]; + return ncclSuccess; + } + } +} +NCCL_PARAM(LocalRegister, "LOCAL_REGISTER", 1); + +ncclResult_t ncclRegLocalIsValid(struct ncclReg *reg, bool *isValid) { + if (reg && isValid) { + if (reg->localRefs) + *isValid = true; + else + *isValid = false; + } + return ncclSuccess; +} + +ncclResult_t ncclRegister(struct ncclComm* comm, void* data, size_t size, bool isGraph, void** handle) { + NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); + struct ncclRegCache* cache = &comm->regCache; + uintptr_t pageSize = cache->pageSize; + uintptr_t addr = (uintptr_t)data & -pageSize; + size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; + + if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(data, comm, "buff", "ncclCommRegister")); + INFO(NCCL_REG, "register comm %p buffer %p size %zi", comm, data, size); + + for (int slot=0; /*true*/; slot++) { + if ((slot == cache->population) || (addr < cache->slots[slot]->addr)) { + if (cache->population == cache->capacity) { // must grow cache + cache->capacity = cache->capacity < 32 ? 32 : 2*cache->capacity; + NCCLCHECK(ncclRealloc(&cache->slots, cache->population, cache->capacity)); + } + memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclReg*)); + NCCLCHECK(ncclCalloc(cache->slots+slot, 1)); + struct ncclReg* regSlot = cache->slots[slot]; + regSlot->addr = addr; + regSlot->pages = pages; + if (isGraph) regSlot->graphRefs = 1; + else regSlot->localRefs = 1; + cache->population += 1; + *handle = regSlot; + goto exit; + } else if ((addr >= cache->slots[slot]->addr) && + ((addr-cache->slots[slot]->addr)/pageSize+pages) <= cache->slots[slot]->pages) { + if (isGraph) cache->slots[slot]->graphRefs++; + else cache->slots[slot]->localRefs++; + *handle = cache->slots[slot]; + goto exit; + } + } + +exit: + return ncclSuccess; +} + +static ncclResult_t regCleanup(struct ncclComm* comm, struct ncclReg* reg) { + if (reg->state & NET_REG_COMPLETE) { + struct ncclRegNetHandles* netHandle = reg->netHandleHead; + struct ncclRegNetHandles* netHandlePrev; + while(netHandle) { + if (ncclNetDeregBuffer(comm, netHandle->proxyConn, netHandle->handle) != ncclSuccess) { + WARN("rank %d deregister NET buffer handle %p proxy rank %d failed\n", comm->rank, netHandle->handle, netHandle->proxyConn->rank); + } + netHandlePrev = netHandle; + netHandle = netHandle->next; + free(netHandlePrev); + } + } + if (reg->state & NVLS_REG_COMPLETE) { + if (ncclNvlsDeregBuffer(comm, ®->mcHandle, reg->regAddr, reg->dev, reg->regSize) != ncclSuccess) { + WARN("rank %d deregister NVLS buffer %p dev %d size %ld failed", comm->rank, (void*)reg->regAddr, reg->dev, reg->regSize); + } + reg->regAddr = (CUdeviceptr)NULL; + } + if (reg->state & COLLNET_REG_COMPLETE) { + if (ncclCollnetDeregBuffer(comm, reg->collnetProxyconn, reg->collnetHandle) != ncclSuccess) { + WARN("rank %d deregister COLLNET buffer handle %p proxy rank %d failed", comm->rank, reg->collnetHandle, reg->collnetProxyconn->rank); + } + } + if (reg->state & IPC_REG_COMPLETE) { + for (int i = 0; i < NCCL_MAX_LOCAL_RANKS; ++i) + if (reg->ipcInfos[i]) { + if (ncclIpcDeregBuffer(comm, reg->ipcInfos[i]) != ncclSuccess) { + WARN("rank %d deregister IPC buffer %p peerRank %d failed", comm->rank, reg->ipcInfos[i]->baseAddr, reg->ipcInfos[i]->peerRank); + } + free(reg->ipcInfos[i]); + } + if (reg->regIpcAddrs.hostPeerRmtAddrs) free(reg->regIpcAddrs.hostPeerRmtAddrs); + if (reg->regIpcAddrs.devPeerRmtAddrs) NCCLCHECK(ncclCudaFree(reg->regIpcAddrs.devPeerRmtAddrs)); + } + return ncclSuccess; +} + +ncclResult_t ncclRegCleanup(struct ncclComm* comm) { + struct ncclRegCache* cache = &comm->regCache; + for (int i = 0; i < cache->population; i++) { + struct ncclReg* reg = cache->slots[i]; + INFO(NCCL_INIT, "Cleanup buffer %p pages %lx", (void*)reg->addr, reg->pages); + NCCLCHECK(regCleanup(comm, reg)); + free(reg); + } + free(cache->slots); + return ncclSuccess; +} + +NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size_t size, void** handle); +ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) { + ncclResult_t ret = ncclSuccess; + + if (!ncclParamLocalRegister()) + *handle = NULL; + else { + #ifdef ENABLE_MSCCLPP + if (comm->mscclppCompatible) { + if (comm->mscclCompatible && size > 0){ + bool isManagedBuffer = false; + CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(buff))); + if(!isManagedBuffer){ + INFO(NCCL_INIT, "MSCCL++: ncclCommRegister"); + NCCLCHECKGOTO(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle), ret, end); + } + else{ + WARN("MSCCL++: Cannot register user-buffers on managed memory. RCCL user-buffer registration will occur."); + } + } + } + #endif + INFO(NCCL_INIT, "RCCL: ncclCommRegister"); + NCCLCHECKGOTO(ncclRegister(comm, buff, size, false, handle), ret, end); + } +end: + // !recording at sink + NCCLCHECK(Recorder::instance().record(rrCommRegister, comm, *handle, buff, size)); + return ret; +} + +ncclResult_t ncclCommGraphRegister(const ncclComm_t comm, void* buff, size_t size, void** handle) { + NCCLCHECK(ncclRegister(comm, buff, size, true, handle)); + return ncclSuccess; +} + +static ncclResult_t commDeregister(struct ncclComm *comm, bool isGraph, struct ncclReg* reg) { + NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); + struct ncclRegCache* cache = &comm->regCache; + int slot; + int saveDev; + if (reg == NULL) goto exit; + CUDACHECK(cudaGetDevice(&saveDev)); + CUDACHECK(cudaSetDevice(comm->cudaDev)); + for (slot = 0; slot < cache->population && cache->slots[slot] != reg; slot++); + if (slot == cache->population) { + WARN("Deregister: Could not find handle"); + return ncclInvalidUsage; + } + if (isGraph) --reg->graphRefs; + else --reg->localRefs; + if (reg->localRefs || reg->graphRefs) return ncclSuccess; + NCCLCHECK(regCleanup(comm, reg)); + free(reg); + memmove(cache->slots + slot, cache->slots + slot + 1, (cache->population - slot - 1) * sizeof(struct ncclReg*)); + cache->population -= 1; + CUDACHECK(cudaSetDevice(saveDev)); +exit: + return ncclSuccess; +} + +NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); +ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void *handle) { + NCCLCHECK(Recorder::instance().record(rrCommDeregister, comm, handle)); + + #ifdef ENABLE_MSCCLPP + if (comm->mscclppCompatible) { + const size_t size = mscclpp_BufferSize(comm->mscclpp_comm, handle); + if (comm->mscclCompatible && size > 0) { + NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle)); + return ncclSuccess; + } + } + #endif + + NCCLCHECK(commDeregister(comm, false, (struct ncclReg*)handle)); + return ncclSuccess; +} + +ncclResult_t ncclCommGraphDeregister(const ncclComm_t comm, struct ncclReg *handle) { + NCCLCHECK(commDeregister(comm, true, handle)); + return ncclSuccess; +} diff --git a/src/register/sendrecv_reg.cc b/src/register/sendrecv_reg.cc new file mode 100644 index 0000000000..f82fbd7142 --- /dev/null +++ b/src/register/sendrecv_reg.cc @@ -0,0 +1,35 @@ +#include "register.h" +#include "transport.h" + +ncclResult_t ncclRegisterP2pNetBuffer(struct ncclComm* comm, void* userbuff, size_t size, struct ncclConnector* conn, int* regFlag, void** handle, struct ncclIntruQueue* cleanupQueue) { + ncclResult_t ret = ncclSuccess; + + *regFlag = 0; + if (comm->netDeviceType != NCCL_NET_DEVICE_UNPACK) { + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclNetGraphRegisterBuffer(comm, userbuff, size, &conn, 1, regFlag, handle, cleanupQueue, NULL); + } + if (*regFlag == 0 && ncclParamLocalRegister()) { + ncclNetLocalRegisterBuffer(comm, userbuff, size, &conn, 1, regFlag, handle); + } + } + return ret; +} + +ncclResult_t ncclRegisterP2pIpcBuffer(struct ncclComm* comm, void* userbuff, size_t size, int peerRank, int* regFlag, void** regAddr, struct ncclIntruQueue* cleanupQueue) { + ncclResult_t ret = ncclSuccess; + uintptr_t offset = 0; + uintptr_t* peerRmtAddrs = NULL; + + *regFlag = 0; + if (comm->planner.persistent && ncclParamGraphRegister()) { + ncclIpcGraphRegisterBuffer(comm, userbuff, size, &peerRank, 1, NCCL_IPC_SENDRECV, regFlag, &offset, &peerRmtAddrs, reinterpret_cast(cleanupQueue), NULL); + } + if (*regFlag == 0 && ncclParamLocalRegister()) { + ncclIpcLocalRegisterBuffer(comm, userbuff, size, &peerRank, 1, NCCL_IPC_SENDRECV, regFlag, &offset, &peerRmtAddrs); + } + + if (*regFlag) + *regAddr = (void*)((uintptr_t)peerRmtAddrs + offset); + return ret; +} diff --git a/src/transport.cc b/src/transport.cc index f0f592c2a9..9d082d37e7 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -111,13 +111,13 @@ ncclResult_t ncclTransportCheckP2pType(struct ncclComm* comm, bool* intraNodeP2p } *intraNodeP2pSupport = supportFlag; *directMode = directFlag; + if (comm->rank == 0) INFO(NCCL_INIT, "Check P2P Type intraNodeP2pSupport %d directMode %d", supportFlag, directFlag); return ncclSuccess; } -ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType/*=NULL*/, bool* needsProxy/*=NULL*/) { +ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, bool* needsProxy/*=NULL*/) { // Stream used during transport setup; need for P2P pre-connect + CUDA Graph ncclResult_t ret = ncclSuccess; - int highestType = TRANSPORT_UNDEFINED; // track highest transport type bool needsProxyResult = false; struct ncclConnect** data; // Store intermediate send/recvData structs for connect struct ncclConnect** recvData = NULL; // Points to entries inside data for given recv connection within a channel @@ -162,8 +162,11 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* } //if ((recvMask.masks[0]) || (sendMask.masks[0])) NCCLCHECK(ncclCalloc(data+p, 2*MAXCHANNELS)); - if (count) NCCLCHECKGOTO(ncclCalloc(data+p, 2*MAXCHANNELS), ret, fail); + if (count) { + if (data[p] == NULL) NCCLCHECKGOTO(ncclCalloc(data + p, 2 * MAXCHANNELS), ret, fail); + else memset(data[p], 0, 2 * MAXCHANNELS * sizeof(struct ncclConnect)); + } recvData[p] = data[p]; int sendChannels = 0, recvChannels = 0; int type; @@ -173,7 +176,6 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* //if (recvMask & (1UL<(comm, graph, recvData[p]+recvChannels++, c, recvPeer, connIndex, &type, &proxy), ret, fail); - if (type > highestType) highestType = type; } } TIME_STOP(0); @@ -183,7 +185,6 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* //if (sendMask & (1UL<(comm, graph, sendData[p]+sendChannels++, c, sendPeer, connIndex, &type, &proxy), ret, fail); - if (type > highestType) highestType = type; needsProxyResult |= proxy; } } @@ -262,30 +263,18 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* } TIME_STOP(4); } - - count = 0; - for (int j = 0; j < num; j++) { - if ((recvMask.masks[j]) || (sendMask.masks[j])) { - count++; - } - } - //if (sendMask.masks[0] || recvMask.masks[0]) { - if (count) { - free(data[p]); - data[p] = NULL; - } } - if (ncclParamReportConnectProgress() && comm->rank == 0 && done > 0) { + if (ncclParamReportConnectProgress() && comm->rank == 0 && done > 0) { struct timeval now; gettimeofday(&now, NULL); - if (((now.tv_sec - timeLast.tv_sec)*1.0 + (now.tv_usec-timeLast.tv_usec)*1e-6) > 1) { - float elapsed = (now.tv_sec - timeStart.tv_sec)*1.0 + (now.tv_usec-timeStart.tv_usec)*1e-6; - float remaining = elapsed*(comm->nRanks-done)/done; + if (((now.tv_sec - timeLast.tv_sec) * 1.0 + (now.tv_usec - timeLast.tv_usec) * 1e-6) > 1) { + float elapsed = (now.tv_sec - timeStart.tv_sec) * 1.0 + (now.tv_usec - timeStart.tv_usec) * 1e-6; + float remaining = elapsed * (comm->nRanks - done) / done; printf("%sP2p connect: %g%% Elapsed %d:%02d Remaining %d:%02d ", - timeReported ? "\r" : "", done*100.0/comm->nRanks, ((int)elapsed)/60, ((int)elapsed)%60, ((int)remaining)/60, ((int)remaining)%60); + timeReported ? "\r" : "", done * 100.0 / comm->nRanks, ((int)elapsed) / 60, ((int)elapsed) % 60, ((int)remaining) / 60, ((int)remaining) % 60); fflush(stdout); timeReported = true; - timeLast = now; // struct copy; + timeLast = now; // struct copy; } } } @@ -332,7 +321,6 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* } } - if (highestTransportType != NULL) *highestTransportType = highestType; if (needsProxy != NULL) *needsProxy = needsProxyResult; TIME_PRINT("P2P Setup/Connect"); exit: diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc index 9c64ea9520..2c3bee4d71 100644 --- a/src/transport/coll_net.cc +++ b/src/transport/coll_net.cc @@ -113,6 +113,7 @@ struct sendResources { uint64_t step; struct reqSlot (*reqFifo)[NCCL_STEPS]; int collNetRank; + size_t maxCollBytes; volatile uint32_t* curr_hdp_reg; // Curr GPU in ring (for rdma transport use only) }; @@ -135,6 +136,7 @@ struct recvResources { uint64_t step; struct reqSlot reqFifo[COLLNET_MAX_GROUPS][NCCL_STEPS]; int collNetRank; + size_t maxCollBytes; volatile uint32_t* curr_hdp_reg; // Curr GPU in ring (for rdma transport use only) }; @@ -160,7 +162,7 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; int64_t netId; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &netId, &req.netDev, &proxyRank)); - NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, netId, 1, &req.useGdr)); + NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->rank, netId, 1, &req.useGdr)); send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; send->proxyConn.tpLocalRank = comm->topParentLocalRanks[comm->localRank]; @@ -180,10 +182,10 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; int64_t netId; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &netId, &req.netDev, &proxyRank)); - NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, netId, 0, &req.useGdr)); + NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->rank, netId, 0, &req.useGdr)); recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; // Determine whether we need to flush the GDR buffer on recv or not - if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush)); + if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm, req.netDev, myInfo->rank, &req.needFlush)); recv->proxyConn.tpLocalRank = comm->topParentLocalRanks[comm->localRank]; NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 0, myInfo->rank, &recv->proxyConn)); @@ -322,6 +324,13 @@ static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struc connection->collNet = req->collNet; /* DMA-BUF support */ resources->useDmaBuf = resources->useGdr && proxyState->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF); + /* collective size limits*/ + resources->maxCollBytes = props.maxCollBytes; + if((resources->maxCollBytes <= 0) || (resources->maxCollBytes > NCCL_MAX_NET_SIZE_BYTES)) { + WARN("sendProxySetup: collnet plugin returned invalid value for maxCollBytes %ld \ + [allowed range: %ld - %ld] \n", resources->maxCollBytes, 0L, NCCL_MAX_NET_SIZE_BYTES); + return ncclInternalError; + } return ncclSuccess; } @@ -435,6 +444,12 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc connection->collNet = req->collNet; /* DMA-BUF support */ resources->useDmaBuf = resources->useGdr && proxyState->dmaBufSupport && (props.ptrSupport & NCCL_PTR_DMABUF); + resources->maxCollBytes = props.maxCollBytes; + if((resources->maxCollBytes <= 0) || (resources->maxCollBytes > NCCL_MAX_NET_SIZE_BYTES)) { + WARN("sendProxySetup: collnet plugin returned invalid value for maxCollBytes %ld \ + [allowed range: %ld - %ld] \n", resources->maxCollBytes, 0L, NCCL_MAX_NET_SIZE_BYTES); + return ncclInternalError; + } collNetHandle_t* netHandle = (collNetHandle_t*) respBuff; if (respSize != sizeof(collNetHandle_t)) return ncclInternalError; @@ -650,14 +665,14 @@ static size_t calcAlgoOffset(struct ncclProxyArgs* args, int isAllNotOne, int su return offset; } -static int calcRegionOffset( +static ssize_t calcRegionOffset( struct ncclProxyArgs* args, int isRecvNotSend, int sub, uint64_t step, int side // 0=begin, 1=end ) { struct ncclCollNetSharedRes* collNet = args->subs[0].connection->collNet; - int slotSize = collNet->buffSize/NCCL_STEPS; - int chunkSize = args->chunkSize; - int base = isRecvNotSend*NCCL_STEPS + (step%NCCL_STEPS); + ssize_t slotSize = collNet->buffSize/NCCL_STEPS; + ssize_t chunkSize = args->chunkSize; + ssize_t base = isRecvNotSend*NCCL_STEPS + (step%NCCL_STEPS); base *= collNet->nChannels*slotSize; if (args->coll == ncclFuncAllReduce) { return base + (sub+side)*chunkSize; @@ -679,6 +694,165 @@ static constexpr int calcStepsPerGroup(int nGroups) { return NCCL_STEPS; } +static ncclResult_t collNetRegIallreduce(struct ncclProxyState* proxyState, struct sendResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, int groupStart, ssize_t *nBytesInOut, void **request) { + ssize_t loopSize, winOffset, nBytes; + ssize_t eltSize = ncclTypeSize((ncclDataType_t)args->dtype); + // for UB iallreduce 1RPN case, user's send and recv buffers are both directly accessed by collnet network. + // we can just issue maximal collnet bytes by resources->maxCollBytes for each iallreduce. + // for multi-RPN case, we have to consider pipeline, so each time we only send groupSize * chunkSize (i.e., nBytesInOut) + // sub->loopOffset is data offset to the buffer for this head rank in each loop + // winOffset is used to find actual offset from send and recv buffer for this iallreduce + // loopSize is all bytes sent by all channels and head ranks in each loop. + // send and recv mem handle are retrieved from sub in which user buffer mem handles are stored. + if (sub->isOneRPN) { + winOffset = 0; + nBytes = std::min((size_t)sub->nbytes, resources->maxCollBytes); + loopSize = nBytes; + } else { + winOffset = sub->loopOffset + groupStart * args->chunkSize; + nBytes = std::min(sub->nbytes - winOffset, *nBytesInOut); + loopSize = sub->loopSize; + } + + if (nBytes > 0) { + NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, sub->sendbuff + winOffset, sub->recvbuff + winOffset, nBytes / eltSize, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sub->sendMhandle, sub->recvMhandle, request)); + if (*request) { + // if issued successfully, we need to move the pointer forward and reduce the existing nbytes. + sub->nbytes -= loopSize; + sub->sendbuff += loopSize; + sub->recvbuff += loopSize; + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] registered Iallreduce posted sendbuff %p recvbuff %p size %ld loopSize %ld winOffset %ld isOneRPN %d req %p", (long)sub->transmitted, sub->nsteps, groupStart, sub->sendbuff, sub->recvbuff, nBytes, loopSize, winOffset, sub->isOneRPN, *request); + } + } + *nBytesInOut = nBytes; + return ncclSuccess; +} + +static ncclResult_t collNetIallreduce(struct ncclProxyState* proxyState, struct sendResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, ssize_t nBytes, ssize_t sendBeg, ssize_t recvBeg, void **request) { + void *sendMhandle = resources->sendMhandles[NCCL_PROTO_SIMPLE]; + void *recvMhandle = resources->recvMhandles[NCCL_PROTO_SIMPLE]; + char *region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[NCCL_PROTO_SIMPLE]); + ssize_t eltSize = ncclTypeSize((ncclDataType_t)args->dtype); + // non-UB iallreduce, region is intermediate buffer and sendBeg/recvBeg is the corresponding offset + // for send and recv data. The send and recv mem handle are retrieved from resources. + NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, region + sendBeg, region + recvBeg, nBytes / eltSize, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sendMhandle, recvMhandle, request)); + if (*request) + TRACE(NCCL_NET, "sendProxy [%ld/%d] Iallreduce posted size %ld sendBeg %ld recvBeg %ld req %p", (long)sub->transmitted, sub->nsteps, nBytes, sendBeg, recvBeg, *request); + return ncclSuccess; +} + +static ncclResult_t collNetRegIallgather(struct ncclProxyState* proxyState, struct sendResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, ssize_t nBytesIn, ssize_t allBeg, ssize_t recvBeg, void *recvMhandle, void **request) { + ncclNetSGE_v9_t recvParts; + ssize_t sizePerRank = args->specifics.collnetDirect.sizePerRank; + char *region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[NCCL_PROTO_SIMPLE]); + ssize_t nBytes; + ssize_t winOffset; + void *sendbuff; + // UB iallgather 1RPN logic is the same as iallreduce. + // If iallgather is not 1RPN, we can let collnet network directly access sendbuff but not recvbuff; + // the main reason is non-1RPN case will cause non-contiguous recv data from network, so + // we have to use intermediate buffer "region" to recv data and copy into the recvbuff. + // so allBeg and recvMhandle, which are global window offset of recv buffer and mem handle for region, + // are only used in multi-RPN case. + if (sub->isOneRPN) { + nBytes = std::min((size_t)sub->nbytes, resources->maxCollBytes); + winOffset = sub->offset; + recvParts.mhandle = sub->recvMhandle; + recvParts.address = sub->recvbuff; + } else { + nBytes = nBytesIn; + winOffset = allBeg; + recvParts.mhandle = recvMhandle; + recvParts.address = region + recvBeg; + } + recvParts.size = nBytes; + if (winOffset / sizePerRank == args->specifics.collnetDirect.node) { + sendbuff = sub->sendbuff + winOffset % sizePerRank; + } else { + sendbuff = sub->sendbuff; + } + NCCLCHECK(proxyState->ncclCollNet->iallgather(resources->collNetComm, sendbuff, 1, &recvParts, sizePerRank, winOffset, nBytes, sub->sendMhandle, request)); + if (*request) { + if (sub->isOneRPN) { + sub->recvbuff += nBytes; + sub->nbytes -= nBytes; + sub->offset += nBytes; + } + TRACE(NCCL_NET, "sendProxy [%ld/%d] registered Iallgather posted sizePerRank %ld winOffset %ld recvSize %ld isOneRPN %d request %p", sub->transmitted, sub->nsteps, sizePerRank, winOffset, nBytes, sub->isOneRPN, *request); + } + return ncclSuccess; +} + +static ncclResult_t collNetIallgather(struct ncclProxyState* proxyState, struct sendResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, ssize_t nBytes, ssize_t allBeg, ssize_t sendBeg, ssize_t recvBeg, void *sendMhandle, void *recvMhandle, void **request) { + ncclNetSGE_v9_t recvParts; + ssize_t sizePerRank = args->specifics.collnetDirect.sizePerRank; + char *region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[NCCL_PROTO_SIMPLE]); + recvParts.mhandle = recvMhandle; + recvParts.address = region + recvBeg; + recvParts.size = nBytes; + // non-UB iallgather, we use intermidate region buffers for both send and recv data. + // sendMhandle and recvMhandle are send and recv mem handles for region, and allBeg is + // the global window offset of recv buffer. sendBeg and recvBeg are offset to the region + // for intermediate data. + NCCLCHECK(proxyState->ncclCollNet->iallgather(resources->collNetComm, region + sendBeg, 1, &recvParts, sizePerRank, allBeg, nBytes, sendMhandle, request)); + if (*request) + TRACE(NCCL_NET, "sendProxy [%ld/%d] Iallgather posted sizePerRank %ld winOffset %ld recvSize %ld request %p", sub->transmitted, sub->nsteps, sizePerRank, allBeg, nBytes, *request); + return ncclSuccess; +} + +static ncclResult_t collNetRegIreducescatter(struct ncclProxyState* proxyState, struct sendResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, ssize_t nBytesIn, ssize_t allBeg, ssize_t sendBeg, void *sendMhandle, void **request) { + ncclNetSGE_v9_t sendParts; + ssize_t sizePerRank = args->specifics.collnetDirect.sizePerRank; + char *region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[NCCL_PROTO_SIMPLE]); + ssize_t nBytes; + size_t winOffset; + void *recvbuff; + // Similar to iallgather, if ireducescatter is not 1RPN, we can let collnet network + // directly access recvbuff but not sendbuff. We use intermediate buffer "region" to + // send data and directly recv into the recvbuff. + if (sub->isOneRPN) { + nBytes = std::min((size_t)sub->nbytes, resources->maxCollBytes); + winOffset = sub->offset; + sendParts.mhandle = sub->sendMhandle; + sendParts.address = sub->sendbuff; + } else { + nBytes = nBytesIn; + winOffset = allBeg; + sendParts.mhandle = sendMhandle; + sendParts.address = region + sendBeg; + } + sendParts.size = nBytes; + if (winOffset / sizePerRank == args->specifics.collnetDirect.node) { + recvbuff = sub->recvbuff + winOffset % sizePerRank; + } else { + recvbuff = sub->recvbuff; + } + NCCLCHECK(proxyState->ncclCollNet->ireducescatter(resources->collNetComm, 1, &sendParts, recvbuff, sizePerRank, winOffset, nBytes, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sub->recvMhandle, request)); + if (*request) { + if (sub->isOneRPN) { + sub->sendbuff += nBytes; + sub->nbytes -= nBytes; + sub->offset += nBytes; + } + TRACE(NCCL_NET, "sendProxy [%ld/%d] registered Ireducescatter posted sizePerRank %ld winOffset %ld sendSize %ld isOneRPN %d request %p", sub->transmitted, sub->nsteps, sizePerRank, winOffset, nBytes, sub->isOneRPN, *request); + } + return ncclSuccess; +} + +static ncclResult_t collNetIreducescatter(struct ncclProxyState* proxyState, struct sendResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, ssize_t nBytes, ssize_t allBeg, ssize_t sendBeg, ssize_t recvBeg, void *sendMhandle, void *recvMhandle, void **request) { + ncclNetSGE_v9_t sendParts; + ssize_t sizePerRank = args->specifics.collnetDirect.sizePerRank; + char *region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[NCCL_PROTO_SIMPLE]); + sendParts.mhandle = sendMhandle; + sendParts.address = region + sendBeg; + sendParts.size = nBytes; + // non-UB ireducescatter is the same as non-UB iallgather but in the reverse direction. + NCCLCHECK(proxyState->ncclCollNet->ireducescatter(resources->collNetComm, 1, &sendParts, region + recvBeg, sizePerRank, allBeg, nBytes, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, recvMhandle, request)); + if (*request) + TRACE(NCCL_NET, "sendProxy [%ld/%d] Ireducescatter posted sizePerRank %ld winOffset %ld sendSize %ld request %p", sub->transmitted, sub->nsteps, sizePerRank, allBeg, nBytes, *request); + return ncclSuccess; +} + static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) { if (args->state == ncclProxyOpReady) { for (int s=0; snsubs; s++) { @@ -688,6 +862,8 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct sub->base = ROUNDUP(resources->step, args->chunkSteps); sub->posted = sub->received = sub->transmitted = sub->done = 0; resources->step = sub->base + sub->nsteps; + //adjust nsteps for registerd buffers as device signals a single step + if (sub->reg && sub->isOneRPN) sub->nsteps = DIVUP((size_t)sub->nbytes, resources->maxCollBytes); } args->state = ncclProxyOpProgress; args->hdp_flushed = 0; @@ -701,28 +877,30 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct struct sendResources* resources = (struct sendResources*) (sub->connection->transportResources); void* sendMhandle = resources->sendMhandles[p]; void* recvMhandle = resources->recvMhandles[p]; - char* region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[p]); auto reqFifo = resources->reqFifo; int group = s/COLLNET_GROUP_NSUBS; int groupStart = s - (s%COLLNET_GROUP_NSUBS); if (sub->posted < sub->nsteps && sub->posted < sub->done + NCCL_STEPS) { int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; - if (sub->reg == 0) { + if (sub->reg == 0 || (!sub->isOneRPN && args->coll == ncclFuncReduceScatter)) { resources->recvMem->connFifo[buffSlot].offset = calcRegionOffset(args, 0, s, sub->posted, 0); __sync_synchronize(); } volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; - TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] posted offset %d @ %p signal %ld->%ld", long(sub->posted), group, buffSlot, resources->recvMem->connFifo[buffSlot].offset, &resources->recvMem->connFifo[buffSlot].offset, long(*sendHead), long(sub->base + sub->posted + args->sliceSteps - NCCL_STEPS)); + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d/%d] posted offset %d @ %p signal %ld->%ld", long(sub->posted), group, buffSlot, sub->nsteps, resources->recvMem->connFifo[buffSlot].offset, &resources->recvMem->connFifo[buffSlot].offset, long(*sendHead), long(sub->base + sub->posted + args->sliceSteps - NCCL_STEPS)); sub->posted += args->sliceSteps; - *sendHead = sub->base + sub->posted - NCCL_STEPS; + // Only post one credit for registered buffer + if (sub->reg == 0 || !sub->isOneRPN || sub->posted == args->sliceSteps) *sendHead = sub->base + sub->posted - NCCL_STEPS; if (resources->gdcSync) wc_store_fence(); // Flush out WC write } if (sub->received < sub->posted && sub->received < sub->done + calcStepsPerGroup(nGroups)) { int buffSlot = (sub->base+sub->received)%NCCL_STEPS; volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; volatile uint64_t* recvTail = &resources->recvMem->tail; - if ((connFifo[buffSlot].size != -1 || sub->reg) && ((*recvTail > (sub->base+sub->received)))) { + //device progresses tail by only 1 for registered buffers + uint64_t tail = sub->base + (sub->reg && sub->isOneRPN ? 0 : sub->received); + if ((connFifo[buffSlot].size != -1 || sub->reg) && (*recvTail > tail)) { if (args->coll != ncclFuncAllReduce && sub->reg == 0) { int sendBeg = calcRegionOffset(args, 0, s, sub->received, 0); int sendEnd = calcRegionOffset(args, 0, s, sub->received, 1); @@ -744,110 +922,42 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; if (!reqFifo[group][buffSlot].turnIsSendNotRecv) continue; - ssize_t sizePerRank = 0; - size_t allBeg = calcAlgoOffset(args, 1, groupStart, sub->transmitted); - size_t allEnd = calcAlgoOffset(args, 1, s+1, sub->transmitted); - int sendBeg = calcRegionOffset(args, 0, groupStart, sub->transmitted, 0); - int sendEnd = calcRegionOffset(args, 0, s, sub->transmitted, 1); - int recvBeg = calcRegionOffset(args, 1, groupStart, sub->transmitted, 0); - int recvEnd = calcRegionOffset(args, 1, s, sub->transmitted, 1); + ssize_t allBeg = calcAlgoOffset(args, 1, groupStart, sub->transmitted); + ssize_t allEnd = calcAlgoOffset(args, 1, s+1, sub->transmitted); + ssize_t sendBeg = calcRegionOffset(args, 0, groupStart, sub->transmitted, 0); + ssize_t sendEnd = calcRegionOffset(args, 0, s, sub->transmitted, 1); + ssize_t recvBeg = calcRegionOffset(args, 1, groupStart, sub->transmitted, 0); + ssize_t recvEnd = calcRegionOffset(args, 1, s, sub->transmitted, 1); reqFifo[group][buffSlot].size = recvEnd - recvBeg; - size_t eltSize = ncclTypeSize((ncclDataType_t)args->dtype); - if (sendBeg==sendEnd && recvBeg==recvEnd && sub->reg == 0) { + if (sendBeg==sendEnd && recvBeg==recvEnd) { sub->requests[buffSlot] = nullptr; // trivally finished request } else { + ssize_t nBytes = 0; if (args->coll == ncclFuncAllReduce) { + nBytes = sendEnd - sendBeg; if (sub->reg) { - size_t nBytes = std::min(sub->nbytes, NCCL_MAX_COLLNET_SIZE); - int count = (int)(nBytes / eltSize); - NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, sub->sendbuff, sub->recvbuff, count, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sub->sendMhandle, sub->recvMhandle, sub->requests + buffSlot)); - if (sub->requests[buffSlot]) { - sub->nbytes -= nBytes; - sub->sendbuff += nBytes; - sub->recvbuff += nBytes; - } + NCCLCHECK(collNetRegIallreduce(proxyState, resources, args, sub, groupStart, &nBytes, &sub->requests[buffSlot])); } else { - int count = (sendEnd - sendBeg) / eltSize; - NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, region + sendBeg, region + recvBeg, count, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sendMhandle, recvMhandle, sub->requests + buffSlot)); + NCCLCHECK(collNetIallreduce(proxyState, resources, args, sub, nBytes, sendBeg, recvBeg, &sub->requests[buffSlot])); } - } else { - sizePerRank = args->specifics.collnetDirect.sizePerRank; - if (args->coll == ncclFuncAllGather) { - ncclNetSGE_v8_t recvParts; - if (sub->reg) { - size_t nBytes = std::min(sub->nbytes, NCCL_MAX_COLLNET_SIZE); - void *sendbuff; - recvParts.mhandle = sub->recvMhandle; - recvParts.address = sub->recvbuff; - recvParts.size = nBytes; - if (sub->offset / sizePerRank == args->specifics.collnetDirect.node) { - sendbuff = sub->sendbuff + sub->offset % sizePerRank; - } else { - sendbuff = sub->sendbuff; - } - NCCLCHECK(proxyState->ncclCollNet->iallgather( - resources->collNetComm, sendbuff, 1, &recvParts, - sizePerRank, sub->offset, nBytes, - sub->sendMhandle, sub->requests + buffSlot)); - if (sub->requests[buffSlot]) { - sub->recvbuff += nBytes; - sub->nbytes -= nBytes; - sub->offset += nBytes; - } - } else { - recvParts.mhandle = recvMhandle; - recvParts.address = region + recvBeg; - recvParts.size = allEnd - allBeg; - NCCLCHECK(proxyState->ncclCollNet->iallgather( - resources->collNetComm, region + sendBeg, 1, &recvParts, - sizePerRank, allBeg, allEnd - allBeg, - sendMhandle, sub->requests + buffSlot)); - } - } else { - ncclNetSGE_v8_t sendParts; - if (sub->reg) { - size_t nBytes = std::min(sub->nbytes, NCCL_MAX_COLLNET_SIZE); - void *recvbuff; - sendParts.mhandle = sub->sendMhandle; - sendParts.address = sub->sendbuff; - sendParts.size = nBytes; - if (sub->offset / sizePerRank == args->specifics.collnetDirect.node) { - recvbuff = sub->recvbuff + sub->offset % sizePerRank; - } else { - recvbuff = sub->recvbuff; - } - NCCLCHECK(proxyState->ncclCollNet->ireducescatter( - resources->collNetComm, 1, &sendParts, recvbuff, - sizePerRank, sub->offset, nBytes, - (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, - sub->recvMhandle, sub->requests + buffSlot)); - if (sub->requests[buffSlot]) { - sub->sendbuff += nBytes; - sub->nbytes -= nBytes; - sub->offset += nBytes; - } - } else { - sendParts.mhandle = sendMhandle; - sendParts.address = region + sendBeg; - sendParts.size = allEnd - allBeg; - NCCLCHECK(proxyState->ncclCollNet->ireducescatter( - resources->collNetComm, 1, &sendParts, region + recvBeg, - sizePerRank, allBeg, allEnd - allBeg, - (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, - recvMhandle, sub->requests + buffSlot)); - } - } - } - if (sub->requests[buffSlot] == nullptr) continue; - - if (args->coll == ncclFuncAllReduce) { - TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Iallreduce posted, size %d req %p", (long)sub->transmitted, group, buffSlot, int(sendEnd-sendBeg), sub->requests[buffSlot]); } else if (args->coll == ncclFuncAllGather) { - TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Iallgather posted sendSize=%ld recvOffset=%ld recvSize=%ld request=%p", (long)sub->transmitted, group, buffSlot, long(sizePerRank), long(allBeg), long(allEnd-allBeg), sub->requests[buffSlot]); + nBytes = allEnd - allBeg; + if (sub->reg) { + NCCLCHECK(collNetRegIallgather(proxyState, resources, args, sub, nBytes, allBeg, recvBeg, recvMhandle, &sub->requests[buffSlot])); + } else { + NCCLCHECK(collNetIallgather(proxyState, resources, args, sub, nBytes, allBeg, sendBeg, recvBeg, sendMhandle, recvMhandle, &sub->requests[buffSlot])); + } } else { - TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Ireducescatter posted sendOffset=%ld sendSize=%ld recvSize=%ld request=%p", (long)sub->transmitted, group, buffSlot, long(allBeg), long(allEnd-allBeg), long(sizePerRank), sub->requests[buffSlot]); + // reducescatter + nBytes = allEnd - allBeg; + if (sub->reg) { + NCCLCHECK(collNetRegIreducescatter(proxyState, resources, args, sub, nBytes, allBeg, sendBeg, sendMhandle, &sub->requests[buffSlot])); + } else { + NCCLCHECK(collNetIreducescatter(proxyState, resources, args, sub, nBytes, allBeg, sendBeg, recvBeg, sendMhandle, recvMhandle, &sub->requests[buffSlot])); + } } + if (nBytes > 0 && sub->requests[buffSlot] == nullptr) continue; } } sub->transmitted += args->sliceSteps; @@ -881,6 +991,52 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct return ncclSuccess; } +static ncclResult_t collNetRecvFlush(struct ncclProxyState* proxyState, struct recvResources *resources, struct ncclProxyArgs *args, struct ncclProxySubArgs *sub, int groupStart, ssize_t nBytesIn, ssize_t recvBeg, void **request) { + char *region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[NCCL_PROTO_SIMPLE]); + if (sub->reg && (sub->isOneRPN || args->coll != ncclFuncAllGather)) { + ssize_t nBytes, loopSize; + ssize_t offset = sub->offset + groupStart * args->chunkSize; + if (sub->isOneRPN) { + nBytes = std::min((size_t)sub->nbytes, resources->maxCollBytes); + loopSize = nBytes; + } else { + nBytes = std::min(sub->nbytes - sub->loopOffset, nBytesIn); + loopSize = sub->loopSize; + } + if (nBytes > 0) { + if (args->coll == ncclFuncReduceScatter) { + ssize_t sizePerRank = args->specifics.collnetDirect.sizePerRank; + ssize_t groupStartOffset = sub->offset + groupStart * args->chunkSize; + ssize_t groupEndOffset = groupStartOffset + nBytes; + int node = args->specifics.collnetDirect.node; + int startNode = groupStartOffset / sizePerRank; + int lastNode = groupEndOffset / sizePerRank; + if (startNode == node) { + offset = groupStartOffset % sizePerRank; + nBytes = std::min(sizePerRank - offset, nBytes); + } else if (startNode < node && node < lastNode) { + offset = 0; + nBytes = sizePerRank; + } else if (node == lastNode) { + offset = 0; + nBytes = groupEndOffset % sizePerRank; + } else { + // dummy flush + offset = 0; + } + } + NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, sub->recvbuff + offset + sub->loopOffset, nBytes, sub->recvMhandle, request)); + if (*request) { + sub->nbytes -= loopSize; + sub->offset += loopSize; + } + } + } else { + NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, region + recvBeg, nBytesIn, resources->mhandles[NCCL_PROTO_SIMPLE], request)); + } + return ncclSuccess; +} + static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) { if (args->state == ncclProxyOpReady) { for (int s=0; snsubs; s++) { @@ -890,22 +1046,21 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->base = ROUNDUP(resources->step, args->chunkSteps); sub->posted = sub->received = sub->flushed = sub->transmitted = sub->done = 0; resources->step = sub->base + sub->nsteps; + //adjust nsteps for registerd buffers as device signals a single step + if (sub->reg && sub->isOneRPN) sub->nsteps = DIVUP((size_t)sub->nbytes, resources->maxCollBytes); memset(sub->requests, 0, sizeof(sub->requests)); } args->state = ncclProxyOpProgress; } args->idle = 1; if (args->state == ncclProxyOpProgress) { - int p = NCCL_PROTO_SIMPLE; int nGroups = DIVUP(args->nsubs, COLLNET_GROUP_NSUBS); for (int s=0; snsubs; s++) { int group = s/COLLNET_GROUP_NSUBS; int groupStart = s - (s%COLLNET_GROUP_NSUBS); struct ncclProxySubArgs* sub = args->subs+s; struct recvResources* resources = (struct recvResources*) (sub->connection->transportResources); - void* mhandle = resources->mhandles[p]; auto reqFifo = resources->reqFifo; - char* region = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); // Enforce sync between operations of the same group. if (LAST_OF_GROUP(args, s) && (sub->posted < sub->done + calcStepsPerGroup(nGroups)) && (sub->posted < sub->nsteps)) { @@ -919,10 +1074,10 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (LAST_OF_GROUP(args, s) && (sub->received < sub->posted)) { int buffSlot = (sub->base+sub->received)%NCCL_STEPS; if (!reqFifo[group][buffSlot].turnIsSendNotRecv) { // Buffer is cleared : coll is complete - int recvBeg = calcRegionOffset(args, 1, groupStart, sub->received, 0); - int recvEnd = calcRegionOffset(args, 1, s, sub->received, 1); - int totalSize = recvEnd - recvBeg; - TRACE(NCCL_NET, "recvProxy [%ld/%d/%d] received, size %d chunkSize=%d", (long)sub->received, group, buffSlot, totalSize, args->chunkSize); + ssize_t recvBeg = calcRegionOffset(args, 1, groupStart, sub->received, 0); + ssize_t recvEnd = calcRegionOffset(args, 1, s, sub->received, 1); + ssize_t totalSize = recvEnd - recvBeg; + TRACE(NCCL_NET, "recvProxy [%ld/%d/%d] received, size %ld chunkSize=%ld", (long)sub->received, group, buffSlot, totalSize, args->chunkSize); sub->received += args->sliceSteps; if ((reqFifo[group][buffSlot].size > 0 || sub->reg) && resources->useGdr && resources->needFlush) { // GDRCOPY support @@ -935,37 +1090,7 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct return ncclInternalError; #endif } else { - if (sub->reg) { - size_t nBytes = std::min(sub->nbytes, NCCL_MAX_COLLNET_SIZE); - size_t offset = 0; - if (args->coll == ncclFuncReduceScatter) { - size_t sizePerRank = args->specifics.collnetDirect.sizePerRank; - int node = args->specifics.collnetDirect.node; - int startNode = sub->offset / sizePerRank; - int lastNode = (sub->offset + nBytes) / sizePerRank; - if (startNode == node) { - offset = sub->offset % sizePerRank; - nBytes = std::min(sizePerRank - offset, nBytes); - } else if (startNode < node && node < lastNode) { - nBytes = sizePerRank; - } else if (node == lastNode) { - nBytes = (sub->offset + nBytes) % sizePerRank; - } else { - // no need to flush - nBytes = 0; - } - } - NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, sub->recvbuff + offset, nBytes, sub->recvMhandle, sub->requests+buffSlot)); - if (sub->requests[buffSlot]) { - sub->nbytes -= nBytes; - sub->offset += nBytes; - if (args->coll == ncclFuncAllGather || args->coll == ncclFuncAllReduce) { - sub->recvbuff += nBytes; - } - } - } else { - NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, region+recvBeg, totalSize, mhandle, sub->requests+buffSlot)); - } + NCCLCHECK(collNetRecvFlush(proxyState, resources, args, sub, groupStart, totalSize, recvBeg, &sub->requests[buffSlot])); } } args->idle = 0; @@ -986,14 +1111,19 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct } } if (sub->transmitted < sub->flushed) { - if (sub->reg == 0) { + if (sub->reg == 0 || (!sub->isOneRPN && args->coll == ncclFuncAllGather)) { int buffSlot = (sub->base + sub->transmitted)%NCCL_STEPS; volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; connFifo[buffSlot].offset = calcRegionOffset(args, 1, s, sub->transmitted, 0); __sync_synchronize(); } volatile uint64_t* recvTail = resources->gdcSync ? resources->gdcSync : &resources->recvMem->tail; - *recvTail = sub->base + sub->flushed; + if (sub->reg && sub->isOneRPN) { + // We may have bumped net steps, but reg operations only have a single step w.r.t. the GPU. + if (sub->flushed == sub->nsteps) *recvTail = sub->base + args->sliceSteps; + } else { + *recvTail = sub->base + sub->flushed; + } if (resources->gdcSync) wc_store_fence(); // Flush out WC write sub->transmitted += args->sliceSteps; args->idle = 0; @@ -1005,7 +1135,8 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct bool groupSync = s==0 ? args->subs[args->nsubs-1].done == sub->done : (sub-1)->done > sub->done; volatile uint64_t* sendHead = &resources->sendMem->head; - if (groupSync && sub->done < sub->transmitted && (sub->base+sub->done) < *sendHead) { + int done = sub->reg && sub->isOneRPN ? 0 : sub->done; + if (groupSync && sub->done < sub->transmitted && sub->base + done < *sendHead) { sub->done += args->sliceSteps; args->idle = 0; if (sub->done == sub->nsteps && s == args->nsubs-1) { @@ -1023,24 +1154,22 @@ struct collnetRegInfo { size_t size; }; -ncclResult_t ncclCollnetLocalRegisterBuffer(struct ncclComm* comm, const void* userbuff, size_t buffSize, int type, int* outRegBufFlag, void** outHandle) { +static ncclResult_t collnetRegisterBuffer(struct ncclComm* comm, const void* userbuff, size_t buffSize, int type, struct ncclReg* regRecord, int* outRegBufFlag, void** outHandle) { ncclResult_t ret = ncclSuccess; - struct ncclReg *regRecord = NULL; + if (regRecord) { + if (regRecord->state & COLLNET_REG_COMPLETE) { + // reuse previous registration + *outRegBufFlag = 2; + *outHandle = regRecord->collnetHandle; + INFO(NCCL_REG, "rank %d - COLLNET reuse register userbuff %p (handle %p), buffSize %ld, type %s", comm->rank, userbuff, regRecord->collnetHandle, buffSize, type == collNetRecv ? "Recv" : "Send"); + goto exit; + } else { + /* start register collnet buffer */ + struct collnetRegInfo info = { regRecord->addr, regRecord->pages * comm->regCache.pageSize }; + void* handle = NULL; + struct ncclConnInfo* conn = (type == collNetRecv) ? &comm->channels[0].peers[comm->nRanks]->recv[type].conn : &comm->channels[0].peers[comm->nRanks]->send[type].conn; - *outRegBufFlag = 0; - *outHandle = NULL; - if (comm && userbuff && buffSize > 0) { - NCCLCHECKGOTO(ncclRegFind(comm, userbuff, buffSize, ®Record), ret, fail); - if (regRecord) { - if (regRecord->state & COLLNET_REG_COMPLETE) { - // reuse previous registration - *outRegBufFlag = 2; - *outHandle = regRecord->collnetHandle; - goto exit; - } else { - /* start register collnet buffer */ - struct collnetRegInfo info = {regRecord->addr, regRecord->pages * comm->regCache.pageSize}; - void* handle = NULL; + if (conn->flags & NCCL_DIRECT_NIC) { struct ncclProxyConnector* proxyconn = (type == collNetRecv) ? &comm->channels[0].peers[comm->nRanks]->recv[type].proxyConn : &comm->channels[0].peers[comm->nRanks]->send[type].proxyConn; NCCLCHECKGOTO(ncclProxyCallBlocking(comm, proxyconn, ncclProxyMsgRegister, &info, sizeof(struct collnetRegInfo), &handle, sizeof(void*)), ret, fail); if (handle) { @@ -1048,10 +1177,78 @@ ncclResult_t ncclCollnetLocalRegisterBuffer(struct ncclComm* comm, const void* u regRecord->collnetProxyconn = proxyconn; *outHandle = regRecord->collnetHandle = handle; *outRegBufFlag = 1; + INFO(NCCL_REG, "rank %d - COLLNET register userbuff %p (handle %p), buffSize %ld, type %s", comm->rank, userbuff, handle, buffSize, type == collNetRecv ? "Recv" : "Send"); } + } else { + WARN("rank %d - COLLNET failed to register userbuff %p (handle %p), buffSize %ld, type %s, GDR is not enabled", comm->rank, userbuff, handle, buffSize, type == collNetRecv ? "Recv" : "Send"); } } } +exit: + return ret; +fail: + *outRegBufFlag = 0; + *outHandle = NULL; + goto exit; +} + +ncclResult_t ncclCollnetLocalRegisterBuffer(struct ncclComm* comm, const void* userbuff, size_t buffSize, int type, int* outRegBufFlag, void** outHandle) { + ncclResult_t ret = ncclSuccess; + struct ncclReg *regRecord = NULL; + bool isValid = false; + + *outRegBufFlag = 0; + *outHandle = NULL; + if (comm && userbuff && buffSize > 0) { + NCCLCHECKGOTO(ncclRegFind(comm, userbuff, buffSize, ®Record), ret, fail); + NCCLCHECKGOTO(ncclRegLocalIsValid(regRecord, &isValid), ret, fail); + if (isValid) + NCCLCHECKGOTO(collnetRegisterBuffer(comm, userbuff, buffSize, type, regRecord, outRegBufFlag, outHandle), ret, fail); + } +exit: + return ret; +fail: + *outRegBufFlag = 0; + goto exit; +} + +struct ncclCollnetCleanupCallback { + struct ncclCommCallback base; + struct ncclComm *comm; + struct ncclReg *reg; +}; + +static ncclResult_t cleanupCollnet(struct ncclComm* comm, struct ncclCommCallback* cb) { + struct ncclCollnetCleanupCallback* obj = (struct ncclCollnetCleanupCallback*)cb; + NCCLCHECK(ncclCommGraphDeregister(obj->comm, obj->reg)); + free(obj); + return ncclSuccess; +} + +ncclResult_t ncclCollnetGraphRegisterBuffer(struct ncclComm* comm, const void* userbuff, size_t buffSize, int type, int* outRegBufFlag, void** outHandle, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueElts) { + ncclResult_t ret = ncclSuccess; + struct ncclCollnetCleanupCallback* record = NULL; + struct ncclReg *regRecord = NULL; + void *baseSend = NULL; + size_t baseSendSize = 0; + + *outRegBufFlag = 0; + if (comm && userbuff && buffSize > 0) { + CUDACHECKGOTO(cuMemGetAddressRange((CUdeviceptr *)&baseSend, &baseSendSize, (CUdeviceptr)userbuff), ret, fail); + NCCLCHECKGOTO(ncclCommGraphRegister(comm, baseSend, baseSendSize, (void**)®Record), ret, fail); + NCCLCHECKGOTO(collnetRegisterBuffer(comm, userbuff, buffSize, type, regRecord, outRegBufFlag, outHandle), ret, fail); + + if (*outRegBufFlag) { + record = (struct ncclCollnetCleanupCallback*)malloc(sizeof(struct ncclCollnetCleanupCallback)); + record->base.fn = cleanupCollnet; + record->comm = comm; + record->reg = regRecord; + ncclIntruQueueEnqueue(cleanupQueue, (struct ncclCommCallback*)record); + *nCleanupQueueElts += 1; + } else { + NCCLCHECKGOTO(ncclCommGraphDeregister(comm, regRecord), ret, fail); + } + } exit: return ret; @@ -1061,55 +1258,9 @@ fail: goto exit; } -struct ncclCollnetCleanupCallback { - struct ncclCommCallback base; - struct ncclProxyConnector* proxyConn; - void* buffer; - size_t size; - void* mhandle; -}; - -static ncclResult_t cleanupCollnet(struct ncclComm* comm, struct ncclCommCallback* cb) { - struct ncclCollnetCleanupCallback* obj = (struct ncclCollnetCleanupCallback*)cb; - NCCLCHECK(ncclCollnetDeregBuffer(comm, obj->proxyConn, obj->mhandle)); - INFO(NCCL_REG, "rank %d - deregistered collnet buffer handle %p, size %ld, buff %p", comm->rank, obj->mhandle, obj->size, obj->buffer); - free(obj); - return ncclSuccess; -} - -ncclResult_t ncclCollnetGraphRegisterBuffer(struct ncclComm* comm, const void* userbuff, size_t buffSize, int type, int* outRegBufFlag, void** outHandle, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueElts) { - ncclResult_t ret = ncclSuccess; - void* handle = NULL; - struct ncclRegCache* cache = &comm->regCache; - uintptr_t pageSize = cache->pageSize; - uintptr_t addr = (uintptr_t)userbuff & -pageSize; - size_t size = DIVUP((uintptr_t)userbuff - addr + buffSize, pageSize) * pageSize; - collnetRegInfo info = {addr, size}; - struct ncclCollnetCleanupCallback* record = NULL; - struct ncclProxyConnector* proxyConn = (type == collNetRecv) ? &comm->channels[0].peers[comm->nRanks]->recv[type].proxyConn : &comm->channels[0].peers[comm->nRanks]->send[type].proxyConn; - - *outRegBufFlag = 0; - NCCLCHECKGOTO(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgRegister, &info, sizeof(struct collnetRegInfo), &handle, sizeof(void*)), ret, fail); - record = (struct ncclCollnetCleanupCallback*)malloc(sizeof(struct ncclCollnetCleanupCallback)); - record->base.fn = cleanupCollnet; - record->proxyConn = proxyConn; - record->buffer = (void*)userbuff; - record->size = buffSize; - *outHandle = record->mhandle = handle; - *outRegBufFlag = 1; - ncclIntruQueueEnqueue(cleanupQueue, (struct ncclCommCallback*)record); - *nCleanupQueueElts += 1; - -exit: - return ret; -fail: - *outRegBufFlag = 0; - *outHandle = NULL; - goto exit; -} - ncclResult_t ncclCollnetDeregBuffer(struct ncclComm* comm, struct ncclProxyConnector* proxyconn, void* handle) { NCCLCHECK(ncclProxyCallBlocking(comm, proxyconn, ncclProxyMsgDeregister, &handle, sizeof(void*), NULL, 0)); + INFO(NCCL_REG, "rank %d - COLLNET deregistered buffer handle %p", comm->rank, handle); return ncclSuccess; } @@ -1117,26 +1268,67 @@ static ncclResult_t sendProxyRegBuffer(struct ncclProxyConnection* connection, s void* handle; struct collnetRegInfo* info = (struct collnetRegInfo*)reqBuff; struct sendResources* resources = (struct sendResources*)(connection->transportResources); + ncclResult_t ret = ncclSuccess; + bool needReg = true; assert(reqSize == sizeof(struct collnetRegInfo)); assert(respSize == sizeof(void*)); - if (proxyState->ncclCollNet->regMr(resources->collNetComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, &handle) != ncclSuccess) handle = NULL; + +#if CUDART_VERSION >= 11070 + /* DMA-BUF support */ + if (resources->useGdr && resources->useDmaBuf) { + int dmabuf_fd; + CUCHECKGOTO(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)info->buffer, info->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0), ret, peermem); + NCCLCHECKGOTO(proxyState->ncclCollNet->regMrDmaBuf(resources->collNetComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, 0ULL, dmabuf_fd, &handle), ret, peermem); + (void)close(dmabuf_fd); + needReg = false; + } +#endif +peermem: + if (needReg) { + NCCLCHECKGOTO(proxyState->ncclCollNet->regMr(resources->collNetComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, &handle), ret, fail); + } + +exit: memcpy(respBuff, (void*)&handle, sizeof(void*)); *done = 1; return ncclSuccess; +fail: + handle = NULL; + goto exit; } static ncclResult_t recvProxyRegBuffer(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { void* handle; struct collnetRegInfo* info = (struct collnetRegInfo*)reqBuff; struct recvResources* resources = (struct recvResources*)(connection->transportResources); + ncclResult_t ret = ncclSuccess; + bool needReg = true; assert(reqSize == sizeof(struct collnetRegInfo)); assert(respSize == sizeof(void*)); - if (proxyState->ncclCollNet->regMr(resources->collNetComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, &handle) != ncclSuccess) handle = NULL; + #if CUDART_VERSION >= 11070 + /* DMA-BUF support */ + if (resources->useGdr && resources->useDmaBuf) { + int dmabuf_fd; + CUCHECKGOTO(cuMemGetHandleForAddressRange((void *)&dmabuf_fd, (CUdeviceptr)info->buffer, info->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0), ret, peermem); + NCCLCHECKGOTO(proxyState->ncclCollNet->regMrDmaBuf(resources->collNetComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, 0ULL, dmabuf_fd, &handle), ret, peermem); + (void)close(dmabuf_fd); + needReg = false; + } +#endif +peermem: + if (needReg) { + NCCLCHECKGOTO(proxyState->ncclCollNet->regMr(resources->collNetComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, &handle), ret, fail); + } + +exit: memcpy(respBuff, (void*)&handle, sizeof(void*)); *done = 1; return ncclSuccess; +fail: + handle = NULL; + goto exit; } static ncclResult_t sendProxyDeregBuffer(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, int* done) { @@ -1161,13 +1353,6 @@ static ncclResult_t recvProxyDeregBuffer(struct ncclProxyConnection* connection, return ncclSuccess; } -struct ncclTransport collNetTransport = { - "COL", - canConnect, - { sendSetup, sendConnect, sendFree, NULL, sendProxySetup, sendProxyConnect, sendProxyFree, sendProxyProgress, sendProxyRegBuffer, sendProxyDeregBuffer }, - { recvSetup, recvConnect, recvFree, NULL, recvProxySetup, recvProxyConnect, recvProxyFree, recvProxyProgress, recvProxyRegBuffer, recvProxyDeregBuffer } -}; - ncclResult_t ncclCollNetChainBufferSetup(ncclComm_t comm) { ncclResult_t ret = ncclSuccess; char line[1024]; @@ -1203,7 +1388,6 @@ fail: ncclResult_t ncclCollNetDirectBufferSetup(ncclComm_t comm) { ncclResult_t ret = ncclSuccess; - int highestTransportType0 = TRANSPORT_UNDEFINED, highestTransportType1 = TRANSPORT_UNDEFINED; if (comm->collNetSupport == 0) goto exit; @@ -1212,13 +1396,13 @@ ncclResult_t ncclCollNetDirectBufferSetup(ncclComm_t comm) { struct ncclChannel* channelRecv = comm->channels + c; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collnetDirect.down, 0), ret, fail); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &comm->graphs[NCCL_ALGO_COLLNET_DIRECT], 0, &highestTransportType0), ret, fail); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &comm->graphs[NCCL_ALGO_COLLNET_DIRECT], 0), ret, fail); for (int c = 0; c < comm->nChannels; c++) { struct ncclChannel* channelSend = comm->channels + c; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.down, NCCL_MAX_DIRECT_ARITY, channelSend->collnetDirect.up, 1), ret, fail); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &comm->graphs[NCCL_ALGO_COLLNET_DIRECT], 1, &highestTransportType1), ret, fail); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &comm->graphs[NCCL_ALGO_COLLNET_DIRECT], 1), ret, fail); INFO(NCCL_INIT, "rank %d Connected CollNet", comm->rank); @@ -1416,3 +1600,10 @@ fail: comm->collNetSupport = 0; goto exit; } + +struct ncclTransport collNetTransport = { + "COL", + canConnect, + { sendSetup, sendConnect, sendFree, NULL, sendProxySetup, sendProxyConnect, sendProxyFree, sendProxyProgress, sendProxyRegBuffer, sendProxyDeregBuffer }, + { recvSetup, recvConnect, recvFree, NULL, recvProxySetup, recvProxyConnect, recvProxyFree, recvProxyProgress, recvProxyRegBuffer, recvProxyDeregBuffer } +}; \ No newline at end of file diff --git a/src/transport/generic.cc b/src/transport/generic.cc index eb6817c9d8..4bb0408f87 100644 --- a/src/transport/generic.cc +++ b/src/transport/generic.cc @@ -6,18 +6,38 @@ #include "comm.h" #include "transport.h" +#include "bootstrap.h" ncclResult_t ncclTransportRingConnect(struct ncclComm* comm) { + struct ringConnInfo { + bool useNetPXN; + bool useGdr; + }; + struct ringConnInfo* ringInfo = NULL; ncclResult_t ret = ncclSuccess; if (comm && comm->nRanks > 1) { + comm->useGdr = true; + comm->useNetPXN = false; for (int c = 0; c < comm->nChannels; c++) { struct ncclChannel* channel = comm->channels + c; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, fail); } NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &comm->graphs[NCCL_ALGO_RING], 0), ret, fail); - INFO(NCCL_INIT, "Connected all rings"); + if (ncclParamLocalRegister() || ncclParamGraphRegister()) { + NCCLCHECK(ncclCalloc(&ringInfo, comm->nRanks)); + ringInfo[comm->rank].useGdr = comm->useGdr; + ringInfo[comm->rank].useNetPXN = comm->useNetPXN; + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, ringInfo, sizeof(struct ringConnInfo)), ret, fail); + for (int i = 0; i < comm->nRanks; ++i) { + if (!ringInfo[i].useGdr) comm->useGdr = false; + if (ringInfo[i].useNetPXN) comm->useNetPXN = true; + if (comm->useGdr == false && comm->useNetPXN == true) break; + } + } + INFO(NCCL_INIT, "Connected all rings, use ring PXN %d GDR %d", comm->useNetPXN, comm->useGdr); } exit: + free(ringInfo); return ret; fail: goto exit; diff --git a/src/transport/net.cc b/src/transport/net.cc index dda9677c21..bd62b719fa 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -17,6 +17,7 @@ #include "profiler.h" #include "transport.h" #include "shm.h" +#include #include "graph.h" #include "graph/topo.h" #if defined(ENABLE_NPKIT) @@ -117,6 +118,7 @@ struct sendNetResources { int netDeviceVersion; ncclNetDeviceType netDeviceType; ncclNetDeviceHandle_t* netDeviceHandle; + size_t maxP2pBytes; volatile uint32_t* curr_hdp_reg; // Curr GPU in ring (for rdma transport use only) }; @@ -150,9 +152,15 @@ struct recvNetResources { int netDeviceVersion; ncclNetDeviceType netDeviceType; ncclNetDeviceHandle_t* netDeviceHandle; + size_t maxP2pBytes; volatile uint32_t* curr_hdp_reg; // Curr GPU in ring (for rdma transport use only) }; +struct netRegInfo { + uintptr_t buffer; + size_t size; +}; + /* Determine if two peers can communicate with NET */ static ncclResult_t canConnect(int* ret, struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { *ret = 1; @@ -186,6 +194,9 @@ struct setupReq { uint32_t* curr_hdp_reg; }; +NCCL_PARAM(NetOptionalRecvCompletion, "NET_OPTIONAL_RECV_COMPLETION", 1); + +static_assert(sizeof(ncclNetHandle_t) + sizeof(int) <= CONNECT_SIZE, "Not large enough ncclConnect to hold ncclNetHandle_t and useGdr flag"); // Forward declaration static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args); @@ -204,8 +215,10 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int64_t netId; if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 1, &netId, &req.netDev)); if (req.netDev < 0) NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, peerInfo->rank, &netId, &req.netDev, &proxyRank)); - NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, netId, 1, &req.useGdr)); + NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->rank, netId, 1, &req.useGdr)); send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; + if (!req.useGdr && connIndex == 0) comm->useGdr = 0; + if (proxyRank != myInfo->rank && connIndex == 0) comm->useNetPXN = true; if (req.useGdr && !IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx90a") && !IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && !IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) { CUDACHECK(hipDeviceGetAttribute((int*)&req.curr_hdp_reg, hipDeviceAttributeHdpMemFlushCntl, myInfo->cudaDev)); send->conn.curr_hdp_reg = req.curr_hdp_reg; @@ -225,6 +238,7 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph proxyRank, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "", comm, comm->nRanks); } *((int*)connectInfo) = comm->topParentRanks[proxyRank]; + memcpy((uint8_t*)connectInfo + sizeof(ncclNetHandle_t), &req.useGdr, sizeof(int)); return ncclSuccess; } @@ -247,11 +261,13 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int64_t netId; if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 0, &netId, &req.netDev)); if (req.netDev < 0) NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, myInfo->rank, &netId, &req.netDev, &proxyRank)); - NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, netId, 0, &req.useGdr)); + NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->rank, netId, 0, &req.useGdr)); + recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; + if (!req.useGdr && connIndex == 0) comm->useGdr = 0; // Determine whether we need to flush the GDR buffer on recv or not if (req.useGdr) { - NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush)); + NCCLCHECK(ncclTopoNeedFlush(comm, req.netDev, myInfo->rank, &req.needFlush)); CUDACHECK(hipDeviceGetAttribute((int*)&req.curr_hdp_reg, hipDeviceAttributeHdpMemFlushCntl, myInfo->cudaDev)); recv->conn.curr_hdp_reg = req.curr_hdp_reg; } @@ -263,6 +279,7 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph req.tpRank = comm->topParentRanks[myInfo->rank]; req.tpRemoteRank = comm->topParentRanks[peerInfo->rank]; NCCLCHECK(ncclProxyCallBlocking(comm, &recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t))); + memcpy((uint8_t*)connectInfo + sizeof(ncclNetHandle_t), &req.useGdr, sizeof(int)); INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s comm %p nRanks %02d", channelId, connIndex, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, comm->ncclNet->name, req.netDev, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "", comm, comm->nRanks); return ncclSuccess; @@ -316,8 +333,11 @@ struct netRecvConnectArgs { static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) { struct connectMap* map = (connectMap*) send->transportResources; - void* opId; + int recvUseGdr; + + memcpy(&recvUseGdr, (uint8_t*)connectInfo + sizeof(ncclNetHandle_t), sizeof(int)); + if (!recvUseGdr) send->conn.flags &= ~NCCL_DIRECT_NIC; // map isn't allocated thus this op hasn't been submitted yet if (!map) { @@ -424,6 +444,11 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) { struct connectMap* map = (connectMap*) recv->transportResources; void* opId; + int sendUseGdr; + + memcpy(&sendUseGdr, (uint8_t*)connectInfo + sizeof(ncclNetHandle_t), sizeof(int)); + if (!sendUseGdr) recv->conn.flags &= ~NCCL_DIRECT_NIC; + if (!map) { NCCLCHECK(ncclCalloc(&map, 1)); recv->transportResources = map; @@ -567,7 +592,7 @@ static ncclResult_t sharedNetBuffersInit(struct ncclProxyState* proxyState, int return ncclSuccess; } -static ncclResult_t sharedBuffersGet(struct ncclProxyState* proxyState, int channel, int slot, int* offset, int* size) { +static ncclResult_t sharedBuffersGet(struct ncclProxyState* proxyState, int channel, int slot, int* offset, size_t* size) { // Use different pools for different channels and also separate send/recv. int globalSlot = (channel*NCCL_SHARED_STEPS)+slot; *offset = proxyState->p2pChunkSize * globalSlot; @@ -639,6 +664,13 @@ static ncclResult_t sendProxySetup(struct ncclProxyConnection* connection, struc resources->netDeviceVersion = props.netDeviceVersion; resources->netDeviceType = props.netDeviceType; + /* point-to-point size limits*/ + resources->maxP2pBytes = props.maxP2pBytes; + if((resources->maxP2pBytes <= 0) || (resources->maxP2pBytes > NCCL_MAX_NET_SIZE_BYTES)) { + WARN("sendProxySetup: net plugin returned invalid value for maxP2pBytes %ld \ + [allowed range: %ld - %ld] \n", resources->maxP2pBytes, 0L, NCCL_MAX_NET_SIZE_BYTES); + return ncclInternalError; + } // We don't return any data if (respSize != 0) return ncclInternalError; @@ -671,6 +703,13 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc resources->maxRecvs = props.maxRecvs; resources->netDeviceVersion = props.netDeviceVersion; resources->netDeviceType = props.netDeviceType; + /* point-to-point size limits*/ + resources->maxP2pBytes = props.maxP2pBytes; + if((resources->maxP2pBytes <= 0) || (resources->maxP2pBytes > NCCL_MAX_NET_SIZE_BYTES)) { + WARN("recvProxySetup: net plugin returned invalid value for maxP2pBytes %ld \ + [allowed range: %ld - %ld] \n", resources->maxP2pBytes, 0L, NCCL_MAX_NET_SIZE_BYTES); + return ncclInternalError; + } if (respSize != sizeof(ncclNetHandle_t)) return ncclInternalError; NCCLCHECK(proxyState->ncclNet->listen(req->netDev, respBuff, &resources->netListenComm)); @@ -1030,6 +1069,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str resources->sendMem = (struct ncclSendMem*) NCCL_NET_MAP_GET_POINTER(map, cpu, sendMem); resources->recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, cpu, recvMem); + for (int i = 0; i < NCCL_STEPS; i++) resources->recvMem->connFifo[i].size = -1; for (int p=0; pbuffers[p] = NCCL_NET_MAP_GET_POINTER(map, cpu, buffs[p]); if (resources->buffers[p]) { @@ -1158,7 +1198,6 @@ static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct } static_assert(NCCL_STEPS <= NCCL_NET_MAX_REQUESTS, "Not enough net requests to cover for steps"); -#define MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two. #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_NET_COLLECT_POLL_CNT) static int g_npkit_net_poll_cnt = 0; @@ -1178,11 +1217,8 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct resources->step = sub->base + sub->nsteps; sub->posted = sub->transmitted = sub->done = 0; ncclProfilerStartSendProxyOpEvent(s, args); - if (sub->reg && sub->nbytes > 0) { - NCCLCHECK(proxyState->ncclNet->regMr(resources->netSendComm, sub->recvbuff, sub->nbytes, NCCL_PTR_CUDA, &sub->mhandle)); - } else { - sub->mhandle = resources->mhandles[args->protocol]; - } + if (!sub->reg) + sub->sendMhandle = resources->mhandles[args->protocol]; } args->state = ncclProxyOpProgress; args->hdp_flushed = 0; @@ -1193,6 +1229,9 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct int maxDepth = std::min(NCCL_STEPS, NCCL_SHARED_STEPS/args->nsubs); for (int s=0; snsubs; s++) { struct ncclProxySubArgs* sub = args->subs+s; + int postedStepId = sub->posted; + int transmittedStepId = sub->transmitted; + int doneStepId = sub->done; if (sub->done == sub->nsteps) continue; struct sendNetResources* resources = (struct sendNetResources*) (sub->connection->transportResources); volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; @@ -1200,7 +1239,7 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); // Post buffers to the GPU if (sub->posted < sub->nsteps && sub->posted < sub->done + maxDepth) { - ncclProfilerStartSendProxyStepEvents(s, args, sub->posted, sub->posted+args->sliceSteps); + ncclProfilerStartSendProxyStepEvent(s, args, postedStepId); int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; if (resources->shared) { if (!sub->reg) { @@ -1212,12 +1251,13 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct } volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; sub->posted += args->sliceSteps; - // Only post one credit for registered buffer - if (sub->reg == 0 || sub->posted == args->sliceSteps) *sendHead = sub->base + sub->posted - NCCL_STEPS; + *sendHead = sub->base + sub->posted - NCCL_STEPS; if (resources->gdcSync) wc_store_fence(); // Flush out WC write - } else sub->posted += args->sliceSteps; + } else { + sub->posted += args->sliceSteps; + } ncclProfilerRecordProxyOpEventState(s, args, sub->posted, sub->transSize, ncclProfilerProxyOpSendPosted); - ncclProfilerRecordProxyStepEventStates(s, args, sub->posted-args->sliceSteps, sub->posted, ncclProfilerProxyStepSendGPUWait); + ncclProfilerRecordProxyStepEventState(s, args, postedStepId, ncclProfilerProxyStepSendGPUWait); args->idle = 0; continue; } @@ -1225,10 +1265,10 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct if (sub->transmitted < sub->posted && sub->transmitted < sub->done + NCCL_STEPS) { int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; volatile uint64_t* recvTail = &resources->recvMem->tail; - uint64_t tail = sub->base + (sub->reg ? 0 : sub->transmitted); - if ((sub->reg || connFifo[buffSlot].size != -1) && ((*recvTail > tail) || p == NCCL_PROTO_LL)) { + uint64_t tail = sub->base + sub->transmitted; + if (connFifo[buffSlot].size != -1 && (*recvTail > tail || p == NCCL_PROTO_LL)) { // We have something to receive, let's check if it's completely ready. - int size = sub->reg ? std::min(MAX_NET_SIZE, sub->nbytes) : connFifo[buffSlot].size; + int size = connFifo[buffSlot].size; #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_NET_SEND_ENTRY) && defined(ENABLE_NPKIT_EVENT_NET_SEND_EXIT) sub->npKitSizesFifo[buffSlot] = size; #endif @@ -1257,8 +1297,14 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct volatile uint32_t *f2 = &lines[i].flag2; if (f1[0] != flag || f2[0] != flag) { ready = 0; break; } } - } else if (p == NCCL_PROTO_SIMPLE && resources->shared) { - buff = sub->reg ? (char*)sub->recvbuff : localBuff+resources->recvMem->connFifo[buffSlot].offset; + } else if (p == NCCL_PROTO_SIMPLE) { + if (resources->shared) { + buff = sub->reg ? (char*)sub->sendbuff + sub->transmitted * NCCL_MAX_NET_SIZE : localBuff + resources->recvMem->connFifo[buffSlot].offset; + } else if (sub->reg) { + size_t sendSize; + sub->ringAlgo->getNextSendAddr(sub->transmitted, (uint8_t**)&buff, &sendSize, &sub->sendMhandle); + assert(sendSize == size); + } } if (ready) { // flush HDP if not done @@ -1266,12 +1312,12 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct args->hdp_flushed = *recvTail; *resources->curr_hdp_reg = 1; } - ncclProfilerRecordProxyOpEventState(s, args, sub->transmitted + args->sliceSteps, sub->transSize, ncclProfilerProxyOpSendRemFifoWait); + ncclProfilerRecordProxyOpEventState(s, args, sub->transmitted+args->sliceSteps, sub->transSize, ncclProfilerProxyOpSendRemFifoWait); // Data is ready, try to send. // Coverity complains about the size here as pointing to an out-of-scope temporary. Which is nonsense, // since size is a plain integer. // coverity[use_invalid:FALSE] - NCCLCHECK(proxyState->ncclNet->isend(resources->netSendComm, buff, size, resources->tpRank, sub->mhandle, sub->requests+buffSlot)); + NCCLCHECK(proxyState->ncclNet->isend(resources->netSendComm, buff, size, resources->tpRank, sub->sendMhandle, sub->requests+buffSlot)); if (sub->requests[buffSlot] != NULL) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_NET_SEND_ENTRY) && defined(ENABLE_NPKIT_EVENT_NET_SEND_EXIT) @@ -1290,11 +1336,11 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct sub->timestamp[buffSlot] = 0; #endif - TRACE(NCCL_NET, "sendProxy [%ld/%d] Isend posted, req %p, size %d, proto %d, myRank %d, channelId %d", sub->transmitted, buffSlot, sub->requests[buffSlot], size, p, proxyState->tpRank, sub->channelId); + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Isend posted, req %p, buff %p, size %d, proto %d, myRank %d, channelId %d, mhandle %p", sub->transmitted, buffSlot, sub->nsteps, sub->requests[buffSlot], buff, size, p, proxyState->tpRank, sub->channelId, sub->sendMhandle); + sub->transSize += size; sub->transmitted += args->sliceSteps; ncclProfilerRecordProxyOpEventState(s, args, sub->transmitted, sub->transSize, ncclProfilerProxyOpSendTransmitted); - ncclProfilerRecordProxyStepEventStates(s, args, sub->transmitted-args->sliceSteps, sub->transmitted, ncclProfilerProxyStepSendWait); - sub->transSize += size; + ncclProfilerRecordProxyStepEventState(s, args, transmittedStepId, ncclProfilerProxyStepSendWait); args->idle = 0; continue; } @@ -1354,41 +1400,24 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct g_npkit_net_poll_cnt = 0; #endif #endif - if (sub->reg) { - if (size < sub->nbytes) { - sub->recvbuff += size; - sub->nbytes -= size; - // Do one more step (at least) - sub->nsteps++; - } else { - // Signal the GPU the send is complete and it can return. - connFifo[sub->base%NCCL_STEPS].size = -1; - } - } // Make sure size is reset to -1 before we update the head. - if (sub->reg == 0) connFifo[buffSlot].size = -1; + connFifo[buffSlot].size = -1; __sync_synchronize(); - TRACE(NCCL_NET, "sendProxy [%ld/%d] request %p done", sub->done, buffSlot, sub->requests[buffSlot]); + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] request %p done", sub->done, buffSlot, sub->nsteps, sub->requests[buffSlot]); sub->done += args->sliceSteps; - ncclProfilerStopProxyStepEvents(s, args, sub->done-args->sliceSteps, sub->done); + ncclProfilerStopProxyStepEvent(s, args, doneStepId); ncclProfilerRecordProxyOpEventState(s, args, sub->done, sub->transSize, ncclProfilerProxyOpSendDone); if (resources->shared == 0) { volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; - if (sub->reg) { - // We may have added more net steps, but reg operations only have a single step w.r.t. the GPU. - if (sub->done == sub->nsteps) *sendHead = sub->base + args->sliceSteps; - } else { - *sendHead = sub->base + sub->done; - } + *sendHead = sub->base + sub->done; if (resources->gdcSync) wc_store_fence(); // Flush out WC write } args->idle = 0; if (sub->done == sub->nsteps) { - if (sub->reg && sub->nbytes > 0) { - NCCLCHECK(proxyState->ncclNet->deregMr(resources->netSendComm, sub->mhandle)); - } args->done++; + if (sub->ringAlgo && sub->ringAlgo->decRefCount() == 0) delete sub->ringAlgo; + sub->ringAlgo = NULL; } } } @@ -1442,14 +1471,11 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct // Set step base for next op resources->step = sub->base + sub->nsteps; sub->posted = sub->received = sub->transmitted = sub->done = 0; + sub->regBufferReady = 0; for (int i=0; ireg && sub->nbytes > 0) { - // Register buffer - NCCLCHECK(proxyState->ncclNet->regMr(resources->netRecvComm, sub->recvbuff, sub->nbytes, NCCL_PTR_CUDA, &sub->mhandle)); - } else { - sub->mhandle = resources->mhandles[args->protocol]; - } + if (!sub->reg) + sub->recvMhandle = resources->mhandles[args->protocol]; } args->state = ncclProxyOpProgress; } @@ -1461,32 +1487,44 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct struct ncclProxySubArgs* subGroup = args->subs+s; int subCount = 0; void* ptrs[NCCL_PROXY_MAX_SUBS]; - int sizes[NCCL_PROXY_MAX_SUBS]; + size_t sizes[NCCL_PROXY_MAX_SUBS]; int tags[NCCL_PROXY_MAX_SUBS]; void* mhandles[NCCL_PROXY_MAX_SUBS]; for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; + int postedStepId = sub->posted; if (sub->posted < sub->nsteps) { if (sub->posted >= sub->done + maxDepth) { subCount = 0; break; } - ncclProfilerStartRecvProxyStepEvents(s+i, args, sub->posted, sub->posted+args->sliceSteps); + ncclProfilerStartRecvProxyStepEvent(s+i, args, postedStepId); struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); - if (sub->reg) maxDepth = 1; int stepSize = resources->buffSizes[p] / NCCL_STEPS; char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; - if (p == NCCL_PROTO_SIMPLE && resources->shared) { - if (sub->reg) { - // Wait until CUDA kernel has started before we access the user buffer directly. - if (connFifo[sub->base%NCCL_STEPS].size == -1) continue; - ptrs[subCount] = sub->recvbuff; - sizes[subCount] = std::min(MAX_NET_SIZE, sub->nbytes); + if (p == NCCL_PROTO_SIMPLE) { + if (resources->shared) { + if (sub->reg) { + // Wait until CUDA kernel has started before we access the user buffer directly. + if (!sub->regBufferReady && connFifo[sub->base % NCCL_STEPS].size == -1) continue; + sub->regBufferReady = 1; + ptrs[subCount] = sub->recvbuff + sub->posted * NCCL_MAX_NET_SIZE; + sizes[subCount] = std::min(NCCL_MAX_NET_SIZE, (ssize_t)(sub->nbytes - sub->posted * NCCL_MAX_NET_SIZE)); + } else { + int sharedBuffSlot = sub->posted % maxDepth; + int offset; + NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot * args->nsubs + s + i, &offset, sizes + subCount)); + __atomic_store_n(&connFifo[buffSlot].offset, offset, __ATOMIC_RELAXED); + ptrs[subCount] = localBuff + offset; + } } else { - int sharedBuffSlot = sub->posted%maxDepth; - int offset; - NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s+i, &offset, sizes+subCount)); - __atomic_store_n(&connFifo[buffSlot].offset, offset, __ATOMIC_RELAXED); - ptrs[subCount] = localBuff+offset; + if (sub->reg) { + if (!sub->regBufferReady && connFifo[sub->base % NCCL_STEPS].size == -1) continue; + sub->regBufferReady = 1; + sub->ringAlgo->getNextRecvAddr(sub->posted, (uint8_t**)&ptrs[subCount], &sizes[subCount], &sub->recvMhandle); + } else { + ptrs[subCount] = localBuff + buffSlot * stepSize; + sizes[subCount] = stepSize * args->sliceSteps; + } } } else { ptrs[subCount] = localBuff+buffSlot*stepSize; @@ -1494,7 +1532,7 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct } if (sub->nbytes < sizes[subCount]) sizes[subCount] = sub->nbytes; tags[subCount] = resources->tpRemoteRank; - mhandles[subCount] = sub->mhandle; + mhandles[subCount] = sub->recvMhandle; subCount++; } } @@ -1502,12 +1540,16 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct uint64_t step = subGroup->posted; struct recvNetResources* resources = (struct recvNetResources*) (subGroup->connection->transportResources); void** requestPtr = subGroup->requests+(step%NCCL_STEPS); + bool ignoreCompletion = ncclParamNetOptionalRecvCompletion() && ((args->protocol == NCCL_PROTO_LL128) || (args->protocol == NCCL_PROTO_LL)) && (subCount == 1); + if (ignoreCompletion) *requestPtr = (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION; NCCLCHECK(proxyState->ncclNet->irecv(resources->netRecvComm, subCount, ptrs, sizes, tags, mhandles, requestPtr)); if (*requestPtr) { subGroup->recvRequestsCache[step%NCCL_STEPS] = *requestPtr; subGroup->recvRequestsSubCount = subCount; for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup+i; + int postedStepId = sub->posted; + TRACE(NCCL_NET, "recvProxy [%ld/%ld/%d] Irecv posted, buff %p, size %ld, myRank %d, channelId %d, mhandle %p", sub->posted, (sub->base + sub->posted) % NCCL_STEPS, sub->nsteps, ptrs[i], sizes[i], proxyState->tpRank, sub->channelId, mhandles[i]); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_NET_RECV_ENTRY) && defined(ENABLE_NPKIT_EVENT_NET_RECV_EXIT) NpKit::CollectCpuEvent( @@ -1526,7 +1568,7 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->posted += args->sliceSteps; ncclProfilerRecordProxyOpEventState(s+i, args, sub->posted, sub->transSize, ncclProfilerProxyOpRecvPosted); - ncclProfilerRecordProxyStepEventStates(s+i, args, sub->posted-args->sliceSteps, sub->posted, ncclProfilerProxyStepRecvWait); + ncclProfilerRecordProxyStepEventState(s+i, args, postedStepId, ncclProfilerProxyStepRecvWait); } args->idle = 0; } @@ -1547,7 +1589,6 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (done) { int needFlush = 0; int totalSize = 0; - int subIndex = 0; for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; @@ -1567,27 +1608,15 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct #endif #endif - if (sub->received < sub->nsteps) { - int size = sizes[subIndex++]; - if (sub->reg) { - if (size < sub->nbytes) { - sub->recvbuff += size; - sub->nbytes -= size; - // Do one more step (at least) - sub->nsteps++; - } else { - // Reset connFifo size indicating the GPU was ready to receive. - // There is a __sync_synchronize() later to ensure it is reset before it is set again by the GPU. - struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); - volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; - connFifo[sub->base%NCCL_STEPS].size = -1; - } - } - } - sub->received += args->sliceSteps; + int receivedStepId = sub->received; + int buffSlot = (sub->base + sub->received) % NCCL_STEPS; + struct recvNetResources* resources = (struct recvNetResources*)(sub->connection->transportResources); + volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; + connFifo[buffSlot].size = -1; sub->transSize += sizes[i]; + sub->received += args->sliceSteps; ncclProfilerRecordProxyOpEventState(s+i, args, sub->received, sub->transSize, ncclProfilerProxyOpRecvReceived); - ncclProfilerRecordProxyStepEventStates(s+i, args, sub->received-args->sliceSteps, sub->received, ncclProfilerProxyStepRecvFlushWait); + ncclProfilerRecordProxyStepEventState(s+i, args, receivedStepId, ncclProfilerProxyStepRecvFlushWait); if (step < sub->nsteps) { struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); if (resources->useGdr) needFlush |= resources->needFlush; @@ -1629,10 +1658,16 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct int stepSize = resources->buffSizes[p] / NCCL_STEPS; char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); int buffSlot = (sub->base+sub->received-args->sliceSteps)%NCCL_STEPS; - ptrs[subCount] = resources->shared ? - (sub->reg ? (char*)sub->recvbuff : localBuff+resources->recvMem->connFifo[buffSlot].offset) : - localBuff+buffSlot*stepSize; - mhandles[subCount] = sub->mhandle; + if (resources->shared) { + ptrs[subCount] = sub->reg ? (char*)sub->recvbuff + step * NCCL_MAX_NET_SIZE : localBuff + resources->recvMem->connFifo[buffSlot].offset; + } else { + if (sub->reg) { + sub->ringAlgo->getNextRecvAddr(step, (uint8_t**)&ptrs[subCount], NULL, &sub->recvMhandle); + } else { + ptrs[subCount] = localBuff + buffSlot * stepSize; + } + } + mhandles[subCount] = sub->recvMhandle; subCount++; } } @@ -1660,19 +1695,16 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (done) { for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; + int transmittedStepId = sub->transmitted; sub->transmitted += args->sliceSteps; ncclProfilerRecordProxyOpEventState(s+i, args, sub->transmitted, sub->transSize, ncclProfilerProxyOpRecvTransmitted); - ncclProfilerRecordProxyStepEventStates(s+i, args, sub->transmitted-args->sliceSteps, sub->transmitted, ncclProfilerProxyStepRecvGPUWait); + ncclProfilerRecordProxyStepEventState(s+i, args, transmittedStepId, ncclProfilerProxyStepRecvGPUWait); if (step < sub->nsteps) { __sync_synchronize(); struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); volatile uint64_t* recvTail = resources->gdcSync ? resources->gdcSync : &resources->recvMem->tail; - if (sub->reg) { - // We may have added more net steps, but reg operations only have a single step w.r.t. the GPU. - if (sub->transmitted == sub->nsteps) *recvTail = sub->base + args->sliceSteps; - } else - *recvTail = sub->base + sub->transmitted; + *recvTail = sub->base + sub->transmitted; if (resources->gdcSync) wc_store_fence(); // Flush out WC write } } @@ -1686,11 +1718,12 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct struct ncclProxySubArgs* subGroup = args->subs+s; for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; + int doneStepId = sub->done; if (sub->done == sub->nsteps) continue; if (sub->transmitted > sub->done) { struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); volatile uint64_t* sendHead = &resources->sendMem->head; - uint64_t done = sub->reg ? sub->base + sub->nsteps : *sendHead; + uint64_t done = *sendHead; while (done > sub->base + sub->done && // LL and LL128 can acknowledge 0-bytes send before they even happen. Don't go past what we transmitted. sub->transmitted > sub->done) { @@ -1701,15 +1734,13 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct subGroup->recvRequestsCache[sub->done%NCCL_STEPS] = NULL; } sub->done += args->sliceSteps; - ncclProfilerStopProxyStepEvents(s+i, args, sub->done-args->sliceSteps, sub->done); + ncclProfilerStopProxyStepEvent(s+i, args, doneStepId); ncclProfilerRecordProxyOpEventState(s+i, args, sub->done, sub->transSize, ncclProfilerProxyOpRecvDone); args->idle = 0; if (sub->done == sub->nsteps) { - struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); - if (sub->reg && sub->nbytes > 0) { - NCCLCHECK(proxyState->ncclNet->deregMr(resources->netRecvComm, sub->mhandle)); - } args->done++; + if (sub->ringAlgo && sub->ringAlgo->decRefCount() == 0) delete sub->ringAlgo; + sub->ringAlgo = NULL; break; } } @@ -1726,9 +1757,228 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct return ncclSuccess; } +ncclResult_t ncclNetDeregBuffer(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* handle) { + NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgDeregister, &handle, sizeof(void*), NULL, 0)); + INFO(NCCL_REG, "rank %d - deregistered net buffer handle %p", comm->rank, handle); + return ncclSuccess; +} + +static ncclResult_t netRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, struct ncclConnector** peerConns, int nPeers, struct ncclReg* regRecord, int* outRegBufFlag, void** outHandle) { + ncclResult_t ret = ncclSuccess; + int gdrFlag = 1; + + if (regRecord) { + for (int p = 0; p < nPeers; ++p) { + struct ncclConnector* peerConn = peerConns[p]; + struct ncclProxyConnector* peerProxyConn = NULL; + struct ncclRegNetHandles* netHandle = NULL; + bool found = false; + if (peerConn == NULL) continue; + peerProxyConn = &peerConn->proxyConn; + netHandle = regRecord->netHandleHead; + while (netHandle) { + if (netHandle->proxyConn == peerProxyConn) { + found = true; + break; + } + netHandle = netHandle->next; + } + if (found) { + *outRegBufFlag = 1; + outHandle[p] = netHandle->handle; + INFO(NCCL_REG, "rank %d - NET reuse buffer %p size %ld (baseAddr %p size %ld) handle %p", comm->rank, userbuff, buffSize, (void*)regRecord->addr, regRecord->pages * comm->regCache.pageSize, netHandle->handle); + } else { + struct netRegInfo info = { regRecord->addr, regRecord->pages * comm->regCache.pageSize }; + void* handle = NULL; + + if (peerConn->conn.flags & NCCL_DIRECT_NIC) { + NCCLCHECKGOTO(ncclProxyCallBlocking(comm, peerProxyConn, ncclProxyMsgRegister, &info, sizeof(struct netRegInfo), &handle, sizeof(void*)), ret, fail); + if (handle) { + struct ncclRegNetHandles* netHandle; + regRecord->state |= NET_REG_COMPLETE; + NCCLCHECK(ncclCalloc(&netHandle, 1)); + netHandle->handle = handle; + netHandle->proxyConn = peerProxyConn; + netHandle->next = regRecord->netHandleHead; + regRecord->netHandleHead = netHandle; + outHandle[p] = handle; + *outRegBufFlag = 1; + INFO(NCCL_REG, "rank %d - NET register userbuff %p (handle %p), buffSize %ld", comm->rank, userbuff, handle, buffSize); + } else { + goto fail; + } + } else { + gdrFlag = 0; + goto fail; + } + } + } + } + +exit: + return ret; +fail: + *outRegBufFlag = 0; + WARN("rank %d failed to NET register userbuff %p buffSize %ld GDR flag %d", comm->rank, userbuff, buffSize, gdrFlag); + goto exit; +} + +ncclResult_t ncclNetLocalRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, struct ncclConnector** peerConns, int nPeers, int* outRegBufFlag, void** outHandle) { + ncclResult_t ret = ncclSuccess; + struct ncclReg *regRecord = NULL; + bool isValid = false; + + *outRegBufFlag = 0; + if (comm && userbuff && buffSize > 0 && nPeers > 0) { + NCCLCHECKGOTO(ncclRegFind(comm, userbuff, buffSize, ®Record), ret, fail); + NCCLCHECKGOTO(ncclRegLocalIsValid(regRecord, &isValid), ret, fail); + if (isValid) + NCCLCHECKGOTO(netRegisterBuffer(comm, userbuff, buffSize, peerConns, nPeers, regRecord, outRegBufFlag, outHandle), ret, fail); + } + +exit: + return ret; +fail: + *outRegBufFlag = 0; + goto exit; +} + +struct ncclNetCleanupCallback { + struct ncclCommCallback base; + struct ncclComm *comm; + struct ncclReg *reg; +}; + +static ncclResult_t cleanupNet(struct ncclComm* comm, struct ncclCommCallback* cb) { + struct ncclNetCleanupCallback* obj = (struct ncclNetCleanupCallback*)cb; + NCCLCHECK(ncclCommGraphDeregister(obj->comm, obj->reg)); + free(obj); + return ncclSuccess; +} + +ncclResult_t ncclNetGraphRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, struct ncclConnector** peerConns, int nPeers, int* outRegBufFlag, void** outHandle, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueElts) { + ncclResult_t ret = ncclSuccess; + struct ncclNetCleanupCallback *record = NULL; + struct ncclReg *regRecord = NULL; + void *baseSend; + size_t baseSendSize; + + *outRegBufFlag = 0; + if (comm && userbuff && buffSize > 0 && nPeers > 0) { + CUDACHECKGOTO(cuMemGetAddressRange((CUdeviceptr *)&baseSend, &baseSendSize, (CUdeviceptr)userbuff), ret, fail); + NCCLCHECKGOTO(ncclCommGraphRegister(comm, baseSend, baseSendSize, (void**)®Record), ret, fail); + NCCLCHECKGOTO(netRegisterBuffer(comm, userbuff, buffSize, peerConns, nPeers, regRecord, outRegBufFlag, outHandle), ret, fail); + if (*outRegBufFlag) { + NCCLCHECKGOTO(ncclCalloc(&record, 1), ret, fail); + record->base.fn = cleanupNet; + record->comm = comm; + record->reg = regRecord; + ncclIntruQueueEnqueue(cleanupQueue, (struct ncclCommCallback*)record); + if (nCleanupQueueElts) *nCleanupQueueElts += 1; + } else { + NCCLCHECKGOTO(ncclCommGraphDeregister(comm, regRecord), ret, fail); + } + } +exit: + return ret; +fail: + *outRegBufFlag = 0; + goto exit; +} + +static ncclResult_t sendProxyRegBuffer(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { + void* handle; + struct netRegInfo* info = (struct netRegInfo*)reqBuff; + struct sendNetResources* resources = (struct sendNetResources*)(connection->transportResources); + ncclResult_t ret = ncclSuccess; + bool needReg = true; + + assert(reqSize == sizeof(struct netRegInfo)); + assert(respSize == sizeof(void*)); + +#if CUDART_VERSION >= 11070 + /* DMA-BUF support */ + if (resources->useDmaBuf) { + int dmabuf_fd; + CUCHECKGOTO(cuMemGetHandleForAddressRange((void*)&dmabuf_fd, (CUdeviceptr)info->buffer, info->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0), ret, peermem); + NCCLCHECKGOTO(proxyState->ncclNet->regMrDmaBuf(resources->netSendComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, 0ULL, dmabuf_fd, &handle), ret, peermem); + (void)close(dmabuf_fd); + needReg = false; + } +peermem: +#endif + if (needReg) { + NCCLCHECKGOTO(proxyState->ncclNet->regMr(resources->netSendComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, &handle), ret, fail); + } + +exit: + memcpy(respBuff, (void*)&handle, sizeof(void*)); + *done = 1; + return ncclSuccess; +fail: + handle = NULL; + goto exit; +} + +static ncclResult_t recvProxyRegBuffer(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { + void* handle; + struct netRegInfo* info = (struct netRegInfo*)reqBuff; + struct recvNetResources* resources = (struct recvNetResources*)(connection->transportResources); + ncclResult_t ret = ncclSuccess; + bool needReg = true; + + assert(reqSize == sizeof(struct netRegInfo)); + assert(respSize == sizeof(void*)); + +#if CUDART_VERSION >= 11070 + /* DMA-BUF support */ + if (resources->useDmaBuf) { + int dmabuf_fd; + CUCHECKGOTO(cuMemGetHandleForAddressRange((void*)&dmabuf_fd, (CUdeviceptr)info->buffer, info->size, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0), ret, peermem); + NCCLCHECKGOTO(proxyState->ncclNet->regMrDmaBuf(resources->netRecvComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, 0ULL, dmabuf_fd, &handle), ret, peermem); + (void)close(dmabuf_fd); + needReg = false; + } +peermem: +#endif + if (needReg) { + NCCLCHECKGOTO(proxyState->ncclNet->regMr(resources->netRecvComm, (void*)info->buffer, info->size, NCCL_PTR_CUDA, &handle), ret, fail); + } + +exit: + memcpy(respBuff, (void*)&handle, sizeof(void*)); + *done = 1; + return ncclSuccess; +fail: + handle = NULL; + goto exit; +} + +static ncclResult_t sendProxyDeregBuffer(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, int* done) { + void* handle; + struct sendNetResources* resources = (struct sendNetResources*)(connection->transportResources); + + assert(reqSize == sizeof(void*)); + memcpy(&handle, reqBuff, sizeof(void*)); + NCCLCHECK(proxyState->ncclNet->deregMr(resources->netSendComm, handle)); + *done = 1; + return ncclSuccess; +} + +static ncclResult_t recvProxyDeregBuffer(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, int* done) { + void* handle; + struct recvNetResources* resources = (struct recvNetResources*)(connection->transportResources); + + assert(reqSize == sizeof(void*)); + memcpy(&handle, reqBuff, sizeof(void*)); + NCCLCHECK(proxyState->ncclNet->deregMr(resources->netRecvComm, handle)); + *done = 1; + return ncclSuccess; +} + struct ncclTransport netTransport = { "NET", canConnect, - { sendSetup, sendConnect, sendFree, proxySharedInit, sendProxySetup, sendProxyConnect, sendProxyFree, sendProxyProgress, NULL }, - { recvSetup, recvConnect, recvFree, proxySharedInit, recvProxySetup, recvProxyConnect, recvProxyFree, recvProxyProgress, NULL } + { sendSetup, sendConnect, sendFree, proxySharedInit, sendProxySetup, sendProxyConnect, sendProxyFree, sendProxyProgress, sendProxyRegBuffer, sendProxyDeregBuffer }, + { recvSetup, recvConnect, recvFree, proxySharedInit, recvProxySetup, recvProxyConnect, recvProxyFree, recvProxyProgress, recvProxyRegBuffer, recvProxyDeregBuffer } }; diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index 7b6d655836..8b288bca77 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -45,14 +45,12 @@ struct ncclIbMrCache { }; static int ncclNMergedIbDevs = -1; -#define NCCL_IB_MAX_DEVS_PER_NIC 2 +#define NCCL_IB_MAX_DEVS_PER_NIC 4 #define MAX_MERGED_DEV_NAME (MAXNAMESIZE*NCCL_IB_MAX_DEVS_PER_NIC)+NCCL_IB_MAX_DEVS_PER_NIC struct alignas(64) ncclIbMergedDev { - int ndevs; - int devs[NCCL_IB_MAX_DEVS_PER_NIC]; // Points to an index in ncclIbDevs + ncclNetVDeviceProps_t vProps; int speed; char devName[MAX_MERGED_DEV_NAME]; // Up to NCCL_IB_MAX_DEVS_PER_NIC * name size, and a character for each '+' - int dmaBufSupported; // 0 = uninit, 1 = yes, -1 = no }; struct ncclIbStats { @@ -72,16 +70,20 @@ struct alignas(64) ncclIbDev { ibv_pd* pd; char devName[MAXNAMESIZE]; char* pciPath; + char* virtualPciPath; int realPort; int maxQp; + float latency; struct ncclIbMrCache mrCache; int ar; // ADAPTIVE_ROUTING struct ibv_port_attr portAttr; struct ncclIbStats stats; + int dmaBufSupported; }; -#define MAX_IB_DEVS 32 -struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_DEVS]; +#define MAX_IB_DEVS 32 +#define MAX_IB_VDEVS MAX_IB_DEVS*8 +struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_VDEVS]; struct ncclIbDev ncclIbDevs[MAX_IB_DEVS]; pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER; static int ncclIbRelaxedOrderingEnabled = 0; @@ -98,7 +100,7 @@ NCCL_PARAM(IbTc, "IB_TC", 0); NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192); NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2); NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); -NCCL_PARAM(IbFifoTc, "IB_FIFO_TC", 0); +NCCL_PARAM(IbFifoTc, "IB_FIFO_TC", -1); NCCL_PARAM(IbAsyncEvents,"IB_RETURN_ASYNC_EVENTS",1); NCCL_PARAM(IbEceEnable,"IB_ECE_ENABLE",1); @@ -226,17 +228,17 @@ static void* envIbAddrRange(sa_family_t af, int* mask) { *(maskStrPtr++) = '\0'; if (inet_pton(af, addrStrPtr, ret) == 0) { - WARN("NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); return NULL; } *mask = (int)strtol(maskStrPtr, NULL, 10); if (af == AF_INET && *mask > 32) { - WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); *mask = 0; ret = NULL; } else if (af == AF_INET6 && *mask > 128) { - WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); *mask = 0; ret = NULL; } @@ -317,7 +319,7 @@ static bool validGid(union ibv_gid* gid) { static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) { char gidRoceVerStr[16] = { 0 }; char roceTypePath[PATH_MAX] = { 0 }; - sprintf(roceTypePath, "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); + snprintf(roceTypePath, sizeof(roceTypePath), "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); int fd = open(roceTypePath, O_RDONLY); if (fd == -1) { @@ -426,6 +428,16 @@ NCCL_PARAM(IbDisable, "IB_DISABLE", 0); NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1); NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1); +// Returns 0 if this is the path of two VFs of the same physical device +static int ncclIbMatchVfPath(char* path1, char* path2) { + // Merge multi-port NICs into the same PCI device + if (ncclParamIbMergeVfs()) { + return strncmp(path1, path2, strlen(path1)-4) == 0; + } else { + return strncmp(path1, path2, strlen(path1)-1) == 0; + } +} + static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) { char devicePath[PATH_MAX]; snprintf(devicePath, PATH_MAX, "/sys/class/infiniband/%s/device", devName); @@ -433,14 +445,10 @@ static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) if (p == NULL) { WARN("Could not find real path of %s (%s)", devName, devicePath); } else { - // Merge multi-port NICs into the same PCI device - p[strlen(p)-1] = '0'; - // Also merge virtual functions (VF) into the same device - if (ncclParamIbMergeVfs()) p[strlen(p)-3] = p[strlen(p)-4] = '0'; - // And keep the real port aside (the ibv port is always 1 on recent cards) + // Keep the real port aside (the ibv port is always 1 on recent cards) *realPort = 0; for (int d=0; dndevs > 1) { + WARN("NET/IB : Trying to merge multiple devices together when NCCL_IB_MERGE_NICS=0. Please enable it or disable device merging in NCCL."); + return ncclInvalidUsage; + } + + if (props->ndevs == 0) { + WARN("NET/IB : Can't make virtual NIC with 0 devices"); + return ncclInvalidUsage; + } + + if (ncclNMergedIbDevs == MAX_IB_VDEVS) { + WARN("NET/IB : Cannot allocate any more virtual devices (%d)", MAX_IB_VDEVS); + return ncclInvalidUsage; + } + + // Always count up number of merged devices + ncclIbMergedDev* mDev = ncclIbMergedDevs + ncclNMergedIbDevs; + mDev->vProps.ndevs = 0; + mDev->speed = 0; + + for (int i = 0; i < props->ndevs; i++) { + ncclIbDev* dev = ncclIbDevs + props->devs[i]; + if (mDev->vProps.ndevs == NCCL_IB_MAX_DEVS_PER_NIC) return ncclInvalidUsage; + mDev->vProps.devs[mDev->vProps.ndevs++] = props->devs[i]; + mDev->speed += dev->speed; + // Each successive time, copy the name '+' new name + if (mDev->vProps.ndevs > 1) { + snprintf(mDev->devName + strlen(mDev->devName), sizeof(mDev->devName) - strlen(mDev->devName), "+%s", dev->devName); + // First time, copy the plain name + } else { + strncpy(mDev->devName, dev->devName, MAXNAMESIZE); } } - return ncclNMergedIbDevs; + // Check link layers + ncclIbDev* dev0 = ncclIbDevs + props->devs[0]; + for (int i = 1; i < props->ndevs; i++) { + if (props->devs[i] >= ncclNIbDevs) { + WARN("NET/IB : Cannot use physical device %d, max %d", props->devs[i], ncclNIbDevs); + return ncclInvalidUsage; + } + ncclIbDev* dev = ncclIbDevs + props->devs[i]; + if (dev->link != dev0->link) { + WARN("NET/IB : Trying to merge multiple devices together with different link_layer properties %s -> %d, %s -> %d. Try only selecting NICs with one type of link using NCCL_IB_HCA", + dev0->devName, dev0->link, dev->devName, dev->link); + return ncclInvalidUsage; + } + } + + *d = ncclNMergedIbDevs++; + INFO(NCCL_NET, "NET/IB : Made virtual device [%d] name=%s speed=%d ndevs=%d", *d, mDev->devName, mDev->speed, mDev->vProps.ndevs); + return ncclSuccess; +} + +ncclResult_t ncclIbMakeVDevice(int* d, ncclNetVDeviceProps_t* props) { + pthread_mutex_lock(&ncclIbLock); + ncclResult_t res = ncclIbMakeVDeviceInternal(d, props); + pthread_mutex_unlock(&ncclIbLock); + return res; } ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { @@ -534,10 +585,6 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; } - // Should NCCL merge multi-port devices into one? - int mergeNics = ncclParamIbMergeNics(); - -build_ib_list: for (int d=0; dndevs > 1) { - // Print out merged dev info - snprintf(line+strlen(line), 2047-strlen(line), " [%d]={", d); - for (int i = 0; i < mergedDev->ndevs; i++) { - int ibDev = mergedDev->devs[i]; - snprintf(line+strlen(line), 2047-strlen(line), "[%d] %s:%d/%s%s", ibDev, ncclIbDevs[ibDev].devName, - ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", - // Insert comma to delineate - i == (mergedDev->ndevs - 1) ? "" : ", "); - } - snprintf(line+strlen(line), 2047-strlen(line), "}"); - } else { - int ibDev = mergedDev->devs[0]; - snprintf(line+strlen(line), 2047-strlen(line), " [%d]%s:%d/%s", ibDev, ncclIbDevs[ibDev].devName, - ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); - } - } - line[2047] = '\0'; - char addrline[SOCKET_NAME_MAXLEN+1]; - INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", - ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline)); } + + // Print out all net devices to the user (in the same format as before) + char line[2048]; + line[0] = '\0'; + // Determine whether RELAXED_ORDERING is enabled and possible + ncclIbRelaxedOrderingEnabled = ncclIbRelaxedOrderingCapable(); + for (int d = 0; d < ncclNIbDevs; d++) { + snprintf(line+strlen(line), sizeof(line)-strlen(line), " [%d]%s:%d/%s", d, ncclIbDevs[d].devName, + ncclIbDevs[d].portNum, ncclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); + } + char addrline[SOCKET_NAME_MAXLEN+1]; + INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", + ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline)); + pthread_mutex_unlock(&ncclIbLock); } exit: @@ -783,30 +786,26 @@ ncclResult_t ncclIbGdrSupport() { static __thread int ibDmaSupportInitDev; // which device to init, must be thread local static void ibDmaBufSupportInitOnce(){ ncclResult_t res; - // select the appropriate - struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + ibDmaSupportInitDev; - // Test each real devices int dev_fail = 0; - NCCLCHECKGOTO(rocmLibraryInit(), res, failure); - - for (int i = 0; i < mergedDev->ndevs; i++) { - int ibDev = mergedDev->devs[i]; - struct ibv_pd* pd; - struct ibv_context* ctx = ncclIbDevs[ibDev].context; - NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); - // Test kernel DMA-BUF support with a dummy call (fd=-1) - (void)wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); - // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) - dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); - NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); - // stop the search and goto failure - if (dev_fail) goto failure; - } - mergedDev->dmaBufSupported = 1; + // This is a physical device, not a virtual one, so select from ibDevs + ncclIbMergedDev* mergedDev = ncclIbMergedDevs + ibDmaSupportInitDev; + ncclIbDev* ibDev = ncclIbDevs + mergedDev->vProps.devs[0]; + struct ibv_pd* pd; + struct ibv_context* ctx = ibDev->context; + rocmLibraryInit(); + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); + // Test kernel DMA-BUF support with a dummy call (fd=-1) + (void)wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); + // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) + dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + // stop the search and goto failure + if (dev_fail) goto failure; + ibDev->dmaBufSupported = 1; return; failure: - mergedDev->dmaBufSupported = -1; + ibDev->dmaBufSupported = -1; return; } // Detect whether DMA-BUF support is present in the kernel @@ -821,21 +820,20 @@ ncclResult_t ncclIbDmaBufSupport(int dev) { // init the device only once ibDmaSupportInitDev = dev; pthread_once(&onces[dev].once, ibDmaBufSupportInitOnce); - - int dmaBufSupported = ncclIbMergedDevs[dev].dmaBufSupported; + ncclIbMergedDev* mergedDev = ncclIbMergedDevs + ibDmaSupportInitDev; + ncclIbDev* ibDev = ncclIbDevs + mergedDev->vProps.devs[0]; + int dmaBufSupported = ibDev->dmaBufSupported; if (dmaBufSupported == 1) return ncclSuccess; return ncclSystemError; } #define NCCL_NET_IB_MAX_RECVS 8 -ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { - struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs+dev; - props->name = mergedDev->devName; - props->speed = mergedDev->speed; - - // Take the rest of the properties from an arbitrary sub-device (should be the same) - struct ncclIbDev* ibDev = ncclIbDevs + mergedDev->devs[0]; +ncclResult_t ncclIbGetPhysProperties(int dev, ncclNetProperties_t* props) { + struct ncclIbDev* ibDev = ncclIbDevs + dev; + pthread_mutex_lock(&ibDev->lock); + props->name = ibDev->devName; + props->speed = ibDev->speed; props->pciPath = ibDev->pciPath; props->guid = ibDev->guid; props->ptrSupport = NCCL_PTR_HOST; @@ -846,12 +844,29 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { if (ncclIbDmaBufSupport(dev) == ncclSuccess) { props->ptrSupport |= NCCL_PTR_DMABUF; // GDR support via DMA-BUF } + props->forceFlush = 0; props->latency = 0; // Not set props->port = ibDev->portNum + ibDev->realPort; props->maxComms = ibDev->maxQp; props->maxRecvs = NCCL_NET_IB_MAX_RECVS; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->maxP2pBytes = NCCL_MAX_NET_SIZE_BYTES; + pthread_mutex_unlock(&ibDev->lock); + return ncclSuccess; +} + +ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { + if (dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Requested properties for vNic %d, only %d vNics have been created", dev, ncclNMergedIbDevs); + return ncclInvalidUsage; + } + struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + dev; + // Take the rest of the properties from an arbitrary sub-device (should be the same) + NCCLCHECK(ncclIbGetPhysProperties(mergedDev->vProps.devs[0], props)); + props->name = mergedDev->devName; + props->speed = mergedDev->speed; + memcpy(&props->vProps, &mergedDev->vProps, sizeof(ncclNetVDeviceProps_t)); return ncclSuccess; } @@ -908,6 +923,8 @@ enum ncclIbCommState { ncclIbCommStateConnecting = 6, ncclIbCommStateConnected = 7, ncclIbCommStatePendingReady = 8, + ncclIbCommStateSendDevList = 9, + ncclIbCommStateRecvDevList = 10, }; struct ncclIbCommStage { @@ -972,12 +989,12 @@ struct ncclIbListenComm { struct alignas(64) ncclIbSendFifo { uint64_t addr; - int size; + uint64_t size; uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; uint32_t nreqs; uint32_t tag; uint64_t idx; - char padding[24]; + char padding[16]; }; struct ncclIbQp { @@ -1009,7 +1026,7 @@ struct ncclIbMrHandle { }; struct alignas(32) ncclIbNetCommBase { - int ndevs; + ncclNetVDeviceProps_t vProps; bool isSend; struct ncclIbRequest reqs[MAX_REQUESTS]; struct ncclIbQp qps[NCCL_IB_MAX_QPS]; @@ -1020,6 +1037,7 @@ struct alignas(32) ncclIbNetCommBase { int ready; // Track necessary remDevInfo here int nRemDevs; + int nDataQps; struct ncclIbDevInfo remDevs[NCCL_IB_MAX_DEVS_PER_NIC]; // statistics about the comm struct ncclIbStats stats; @@ -1065,7 +1083,6 @@ struct ncclIbRemFifo { struct alignas(16) ncclIbRecvCommDev { struct ncclIbNetCommDevBase base; struct ncclIbGpuFlush gpuFlush; - uint32_t fifoRkey; struct ibv_mr* fifoMr; struct ibv_sge fifoSge; struct ibv_mr* sizesFifoMr; @@ -1073,7 +1090,7 @@ struct alignas(16) ncclIbRecvCommDev { struct ncclIbRecvComm { struct ncclIbNetCommBase base; - struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ncclIbRemFifo remFifo; int sizesFifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; int gpuFlushHostMem; @@ -1145,10 +1162,12 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, qpAttr.port_num = ib_port; qpAttr.qp_access_flags = access_flags; NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p", + ib_port, base->ibDevN, ncclIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd); return ncclSuccess; } -ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool override_tc) { +ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool fifoTc) { struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_RTR; @@ -1164,11 +1183,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint qpAttr.ah_attr.grh.flow_label = 0; qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex; qpAttr.ah_attr.grh.hop_limit = 255; - if(ncclParamIbFifoTc() && override_tc) { - qpAttr.ah_attr.grh.traffic_class = ncclParamIbFifoTc(); - } else { - qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc(); - } + qpAttr.ah_attr.grh.traffic_class = fifoTc && ncclParamIbFifoTc() != -1 ? ncclParamIbFifoTc() : ncclParamIbTc(); } else { //pick lid if subnet prefixs are same, FLID if they are not if (ncclIbExtractLocalSubnetPrefix(sGidInfo->localGid.global.subnet_prefix) == @@ -1193,6 +1208,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint qpAttr.ah_attr.sl = ncclParamIbSl(); qpAttr.ah_attr.src_path_bits = 0; qpAttr.ah_attr.port_num = info->ib_port; + TRACE(NCCL_NET, "NET/IB : ncclIbRtrQp qpn=%u mtu=%d dst=%u ll=%u port=%u", qp->qp_num, info->mtu, dest_qp_num, info->link_layer, info->ib_port); NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)); return ncclSuccess; } @@ -1239,10 +1255,12 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet int ready; *sendComm = NULL; - if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; - if (stage->state == ncclIbCommStateSend) goto ib_send; - if (stage->state == ncclIbCommStateConnecting) goto ib_connect; - if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnecting) goto ib_connect; + if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; if (stage->state != ncclIbCommStateStart) { WARN("Error: trying to connect already connected sendComm"); return ncclInternalError; @@ -1263,21 +1281,51 @@ ib_connect_check: // IB Setup struct ncclIbMergedDev* mergedDev; + if (dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Trying to use non-existant virtual device %d", dev); + return ncclInternalError; + } + mergedDev = ncclIbMergedDevs + dev; - comm->base.ndevs = mergedDev->ndevs; - comm->base.nqps = ncclParamIbQpsPerConn() * comm->base.ndevs; // We must have at least 1 qp per-device + comm->base.vProps = mergedDev->vProps; comm->base.isSend = true; + stage->state = ncclIbCommStateSendDevList; + stage->offset = 0; + struct ncclIbConnectionMetadata meta; + NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(meta)), ret, fail); + memcpy(stage->buffer, &mergedDev->vProps, sizeof(ncclNetVDeviceProps_t)); + +// In the case of mismatched nDevs, we will make sure that both sides of a logical connection have the same number of RC qps +ib_send_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + + stage->state = ncclIbCommStateRecvDevList; + stage->offset = 0; + +ib_recv_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + stage->offset = 0; + ncclNetVDeviceProps_t remoteVProps; + memcpy(&remoteVProps, stage->buffer, sizeof(ncclNetVDeviceProps_t)); + mergedDev = ncclIbMergedDevs + dev; + comm->base.vProps = mergedDev->vProps; + int localNqps, remoteNqps; + localNqps = ncclParamIbQpsPerConn() * comm->base.vProps.ndevs; // We must have at least 1 qp per-device + remoteNqps = ncclParamIbQpsPerConn() * remoteVProps.ndevs; + comm->base.nqps = remoteNqps > localNqps ? remoteNqps : localNqps; // Select max nqps (local or remote) // Init PD, Ctx for each IB device comm->ar = 1; // Set to 1 for logic - for (int i = 0; i < mergedDev->ndevs; i++) { - int ibDevN = mergedDev->devs[i]; + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + int ibDevN = comm->base.vProps.devs[i]; NCCLCHECKGOTO(ncclIbInitCommDevBase(ibDevN, &comm->devs[i].base, &comm->base.stats), ret, fail); - comm->ar = comm->ar && ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled + comm->ar = comm->ar && ncclIbDevs[ibDevN].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled } - struct ncclIbConnectionMetadata meta; - meta.ndevs = comm->base.ndevs; + memset(&meta, 0, sizeof(meta)); + meta.ndevs = comm->base.vProps.ndevs; // Alternate QPs between devices int devIndex; @@ -1296,10 +1344,10 @@ ib_connect_check: } else { meta.qpInfo[q].ece_supported = 0; } - devIndex = (devIndex + 1) % comm->base.ndevs; + devIndex = (devIndex + 1) % comm->base.vProps.ndevs; } - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { ncclIbSendCommDev* commDev = comm->devs + i; ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; @@ -1326,7 +1374,7 @@ ib_connect_check: // Print just the QPs for this dev if (comm->base.qps[q].devIndex == i) INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d subnet-prefix %lu FLID %d fifoRkey=0x%x fifoLkey=0x%x", - comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", + comm->base.vProps.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->gid.global.subnet_prefix, ncclIbExtractFlid(&devInfo->gid), devInfo->fifoRkey, commDev->fifoMr->lkey); } @@ -1335,7 +1383,7 @@ ib_connect_check: // Print just the QPs for this dev if (comm->base.qps[q].devIndex == i) INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", - comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, + comm->base.vProps.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex, devInfo->gid.global.subnet_prefix, devInfo->gid.global.interface_id, devInfo->fifoRkey, commDev->fifoMr->lkey); } @@ -1346,7 +1394,6 @@ ib_connect_check: stage->state = ncclIbCommStateSend; stage->offset = 0; - NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(meta)), ret, fail); memcpy(stage->buffer, &meta, sizeof(meta)); @@ -1367,17 +1414,12 @@ ib_connect: memcpy(&remMeta, stage->buffer, sizeof(ncclIbConnectionMetadata)); comm->base.nRemDevs = remMeta.ndevs; - if (comm->base.nRemDevs != comm->base.ndevs) { - mergedDev = ncclIbMergedDevs + dev; - WARN("NET/IB : Local mergedDev=%s has a different number of devices=%d as remoteDev=%s nRemDevs=%d", - mergedDev->devName, comm->base.ndevs, remMeta.devName, comm->base.nRemDevs); - } int link_layer; link_layer = remMeta.devs[0].link_layer; for (int i = 1; i < remMeta.ndevs; i++) { if (remMeta.devs[i].link_layer != link_layer) { - WARN("NET/IB : Can't merge net devices with different link_layer. i=%d remMeta.ndevs=%d link_layer=%d rem_link_layer=%d", + WARN("NET/IB : Can't connect net devices with different link_layer. i=%d remMeta.ndevs=%d link_layer=%d rem_link_layer=%d", i, remMeta.ndevs, link_layer, remMeta.devs[i].link_layer); return ncclInternalError; } @@ -1393,7 +1435,7 @@ ib_connect: comm->remSizesFifo.addr = remMeta.fifoAddr; } - for (int i=0; i < comm->base.ndevs; i++) { + for (int i=0; i < comm->base.vProps.ndevs; i++) { NCCLCHECKGOTO(wrap_ibv_reg_mr(comm->remSizesFifo.mrs+i, comm->devs[i].base.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); } comm->base.nRemDevs = remMeta.ndevs; @@ -1411,6 +1453,8 @@ ib_connect: if (remQpInfo->ece_supported) NCCLCHECKGOTO(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported), ret, fail); + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; + remDevInfo->mtu = std::min(remDevInfo->mtu, ibDev->portAttr.active_mtu); NCCLCHECKGOTO(ncclIbRtrQp(qp, &commDev->base.gidInfo, remQpInfo->qpn, remDevInfo, false), ret, fail); NCCLCHECKGOTO(ncclIbRtsQp(qp), ret, fail); } @@ -1425,6 +1469,8 @@ ib_connect: } } + comm->base.nDataQps = std::max(comm->base.vProps.ndevs, comm->base.nRemDevs); + comm->base.ready = 1; stage->state = ncclIbCommStateConnected; stage->offset = 0; @@ -1443,6 +1489,50 @@ fail: goto exit; } +NCCL_PARAM(IbWarnRailLocal, "IB_WARN_RAIL_LOCAL", 0); + +ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDeviceProps_t* vProps2) { + ncclNetVDeviceProps_t outVProps = {0}; + ncclNetVDeviceProps_t* minVProps = vProps2; + ncclNetVDeviceProps_t* maxVProps = vProps1; + if (vProps2->ndevs > vProps1->ndevs) { + minVProps = vProps1; + maxVProps = vProps2; + } + + // Find the intersection of devices + for (int i = 0; i < minVProps->ndevs; i++) { + int dev = minVProps->devs[i]; + for (int j = 0; j < maxVProps->ndevs; j++) { + // Found + if (maxVProps->devs[j] == dev) { + outVProps.devs[outVProps.ndevs++] = dev; + } + } + } + + // In the case that at least one side has a fused NIC but there are no matching physical NICs, we should check if the user wants this + if (ncclParamIbWarnRailLocal() && outVProps.ndevs < maxVProps->ndevs) { + char local[128]; + int cursor = 1; + snprintf(local, sizeof(local), "%d", vProps1->devs[0]); + for (int i = 1; i < vProps1->ndevs; i++) { + snprintf(local+cursor, sizeof(local)-cursor, ",%d", vProps1->devs[i]); + cursor += 2; + } + char remote[128]; + snprintf(remote, sizeof(remote), "%d", vProps2->devs[0]); + cursor = 1; + for (int i = 1; i < vProps2->ndevs; i++) { + snprintf(remote+cursor, sizeof(remote)-cursor, ",%d", vProps2->devs[i]); + cursor += 2; + } + INFO(NCCL_NET, "NET/IB : There are mismatched physical devices between local (%s) and remote (%s). To disable this warning, set NCCL_IB_WARN_RAIL_LOCAL=0", local, remote); + } + + return ncclSuccess; +} + NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); RCCL_PARAM(IbGdrFlushGpuMemNoRelaxedOrdering, "GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING", 1); @@ -1454,7 +1544,9 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle int ready; *recvComm = NULL; - if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; if (stage->state == ncclIbCommStateRecv) goto ib_recv; if (stage->state == ncclIbCommStateSend) goto ib_send; if (stage->state == ncclIbCommStatePendingReady) goto ib_recv_ready; @@ -1470,14 +1562,49 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle NCCLCHECKGOTO(ncclSocketInit(&rComm->base.sock), ret, fail); NCCLCHECKGOTO(ncclSocketAccept(&rComm->base.sock, &lComm->sock), ret, fail); + // Alloc stage->buffer here to be used for all following steps + struct ncclIbConnectionMetadata remMeta; + stage->offset = 0; + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta))); + ib_accept_check: NCCLCHECKGOTO(ncclSocketReady(&rComm->base.sock, &ready), ret, fail); if (!ready) return ncclSuccess; - - struct ncclIbConnectionMetadata remMeta; - stage->state = ncclIbCommStateRecv; + stage->state = ncclIbCommStateRecvDevList; stage->offset = 0; - NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta)), ret, fail); + +// In the case of mismatched nDevs, we will make sure that both sides of a logical connection have the same number of RC qps +ib_recv_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + ncclNetVDeviceProps_t remoteVProps; + memcpy(&remoteVProps, stage->buffer, sizeof(ncclNetVDeviceProps_t)); + if (lComm->dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Trying to use non-existant virtual device %d", lComm->dev); + return ncclInternalError; + } + + // Reduce the physical device list and store in the connection base + struct ncclIbMergedDev* mergedDev; + mergedDev = ncclIbMergedDevs + lComm->dev; + NCCLCHECK(ncclIbCheckVProps(&mergedDev->vProps, &remoteVProps)); + rComm->base.vProps = mergedDev->vProps; + memcpy(stage->buffer, &rComm->base.vProps, sizeof(ncclNetVDeviceProps_t)); + rComm->base.isSend = false; + int localNqps, remoteNqps; + localNqps = ncclParamIbQpsPerConn() * rComm->base.vProps.ndevs; // We must have at least 1 qp per-device + remoteNqps = ncclParamIbQpsPerConn() * remoteVProps.ndevs; + rComm->base.nqps = remoteNqps > localNqps ? remoteNqps : localNqps; // Select max nqps (local or remote) + + stage->offset = 0; + stage->state = ncclIbCommStateSendDevList; + +ib_send_dev_list: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset), ret, fail); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + + stage->offset = 0; + stage->state = ncclIbCommStateRecv; ib_recv: NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(remMeta), &stage->offset), ret, fail); @@ -1488,7 +1615,6 @@ ib_recv: // IB setup // Pre-declare variables because of goto - struct ncclIbMergedDev* mergedDev; struct ncclIbDev* ibDev; int ibDevN; struct ncclIbRecvCommDev* rCommDev; @@ -1496,21 +1622,18 @@ ib_recv: struct ncclIbQp* qp; mergedDev = ncclIbMergedDevs + lComm->dev; - rComm->base.ndevs = mergedDev->ndevs; - rComm->base.nqps = ncclParamIbQpsPerConn() * rComm->base.ndevs; // We must have at least 1 qp per-device - rComm->base.isSend = false; - rComm->base.nRemDevs = remMeta.ndevs; - if (rComm->base.nRemDevs != rComm->base.ndevs) { - WARN("NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", - mergedDev->devName, rComm->base.ndevs, remMeta.devName, rComm->base.nRemDevs); + if (rComm->base.nRemDevs != rComm->base.vProps.ndevs) { + INFO(NCCL_NET, "NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", + mergedDev->devName, rComm->base.vProps.ndevs, remMeta.devName, rComm->base.nRemDevs); } // Metadata to send back to requestor (sender) struct ncclIbConnectionMetadata meta; - for (int i = 0; i < rComm->base.ndevs; i++) { + memset(&meta, 0, sizeof(meta)); + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { rCommDev = rComm->devs + i; - ibDevN = mergedDev->devs[i]; + ibDevN = rComm->base.vProps.devs[i]; NCCLCHECKGOTO(ncclIbInitCommDevBase(ibDevN, &rCommDev->base, &rComm->base.stats), ret, fail); ibDev = ncclIbDevs + ibDevN; NCCLCHECKGOTO(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &rCommDev->base.gidInfo.localGidIndex), ret, fail); @@ -1541,7 +1664,7 @@ ib_recv: ibDev = ncclIbDevs + ibDevN; NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp), ret, fail); qp->devIndex = devIndex; - devIndex = (devIndex + 1) % rComm->base.ndevs; + devIndex = (devIndex + 1) % rComm->base.vProps.ndevs; // Set the ece (enhanced connection establishment) on this QP before RTR if (remMeta.qpInfo[q].ece_supported) { @@ -1554,23 +1677,22 @@ ib_recv: // Store this in our own qpInfo for returning to the requestor if (meta.qpInfo[q].ece_supported) NCCLCHECKGOTO(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported), ret, fail); + } else { + meta.qpInfo[q].ece_supported = 0; } - bool override_tc = (q == 0) ? true : false; - NCCLCHECKGOTO(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, override_tc), ret, fail); + NCCLCHECKGOTO(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, true), ret, fail); NCCLCHECKGOTO(ncclIbRtsQp(qp->qp), ret, fail); } rComm->flushEnabled = ((ncclIbGdrSupport() == ncclSuccess || ncclIbDmaBufSupport(lComm->dev) == ncclSuccess) && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; - for (int i = 0; i < mergedDev->ndevs; i++) { + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { rCommDev = rComm->devs + i; - ibDevN = rCommDev->base.ibDevN; - ibDev = ncclIbDevs + ibDevN; + ibDev = ncclIbDevs + rCommDev->base.ibDevN; // Retain remote fifo info and prepare my RDMA ops - rCommDev->fifoRkey = remMeta.devs[i].fifoRkey; rComm->remFifo.addr = remMeta.fifoAddr; NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; @@ -1606,17 +1728,14 @@ ib_recv: } // Fill Handle - meta.devs[i].lid = ibDev->portAttr.lid; - meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; - meta.devs[i].ib_port = ibDev->portNum; + meta.devs[i].lid = ibDev->portAttr.lid; + meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + meta.devs[i].ib_port = ibDev->portNum; meta.devs[i].gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix; meta.devs[i].gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id; + meta.devs[i].mtu = ibDev->portAttr.active_mtu; meta.devs[i].ibv_dev_index = rCommDev->base.ibDevN; - // Adjust the MTU - remMeta.devs[i].mtu = (enum ibv_mtu) std::min(remMeta.devs[i].mtu, ibDev->portAttr.active_mtu); - meta.devs[i].mtu = remMeta.devs[i].mtu; - // Prepare sizes fifo NCCLCHECKGOTO(wrap_ibv_reg_mr(&rComm->devs[i].sizesFifoMr, rComm->devs[i].base.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); meta.devs[i].fifoRkey = rComm->devs[i].sizesFifoMr->rkey; @@ -1627,9 +1746,9 @@ ib_recv: meta.qpInfo[q].qpn = rComm->base.qps[q].qp->qp_num; meta.qpInfo[q].devIndex = rComm->base.qps[q].devIndex; } - - meta.ndevs = rComm->base.ndevs; + meta.ndevs = rComm->base.vProps.ndevs; strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); + rComm->base.nDataQps = std::max(rComm->base.vProps.ndevs, rComm->base.nRemDevs); stage->state = ncclIbCommStateSend; stage->offset = 0; @@ -1759,7 +1878,7 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui assert(size > 0); struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) malloc(sizeof(struct ncclIbMrHandle)); - for (int i = 0; i < base->ndevs; i++) { + for (int i = 0; i < base->vProps.ndevs; i++) { // Each ncclIbNetCommDevBase is at different offset in send and recv netComms struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); NCCLCHECKGOTO(ncclIbRegMrDmaBufInternal(devComm, data, size, type, offset, fd, mhandleWrapper->mrs + i), ret, fail); @@ -1803,9 +1922,11 @@ returning: } ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { + if (mhandle == NULL) return ncclSuccess; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; - for (int i = 0; i < base->ndevs; i++) { + for (int i = 0; i < base->vProps.ndevs; i++) { // Each ncclIbNetCommDevBase is at different offset in send and recv netComms struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); NCCLCHECK(ncclIbDeregMrInternal(devComm, mhandleWrapper->mrs[i])); @@ -1870,7 +1991,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { // Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work const int align = 128; - int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; for (int i = 0; i < nqps; i++) { int qpIndex = comm->base.qpIndex; ncclIbQp* qp = comm->base.qps + qpIndex; @@ -1924,7 +2045,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { return ncclSuccess; } -ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { +ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->base.ready == 0"); return ncclInternalError; } if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } @@ -1954,7 +2075,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh char line[SOCKET_NAME_MAXLEN + 1]; union ncclSocketAddress addr; ncclSocketGetAddr(&comm->base.sock, &addr); - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkeys[0]=%x", + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); return ncclInternalError; } @@ -1970,7 +2091,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh req->send.offset = 0; // Populate events - int nEvents = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + int nEvents = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; int qpIndex = comm->base.qpIndex; // Count down while (nEvents > 0) { @@ -1985,7 +2106,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh } // Store all lkeys - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { req->send.lkeys[i] = mhandleWrapper->mrs[i]->lkey; } @@ -2011,7 +2132,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh return ncclSuccess; } -ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { +ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); @@ -2023,14 +2144,14 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int // Select the next devIndex (local) and QP to use for posting this CTS message // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex; - comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.ndevs; + comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; for (int i=0; ibase.ndevs; j++) + for (int j = 0; j < comm->base.vProps.ndevs; j++) localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; localElem[i].nreqs = n; @@ -2093,7 +2214,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int return ncclSuccess; } -ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { +ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0"); return ncclInternalError; } if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } @@ -2106,7 +2227,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta req->sock = &comm->base.sock; req->nreqs = n; - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { req->devBases[i] = &comm->devs[i].base; } @@ -2118,7 +2239,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta TIME_START(1); // Select either all QPs, or one qp per-device - const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; // Post recvs struct ibv_recv_wr* bad_wr; @@ -2154,7 +2275,7 @@ ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** struct ncclIbMrHandle* mhandle = (struct ncclIbMrHandle*) mhandles[last]; // We don't know which devIndex the recv was on, so we flush on all devices - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); wr.wr_id = req - comm->base.reqs; @@ -2201,7 +2322,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { *done = 0; while (1) { NCCLCHECK(ncclIbStatsCheckFatalCount(&r->base->stats,__func__)); - if (r->events[0] == 0 && r->events[1] == 0) { + if (r->events[0] == 0 && r->events[1] == 0 && r->events[2] == 0 && r->events[3] == 0) { TRACE(NCCL_NET, "r=%p done", r); *done = 1; if (sizes && r->type == NCCL_NET_IB_REQ_RECV) { @@ -2235,13 +2356,13 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { char remoteGidString[INET6_ADDRSTRLEN] = ""; const char* localGidStr = NULL, *remoteGidStr = NULL; if (r->devBases[i]->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) { - localGidStr = inet_ntop(AF_INET6, &r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); - remoteGidStr = inet_ntop(AF_INET6, &r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); + localGidStr = ibvGetGidStr(&r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); + remoteGidStr = ibvGetGidStr(&r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); } char line[SOCKET_NAME_MAXLEN+1]; char *hcaName = r->devBases[i]->pd->context->device->name; - WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s hca %s", + WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%u vendor err %u (%s)%s%s%s%s hca %s", ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString, hcaName); return ncclRemoteError; @@ -2253,7 +2374,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; - TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%ld r=%p type=%d events={%d,%d}, i=%d", + TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%u wr_id=%lu r=%p type=%d events={%d,%d}, i=%d", ncclSocketToString(&addr, line), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i); #endif if (req && req->type == NCCL_NET_IB_REQ_SEND) { @@ -2297,7 +2418,7 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { for (int q = 0; q < comm->base.nqps; q++) if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { struct ncclIbSendCommDev* commDev = comm->devs + i; if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); if (comm->remSizesFifo.mrs[i] != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mrs[i])); @@ -2317,7 +2438,7 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) { for (int q = 0; q < comm->base.nqps; q++) if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { struct ncclIbRecvCommDev* commDev = comm->devs + i; if (comm->flushEnabled) { if (commDev->gpuFlush.gpuFlushGpuMem != nullptr) { @@ -2366,5 +2487,11 @@ ncclNet_t ncclNetIb = { ncclIbCloseRecv, ncclIbCloseListen, NULL /* getDeviceMr */, - NULL /* irecvConsumed */ -}; \ No newline at end of file + NULL /* irecvConsumed */, + ncclIbMakeVDevice +}; + +/* + ncclIbSetProperties, + ncclIbRefreshDevices +*/ diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 73a5d55b00..235dee865a 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -44,6 +44,7 @@ ncclResult_t ncclNetSocketInit(ncclDebugLogger_t logFunction) { ncclNetIfs = ncclFindInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS); if (ncclNetIfs <= 0) { WARN("NET/Socket : no interface found"); + pthread_mutex_unlock(&ncclNetSocketLock); return ncclInternalError; } else { #define MAX_LINE_LEN (2047) @@ -76,7 +77,7 @@ static ncclResult_t ncclNetSocketGetSpeed(char* devName, int* speed) { ncclResult_t ret = ncclSuccess; *speed = 0; char speedPath[PATH_MAX]; - sprintf(speedPath, "/sys/class/net/%s/speed", devName); + snprintf(speedPath, sizeof(speedPath), "/sys/class/net/%s/speed", devName); int fd = -1; SYSCHECKSYNC(open(speedPath, O_RDONLY), "open", fd); if (fd != -1) { @@ -102,6 +103,7 @@ ncclResult_t ncclNetSocketGetProperties(int dev, ncclNetProperties_t* props) { props->guid = dev; props->ptrSupport = NCCL_PTR_HOST; props->regIsGlobal = 0; + props->forceFlush = 0; NCCLCHECK(ncclNetSocketGetSpeed(props->name, &props->speed)); props->latency = 0; // Not set props->port = 0; @@ -109,6 +111,7 @@ ncclResult_t ncclNetSocketGetProperties(int dev, ncclNetProperties_t* props) { props->maxRecvs = 1; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->maxP2pBytes = NCCL_MAX_NET_SIZE_BYTES; return ncclSuccess; } @@ -297,6 +300,7 @@ fail: ncclResult_t ncclNetSocketListen(int dev, void* opaqueHandle, void** listenComm) { if (dev < 0 || dev >= ncclNetIfs) { // data transfer socket is based on specified dev + WARN("NET/Socket : ncclNetSocketListen dev=%d ncclNetIfs=%d", dev, ncclNetIfs); return ncclInternalError; } ncclResult_t ret = ncclSuccess; @@ -558,16 +562,16 @@ ncclResult_t ncclNetSocketRegMr(void* comm, void* data, size_t size, int type, v } ncclResult_t ncclNetSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; } -ncclResult_t ncclNetSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { +ncclResult_t ncclNetSocketIsend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)sendComm; - NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, (struct ncclNetSocketRequest**)request)); + NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_SEND, data, (int) size, (struct ncclNetSocketRequest**)request)); return ncclSuccess; } -ncclResult_t ncclNetSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { +ncclResult_t ncclNetSocketIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { struct ncclNetSocketComm* comm = (struct ncclNetSocketComm*)recvComm; if (n != 1) return ncclInternalError; - NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], sizes[0], (struct ncclNetSocketRequest**)request)); + NCCLCHECK(ncclNetSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], (int)sizes[0], (struct ncclNetSocketRequest**)request)); return ncclSuccess; } @@ -632,5 +636,6 @@ ncclNet_t ncclNetSocket = { ncclNetSocketClose, ncclNetSocketCloseListen, NULL /* getDeviceMr */, - NULL /* irecvConsumed */ + NULL /* irecvConsumed */, + NULL /* mergeDevices */ }; diff --git a/src/transport/nvls.cc b/src/transport/nvls.cc index aa9c486b14..582c30a353 100644 --- a/src/transport/nvls.cc +++ b/src/transport/nvls.cc @@ -108,11 +108,12 @@ ncclResult_t nvlsGroupUnbind(struct ncclComm *comm, size_t size, CUmemGenericAll return ncclSuccess; } -ncclResult_t ncclNvlsDeregBuffer(CUmemGenericAllocationHandle *mcHandler, CUdeviceptr ptr, int dev, size_t size) { +ncclResult_t ncclNvlsDeregBuffer(struct ncclComm* comm, CUmemGenericAllocationHandle *mcHandler, CUdeviceptr ptr, int dev, size_t size) { CUCHECK(cuMulticastUnbind(*mcHandler, dev, 0/*mcOffset*/, size)); CUCHECK(cuMemUnmap(ptr, size)); CUCHECK(cuMemAddressFree(ptr, size)); CUCHECK(cuMemRelease(*mcHandler)); + INFO(NCCL_NVLS, "rank %d - NVLS deregistered buffer %p on device %d, size %ld", comm->rank, (void*)ptr, dev, size); return ncclSuccess; } @@ -450,11 +451,11 @@ setup: if (comm->localRank == 0) { shmPath[0] = '\0'; - NCCLCHECKGOTO(ncclShmOpen(shmPath, (sizeof(size_t) + typeSize * comm->localRanks) * 2, (void**)&nvlsShmem, NULL, comm->localRanks - 1, &comm->nvlsResources->nvlsShmemHandle), res, fail); + NCCLCHECKGOTO(ncclShmOpen(shmPath, sizeof(shmPath), (sizeof(size_t) + typeSize * comm->localRanks) * 2, (void**)&nvlsShmem, NULL, comm->localRanks - 1, &comm->nvlsResources->nvlsShmemHandle), res, fail); NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shmPath, sizeof(shmPath)), res, fail); } else { NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shmPath, sizeof(shmPath)), res, fail); - NCCLCHECKGOTO(ncclShmOpen(shmPath, (sizeof(size_t) + typeSize * comm->localRanks) * 2, (void**)&nvlsShmem, NULL, -1, &comm->nvlsResources->nvlsShmemHandle), res, fail); + NCCLCHECKGOTO(ncclShmOpen(shmPath, sizeof(shmPath), (sizeof(size_t) + typeSize * comm->localRanks) * 2, (void**)&nvlsShmem, NULL, -1, &comm->nvlsResources->nvlsShmemHandle), res, fail); } /* need 2 pools and a shared counter for shmem-based collectives */ comm->nvlsResources->nvlsShmem.cnt[0] = (size_t*)nvlsShmem; @@ -495,7 +496,7 @@ ncclResult_t ncclNvlsFree(struct ncclComm* comm) { return ncclSuccess; } -ncclResult_t tryRegisterBuffer(struct ncclComm *comm, uintptr_t userBuff, size_t buffSize, CUdeviceptr *regAddr, bool *regUsed) { +ncclResult_t tryRegisterBuffer(struct ncclComm *comm, uintptr_t userBuff, size_t buffSize, CUdeviceptr *regAddr, int *regUsed) { ncclResult_t ret = ncclSuccess; struct ncclReg *regRecord = NULL; CUdeviceptr regPtr = 0; @@ -601,43 +602,33 @@ ncclResult_t tryRegisterBuffer(struct ncclComm *comm, uintptr_t userBuff, size_t } *regAddr = (uintptr_t)regPtr + regData[comm->localRank].offset; - *regUsed = true; + *regUsed = 1; exit: free(regData); return ret; fail: - *regUsed = false; + *regUsed = 0; goto exit; } -ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, bool *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv) { +static ncclResult_t nvlsRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, struct ncclReg *sendRegRecord, struct ncclReg *recvRegRecord, int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv) { ncclResult_t ret = ncclSuccess; - bool localRegBufUsed = false; + int regBufUsed = 0; struct localRegData *regData = NULL; bool sendNeedReg = false, recvNeedReg = false; CUdeviceptr regSendPtr = 0; CUdeviceptr regRecvPtr = 0; - struct ncclReg *sendRegRecord = NULL; - struct ncclReg *recvRegRecord = NULL; - - *outRegBufUsed = false; NCCLCHECKGOTO(ncclCalloc(®Data, comm->localRanks * 2), ret, fail); - if (sendbuff) { - NCCLCHECKGOTO(ncclRegFind(comm, sendbuff, sendbuffSize, &sendRegRecord), ret, fail); - if (sendRegRecord) { - memcpy(®Data[comm->localRank * 2].reg, sendRegRecord, sizeof(struct ncclReg)); - regData[comm->localRank * 2].offset = (uintptr_t)sendbuff - sendRegRecord->addr; - } + if (sendRegRecord) { + memcpy(®Data[comm->localRank * 2].reg, sendRegRecord, sizeof(struct ncclReg)); + regData[comm->localRank * 2].offset = (uintptr_t)sendbuff - sendRegRecord->addr; } - if (recvbuff) { - NCCLCHECKGOTO(ncclRegFind(comm, recvbuff, recvbuffSize, &recvRegRecord), ret, fail); - if (recvRegRecord) { - memcpy(®Data[comm->localRank * 2 + 1].reg, recvRegRecord, sizeof(struct ncclReg)); - regData[comm->localRank * 2 + 1].offset = (uintptr_t)recvbuff - recvRegRecord->addr; - } + if (recvRegRecord) { + memcpy(®Data[comm->localRank * 2 + 1].reg, recvRegRecord, sizeof(struct ncclReg)); + regData[comm->localRank * 2 + 1].offset = (uintptr_t)recvbuff - recvRegRecord->addr; } NCCLCHECKGOTO(ncclShmemAllgather(comm, &comm->nvlsResources->nvlsShmem, regData + comm->localRank * 2, regData, sizeof(struct localRegData) * 2), ret, fail); @@ -682,229 +673,127 @@ ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *send } if ((!sendNeedReg || sendbuff == NULL) && (!recvNeedReg || recvbuff == NULL)) { - localRegBufUsed = true; - INFO(NCCL_NVLS, "rank %d reuse local-registered NVLS sendbuff %p, recvbuff %p, sendbuff size %ld, recvbuff size %ld, reg sendbuff %p, reg recvbuff %p", comm->rank, sendbuff, recvbuff, sendbuffSize, recvbuffSize, (void*)regSendPtr, (void*)regRecvPtr); + regBufUsed = 1; + INFO(NCCL_REG, "rank %d reuse registered NVLS sendbuff %p, recvbuff %p, sendbuff size %ld, recvbuff size %ld, reg sendbuff %p, reg recvbuff %p", comm->rank, sendbuff, recvbuff, sendbuffSize, recvbuffSize, (void*)regSendPtr, (void*)regRecvPtr); goto exit; } /* Start Registration. Not found registered buffers, then check whether both send and recv buffer locate * in register request cache. */ - if (sendNeedReg && sendbuff) { - tryRegisterBuffer(comm, (uintptr_t)sendbuff, sendbuffSize, ®SendPtr, &localRegBufUsed); - if (localRegBufUsed == false) goto fail; + if (sendNeedReg && sendbuff && sendbuffSize > 0) { + tryRegisterBuffer(comm, (uintptr_t)sendbuff, sendbuffSize, ®SendPtr, ®BufUsed); + if (regBufUsed == 0) goto fail; } - if (recvNeedReg && recvbuff) { - tryRegisterBuffer(comm, (uintptr_t)recvbuff, recvbuffSize, ®RecvPtr, &localRegBufUsed); - if (localRegBufUsed == false) goto fail; + if (recvNeedReg && recvbuff && recvbuffSize > 0) { + tryRegisterBuffer(comm, (uintptr_t)recvbuff, recvbuffSize, ®RecvPtr, ®BufUsed); + if (regBufUsed == 0) goto fail; } - INFO(NCCL_NVLS, "rank %d successfully local-registered NVLS sendbuff %p, recvbuff %p, sendbuff size %ld, recvbuff size %ld, reg sendbuff %p, reg recvbuff %p", comm->rank, sendbuff, recvbuff, sendbuffSize, recvbuffSize, (void*)regSendPtr, (void*)regRecvPtr); + INFO(NCCL_REG, "rank %d successfully registered NVLS sendbuff %p, recvbuff %p, sendbuff size %ld, recvbuff size %ld, reg sendbuff %p, reg recvbuff %p", comm->rank, sendbuff, recvbuff, sendbuffSize, recvbuffSize, (void*)regSendPtr, (void*)regRecvPtr); exit: *outRegBufSend = (void*)regSendPtr; *outRegBufRecv = (void*)regRecvPtr; - *outRegBufUsed = localRegBufUsed; + *outRegBufUsed = regBufUsed; free(regData); return ncclSuccess; fail: - localRegBufUsed = false; + regBufUsed = 0; + WARN("rank %d failed to NVLS register sendbuff %p sendbuffSize %ld recvbuff %p recvbuffSize %ld", comm->rank, sendbuff, sendbuffSize, recvbuff, recvbuffSize); goto exit; } +ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv) { + struct ncclReg *sendRegRecord = NULL; + struct ncclReg *recvRegRecord = NULL; + bool sendIsValid = false; + bool recvIsValid = false; + + *outRegBufUsed = 0; + if (sendbuff) { + NCCLCHECK(ncclRegFind(comm, sendbuff, sendbuffSize, &sendRegRecord)); + NCCLCHECK(ncclRegLocalIsValid(sendRegRecord, &sendIsValid)); + } else { + sendIsValid = true; + } + if (recvbuff) { + NCCLCHECK(ncclRegFind(comm, recvbuff, recvbuffSize, &recvRegRecord)); + NCCLCHECK(ncclRegLocalIsValid(recvRegRecord, &recvIsValid)); + } else { + recvIsValid = true; + } + + if (sendIsValid && recvIsValid) + NCCLCHECK(nvlsRegisterBuffer(comm, sendbuff, recvbuff, sendbuffSize, recvbuffSize, sendRegRecord, recvRegRecord, outRegBufUsed, outRegBufSend, outRegBufRecv)); + + return ncclSuccess; +} + struct ncclNvlsCleanupCallback { struct ncclCommCallback base; - CUmemGenericAllocationHandle mcHandle; - CUdeviceptr ptr; - int dev; - size_t size; + struct ncclReg *reg; + struct ncclComm *comm; }; static ncclResult_t cleanupNvls(struct ncclComm* comm, struct ncclCommCallback* cb) { struct ncclNvlsCleanupCallback* obj = (struct ncclNvlsCleanupCallback*)cb; - NCCLCHECK(ncclNvlsDeregBuffer(&obj->mcHandle, obj->ptr, obj->dev, obj->size)); - INFO(NCCL_NVLS, "rank %d - deregistered buffer %p on device %d, size %ld", comm->rank, (void*)obj->ptr, obj->dev, obj->size); + NCCLCHECK(ncclCommGraphDeregister(obj->comm, obj->reg)); free(obj); return ncclSuccess; } ncclResult_t ncclNvlsGraphRegisterBuffer( struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, - bool *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv, + int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueEltsAdded ) { - ncclResult_t ret = ncclSuccess; - bool localRegBufUsed = false; struct ncclNvlsCleanupCallback* sendRecord = NULL; struct ncclNvlsCleanupCallback* recvRecord = NULL; - CUdeviceptr regSendPtr = 0; - CUdeviceptr regRecvPtr = 0; - CUmulticastObjectProp mcprop; - CUmemAllocationProp ucprop; - char shareableHandle[NVLS_HANDLE_SIZE]; - CUmemGenericAllocationHandle sendMcHandle, recvMcHandle; - size_t sendGran = 0, recvGran = 0; - bool *regBufFlags = NULL; - struct graphRegData *rdata = NULL; - const void *baseSend = NULL; - const void *baseRecv = NULL; - size_t baseSendSize = 1; - size_t baseRecvSize = 1; - size_t ucgran; + void *baseSend = NULL; + void *baseRecv = NULL; + size_t baseSendSize = 0; + size_t baseRecvSize = 0; + struct ncclReg *sendRegRecord = NULL; + struct ncclReg *recvRegRecord = NULL; - *outRegBufUsed = false; - NCCLCHECKGOTO(ncclCalloc(®BufFlags, comm->localRanks), ret, fail); - NCCLCHECKGOTO(ncclCalloc(&rdata, comm->localRanks), ret, fail); - - if (sendbuffSize > 0 || recvbuffSize > 0) { - /* retrieve base pointer and size */ - if (CUPFN(cuMemGetAddressRange) == nullptr) goto fail; - if (sendbuff != NULL) - CUCHECKGOTO(cuMemGetAddressRange((CUdeviceptr *)&baseSend, &baseSendSize, (CUdeviceptr)sendbuff), ret, fail); - if (recvbuff != NULL) - CUCHECKGOTO(cuMemGetAddressRange((CUdeviceptr *)&baseRecv, &baseRecvSize, (CUdeviceptr)recvbuff), ret, fail); - - memset(&ucprop, 0, sizeof(CUmemAllocationProp)); - ucprop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - ucprop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - ucprop.location.id = comm->cudaDev; - ucprop.requestedHandleTypes = ncclCuMemHandleType; - CUCHECKGOTO(cuMemGetAllocationGranularity(&ucgran, &ucprop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED), ret, fail); - - localRegBufUsed = ((uint64_t)baseSend % ucgran != 0 || (uint64_t)baseRecv % ucgran != 0) ? false : true; - regBufFlags[comm->localRank] = localRegBufUsed; - NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, regBufFlags, sizeof(bool)), ret, fail); - for (int i = 0; i < comm->localRanks; ++i) - if (regBufFlags[i] == false) goto fail; - - memset(&mcprop, 0, sizeof(CUmulticastObjectProp)); - mcprop.numDevices = comm->localRanks; - mcprop.handleTypes = ncclCuMemHandleType; - mcprop.flags = 0; - - if (sendbuff != NULL) { - mcprop.size = baseSendSize; - CUCHECKGOTO(cuMulticastGetGranularity(&sendGran, &mcprop, CU_MULTICAST_GRANULARITY_RECOMMENDED), ret, fail); - - /* check send buffer offset and size */ - rdata[comm->localRank].offset = (uintptr_t)sendbuff - (uintptr_t)baseSend; - rdata[comm->localRank].size = baseSendSize; - NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, rdata, sizeof(struct graphRegData)), ret, fail); - baseSendSize = rdata[0].size; - for (int i = 1; i < comm->localRanks; ++i) { - if (rdata[0].offset != rdata[i].offset) goto fail; - if (baseSendSize > rdata[i].size) baseSendSize = rdata[i].size; - } - if (baseSendSize % sendGran != 0) goto fail; - - mcprop.size = baseSendSize; - - /* register sendbuff */ - if (comm->localRank == 0) { - NCCLCHECKGOTO(nvlsGroupCreate(comm, &mcprop, comm->localRank, comm->localRanks, &sendMcHandle, shareableHandle), ret, fail); - NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shareableHandle, NVLS_HANDLE_SIZE), ret, fail); - } else { - NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shareableHandle, NVLS_HANDLE_SIZE), ret, fail); - NCCLCHECKGOTO(nvlsGroupConnect(comm, shareableHandle, comm->localRankToRank[0], &sendMcHandle), ret, fail); - } - - CUCHECKGOTO(cuMulticastAddDevice(sendMcHandle, comm->nvlsResources->dev), ret, fail); - CUCHECKGOTO(cuMulticastBindAddr(sendMcHandle, 0, (CUdeviceptr)baseSend, baseSendSize, 0), ret, fail); - - // Create a VA for the NVLS - CUCHECKGOTO(cuMemAddressReserve(®SendPtr, baseSendSize, sendGran, 0U, 0), ret, fail); - // Map the VA locally - CUCHECKGOTO(cuMemMap(regSendPtr, baseSendSize, 0, sendMcHandle, 0), ret, fail); - CUCHECKGOTO(cuMemSetAccess(regSendPtr, baseSendSize, &comm->nvlsResources->accessDesc, 1), ret, fail); - - sendRecord = (struct ncclNvlsCleanupCallback*)malloc(sizeof(struct ncclNvlsCleanupCallback)); - sendRecord->base.fn = cleanupNvls; - sendRecord->mcHandle = sendMcHandle; - sendRecord->ptr = regSendPtr; - sendRecord->dev = comm->nvlsResources->dev; - sendRecord->size = baseSendSize; - } - - if (recvbuff != NULL) { - mcprop.size = baseRecvSize; - CUCHECKGOTO(cuMulticastGetGranularity(&recvGran, &mcprop, CU_MULTICAST_GRANULARITY_RECOMMENDED), ret, fail); - - rdata[comm->localRank].offset = (uintptr_t)recvbuff - (uintptr_t)baseRecv; - rdata[comm->localRank].size = baseRecvSize; - NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, rdata, sizeof(struct graphRegData)), ret, fail); - baseRecvSize = rdata[0].size; - for (int i = 1; i < comm->localRanks; ++i) { - if (rdata[0].offset != rdata[i].offset) goto fail; - if (baseRecvSize > rdata[i].size) baseRecvSize = rdata[i].size; - } - if (baseRecvSize % recvGran != 0) goto fail; - - mcprop.size = baseRecvSize; - if (comm->localRank == 0) { - NCCLCHECKGOTO(nvlsGroupCreate(comm, &mcprop, comm->localRank, comm->localRanks, &recvMcHandle, shareableHandle), ret, fail); - NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shareableHandle, NVLS_HANDLE_SIZE), ret, fail); - } else { - NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, shareableHandle, NVLS_HANDLE_SIZE), ret, fail); - NCCLCHECKGOTO(nvlsGroupConnect(comm, shareableHandle, comm->localRankToRank[0], &recvMcHandle), ret, fail); - } - - CUCHECKGOTO(cuMulticastAddDevice(recvMcHandle, comm->nvlsResources->dev), ret, fail); - CUCHECKGOTO(cuMulticastBindAddr(recvMcHandle, 0, (CUdeviceptr)baseRecv, baseRecvSize, 0), ret, fail); - - // Create a VA for the NVLS - CUCHECKGOTO(cuMemAddressReserve(®RecvPtr, baseRecvSize, recvGran, 0U, 0), ret, fail); - // Map the VA locally - CUCHECKGOTO(cuMemMap(regRecvPtr, baseRecvSize, 0, recvMcHandle, 0), ret, fail); - CUCHECKGOTO(cuMemSetAccess(regRecvPtr, baseRecvSize, &comm->nvlsResources->accessDesc, 1), ret, fail); - - recvRecord = (struct ncclNvlsCleanupCallback*)malloc(sizeof(struct ncclNvlsCleanupCallback)); - recvRecord->base.fn = cleanupNvls; - recvRecord->mcHandle = recvMcHandle; - recvRecord->ptr = regRecvPtr; - recvRecord->dev = comm->nvlsResources->dev; - recvRecord->size = baseRecvSize; - } - - localRegBufUsed = true; + *outRegBufUsed = 0; + if (sendbuff) { + CUCHECK(cuMemGetAddressRange((CUdeviceptr *)&baseSend, &baseSendSize, (CUdeviceptr)sendbuff)); + NCCLCHECK(ncclCommGraphRegister(comm, baseSend, baseSendSize, (void**)&sendRegRecord)); } -exit: - if (localRegBufUsed == false) { - if (sendRecord) { - ncclNvlsDeregBuffer(&sendRecord->mcHandle, sendRecord->ptr, sendRecord->dev, sendRecord->size); - free(sendRecord); - } + if (recvbuff) { + CUCHECK(cuMemGetAddressRange((CUdeviceptr *)&baseRecv, &baseRecvSize, (CUdeviceptr)recvbuff)); + NCCLCHECK(ncclCommGraphRegister(comm, baseRecv, baseRecvSize, (void**)&recvRegRecord)); + } - if (recvRecord) { - // Yes, it's a dead code. That's fine... - // coverity[dead_error_begin] - ncclNvlsDeregBuffer(&recvRecord->mcHandle, recvRecord->ptr, recvRecord->dev, recvRecord->size); - free(recvRecord); - } - } else { - if (sendRecord) { - *outRegBufSend = (void*)((uintptr_t)regSendPtr + (uintptr_t)sendbuff - (uintptr_t)baseSend); + NCCLCHECK(nvlsRegisterBuffer(comm, baseSend, baseRecv, baseSendSize, baseRecvSize, sendRegRecord, recvRegRecord, outRegBufUsed, outRegBufSend, outRegBufRecv)); + + if (*outRegBufUsed) { + if (sendRegRecord) { + sendRecord = (struct ncclNvlsCleanupCallback*)malloc(sizeof(struct ncclNvlsCleanupCallback)); + sendRecord->base.fn = cleanupNvls; + sendRecord->reg = sendRegRecord; + sendRecord->comm = comm; ncclIntruQueueEnqueue(cleanupQueue, (struct ncclCommCallback*)sendRecord); *nCleanupQueueEltsAdded += 1; } - if (recvRecord) { - *outRegBufRecv = (void*)((uintptr_t)regRecvPtr + (uintptr_t)recvbuff - (uintptr_t)baseRecv); + if (recvRegRecord) { + recvRecord = (struct ncclNvlsCleanupCallback*)malloc(sizeof(struct ncclNvlsCleanupCallback)); + recvRecord->base.fn = cleanupNvls; + recvRecord->reg = recvRegRecord; + recvRecord->comm = comm; ncclIntruQueueEnqueue(cleanupQueue, (struct ncclCommCallback*)recvRecord); *nCleanupQueueEltsAdded += 1; } - - INFO(NCCL_NVLS, "rank %d successfully graph-registered sendbuff %p, recvbuff %p, sendbuff size %ld (register size %ld, sendGran %ld), recvbuff size %ld (register size %ld, recvGran %ld), reg sendbuff %p, reg recvbuff %p", comm->rank, sendbuff, recvbuff, sendbuffSize, baseSendSize, sendGran, recvbuffSize, baseRecvSize, recvGran, (void*)regSendPtr, (void*)regRecvPtr); + } else { + if (sendbuff) NCCLCHECK(ncclCommGraphDeregister(comm, sendRegRecord)); + if (recvbuff) NCCLCHECK(ncclCommGraphDeregister(comm, recvRegRecord)); } - *outRegBufUsed = localRegBufUsed; - free(regBufFlags); - free(rdata); - /* always return success. */ return ncclSuccess; -fail: - localRegBufUsed = false; - goto exit; } #else @@ -936,19 +825,19 @@ ncclResult_t ncclNvlsTreeConnect(struct ncclComm* comm) { ncclResult_t ncclNvlsGraphRegisterBuffer( struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, - bool *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv, + int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv, struct ncclIntruQueue* cleanupQueue, int* nCleanupQueueEltsAdded ) { *outRegBufUsed = false; return ncclSuccess; } -ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, bool *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv) { +ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *sendbuff, void *recvbuff, size_t sendbuffSize, size_t recvbuffSize, int *outRegBufUsed, void **outRegBufSend, void **outRegBufRecv) { *outRegBufUsed = false; return ncclSuccess; } -ncclResult_t ncclNvlsDeregBuffer(CUmemGenericAllocationHandle *mcHandler, CUdeviceptr ptr, int dev, size_t size) { +ncclResult_t ncclNvlsDeregBuffer(struct ncclComm* comm, CUmemGenericAllocationHandle *mcHandler, CUdeviceptr ptr, int dev, size_t size) { return ncclSuccess; } diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc index eb8c8b73b3..7d561582e7 100644 --- a/src/transport/p2p.cc +++ b/src/transport/p2p.cc @@ -97,6 +97,8 @@ struct p2pCuMemProxyInfo { #include +NCCL_PARAM(LegacyCudaRegister, "LEGACY_CUDA_REGISTER", 0); + /* Convert a PCI busId string into a local cudaDev device index (cf. CUDA_VISIBLE_DEVICES) */ static int busIdToCudaDev(int64_t busId) { int ndev; @@ -132,21 +134,9 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclComm* comm, struct ncclTopoGraph } #endif - // MNNVL support - if (comm->MNNVL && info1->hostHash != info2->hostHash) { - NCCLCHECK(ncclTopoCheckMNNVL(comm->topo, info1, info2, ret)); - if (*ret) return ncclSuccess; - } - - // Rule out different nodes / isolated containers - if (info1->hostHash != info2->hostHash || info1->shmDev != info2->shmDev) { - *ret = 0; - return ncclSuccess; - } - // Check topology / p2p level. int intermediateRank; - NCCLCHECK(ncclTopoCheckP2p(comm->topo, info1->rank, info2->rank, ret, NULL, &intermediateRank)); + NCCLCHECK(ncclTopoCheckP2p(comm, comm->topo, info1->rank, info2->rank, ret, NULL, &intermediateRank)); if (*ret == 0) return ncclSuccess; if (intermediateRank != -1) { if (useMemcpy) *ret = 0; @@ -161,6 +151,12 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclComm* comm, struct ncclTopoGraph return ncclSuccess; } + if (info1->hostHash != comm->peerInfo[comm->rank].hostHash || + info1->hostHash != info2->hostHash) { + // If either peer is non-local then we are done. + return ncclSuccess; + } + // Convert the peer's busId into a local cudaDev index (cf. CUDA_VISIBLE_DEVICES) int cudaDev1 = busIdToCudaDev(info1->busId); int cudaDev2 = busIdToCudaDev(info2->busId); @@ -332,11 +328,11 @@ NCCL_PARAM(P2pDirectDisable, "P2P_DIRECT_DISABLE", 0); #define P2P_SAME_PID(MYINFO, PEERINFO) ((MYINFO->hostHash == PEERINFO->hostHash) && (MYINFO->pidHash == PEERINFO->pidHash)) -static ncclResult_t p2pGetInfo(struct ncclTopoSystem* topo, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* read, int* intermediateRank) { +static ncclResult_t p2pGetInfo(struct ncclComm* comm, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* read, int* intermediateRank) { int p2p; // Queries the topology to see if the GPUs are Ampere and // connected via NVLink, if so we enable P2P Read by default - NCCLCHECK(ncclTopoCheckP2p(topo, info1->rank, info2->rank, &p2p, read, intermediateRank)); + NCCLCHECK(ncclTopoCheckP2p(comm, comm->topo, info1->rank, info2->rank, &p2p, read, intermediateRank)); int readEnable = ncclParamP2pReadEnable(); if (readEnable != -2) *read = readEnable; @@ -388,7 +384,7 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st NCCLCHECK(ncclCalloc(&resources, 1)); send->transportResources = resources; int useRead, intermediateRank; - NCCLCHECK(p2pGetInfo(comm->topo, myInfo, peerInfo, &useRead, &intermediateRank)); + NCCLCHECK(p2pGetInfo(comm, myInfo, peerInfo, &useRead, &intermediateRank)); if (useMemcpy) useRead = 0; resources->next_hdp_reg = 0; @@ -418,7 +414,6 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st info->rank = myInfo->rank; if (P2P_SAME_PID(myInfo, peerInfo) && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0) { resources->type = P2P_DIRECT; - send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, comm, comm->nRanks); } else { @@ -434,8 +429,8 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/IPC%s%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, useMemcpy ? "/CE" : "", comm, comm->nRanks); } - send->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; } + send->conn.flags |= info->read ? NCCL_P2P_READ : NCCL_P2P_WRITE; } else { resources->type = P2P_INTERMEDIATE; info->rank = intermediateRank; @@ -469,7 +464,7 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st NCCLCHECK(ncclCalloc(&resources, 1)); recv->transportResources = resources; int useRead, intermediateRank; - NCCLCHECK(p2pGetInfo(comm->topo, myInfo, peerInfo, &useRead, &intermediateRank)); + NCCLCHECK(p2pGetInfo(comm, myInfo, peerInfo, &useRead, &intermediateRank)); static_assert(sizeof(struct p2pConnectInfo) <= sizeof(struct ncclConnect), "p2p Connect Info is too big"); struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo; @@ -486,7 +481,6 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st info->rank = myInfo->rank; if (P2P_SAME_PID(myInfo, peerInfo) && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0) { resources->type = P2P_DIRECT; - recv->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; } else { if (ncclCuMemEnable()) { // cuMem API support @@ -497,8 +491,8 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st // Legacy CUDA IPC resources->type = P2P_IPC; } - recv->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; } + recv->conn.flags |= info->read ? NCCL_P2P_READ : NCCL_P2P_WRITE; } else { resources->type = P2P_INTERMEDIATE; info->rank = intermediateRank; @@ -843,9 +837,8 @@ static ncclResult_t p2pSendProxyProgress(struct ncclProxyState* proxyState, stru return ncclSuccess; } -ncclResult_t ncclIpcLocalRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, int* peerRanks, int nPeers, ncclIpcRegType type, int* regBufFlag, uintptr_t* offsetOut, uintptr_t** peerRmtAddrsOut) { - ncclResult_t ret = ncclSuccess; - struct ncclReg *regRecord = NULL; +static ncclResult_t ipcRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, int* peerRanks, int nPeers, ncclIpcRegType type, struct ncclReg* regRecord, int* regBufFlag, uintptr_t* offsetOut, uintptr_t** peerRmtAddrsOut, bool* isLegacyIpc) { +ncclResult_t ret = ncclSuccess; struct ncclIpcRegInfo* newInfo = NULL; uintptr_t* peerRmtAddrs = NULL; bool legacyIpcCap = false; @@ -856,123 +849,151 @@ ncclResult_t ncclIpcLocalRegisterBuffer(ncclComm* comm, const void* userbuff, si *regBufFlag = 0; *offsetOut = 0; *peerRmtAddrsOut = NULL; - if (comm && userbuff && buffSize > 0 && nPeers > 0) { - NCCLCHECKGOTO(ncclRegFind(comm, userbuff, buffSize, ®Record), ret, fail); - if (regRecord) { - // buffer was registered by by users, we need to start to register or reuse it - int peerLocalRank; - for (int p = 0; p < nPeers; p++) { - int peerRank = peerRanks[p]; - peerLocalRank = comm->rankToLocalRank[peerRank]; - if (regRecord->ipcInfos[peerLocalRank]) { - // We already have IPC info for peerLocalRank, no need to register it, we can reuse it - *regBufFlag = 1; - INFO(NCCL_REG, "rank %d - IPC local reuse buffer %p size %ld (baseAddr %p size %ld) to peer %d regAddr %p", comm->rank, userbuff, buffSize, (void*)regRecord->addr, regRecord->pages * comm->regCache.pageSize, peerRank, regRecord->ipcInfos[peerLocalRank]->impInfo.rmtRegAddr); - } else { - // Register buffer with peerLocalRank - struct ncclProxyConnector* proxyConn = NULL; - struct p2pIpcExpInfo ipcInfo; + if (isLegacyIpc) *isLegacyIpc = false; + if (regRecord) { + // buffer was registered by by users, we need to start to register or reuse it + int peerLocalRank; + for (int p = 0; p < nPeers; p++) { + int peerRank = peerRanks[p]; + peerLocalRank = comm->rankToLocalRank[peerRank]; + if (regRecord->ipcInfos[peerLocalRank]) { + // We already have IPC info for peerLocalRank, no need to register it, we can reuse it + *regBufFlag = 1; + if (isLegacyIpc) *isLegacyIpc = regRecord->ipcInfos[peerLocalRank]->impInfo.legacyIpcCap; + INFO(NCCL_REG, "rank %d - IPC reuse buffer %p size %ld (baseAddr %p size %ld) to peer %d regAddr %p", comm->rank, userbuff, buffSize, (void*)regRecord->addr, regRecord->pages * comm->regCache.pageSize, peerRank, regRecord->ipcInfos[peerLocalRank]->impInfo.rmtRegAddr); + } else { + // Register buffer with peerLocalRank + struct ncclProxyConnector* proxyConn = NULL; + struct p2pIpcExpInfo ipcInfo; - if (baseAddr == NULL) { - CUDACHECKGOTO(cuMemGetAddressRange((CUdeviceptr*)&baseAddr, &baseSize, (CUdeviceptr)userbuff), ret, fail); - CUDACHECKGOTO(cuPointerGetAttribute((void*)&legacyIpcCap, CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE, (CUdeviceptr)baseAddr), ret, fail); - } - if (comm->gproxyConn[peerRank].initialized == false) - NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_P2P, 1, peerRank, &comm->gproxyConn[peerRank]), ret, fail); - proxyConn = &comm->gproxyConn[peerRank]; + if (baseAddr == NULL) { + CUDACHECKGOTO(cuMemGetAddressRange((CUdeviceptr*)&baseAddr, &baseSize, (CUdeviceptr)userbuff), ret, fail); + CUDACHECKGOTO(cuPointerGetAttribute((void*)&legacyIpcCap, CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE, (CUdeviceptr)baseAddr), ret, fail); + } + if (comm->gproxyConn[peerRank].initialized == false) + NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_P2P, 1, peerRank, &comm->gproxyConn[peerRank]), ret, fail); + proxyConn = &comm->gproxyConn[peerRank]; - ipcInfo.legacyIpcCap = legacyIpcCap; - // Get the mem handle for that buffer. It may have been allocated through cudaMalloc in which case we'll - // get the CUDA legacy mem handle, or through cuMem*. - if (ipcInfo.legacyIpcCap) { - // legacy export - if (comm->directMode) goto fail; - CUDACHECKGOTO(cudaIpcGetMemHandle(&ipcInfo.ipcDesc.devIpc, baseAddr), ret, fail); - } else if (ncclCuMemEnable()) { + // Get the mem handle for that buffer. It may have been allocated through cudaMalloc in which case we'll + // get the CUDA legacy mem handle, or through cuMem*. + if (ncclCuMemEnable()) { #if CUDART_VERSION >= 11030 - CUmemGenericAllocationHandle handle; - if (CUPFN(cuMemRetainAllocationHandle(&handle, baseAddr)) != CUDA_SUCCESS) { - // if cuMem* export fails, retry legacy export - if (comm->directMode) goto fail; - CUDACHECKGOTO(cudaIpcGetMemHandle(&ipcInfo.ipcDesc.devIpc, baseAddr), ret, fail); - ipcInfo.legacyIpcCap = true; + CUmemGenericAllocationHandle handle; + if (CUPFN(cuMemRetainAllocationHandle(&handle, baseAddr)) != CUDA_SUCCESS) { + // if cuMem* export fails, retry legacy export + if (comm->directMode || !ncclParamLegacyCudaRegister()) goto fail; + CUDACHECKGOTO(cudaIpcGetMemHandle(&ipcInfo.ipcDesc.devIpc, baseAddr), ret, fail); + ipcInfo.legacyIpcCap = true; + if (isLegacyIpc) *isLegacyIpc = true; + } else { + ipcInfo.legacyIpcCap = false; + if (isLegacyIpc) *isLegacyIpc = false; + // cuMem* export to file descriptor or fabric handle + if (proxyConn->sameProcess) { + memcpy(&ipcInfo.ipcDesc.memHandle, &handle, sizeof(CUmemGenericAllocationHandle)); } else { - // cuMem* export to file descriptor or fabric handle - if (proxyConn->sameProcess) { - memcpy(&ipcInfo.ipcDesc.memHandle, &handle, sizeof(CUmemGenericAllocationHandle)); + if (ncclCuMemHandleType == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { + int expFd = -1; + CUCHECKGOTO(cuMemExportToShareableHandle(&expFd, handle, ncclCuMemHandleType, 0), ret, fail); + NCCLCHECKGOTO(ncclProxyClientQueryFdBlocking(comm, proxyConn, expFd, &ipcInfo.impFd), ret, fail); + SYSCHECKGOTO(close(expFd), "close", ret, fail); } else { - if (ncclCuMemHandleType == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { - int expFd = -1; - CUCHECKGOTO(cuMemExportToShareableHandle(&expFd, handle, ncclCuMemHandleType, 0), ret, fail); - NCCLCHECKGOTO(ncclProxyClientQueryFdBlocking(comm, proxyConn, expFd, &ipcInfo.impFd), ret, fail); - SYSCHECKGOTO(close(expFd), "close", ret, fail); - } else { - // Allow this to silently fail for cases where the user buff cannot be registered - if (CUPFN(cuMemExportToShareableHandle(&ipcInfo.ipcDesc.cuDesc.handle, handle, ncclCuMemHandleType, 0)) != CUDA_SUCCESS) { - CUCHECKGOTO(cuMemRelease(handle), ret, fail); - goto fail; - } + // Allow this to silently fail for cases where the user buff cannot be registered + if (CUPFN(cuMemExportToShareableHandle(&ipcInfo.ipcDesc.cuDesc.handle, handle, ncclCuMemHandleType, 0)) != CUDA_SUCCESS) { + CUCHECKGOTO(cuMemRelease(handle), ret, fail); + goto fail; } } - CUCHECKGOTO(cuMemRelease(handle), ret, fail); } + CUCHECKGOTO(cuMemRelease(handle), ret, fail); + } #endif - } else { - // nothing works, just return - goto fail; - } - - void* rmtRegAddr = NULL; - ipcInfo.size = baseSize; - ipcInfo.offset = regRecord->addr - (uintptr_t)baseAddr; - // Now ipcInfo contains all necessary registration info. Start to register buffer on proxy side - // and get the remote register address back. - if (proxyConn) - NCCLCHECKGOTO(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgRegister, &ipcInfo, sizeof(p2pIpcExpInfo), &rmtRegAddr, sizeof(void*)), ret, fail); - if (rmtRegAddr) { - NCCLCHECKGOTO(ncclCalloc(&newInfo, 1), ret, fail); - assert(regRecord->ipcInfos[peerLocalRank] == NULL); - regRecord->state |= IPC_REG_COMPLETE; - newInfo->peerRank = peerRank; - newInfo->baseAddr = baseAddr; - newInfo->impInfo.rmtRegAddr = rmtRegAddr; - newInfo->impInfo.offset = ipcInfo.offset; - newInfo->impInfo.legacyIpcCap = ipcInfo.legacyIpcCap; - newInfo->ipcProxyconn = proxyConn; - regRecord->ipcInfos[peerLocalRank] = newInfo; - if (regRecord->regIpcAddrs.hostPeerRmtAddrs == NULL) { - NCCLCHECKGOTO(ncclCalloc(®Record->regIpcAddrs.hostPeerRmtAddrs, comm->localRanks), ret, fail); - } - regRecord->regIpcAddrs.hostPeerRmtAddrs[peerLocalRank] = (uintptr_t)rmtRegAddr; - needUpdate = true; - *regBufFlag = 1; - INFO(NCCL_REG, "rank %d - IPC local register buffer %p size %ld (baseAddr %p size %ld) to peer %d regAddr %p offsetOut %ld", comm->rank, userbuff, buffSize, (void*)regRecord->addr, ipcInfo.size, peerRank, rmtRegAddr, (uintptr_t)userbuff - regRecord->addr); - } - } - } - - if (*regBufFlag) { - if (type == NCCL_IPC_COLLECTIVE) { - // for collective, store registered remote buffers into dev memory for future reference - if (regRecord->regIpcAddrs.devPeerRmtAddrs == NULL || needUpdate) { - NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->hostStream), ret, fail); - if (regRecord->regIpcAddrs.devPeerRmtAddrs == NULL) - NCCLCHECKGOTO(ncclCudaCallocAsync(®Record->regIpcAddrs.devPeerRmtAddrs, comm->localRanks, comm->sharedRes->hostStream.cudaStream), ret, fail); - if (needUpdate) - NCCLCHECKGOTO(ncclCudaMemcpyAsync(regRecord->regIpcAddrs.devPeerRmtAddrs, regRecord->regIpcAddrs.hostPeerRmtAddrs, comm->localRanks, comm->sharedRes->hostStream.cudaStream), ret, fail); - NCCLCHECKGOTO(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->sharedRes->deviceStream, &comm->sharedRes->hostStream), ret, fail); - NCCLCHECKGOTO(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->hostStream), ret, fail); - } - peerRmtAddrs = regRecord->regIpcAddrs.devPeerRmtAddrs; + } else if (legacyIpcCap) { + // legacy export + if (comm->directMode || !ncclParamLegacyCudaRegister()) goto fail; + CUDACHECKGOTO(cudaIpcGetMemHandle(&ipcInfo.ipcDesc.devIpc, baseAddr), ret, fail); + ipcInfo.legacyIpcCap = true; + if (isLegacyIpc) *isLegacyIpc = true; } else { - assert(nPeers == 1); - // p2p always returns remote addr here since remote buffer addr is passed in ncclDevWorkP2p struct - peerRmtAddrs = (uintptr_t*)regRecord->regIpcAddrs.hostPeerRmtAddrs[peerLocalRank]; + // nothing works, just return + goto fail; + } + + void* rmtRegAddr = NULL; + ipcInfo.size = baseSize; + ipcInfo.offset = regRecord->addr - (uintptr_t)baseAddr; + // Now ipcInfo contains all necessary registration info. Start to register buffer on proxy side + // and get the remote register address back. + if (proxyConn) + NCCLCHECKGOTO(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgRegister, &ipcInfo, sizeof(p2pIpcExpInfo), &rmtRegAddr, sizeof(void*)), ret, fail); + if (rmtRegAddr) { + NCCLCHECKGOTO(ncclCalloc(&newInfo, 1), ret, fail); + assert(regRecord->ipcInfos[peerLocalRank] == NULL); + regRecord->state |= IPC_REG_COMPLETE; + newInfo->peerRank = peerRank; + newInfo->baseAddr = baseAddr; + newInfo->impInfo.rmtRegAddr = rmtRegAddr; + newInfo->impInfo.offset = ipcInfo.offset; + newInfo->impInfo.legacyIpcCap = ipcInfo.legacyIpcCap; + newInfo->ipcProxyconn = proxyConn; + regRecord->ipcInfos[peerLocalRank] = newInfo; + if (regRecord->regIpcAddrs.hostPeerRmtAddrs == NULL) { + NCCLCHECKGOTO(ncclCalloc(®Record->regIpcAddrs.hostPeerRmtAddrs, comm->localRanks), ret, fail); + } + regRecord->regIpcAddrs.hostPeerRmtAddrs[peerLocalRank] = (uintptr_t)rmtRegAddr; + needUpdate = true; + *regBufFlag = 1; + INFO(NCCL_REG, "rank %d - IPC register buffer %p size %ld (baseAddr %p size %ld) to peer %d regAddr %p offsetOut %ld", comm->rank, userbuff, buffSize, (void*)regRecord->addr, ipcInfo.size, peerRank, rmtRegAddr, (uintptr_t)userbuff - regRecord->addr); } - *offsetOut = (uintptr_t)userbuff - regRecord->addr; - *peerRmtAddrsOut = peerRmtAddrs; } } + + if (*regBufFlag) { + if (type == NCCL_IPC_COLLECTIVE) { + // for collective, store registered remote buffers into dev memory for future reference + if (regRecord->regIpcAddrs.devPeerRmtAddrs == NULL || needUpdate) { + NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->hostStream), ret, fail); + if (regRecord->regIpcAddrs.devPeerRmtAddrs == NULL) + NCCLCHECKGOTO(ncclCudaCallocAsync(®Record->regIpcAddrs.devPeerRmtAddrs, comm->localRanks, comm->sharedRes->hostStream.cudaStream), ret, fail); + if (needUpdate) + NCCLCHECKGOTO(ncclCudaMemcpyAsync(regRecord->regIpcAddrs.devPeerRmtAddrs, regRecord->regIpcAddrs.hostPeerRmtAddrs, comm->localRanks, comm->sharedRes->hostStream.cudaStream), ret, fail); + NCCLCHECKGOTO(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->sharedRes->deviceStream, &comm->sharedRes->hostStream), ret, fail); + NCCLCHECKGOTO(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->hostStream), ret, fail); + } + peerRmtAddrs = regRecord->regIpcAddrs.devPeerRmtAddrs; + } else { + assert(nPeers == 1); + // p2p always returns remote addr here since remote buffer addr is passed in ncclDevWorkP2p struct + peerRmtAddrs = (uintptr_t*)regRecord->regIpcAddrs.hostPeerRmtAddrs[peerLocalRank]; + } + *offsetOut = (uintptr_t)userbuff - regRecord->addr; + *peerRmtAddrsOut = peerRmtAddrs; + } + } +exit: + return ret; +fail: + *regBufFlag = 0; + *offsetOut = 0; + *peerRmtAddrsOut = NULL; + if (newInfo) free(newInfo); + WARN("rank %d failed to IPC register userbuff %p buffSize %ld nPeers %d isLegacyIpc %p", comm->rank, userbuff, buffSize, nPeers, isLegacyIpc); + goto exit; +} + +ncclResult_t ncclIpcLocalRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, int* peerRanks, int nPeers, ncclIpcRegType type, int* regBufFlag, uintptr_t* offsetOut, uintptr_t** peerRmtAddrsOut) { + ncclResult_t ret = ncclSuccess; + struct ncclReg *regRecord = NULL; + bool isValid = false; + + *regBufFlag = 0; + *offsetOut = 0; + *peerRmtAddrsOut = NULL; + if (comm && userbuff && buffSize > 0 && nPeers > 0) { + NCCLCHECKGOTO(ncclRegFind(comm, userbuff, buffSize, ®Record), ret, fail); + NCCLCHECKGOTO(ncclRegLocalIsValid(regRecord, &isValid), ret, fail); + if (isValid) + NCCLCHECKGOTO(ipcRegisterBuffer(comm, userbuff, buffSize, peerRanks, nPeers, type, regRecord, regBufFlag, offsetOut, peerRmtAddrsOut, NULL), ret, fail); } exit: @@ -981,149 +1002,56 @@ fail: *regBufFlag = 0; *offsetOut = 0; *peerRmtAddrsOut = NULL; - if (newInfo) free(newInfo); goto exit; } struct ncclIpcCleanupCallback { struct ncclCommCallback base; - bool isAddrs; - union { - struct ncclIpcRegInfo regInfo; - struct ncclPeerRegIpcAddr regIpcAddrs; - }; + struct ncclComm *comm; + struct ncclReg *reg; }; static ncclResult_t cleanupIpc(struct ncclComm* comm, struct ncclCommCallback* cb) { struct ncclIpcCleanupCallback* obj = (struct ncclIpcCleanupCallback*)cb; - if (obj->isAddrs) { - if (obj->regIpcAddrs.hostPeerRmtAddrs) - free(obj->regIpcAddrs.hostPeerRmtAddrs); - if (obj->regIpcAddrs.devPeerRmtAddrs) - NCCLCHECK(ncclCudaFree(obj->regIpcAddrs.devPeerRmtAddrs)); - } else { - NCCLCHECK(ncclIpcDeregBuffer(comm, &obj->regInfo)); - } + NCCLCHECK(ncclCommGraphDeregister(obj->comm, obj->reg)); free(obj); return ncclSuccess; } ncclResult_t ncclIpcGraphRegisterBuffer(ncclComm* comm, const void* userbuff, size_t buffSize, int* peerRanks, int nPeers, ncclIpcRegType type, int* regBufFlag, uintptr_t* offsetOut, uintptr_t** peerRmtAddrsOut, void* cleanupQueuePtr, int* nCleanupQueueElts) { ncclResult_t ret = ncclSuccess; - struct ncclProxyConnector* proxyConn = NULL; - struct p2pIpcExpInfo ipcInfo; void* baseAddr = nullptr; size_t baseSize = 0; struct ncclIntruQueue* cleanupQueue = reinterpret_cast*>(cleanupQueuePtr); - uintptr_t* peerRmtAddrs = NULL; - struct ncclIpcCleanupCallback* addrsRecord = NULL; + bool isLegacyIpc = false; + struct ncclReg *regRecord = NULL; *regBufFlag = 0; - CUDACHECKGOTO(cuMemGetAddressRange((CUdeviceptr*)&baseAddr, &baseSize, (CUdeviceptr)userbuff), ret, fail); - CUDACHECKGOTO(cuPointerGetAttribute((void*)&ipcInfo.legacyIpcCap, CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE, (CUdeviceptr)baseAddr), ret, fail); - - if (type == NCCL_IPC_COLLECTIVE) { - // collective needs host memory array to hold all remote buffer addrs. - // We need to put this into graph release queue - NCCLCHECKGOTO(ncclCalloc(&addrsRecord, 1), ret, fail); - addrsRecord->base.fn = cleanupIpc; - addrsRecord->isAddrs = true; - NCCLCHECKGOTO(ncclCalloc(&addrsRecord->regIpcAddrs.hostPeerRmtAddrs, comm->localRanks), ret, fail); - } else { - assert(nPeers == 1); - // p2p does not need anything, just returning the remote buffer is enough, but for now, we register - // peer one by one so nPeers must be 1 - } - - for (int p = 0; p < nPeers; ++p) { - int peerRank = peerRanks[p]; - if (comm->gproxyConn[peerRank].initialized == false) - NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_P2P, 1, peerRank, &comm->gproxyConn[peerRank]), ret, fail); - proxyConn = &comm->gproxyConn[peerRank]; - // Same as local registration. Get the mem handle for that buffer. It may have been allocated through - // cudaMalloc in which case we'll get the CUDA legacy mem handle, or through cuMem*. - if (ipcInfo.legacyIpcCap) { - if (comm->directMode) goto fail; - CUDACHECKGOTO(cudaIpcGetMemHandle(&ipcInfo.ipcDesc.devIpc, baseAddr), ret, fail); - } else if (ncclCuMemEnable()) { -#if CUDART_VERSION >= 11030 - // cuMem* export - CUmemGenericAllocationHandle handle; - if (pfn_cuMemRetainAllocationHandle(&handle, baseAddr) != CUDA_SUCCESS) { - if (comm->directMode) goto fail; - CUDACHECKGOTO(cudaIpcGetMemHandle(&ipcInfo.ipcDesc.devIpc, baseAddr), ret, fail); - ipcInfo.legacyIpcCap = true; - } else { - if (proxyConn->sameProcess) { - memcpy(&ipcInfo.ipcDesc.memHandle, &handle, sizeof(CUmemGenericAllocationHandle)); - } else { - if (ncclCuMemHandleType == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { - int expFd = -1; - CUCHECKGOTO(cuMemExportToShareableHandle(&expFd, handle, ncclCuMemHandleType, 0), ret, fail); - if (proxyConn->sameProcess) { - ipcInfo.impFd = expFd; - } else { - NCCLCHECKGOTO(ncclProxyClientQueryFdBlocking(comm, proxyConn, expFd, &ipcInfo.impFd), ret, fail); - SYSCHECKGOTO(close(expFd), "close", ret, fail); - } - } else { - CUCHECKGOTO(cuMemExportToShareableHandle(&ipcInfo.ipcDesc.cuDesc.handle, handle, ncclCuMemHandleType, 0), ret, fail); - } - } - CUCHECKGOTO(cuMemRelease(handle), ret, fail); - } -#endif - } else { - goto fail; - } - - void* rmtRegAddr = NULL; - ipcInfo.size = baseSize; - ipcInfo.offset = 0; - NCCLCHECKGOTO(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgRegister, &ipcInfo, sizeof(struct p2pIpcExpInfo), &rmtRegAddr, sizeof(void*)), ret, fail); - if (rmtRegAddr) { + *offsetOut = 0; + *peerRmtAddrsOut = NULL; + if (comm && userbuff && buffSize > 0 && nPeers > 0) { + CUDACHECKGOTO(cuMemGetAddressRange((CUdeviceptr*)&baseAddr, &baseSize, (CUdeviceptr)userbuff), ret, fail); + NCCLCHECKGOTO(ncclCommGraphRegister(comm, baseAddr, baseSize, (void**)®Record), ret, fail); + NCCLCHECKGOTO(ipcRegisterBuffer(comm, userbuff, buffSize, peerRanks, nPeers, type, regRecord, regBufFlag, offsetOut, peerRmtAddrsOut, &isLegacyIpc), ret, fail); + if (*regBufFlag) { struct ncclIpcCleanupCallback* record; NCCLCHECKGOTO(ncclCalloc(&record, 1), ret, fail); record->base.fn = cleanupIpc; - record->isAddrs = false; - record->regInfo.peerRank = peerRank; - record->regInfo.baseAddr = baseAddr; - record->regInfo.impInfo.rmtRegAddr = rmtRegAddr; - record->regInfo.impInfo.offset = 0; - record->regInfo.impInfo.legacyIpcCap = ipcInfo.legacyIpcCap; - record->regInfo.ipcProxyconn = proxyConn; - // store the remote address into host addr array - if (type == NCCL_IPC_COLLECTIVE) - addrsRecord->regIpcAddrs.hostPeerRmtAddrs[comm->rankToLocalRank[peerRank]] = (uintptr_t)rmtRegAddr; - else - peerRmtAddrs = (uintptr_t*)rmtRegAddr; - *regBufFlag = 1; - if (ipcInfo.legacyIpcCap) - ncclIntruQueueEnqueue(&comm->legacyRegCleanupQueue, &record->base); - else - ncclIntruQueueEnqueue(cleanupQueue, &record->base); - if (nCleanupQueueElts) *nCleanupQueueElts += 1; - INFO(NCCL_REG, "rank %d - IPC graph register buffer %p size %ld (baseAddr %p size %ld) to peer %d regAddr %p offsetOut %ld", comm->rank, userbuff, buffSize, baseAddr, ipcInfo.size, peerRank, rmtRegAddr, (uintptr_t)userbuff - (uintptr_t)baseAddr); + record->comm = comm; + record->reg = regRecord; + if (isLegacyIpc) { + ncclIntruQueueEnqueue(&comm->legacyRegCleanupQueue, (struct ncclCommCallback*)record); + } else { + ncclIntruQueueEnqueue(cleanupQueue, (struct ncclCommCallback*)record); + if (nCleanupQueueElts) *nCleanupQueueElts += 1; + } + } else { + NCCLCHECKGOTO(ncclCommGraphDeregister(comm, regRecord), ret, fail); } } - if (type == NCCL_IPC_COLLECTIVE) { - // allocate the dev addr array and copy all previously stored addrs into it. - NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->sharedRes->hostStream), ret, fail); - NCCLCHECKGOTO(ncclCudaCallocAsync(&addrsRecord->regIpcAddrs.devPeerRmtAddrs, comm->localRanks, comm->sharedRes->hostStream.cudaStream), ret, fail); - NCCLCHECKGOTO(ncclCudaMemcpyAsync(addrsRecord->regIpcAddrs.devPeerRmtAddrs, addrsRecord->regIpcAddrs.hostPeerRmtAddrs, comm->nRanks, comm->sharedRes->hostStream.cudaStream), ret, fail); - NCCLCHECKGOTO(ncclStrongStreamWaitStream(ncclCudaGraphNone(), &comm->sharedRes->deviceStream, &comm->sharedRes->hostStream), ret, fail); - NCCLCHECKGOTO(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->hostStream), ret, fail); - peerRmtAddrs = addrsRecord->regIpcAddrs.devPeerRmtAddrs; - if (ipcInfo.legacyIpcCap) - ncclIntruQueueEnqueue(&comm->legacyRegCleanupQueue, &addrsRecord->base); - else - ncclIntruQueueEnqueue(cleanupQueue, &addrsRecord->base); - } - *offsetOut = (uintptr_t)userbuff - (uintptr_t)baseAddr; - *peerRmtAddrsOut = peerRmtAddrs; - exit: + // coverity[leaked_storage:FALSE] => normally, addrsRecord is added to the cleanupQueue return ret; fail: *regBufFlag = 0; diff --git a/src/transport/shm.cc b/src/transport/shm.cc index 0f4cd32fca..5d83ef00cd 100644 --- a/src/transport/shm.cc +++ b/src/transport/shm.cc @@ -454,6 +454,7 @@ static ncclResult_t shmRecvProxyProgress(struct ncclProxyState* proxyState, stru } static ncclResult_t shmSendProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { + ncclResult_t result = ncclSuccess; struct shmRequest* req = (struct shmRequest*)reqBuff; /* check message size */ if (reqSize != sizeof(struct shmRequest)) return ncclInternalError; @@ -463,13 +464,18 @@ static ncclResult_t shmSendProxySetup(struct ncclProxyConnection* connection, st struct shmProxyInfo* proxyInfo; NCCLCHECK(ncclCalloc(&proxyInfo, 1)); - NCCLCHECK(ncclShmAllocateShareableBuffer(proxyState->tpRank, req->size, req->legacy, &proxyInfo->desc, &info->buf.hptr, &info->buf.dptr)); + NCCLCHECKGOTO(ncclShmAllocateShareableBuffer(proxyState->tpRank, req->size, req->legacy, &proxyInfo->desc, &info->buf.hptr, &info->buf.dptr), result, fail); memcpy(&info->desc, &proxyInfo->desc, sizeof(ncclShmIpcDesc_t)); connection->transportResources = proxyInfo; - return ncclSuccess; +exit: + return result; +fail: + free(proxyInfo); + goto exit; } static ncclResult_t shmRecvProxySetup(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { + ncclResult_t result = ncclSuccess; struct shmRequest* req = (struct shmRequest*)reqBuff; /* check message size */ if (reqSize != sizeof(struct shmRequest)) return ncclInternalError; @@ -479,10 +485,14 @@ static ncclResult_t shmRecvProxySetup(struct ncclProxyConnection* connection, st struct shmProxyInfo* proxyInfo; NCCLCHECK(ncclCalloc(&proxyInfo, 1)); - NCCLCHECK(ncclShmAllocateShareableBuffer(proxyState->tpRank, req->size, req->legacy, &proxyInfo->desc, &info->buf.hptr, &info->buf.dptr)); + NCCLCHECKGOTO(ncclShmAllocateShareableBuffer(proxyState->tpRank, req->size, req->legacy, &proxyInfo->desc, &info->buf.hptr, &info->buf.dptr), result, fail); memcpy(&info->desc, &proxyInfo->desc, sizeof(ncclShmIpcDesc_t)); connection->transportResources = proxyInfo; - return ncclSuccess; +exit: + return result; +fail: + free(proxyInfo); + goto exit; } static void initCeOperation() { @@ -534,7 +544,7 @@ ncclResult_t ncclShmAllocateShareableBuffer(int tpProxyRank, size_t size, bool l } else { char shmPath[SHM_PATH_MAX] = { '\0' }; desc->shmli.shmSize = size; - NCCLCHECK(ncclShmOpen(shmPath, size, hptr, dptr, 1, &desc->shmli.handle)); + NCCLCHECK(ncclShmOpen(shmPath, sizeof(shmPath), size, hptr, dptr, 1, &desc->shmli.handle)); memcpy(desc->shmli.shmSuffix, shmPath + sizeof("/dev/shm/nccl-") - 1, sizeof(desc->shmli.shmSuffix)); desc->legacy = true; INFO(NCCL_SHM, "MMAP allocated shareable host buffer %s size %zi ptr %p", shmPath, desc->shmli.shmSize, *hptr); @@ -542,7 +552,7 @@ ncclResult_t ncclShmAllocateShareableBuffer(int tpProxyRank, size_t size, bool l #else /* CUDART_VERSION >= 12020 */ char shmPath[SHM_PATH_MAX] = { '\0' }; desc->shmli.shmSize = size; - NCCLCHECK(ncclShmOpen(shmPath, size, hptr, dptr, 1, &desc->shmli.handle)); + NCCLCHECK(ncclShmOpen(shmPath, sizeof(shmPath), size, hptr, dptr, 1, &desc->shmli.handle)); memcpy(desc->shmli.shmSuffix, shmPath + sizeof("/dev/shm/nccl-") - 1, sizeof(desc->shmli.shmSuffix)); desc->legacy = true; INFO(NCCL_SHM, "MMAP allocated shareable host buffer %s size %zi ptr %p", shmPath, size, *hptr); @@ -618,15 +628,15 @@ ncclResult_t ncclShmImportShareableBuffer(struct ncclComm *comm, ncclShmIpcDesc_ INFO(NCCL_SHM, "CUMEM imported shareable host buffer from tpProxyRank %d size %zi ptr %p, granularity %ld", desc->shmci.tpProxyRank, desc->shmci.size, descOut->shmci.ptr, granularity); } else { char shmPath[SHM_PATH_MAX]; - sprintf(shmPath, "/dev/shm/nccl-%s", desc->shmli.shmSuffix); - NCCLCHECK(ncclShmOpen(shmPath, desc->shmli.shmSize, hptr, dptr, -1, &descOut->shmli.handle)); + snprintf(shmPath, sizeof(shmPath), "/dev/shm/nccl-%s", desc->shmli.shmSuffix); + NCCLCHECK(ncclShmOpen(shmPath, sizeof(shmPath), desc->shmli.shmSize, hptr, dptr, -1, &descOut->shmli.handle)); descOut->legacy = true; INFO(NCCL_SHM, "MMAP imported shareable host buffer %s size %zi ptr %p", shmPath, desc->shmli.shmSize, *hptr); } #else /* CUDART_VERSION >= 12020 */ char shmPath[SHM_PATH_MAX]; - sprintf(shmPath, "/dev/shm/nccl-%s", desc->shmli.shmSuffix); - NCCLCHECK(ncclShmOpen(shmPath, desc->shmli.shmSize, hptr, dptr, -1, &descOut->shmli.handle)); + snprintf(shmPath, sizeof(shmPath), "/dev/shm/nccl-%s", desc->shmli.shmSuffix); + NCCLCHECK(ncclShmOpen(shmPath, sizeof(shmPath), desc->shmli.shmSize, hptr, dptr, -1, &descOut->shmli.handle)); descOut->legacy = true; INFO(NCCL_SHM, "MMAP imported shareable host buffer %s size %zi ptr %p", shmPath, desc->shmli.shmSize, *hptr); #endif diff --git a/test/AllGatherTests.cpp b/test/AllGatherTests.cpp index a712673af7..15d79d180c 100644 --- a/test/AllGatherTests.cpp +++ b/test/AllGatherTests.cpp @@ -33,7 +33,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllGather}; - std::vector const dataTypes = {ncclBfloat16, ncclFloat64, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclBfloat16, ncclFloat64, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {586}; diff --git a/test/AllReduceTests.cpp b/test/AllReduceTests.cpp index 3a49411026..9e1d3d4410 100644 --- a/test/AllReduceTests.cpp +++ b/test/AllReduceTests.cpp @@ -14,7 +14,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllReduce}; - std::vector const dataTypes = {ncclFloat32, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclFloat32, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {393216, 384}; @@ -33,7 +33,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllReduce}; - std::vector const dataTypes = {ncclFloat16, ncclFloat64, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclFloat16, ncclFloat64, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclMin}; std::vector const roots = {0}; std::vector const numElements = {12888}; @@ -71,7 +71,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllReduce}; - std::vector const dataTypes = {ncclInt32, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclInt32, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; std::vector const numElements = {393216, 12888, 384}; diff --git a/test/AllToAllTests.cpp b/test/AllToAllTests.cpp index 4298b2a2f4..c375794327 100644 --- a/test/AllToAllTests.cpp +++ b/test/AllToAllTests.cpp @@ -35,7 +35,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllToAll}; - std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {5685}; diff --git a/test/BroadcastTests.cpp b/test/BroadcastTests.cpp index 32511c39c9..f35c0f9a07 100644 --- a/test/BroadcastTests.cpp +++ b/test/BroadcastTests.cpp @@ -32,7 +32,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollBroadcast}; - std::vector const dataTypes = {ncclBfloat16, ncclFloat64, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclBfloat16, ncclFloat64, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {586}; diff --git a/test/ReduceScatterTests.cpp b/test/ReduceScatterTests.cpp index 23b0f56289..390aa6e288 100644 --- a/test/ReduceScatterTests.cpp +++ b/test/ReduceScatterTests.cpp @@ -32,7 +32,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollReduceScatter}; - std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; std::vector const numElements = {1048576}; diff --git a/test/ReduceTests.cpp b/test/ReduceTests.cpp index 8fab6942d9..ae54197b16 100644 --- a/test/ReduceTests.cpp +++ b/test/ReduceTests.cpp @@ -32,7 +32,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollReduce}; - std::vector const dataTypes = {ncclFloat16, ncclFloat64, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclFloat16, ncclFloat64, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclMin}; std::vector const roots = {0}; std::vector const numElements = {393216}; @@ -70,7 +70,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollReduce}; - std::vector const dataTypes = {ncclBfloat16, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclBfloat16, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclMax}; std::vector const roots = {0}; std::vector const numElements = {393216}; diff --git a/test/ScatterTests.cpp b/test/ScatterTests.cpp index d7bd7ab083..80ea4d2bd6 100644 --- a/test/ScatterTests.cpp +++ b/test/ScatterTests.cpp @@ -32,7 +32,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollScatter}; - std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFp8E4M3, ncclFp8E5M2}; + std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclSum}; std::vector const roots = {1}; std::vector const numElements = {24658}; diff --git a/test/StandaloneTests.cpp b/test/StandaloneTests.cpp index 9021faa72e..35d4814c3f 100644 --- a/test/StandaloneTests.cpp +++ b/test/StandaloneTests.cpp @@ -223,7 +223,9 @@ namespace RcclUnitTesting NCCLCHECK(ncclGroupStart()); for (int rank = 0; rank < numRanks; rank++) NCCLCHECK(ncclAllReduce(gpuInput[rank], gpuOutput[rank], N, ncclInt, ncclSum, comms[rank], stream[rank])); - NCCLCHECK(ncclGroupEnd()); + ncclResult_t res = ncclGroupEnd(); + + if (res != ncclSuccess) continue; const auto start = Clock::now(); diff --git a/test/common/CollectiveArgs.cpp b/test/common/CollectiveArgs.cpp index 2c97a76e68..ab24677880 100644 --- a/test/common/CollectiveArgs.cpp +++ b/test/common/CollectiveArgs.cpp @@ -195,18 +195,18 @@ namespace RcclUnitTesting scalarsPerRank.Attach(scalarsPerRank.ptr); switch (this->dataType) { - case ncclInt8: ss << scalarsPerRank.I1[this->globalRank]; break; - case ncclUint8: ss << scalarsPerRank.U1[this->globalRank]; break; - case ncclInt32: ss << scalarsPerRank.I4[this->globalRank]; break; - case ncclUint32: ss << scalarsPerRank.U4[this->globalRank]; break; - case ncclInt64: ss << scalarsPerRank.I8[this->globalRank]; break; - case ncclUint64: ss << scalarsPerRank.U8[this->globalRank]; break; - case ncclFp8E4M3: ss << scalarsPerRank.F1[this->globalRank]; break; - case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break; - case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break; - case ncclFp8E5M2: ss << scalarsPerRank.B1[this->globalRank]; break; - case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break; - default: ss << "(UNKNOWN)"; + case ncclInt8: ss << scalarsPerRank.I1[this->globalRank]; break; + case ncclUint8: ss << scalarsPerRank.U1[this->globalRank]; break; + case ncclInt32: ss << scalarsPerRank.I4[this->globalRank]; break; + case ncclUint32: ss << scalarsPerRank.U4[this->globalRank]; break; + case ncclInt64: ss << scalarsPerRank.I8[this->globalRank]; break; + case ncclUint64: ss << scalarsPerRank.U8[this->globalRank]; break; + case ncclFloat8e4m3: ss << scalarsPerRank.F1[this->globalRank]; break; + case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break; + case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break; + case ncclFloat8e5m2: ss << scalarsPerRank.B1[this->globalRank]; break; + case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break; + default: ss << "(UNKNOWN)"; } ss << " "; } diff --git a/test/common/CollectiveArgs.hpp b/test/common/CollectiveArgs.hpp index 182f76325a..aa497ebf01 100644 --- a/test/common/CollectiveArgs.hpp +++ b/test/common/CollectiveArgs.hpp @@ -54,8 +54,8 @@ namespace RcclUnitTesting "ncclFloat32", "ncclFloat64", "ncclBfloat16", - "ncclFp8E4M3", - "ncclFp8E5M2" + "ncclFloat8e4m3", + "ncclFloat8e5m2" }; char const ncclRedOpNames[ncclNumOps][32] = diff --git a/test/common/EnvVars.cpp b/test/common/EnvVars.cpp index 50d4661055..edcda6f457 100644 --- a/test/common/EnvVars.cpp +++ b/test/common/EnvVars.cpp @@ -282,8 +282,8 @@ namespace RcclUnitTesting dataTypes.push_back(ncclFloat32); dataTypes.push_back(ncclFloat64); dataTypes.push_back(ncclBfloat16); - dataTypes.push_back(ncclFp8E4M3); - dataTypes.push_back(ncclFp8E5M2); + dataTypes.push_back(ncclFloat8e4m3); + dataTypes.push_back(ncclFloat8e5m2); } // Build list of possible # GPU ranks based on env vars diff --git a/test/common/PtrUnion.cpp b/test/common/PtrUnion.cpp index facf60b342..1a300b1a15 100644 --- a/test/common/PtrUnion.cpp +++ b/test/common/PtrUnion.cpp @@ -14,8 +14,8 @@ namespace RcclUnitTesting { case ncclInt8: return 1; case ncclUint8: return 1; - case ncclFp8E4M3:return 1; - case ncclFp8E5M2:return 1; + case ncclFloat8e4m3:return 1; + case ncclFloat8e5m2:return 1; case ncclInt32: return 4; case ncclUint32: return 4; case ncclInt64: return 8; @@ -150,7 +150,7 @@ namespace RcclUnitTesting { // Due to floating-point math not being commutative, the ordering in which ranks are added will matter. // For lower-precision data types, we initialize all ranks to the same value to avoid this - int valueI = (dataType == ncclFp8E4M3 || dataType == ncclFp8E5M2)? (i % 16) :(globalRank + i) % 256; + int valueI = (dataType == ncclFloat8e4m3 || dataType == ncclFloat8e5m2)? (i % 16) :(globalRank + i) % 256; double valueF = 1.0L/((double)valueI+1.0L); temp.Set(dataType, i, valueI, valueF); } @@ -179,11 +179,11 @@ namespace RcclUnitTesting case ncclUint32: U4[idx] = valueI; break; case ncclInt64: I8[idx] = valueI; break; case ncclUint64: U8[idx] = valueI; break; - case ncclFp8E4M3: F1[idx] = rccl_float8(valueF); break; + case ncclFloat8e4m3: F1[idx] = rccl_float8(valueF); break; case ncclFloat16: F2[idx] = __float2half(static_cast(valueF)); break; case ncclFloat32: F4[idx] = valueF; break; case ncclFloat64: F8[idx] = valueF; break; - case ncclFp8E5M2: B1[idx] = rccl_bfloat8(valueF); break; + case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(valueF); break; case ncclBfloat16: B2[idx] = hip_bfloat16(static_cast(valueF)); break; default: ERROR("Unsupported datatype\n"); @@ -202,11 +202,11 @@ namespace RcclUnitTesting case ncclUint32: valueI = U4[idx]; break; case ncclInt64: valueI = I8[idx]; break; case ncclUint64: valueI = U8[idx]; break; - case ncclFp8E4M3: valueF = float(F1[idx]); break; + case ncclFloat8e4m3: valueF = float(F1[idx]); break; case ncclFloat16: valueF = __half2float(F2[idx]); break; case ncclFloat32: valueF = F4[idx]; break; case ncclFloat64: valueF = F8[idx]; break; - case ncclFp8E5M2: valueF = float(B1[idx]); break; + case ncclFloat8e5m2: valueF = float(B1[idx]); break; case ncclBfloat16: valueF = B2[idx]; break; default: ERROR("Unsupported datatype\n"); @@ -234,11 +234,11 @@ namespace RcclUnitTesting case ncclUint32: U4[idx] *= scalarsPerRank.U4[rank]; break; case ncclInt64: I8[idx] *= scalarsPerRank.I8[rank]; break; case ncclUint64: U8[idx] *= scalarsPerRank.U8[rank]; break; - case ncclFp8E4M3: F1[idx] = rccl_float8(F1[idx] * scalarsPerRank.F1[rank]); break; + case ncclFloat8e4m3: F1[idx] = rccl_float8(F1[idx] * scalarsPerRank.F1[rank]); break; case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx]) * __half2float(scalarsPerRank.F2[rank])); break; case ncclFloat32: F4[idx] *= scalarsPerRank.F4[rank]; break; case ncclFloat64: F8[idx] *= scalarsPerRank.F8[rank]; break; - case ncclFp8E5M2: B1[idx] = rccl_bfloat8(B1[idx] * scalarsPerRank.B1[rank]); break; + case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(B1[idx] * scalarsPerRank.B1[rank]); break; case ncclBfloat16: B2[idx] *= scalarsPerRank.B2[rank]; break; default: ERROR("Unsupported datatype\n"); @@ -269,11 +269,11 @@ namespace RcclUnitTesting case ncclUint32: U4[idx] = ReduceOp(op, U4[idx], inputCpu.U4[idx]); break; case ncclInt64: I8[idx] = ReduceOp(op, I8[idx], inputCpu.I8[idx]); break; case ncclUint64: U8[idx] = ReduceOp(op, U8[idx], inputCpu.U8[idx]); break; - case ncclFp8E4M3: F1[idx] = rccl_float8(ReduceOp(op, float(F1[idx]), float(inputCpu.F1[idx]))); break; + case ncclFloat8e4m3: F1[idx] = rccl_float8(ReduceOp(op, float(F1[idx]), float(inputCpu.F1[idx]))); break; case ncclFloat16: F2[idx] = __float2half(ReduceOp(op, __half2float(F2[idx]), __half2float(inputCpu.F2[idx]))); break; case ncclFloat32: F4[idx] = ReduceOp(op, F4[idx], inputCpu.F4[idx]); break; case ncclFloat64: F8[idx] = ReduceOp(op, F8[idx], inputCpu.F8[idx]); break; - case ncclFp8E5M2: B1[idx] = rccl_bfloat8(ReduceOp(op, float(B1[idx]), float(inputCpu.B1[idx]))); break; + case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(ReduceOp(op, float(B1[idx]), float(inputCpu.B1[idx]))); break; case ncclBfloat16: B2[idx] = ReduceOp(op, B2[idx], inputCpu.B2[idx]); break; default: ERROR("Unsupported datatype\n"); @@ -298,11 +298,11 @@ namespace RcclUnitTesting case ncclUint32: U4[idx] /= divisor; break; case ncclInt64: I8[idx] /= divisor; break; case ncclUint64: U8[idx] /= divisor; break; - case ncclFp8E4M3: F1[idx] = (rccl_float8((float)(F1[idx]) / divisor)); break; + case ncclFloat8e4m3: F1[idx] = (rccl_float8((float)(F1[idx]) / divisor)); break; case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx])/divisor); break; case ncclFloat32: F4[idx] /= divisor; break; case ncclFloat64: F8[idx] /= divisor; break; - case ncclFp8E5M2: B1[idx] = (rccl_bfloat8((float)(B1[idx]) / divisor)); break; + case ncclFloat8e5m2: B1[idx] = (rccl_bfloat8((float)(B1[idx]) / divisor)); break; case ncclBfloat16: B2[idx] = (hip_bfloat16((float)(B2[idx]) / divisor)); break; default: ERROR("Unsupported datatype\n"); @@ -330,11 +330,11 @@ namespace RcclUnitTesting case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break; case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break; case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break; - case ncclFp8E4M3: isMatch = (fabs(float(F1[idx]) - float(expected.F1[idx])) < 9e-2); break; + case ncclFloat8e4m3: isMatch = (fabs(float(F1[idx]) - float(expected.F1[idx])) < 9e-2); break; case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break; case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break; case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break; - case ncclFp8E5M2: isMatch = (fabs(float(B1[idx]) - float(expected.B1[idx])) < 9e-2); break; + case ncclFloat8e5m2: isMatch = (fabs(float(B1[idx]) - float(expected.B1[idx])) < 9e-2); break; case ncclBfloat16: isMatch = (fabs((float)B2[idx] - (float)expected.B2[idx]) < 9e-2); break; default: ERROR("Unsupported datatype\n"); @@ -359,16 +359,16 @@ namespace RcclUnitTesting ERROR("Expected output: %ld. Actual output: %ld at index %lu\n", expected.I8[idx], I8[idx], idx); break; case ncclUint64: ERROR("Expected output: %lu. Actual output: %lu at index %lu\n", expected.U8[idx], U8[idx], idx); break; - case ncclFp8E4M3: - ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); break; + case ncclFloat8e4m3: + ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); case ncclFloat16: ERROR("Expected output: %f. Actual output: %f at index %lu\n", __half2float(expected.F2[idx]), __half2float(F2[idx]), idx); break; case ncclFloat32: ERROR("Expected output: %f. Actual output: %f at index %lu\n", expected.F4[idx], F4[idx], idx); break; case ncclFloat64: ERROR("Expected output: %lf. Actual output: %lf at index %lu\n", expected.F8[idx], F8[idx], idx); break; - case ncclFp8E5M2: - ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); break; + case ncclFloat8e5m2: + ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); case ncclBfloat16: ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B2[idx], (float)B2[idx], idx); break; default: @@ -393,11 +393,11 @@ namespace RcclUnitTesting case ncclUint32: ss << U4[i]; break; case ncclInt64: ss << I8[i]; break; case ncclUint64: ss << U8[i]; break; - case ncclFp8E4M3: ss << (float)F1[i]; break; + case ncclFloat8e4m3: ss << (float)F1[i]; break; case ncclFloat16: ss << __half2float(F2[i]); break; case ncclFloat32: ss << F4[i]; break; case ncclFloat64: ss << F8[i]; break; - case ncclFp8E5M2: ss << (float)B1[i]; break; + case ncclFloat8e5m2: ss << (float)B1[i]; break; case ncclBfloat16: ss << (float)B2[i]; break; default: break; } diff --git a/test/common/PtrUnion.hpp b/test/common/PtrUnion.hpp index 29d78a376a..75c1255d2b 100644 --- a/test/common/PtrUnion.hpp +++ b/test/common/PtrUnion.hpp @@ -44,10 +44,10 @@ namespace RcclUnitTesting int64_t* I8; // ncclInt64 uint64_t* U8; // ncclUint64 __half* F2; // ncclFloat16 - rccl_float8* F1; // ncclFp8E4M3 + rccl_float8* F1; // ncclFloat8e4m3 float* F4; // ncclFloat32 double* F8; // ncclFloat64 - rccl_bfloat8* B1; // ncclFp8E5M2 + rccl_bfloat8* B1; // ncclFloat8e5m2 hip_bfloat16* B2; // ncclBfloat16 constexpr PtrUnion() : ptr(nullptr) {} diff --git a/test/common/TestBed.cpp b/test/common/TestBed.cpp index 5ea0efce11..700a572004 100644 --- a/test/common/TestBed.cpp +++ b/test/common/TestBed.cpp @@ -697,8 +697,8 @@ namespace RcclUnitTesting { //Skipping AllReduce FP8 test on 9 to 16 ranks (gfx90a). if(ev.isGfx90 && numRanks > 8 && funcTypes[ftIdx] == ncclCollAllReduce - && (dataTypes[dtIdx] == ncclFp8E4M3 - || dataTypes[dtIdx] == ncclFp8E5M2)) + && (dataTypes[dtIdx] == ncclFloat8e4m3 + || dataTypes[dtIdx] == ncclFloat8e5m2)) { continue; } diff --git a/tools/RcclReplayer/rcclReplayer.hpp b/tools/RcclReplayer/rcclReplayer.hpp index 486826d94e..81934e1d86 100644 --- a/tools/RcclReplayer/rcclReplayer.hpp +++ b/tools/RcclReplayer/rcclReplayer.hpp @@ -112,10 +112,10 @@ union PtrUnion int64_t* I8; // ncclInt64 uint64_t* U8; // ncclUint64 __half* F2; // ncclFloat16 - rccl_float8* F1; // ncclFp8E4M3 + rccl_float8* F1; // ncclFloat8e4m3 float* F4; // ncclFloat32 double* F8; // ncclFloat64 - rccl_bfloat8* B1; // ncclFp8E5M2 + rccl_bfloat8* B1; // ncclFloat8e5m2 hip_bfloat16* B2; // ncclBfloat16 constexpr PtrUnion() : ptr(nullptr) {} @@ -176,8 +176,8 @@ std::string DataTypeToName(ncclDataType_t const dataType) case ncclFloat32: return "Float32"; case ncclFloat64: return "Float64"; case ncclBfloat16: return "Bfloat16"; - case ncclFp8E4M3: return "Fp8E4M3"; - case ncclFp8E5M2: return "Fp8E5M2"; + case ncclFloat8e4m3: return "Fp8E4M3"; + case ncclFloat8e5m2: return "Fp8E5M2"; default: printf("Unsupported datatype (%d)\n", dataType); exit(0); @@ -197,8 +197,8 @@ size_t DataTypeToBytes(ncclDataType_t const dataType) case ncclFloat32: return 4; case ncclFloat64: return 8; case ncclBfloat16: return 2; - case ncclFp8E4M3: return 1; - case ncclFp8E5M2: return 1; + case ncclFloat8e4m3: return 1; + case ncclFloat8e5m2: return 1; default: printf("Unsupported datatype (%s)\n", DataTypeToName(dataType).c_str()); exit(0); @@ -239,11 +239,11 @@ void SetPtr(PtrUnion& ptrUnion, ncclDataType_t const dataType, int const idx, in case ncclUint32: ptrUnion.U4[idx] = valueI; break; case ncclInt64: ptrUnion.I8[idx] = valueI; break; case ncclUint64: ptrUnion.U8[idx] = valueI; break; - case ncclFp8E4M3: ptrUnion.F1[idx] = rccl_float8(valueF); break; + case ncclFloat8e4m3: ptrUnion.F1[idx] = rccl_float8(valueF); break; case ncclFloat16: ptrUnion.F2[idx] = __float2half(static_cast(valueF)); break; case ncclFloat32: ptrUnion.F4[idx] = valueF; break; case ncclFloat64: ptrUnion.F8[idx] = valueF; break; - case ncclFp8E5M2: ptrUnion.B1[idx] = rccl_bfloat8(valueF); break; + case ncclFloat8e5m2: ptrUnion.B1[idx] = rccl_bfloat8(valueF); break; case ncclBfloat16: ptrUnion.B2[idx] = hip_bfloat16(static_cast(valueF)); break; default: printf("Unsupported datatype (%s)\n", DataTypeToName(dataType).c_str()); @@ -265,11 +265,11 @@ bool IsEqual(PtrUnion const& actual, PtrUnion const& expected, ncclDataType_t co case ncclUint32: isMatch = (actual.U4[idx] == expected.U4[idx]); break; case ncclInt64: isMatch = (actual.I8[idx] == expected.I8[idx]); break; case ncclUint64: isMatch = (actual.U8[idx] == expected.U8[idx]); break; - case ncclFp8E4M3: isMatch = (fabs(float(actual.F1[idx]) - float(expected.F1[idx])) < 9e-2); break; + case ncclFloat8e4m3: isMatch = (fabs(float(actual.F1[idx]) - float(expected.F1[idx])) < 9e-2); break; case ncclFloat16: isMatch = (fabs(__half2float(actual.F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break; case ncclFloat32: isMatch = (fabs(actual.F4[idx] - expected.F4[idx]) < 1e-5); break; case ncclFloat64: isMatch = (fabs(actual.F8[idx] - expected.F8[idx]) < 1e-12); break; - case ncclFp8E5M2: isMatch = (fabs(float(actual.B1[idx]) - float(expected.B1[idx])) < 9e-2); break; + case ncclFloat8e5m2: isMatch = (fabs(float(actual.B1[idx]) - float(expected.B1[idx])) < 9e-2); break; case ncclBfloat16: isMatch = (fabs((float)actual.B2[idx] - (float)expected.B2[idx]) < 9e-2); break; default: printf("Unsupported datatype (%s)\n", DataTypeToName(dataType).c_str()); @@ -290,7 +290,7 @@ bool IsEqual(PtrUnion const& actual, PtrUnion const& expected, ncclDataType_t co printf("[Error Rank = %d] Expected output: %ld. Actual output: %ld at index %lu\n", globalRank, expected.I8[idx], actual.I8[idx], idx); break; case ncclUint64: printf("[Error Rank = %d] Expected output: %lu. Actual output: %lu at index %lu\n", globalRank, expected.U8[idx], actual.U8[idx], idx); break; - case ncclFp8E4M3: + case ncclFloat8e4m3: printf("[Error Rank = %d] Expected output: %f. Actual output: %f at index %lu\n", globalRank, (float)expected.F1[idx], (float)actual.F1[idx], idx); break; case ncclFloat16: printf("[Error Rank = %d] Expected output: %f. Actual output: %f at index %lu\n", globalRank, __half2float(expected.F2[idx]), __half2float(actual.F2[idx]), idx); break; @@ -298,7 +298,7 @@ bool IsEqual(PtrUnion const& actual, PtrUnion const& expected, ncclDataType_t co printf("[Error Rank = %d] Expected output: %f. Actual output: %f at index %lu\n", globalRank, expected.F4[idx], actual.F4[idx], idx); break; case ncclFloat64: printf("[Error Rank = %d] Expected output: %lf. Actual output: %lf at index %lu\n", globalRank, expected.F8[idx], actual.F8[idx], idx); break; - case ncclFp8E5M2: + case ncclFloat8e5m2: printf("[Error Rank = %d] Expected output: %f. Actual output: %f at index %lu\n", globalRank, (float)expected.B1[idx], (float)actual.B1[idx], idx); break; case ncclBfloat16: printf("[Error Rank = %d] Expected output: %f. Actual output: %f at index %lu\n", globalRank, (float)expected.B2[idx], (float)actual.B2[idx], idx); break; @@ -340,11 +340,11 @@ void Reduce(PtrUnion& ptrUnion, PtrUnion const& otherPtrUnion, size_t const numE case ncclUint32: ptrUnion.U4[idx] = ReduceOp(op, ptrUnion.U4[idx], otherPtrUnion.U4[idx]); break; case ncclInt64: ptrUnion.I8[idx] = ReduceOp(op, ptrUnion.I8[idx], otherPtrUnion.I8[idx]); break; case ncclUint64: ptrUnion.U8[idx] = ReduceOp(op, ptrUnion.U8[idx], otherPtrUnion.U8[idx]); break; - case ncclFp8E4M3: ptrUnion.F1[idx] = rccl_float8(ReduceOp(op, float(ptrUnion.F1[idx]), float(otherPtrUnion.F1[idx]))); break; + case ncclFloat8e4m3: ptrUnion.F1[idx] = rccl_float8(ReduceOp(op, float(ptrUnion.F1[idx]), float(otherPtrUnion.F1[idx]))); break; case ncclFloat16: ptrUnion.F2[idx] = __float2half(ReduceOp(op, __half2float(ptrUnion.F2[idx]), __half2float(otherPtrUnion.F2[idx]))); break; case ncclFloat32: ptrUnion.F4[idx] = ReduceOp(op, ptrUnion.F4[idx], otherPtrUnion.F4[idx]); break; case ncclFloat64: ptrUnion.F8[idx] = ReduceOp(op, ptrUnion.F8[idx], otherPtrUnion.F8[idx]); break; - case ncclFp8E5M2: ptrUnion.B1[idx] = rccl_bfloat8(ReduceOp(op, float(ptrUnion.B1[idx]), float(otherPtrUnion.B1[idx]))); break; + case ncclFloat8e5m2: ptrUnion.B1[idx] = rccl_bfloat8(ReduceOp(op, float(ptrUnion.B1[idx]), float(otherPtrUnion.B1[idx]))); break; case ncclBfloat16: ptrUnion.B2[idx] = ReduceOp(op, ptrUnion.B2[idx], otherPtrUnion.B2[idx]); break; default: printf("Unsupported datatype (%s)\n", DataTypeToName(dataType).c_str()); @@ -365,11 +365,11 @@ void DivideByInt(PtrUnion& ptrUnion, ncclDataType_t const dataType, size_t const case ncclUint32: ptrUnion.U4[idx] /= divisor; break; case ncclInt64: ptrUnion.I8[idx] /= divisor; break; case ncclUint64: ptrUnion.U8[idx] /= divisor; break; - case ncclFp8E4M3: ptrUnion.F1[idx] = (rccl_float8((float)(ptrUnion.F1[idx]) / divisor)); break; + case ncclFloat8e4m3: ptrUnion.F1[idx] = (rccl_float8((float)(ptrUnion.F1[idx]) / divisor)); break; case ncclFloat16: ptrUnion.F2[idx] = __float2half(__half2float(ptrUnion.F2[idx])/divisor); break; case ncclFloat32: ptrUnion.F4[idx] /= divisor; break; case ncclFloat64: ptrUnion.F8[idx] /= divisor; break; - case ncclFp8E5M2: ptrUnion.B1[idx] = (rccl_bfloat8((float)(ptrUnion.B1[idx]) / divisor)); break; + case ncclFloat8e5m2: ptrUnion.B1[idx] = (rccl_bfloat8((float)(ptrUnion.B1[idx]) / divisor)); break; case ncclBfloat16: ptrUnion.B2[idx] = (hip_bfloat16((float)(ptrUnion.B2[idx]) / divisor)); break; default: printf("Unsupported datatype (%s)\n", DataTypeToName(dataType).c_str()); diff --git a/tools/ib-test/include/nccl.h b/tools/ib-test/include/nccl.h index 2c86c33269..6e95d87d16 100755 --- a/tools/ib-test/include/nccl.h +++ b/tools/ib-test/include/nccl.h @@ -121,8 +121,8 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, ncclBfloat16 = 9, - ncclFp8E4M3 = 10, - ncclFp8E5M2 = 11, + ncclFloat8e4m3 = 10, + ncclFloat8e5m2 = 11, ncclNumTypes = 12 } ncclDataType_t; /* diff --git a/tools/topo_expl/include/nccl.h b/tools/topo_expl/include/nccl.h index 76b74289f3..ae94d20d81 100644 --- a/tools/topo_expl/include/nccl.h +++ b/tools/topo_expl/include/nccl.h @@ -389,13 +389,10 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, ncclBfloat16 = 9, -#if defined(RCCL_FLOAT8) - ncclFp8E4M3 = 10, - ncclFp8E5M2 = 11, - ncclNumTypes = 12 } ncclDataType_t; -#else - ncclNumTypes = 10 } ncclDataType_t; -#endif + ncclFloat8e4m3 = 10, + ncclFloat8e5m2 = 11, + ncclNumTypes = 12 +} ncclDataType_t; /*! @} */ /*! @defgroup rccl_api_custom_redop Custom Reduction Operator