diff --git a/projects/rccl/ext-net/example/nccl/net.h b/projects/rccl/ext-net/example/nccl/net.h index f5101aec8b..2f455c60f1 100644 --- a/projects/rccl/ext-net/example/nccl/net.h +++ b/projects/rccl/ext-net/example/nccl/net.h @@ -17,13 +17,14 @@ #define NCCL_PTR_DMABUF 0x4 // Maximum number of requests per comm object -#define NCCL_NET_MAX_REQUESTS 8 +#define NCCL_NET_MAX_REQUESTS 32 typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel; typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_ALL=~0} ncclDebugLogSubSys; typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); +#include "net_v8.h" #include "net_v7.h" #include "net_v6.h" #include "net_v5.h" diff --git a/projects/rccl/ext-net/example/nccl/net_device.h b/projects/rccl/ext-net/example/nccl/net_device.h index 32cc519ded..a0b84c7656 100644 --- a/projects/rccl/ext-net/example/nccl/net_device.h +++ b/projects/rccl/ext-net/example/nccl/net_device.h @@ -26,6 +26,7 @@ typedef struct { int needsProxyProgress; } ncclNetDeviceHandle_v7_t; +typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t; typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_t; #endif diff --git a/projects/rccl/ext-net/example/nccl/net_v6.h b/projects/rccl/ext-net/example/nccl/net_v6.h index 21379d3d11..fffaf8c62d 100644 --- a/projects/rccl/ext-net/example/nccl/net_v6.h +++ b/projects/rccl/ext-net/example/nccl/net_v6.h @@ -5,6 +5,8 @@ #ifndef NCCL_NET_V6_H_ #define NCCL_NET_V6_H_ +#define NCCL_NET_MAX_REQUESTS_V6 8 + typedef struct { char* name; // Used mostly for logging. char* pciPath; // Path to the PCI device in /sys. diff --git a/projects/rccl/ext-net/example/nccl/net_v7.h b/projects/rccl/ext-net/example/nccl/net_v7.h index 77d6cb73ee..d607095de3 100644 --- a/projects/rccl/ext-net/example/nccl/net_v7.h +++ b/projects/rccl/ext-net/example/nccl/net_v7.h @@ -22,8 +22,6 @@ typedef struct { int netDeviceVersion; // Version number for network offload } ncclNetProperties_v7_t; -typedef ncclNetProperties_v7_t ncclNetProperties_t; - typedef struct { // Name of the network (mainly for logs) const char* name; diff --git a/projects/rccl/ext-net/example/nccl/net_v8.h b/projects/rccl/ext-net/example/nccl/net_v8.h new file mode 100644 index 0000000000..3161558205 --- /dev/null +++ b/projects/rccl/ext-net/example/nccl/net_v8.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. + */ + +#ifndef NCCL_NET_V8_H_ +#define NCCL_NET_V8_H_ + +#include "net_device.h" + +typedef struct { + char* name; // Used mostly for logging. + char* pciPath; // Path to the PCI device in /sys. + uint64_t guid; // Unique identifier for the NIC chip. Important for + // cards with multiple PCI functions (Physical or virtual). + int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] + int regIsGlobal; // regMr is not tied to a particular comm + int speed; // Port speed in Mbps. + int port; // Port number. + float latency; // Network latency + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives. + ncclNetDeviceType netDeviceType; // Network offload type + int netDeviceVersion; // Version number for network offload +} ncclNetProperties_v8_t; + +typedef ncclNetProperties_v8_t ncclNetProperties_t; + +typedef struct { + // Name of the network (mainly for logs) + const char* name; + // Initialize the network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create a connection. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Connect to a handle and return a sending comm object for that peer. + // This call must not block for the connection to be established, and instead + // should return successfully with sendComm == NULL with the expectation that + // it will be called again until sendComm != NULL. + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm); + // Finalize connection establishment after remote peer has called connect. + // This call must not block for the connection to be established, and instead + // should return successfully with recvComm == NULL with the expectation that + // it will be called again until recvComm != NULL. + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm); + // Register/Deregister memory. Comm can be either a sendComm or a recvComm. + // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + // Asynchronous send to a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request); + // Asynchronous recv from a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* sizes); + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); +} ncclNet_v8_t; + +#endif // end include guard diff --git a/projects/rccl/ext-net/example/plugin.c b/projects/rccl/ext-net/example/plugin.c index cc860b0067..128dde9b47 100644 --- a/projects/rccl/ext-net/example/plugin.c +++ b/projects/rccl/ext-net/example/plugin.c @@ -15,15 +15,37 @@ __hidden ncclResult_t pluginDevices(int* ndev) { *ndev = 0; return ncclSuccess; __hidden ncclResult_t pluginPciPath(int dev, char** path) { return ncclInternalError; } __hidden ncclResult_t pluginPtrSupport(int dev, int* supportedTypes) { return ncclInternalError; } -__hidden ncclResult_t pluginGetProperties(int dev, ncclNetProperties_v7_t* props) { - //pluginPciPath(dev, &props.pciPath); - //pluginPtrSupport(dev, &props.ptrSupport); +__hidden ncclResult_t pluginGetProperties(int dev, ncclNetProperties_v8_t* props) { + // Below are default values, if unsure don't change. + + props->name = "Example"; + // Fill for proper topology detection, e.g. /sys/devices/pci0000:00/0000:00:10.0/0000:0b:00.0 + props->pciPath = NULL; + // Only used to detect NICs with multiple PCI attachments. + props->guid = 0; + // Add NCCL_PTR_CUDA if GPU Direct RDMA is supported and regMr can take CUDA pointers. + props->ptrSupport = NCCL_PTR_HOST; + // If you regMr has a fast registration cache, set to 1. If set to 0, user buffer registration may be disabled. + props->regIsGlobal = 0; + // Speed in *Mbps*. 100000 means 100G + props->speed = 100000; + // Port number, used in conjunction with guid + props->port = 0; + // Custom latency (used to help tuning if latency is high. If set to 0, use default NCCL values. + props->latency = 0; + // Maximum number of comm objects we can create. + props->maxComms = 1024*1024; + // Maximum number of receive operations taken by irecv(). + props->maxRecvs = 1; + // Coupling with NCCL network device-side code. + props->netDeviceType = 0; + props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; return ncclInternalError; } __hidden ncclResult_t pluginListen(int dev, void* handle, void** listenComm) { return ncclInternalError; } -__hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v7_t** sendDevComm) { return ncclInternalError; } -__hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_v7_t** recvDevComm) { return ncclInternalError; } -__hidden ncclResult_t pluginRegMr(void* collComm, void* data, int size, int type, void** mhandle) { return ncclInternalError; } +__hidden ncclResult_t pluginConnect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm) { return ncclInternalError; } +__hidden ncclResult_t pluginAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm) { return ncclInternalError; } +__hidden ncclResult_t pluginRegMr(void* collComm, void* data, size_t size, int type, void** mhandle) { return ncclInternalError; } __hidden ncclResult_t pluginRegMrDmaBuf(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { return ncclInternalError; } __hidden ncclResult_t pluginDeregMr(void* collComm, void* mhandle) { return ncclInternalError;} __hidden ncclResult_t pluginIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { return ncclInternalError; } @@ -38,7 +60,7 @@ __hidden ncclResult_t pluginGetDeviceMr(void* comm, void* mhandle, void** dptr_m #define PLUGIN_NAME "Plugin" -const ncclNet_v7_t ncclNetPlugin_v7 = { +const ncclNet_v8_t ncclNetPlugin_v8 = { .name = PLUGIN_NAME, .init = pluginInit, .devices = pluginDevices, @@ -60,10 +82,62 @@ const ncclNet_v7_t ncclNetPlugin_v7 = { .irecvConsumed = pluginIrecvConsumed, }; -__hidden ncclResult_t pluginGetProperties_v6(int dev, ncclNetProperties_v6_t* props) { - //pluginPciPath(dev, &props.pciPath); - //pluginPtrSupport(dev, &props.ptrSupport); - return ncclInternalError; +__hidden ncclResult_t pluginGetProperties_v7(int dev, ncclNetProperties_v7_t* props_v7) { + ncclNetProperties_t props; + ncclResult_t ret = pluginGetProperties(dev, &props); + if (ret != ncclSuccess) return ret; + props_v7->name = props.name; + props_v7->pciPath = props.pciPath; + props_v7->guid = props.guid; + props_v7->ptrSupport = props.ptrSupport; + props_v7->speed = props.speed; + props_v7->port = props.port; + props_v7->maxComms = props.maxComms; + props_v7->maxRecvs = props.maxRecvs; + props_v7->netDeviceType = props.netDeviceType; + props_v7->netDeviceVersion = props.netDeviceVersion; + return ncclSuccess; +} + +__hidden ncclResult_t pluginRegMr_v7(void* collComm, void* data, int size, int type, void** mhandle) { + return pluginRegMr(collComm, data, size, type, mhandle); +} + +const ncclNet_v7_t ncclNetPlugin_v7 = { + .name = PLUGIN_NAME, + .init = pluginInit, + .devices = pluginDevices, + .getProperties = pluginGetProperties_v7, + .listen = pluginListen, + .connect = pluginConnect, + .accept = pluginAccept, + .regMr = pluginRegMr_v7, + .regMrDmaBuf = pluginRegMrDmaBuf, + .deregMr = pluginDeregMr, + .isend = pluginIsend, + .irecv = pluginIrecv, + .iflush = pluginIflush, + .test = pluginTest, + .closeSend = pluginCloseSend, + .closeRecv = pluginCloseRecv, + .closeListen = pluginCloseListen, + .getDeviceMr = pluginGetDeviceMr, + .irecvConsumed = pluginIrecvConsumed, +}; + +__hidden ncclResult_t pluginGetProperties_v6(int dev, ncclNetProperties_v6_t* props_v6) { + ncclNetProperties_t props; + ncclResult_t ret = pluginGetProperties(dev, &props); + if (ret != ncclSuccess) return ret; + props_v6->name = props.name; + props_v6->pciPath = props.pciPath; + props_v6->guid = props.guid; + props_v6->ptrSupport = props.ptrSupport; + props_v6->speed = props.speed; + props_v6->port = props.port; + props_v6->maxComms = props.maxComms; + props_v6->maxRecvs = props.maxRecvs; + return ncclSuccess; } __hidden ncclResult_t pluginConnect_v6(int dev, void* handle, void** sendComm) { return ncclInternalError; } @@ -77,7 +151,7 @@ const ncclNet_v6_t ncclNetPlugin_v6 = { .listen = pluginListen, .connect = pluginConnect_v6, .accept = pluginAccept_v6, - .regMr = pluginRegMr, + .regMr = pluginRegMr_v7, .regMrDmaBuf = pluginRegMrDmaBuf, .deregMr = pluginDeregMr, .isend = pluginIsend, @@ -98,7 +172,7 @@ const ncclNet_v5_t ncclNetPlugin_v5 = { .listen = pluginListen, .connect = pluginConnect_v6, .accept = pluginAccept_v6, - .regMr = pluginRegMr, + .regMr = pluginRegMr_v7, .deregMr = pluginDeregMr, .isend = pluginIsend, .irecv = pluginIrecv, @@ -110,17 +184,17 @@ const ncclNet_v5_t ncclNetPlugin_v5 = { }; /* v4 Compat */ -static ncclResult_t pluginGetProperties_v4(int dev, ncclNetProperties_v4_t* props) { - ncclNetProperties_v6_t props_v6; - ncclResult_t ret = pluginGetProperties_v6(dev, &props_v6); +static ncclResult_t pluginGetProperties_v4(int dev, ncclNetProperties_v4_t* props_v4) { + ncclNetProperties_t props; + ncclResult_t ret = pluginGetProperties(dev, &props); if (ret != ncclSuccess) return ret; - props->name = props_v6.name; - props->pciPath = props_v6.pciPath; - props->guid = props_v6.guid; - props->ptrSupport = props_v6.ptrSupport; - props->speed = props_v6.speed; - props->port = props_v6.port; - props->maxComms = props_v6.maxComms; + props_v4->name = props.name; + props_v4->pciPath = props.pciPath; + props_v4->guid = props.guid; + props_v4->ptrSupport = props.ptrSupport; + props_v4->speed = props.speed; + props_v4->port = props.port; + props_v4->maxComms = props.maxComms; return ncclSuccess; } static ncclResult_t pluginIsend_v4(void *sendComm, void* data, int size, void *mhandle, void** request) { @@ -157,7 +231,7 @@ const ncclNet_v4_t ncclNetPlugin_v4 = { .listen = pluginListen, .connect = pluginConnect_v4, .accept = pluginAccept_v4, - .regMr = pluginRegMr, + .regMr = pluginRegMr_v7, .deregMr = pluginDeregMr, .isend = pluginIsend_v4, .irecv = pluginIrecv_v4, @@ -202,7 +276,7 @@ const ncclNet_v3_t ncclNetPlugin_v3 = { .listen = pluginListen_v3, .connect = pluginConnect_v3, .accept = pluginAccept_v4, - .regMr = pluginRegMr, + .regMr = pluginRegMr_v7, .deregMr = pluginDeregMr, .isend = pluginIsend_v4, .irecv = pluginIrecv_v4, @@ -223,7 +297,7 @@ const ncclNet_v2_t ncclNetPlugin_v2 = { .listen = pluginListen, .connect = pluginConnect_v4, .accept = pluginAccept_v4, - .regMr = pluginRegMr, + .regMr = pluginRegMr_v7, .deregMr = pluginDeregMr, .isend = pluginIsend_v4, .irecv = pluginIrecv_v4, diff --git a/projects/rccl/makefiles/version.mk b/projects/rccl/makefiles/version.mk index b383eebe80..ab4fd3c9f7 100644 --- a/projects/rccl/makefiles/version.mk +++ b/projects/rccl/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 19 -NCCL_PATCH := 4 +NCCL_MINOR := 20 +NCCL_PATCH := 3 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/projects/rccl/src/Makefile b/projects/rccl/src/Makefile index 7a1881d9d6..b254eac32c 100644 --- a/projects/rccl/src/Makefile +++ b/projects/rccl/src/Makefile @@ -10,7 +10,7 @@ include ../makefiles/version.mk INCEXPORTS := nccl.h nccl_net.h LIBSRCFILES := \ bootstrap.cc channel.cc collectives.cc debug.cc enqueue.cc group.cc \ - init.cc init_nvtx.cc net.cc proxy.cc transport.cc \ + init.cc init_nvtx.cc net.cc proxy.cc transport.cc register.cc \ $(wildcard graph/*.cc) \ $(wildcard misc/*.cc) \ $(wildcard transport/*.cc) diff --git a/projects/rccl/src/bootstrap.cc b/projects/rccl/src/bootstrap.cc index 0c8a338d6e..a1475d375e 100644 --- a/projects/rccl/src/bootstrap.cc +++ b/projects/rccl/src/bootstrap.cc @@ -221,6 +221,7 @@ struct bootstrapState { struct ncclSocket ringSendSocket; union ncclSocketAddress* peerCommAddresses; union ncclSocketAddress* peerProxyAddresses; + uint64_t* peerProxyAddressesUDS; struct unexConn* unexpectedConnections; int cudaDev; int rank; @@ -295,6 +296,7 @@ ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* // Create the service proxy NCCLCHECK(ncclCalloc(&state->peerProxyAddresses, nranks)); + NCCLCHECK(ncclCalloc(&state->peerProxyAddressesUDS, nranks)); // proxy is aborted through a message; don't set abortFlag NCCLCHECK(ncclCalloc(&proxySocket, 1)); @@ -302,7 +304,10 @@ ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* NCCLCHECK(ncclSocketListen(proxySocket)); NCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank)); NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress))); - NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses)); + // cuMem UDS support + state->peerProxyAddressesUDS[rank] = getPidHash()+comm->commHash; + NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddressesUDS, sizeof(*state->peerProxyAddressesUDS))); + NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses, state->peerProxyAddressesUDS)); TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks); @@ -355,8 +360,6 @@ ncclResult_t bootstrapSplit(struct ncclBootstrapHandle* handle, struct ncclComm* for (int i = 0; i < nranks; ++i) { comm->topParentRanks[i] = parent->topParentRanks[parentRanks[i]]; } - comm->proxyState = parent->sharedRes->proxyState; - ncclAtomicRefCountIncrement(&parent->sharedRes->proxyState->refCount); } else { // Create the service proxy NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddresses, nranks), ret, fail); @@ -366,10 +369,14 @@ ncclResult_t bootstrapSplit(struct ncclBootstrapHandle* handle, struct ncclComm* NCCLCHECKGOTO(ncclSocketGetAddr(proxySocket, &tmpAddr), ret, fail); memcpy(state->peerProxyAddresses + rank, &tmpAddr, sizeof(union ncclSocketAddress)); NCCLCHECKGOTO(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress)), ret, fail); - NCCLCHECKGOTO(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses), ret, fail); + // cuMem UDS support + NCCLCHECKGOTO(ncclCalloc(&state->peerProxyAddressesUDS, nranks), ret, fail); + state->peerProxyAddressesUDS[rank] = getPidHash()+comm->commHash; + NCCLCHECKGOTO(bootstrapAllGather(state, state->peerProxyAddressesUDS, sizeof(*state->peerProxyAddressesUDS)), ret, fail); + NCCLCHECKGOTO(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses, state->peerProxyAddressesUDS), ret, fail); } - INFO(NCCL_INIT, "bootstrapSplit: rank %d nranks %d color %d key %d prev %d next %d - DONE", rank, nranks, color, key, prev, next); + INFO(NCCL_INIT, "bootstrapSplit: comm %p parent %p rank %d nranks %d color %d key %d prev %d next %d - DONE", comm, parent, rank, nranks, color, key, prev, next); exit: return ret; @@ -568,7 +575,7 @@ ncclResult_t bootstrapClose(void* commState) { struct bootstrapState* state = (struct bootstrapState*)commState; if (state->unexpectedConnections != NULL) { unexpectedFree(state); - if (*state->abortFlag == 0) { + if (__atomic_load_n(state->abortFlag, __ATOMIC_RELAXED) == 0) { WARN("Unexpected connections are not empty"); return ncclInternalError; } @@ -592,6 +599,7 @@ ncclResult_t bootstrapAbort(void* commState) { NCCLCHECK(ncclSocketClose(&state->ringRecvSocket)); free(state->peerCommAddresses); free(state->peerProxyAddresses); + free(state->peerProxyAddressesUDS); free(state); return ncclSuccess; } diff --git a/projects/rccl/src/debug.cc b/projects/rccl/src/debug.cc index 63b3e5bc08..2771d1b708 100644 --- a/projects/rccl/src/debug.cc +++ b/projects/rccl/src/debug.cc @@ -191,6 +191,9 @@ void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *file va_start(vargs, fmt); len += vsnprintf(buffer+len, sizeof(buffer)-len, fmt, vargs); va_end(vargs); + // vsnprintf may return len > sizeof(buffer) in the case of a truncated output. + // Rewind len so that we can replace the final \0 by \n + if (len > sizeof(buffer)) len = sizeof(buffer)-1; buffer[len++] = '\n'; fwrite(buffer, 1, len, ncclDebugFile); } diff --git a/projects/rccl/src/device/all_gather.h b/projects/rccl/src/device/all_gather.h index 0122499320..702eb97648 100644 --- a/projects/rccl/src/device/all_gather.h +++ b/projects/rccl/src/device/all_gather.h @@ -12,63 +12,50 @@ namespace { template __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const int nthreads = (int)args->nWarps * WARP_SIZE; ncclRing *ring = &ncclShmem.channel.ring; const int *ringRanks = ring->userRanks; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLGATHER_CHUNKSTEPS : 1)); - // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. - const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2); const int nranks = ncclShmem.comm.nRanks; - const ssize_t loopSize = nChannels*int(chunkSize); - const ssize_t size = args->count; + const size_t chunkCount = args->chunkCount; + const size_t channelCount = args->workCount; + const size_t gridOffset = args->workOffset; + const size_t count = args->count; + size_t offset; + size_t dataOffset; + int nelem; + int rankDest; T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 1, Proto, 0> prims (tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t realChunkSize; - if (Proto::Id == NCCL_PROTO_SIMPLE) { - realChunkSize = min(chunkSize, divUp(size-gridOffset,nChannels)); - realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); - } - else if (Proto::Id == NCCL_PROTO_LL) - realChunkSize = size-gridOffset < loopSize ? args->lastChunkSize : chunkSize; - else if (Proto::Id == NCCL_PROTO_LL128) - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128); - realChunkSize = int(realChunkSize); - - ssize_t chunkOffset = gridOffset + int(bid*realChunkSize); - + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { /////////////// begin AllGather steps /////////////// - ssize_t offset; - int nelem = min(realChunkSize, size-chunkOffset); - int rankDest; + nelem = min(chunkCount, channelCount - elemOffset); + dataOffset = gridOffset + elemOffset; // step 0: push data to next GPU rankDest = ringRanks[0]; - offset = chunkOffset + rankDest * size; + offset = dataOffset + rankDest * count; - if (inputBuf + chunkOffset == outputBuf + offset) { // In place - prims.directSend(chunkOffset, offset, nelem); + if (inputBuf + dataOffset == outputBuf + offset) { // In place + prims.directSend(dataOffset, offset, nelem); } else { - prims.directCopySend(chunkOffset, offset, nelem); + prims.directCopySend(dataOffset, offset, nelem); } // k-2 steps: copy to next GPU for (int j=1; j struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nChannels = args->nChannels; struct ncclNvls* nvls = &ncclShmem.channel.nvls; - const ssize_t chunkSize = int(args->lastChunkSize); - const ssize_t size = args->count; - const ssize_t loopSize = nChannels*chunkSize; + const ssize_t count = args->count; const ssize_t rank = ncclShmem.comm.rank; + const size_t chunkCount = args->chunkCount; + size_t gridOffset = args->workOffset; + size_t channelCount = args->workCount; + size_t offset; + int nelem; const int nThreadsBcast = args->regUsed ? (NCCL_MAX_NTHREADS - WARP_SIZE) : 4 * WARP_SIZE; const int nThreadsGather = args->regUsed ? WARP_SIZE : NCCL_MAX_NTHREADS - nThreadsBcast; @@ -122,10 +110,10 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * chunkSize; - int nelem = min(chunkSize, size - offset); - prims.gather(offset, nvls->nHeads * size, nelem, size, -1, 0); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); + prims.gather(offset, nvls->nHeads * count, nelem, count, -1, 0); } } else if (tid < tidEndBcast) { // Bcast through NVLS @@ -133,9 +121,9 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid - tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL, args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * chunkSize; - int nelem = min(chunkSize, size - offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.send(offset, nelem); } } @@ -150,7 +138,7 @@ struct RunWorkElement +struct RunWorkElement { + template + struct Scatterer { + struct ncclWorkElem* args; + ssize_t chunkSize; + ssize_t railGridOffset; + + template + __device__ __forceinline__ void operator()( + int tid, int tn, int slice, int maxSliceSize, + int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs, int32_t* dstSizes + ) { + static_assert(SlicePerChunk==1, "require: SlicePerChunk==1"); + static_assert(MaxDsts<=1 || MaxSrcs<=1, "require: MaxDsts<=1 || MaxSrcs<=1"); + + struct ncclDirect* direct = &ncclShmem.channel.collnetDirect; + int nNodes = ncclShmem.comm.nNodes; + int nRails = direct->nHeads; + int bid = args->bid; + char* inbuf = (char*)args->sendbuff; + char* outbuf = (char*)args->recvbuff; + ssize_t sizePerRank = args->count*sizeof(T); + bool inPlace = (inbuf == outbuf + ncclShmem.comm.rank*sizePerRank); + + ssize_t railAllBeg = min(railGridOffset + bid*chunkSize, nNodes*sizePerRank); + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank); + int railAllSize = railAllEnd - railAllBeg; + if (tid < nDsts) dstSizes[tid] = railAllSize; + + int src = 0; + int rail; + if (BcastSendNotRecv) { + rail = direct->headRank; + } else { + rail = direct->headRank+1; + if (rail == nRails) rail = 0; + } + do { + int node = railAllBeg/sizePerRank; + int railAllOffset = 0; + while (railAllOffset < railAllSize) { + ssize_t railOneBeg = node*sizePerRank; + ssize_t railOneEnd = railOneBeg + sizePerRank; + ssize_t railOneOffset = (railAllBeg+railAllOffset) - railOneBeg; + int delta = min(railAllEnd, railOneEnd) - (railAllBeg+railAllOffset); + int rank = ncclShmem.comm.collNetDenseToUserRank[node*nRails + rail]; + ssize_t userOneBeg = rank*sizePerRank + railOneOffset; + int outIsDst = (inPlace && rank == ncclShmem.comm.rank) ? 0 : 1; + reduceCopy + (tid, tn, 0, nullptr, false, + /*nSrcs=*/1, [=]__device__(int s/*==0*/) -> void* { + return (char*)srcPtrs[src] + railAllOffset; + }, + /*nDsts=*/outIsDst+nDsts, [=]__device__(int d) -> void* { + return d < outIsDst ? outbuf + userOneBeg + : (char*)dstPtrs[d-outIsDst] + railAllOffset; + }, + delta); + railAllOffset += delta; + node += 1; + } + src += 1; + rail += 1; + if (rail == nRails) rail = 0; + } while (!BcastSendNotRecv && src < nRails-1); + } + }; + + __device__ __forceinline__ void run(ncclWorkElem *args) { + int tid = threadIdx.x; + const int nChannels = args->nChannels; + struct ncclDirect* direct = &ncclShmem.channel.collnetDirect; + int const &nNodes = ncclShmem.comm.nNodes; + ssize_t chunkSize = int(args->chunkCount); + ssize_t const &sizePerRank = args->count; + + bool isMultiRail = (direct->nHeads > 1); + int nWarps1 = 1; + int nWarps2 = (isMultiRail ? 2 : 1); + int nWarps3 = (isMultiRail ? 2 : 0); + float denom = float(args->nWarps)/float(nWarps1+nWarps2+nWarps3); + nWarps3 = int(denom*nWarps3); + nWarps2 = int(denom*nWarps2); + nWarps1 = args->nWarps - (nWarps2+nWarps3); + + using Proto = ProtoSimple<1, 1>; + + int tn = nWarps1*WARP_SIZE; + if (tid < tn) { + // Phase 1: send to network + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, tn, nullptr, &direct->out, args->sendbuff, nullptr, + /*redOpArg=*/0, 0*Proto::MaxGroupWidth, 1, 1); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + ssize_t railAllBeg = railGridOffset + args->bid*chunkSize; + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank); + ssize_t railOneBeg = ncclShmem.comm.node*sizePerRank; + ssize_t railOneEnd = railOneBeg + sizePerRank; + ssize_t beg = max(railAllBeg, railOneBeg); + ssize_t end = min(railAllEnd, railOneEnd); + prims.send(beg-railOneBeg, max(ssize_t(0), end-beg)); + } + return; + } + tid -= tn; + + tn = nWarps2*WARP_SIZE; + if (tid < tn) { + // Phase 2: Recv network -> deposit output + send to bcast + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, tn, &direct->out, direct->heads+1, nullptr, nullptr, + /*redOpArg=*/0, 1*Proto::MaxGroupWidth, 0, 0); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + Scatterer scat; + scat.args = args; + scat.chunkSize = chunkSize; + scat.railGridOffset = railGridOffset; + prims.process(scat); + } + return; + } + tid -= tn; + + tn = nWarps3*WARP_SIZE; + if (tid < tn) { + // Phase 3: Recv bcast -> deposit output + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, tn, direct->heads+1, nullptr, nullptr, nullptr, + /*redOpArg=*/0, 2*Proto::MaxGroupWidth, 0, 0); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + Scatterer scat; + scat.args = args; + scat.chunkSize = chunkSize; + scat.railGridOffset = railGridOffset; + prims.process(scat); + } + return; + } + } +}; diff --git a/projects/rccl/src/device/all_reduce.h b/projects/rccl/src/device/all_reduce.h index bf37dfe962..75e2bed541 100644 --- a/projects/rccl/src/device/all_reduce.h +++ b/projects/rccl/src/device/all_reduce.h @@ -12,84 +12,69 @@ namespace { template __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const int nthreads = (int)args->nWarps * WARP_SIZE; ncclRing *ring = &ncclShmem.channel.ring; int ringIx = ring->index; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? ALLREDUCE_CHUNKSTEPS : 1)); + ssize_t chunkCount = args->chunkCount; const int nranks = ncclShmem.comm.nRanks; - const ssize_t loopSize = nChannels*nranks*chunkSize; - const ssize_t size = args->count; - - int minChunkSize; - if (Proto::Id == NCCL_PROTO_LL) - minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T)); - if (Proto::Id == NCCL_PROTO_LL128) { - // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. - minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2; - } + const ssize_t loopCount = nranks * chunkCount; + ssize_t offset; + ssize_t gridOffset = args->workOffset; + ssize_t channelCount = args->workCount; + int nelem; + int chunk; Primitives, 1, Proto, 0> prims (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t realChunkSize; - if (Proto::Id == NCCL_PROTO_SIMPLE) { - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks)); - realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); - } - else - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*nranks*minChunkSize)*minChunkSize); - realChunkSize = int(realChunkSize); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + ssize_t remCount = channelCount - elemOffset; + ssize_t chunkOffset; + + if (remCount < loopCount) chunkCount = args->lastChunkCount; - auto calcOffset = [&]__device__(int chunk)->ssize_t { - if (Proto::Id == NCCL_PROTO_SIMPLE) - return gridOffset + bid*nranks*realChunkSize + chunk*realChunkSize; - else - return gridOffset + (chunk*nChannels + bid)*realChunkSize; - }; auto modRanks = [&]__device__(int r)->int { return r - (r >= nranks ? nranks : 0); }; - ssize_t offset; - int nelem; - int chunk; - // step 0: push data to next GPU - chunk = modRanks(ringIx + nranks-1); - offset = calcOffset(chunk); - nelem = min(realChunkSize, size-offset); + chunk = modRanks(ringIx + nranks - 1); + chunkOffset = chunk * chunkCount; + offset = gridOffset + elemOffset + chunkOffset; + nelem = (int)min(chunkCount, remCount - chunkOffset); prims.send(offset, nelem); // k-2 steps: reduce and copy to next GPU - for (int j=2; j __device__ __forceinline__ void runTreeUpDown(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const int nthreads = (int)args->nWarps * WARP_SIZE; ncclTree *tree = &ncclShmem.channel.tree; - ssize_t chunkSize = int( - Proto::Id == NCCL_PROTO_SIMPLE ? args->lastChunkSize - /* LL & LL128 */ : Proto::calcBytePerStep()/sizeof(T)); - const ssize_t minChunkSize = int( - Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads-2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) - /* LL & LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))); - const ssize_t loopSize = int(nChannels*chunkSize); - const ssize_t size = args->count; - - if (loopSize > size) - chunkSize = divUp((int)size, int(nChannels*minChunkSize))*int(minChunkSize); + const size_t channelCount = args->workCount; + const size_t gridOffset = args->workOffset; + const size_t chunkCount = args->chunkCount; + size_t offset; + int nelem; { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) Primitives, /*Direct=*/0, Proto, 0> prims (tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg); if (tree->up == -1) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true); } } else if (tree->down[0] == -1) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.send(offset, nelem); } } else { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceSend(offset, nelem); } } @@ -143,23 +120,23 @@ namespace { Primitives, /*Direct=*/1, Proto, 0> prims (tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg); if (tree->up == -1) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.directSendFromOutput(offset, nelem); } } else if (tree->down[0] == -1) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.directRecv(offset, nelem); } } else { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.directRecvCopySend(offset, nelem); } } @@ -169,19 +146,13 @@ namespace { template __device__ __forceinline__ void runTreeSplit(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const int nthreads = (int)args->nWarps * WARP_SIZE; ncclTree *tree = &ncclShmem.channel.tree; - ssize_t chunkSize = int( - Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize - : Proto::calcBytePerStep()/sizeof(T)); - const ssize_t minChunkSize = int( - Proto::Id == NCCL_PROTO_SIMPLE ? (nthreads - 2*WARP_SIZE)*8*(sizeof(uint64_t)/sizeof(T)) : - Proto::Id == NCCL_PROTO_LL ? nthreads*(Proto::calcBytePerGrain()/sizeof(T)) - /* LL128 */ : nthreads*(Proto::calcBytePerGrain()/sizeof(T))/8); - const ssize_t loopSize = int(nChannels*chunkSize); - const ssize_t size = args->count; + const size_t chunkCount = args->chunkCount; + const size_t gridOffset = args->workOffset; + const size_t channelCount = args->workCount; + size_t offset; + int nelem; int nthreadsSplit; if (Proto::Id == NCCL_PROTO_SIMPLE) { @@ -193,16 +164,13 @@ namespace { nthreadsSplit = (nthreads*7/(10*WARP_SIZE))*WARP_SIZE; } - if (loopSize > size) - chunkSize = divUp((int)size, nChannels*int(minChunkSize))*int(minChunkSize); - if (tree->up == -1) { // Reduce and broadcast. Max number of recv is 2, max number of send is 2 Primitives, /*Direct=*/1, Proto, 0> prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->redOpArg); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.directRecvReduceCopySend(offset, offset, nelem, /*doPost=*/true); } } @@ -218,16 +186,16 @@ namespace { Primitives, /*Direct=*/1, Proto, 0> prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->redOpArg, 0*Proto::MaxGroupWidth); if (tree->down[0] == -1) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.send(offset, nelem); } } else { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceSend(offset, nelem); } } @@ -238,16 +206,16 @@ namespace { prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->redOpArg, 1*Proto::MaxGroupWidth); if (tree->down[0] == -1) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.directRecv(offset, nelem); } } else { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid*int(chunkSize); - int nelem = min(chunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.directRecvCopySend(offset, nelem); } } @@ -282,7 +250,7 @@ struct RunWorkElementbid; const int nChannels = args->nChannels; struct ncclDirect* direct = &ncclShmem.channel.collnetDirect; - const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t chunkSize = args->chunkCount; const ssize_t size = args->count; const ssize_t loopSize = nChannels*direct->nHeads*chunkSize; @@ -378,14 +346,10 @@ template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nChannels = args->nChannels; struct ncclNvls* nvls = &ncclShmem.channel.nvls; - const ssize_t chunkSize = int(args->lastChunkSize); - const ssize_t size = args->count; - const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize; - const int nranks = ncclShmem.comm.nRanks; + ssize_t chunkSize = args->chunkCount; const bool hasOut = nvls->out != -1; + const int nranks = ncclShmem.comm.nRanks; const int totalWarps = NCCL_MAX_NTHREADS/WARP_SIZE; const int bcastWarps = hasOut ? (args->regUsed ? ((totalWarps - 2) >> 1) - 1 : 2) : 0; const int reduceWarps = args->regUsed ? (totalWarps - bcastWarps - 2) : (hasOut ? 3 : nranks <= 6 ? 7 : 5); @@ -401,62 +365,114 @@ struct RunWorkElement; - Primitives, /*Direct=*/0, Proto, 0> - prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, - args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; - int nelem = args->regUsed ? 0 : min(nvls->nHeads * chunkSize, size - offset); - prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); - } - } else if (tid < tidEndGather) { - // Gather - using Proto = ProtoSimple<1, 1, COLL_UNROLL>; - Primitives, /*Direct=*/0, Proto, 0> - prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, - args->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; - int nelem = args->regUsed ? 0 :min(nvls->nHeads * chunkSize, size - offset); - prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); - } - } else if (tid < tidEndReduce && nvls->headRank != -1) { - if (!hasOut) { + if (args->oneNode) { + const ssize_t loopCount = nvls->nHeads * chunkSize; + const ssize_t channelCount = args->workCount; + const ssize_t gridOffset = args->workOffset; + ssize_t offset; + int nelem; + + if (tid < tidEndScatter) { + // Scatter + using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, + args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + if (channelCount - elemOffset < loopCount) chunkSize = args->lastChunkCount; + offset = gridOffset + elemOffset; + nelem = args->regUsed ? 0 : min(loopCount, channelCount - elemOffset); + prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndGather) { + // Gather + using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + Primitives, /*Direct=*/0, Proto, 0> + prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, + args->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + if (channelCount - elemOffset < loopCount) chunkSize = args->lastChunkCount; + offset = gridOffset + elemOffset; + nelem = args->regUsed ? 0 : min(loopCount, channelCount - elemOffset); + prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndReduce) { // Reduce, broadcast through NVLS using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>; Primitives, /*Direct=*/1, Proto, 0> prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, NULL, NULL, args->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); - prims.directRecvDirectSend(offset, offset, nelem); - } - } else { - // Reduce, send to network - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; - Primitives, /*Direct=*/1, Proto, 0> - prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->out, NULL, NULL, - args->redOpArg, 2 * Proto::MaxGroupWidth, 0, 1, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + ssize_t chunkOffset; + if (channelCount - elemOffset < loopCount) chunkSize = args->lastChunkCount; + chunkOffset = elemOffset + nvls->headRank * chunkSize; + offset = gridOffset + chunkOffset; + nelem = min(chunkSize, channelCount - chunkOffset); prims.directRecvDirectSend(offset, offset, nelem); } } - } else if (tid < tidEndBcast && nvls->headRank != -1) { - // Recv from network, broadcast - using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>; - Primitives, /*Direct=*/1, Proto, 0> - prims(tid - tidEndReduce, nThreadsBcast, &nvls->out, &nvls->down, NULL, NULL, - args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); - prims.directRecvDirectSend(offset, offset, nelem); + } else { + const int bid = args->bid; + const ssize_t loopSize = args->nChannels * nvls->nHeads * chunkSize; + const ssize_t size = args->count; + + if (tid < tidEndScatter) { + // Scatter + using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, + args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; + int nelem = args->regUsed ? 0 : min(nvls->nHeads * chunkSize, size - offset); + prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndGather) { + // Gather + using Proto = ProtoSimple<1, 1, COLL_UNROLL>; + Primitives, /*Direct=*/0, Proto, 0> + prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, + args->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; + int nelem = args->regUsed ? 0 :min(nvls->nHeads * chunkSize, size - offset); + prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndReduce && nvls->headRank != -1) { + if (!hasOut) { + // Reduce, broadcast through NVLS + using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>; + Primitives, /*Direct=*/1, Proto, 0> + prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, NULL, NULL, + args->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); + prims.directRecvDirectSend(offset, offset, nelem); + } + } else { + // Reduce, send to network + using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>; + Primitives, /*Direct=*/1, Proto, 0> + prims(tid - tidEndGather, nThreadsReduce, &nvls->down, &nvls->out, NULL, NULL, + args->redOpArg, 2 * Proto::MaxGroupWidth, 0, 1, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); + prims.directRecvDirectSend(offset, offset, nelem); + } + } + } else if (tid < tidEndBcast && nvls->headRank != -1) { + // Recv from network, broadcast + using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>; + Primitives, /*Direct=*/1, Proto, 0> + prims(tid - tidEndReduce, nThreadsBcast, &nvls->out, &nvls->down, NULL, NULL, + args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; + int nelem = min(chunkSize, size - offset); + prims.directRecvDirectSend(offset, offset, nelem); + } } } } @@ -466,14 +482,13 @@ template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nChannels = args->nChannels; struct ncclNvls* nvls = &ncclShmem.channel.nvls; const int treeUp = nvls->treeUp; const int* treeDown = nvls->treeDown; - const ssize_t chunkSize = int(args->lastChunkSize); - const ssize_t size = args->count; - const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize; + ssize_t chunkCount = args->chunkCount; + const ssize_t loopCount = nvls->nHeads * chunkCount; + const ssize_t channelCount = args->workCount; + const ssize_t gridOffset = args->workOffset; const int nranks = ncclShmem.comm.nRanks; const bool hasUp = treeUp != -1; const int totalWarps = NCCL_MAX_NTHREADS/WARP_SIZE; @@ -481,6 +496,8 @@ struct RunWorkElementregUsed ? (totalWarps - bcastWarps - 2) : (hasUp ? 5 : nranks <= 6 ? 7 : 5); const int scatterWarps = args->regUsed ? 1 : (totalWarps - reduceWarps - bcastWarps + 1) >> 1; const int gatherWarps = args->regUsed ? 1 : (totalWarps - reduceWarps - bcastWarps) >> 1; + ssize_t offset; + int nelem; const int nThreadsScatter = scatterWarps*WARP_SIZE; const int nThreadsGather = gatherWarps*WARP_SIZE; @@ -497,10 +514,11 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; - int nelem = args->regUsed ? 0 : min(nvls->nHeads * chunkSize, size - offset); - prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + if (channelCount - elemOffset < loopCount) chunkCount = args->lastChunkCount; + offset = gridOffset + elemOffset; + nelem = args->regUsed ? 0 : min(loopCount, channelCount - elemOffset); + prims.scatter(offset, nelem, chunkCount, chunkCount, -1, 0); } } else if (tid < tidEndGather) { // Gather @@ -508,10 +526,11 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, args->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * nvls->nHeads * chunkSize; - int nelem = args->regUsed ? 0 : min(nvls->nHeads * chunkSize, size - offset); - prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + if (channelCount - elemOffset < loopCount) chunkCount = args->lastChunkCount; + offset = gridOffset + elemOffset; + nelem = args->regUsed ? 0 : min(loopCount, channelCount - elemOffset); + prims.gather(offset, nelem, chunkCount, chunkCount, -1, 0); } } else if (tid < tidEndReduce && nvls->headRank != -1) { if (!hasUp) { @@ -520,9 +539,12 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> prims(tid - tidEndGather, nThreadsReduce, treeDown, treeDown, NULL, NULL, args->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + ssize_t chunkOffset; + if (channelCount - elemOffset < loopCount) chunkCount = args->lastChunkCount; + chunkOffset = elemOffset + nvls->headRank * chunkCount; + offset = gridOffset + chunkOffset; + nelem = min(chunkCount, channelCount - chunkOffset); prims.directRecvDirectSend(offset, offset, nelem); } } else { @@ -531,9 +553,12 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> prims(tid - tidEndGather, nThreadsReduce, treeDown, &treeUp, NULL, NULL, args->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + ssize_t chunkOffset; + if (channelCount - elemOffset < loopCount) chunkCount = args->lastChunkCount; + chunkOffset = elemOffset + nvls->headRank * chunkCount; + offset = gridOffset + chunkOffset; + nelem = min(chunkCount, channelCount - chunkOffset); prims.directRecvDirectSend(offset, offset, nelem); } } @@ -543,9 +568,12 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> prims(tid - tidEndReduce, nThreadsBcast, &treeUp, treeDown, NULL, NULL, args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + (bid * nvls->nHeads + nvls->headRank) * chunkSize; - int nelem = min(chunkSize, size - offset); + for (ssize_t elemOffset = 0; elemOffset < channelCount; elemOffset += loopCount) { + ssize_t chunkOffset; + if (channelCount - elemOffset < loopCount) chunkCount = args->lastChunkCount; + chunkOffset = elemOffset + nvls->headRank * chunkCount; + offset = gridOffset + chunkOffset; + nelem = min(chunkCount, channelCount - chunkOffset); prims.directRecvDirectSend(offset, offset, nelem); } } @@ -560,7 +588,7 @@ struct RunWorkElementbid; const int nChannels = args->nChannels; ncclTree *tree = &ncclShmem.channel.collnetChain; - ssize_t chunkSize = int(args->lastChunkSize); + ssize_t chunkSize = args->chunkCount; const ssize_t loopSize = int(nChannels*chunkSize); const int nranks = ncclShmem.comm.nRanks; const ssize_t size = args->count; diff --git a/projects/rccl/src/device/broadcast.h b/projects/rccl/src/device/broadcast.h index 15bf841d50..86d45e77ef 100644 --- a/projects/rccl/src/device/broadcast.h +++ b/projects/rccl/src/device/broadcast.h @@ -12,37 +12,25 @@ namespace { template __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const int nthreads = (int)args->nWarps * WARP_SIZE; ncclRing *ring = &ncclShmem.channel.ring; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1)); - const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->count; const int rank = ring->userRanks[0]; const int nextRank = ring->userRanks[1]; const int root = args->root; + const size_t chunkCount = args->chunkCount; + const size_t channelCount = args->workCount; + const size_t gridOffset = args->workOffset; + size_t offset; + int nelem; T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 0, Proto, 0> prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t realChunkSize; - if (Proto::Id == NCCL_PROTO_SIMPLE) { - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); - realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); - } - else if (Proto::Id == NCCL_PROTO_LL) - realChunkSize = size-gridOffset < loopSize ? args->lastChunkSize : chunkSize; - else if (Proto::Id == NCCL_PROTO_LL128) - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128); - realChunkSize = int(realChunkSize); - - ssize_t offset = gridOffset + int(bid*realChunkSize); - int nelem = min(realChunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); if (rank == root) { if (inputBuf == outputBuf) { diff --git a/projects/rccl/src/device/common.h b/projects/rccl/src/device/common.h index 97581f738d..8f3cad3283 100644 --- a/projects/rccl/src/device/common.h +++ b/projects/rccl/src/device/common.h @@ -25,6 +25,7 @@ struct ncclShmemGroup { union { unpackGroupShmem unpack; } devicePlugin; + int32_t dstSizes[NCCL_MAX_NVLS_ARITY+1]; }; struct ncclShmemData { diff --git a/projects/rccl/src/device/common_kernel.h b/projects/rccl/src/device/common_kernel.h index bfeb87fdf3..e82c94714e 100644 --- a/projects/rccl/src/device/common_kernel.h +++ b/projects/rccl/src/device/common_kernel.h @@ -28,11 +28,11 @@ inline __device__ int loadInt(int* ptr) { template + typename IntBytes, typename SrcPtrFn, typename DstPtrFn> __device__ __forceinline__ void reduceCopyPacks( int nThreads, int &thread, uint64_t redArg, uint64_t *preOpArgs, bool postOp, - int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, + int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn, IntBytes &nBytesBehind, IntBytes &nBytesAhead ) { static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); @@ -66,10 +66,10 @@ __device__ __forceinline__ void reduceCopyPacks( uintptr_t minDsts[MinDsts + !MinDsts]; #pragma unroll for (int s=0; s < MinSrcs; s++) - minSrcs[s] = cvta_to_global(srcPtrs[s]) + threadBytesBehind; + minSrcs[s] = cvta_to_global(srcPtrFn(s)) + threadBytesBehind; #pragma unroll for (int d=0; d < MinDsts; d++) - minDsts[d] = cvta_to_global(dstPtrs[d]) + threadBytesBehind; + minDsts[d] = cvta_to_global(dstPtrFn(d)) + threadBytesBehind; // We dictate loop termination condition according to whether partial hunks // can be handled or not. @@ -114,7 +114,7 @@ __device__ __forceinline__ void reduceCopyPacks( } for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) { - uintptr_t src = cvta_to_global(srcPtrs[s]) + threadBytesBehind; + uintptr_t src = cvta_to_global(srcPtrFn(s)) + threadBytesBehind; BytePack tmp[Unroll]; RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); #pragma unroll Unroll @@ -149,7 +149,7 @@ __device__ __forceinline__ void reduceCopyPacks( } } for (int d=MinDsts; (MinDsts < MaxDsts) && (d < MaxDsts) && (d < nDsts); d++) { - uintptr_t dst = cvta_to_global(dstPtrs[d]) + threadBytesBehind; + uintptr_t dst = cvta_to_global(dstPtrFn(d)) + threadBytesBehind; #pragma unroll Unroll for (int u=0; u < Unroll; u++) { st_global(dst, acc[u]); @@ -183,11 +183,11 @@ __device__ __forceinline__ void reduceCopyPacks( template + typename IntBytes, typename SrcPtrFn, typename DstPtrFn> __device__ __forceinline__ void reduceCopy( int thread, int nThreads, uint64_t redArg, uint64_t *preOpArgs, bool postOp, - int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, + int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn, IntBytes nElts ) { static_assert(MultimemSrcs <= MinSrcs && MultimemDsts <= MinDsts, "Multimem pointers cannot exceed respective Min values."); @@ -198,6 +198,9 @@ __device__ __forceinline__ void reduceCopy( // is supported for this redfn/type. constexpr int BigPackSize = (MultimemSrcs == 0) ? 16 : LoadMultimem_BigPackSize::BigPackSize; + if (MaxDsts==0) return; + if (MinDsts==0 && nDsts==0) return; + IntBytes nBytesBehind = 0; IntBytes nBytesAhead = nElts*sizeof(T); @@ -208,20 +211,20 @@ __device__ __forceinline__ void reduceCopy( #endif // Check that all pointers are BigPackSize aligned. bool aligned = true; - if (lane < nSrcs) aligned &= 0 == cvta_to_global(srcPtrs[lane]) % (BigPackSize + !BigPackSize); - if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane]) % (BigPackSize + !BigPackSize); + if (lane < nSrcs) aligned &= 0 == cvta_to_global(srcPtrFn(lane)) % (BigPackSize + !BigPackSize); + if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrFn(lane)) % (BigPackSize + !BigPackSize); aligned = __all_sync(~0u, aligned); if (aligned) { reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, - nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead); if (nBytesAhead == 0) return; reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, - nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead); if (nBytesAhead == 0) return; } } @@ -229,13 +232,31 @@ __device__ __forceinline__ void reduceCopy( reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, - nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead); if (nBytesAhead == 0) return; reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, - nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead); +} + +template +__device__ __forceinline__ void reduceCopy( + int thread, int nThreads, + uint64_t redArg, uint64_t *preOpArgs, bool postOp, + int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs, + IntBytes nElts + ) { + reduceCopy + (thread, nThreads, redArg, preOpArgs, postOp, + nSrcs, [=]__device__(int i) { return srcPtrs[i]; }, + nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts); } #endif // COMMON_KERNEL_H_ diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index 0b053de17e..43de85d616 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -74,11 +74,11 @@ else: ################################################################################ algos_of_coll = { - "AllGather": ["RING","NVLS"], + "AllGather": ["RING","COLLNET_DIRECT","NVLS"], "AllReduce": all_algos, "Broadcast": ["RING"], "Reduce": ["RING"], - "ReduceScatter": ["RING","NVLS"], + "ReduceScatter": ["RING","COLLNET_DIRECT","NVLS"], "SendRecv": [None] } diff --git a/projects/rccl/src/device/prims_ll.h b/projects/rccl/src/device/prims_ll.h index f341d6fb81..5f59690999 100644 --- a/projects/rccl/src/device/prims_ll.h +++ b/projects/rccl/src/device/prims_ll.h @@ -26,7 +26,7 @@ class Primitives: uint64_t recvConnHead; struct ncclConnInfo* sendConn = NULL; - volatile int* sendConnFifoPtr = NULL; + volatile struct ncclConnFifo* sendConnFifo = NULL; volatile uint64_t* sendConnHeadPtr = NULL; uint64_t sendConnHead; uint64_t sendConnHeadCache; // Cache last seen value @@ -68,9 +68,9 @@ class Primitives: sendConnHeadCache = *sendConnHeadPtr; if (checkAbort(spins, 1)) break; } - if (sendConnFifoPtr) { + if (sendConnFifo) { int size = ((sendConnHead & NCCL_LL_CLEAN_MASK) == NCCL_LL_CLEAN_MASK) ? stepLines*sizeof(union ncclLLFifoLine) : nbytes; - sendConnFifoPtr[sendConnHead%NCCL_STEPS] = size; + sendConnFifo[sendConnHead%NCCL_STEPS].size = size; } sendConnHead += 1; } @@ -315,7 +315,7 @@ class Primitives: sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; - sendConnFifoPtr = sendConn->sizesFifo; + sendConnFifo = sendConn->connFifo; } } @@ -323,7 +323,7 @@ class Primitives: __device__ Primitives( const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0, - uint8_t connIndexRecv=0, uint8_t connIndexSend=0, struct ncclWorkElem* e = nullptr, int stepSize_=0 + uint8_t connIndexRecv=0, uint8_t connIndexSend=0, struct ncclWorkElem* e = nullptr, struct ncclWorkElemP2p* p2p = nullptr, int stepSize_=0 ): redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group), diff --git a/projects/rccl/src/device/prims_ll128.h b/projects/rccl/src/device/prims_ll128.h index 43e01c485d..698eea68e6 100644 --- a/projects/rccl/src/device/prims_ll128.h +++ b/projects/rccl/src/device/prims_ll128.h @@ -30,7 +30,7 @@ class Primitives: uint64_t recvConnHead; struct ncclConnInfo* sendConn = NULL; - volatile int* sendConnFifoPtr = NULL; + volatile struct ncclConnFifo* sendConnFifo = NULL; volatile uint64_t* sendConnTailPtr = NULL; uint64_t sendConnTail; volatile uint64_t* sendConnHeadPtr = NULL; @@ -71,8 +71,8 @@ class Primitives: sendConnHeadCache = *sendConnHeadPtr; if (checkAbort(spins, wid, 1)) break; } - if (sendConnFifoPtr) { - sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes; + if (sendConnFifo) { + sendConnFifo[sendStep[wid]%NCCL_STEPS].size = nbytes; } sendConnHead += 1; } @@ -350,10 +350,10 @@ class Primitives: sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; - sendConnFifoPtr = sendConn->sizesFifo; + sendConnFifo = sendConn->connFifo; } if (tid >= nthreads-WARP_SIZE && widsizesFifo) { + if (sendConn->connFifo) { sendConnTailPtr = sendConn->tail; sendConnTail = sendConn->step; } diff --git a/projects/rccl/src/device/prims_simple.h b/projects/rccl/src/device/prims_simple.h index 048052eef1..6bf8a1a8af 100644 --- a/projects/rccl/src/device/prims_simple.h +++ b/projects/rccl/src/device/prims_simple.h @@ -20,8 +20,8 @@ class Primitives< RolePostSend = 0x10, RolePostRecv = 0x20, Aborted = 0x40, - OffsFifoEnabled = 0x80, - SizesFifoEnabled = 0x100, + UserBufferMode = 0x80, + ConnFifoEnabled = 0x100, DirectWrite = 0x200, DirectRead = 0x400, ThreadsSynced = 0x800, @@ -39,15 +39,12 @@ class Primitives< int flags; int group; uint64_t step; - int *connOffsFifoPtr; // (flags & OffsFifoEnabled) + struct ncclConnFifo* connFifo = NULL; union { T *userBuff; // (flags & (RoleInput|RoleOutput)) T *connEltsFifo; // !(flags & (RoleInput|RoleOutput)) }; - union { - int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled) - T *directBuff; // !(flags & SizesFifoEnabled) - }; + T *directBuff; uint64_t *connStepPtr; uint64_t connStepCache; // Cache last seen value of (*connStepPtr) void* mhandle; @@ -141,14 +138,16 @@ class Primitives< } if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { - if (isSendNotRecv && (flags & SizesFifoEnabled)) - connSizesFifoPtr[step%NCCL_STEPS] = nelts*sizeof(T); + if (flags & ConnFifoEnabled) + connFifo[step%NCCL_STEPS].size = nelts*sizeof(T); void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst) : (ncclShmem.groups[group].srcs + Src); - if (flags & OffsFifoEnabled) - ptrs[index] = connEltsFifo + loadInt(connOffsFifoPtr + (step%NCCL_STEPS))/sizeof(T); - else if (isSendNotRecv && DirectSend) { + if (flags & UserBufferMode) { + // Do nothing + } else if ((flags & ConnFifoEnabled) && connFifo[step%NCCL_STEPS].mode == NCCL_MODE_OFFSET) { + ptrs[index] = connEltsFifo + loadInt(&connFifo[step%NCCL_STEPS].offset)/sizeof(T); + } else if (isSendNotRecv && DirectSend) { if (flags & (DirectWrite | NvlsDirectWrite)) { ptrs[index] = directBuff + dstIx + offset; } else if (flags & DirectRead) { // empty send @@ -179,7 +178,9 @@ class Primitives< inline __device__ void postPeer(bool dataStored) { if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { step += StepPerSlice; - if (Send && (flags & RolePostSend) && dataStored) fence_acq_rel_sys(); + if (Send && (flags & RolePostSend) && (dataStored||(flags&ConnFifoEnabled))) { + fence_acq_rel_sys(); + } st_relaxed_sys_global(connStepPtr, step); } } @@ -199,7 +200,7 @@ class Primitives< int slice = 0; int offset = 0; - if (tid < nworkers && offset < nelem) { + if (tid < nworkers && offset < nelem && ((flags & UserBufferMode) == 0)) { // Worker-only loop for non-empty slices. Non-workers and empty slices are // processed in the loop following this if block. The benefit of splitting // the loop like this is we pull two branches out of the critical path. @@ -301,6 +302,55 @@ class Primitives< } } +public: + template + __device__ __forceinline__ void process(Fn &&fn) { + #pragma unroll 1 + for (int slice=0; slice < SlicePerChunk; slice++) { + if (tid < nworkers) { + if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { + bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; + int spins = 0; + while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { + connStepCache = loadStepValue(connStepPtr); + if (checkAbort(spins)) break; + } + void **ptrs = isSendNotRecv ? ncclShmem.groups[group].dsts + : ncclShmem.groups[group].srcs; + if ((flags & ConnFifoEnabled) && connFifo[step%NCCL_STEPS].mode == NCCL_MODE_OFFSET) { + int offset = loadInt(&connFifo[step%NCCL_STEPS].offset); + ptrs[index] = connEltsFifo + offset/sizeof(T); + } else { + ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; + } + } + subBarrier(); + fn.template operator() + (tid, nworkers, slice, stepSize*StepPerSlice, + fan.nrecv(), ncclShmem.groups[group].srcs, + fan.nsend(), ncclShmem.groups[group].dsts, ncclShmem.groups[group].dstSizes); + } + barrier(); + int32_t dstSize = 0; + if (flags & Send*RolePostSend) { + dstSize = ncclShmem.groups[group].dstSizes[index]; + ncclShmem.groups[group].dstSizes[index] = 0; + if (flags & ConnFifoEnabled) connFifo[step%NCCL_STEPS].size = dstSize*sizeof(T); + } + barrier(); + if (flags & (Recv*(RoleWaitRecv|RolePostRecv) | Send*(RoleWaitSend|RolePostSend))) { + step += StepPerSlice; + } + if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { + if (Send && (!Recv || (flags & RolePostSend)) && (dstSize!=0 || (flags&ConnFifoEnabled))) { + fence_acq_rel_sys(); + } + st_relaxed_sys_global(connStepPtr, step); + } + } + } + +private: // Scatter/Gather generic op // skip: my own rank order in the buffer chunks // shift: peer offset to avoid all ranks sending to or receiving from same peer @@ -386,8 +436,11 @@ class Primitives< flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0; connStepPtr = conn->tail; connStepCache = loadStepValue(connStepPtr); - flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; - if (Direct) { + connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; + if (conn->connFifo != nullptr) { + flags |= ConnFifoEnabled; + connFifo = conn->connFifo; + } else if (Direct) { // User buffers have been registered if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { @@ -409,9 +462,6 @@ class Primitives< flags |= NvlsDirectRead; } } - if (flags & OffsFifoEnabled) - connOffsFifoPtr = conn->offsFifo; - connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; } } } @@ -421,6 +471,10 @@ class Primitives< auto *conn = &peer->send[connIndex]; step = conn->step; step = roundUp(step, SlicePerChunk*StepPerSlice); + + connFifo = conn->connFifo; + if (connFifo != nullptr) flags |= ConnFifoEnabled; + if (flags & RolePostSend) { connStepPtr = conn->tail; connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; @@ -430,15 +484,8 @@ class Primitives< flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0; connStepPtr = conn->head; connStepCache = loadStepValue(connStepPtr); - flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; - if (flags & OffsFifoEnabled) - connOffsFifoPtr = conn->offsFifo; connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; - - if (conn->sizesFifo != nullptr) { - flags |= SizesFifoEnabled; - connSizesFifoPtr = conn->sizesFifo; - } else if (Direct) { + if (connFifo == nullptr && Direct) { // User buffers have been registered if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { @@ -468,7 +515,7 @@ class Primitives< __device__ Primitives( int tid, int nthreads, int const *recvPeers, int const *sendPeers, void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0, - uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr, int stepSize_=0 + uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr, struct ncclWorkElemP2p* p2p = nullptr, int stepSize_=0 ): tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group), stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_) { @@ -507,6 +554,8 @@ class Primitives< loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, e); loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, e); + if (p2p && p2p->reg) flags |= UserBufferMode; + if (barrierAny(flags & NetDeviceUnpack)) { flags |= AnyNetDeviceUnpack; // g == 0 is the first ThreadPerSync # of threads of this warp @@ -533,10 +582,21 @@ class Primitives< auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns; conns[index]->step = step; } - + if ((flags & UserBufferMode) && (flags & RoleWaitSend)) { + // Make sure we wait until the proxy has sent data before we return. + // We don't want the next CUDA kernel to overwrite the send buffer which + // was accessed directly. + uint64_t prevStep = step - StepPerSlice; + volatile ssize_t* ptr = &(connFifo[prevStep%NCCL_STEPS].size); + while (*ptr != -1); + } + if ((flags & (AnyNetDeviceUnpack)) && (flags & (RoleWaitRecv))) { ncclNetDeviceSaveHead(netDeviceHandle, group); } + + // Make sure all threads are done writing back conn->step and done using + // ncclShmem.groups[group] barrier(); } diff --git a/projects/rccl/src/device/reduce.h b/projects/rccl/src/device/reduce.h index 627d9b119b..43cae213b2 100644 --- a/projects/rccl/src/device/reduce.h +++ b/projects/rccl/src/device/reduce.h @@ -12,56 +12,39 @@ namespace { template __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const int nthreads = (int)args->nWarps * WARP_SIZE; ncclRing *ring = &ncclShmem.channel.ring; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCE_CHUNKSTEPS : 1)); - const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))); const int nranks = ncclShmem.comm.nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->count; const int rank = ncclShmem.comm.rank; const int prevRank = ring->userRanks[nranks-1]; const int root = args->root; + const size_t chunkCount = args->chunkCount; + const size_t channelCount = args->workCount; + const size_t gridOffset = args->workOffset; + size_t offset; + int nelem; Primitives, 0, Proto, 0> prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg); - auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int { - int realChunkSize; - if (Proto::Id == NCCL_PROTO_SIMPLE) { - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); - realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); - } - else if (Proto::Id == NCCL_PROTO_LL) - realChunkSize = size-gridOffset < loopSize ? args->lastChunkSize : chunkSize; - else if (Proto::Id == NCCL_PROTO_LL128) - realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize); - return realChunkSize; - }; - if (prevRank == root) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = calcChunkSize(gridOffset); - ssize_t offset = gridOffset + bid*realChunkSize; - int nelem = min(realChunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.send(offset, nelem); } } else if (rank == root) { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = calcChunkSize(gridOffset); - ssize_t offset = gridOffset + bid*realChunkSize; - int nelem = min(realChunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true); } } else { - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - int realChunkSize = calcChunkSize(gridOffset); - ssize_t offset = gridOffset + bid*realChunkSize; - int nelem = min(realChunkSize, size-offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceSend(offset, nelem); } } diff --git a/projects/rccl/src/device/reduce_kernel.h b/projects/rccl/src/device/reduce_kernel.h index 66e9516cd3..cbf774338a 100644 --- a/projects/rccl/src/device/reduce_kernel.h +++ b/projects/rccl/src/device/reduce_kernel.h @@ -616,7 +616,7 @@ struct Apply_PostOp, /*EltPerPack=*/1> { BytePack<2*sizeof(T)> tmp; \ asm("multimem.ld_reduce.relaxed.sys.global.add." #ptx_ty " %0, [%1];" \ : "=" PTX_REG_BytePack_field_##pack_field(tmp.pack_field) \ - : "l"(addr & -uintptr_t(sizeof(T)))); \ + : "l"(addr & -uintptr_t(2*sizeof(T)))); \ return tmp.half[(addr/sizeof(T))%2]; \ } \ }; @@ -629,11 +629,11 @@ struct Apply_PostOp, /*EltPerPack=*/1> { if (fn.isMinNotMax) { \ asm("multimem.ld_reduce.relaxed.sys.global.min." #ptx_ty " %0, [%1];" \ : "=" PTX_REG_BytePack_field_##pack_field(tmp.pack_field) \ - : "l"(addr & -uintptr_t(sizeof(T)))); \ + : "l"(addr & -uintptr_t(2*sizeof(T)))); \ } else { \ asm("multimem.ld_reduce.relaxed.sys.global.max." #ptx_ty " %0, [%1];" \ : "=" PTX_REG_BytePack_field_##pack_field(tmp.pack_field) \ - : "l"(addr & -uintptr_t(sizeof(T)))); \ + : "l"(addr & -uintptr_t(2*sizeof(T)))); \ } \ return tmp.half[(addr/sizeof(T))%2]; \ } \ diff --git a/projects/rccl/src/device/reduce_scatter.h b/projects/rccl/src/device/reduce_scatter.h index 6660cc0adc..96a63caeb4 100644 --- a/projects/rccl/src/device/reduce_scatter.h +++ b/projects/rccl/src/device/reduce_scatter.h @@ -12,56 +12,43 @@ namespace { template __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; - const int nthreads = args->nWarps*WARP_SIZE; - const int bid = args->bid; - const int nChannels = args->nChannels; + const uint32_t nthreads = (uint32_t)args->nWarps * WARP_SIZE; ncclRing *ring = &ncclShmem.channel.ring; int const *ringRanks = ring->userRanks; - const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? REDUCESCATTER_CHUNKSTEPS : 1)); - // We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere. - const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2); + const size_t chunkCount = args->chunkCount; const int nranks = ncclShmem.comm.nRanks; - const ssize_t loopSize = nChannels*chunkSize; - const ssize_t size = args->count; + size_t channelCount = args->workCount; + size_t gridOffset = args->workOffset; + size_t offset; + size_t dataOffset; + size_t count = args->count; + uint32_t nelem; + int rankDest; Primitives, 0, Proto, 0> prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t realChunkSize; - if (Proto::Id == NCCL_PROTO_SIMPLE) { - realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels)); - realChunkSize = roundUp(realChunkSize, (nthreads-WARP_SIZE)*sizeof(uint64_t)/sizeof(T)); - } - else if (Proto::Id == NCCL_PROTO_LL) - realChunkSize = size-gridOffset < loopSize ? args->lastChunkSize : chunkSize; - else if (Proto::Id == NCCL_PROTO_LL128) - realChunkSize = min(divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128, chunkSize); - realChunkSize = int(realChunkSize); - - ssize_t chunkOffset = gridOffset + bid*int(realChunkSize); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + nelem = min(chunkCount, channelCount - elemOffset); + dataOffset = gridOffset + elemOffset; /////////////// begin ReduceScatter steps /////////////// - ssize_t offset; - int nelem = min(realChunkSize, size-chunkOffset); - int rankDest; - // step 0: push data to next GPU rankDest = ringRanks[nranks-1]; - offset = chunkOffset + rankDest * size; + offset = dataOffset + rankDest * count; prims.send(offset, nelem); // k-2 steps: reduce and copy to next GPU for (int j=2; j struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { const int tid = threadIdx.x; - const int bid = args->bid; - const int nChannels = args->nChannels; struct ncclNvls* nvls = &ncclShmem.channel.nvls; - const ssize_t chunkSize = int(args->lastChunkSize); - const ssize_t size = args->count; - const ssize_t loopSize = nChannels*chunkSize; + const size_t chunkCount = args->chunkCount; + const size_t count = args->count; const int rank = ncclShmem.comm.rank; const int nranks = ncclShmem.comm.nRanks; + size_t gridOffset = args->workOffset; + size_t channelCount = args->workCount; + size_t offset; + int nelem; /* if we are direct NVLS, we only need to allocate 1 warp to scatter for sync; * if not, based on #ranks, we allocate 7 or 5 warps to reduce to saturate bandwidth @@ -116,10 +104,10 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * chunkSize; - int nelem = min(chunkSize, size - offset); - prims.scatter(offset, nvls->nHeads * size, nelem, size, -1, 0); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); + prims.scatter(offset, nvls->nHeads * count, nelem, count, -1, 0); } } else if (tid < tidEndReduce) { // Reduce through NVLS @@ -127,9 +115,9 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, args->recvbuff, args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t offset = gridOffset + bid * chunkSize; - int nelem = min(chunkSize, size - offset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + offset = gridOffset + elemOffset; + nelem = min(chunkCount, channelCount - elemOffset); prims.recv(offset, nelem); } } @@ -140,7 +128,7 @@ struct RunWorkElement, /*Direct=*/0, Proto, 0> prims(tid, nThreadsScatter, nvls->up, nvls->up, NULL, NULL, args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { prims.scatter(0, 0, 0, 0, -1, 0); } @@ -152,10 +140,10 @@ struct RunWorkElement, /*Direct=*/1, Proto, 0> prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, &nvls->down, NULL, args->recvbuff, args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, args); - for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { - ssize_t outOffset = gridOffset + bid * chunkSize; - ssize_t inpOffset = outOffset + rank * size; - int nelem = min(chunkSize, size - outOffset); + for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { + size_t outOffset = gridOffset + elemOffset; + size_t inpOffset = outOffset + rank * count; + nelem = min(chunkCount, channelCount - elemOffset); prims.directRecvCopy(inpOffset, outOffset, nelem); } @@ -165,3 +153,146 @@ struct RunWorkElement +struct RunWorkElement { + template + struct Scatterer { + struct ncclWorkElem* args; + int chunkSize; + ssize_t railGridOffset; + + template + __device__ __forceinline__ void operator()( + int tid, int tn, int slice, int maxSliceSize, + int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs, int32_t* dstSizes + ) { + static_assert(SlicePerChunk==1, "require: SlicePerChunk==1"); + static_assert(MaxDsts<=1 || MaxSrcs<=1, "require: MaxDsts<=1 || MaxSrcs<=1"); + + struct ncclDirect* direct = &ncclShmem.channel.collnetDirect; + int nNodes = ncclShmem.comm.nNodes; + int nRails = direct->nHeads; + int bid = args->bid; + void* inbuf = (void*)args->sendbuff; + ssize_t sizePerRank = args->count; + + ssize_t railAllBeg = min(railGridOffset + bid*chunkSize, nNodes*sizePerRank); + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank); + int railAllSize = railAllEnd - railAllBeg; + if (tid < nDsts) dstSizes[tid] = railAllSize; + + int dst = 0; + int rail; + if (!ReduceSendNotRecv) { + rail = direct->headRank; + } else { + rail = direct->headRank+1; + if (rail == nRails) rail = 0; + } + do { + int node = railAllBeg/sizePerRank; + int railAllOffset = 0; + while (railAllOffset < railAllSize) { + ssize_t railOneBeg = node*sizePerRank; + ssize_t railOneEnd = railOneBeg + sizePerRank; + ssize_t railOneOffset = (railAllBeg+railAllOffset) - railOneBeg; + int delta = min(railAllEnd, railOneEnd) - (railAllBeg+railAllOffset); + int rank = ncclShmem.comm.collNetDenseToUserRank[node*nRails + rail]; + ssize_t userOneBeg = rank*sizePerRank + railOneOffset; + reduceCopy + (tid, tn, args->redOpArg, &args->redOpArg, false, + /*nSrcs=*/1+nSrcs, [=]__device__(int s) { + return s==0 ? (T*)inbuf + userOneBeg + : (T*)srcPtrs[s-1] + railAllOffset; + }, + /*nDsts=*/1, [=]__device__(int d/*==0*/) { + return (T*)dstPtrs[dst] + railAllOffset; + }, + delta); + railAllOffset += delta; + node += 1; + } + dst += 1; + rail += 1; + if (rail == nRails) rail = 0; + } while (ReduceSendNotRecv && dst < nRails-1); + } + }; + + __device__ __forceinline__ void run(ncclWorkElem *args) { + int tid = threadIdx.x; + const int nChannels = args->nChannels; + struct ncclDirect* direct = &ncclShmem.channel.collnetDirect; + int const &nNodes = ncclShmem.comm.nNodes; + ssize_t chunkSize = int(args->chunkCount); + ssize_t sizePerRank = args->count; + + if (direct->out == -1) __trap(); + bool isMultiRail = (direct->nHeads > 1); + int nWarps1 = (isMultiRail ? 2 : 0); + int nWarps2 = (isMultiRail ? 2 : 1); + int nWarps3 = 1; + float denom = float(args->nWarps)/float(nWarps1+nWarps2+nWarps3); + nWarps3 = int(denom*nWarps3); + nWarps2 = int(denom*nWarps2); + nWarps1 = args->nWarps - (nWarps2+nWarps3); + + using Proto = ProtoSimple<1, 1>; + + int tn = nWarps1*WARP_SIZE; + if (tid < tn) { + // Phase 1: Scatter inputs to peers + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, tn, nullptr, direct->heads+1, nullptr, nullptr, + args->redOpArg, 0*Proto::MaxGroupWidth, 1, 1); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + Scatterer scat; + scat.args = args; + scat.chunkSize = chunkSize; + scat.railGridOffset = railGridOffset; + prims.process(scat); + } + return; + } + tid -= tn; + + tn = nWarps2*WARP_SIZE; + if (tid < tn) { + // Phase 2: Reduce from peers + local input -> send to network + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, tn, direct->heads+1, &direct->out, nullptr, nullptr, + args->redOpArg, 1*Proto::MaxGroupWidth, 1, 1); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + Scatterer scat; + scat.args = args; + scat.chunkSize = chunkSize; + scat.railGridOffset = railGridOffset; + prims.process(scat); + } + return; + } + tid -= tn; + + tn = nWarps3*WARP_SIZE; + if (tid < tn) { + // Phase 3: recv from network + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, tn, &direct->out, nullptr, nullptr, args->recvbuff, + args->redOpArg, 2*Proto::MaxGroupWidth, 0, 0); + for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) { + ssize_t railAllBeg = railGridOffset + args->bid*chunkSize; + ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank); + ssize_t railOneBeg = ncclShmem.comm.node*sizePerRank; + ssize_t railOneEnd = railOneBeg + sizePerRank; + ssize_t beg = max(railAllBeg, railOneBeg); + ssize_t end = min(railAllEnd, railOneEnd); + prims.recv(beg-railOneBeg, max(ssize_t(0), end-beg), /*postOp=*/true); + } + return; + } + } +}; diff --git a/projects/rccl/src/device/sendrecv.h b/projects/rccl/src/device/sendrecv.h index 5401f0542c..347ac78c56 100644 --- a/projects/rccl/src/device/sendrecv.h +++ b/projects/rccl/src/device/sendrecv.h @@ -26,13 +26,13 @@ struct RunWork { if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; int const peer = args->peer; Primitives, 1, Proto, 1> prims - (tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, group, 1, 1, nullptr, ncclShmem.comm.p2pChunkSize/sizeof(T)); + (tid, nthreads, nullptr, &peer, buff, nullptr, /*redOpArg(ignored)=*/0, group, 1, 1, nullptr, args, ncclShmem.comm.p2pChunkSize/sizeof(T)); size_t offset = 0; do { int nelem = min(size_t(chunkSize), count-offset); prims.directSend(offset, offset, nelem); offset += nelem; - } while(offset < count); + } while(offset < count && args->reg == 0); } } @@ -45,13 +45,13 @@ struct RunWork { if (args->proto == NCCL_PROTO_LL) chunkSize /= 2; // This is to account for chunkEffectiveSize int const peer = args->peer; Primitives, 1, Proto, 1> prims - (tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, group, 1, 1, nullptr, ncclShmem.comm.p2pChunkSize/sizeof(T)); + (tid, nthreads, &peer, nullptr, nullptr, buff, /*redOpArg(ignored)=*/0, group, 1, 1, nullptr, args, ncclShmem.comm.p2pChunkSize/sizeof(T)); size_t offset = 0; do { int nelem = min(size_t(chunkSize), count-offset); prims.directRecv(offset, nelem); offset += nelem; - } while(offset < count); + } while(offset < count && args->reg == 0); } } diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index ae56decd94..8ff4e0220a 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -12,21 +12,25 @@ #include "channel.h" #include "cudawrap.h" #include "transport.h" - +#include #include // std::memcpy #include // PRIx64 -enum ncclRegBufferType { - NCCL_REGULAR_BUFFER = 0, - NCCL_IPC_REG_BUFFER = 1, - NCCL_NVLS_REG_BUFFER = 2, - NCCL_REG_BUFFER_NUM = 3 -}; - -static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */); - NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); +static ncclResult_t initCollWorkElem(struct ncclInfo* collInfo, struct ncclWorkElem* work); +static ncclResult_t setCollWorkElem(uint64_t workCount, uint64_t workOffset, size_t lastChunkCount, struct ncclWorkElem* work); +static ncclResult_t initCollWorkElemReg(struct ncclComm* comm, struct ncclWorkElem* work, struct ncclChannel* channel, ncclRegBufferType regBufType, void* regBufSend[], void* regBufRecv[], struct ncclWorkElemReg* workElemReg); +static ncclResult_t computeCollChunkInfo(struct ncclInfo* collInfo, size_t nBytes, int nChannels); +static ncclResult_t initCollProxyOp(struct ncclInfo* collInfo, int channelId, uint64_t opCount, uint32_t nsteps, struct ncclProxyOp* proxyOp); +static ncclResult_t getTunerInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps); +static ncclResult_t topoGetAlgoInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps); +static ncclResult_t getChannnelThreadInfo(struct ncclInfo* collInfo); +static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo); +static ncclResult_t getPatternInfo(struct ncclInfo* collInfo); +static ncclResult_t getLoopInfo(struct ncclInfo* collInfo); +static ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetSupport); + // Returns maximum kernel stack size of all CUDA kernels ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) { ncclResult_t result = ncclSuccess; @@ -66,25 +70,21 @@ ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) { static void appendWorkElemColl( struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId, - int funcIndex, struct ncclWorkElem const *elem, int bid - ) { + int funcIndex, struct ncclWorkElem const *elem) { struct ncclKernelPlan::Channel* chan = &plan->channels[channelId]; struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue); if (q && funcIndex == q->work.header.funcIndex && elem->nWarps == q->work.elems[0].nWarps - && chan->nWorkElem < NCCL_MAX_WORK_ELEMENTS) { + && chan->nWorkElem < NCCL_MAX_WORK_ELEMENTS + && ncclWorkTypeColl == q->work.header.type) { int e = chan->nWorkElem++; q->work.elems[e] = *elem; // C++ struct assignment - q->work.elems[e].bid = bid; - q->work.elems[e].isUsed = 1; return; } q = ncclMemoryStackAlloc(&comm->memScoped); q->work.header.type = ncclWorkTypeColl; q->work.header.funcIndex = funcIndex; q->work.elems[0] = *elem; // C++ struct assignment - q->work.elems[0].bid = bid; - q->work.elems[0].isUsed = 1; chan->nWorkElem = 1; chan->nWork += 1; ncclIntruQueueEnqueue(&chan->workQueue, q); @@ -92,16 +92,15 @@ static void appendWorkElemColl( static void appendWorkElemColl( struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId, - int funcIndex, struct ncclWorkElemReg const *elem, int bid - ) { + int funcIndex, struct ncclWorkElemReg const *elem) { struct ncclKernelPlan::Channel* chan = &plan->channels[channelId]; struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue); if (q && funcIndex == q->work.header.funcIndex && elem->elem.nWarps == q->work.regElems[0].elem.nWarps - && chan->nWorkElem < NCCL_MAX_WORK_ELEMENTS_REG) { + && chan->nWorkElem < NCCL_MAX_WORK_ELEMENTS_REG + && ncclWorkTypeRegColl == q->work.header.type) { int e = chan->nWorkElem++; q->work.regElems[e] = *elem; // C++ struct assignment - q->work.regElems[e].elem.bid = bid; q->work.regElems[e].elem.isUsed = 1; return; } @@ -109,7 +108,6 @@ static void appendWorkElemColl( q->work.header.type = ncclWorkTypeRegColl; q->work.header.funcIndex = funcIndex; q->work.regElems[0] = *elem; // C++ struct assignment - q->work.regElems[0].elem.bid = bid; q->work.regElems[0].elem.isUsed = 1; chan->nWorkElem = 1; chan->nWork += 1; @@ -186,23 +184,177 @@ static ncclResult_t addProxyOpIfNeeded(struct ncclComm* comm, struct ncclKernelP return ncclSuccess; } -// Put coll workelem & proxyOp in plan assuming nWorkBudget permits, so please -// ensure *nWorkBudget >= nBids upon entry. -static ncclResult_t addCollToPlan( - struct ncclComm* comm, struct ncclKernelPlan* plan, int* nWorkBudget, int funcIndex, - struct ncclWorkElem const* workElem, struct ncclProxyOp const* proxyOp, - int nCollChannels, int nBid, size_t bytes, ncclRegBufferType regBufType, void* regBufSend[], void* regBufRecv[] +static ncclResult_t computeCollSteps(struct ncclInfo* collInfo, size_t workCount, uint32_t* steps) { + struct ncclComm* comm = collInfo->comm; + if (collInfo->coll == ncclFuncAllReduce) { + if (collInfo->algorithm == NCCL_ALGO_RING) + *steps = DIVUP(workCount, comm->nRanks * collInfo->chunkCount) * (comm->nRanks - 1) * 2 * collInfo->chunkSteps; + else if (collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT) + *steps = DIVUP(workCount, comm->channels[0].collnetDirect.nHeads * collInfo->chunkCount) * collInfo->chunkSteps; + else if (collInfo->algorithm == NCCL_ALGO_NVLS || collInfo->algorithm == NCCL_ALGO_NVLS_TREE) + *steps = DIVUP(workCount, comm->channels[0].nvls.nHeads * collInfo->chunkCount) * collInfo->chunkSteps; + else + *steps = DIVUP(workCount, collInfo->chunkCount) * collInfo->chunkSteps; + } else if (collInfo->coll == ncclFuncReduceScatter) { + if (collInfo->algorithm == NCCL_ALGO_RING) + *steps = DIVUP(workCount, collInfo->chunkCount) * (comm->nRanks - 1) * collInfo->chunkSteps; + else + *steps = DIVUP(workCount, collInfo->chunkCount) * collInfo->chunkSteps; + } else if (collInfo->coll == ncclFuncAllGather) { + if (collInfo->algorithm == NCCL_ALGO_RING) + *steps = DIVUP(workCount, collInfo->chunkCount) * (comm->nRanks - 1) * collInfo->chunkSteps; + else + *steps = DIVUP(workCount, collInfo->chunkCount) * collInfo->chunkSteps; + } else { + *steps = DIVUP(workCount, collInfo->chunkCount) * collInfo->chunkSteps; + } + return ncclSuccess; +} + +static ncclResult_t computeCollAlignCount(struct ncclInfo* collInfo, size_t* alignCount) { + if (collInfo->protocol == NCCL_PROTO_SIMPLE) { + *alignCount = NCCL_SIMPLE_ALIGNMENT / ncclTypeSize(collInfo->datatype); + } else if (collInfo->protocol == NCCL_PROTO_LL128) { + *alignCount = NCCL_LL128_ALIGNMENT_PER_WARP / ncclTypeSize(collInfo->datatype) * (collInfo->nThreads / WARP_SIZE); + } else { + *alignCount = NCCL_LL_ALIGNMENT_PER_THREAD / ncclTypeSize(collInfo->datatype) * collInfo->nThreads; + } + return ncclSuccess; +} + +static ncclResult_t computeCollLastChunkInfo(struct ncclInfo* collInfo, size_t workCount, size_t alignCount, size_t* lastChunkCount) { + struct ncclComm* comm = collInfo->comm; + + if (collInfo->coll == ncclFuncAllReduce) { + if (collInfo->algorithm == NCCL_ALGO_RING) { + size_t remCount = workCount % (comm->nRanks * collInfo->chunkCount); + *lastChunkCount = DIVUP(DIVUP(remCount, comm->nRanks), alignCount) * alignCount; + } else if (collInfo->algorithm == NCCL_ALGO_NVLS || collInfo->algorithm == NCCL_ALGO_NVLS_TREE) { + size_t remCount = workCount % (comm->channels[0].nvls.nHeads * collInfo->chunkCount); + *lastChunkCount = DIVUP(DIVUP(remCount, comm->channels[0].nvls.nHeads), alignCount) * alignCount; + } else if (collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT) { + size_t remCount = workCount % (comm->channels[0].collnetDirect.nHeads * collInfo->chunkCount); + *lastChunkCount = DIVUP(DIVUP(remCount, comm->channels[0].collnetDirect.nHeads), alignCount) * alignCount; + } else { + *lastChunkCount = collInfo->chunkCount; + } + } else { + *lastChunkCount = collInfo->chunkCount; + } + return ncclSuccess; +} + +static ncclResult_t getCollnetLoopInfo(struct ncclInfo* collInfo, int* nstepsPerLoop, int* nchunksPerLoop) { + switch (collInfo->pattern) { + case ncclPatternCollnetChain: + *nstepsPerLoop = *nchunksPerLoop = 1; break; + case ncclPatternNvls: + *nstepsPerLoop = 1; *nchunksPerLoop = collInfo->comm->channels[0].nvls.nHeads; break; + case ncclPatternCollnetDirect: + *nstepsPerLoop = 1; *nchunksPerLoop = collInfo->comm->channels[0].collnetDirect.nHeads; break; + default: + WARN("Unknown collnet pattern %d", collInfo->pattern); + return ncclInternalError; + } + return ncclSuccess; +} + +static ncclResult_t addCollnetCollToPlan( + struct ncclComm* comm, struct ncclKernelPlan* plan, int usableChannels, + struct ncclInfo* collInfo, int* nWorkBudget ) { + ncclResult_t ret = ncclSuccess; struct ncclKernelPlan::Channel *chans = plan->channels; + struct ncclWorkElem workElem; + uint64_t opCount = uint64_t(plan->collOpCount++) << 1 | 0; + ncclRegBufferType regBufType = collInfo->regBufType; + int nChannels = std::min(collInfo->nChannels, usableChannels); + size_t countPerChannel = DIVUP(collInfo->count, nChannels); + uint32_t typeSize = ncclTypeSize(collInfo->datatype); + int steps, nchunksPerLoop, nstepsPerLoop, nLoop; + + NCCLCHECK(computeCollChunkInfo(collInfo, collInfo->nBytes, collInfo->nChannels)); + NCCLCHECKGOTO(initCollWorkElem(collInfo, &workElem), ret, fail); + workElem.nChannels = nChannels; + + NCCLCHECKGOTO(getCollnetLoopInfo(collInfo, &nstepsPerLoop, &nchunksPerLoop), ret, fail); + nLoop = (int)DIVUP(collInfo->nBytes, (size_t)nChannels * nchunksPerLoop * collInfo->chunkSize); + steps = nstepsPerLoop * nLoop * collInfo->chunkSteps; + + for (int bid = 0; bid < nChannels; bid++) { + workElem.bid = bid; + // Add work elem + *nWorkBudget += chans[bid].nWork; + if (regBufType == NCCL_REGULAR_BUFFER) { + appendWorkElemColl(comm, plan, bid, collInfo->workFuncIndex, &workElem); + } else { + struct ncclWorkElemReg workElemReg; + NCCLCHECKGOTO(initCollWorkElemReg(comm, &workElem, &comm->channels[bid], regBufType, collInfo->regBufSend, collInfo->regBufRecv, &workElemReg), ret, fail); + appendWorkElemColl(comm, plan, bid, collInfo->workFuncIndex, &workElemReg); + } + *nWorkBudget -= chans[bid].nWork; // subtract delta of chans[c].nWork + + // Add proxy task. Empty collectives do not make it to the proxy thread + // since they don't imply synchronization for the user like p2p. + if (collInfo->nBytes != 0) { + struct ncclProxyOp proxyOp; + NCCLCHECKGOTO(initCollProxyOp(collInfo, bid, opCount, steps, &proxyOp), ret, fail); + NCCLCHECKGOTO(addProxyOpIfNeeded(comm, plan, &proxyOp), ret, fail); + } + + chans[bid].collBytes += countPerChannel * typeSize; + } + + plan->threadPerBlock = std::max(plan->threadPerBlock, collInfo->nThreads); + if (!plan->kernelSpecialized) { + plan->kernelFn = ncclDevKernelForFunc[collInfo->workFuncIndex]; + plan->kernelSpecialized = ncclDevKernelForFuncIsSpecialized[collInfo->workFuncIndex]; + } + + if (comm->rank == 0) { + TRACE(NCCL_COLL, "collnetColl enqueue coll %s(%s, %s, %s, %s), nChannels %d, count %ld (nbytes %ld), usableChannel %d, chunkCount %d, funcIndex %d, nThreads %d", collInfo->opName, ncclOpToString(collInfo->op), ncclDatatypeToString(collInfo->datatype), ncclAlgoToString(collInfo->algorithm), ncclProtoToString(collInfo->protocol), collInfo->nChannels, collInfo->count, collInfo->workBytes, usableChannels, collInfo->chunkCount, collInfo->workFuncIndex, collInfo->nThreads); + } + +exit: + return ret; +fail: + goto exit; +} + +static ncclResult_t addTunedCollToPlan( + struct ncclComm* comm, struct ncclKernelPlan* plan, int usableChannels, + struct ncclInfo* collInfo, int* nWorkBudget + ) { + ncclResult_t ret = ncclSuccess; + struct ncclKernelPlan::Channel *chans = plan->channels; + struct ncclWorkElem workElem; + uint64_t opCount = uint64_t(plan->collOpCount++) << 1 | 0; + uint64_t workCount; + uint64_t workOffset = 0; + uint32_t typeSize = ncclTypeSize(collInfo->datatype); + ncclRegBufferType regBufType = collInfo->regBufType; + size_t alignCount, lastChunkCount; + int least[/*nBid*/MAXCHANNELS]; + int maxIndexInLeast; + size_t maxBytesInLeast; + int nChannels = std::min(collInfo->nChannels, usableChannels); + int rnChannels = 0; + size_t countPerChannels; + size_t remCount = collInfo->count; + + NCCLCHECKGOTO(computeCollAlignCount(collInfo, &alignCount), ret, fail); + countPerChannels = DIVUP(DIVUP(collInfo->count, nChannels), alignCount) * alignCount; + nChannels = DIVUP(collInfo->count, countPerChannels); + NCCLCHECKGOTO(computeCollChunkInfo(collInfo, collInfo->nBytes, nChannels), ret, fail); + NCCLCHECKGOTO(initCollWorkElem(collInfo, &workElem), ret, fail); // Choose the `nBid` least loaded channels to do the work. This ensures // all bids go to different channels in case they need to synchronize. - int least[/*nBid*/MAXCHANNELS]; least[0] = 0; - int maxIndexInLeast = 0; - size_t maxBytesInLeast = chans[0].collBytes; + maxIndexInLeast = 0; + maxBytesInLeast = chans[0].collBytes; // Initialize least[] such that the first nBid channels are accounted for. - for (int b=1; b < nBid; b++) { + for (int b = 1; b < nChannels; b++) { least[b] = b; if (maxBytesInLeast < chans[b].collBytes) { maxIndexInLeast = b; @@ -210,13 +362,14 @@ static ncclResult_t addCollToPlan( } } // Sort in the rest of the channels. If a channel has less work than the max - // member of least[], replace that member and compute the new max. - for (int c=nBid; c < nCollChannels; c++) { + // member of least[], replace that member and compute the new max. We only + // sort channels when coll algo is not collnet. + for (int c = nChannels; c < usableChannels; c++) { if (chans[c].collBytes < maxBytesInLeast) { least[maxIndexInLeast] = c; maxBytesInLeast = chans[least[0]].collBytes; maxIndexInLeast = 0; - for (int b=1; b < nBid; b++) { + for (int b = 1; b < nChannels; b++) { if (maxBytesInLeast < chans[least[b]].collBytes) { maxIndexInLeast = b; maxBytesInLeast = chans[least[b]].collBytes; @@ -225,61 +378,130 @@ static ncclResult_t addCollToPlan( } } - uint64_t opCount = uint64_t(plan->collOpCount++)<<1 | 0; - bytes /= nBid; - for (int bid=0; bid < nBid; bid++) { + for (int bid = 0; bid < nChannels && remCount > 0; bid++) { int c = least[bid]; - chans[c].collBytes += bytes; + + workCount = std::min(countPerChannels, remCount); + NCCLCHECKGOTO(computeCollLastChunkInfo(collInfo, workCount, alignCount, &lastChunkCount), ret, fail); + NCCLCHECKGOTO(setCollWorkElem(workCount, workOffset, lastChunkCount, &workElem), ret, fail); // Add work elem *nWorkBudget += chans[c].nWork; if (regBufType == NCCL_REGULAR_BUFFER) { - appendWorkElemColl(comm, plan, c, funcIndex, workElem, bid); - } else if (regBufType == NCCL_IPC_REG_BUFFER) { - struct ncclChannel* channel = &comm->channels[c]; - struct ncclWorkElemReg workElemReg; - workElemReg.elem = *workElem; // C++ struct assignment - workElemReg.elem.regUsed = 1; - for (int i=0; i < NCCL_MAX_DIRECT_ARITY; i++) { - int peer = channel->collnetDirect.down[i]; - if (peer == -1) break; - int j = comm->rankToLocalRank[peer]; // Get intra-node slot - workElemReg.dnInputs[i] = regBufSend[j]; // Input buffer of leaf peer - workElemReg.dnOutputs[i] = regBufRecv[j]; // Output buffer of leaf peer - } - for (int i=0; i < NCCL_MAX_DIRECT_ARITY; i++) { - int peer = channel->collnetDirect.up[i]; - if (peer == -1) break; - int j = comm->rankToLocalRank[peer]; - // Output buffer of root peer - workElemReg.upOutputs[i] = regBufRecv[j]; - } - appendWorkElemColl(comm, plan, c, funcIndex, &workElemReg, bid); - } else if (regBufType == NCCL_NVLS_REG_BUFFER) { - struct ncclWorkElemReg workElemReg; - workElemReg.elem = *workElem; // C++ struct assignment - workElemReg.elem.regUsed = 1; - /* NVLS only has one send and recv buffer registered */ - workElemReg.dnInputs[0] = regBufSend[0]; - workElemReg.dnOutputs[0] = regBufRecv[0]; - appendWorkElemColl(comm, plan, c, funcIndex, &workElemReg, bid); + appendWorkElemColl(comm, plan, c, collInfo->workFuncIndex, &workElem); } else { - /* impossible value */ - WARN("Invalid regBufType %d\n", regBufType); - return ncclInvalidArgument; + struct ncclWorkElemReg workElemReg; + NCCLCHECKGOTO(initCollWorkElemReg(comm, &workElem, &comm->channels[c], regBufType, collInfo->regBufSend, collInfo->regBufRecv, &workElemReg), ret, fail); + appendWorkElemColl(comm, plan, c, collInfo->workFuncIndex, &workElemReg); } *nWorkBudget -= chans[c].nWork; // subtract delta of chans[c].nWork // Add proxy task. Empty collectives do not make it to the proxy thread // since they don't imply synchronization for the user like p2p. - if (proxyOp->nsteps != 0) { - struct ncclProxyOp tmp = *proxyOp; // C++ struct assignment - tmp.channelId = c; - tmp.opCount = opCount; - NCCLCHECK(addProxyOpIfNeeded(comm, plan, &tmp)); + if (collInfo->nBytes != 0) { + uint32_t steps; + struct ncclProxyOp proxyOp; + NCCLCHECKGOTO(computeCollSteps(collInfo, workCount, &steps), ret, fail); + NCCLCHECKGOTO(initCollProxyOp(collInfo, c, opCount, steps, &proxyOp), ret, fail); + NCCLCHECKGOTO(addProxyOpIfNeeded(comm, plan, &proxyOp), ret, fail); } + + remCount -= workCount; + chans[c].collBytes += workCount * typeSize; + workOffset += workCount; + rnChannels++; } - return ncclSuccess; + + plan->threadPerBlock = std::max(plan->threadPerBlock, collInfo->nThreads); + if (!plan->kernelSpecialized) { + plan->kernelFn = ncclDevKernelForFunc[collInfo->workFuncIndex]; + plan->kernelSpecialized = ncclDevKernelForFuncIsSpecialized[collInfo->workFuncIndex]; + } + + if (comm->rank == 0) { + TRACE(NCCL_COLL, "tunedColl enqueue coll %s(%s, %s, %s, %s), nChannels %d, count %ld (nbytes %ld), usableChannel %d, chunkCount %d, lastChunkCount %ld, funcIndex %d, nThreads %d", collInfo->opName, ncclOpToString(collInfo->op), ncclDatatypeToString(collInfo->datatype), ncclAlgoToString(collInfo->algorithm), ncclProtoToString(collInfo->protocol), rnChannels, collInfo->count, collInfo->workBytes, usableChannels, collInfo->chunkCount, lastChunkCount, collInfo->workFuncIndex, collInfo->nThreads); + } + +exit: + return ret; +fail: + goto exit; +} + +static ncclResult_t addCBDCollToPlan( + struct ncclComm* comm, struct ncclKernelPlan* plan, int usableChannels, + struct ncclInfo* collInfo, int* nWorkBudget + ) { + ncclResult_t ret = ncclSuccess; + struct ncclKernelPlan::Channel *chans = plan->channels; + size_t enqBytes; + uint64_t opCount = uint64_t(plan->collOpCount++) << 1 | 0; + size_t typeSize = ncclTypeSize(collInfo->datatype); + size_t workBytesTotal = collInfo->count * typeSize; + size_t workCountTotal = collInfo->count; + struct ncclWorkElem workElem; + size_t workOffset = 0; + size_t workCount; + ncclRegBufferType regBufType = collInfo->regBufType; + size_t alignCount; + size_t lastChunkCount; + int rnChannel = 0; + + NCCLCHECKGOTO(computeCollChunkInfo(collInfo, collInfo->aggnBytes, collInfo->nChannels), ret, fail); + NCCLCHECKGOTO(computeCollAlignCount(collInfo, &alignCount), ret, fail); + NCCLCHECKGOTO(initCollWorkElem(collInfo, &workElem), ret, fail); + for (int c = 0; c < usableChannels; c++) { + if (plan->maxBytesPerChannel <= chans[c].collBytes) continue; + if (workBytesTotal == 0) break; + enqBytes = std::min(plan->maxBytesPerChannel - chans[c].collBytes, workBytesTotal); + workCount = std::min(DIVUP(DIVUP(enqBytes, typeSize), alignCount) * alignCount, workCountTotal); + enqBytes = workCount * typeSize; + + NCCLCHECKGOTO(computeCollLastChunkInfo(collInfo, workCount, alignCount, &lastChunkCount), ret, fail); + NCCLCHECKGOTO(setCollWorkElem(workCount, workOffset, lastChunkCount, &workElem), ret, fail); + + // Add work elem + *nWorkBudget += chans[c].nWork; + if (regBufType == NCCL_REGULAR_BUFFER) { + appendWorkElemColl(comm, plan, c, collInfo->workFuncIndex, &workElem); + } else { + struct ncclWorkElemReg workElemReg; + NCCLCHECKGOTO(initCollWorkElemReg(comm, &workElem, &comm->channels[c], regBufType, collInfo->regBufSend, collInfo->regBufRecv, &workElemReg), ret, fail); + appendWorkElemColl(comm, plan, c, collInfo->workFuncIndex, &workElemReg); + } + *nWorkBudget -= chans[c].nWork; // subtract delta of chans[c].nWork + + // Add proxy task. Empty collectives do not make it to the proxy thread + // since they don't imply synchronization for the user like p2p. + if (collInfo->nBytes != 0) { + uint32_t steps; + struct ncclProxyOp proxyOp; + NCCLCHECKGOTO(computeCollSteps(collInfo, workCount, &steps), ret, fail); + NCCLCHECKGOTO(initCollProxyOp(collInfo, c, opCount, steps, &proxyOp), ret, fail); + NCCLCHECKGOTO(addProxyOpIfNeeded(comm, plan, &proxyOp), ret, fail); + } + + workBytesTotal -= enqBytes; + workCountTotal -= workCount; + chans[c].collBytes += enqBytes; + workOffset += workCount; + rnChannel++; + } + + plan->threadPerBlock = std::max(plan->threadPerBlock, collInfo->nThreads); + if (!plan->kernelSpecialized) { + plan->kernelFn = ncclDevKernelForFunc[collInfo->workFuncIndex]; + plan->kernelSpecialized = ncclDevKernelForFuncIsSpecialized[collInfo->workFuncIndex]; + } + + if (comm->rank == 0) { + TRACE(NCCL_COLL, "CBDColl enqueue coll %s(%s, %s, %s, %s), nChannels %d, count %ld (nbytes %ld), usableChannel %d, maxBytesPerChannel %ld, chunkCount %d, lastChunkCount %ld, funcIndex %d, nThreads %d", collInfo->opName, ncclOpToString(collInfo->op), ncclDatatypeToString(collInfo->datatype), ncclAlgoToString(collInfo->algorithm), ncclProtoToString(collInfo->protocol), rnChannel, collInfo->count, collInfo->workBytes, usableChannels, plan->maxBytesPerChannel, collInfo->chunkCount, lastChunkCount, collInfo->workFuncIndex, collInfo->nThreads); + } + +exit: + return ret; +fail: + goto exit; } NCCL_PARAM(P2pLLThreshold, "P2P_LL_THRESHOLD", 16384); @@ -306,13 +528,22 @@ static ncclResult_t addP2pToPlan( &comm->channels[channelId].peers[peer]->send[1].conn : &comm->channels[channelId].peers[peer]->recv[1].conn; info.protocol = ((conn->buffs[NCCL_PROTO_LL] != nullptr) && bytes <= ncclParamP2pLLThreshold()) ? NCCL_PROTO_LL : NCCL_PROTO_SIMPLE; + int reg = 0; + if (info.protocol == NCCL_PROTO_SIMPLE) { + struct ncclReg* regRecord; + NCCLCHECK(ncclRegFind(comm, addr, bytes, ®Record)); + reg = regRecord && regRecord->nDevs ? 1 : 0; + } + struct ncclProxyOp proxyOp = {}; - NCCLCHECK(ncclProxyComputeP2p(&info, &proxyOp)); + // May tune chunksize and set proxyOp.reg=0 if not using the network. + NCCLCHECK(ncclProxyComputeP2p(&info, &proxyOp, reg)); struct ncclWorkElemP2p elem = {0}; elem.proto = info.protocol; elem.peer = peer; elem.nWarps = NCCL_MAX_NTHREADS/WARP_SIZE; + elem.reg = proxyOp.reg; elem.p2pType = isSendNotRecv ? ncclWorkP2pTypeSend : ncclWorkP2pTypeRecv; elem.buffLo32 = uint32_t(reinterpret_cast(addr)); elem.buffHi32 = reinterpret_cast(addr)>>32; @@ -358,22 +589,17 @@ int64_t ncclParamLocalRegister(); NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 1); static ncclResult_t registerIntraNodeBuffers( - struct ncclComm* comm, struct ncclKernelPlan* plan, struct ncclInfo* info, - void* outRegBufSend[NCCL_MAX_LOCAL_RANKS], - void* outRegBufRecv[NCCL_MAX_LOCAL_RANKS], - ncclRegBufferType *outRegBufType + struct ncclComm* comm, struct ncclKernelPlan* plan, struct ncclInfo* info ) { ncclResult_t result = ncclSuccess; - *outRegBufType = NCCL_REGULAR_BUFFER; + info->regBufType = NCCL_REGULAR_BUFFER; #if CUDART_VERSION >= 11030 if ((info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_NVLS_TREE) && comm->nvlsRegSupport) { bool regBufUsed = false; const void *sendbuff = info->sendbuff; void *recvbuff = info->recvbuff; - cudaPointerAttributes sattr, rattr; - bool query = false; - + if (info->coll == ncclFuncAllGather) sendbuff = NULL; else if (info->coll == ncclFuncReduceScatter) @@ -381,30 +607,26 @@ static ncclResult_t registerIntraNodeBuffers( /* first try local registration. */ if (ncclParamLocalRegister()) { - CUDACHECK(cudaPointerGetAttributes(&sattr, info->sendbuff)); - CUDACHECK(cudaPointerGetAttributes(&rattr, info->recvbuff)); - query = true; - if (sattr.type == cudaMemoryTypeDevice && rattr.type == cudaMemoryTypeDevice) - ncclNvlsLocalRegisterBuffer(comm, sendbuff, recvbuff, info->sendbuffSize, info->recvbuffSize, ®BufUsed, outRegBufSend, outRegBufRecv); + ncclNvlsLocalRegisterBuffer(comm, sendbuff, recvbuff, info->sendbuffSize, info->recvbuffSize, ®BufUsed, info->regBufSend, info->regBufRecv); } if (regBufUsed == false && plan->persistent && ncclParamGraphRegister()) { - if (!query) { - CUDACHECK(cudaPointerGetAttributes(&sattr, info->sendbuff)); - CUDACHECK(cudaPointerGetAttributes(&rattr, info->recvbuff)); - } - if (sattr.type == cudaMemoryTypeDevice && rattr.type == cudaMemoryTypeDevice) - ncclNvlsGraphRegisterBuffer(comm, plan, sendbuff, recvbuff, info->sendbuffSize, info->recvbuffSize, ®BufUsed, outRegBufSend, outRegBufRecv); + ncclNvlsGraphRegisterBuffer(comm, plan, sendbuff, recvbuff, info->sendbuffSize, info->recvbuffSize, ®BufUsed, info->regBufSend, info->regBufRecv); } if (regBufUsed) { /* tweak NVLS channels usage; for registered NVLS buffer, we only need 4/5 channels to * saturate bandwidth. */ - if (info->coll == ncclFuncReduceScatter) - info->nChannels = std::min(5, comm->nvlsChannels); - else - info->nChannels = std::min(4, comm->nvlsChannels); - *outRegBufType = NCCL_NVLS_REG_BUFFER; + if (comm->nNodes == 1) { + if (info->coll == ncclFuncReduceScatter) + info->nChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 5)); + else + info->nChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 4)); + } else { + info->nChannels = std::max(comm->config.minCTAs, std::min(comm->config.maxCTAs, 6)); + } + + info->regBufType = NCCL_NVLS_REG_BUFFER; } } else if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT && // limited to CollNetDirect for now comm->intraHighestTransportType == TRANSPORT_P2P && // only when all ranks can p2p each other @@ -441,15 +663,15 @@ static ncclResult_t registerIntraNodeBuffers( // Open handles locally for (int i=0; i < comm->localRanks; i++) { if (i == localRank) { // Skip self - outRegBufSend[i] = nullptr; - outRegBufRecv[i] = nullptr; + info->regBufSend[i] = nullptr; + info->regBufRecv[i] = nullptr; } else { for (int sr=0; sr < 2; sr++) { // Get base address of mapping void* base; CUDACHECK(cudaIpcOpenMemHandle(&base, handles[i].ipc[sr], cudaIpcMemLazyEnablePeerAccess)); // Get real buffer address by adding offset in the mapping - (sr==0 ? outRegBufSend : outRegBufRecv)[i] = (char*)base + handles[i].offset[sr]; + (sr == 0 ? info->regBufSend : info->regBufRecv)[i] = (char*)base + handles[i].offset[sr]; // Enqueue reminder to close memory handle struct ncclPointerList* q = ncclMemoryPoolAlloc(&comm->memPool_ncclPointerList, &comm->memPermanent); q->ptr = base; @@ -457,135 +679,163 @@ static ncclResult_t registerIntraNodeBuffers( } } } - *outRegBufType = NCCL_IPC_REG_BUFFER; + info->regBufType = NCCL_IPC_REG_BUFFER; } fallback: #endif return result; } -static ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetSupport); -static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetSupport, int nvlsSupport, int numPipeOps); +static ncclResult_t getCBDCollnChannel(struct ncclKernelPlan* plan, struct ncclInfo* collInfo, int usableChannels) { + size_t firstEnqBytes; + size_t workBytesTotal = collInfo->workBytes; + struct ncclKernelPlan::Channel *chans = plan->channels; + int typeSize = ncclTypeSize(collInfo->datatype); + size_t maxCount = DIVUP(plan->maxBytesPerChannel, typeSize); + + if (workBytesTotal == 0) { + collInfo->nChannels = 1; + goto exit; + } + + for (int c = 0; c < usableChannels; c++) { + if (plan->maxBytesPerChannel <= chans[c].collBytes) continue; + firstEnqBytes = std::min(plan->maxBytesPerChannel - chans[c].collBytes, workBytesTotal); + firstEnqBytes = DIVUP(firstEnqBytes, typeSize) * typeSize; + collInfo->nChannels = 1 + DIVUP((workBytesTotal - firstEnqBytes) / typeSize, maxCount); + break; + } + +exit: + return ncclSuccess; +} static ncclResult_t scheduleCollTasksToPlan( struct ncclComm* comm, struct ncclKernelPlan* plan, int* nWorkBudget ) { struct ncclTasks* tasks = &comm->tasks; + size_t totalCBDBytes = tasks->workBytesTotal; + struct ncclInfo* collInfo; - size_t bytePerChannel[/*collNetSupport*/2]; - if (comm->channelSize > 0) { - // Set by user - bytePerChannel[/*collNetSupport=*/0] = comm->channelSize; - bytePerChannel[/*collNetSupport=*/1] = comm->channelSize; - } else { - // Latency increases as scale increases - // We would thus want to increase the chunk size to compensate for the lost efficiency - bytePerChannel[/*collNetSupport=*/0] = NCCL_AGG_CHANNEL_SIZE * std::min(16, comm->nRanks); - bytePerChannel[/*collNetSupport=*/1] = 256<<10; // Hand-tuned - } + if (!ncclIntruQueueEmpty(&tasks->collQueue)) { + int usableChannels = 0, accChannels = 0; - for (int collNetSupport=0; collNetSupport < 2; collNetSupport++) { - while (tasks->collBytesTotal < bytePerChannel[collNetSupport]*comm->nChannels && - bytePerChannel[collNetSupport] > NCCL_MIN_CHANNEL_SIZE) { - // Reduce per-channel size so we utilize all channels. - bytePerChannel[collNetSupport] /= 2; - } - } + tasks->usableChannels = 1; + while (!ncclIntruQueueEmpty(&tasks->collQueue)) { + collInfo = ncclIntruQueueDequeue(&tasks->collQueue); + if (collInfo->count == 0) continue; + if (collInfo->algorithm == NCCL_ALGO_UNDEF) { + struct ncclInfo* aggInfo = ncclMemoryStackAlloc(&comm->memScoped); + struct ncclInfo* nextInfo = collInfo->next; + int nvlsSupport; + int collNetSupport; - while (tasks->nTasksColl != 0) { - struct ncclTaskColl* head = ncclIntruQueueHead(&tasks->collQueue); - struct ncclInfo aggInfo = {}; - aggInfo.comm = comm; - aggInfo.coll = head->func; - aggInfo.datatype = head->datatype; - aggInfo.opFull = head->op; - aggInfo.op = (ncclRedOp_t)(int)head->op.op; - aggInfo.count = head->count; - int nAggChannels = 0; - int nAggOps = 1; - struct ncclTaskColl* aggEnd = head->next; - int nvlsSupport = comm->nvlsSupport && ncclNvlsSupported(aggInfo.opFull.op, aggInfo.datatype); - int collNetSupport = 0; - NCCLCHECK(getCollNetSupport(&aggInfo, &collNetSupport)); + memcpy(aggInfo, collInfo, sizeof(struct ncclInfo)); + while (nextInfo) { + if (nextInfo->coll == aggInfo->coll && nextInfo->opFull.op == aggInfo->opFull.op && nextInfo->datatype == aggInfo->datatype) { + aggInfo->count += nextInfo->count; + nextInfo = nextInfo->next; + } else { + break; + } + } - // Find a range of ops that can be aggregated together. - while (aggEnd != nullptr && - aggEnd->func == aggInfo.coll && - aggEnd->datatype == aggInfo.datatype && - aggEnd->op.op == aggInfo.opFull.op) { - aggInfo.count += aggEnd->count; - int nc = DIVUP(aggEnd->count*ncclTypeSize(aggInfo.datatype), bytePerChannel[collNetSupport]); - nc = std::max(1, std::min(nc, comm->nChannels)); - nAggChannels += nc; - nAggOps++; - aggEnd = aggEnd->next; - } + nvlsSupport = comm->nvlsSupport && ncclNvlsSupported(aggInfo->opFull.op, aggInfo->datatype); + NCCLCHECK(getCollNetSupport(aggInfo, &collNetSupport)); + NCCLCHECK(ncclInfoSetDerived(aggInfo, comm->nRanks)); + NCCLCHECK(getTunerInfo(aggInfo, collNetSupport, nvlsSupport, 1)); + NCCLCHECK(topoGetAlgoInfo(aggInfo, collNetSupport, nvlsSupport, 1)); + NCCLCHECK(getChannnelThreadInfo(aggInfo)); + NCCLCHECK(computeCollWorkFunc(aggInfo)); + NCCLCHECK(getPatternInfo(aggInfo)); - if (nAggOps > 1) { - NCCLCHECK(ncclInfoSetDerived(&aggInfo, comm->nRanks)); - aggInfo.nChannels = std::min(comm->nChannels, nAggChannels); - int opPerChannel = DIVUP(nAggChannels, aggInfo.nChannels); - NCCLCHECK(getAlgoInfo(&aggInfo, collNetSupport, nvlsSupport, opPerChannel)); - } + // Try to assign algo and proto to all possible collectives + nextInfo = collInfo; + while (nextInfo) { + if (nextInfo->coll == aggInfo->coll && nextInfo->opFull.op == aggInfo->opFull.op && nextInfo->datatype == aggInfo->datatype) { + NCCLCHECK(ncclInfoSetDerived(nextInfo, comm->nRanks)); + NCCLCHECK(getTunerInfo(nextInfo, collNetSupport, nvlsSupport, 1)); + nextInfo->algorithm = aggInfo->algorithm; + nextInfo->protocol = aggInfo->protocol; + nextInfo->nThreads = aggInfo->nThreads; + nextInfo->pattern = aggInfo->pattern; + nextInfo->workFuncIndex = aggInfo->workFuncIndex; + nextInfo->aggnBytes = aggInfo->nBytes; - while (head != aggEnd) { - struct ncclInfo info = {}; - info.comm = comm; - info.coll = head->func; - info.sendbuff = head->sendbuff; - info.recvbuff = head->recvbuff; - info.count = head->count; - info.root = head->root; - info.datatype = head->datatype; - info.opFull = head->op; // C++ struct assignment - info.op = (ncclRedOp_t)(int)head->op.op; - info.chunkSteps = head->chunkSteps; - info.sliceSteps = head->sliceSteps; - NCCLCHECK(ncclInfoSetDerived(&info, comm->nRanks)); - if (nAggOps > 1) { - int maxChannels = aggInfo.algorithm == NCCL_ALGO_NVLS || aggInfo.algorithm == NCCL_ALGO_NVLS_TREE ? comm->nvlsChannels : comm->nChannels; - info.nChannels = DIVUP(info.nBytes, bytePerChannel[collNetSupport]); - info.nChannels = std::max(1, std::min(info.nChannels, maxChannels)); - info.algorithm = aggInfo.algorithm; - info.protocol = aggInfo.protocol; - info.nThreads = aggInfo.nThreads; + NCCLCHECK(getChannnelThreadInfo(nextInfo)); + // if possible, start registration + registerIntraNodeBuffers(comm, plan, nextInfo); + // accumulate channels + accChannels += nextInfo->nChannels; + nextInfo = nextInfo->next; + } else { + break; + } + } + } // end of aggInfo + + if (collInfo->algorithm == NCCL_ALGO_NVLS || collInfo->algorithm == NCCL_ALGO_NVLS_TREE) { + usableChannels = std::max(usableChannels, comm->nvlsChannels); + } else { + usableChannels = std::max(usableChannels, comm->collChannels); } - int workFuncIndex; - struct ncclWorkElem workElem = {}; - struct ncclProxyOp proxyOp = {}; - // Check whether algo and proto have been preset (as in aggregation case) - // If so, skip the calculation - if (info.nChannels <= 0 || info.nThreads <= 0) { - NCCLCHECK(getAlgoInfo(&info, collNetSupport, nvlsSupport, 1)); - } - - if (*nWorkBudget < info.nChannels) return ncclSuccess; // Ensure room for addCollToPlan() - - /* if possible, start registration */ - ncclRegBufferType regBufType = NCCL_REGULAR_BUFFER; - void* regBufSend[NCCL_MAX_LOCAL_RANKS]; - void* regBufRecv[NCCL_MAX_LOCAL_RANKS]; - - registerIntraNodeBuffers(comm, plan, &info, regBufSend, regBufRecv, ®BufType); - - NCCLCHECK(computeColl(&info, &workFuncIndex, &workElem, &proxyOp)); - - int maxChannels = info.algorithm == NCCL_ALGO_NVLS || aggInfo.algorithm == NCCL_ALGO_NVLS_TREE ? comm->nvlsChannels : comm->nChannels; - NCCLCHECK(addCollToPlan(comm, plan, nWorkBudget, workFuncIndex, &workElem, &proxyOp, - maxChannels, info.nChannels, info.nBytes, regBufType, regBufSend, regBufRecv)); - tasks->nTasksColl -= 1; - tasks->collBytesTotal -= info.nBytes; - ncclIntruQueueDequeue(&tasks->collQueue); - head = ncclIntruQueueHead(&tasks->collQueue); - - plan->threadPerBlock = std::max(plan->threadPerBlock, info.nThreads); - if (!plan->kernelSpecialized) { - plan->kernelFn = ncclDevKernelForFunc[workFuncIndex]; - plan->kernelSpecialized = ncclDevKernelForFuncIsSpecialized[workFuncIndex]; + if (collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT || collInfo->algorithm == NCCL_ALGO_COLLNET_CHAIN || (collInfo->algorithm == NCCL_ALGO_NVLS && comm->nNodes > 1)) { + // substract collective which needs to be executed separately + totalCBDBytes -= collInfo->workBytes; + tasks->workBytesTotal -= collInfo->workBytes; + ncclIntruQueueEnqueue(&tasks->collnetQueue, collInfo); + } else if (collInfo->userTuned) { + // substract collective which needs to be executed separately + totalCBDBytes -= collInfo->workBytes; + tasks->workBytesTotal -= collInfo->workBytes; + ncclIntruQueueEnqueue(&tasks->collTunedQueue, collInfo); + } else { + ncclIntruQueueEnqueue(&tasks->collCBDQueue, collInfo); } } + + tasks->usableChannels = std::min(usableChannels, accChannels); } + + /* Calculate maxBytesPerChannel for CBD colls and it should be 16 bytes aligned + * Note: it it not hard upper bound for maxBytes, we can relax it if any optimization + * is needed */ + plan->maxBytesPerChannel = DIVUP(DIVUP(totalCBDBytes, tasks->usableChannels), NCCL_BYTES_ALIGNMENT) * NCCL_BYTES_ALIGNMENT; + // First enqueue CBD colls + while (!ncclIntruQueueEmpty(&tasks->collCBDQueue)) { + // Get nChannels and peek whether the budget allows before we enqueue + collInfo = ncclIntruQueueHead(&tasks->collCBDQueue); + collInfo->nChannels = DIVUP(collInfo->aggnBytes * tasks->usableChannels, totalCBDBytes); + // Haven't got nChannels info yet, relax the budget boundary a bit. + if (*nWorkBudget < collInfo->nChannels) return ncclSuccess; + + collInfo = ncclIntruQueueDequeue(&tasks->collCBDQueue); + NCCLCHECK(addCBDCollToPlan(comm, plan, tasks->usableChannels, collInfo, nWorkBudget)); + tasks->nTasksColl -= 1; + tasks->workBytesTotal -= collInfo->count * ncclTypeSize(collInfo->datatype); + } + + // Then enqueue collnet colls + while (!ncclIntruQueueEmpty(&tasks->collnetQueue)) { + collInfo = ncclIntruQueueHead(&tasks->collnetQueue); + if (*nWorkBudget < collInfo->nChannels) return ncclSuccess; + + collInfo = ncclIntruQueueDequeue(&tasks->collnetQueue); + NCCLCHECK(addCollnetCollToPlan(comm, plan, tasks->usableChannels, collInfo, nWorkBudget)); + tasks->nTasksColl -= 1; + } + + // Finally enqueue user-tuned colls + while (!ncclIntruQueueEmpty(&tasks->collTunedQueue)) { + collInfo = ncclIntruQueueHead(&tasks->collTunedQueue); + if (*nWorkBudget < collInfo->nChannels) return ncclSuccess; + + collInfo = ncclIntruQueueDequeue(&tasks->collTunedQueue); + NCCLCHECK(addTunedCollToPlan(comm, plan, tasks->usableChannels, collInfo, nWorkBudget)); + tasks->nTasksColl -= 1; + } + return ncclSuccess; } @@ -620,12 +870,8 @@ static ncclResult_t scheduleP2pTasksToPlan( // Try to use all channels int nChannelsMax = comm->p2pnChannelsPerPeer; int nChannelsMin = nChannelsMax; - if (comm->nNodes == 1) { - // Try to use all channels, but one channel per operation. - while (nChannelsMin*nRanks > comm->p2pnChannels && nChannelsMin > 1) nChannelsMin /= 2; - // Avoid overloading channels with 8+ operations as we loose the sync warp, hence a bit of bandwidth. - while (nChannelsMax*nRanks > comm->p2pnChannels*4 && nChannelsMax > 1) nChannelsMax /= 2; - } + // Try to use all channels, but one channel per operation. + while (nChannelsMin*nRanks > comm->p2pnChannels && nChannelsMin > 1) nChannelsMin /= 2; bool fuseOk = false; // We can perform 8 send/recv per round per CTA. Make sure we jump between fused blocks at node boundaries. @@ -654,7 +900,7 @@ static ncclResult_t scheduleP2pTasksToPlan( char* sendPtr = send ? (char*)send->buff : nullptr; ssize_t recvBytes = recv ? recv->bytes : 0; ssize_t sendBytes = send ? send->bytes : 0; - ssize_t minSize = stepSize/8; + ssize_t minSize = comm->nNodes > 1 ? stepSize/2 : stepSize/8; ssize_t maxSize = comm->nNodes > 1 ? stepSize : stepSize*32; ssize_t recvChunkBytesMax = calcP2pChunkSize(recvBytes, nChannelsMin, nChannelsMax, minSize, maxSize); ssize_t sendChunkBytesMax = calcP2pChunkSize(sendBytes, nChannelsMin, nChannelsMax, minSize, maxSize); @@ -825,32 +1071,60 @@ static ncclResult_t uploadProxyOps(struct ncclComm* comm, struct ncclKernelPlan* uint64_t collOpCount = comm->sharedRes->collOpCount; // Advance comm's collOpCount by number of colls in this plan. comm->sharedRes->collOpCount += plan->collOpCount; + + uint64_t p2pOpBump[MAXCHANNELS]; + struct ncclProxyOp* heads[MAXCHANNELS]; + uint64_t headIds[MAXCHANNELS]; + int nHeads = 0; for (int c=0; c < plan->channelUbound; c++) { - struct ncclProxyOp* q = ncclIntruQueueHead(&plan->channels[c].proxyOpQueue); - uint64_t p2pOpCount = comm->sharedRes->p2pOpCount[c]; - uint64_t nextP2pOpCount = p2pOpCount; - while (q != nullptr) { - struct ncclProxyOp* qNext = q->enqNext; - // Ignoring the bottom tag bit, opCount's are zero-based within plan so - // translate them to the tip of the comm's history. - if (q->opCount & 1) { // p2p - // p2pOpCount is monotonic increasing within a plan's channel so just - // remember last value to compute max. - nextP2pOpCount = p2pOpCount + (q->opCount>>1); - nextP2pOpCount += 1; // +1 to ensure next plan doesn't collide - q->opCount = (p2pOpCount<<1) + q->opCount; - } else { // coll - q->opCount = (collOpCount<<1) + q->opCount; - } - NCCLCHECK(ncclProxySaveOp(comm, q, nullptr)); // May overwrite enqNext. - if (!plan->persistent) { - // Non-persistent kernels have their memory reclaimed after upload. - ncclMemoryPoolFree(&plan->memPool_ncclProxyOp, q); - } - q = qNext; + p2pOpBump[c] = 0; + heads[c] = ncclIntruQueueHead(&plan->channels[c].proxyOpQueue); + nHeads += (heads[c] != nullptr) ? 1 : 0; + headIds[c] = (heads[c] != nullptr) ? heads[c]->opCount : uint64_t(-1); + } + + while (nHeads != 0) { + int minChan = -1; + uint64_t minId = uint64_t(-1); + // We store the heads[c]->opCount in headIds[c] specifically to remove indirect + // loads from this loop which speeds it up considerably. + for (int c=0; c < plan->channelUbound; c++) { + uint64_t id = headIds[c]; + id = (id>>1 | id<<63); // Move tag bit to order collectives before p2p's + if (id < minId) { minChan = c; minId = id; } } + + struct ncclProxyOp* q = heads[minChan]; + uint64_t oldId = headIds[minChan]; // same as q->opCount + // Advance heads[c] + heads[minChan] = q->enqNext; + if (q->enqNext == nullptr) nHeads -= 1; + headIds[minChan] = (q->enqNext != nullptr) ? q->enqNext->opCount : uint64_t(-1); + + // Ignoring the bottom tag bit, opCount's are zero-based within plan so + // translate them to the tip of the comm's history. + if (oldId & 1) { // p2p + // opCount is monotonic increasing within a plan's channel so just + // remember last value to compute max. + p2pOpBump[minChan] = (oldId>>1) + 1; // +1 to ensure next plan doesn't collide + q->opCount = (comm->sharedRes->p2pOpCount[minChan]<<1) + oldId; + } else { // coll + q->opCount = (collOpCount<<1) + oldId; + } + + NCCLCHECK(ncclProxySaveOp(comm, q, nullptr)); + q->opCount = oldId; // Restore for next uploadProxyOps() + if (!plan->persistent) { + // Non-persistent kernels upload ops only once so can be free'd here. + ncclMemoryPoolFree(&comm->memPool_ncclProxyOp, q); + } + } + + for (int c=0; c < plan->channelUbound; c++) { + // Erase proxyOpQueue since all ops were free'd back to mempool. + if (!plan->persistent) ncclIntruQueueConstruct(&plan->channels[c].proxyOpQueue); // Advance channel's p2pOpCount by number of p2p's in this plan channel. - comm->sharedRes->p2pOpCount[c] = nextP2pOpCount; + comm->sharedRes->p2pOpCount[c] += p2pOpBump[c]; } return ncclSuccess; } @@ -883,7 +1157,7 @@ static ncclResult_t reclaimPlan(struct ncclComm* comm, struct ncclCommCallback* struct ncclProxyOp* q = ncclIntruQueueHead(&plan->channels[c].proxyOpQueue); while (q != nullptr) { struct ncclProxyOp* q1 = q->enqNext; - ncclMemoryPoolFree(&plan->memPool_ncclProxyOp, q); + ncclMemoryPoolFree(&comm->memPool_ncclProxyOp, q); q = q1; } } @@ -900,7 +1174,6 @@ static ncclResult_t reclaimPlan(struct ncclComm* comm, struct ncclCommCallback* ncclMemoryPoolFree(&comm->memPool_ncclNvlsHandleList, obj); } } - ncclMemoryPoolTakeAll(&comm->memPool_ncclProxyOp, &plan->memPool_ncclProxyOp); ncclMemoryPoolFree(&comm->memPool_ncclKernelPlan, plan); return ncclSuccess; } @@ -1118,7 +1391,7 @@ ncclResult_t ncclLaunchKernelAfter_NoCuda(struct ncclComm* comm, struct ncclKern ncclResult_t ncclLaunchFinish(struct ncclComm* comm) { ncclResult_t result = ncclSuccess; struct ncclTasks* tasks = &comm->tasks; - tasks->collBytesTotal = 0; // Just in case subtraction during scheduleCollTasksToPlan() doesn't get to 0 + tasks->workBytesTotal = 0; // Just in case subtraction during scheduleCollTasksToPlan() doesn't get to 0 // Deallocate ncclWork's. This frame exists so long as ncclLaunchPrepare // succeeded, and if it ncclLaunchPrepare didn't succeed we wouldn't be here. @@ -1158,42 +1431,50 @@ ncclResult_t ncclLaunchFinish(struct ncclComm* comm) { static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetSupport) { // Translate ncclAvg and PreMulSum ncclRedOp_t netOp = info->op == ncclAvg || info->op >= ncclNumOps ? ncclSum : info->op; - *collNetSupport = info->comm->collNetSupport && info->comm->collNetSupportMatrix[netOp][info->datatype]; + *collNetSupport = info->comm->collNetSupport; + switch (info->coll) { + case ncclFuncAllReduce: + case ncclFuncReduce: + case ncclFuncReduceScatter: + *collNetSupport &= info->comm->collNetSupportMatrix[netOp][info->datatype]; + break; + default: + break; + } return ncclSuccess; } // numPipeOps: number of pipelined ops. Can be greater than 1 in aggregation mode. Used to adjust latency. -static ncclResult_t topoGetAlgoInfo(struct ncclInfo* info, int collNetSupport, int nvlsSupport, int numPipeOps) { - struct ncclComm* comm = info->comm; +static ncclResult_t topoGetAlgoInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps) { + struct ncclComm* comm = collInfo->comm; if (comm->nRanks == 1) { - info->algorithm = NCCL_ALGO_RING; - info->protocol = NCCL_PROTO_SIMPLE; + collInfo->algorithm = NCCL_ALGO_RING; + collInfo->protocol = NCCL_PROTO_SIMPLE; } - else if (info->algorithm == NCCL_ALGO_UNDEF || info->protocol == NCCL_PROTO_UNDEF) { + else if (collInfo->algorithm == NCCL_ALGO_UNDEF || collInfo->protocol == NCCL_PROTO_UNDEF) { float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete. float backupMinTime = 3600000000.0; bool backup = false; int backupAlgo = NCCL_ALGO_UNDEF; // back up algo and proto if no algo/proto is picked up. int backupProto = NCCL_PROTO_UNDEF; // Find algorithm / protocol. - info->algorithm = -1; - info->protocol = -1; + collInfo->algorithm = -1; + collInfo->protocol = -1; int nAlgos = NCCL_NUM_ALGORITHMS; for (int a=0; acoll != ncclFuncAllGather) continue; + if ((a == NCCL_ALGO_NVLS || a == NCCL_ALGO_NVLS_TREE) && nvlsSupport != 1) continue; if (a == NCCL_ALGO_NVLS && collNetSupport != 1 && comm->nNodes > 1) continue; /* now we only support single-node NVLS allgather and reducescatter */ - if (a == NCCL_ALGO_NVLS && (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) && comm->nNodes > 1) continue; - if (a == NCCL_ALGO_NVLS_TREE && nvlsSupport != 1) continue; + if (a == NCCL_ALGO_NVLS && (collInfo->coll == ncclFuncAllGather || collInfo->coll == ncclFuncReduceScatter) && comm->nNodes > 1) continue; for (int p=0; p= 0 && time < minTime) { - info->algorithm = a; - info->protocol = p; + collInfo->algorithm = a; + collInfo->protocol = p; minTime = time; } } else { @@ -1206,51 +1487,18 @@ static ncclResult_t topoGetAlgoInfo(struct ncclInfo* info, int collNetSupport, i } } - if (info->algorithm == NCCL_ALGO_UNDEF || info->protocol == NCCL_PROTO_UNDEF) { + if (collInfo->algorithm == NCCL_ALGO_UNDEF || collInfo->protocol == NCCL_PROTO_UNDEF) { if (backupAlgo == NCCL_ALGO_UNDEF || backupProto == NCCL_PROTO_UNDEF) { WARN("Error : no algorithm/protocol available"); return ncclInternalError; } - info->algorithm = backupAlgo; - info->protocol = backupProto; + collInfo->algorithm = backupAlgo; + collInfo->protocol = backupProto; } - //if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); - TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); + if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", collInfo->nBytes, collInfo->algorithm, collInfo->protocol, minTime); + TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", collInfo->nBytes, collInfo->algorithm, collInfo->protocol, minTime); } - int nc = (info->nChannels > 0) ? info->nChannels : comm->nChannels; - int nt = comm->maxThreads[info->algorithm][info->protocol]; - int threadThreshold = comm->threadThresholds[info->algorithm][info->protocol]; - if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) { - // CollNet channel tuning - int ncSwitch = 16; - bool flag = true; - while (ncSwitch >= 1 && flag) { - while ((flag = info->nBytes < nc*nt*info->comm->channels[0].collnetDirect.nHeads*threadThreshold) && nc > ncSwitch) { - if (nc == ncSwitch+ncSwitch/2) threadThreshold /= 2; - nc--; - } - ncSwitch /= 2; - } - } else if (info->algorithm == NCCL_ALGO_NVLS || info->algorithm == NCCL_ALGO_NVLS_TREE) { - // NVLS should not need more than 16 channels to get peak BW. - nc = comm->nvlsChannels; - } else { - // Ring/Tree channel tuning - while (info->nBytes < nc*nt*threadThreshold) { - if (nc >= 2) nc--; - else if ((nt % 128) == 0) nt/=2; - else break; - } - } - if (info->protocol == NCCL_PROTO_SIMPLE) { - if (info->algorithm == NCCL_ALGO_RING) nt += WARP_SIZE; // Extra warp for sync - // More threads or sync warps needed due to split thread model - if (info->algorithm == NCCL_ALGO_TREE) nt += 4*WARP_SIZE; - } - nt = nt/WARP_SIZE < 3 ? 3*WARP_SIZE : nt; - info->nChannels = nc; - info->nThreads = nt; return ncclSuccess; } @@ -1258,178 +1506,262 @@ static ncclResult_t topoGetAlgoInfo(struct ncclInfo* info, int collNetSupport, i // Call the plugin first. Let it set algo+proto, and/or nChannels. // Then, topoGetAlgoInfo will set algo/proto if not set, then nChannels and nThreads based on algo/proto. // Finally, nChannels will be overriden by the plugin setting. -static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetSupport, int nvlsSupport, int numPipeOps) { - info->algorithm = NCCL_ALGO_UNDEF; - info->protocol = NCCL_PROTO_UNDEF; - int nChannels = 0; - if (info->comm->tuner != NULL) { - NCCLCHECK(info->comm->tuner->getCollInfo( - info->coll, info->nBytes, +static ncclResult_t getTunerInfo(struct ncclInfo* collInfo, int collNetSupport, int nvlsSupport, int numPipeOps) { + collInfo->algorithm = NCCL_ALGO_UNDEF; + collInfo->protocol = NCCL_PROTO_UNDEF; + collInfo->nChannels = 0; + if (collInfo->comm->tuner != NULL) { + NCCLCHECK(collInfo->comm->tuner->getCollInfo( + collInfo->coll, collInfo->nBytes, collNetSupport, nvlsSupport, numPipeOps, - &info->algorithm, &info->protocol, &nChannels)); + &collInfo->algorithm, &collInfo->protocol, &collInfo->nChannels)); } - NCCLCHECK(topoGetAlgoInfo(info, collNetSupport, nvlsSupport, numPipeOps)); - if (nChannels) info->nChannels = nChannels; // Set by plugin; override default. + + /* We only honor nChannels decision when user sets the nChannels by tuner plugin or the coll picks + * collnet algorithm. For other cases, we need to decide nChannels based on the maxBytesPerChannel */ + if (collInfo->nChannels != 0) + collInfo->userTuned = true; + else + collInfo->userTuned = false; return ncclSuccess; } -static ncclResult_t getPatternInfo(struct ncclInfo* info) { - switch (info->coll) { +/* Compute nChannels and nThreads. */ +static ncclResult_t getChannnelThreadInfo(struct ncclInfo* collInfo) { + struct ncclComm *comm = collInfo->comm; + int nc = comm->collChannels; + int nt = comm->maxThreads[collInfo->algorithm][collInfo->protocol]; + int threadThreshold = comm->threadThresholds[collInfo->algorithm][collInfo->protocol]; + + if (collInfo->nChannels == 0) { + /* not preset by users */ + if (collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT) { + // CollNet channel tuning + int ncSwitch = 16; + bool flag = true; + while (ncSwitch >= 1 && flag) { + while ((flag = collInfo->nBytes < nc * nt * collInfo->comm->channels[0].collnetDirect.nHeads * threadThreshold) && nc > ncSwitch) { + if (nc == ncSwitch + ncSwitch / 2) threadThreshold /= 2; + nc--; + } + ncSwitch /= 2; + } + } else if (collInfo->algorithm == NCCL_ALGO_NVLS || collInfo->algorithm == NCCL_ALGO_NVLS_TREE) { + // NVLS should not need more than 16 channels to get peak BW. + nc = comm->nvlsChannels; + } else { + // Ring/Tree channel tuning + while (collInfo->nBytes < nc * nt * threadThreshold) { + if (nc >= 2) nc--; + else break; + } + } + collInfo->nChannels = nc; + } else { + nc = collInfo->nChannels; + } + + if (collInfo->nThreads == 0) { + if (collInfo->algorithm != NCCL_ALGO_NVLS && collInfo->algorithm != NCCL_ALGO_NVLS_TREE && + collInfo->algorithm != NCCL_ALGO_COLLNET_DIRECT) { + while (collInfo->nBytes < nc * nt * threadThreshold) { + if (nt % 128 == 0) nt /= 2; + else break; + } + } + + if (collInfo->protocol == NCCL_PROTO_SIMPLE) { + if (collInfo->algorithm == NCCL_ALGO_RING) nt += WARP_SIZE; // Extra warp for sync + // More threads or sync warps needed due to split thread model + if (collInfo->algorithm == NCCL_ALGO_TREE) nt += 4*WARP_SIZE; + } + nt = nt / WARP_SIZE < 3 ? 3 * WARP_SIZE : nt; + collInfo->nThreads = nt; + } + + return ncclSuccess; +} + +static ncclResult_t getPatternInfo(struct ncclInfo* collInfo) { + switch (collInfo->coll) { case ncclFuncBroadcast: - info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeDown : ncclPatternPipelineFrom; break; + collInfo->pattern = collInfo->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeDown : ncclPatternPipelineFrom; break; case ncclFuncReduce: - info->pattern = info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break; + collInfo->pattern = collInfo->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUp : ncclPatternPipelineTo; break; case ncclFuncReduceScatter: case ncclFuncAllGather: - info->pattern = - info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : + collInfo->pattern = + collInfo->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : + collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect : ncclPatternRing; break; case ncclFuncAllReduce: - info->pattern = - info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : - info->algorithm == NCCL_ALGO_NVLS_TREE ? ncclPatternNvlsTree : - info->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect : - info->algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain : - info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : + collInfo->pattern = + collInfo->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : + collInfo->algorithm == NCCL_ALGO_NVLS_TREE ? ncclPatternNvlsTree : + collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect : + collInfo->algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain : + collInfo->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : ncclPatternRingTwice; break; default: - WARN("Unknown pattern for collective %d algorithm %d", info->coll, info->algorithm); + WARN("Unknown pattern for collective %d algorithm %d", collInfo->coll, collInfo->algorithm); return ncclInternalError; } return ncclSuccess; } -static ncclResult_t getLoopInfo(struct ncclInfo* info) { - switch (info->pattern) { - case ncclPatternTreeUp: - case ncclPatternTreeDown: - case ncclPatternTreeUpDown: - case ncclPatternPipelineFrom: - case ncclPatternPipelineTo: - case ncclPatternCollnetChain: - info->nstepsPerLoop = info->nchunksPerLoop = 1; break; - case ncclPatternNvls: - info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].nvls.nHeads; break; - case ncclPatternCollnetDirect: - info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].collnetDirect.nHeads; break; - case ncclPatternRing: - info->nstepsPerLoop = info->comm->nRanks-1; info->nchunksPerLoop = info->comm->nRanks; break; - case ncclPatternRingTwice: - info->nstepsPerLoop = 2*(info->comm->nRanks-1); info->nchunksPerLoop = info->comm->nRanks; break; - case ncclPatternNvlsTree: - info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].nvls.nHeads; break; - default: - WARN("Unknown pattern %d", info->pattern); - return ncclInternalError; - } +static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo) { + collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol); return ncclSuccess; } -static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */) { - // Set nstepsPerLoop and nchunksPerLoop - NCCLCHECK(getPatternInfo(info)); - NCCLCHECK(getLoopInfo(info)); +static ncclResult_t initCollWorkElem(struct ncclInfo* collInfo, struct ncclWorkElem* work) { + work->sendbuff = collInfo->sendbuff; + work->recvbuff = collInfo->recvbuff; + work->root = collInfo->root; + work->count = collInfo->count; + work->nWarps = collInfo->nThreads / WARP_SIZE; + work->redOpArg = collInfo->opFull.scalarArg; + work->redOpArgIsPtr = collInfo->opFull.scalarArgIsPtr; + work->chunkCount = collInfo->chunkCount; + work->regUsed = 0; + work->isUsed = 1; - work->sendbuff = info->sendbuff; - work->recvbuff = info->recvbuff; - work->root = info->root; - work->count = info->count; - work->nChannels = info->nChannels; - work->nWarps = info->nThreads / WARP_SIZE; - work->redOpArg = info->opFull.scalarArg; - work->redOpArgIsPtr = info->opFull.scalarArgIsPtr; - *workFuncIndex = ncclDevFuncId(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); - - int stepSize = info->comm->buffSizes[info->protocol]/NCCL_STEPS; - int chunkSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->chunkSteps : 1; - int sliceSteps = (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) ? info->sliceSteps : 1; - int chunkSize = stepSize*chunkSteps; - - // Compute lastChunkSize - if (info->algorithm == NCCL_ALGO_TREE && info->protocol == NCCL_PROTO_SIMPLE) { - if (info->pattern == ncclPatternTreeUpDown) { - // Optimize chunkSize / nSteps - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth*8 && chunkSize > 131072) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth*4 && chunkSize > 65536) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].tree.depth && chunkSize > 32768) chunkSize /= 2; - } - // Use lastChunkSize as chunkSize - work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); - } else if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) { - // Optimize chunkSize / nSteps - while (info->nBytes / (info->nChannels*info->comm->channels[0].collnetDirect.nHeads*chunkSize) < info->comm->channels[0].collnetDirect.depth*64 && chunkSize > 131072) chunkSize /= 2; - while (info->nBytes / (info->nChannels*info->comm->channels[0].collnetDirect.nHeads*chunkSize) < info->comm->channels[0].collnetDirect.depth*8 && chunkSize > 65536) chunkSize /= 2; - while (info->nBytes / (info->nChannels*info->comm->channels[0].collnetDirect.nHeads*chunkSize) < info->comm->channels[0].collnetDirect.depth*8 && chunkSize > 32768) chunkSize /= 2; - // Use lastChunkSize as chunkSize - work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); + if (collInfo->comm->nNodes == 1) + work->oneNode = 1; + else + work->oneNode = 0; + if (collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT) { // Set direct direction for broadcast-gather (read or write) - work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ; - } else if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) { - stepSize = info->comm->buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; - chunkSize = std::min(256*1024, stepSize*chunkSteps); - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*64 && chunkSize > 131072) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; - work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); - } else if (info->algorithm == NCCL_ALGO_NVLS) { + work->direct = (collInfo->nBytes / collInfo->nChannels <= 1024 * 1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ; + } else { + work->direct = 0; + } + return ncclSuccess; +} + +static ncclResult_t setCollWorkElem(uint64_t workCount, uint64_t workOffset, size_t lastChunkCount, struct ncclWorkElem* work) { + work->workCount = workCount; + work->workOffset = workOffset; + work->lastChunkCount = lastChunkCount; + return ncclSuccess; +} + +static ncclResult_t initCollWorkElemReg(struct ncclComm* comm, struct ncclWorkElem* work, struct ncclChannel* channel, ncclRegBufferType regBufType, void* regBufSend[], void* regBufRecv[], struct ncclWorkElemReg* workElemReg) { + if (regBufType == NCCL_IPC_REG_BUFFER) { + workElemReg->elem = *work; + workElemReg->elem.regUsed = 1; + for (int i = 0; i < NCCL_MAX_DIRECT_ARITY; i++) { + int peer = channel->collnetDirect.down[i]; + if (peer == -1) break; + int j = comm->rankToLocalRank[peer]; // Get intra-node slot + workElemReg->dnInputs[i] = regBufSend[j]; // Input buffer of leaf peer + workElemReg->dnOutputs[i] = regBufRecv[j]; // Output buffer of leaf peer + } + for (int i = 0; i < NCCL_MAX_DIRECT_ARITY; i++) { + int peer = channel->collnetDirect.up[i]; + if (peer == -1) break; + int j = comm->rankToLocalRank[peer]; + // Output buffer of root peer + workElemReg->upOutputs[i] = regBufRecv[j]; + } + } else if (regBufType == NCCL_NVLS_REG_BUFFER) { + workElemReg->elem = *work; + workElemReg->elem.regUsed = 1; + /* NVLS only has one send and recv buffer registered */ + workElemReg->dnInputs[0] = regBufSend[0]; + workElemReg->dnOutputs[0] = regBufRecv[0]; + } else { + /* impossible value */ + WARN("Invalid regBufType %d\n", regBufType); + return ncclInvalidArgument; + } + return ncclSuccess; +} + +NCCL_PARAM(NvlsTreeChunkSize, "NVLSTREE_MAX_CHUNKSIZE", -2); + +static ncclResult_t computeCollChunkInfo(struct ncclInfo* collInfo, size_t nBytes, int nChannels) { + int stepSize = collInfo->comm->buffSizes[collInfo->protocol] / NCCL_STEPS; + int chunkSteps = (collInfo->protocol == NCCL_PROTO_SIMPLE && collInfo->algorithm == NCCL_ALGO_RING) ? collInfo->chunkSteps : 1; + int sliceSteps = (collInfo->protocol == NCCL_PROTO_SIMPLE && collInfo->algorithm == NCCL_ALGO_RING) ? collInfo->sliceSteps : 1; + int chunkSize = stepSize * chunkSteps; + + if (collInfo->protocol == NCCL_PROTO_LL) chunkSize /= 2; + if (collInfo->protocol == NCCL_PROTO_LL128) chunkSize = (chunkSize / NCCL_LL128_LINEELEMS) * NCCL_LL128_DATAELEMS; + + if (collInfo->algorithm == NCCL_ALGO_COLLNET_DIRECT) { + // Optimize chunkSize / nSteps + while (nBytes / (nChannels * collInfo->comm->channels[0].collnetDirect.nHeads * chunkSize) < collInfo->comm->channels[0].collnetDirect.depth * 64 && chunkSize > 131072) chunkSize /= 2; + while (nBytes / (nChannels * collInfo->comm->channels[0].collnetDirect.nHeads * chunkSize) < collInfo->comm->channels[0].collnetDirect.depth * 8 && chunkSize > 65536) chunkSize /= 2; + while (nBytes / (nChannels * collInfo->comm->channels[0].collnetDirect.nHeads * chunkSize) < collInfo->comm->channels[0].collnetDirect.depth * 8 && chunkSize > 32768) chunkSize /= 2; + } else if (collInfo->algorithm == NCCL_ALGO_COLLNET_CHAIN) { + stepSize = collInfo->comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS; + chunkSize = std::min(256 * 1024, stepSize * chunkSteps); + while (nBytes / (nChannels * chunkSize) < collInfo->comm->channels[0].collnetChain.depth * 64 && chunkSize > 131072) chunkSize /= 2; + while (nBytes / (nChannels * chunkSize) < collInfo->comm->channels[0].collnetChain.depth * 8 && chunkSize > 65536) chunkSize /= 2; + while (nBytes / (nChannels * chunkSize) < collInfo->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; + } else if (collInfo->algorithm == NCCL_ALGO_NVLS) { int maxChunkSize = 131072; - if (info->comm->nNodes > 1 && info->comm->bandwidths[ncclFuncAllReduce][NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] < 150) maxChunkSize = 32768; + if (collInfo->comm->nNodes > 1 && collInfo->comm->bandwidths[ncclFuncAllReduce][NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] < 150) maxChunkSize = 32768; if (chunkSize > maxChunkSize) chunkSize = maxChunkSize; // Use uint64_t so that concurrentOps*chunkSize*X does not overflow - uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads; - if ((info->nBytes < (64 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; - if ((info->nBytes < (8 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; - if ((info->nBytes < (2 * (concurrentOps*chunkSize))) && (chunkSize > 16384)) chunkSize = 16384; - work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); - } else if (info->algorithm == NCCL_ALGO_NVLS_TREE) { + uint64_t concurrentOps = nChannels * collInfo->comm->channels[0].nvls.nHeads; + if ((nBytes < (64 * (concurrentOps * chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; + if ((nBytes < (8 * (concurrentOps * chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; + if ((nBytes < (2 * (concurrentOps * chunkSize))) && (chunkSize > 16384)) chunkSize = 16384; + } else if (collInfo->algorithm == NCCL_ALGO_NVLS_TREE) { // Use uint64_t so that concurrentOps*chunkSize*X does not overflow - uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads; - if (info->comm->nNodes >= 4) chunkSize = 65536; - if ((info->nBytes < (32 * (concurrentOps*chunkSize))) && (chunkSize > 262144)) chunkSize = 262144; - if ((info->nBytes < (16 * (concurrentOps*chunkSize))) && (chunkSize > 131072)) chunkSize = 131072; - if ((info->nBytes < (4 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; - if ((info->nBytes < (1 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; - work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); - } else if (info->protocol == NCCL_PROTO_LL) { - const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); - const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; - work->lastChunkSize = DIVUP((info->nBytes-(info->nBytes/loopSize)*loopSize), info->nChannels*info->nchunksPerLoop); - ALIGN_SIZE(work->lastChunkSize, info->nThreads*sizeof(uint64_t)); - work->lastChunkSize /= ncclTypeSize(info->datatype); - } else if (info->algorithm == NCCL_ALGO_TREE && info->protocol == NCCL_PROTO_LL128) { - int nNodes = info->comm->nNodes; - float ppn = info->comm->nRanks / (float)nNodes; + uint64_t concurrentOps = nChannels * collInfo->comm->channels[0].nvls.nHeads; + int maxChunkSize = ncclParamNvlsTreeChunkSize(); + if (maxChunkSize == -2) maxChunkSize = collInfo->comm->nNodes >= 4 ? 65536 : chunkSize; + chunkSize = std::min(chunkSize, maxChunkSize); + if ((nBytes < (32 * (concurrentOps * chunkSize))) && (chunkSize > 262144)) chunkSize = 262144; + if ((nBytes < (16 * (concurrentOps * chunkSize))) && (chunkSize > 131072)) chunkSize = 131072; + if ((nBytes < (4 * (concurrentOps * chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; + if ((nBytes < (1 * (concurrentOps * chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; + } else if (collInfo->algorithm == NCCL_ALGO_TREE && collInfo->protocol == NCCL_PROTO_LL128) { + int nNodes = collInfo->comm->nNodes; + float ppn = collInfo->comm->nRanks / (float)nNodes; float nstepsLL128 = 1+log2i(nNodes) + 0.1*ppn; - while (info->nBytes / (info->nChannels*chunkSize) < nstepsLL128*64/ppn && chunkSize > 131072) chunkSize /= 2; - while (info->nBytes / (info->nChannels*chunkSize) < nstepsLL128*16/ppn && chunkSize > 32768) chunkSize /= 2; - // Use lastChunkSize as chunkSize - work->lastChunkSize = chunkSize*NCCL_LL128_DATAELEMS/(NCCL_LL128_LINEELEMS*ncclTypeSize(info->datatype)); + while (nBytes / (nChannels*chunkSize) < nstepsLL128*64/ppn && chunkSize > 131072) chunkSize /= 2; + while (nBytes / (nChannels*chunkSize) < nstepsLL128*16/ppn && chunkSize > 32768) chunkSize /= 2; } - // Compute nSteps for proxies - int chunkEffectiveSize = chunkSize; - if (info->protocol == NCCL_PROTO_LL) chunkEffectiveSize /= 2; - if (info->protocol == NCCL_PROTO_LL128) chunkEffectiveSize = (chunkSize / NCCL_LL128_LINEELEMS) * NCCL_LL128_DATAELEMS; - //if (info->comm->rank == 0) printf("Coll %d, size %ld -> %dx%d, chunkSize %d (algo %d proto%d)\n", info->coll, info->nBytes, info->nChannels, info->nThreads, chunkSize, info->algorithm, info->protocol); - int nLoops = (int)(DIVUP(info->nBytes, (((size_t)(info->nChannels))*info->nchunksPerLoop*chunkEffectiveSize))); - proxyOp->nsteps = info->nstepsPerLoop * nLoops * chunkSteps; - proxyOp->sliceSteps = sliceSteps; - proxyOp->chunkSteps = chunkSteps; - proxyOp->chunkSize = chunkSize; - proxyOp->protocol = info->protocol; - proxyOp->dtype = info->datatype; - proxyOp->redOp = info->opFull.op==ncclDevPreMulSum || info->opFull.op==ncclDevSumPostDiv ? ncclSum : // Network sees avg as sum - info->opFull.proxyOp; - proxyOp->pattern = info->pattern; - proxyOp->root = info->root; + collInfo->chunkSize = chunkSize; + collInfo->chunkCount = chunkSize / ncclTypeSize(collInfo->datatype); + collInfo->chunkSteps = chunkSteps; + collInfo->sliceSteps = sliceSteps; + collInfo->stepSize = stepSize; + return ncclSuccess; +} + +static ncclResult_t initCollProxyOp(struct ncclInfo* collInfo, int channelId, uint64_t opCount, uint32_t nsteps, struct ncclProxyOp* proxyOp) { + proxyOp->nsteps = nsteps; + proxyOp->sliceSteps = collInfo->sliceSteps; + proxyOp->chunkSteps = collInfo->chunkSteps; + proxyOp->chunkSize = collInfo->chunkSize; + proxyOp->protocol = collInfo->protocol; + proxyOp->dtype = collInfo->datatype; + // Network sees avg as sum + proxyOp->redOp = collInfo->opFull.op == ncclDevPreMulSum || collInfo->opFull.op == ncclDevSumPostDiv ? ncclSum : collInfo->opFull.proxyOp; + proxyOp->pattern = collInfo->pattern; + proxyOp->coll = collInfo->coll; + proxyOp->root = collInfo->root; + proxyOp->reg = 0; // This is used by P2P to reduce the receive buffer size. We don't use it in collectives // because some protocols need to transmit more than the total size, plus they sometimes // round up - proxyOp->nbytes = stepSize*proxyOp->sliceSteps; + proxyOp->nbytes = collInfo->stepSize * proxyOp->sliceSteps; + proxyOp->channelId = channelId; + proxyOp->opCount = opCount; - TRACE(NCCL_COLL,"opCount %lx slicesteps %d spl %d cpl %d nbytes %zi -> protocol %d nchannels %d nthreads %d, nloops %d nsteps %d chunksize %d comm %p", - proxyOp->opCount, sliceSteps, info->nstepsPerLoop, info->nchunksPerLoop, info->nBytes, info->protocol, info->nChannels, info->nThreads, - nLoops, proxyOp->nsteps, chunkSize, info->comm); + if (collInfo->pattern == ncclPatternCollnetDirect) { + proxyOp->specifics.collnetDirect.nNodes = collInfo->comm->nNodes; + proxyOp->specifics.collnetDirect.node = collInfo->comm->node; + if (collInfo->coll == ncclFuncAllGather || collInfo->coll == ncclFuncReduceScatter) { + proxyOp->specifics.collnetDirect.sizePerRank = collInfo->count * ncclTypeSize(collInfo->datatype); + } + } return ncclSuccess; } @@ -1511,11 +1843,26 @@ static ncclResult_t hostToDevRedOp( return ncclSuccess; } +static int collCmp(struct ncclInfo *a, struct ncclInfo *b) { + if (a->coll > b->coll) + return 1; + else if (a->coll == b->coll && a->datatype > b->datatype) + return 1; + else if (a->coll == b->coll && a->datatype == b->datatype && a->opFull.op > b->opFull.op) + return 1; + else if (a->coll == b->coll && a->datatype == b->datatype && a->opFull.op == b->opFull.op && a->count > b->count) + return 1; + else + return -1; +} + // Converts `info` to a task and adds it to `comm->tasks`. The exception is with // single rank communicators, collectives are issued as `ncclMemcpyAsync`s and // thus don't need a task. -static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo const* info) { +static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) { ncclTasks *tasks = &comm->tasks; + + if (info->count == 0 && info->coll != ncclFuncSend && info->coll != ncclFuncRecv) return ncclSuccess; if (info->coll == ncclFuncSend || info->coll == ncclFuncRecv) { int peer = info->root; ssize_t nBytes = info->count*ncclTypeSize(info->datatype); @@ -1558,27 +1905,23 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo const* inf } else { // Copy reduction op state from op handle into info struct here since the // op handle may be destroyed before ncclGroupEnd(). - struct ncclDevRedOpFull opFull; - NCCLCHECK(hostToDevRedOp(&opFull, info->op, info->datatype, comm)); + NCCLCHECK(hostToDevRedOp(&info->opFull, info->op, info->datatype, comm)); if (comm->nRanks == 1) { - NCCLCHECK(ncclLaunchOneRank(info->recvbuff, info->sendbuff, info->count, opFull, info->datatype, info->stream)); + NCCLCHECK(ncclLaunchOneRank(info->recvbuff, info->sendbuff, info->count, info->opFull, info->datatype, info->stream)); return ncclSuccess; } else { // Must be in thread local group before tasks can be alloc'd in `comm->memScoped`. ncclGroupCommJoin(info->comm); - struct ncclTaskColl* t = ncclMemoryStackAlloc(&comm->memScoped); - t->func = info->coll; - t->sendbuff = info->sendbuff; - t->recvbuff = info->recvbuff; - t->count = info->count; - t->root = info->root; - t->datatype = info->datatype; - t->op = opFull; // C++ struct assignment - t->chunkSteps = info->chunkSteps; - t->sliceSteps = info->sliceSteps; - ncclIntruQueueEnqueue(&tasks->collQueue, t); - tasks->collBytesTotal += info->nBytes; + struct ncclInfo* t = ncclMemoryStackAlloc(&comm->memScoped); + info->nChannels = 0; + info->nThreads = 0; + info->algorithm = NCCL_ALGO_UNDEF; + info->protocol = NCCL_PROTO_UNDEF; + info->userTuned = false; + memcpy(t, info, sizeof(struct ncclInfo)); + ncclIntruQueueSortEnqueue(&tasks->collQueue, t, collCmp); + tasks->workBytesTotal += info->count * ncclTypeSize(info->datatype); tasks->nTasksColl += 1; } } diff --git a/projects/rccl/src/graph/connect.cc b/projects/rccl/src/graph/connect.cc index 5af0020eda..86efcbaf4f 100644 --- a/projects/rccl/src/graph/connect.cc +++ b/projects/rccl/src/graph/connect.cc @@ -19,6 +19,7 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, struct ncclTopoGraph** graphs int localRanks = comm->topo->nodes[GPU].count; int nChannels = comm->nChannels; + topoRanks->nvlsHeadNum = 0; for (int c=0; cchannels+c; channel->ring.prev = channel->ring.next = -1; @@ -30,20 +31,20 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, struct ncclTopoGraph** graphs channel->collnetDirect.headRank = -1; channel->collnetDirect.nHeads = 0; channel->collnetDirect.shift = 0; + for (int i=0; icollnetDirect.heads[i] = -1; for (int i=0; icollnetDirect.up[i] = -1; for (int i=0; icollnetDirect.down[i] = -1; int* ringIntra = graphs[NCCL_ALGO_RING]->intra+c*localRanks; int* treeIntra = graphs[NCCL_ALGO_TREE]->intra+c*localRanks; int* collNetIntra = graphs[NCCL_ALGO_COLLNET_CHAIN]->intra+c*localRanks; - int* nvlsIntra = graphs[NCCL_ALGO_NVLS]->intra+c*localRanks; for (int i=0; iringRecv[c] = ringIntra[0]; topoRanks->ringSend[c] = ringIntra[localRanks-1]; - channel->ring.prev = (i == 0) ? -1 : ringIntra[i-1]; - channel->ring.next = (i == localRanks-1) ? -1 : ringIntra[i+1]; + topoRanks->ringPrev[c] = (i == 0) ? -1 : ringIntra[i-1]; + topoRanks->ringNext[c] = (i == localRanks-1) ? -1 : ringIntra[i+1]; } if (treeIntra[i] == rank) { int parentIndex = 0; @@ -61,14 +62,28 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, struct ncclTopoGraph** graphs channel->collnetChain.down[0] = i == localRanks-1 ? -1 : collNetIntra[i+1]; } } - topoRanks->ringPrev[c] = channel->ring.prev; - topoRanks->ringNext[c] = channel->ring.next; - topoRanks->nvlsHeads[c] = nvlsIntra[0]; } - // Duplicate channels rings/trees + // Duplicate channels trees struct ncclChannel* channel0 = comm->channels; struct ncclChannel* channel1 = channel0+nChannels; memcpy(channel1, channel0, nChannels*sizeof(struct ncclChannel)); + + // Get nvls heads and the number of heads. Duplicate head is not allowed. + for (int c = 0; c < graphs[NCCL_ALGO_NVLS]->nChannels; ++c) { + bool addHead = true; + int* nvlsIntra = graphs[NCCL_ALGO_NVLS]->intra + c * localRanks; + + for (int dup = 0; dup < topoRanks->nvlsHeadNum; dup++) { + if (topoRanks->nvlsHeads[dup] == nvlsIntra[0]) { + addHead = false; + break; + } + } + if (addHead) { + topoRanks->nvlsHeads[topoRanks->nvlsHeadNum++] = nvlsIntra[0]; + } + } + return ncclSuccess; } @@ -80,26 +95,14 @@ static ncclResult_t connectRings(struct ncclComm* comm, int* ringRecv, int* ring int* send = ringSend+c*comm->nNodes; int* prev = ringPrev+c*comm->nRanks; int* next = ringNext+c*comm->nRanks; - struct ncclChannel* channel0 = comm->channels+c; - struct ncclChannel* channel1 = channel0+nChannels; for (int n=0; nrank == recvRank) { - channel0->ring.prev = prevSendRank; - channel1->ring.prev = prevSendRank; - } int sendRank = send[n]; int nextRecvRank = recv[(n+1)%nNodes]; next[sendRank] = nextRecvRank; - if (comm->rank == sendRank) { - channel0->ring.next = nextRecvRank; - channel1->ring.next = nextRecvRank; - } } - TRACE(NCCL_GRAPH, "Ring %d : %d -> %d -> %d", c, channel0->ring.prev, comm->rank, channel0->ring.next); - TRACE(NCCL_GRAPH, "Ring %d : %d -> %d -> %d", c+nChannels, channel1->ring.prev, comm->rank, channel1->ring.next); } return ncclSuccess; } @@ -209,6 +212,15 @@ static ncclResult_t connectCollNet(struct ncclComm* comm, struct ncclTopoGraph* channel->collnetDirect.up[nUp++] = heads[h]; sprintf(line+strlen(line), " %d ", heads[h]); } + sprintf(line+strlen(line), "heads "); + { // heads[] is the list of heads ordered in head order startubg with self + int h0 = (channel->collnetDirect.headRank == -1) ? 0 : channel->collnetDirect.headRank; + for (int h1=0; h1 < nHeads; h1++) { + int h = (h0+h1)%nHeads; + channel->collnetDirect.heads[h1] = heads[h]; + sprintf(line+strlen(line), " %d ", heads[h]); + } + } channel->collnetDirect.nHeads = nHeads; channel->collnetDirect.shift = (rank%localRanks)%nHeads; // Shift by intraRank so that leaves don't send to same head simultaneously channel->collnetDirect.depth = (nUp == 0 && nDown == 0) ? 1 : 2; @@ -217,27 +229,22 @@ static ncclResult_t connectCollNet(struct ncclComm* comm, struct ncclTopoGraph* INFO(NCCL_GRAPH, "%s", line); channel->collnetChain.depth = comm->nRanks/comm->nNodes; } - for (int c=0; cnvlsChannels; c++) { - struct ncclChannel* channel = comm->channels+c; - if (channel->nvls.headRank != -1) channel->nvls.out = comm->nRanks; - } free(heads); return ncclSuccess; } -static ncclResult_t connectNvls(struct ncclComm* comm, int* nvlsHeads, struct ncclTopoGraph* nvlsGraph) { - int nHeads = nvlsGraph->nChannels; +static ncclResult_t connectNvls(struct ncclComm* comm, int* nvlsHeads, int nHeads) { int headRank = -1; - for (int h=0; hintra[h*comm->localRanks] == comm->rank) headRank = h; - } - if (nHeads == 0) { comm->nvlsChannels = 0; return ncclSuccess; } - for (int c=0; cnvlsChannels; c++) { + for (int h = 0; h < nHeads; h++) { + if (nvlsHeads[h * comm->nNodes + comm->node] == comm->rank) headRank = h; + } + + for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; channel->nvls.nHeads = nHeads; for (int h=0; hnvls.up[h] = comm->nRanks+1+h; @@ -248,8 +255,10 @@ static ncclResult_t connectNvls(struct ncclComm* comm, int* nvlsHeads, struct nc channel->nvls.treeUp = channel->nvls.treeDown[0] = channel->nvls.treeDown[1] = channel->nvls.treeDown[2] = -1; channel->nvls.node = comm->node; channel->nvls.nNodes = comm->nNodes; + if (comm->collNetSupport && channel->nvls.headRank != -1) channel->nvls.out = comm->nRanks; } - if (comm->nNodes == 1) return ncclSuccess; + // MNNVL: NVLS not yet supported + if (comm->nNodes == 1 || comm->MNNVL) return ncclSuccess; // Connect Trees int tree0Parent, tree0Child0, tree0Child1, tree1Parent, tree1Child0, tree1Child1; @@ -290,7 +299,7 @@ static ncclResult_t connectNvls(struct ncclComm* comm, int* nvlsHeads, struct nc } // Set prev/next in all channels (NVLS compute channels work // orthogonally to NVLS search channels). - for (int c=0; cnvlsChannels; c++) { + for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; channel->nvls.treeUp = treeUp[c%2]; channel->nvls.treeDown[0] = channel->nvls.down; @@ -348,12 +357,19 @@ static int copyChannels(struct ncclComm* comm, int start, int end, int* ringPrev return c; } +void exchangeValues(int* v0, int* v1) { + int tmp = *v1; + *v1 = *v0; + *v0 = tmp; +} + ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePatterns, struct ncclTopoRanks** allTopoRanks, int* rings, struct ncclTopoGraph** graphs) { // Gather data from all ranks int *ringRecv, *ringSend, *ringPrev, *ringNext, *treeToParent, *treeToChild0, *treeToChild1, *nvlsHeads; int nranks = comm->nRanks; int nNodes = comm->nNodes; int nChannels = comm->nChannels; + int minHeadNum = INT_MAX; NCCLCHECK(ncclCalloc(&ringRecv, nNodes*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&ringSend, nNodes*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&ringPrev, nranks*MAXCHANNELS)); @@ -362,6 +378,22 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa NCCLCHECK(ncclCalloc(&treeToChild0, nNodes*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&treeToChild1, nNodes*MAXCHANNELS)); NCCLCHECK(ncclCalloc(&nvlsHeads, nNodes*MAXCHANNELS)); + + // Alternate rings to avoid crossing rails + if (graphs[NCCL_ALGO_RING]->crossNic && (comm->nNodes % 2) == 0 && (nChannels % 2) == 0) { + for (int r=0; rnRanks; r++) { + if (comm->rankToNode[r] % 2 == 1) { + // Exchange rings + for (int c=0; cringRecv+c, allTopoRanks[r]->ringRecv+(c^1)); + exchangeValues(allTopoRanks[r]->ringSend+c, allTopoRanks[r]->ringSend+(c^1)); + exchangeValues(allTopoRanks[r]->ringPrev+c, allTopoRanks[r]->ringPrev+(c^1)); + exchangeValues(allTopoRanks[r]->ringNext+c, allTopoRanks[r]->ringNext+(c^1)); + } + } + } + } + for (int c=0; cringNext[c]; } } - for (int c=0; cnChannels; c++) { - for (int n=0; n allTopoRanks[r]->nvlsHeadNum) + minHeadNum = allTopoRanks[r]->nvlsHeadNum; + } + + for (int c = 0; c < minHeadNum; c++) { + for (int n = 0; n < nNodes; n++) { int r = firstRanks[n]; - nvlsHeads[c*nNodes+n] = allTopoRanks[r]->nvlsHeads[c]; + nvlsHeads[c * nNodes + n] = allTopoRanks[r]->nvlsHeads[c]; } } // Connect rings and trees. This should also duplicate the channels. NCCLCHECK(connectRings(comm, ringRecv, ringSend, ringPrev, ringNext)); NCCLCHECK(connectTrees(comm, treeToParent, treeToChild0, treeToChild1, treePatterns)); - NCCLCHECK(connectNvls(comm, nvlsHeads, graphs[NCCL_ALGO_NVLS])); // Duplicate ringPrev/ringNext for ncclBuildRing memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int)); memcpy(ringNext+nChannels*nranks, ringNext, nChannels*nranks*sizeof(int)); + // Set ring prev/next for my rank + for (int c=0; cchannels+c; + struct ncclChannel* channel1 = channel0+nChannels; + channel0->ring.prev = channel1->ring.prev = ringPrev[c*nranks+comm->rank]; + channel0->ring.next = channel1->ring.next = ringNext[c*nranks+comm->rank]; + } + // Duplication should be complete now nChannels = comm->nChannels = std::min(MAXCHANNELS,nChannels*2); @@ -407,7 +453,7 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa } // Use 4 compute channels per search channel to reach peak BW on <8 PPN - if (comm->minCompCap == 90 && comm->nNodes > 1 && graphs[NCCL_ALGO_RING]->bwIntra > 45.0 && 2*nChannels <= MAXCHANNELS) { + if (comm->minCompCap == 90 && comm->nNodes > 1 && graphs[NCCL_ALGO_RING]->bwIntra > 45.0 && nChannels < 16) { nChannels = comm->nChannels = copyChannels(comm, nChannels, 2*nChannels, ringPrev, ringNext); } @@ -422,6 +468,13 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa nChannels = comm->nChannels = copyChannels(comm, nChannels, std::max(ncclMinNchannels(), comm->config.minCTAs), ringPrev, ringNext); } + comm->collChannels = comm->nChannels; + // Support maximal channel usage for aggregation + if (comm->nChannels < comm->nvlsChannels) { + nChannels = comm->nChannels = copyChannels(comm, comm->nChannels, comm->nvlsChannels, ringPrev, ringNext); + } + NCCLCHECK(connectNvls(comm, nvlsHeads, minHeadNum)); + // Create rings array and check all is fine NCCLCHECK(ncclBuildRings(nChannels, rings, comm->rank, comm->nRanks, ringPrev, ringNext)); diff --git a/projects/rccl/src/graph/paths.cc b/projects/rccl/src/graph/paths.cc index 42be5919ed..dea2e70869 100644 --- a/projects/rccl/src/graph/paths.cc +++ b/projects/rccl/src/graph/paths.cc @@ -341,6 +341,23 @@ compare: return ncclSuccess; } +// MNNVL: Check whether peers are in the same fabric cluster and clique +ncclResult_t ncclTopoCheckMNNVL(struct ncclTopoSystem* system, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* ret) { + *ret = 0; + + nvmlGpuFabricInfoV_t *fabricInfo1 = &info1->fabricInfo; + nvmlGpuFabricInfoV_t *fabricInfo2 = &info2->fabricInfo; + // A zero UUID means we don't have MNNVL fabric info + if ((((long *)&fabricInfo2->clusterUuid)[0]|((long *)fabricInfo2->clusterUuid)[1]) == 0) return ncclSuccess; + if ((memcmp(fabricInfo1->clusterUuid, fabricInfo2->clusterUuid, NVML_GPU_FABRIC_UUID_LEN) == 0) && + (fabricInfo1->cliqueId == fabricInfo2->cliqueId)) { + INFO(NCCL_NET, "MNNVL matching peer 0x%lx UUID %lx.%lx cliqueId 0x%x", + info2->busId, ((long *)fabricInfo2->clusterUuid)[0], ((long *)fabricInfo2->clusterUuid)[1], fabricInfo2->cliqueId); + *ret = 1; + } + return ncclSuccess; +} + NCCL_PARAM(NetGdrRead, "NET_GDR_READ", -2); int ncclTopoUserGdrLevel = -1; @@ -652,7 +669,8 @@ ncclResult_t ncclTopoTrimSystem(struct ncclTopoSystem* system, struct ncclComm* NCCLCHECK(ncclTopoRemoveNode(system, GPU, g)); } - if (system->nodes[GPU].count == comm->nRanks) { + // MNNVL: Remove network nodes as they are connected via NVLink + if (system->nodes[GPU].count == comm->nRanks || comm->MNNVL) { for (int n=system->nodes[NET].count-1; n>=0; n--) NCCLCHECK(ncclTopoRemoveNode(system, NET, n)); } @@ -666,10 +684,11 @@ void ncclTopoFree(struct ncclTopoSystem* system) { free(system); } -NCCL_PARAM(NChannelsPerNetPeer, "NCHANNELS_PER_NET_PEER", 2); +NCCL_PARAM(NChannelsPerNetPeer, "NCHANNELS_PER_NET_PEER", -1); -static ncclResult_t ncclTopoGetNchannels(struct ncclTopoSystem* system, int g /*local gpu index*/, int peerRank, int* nChannels) { +static ncclResult_t ncclTopoGetNchannels(struct ncclComm* comm, int g /*local gpu index*/, int peerRank, int* nChannels) { int peer; + struct ncclTopoSystem* system = comm->topo; struct ncclTopoLinkList* path = NULL; if (ncclTopoRankToIndex(system, peerRank, &peer) == ncclSuccess) { // Same rank @@ -685,9 +704,28 @@ static ncclResult_t ncclTopoGetNchannels(struct ncclTopoSystem* system, int g /* } else { *nChannels = 2; } + } else if (comm->MNNVL) { + // MNNVL assume all GPUs are connected via NVLink + path = system->nodes[GPU].nodes[g].paths[GPU]+((g+1)%system->nodes[GPU].count); + float nvlBw = ncclTopoNVLinkBw(system->nodes[GPU].nodes[g].gpu.cudaCompCap); + *nChannels = 2*std::max(1, (int)(path->bw / nvlBw)); } else { // Remote rank, use network - *nChannels = ncclParamNChannelsPerNetPeer(); + int nNetChannels = ncclParamNChannelsPerNetPeer(); + if (nNetChannels == -1) { + //start from 2 channels per NIC and reduce with scale + nNetChannels = 2; + + // check if we need to use more than one NIC, hence more than one channel + int netCountByBw = 1, nChannelsMax = nNetChannels; + NCCLCHECK(getLocalNetCountByBw(system, g, &netCountByBw)); + // Avoid overloading channels with 8+ operations as we loose the sync warp, hence a bit of bandwidth. + while (nChannelsMax*comm->nRanks > comm->p2pnChannels*4 && nChannelsMax > 1) nChannelsMax /= 2; + + //allow upto channels requires to drive the NICs + nNetChannels = std::max(netCountByBw, nChannelsMax); + } + *nChannels = nNetChannels; } return ncclSuccess; } @@ -716,7 +754,7 @@ ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) { for (int g=0; gtopo->nodes[GPU].count; g++) { for (int r=0; rnRanks; r++) { int nChannels; - NCCLCHECK(ncclTopoGetNchannels(comm->topo, g, r, &nChannels)); + NCCLCHECK(ncclTopoGetNchannels(comm, g, r, &nChannels)); if (nChannels >= 0) minChannels = std::min(minChannels, nChannels); } } diff --git a/projects/rccl/src/graph/search.cc b/projects/rccl/src/graph/search.cc index 3ebb0d4204..c3287b0bae 100644 --- a/projects/rccl/src/graph/search.cc +++ b/projects/rccl/src/graph/search.cc @@ -372,13 +372,12 @@ ncclResult_t ncclTopoCompareGraphs(struct ncclTopoSystem* system, struct ncclTop return ncclSuccess; } // 2. Try to get better bandwidth - // Give a 15% perf bonus to paths not crossing nics - float target = 1.0 - (refGraph->crossNic - graph->crossNic) * .15; - if (graph->nChannels*graph->bwIntra > refGraph->nChannels*refGraph->bwIntra*target) { + // Give a 5% perf bonus to paths not crossing nics + if (graph->nChannels*graph->bwIntra > refGraph->nChannels*refGraph->bwIntra) { *copy = 1; return ncclSuccess; } - if (graph->nChannels*graph->bwIntra < refGraph->nChannels*refGraph->bwIntra*target) return ncclSuccess; + if (graph->nChannels*graph->bwIntra < refGraph->nChannels*refGraph->bwIntra) return ncclSuccess; // 3. Less hops if (graph->pattern == refGraph->pattern && graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops) *copy = 1; @@ -484,6 +483,7 @@ ncclResult_t ncclTopoSearchRecGpu(struct ncclTopoSystem* system, struct ncclTopo struct ncclTopoNode* net = system->nodes[NET].nodes+n; if (graph->pattern == NCCL_TOPO_PATTERN_TREE && net->id != startNet->id) continue; // Trees are symmetric if (graph->crossNic != 1 && (net->net.asic != startNet->net.asic || net->net.port != startNet->net.port)) continue; + if (graph->crossNic && (graph->nChannels & 1) && net->id != graph->inter[(graph->nChannels-1)*2]) continue; // Balanced Tree : count half of the bandwidth on first two GPUs int nextBackToNet = -1; @@ -555,6 +555,7 @@ ncclResult_t ncclTopoSearchRecNet(struct ncclTopoSystem* system, struct ncclTopo struct ncclTopoNode* net = system->nodes[NET].nodes+n; if (graph->collNet && net->net.collSupport == 0) continue; if (net->net.bw < bw) continue; + if (graph->crossNic && (graph->nChannels & 1) && net->id != graph->inter[(graph->nChannels-1)*2+1]) continue; graph->inter[graph->nChannels*2] = net->id; graph->latencyInter = net->net.latency; @@ -1071,16 +1072,29 @@ ncclResult_t ncclTopoDumpGraphs(struct ncclTopoSystem* system, int ngraphs, stru #include "comm.h" // NVLS channels aren't compute channels. Find which NIC corresponds to our rank being the head -ncclResult_t getNvlsNetDev(struct ncclComm* comm, struct ncclTopoGraph* graph, int* dev) { +ncclResult_t getNvlsNetDev(struct ncclComm* comm, struct ncclTopoGraph* graph, int channelId, int* dev) { + ncclResult_t ret = ncclSuccess; int localRanks = comm->topo->nodes[GPU].count; - for (int c=0; cnChannels; c++) { - if (graph->intra[c*localRanks] == comm->rank) { - *dev = graph->inter[c*2]; - return ncclSuccess; + int netNum = 0; + int net[MAXCHANNELS]; + + for (int c = 0; c < graph->nChannels; c++) { + if (graph->intra[c * localRanks] == comm->rank) { + net[netNum++] = graph->inter[c * 2]; } } + if (netNum) { + *dev = net[channelId % netNum]; + } else { + ret = ncclInternalError; + goto fail; + } + +exit: + return ret; +fail: WARN("Could not find NIC for rank %d in NVLS graph\n", comm->rank); - return ncclInternalError; + goto exit; } // 0: don't use PXN for P2P, 1: use PXN if needed, 2: use PXN as much as possible to maximize aggregation @@ -1095,7 +1109,7 @@ ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoG if (graph->pattern != NCCL_TOPO_PATTERN_NVLS) { *dev = graph->inter[channel*2+index]; } else { - NCCLCHECK(getNvlsNetDev(comm, graph, dev)); + NCCLCHECK(getNvlsNetDev(comm, graph, channelId, dev)); } NCCLCHECK(ncclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank)); } else if (peerRank == -1) { diff --git a/projects/rccl/src/graph/topo.cc b/projects/rccl/src/graph/topo.cc index 481def486b..402fa15ee2 100644 --- a/projects/rccl/src/graph/topo.cc +++ b/projects/rccl/src/graph/topo.cc @@ -180,12 +180,17 @@ ncclResult_t ncclTopoConnectNodes(struct ncclTopoNode* node, struct ncclTopoNode // even though they're supposed to sustain full BW across all ports. // Flatten the switch as this extra level can break the search and make // NCCL take wrong topology decisions. +int getBcmGen(uint64_t id, int level) { + if ((id & 0xfffffffffffff000) == 0x1000c0101000a000) return 4; + if ((id & 0xfffffffffffff000) == (0x1000c03010000000 | level*0x1000)) return 5; + return 0; +} ncclResult_t ncclTopoFlattenBcmSwitches(struct ncclTopoSystem* system) { for (int s=0; snodes[PCI].count; s++) { struct ncclTopoNode* pciSwitch = system->nodes[PCI].nodes+s; - uint64_t device = pciSwitch->pci.device; - // Only flatten PEX Gen 4 switches in base mode - if ((device & 0xfffffffffffff000) == 0x1000c0101000a000) { + int gen = getBcmGen(pciSwitch->pci.device, 0); + // Flatten Gen4 PEX switches in base mode + if (gen) { // Find sub switches with the same device ID. int64_t* subSwIds; NCCLCHECK(ncclCalloc(&subSwIds, pciSwitch->nlinks)); @@ -193,7 +198,7 @@ ncclResult_t ncclTopoFlattenBcmSwitches(struct ncclTopoSystem* system) { for (int l=0; lnlinks; l++) { struct ncclTopoNode* sub = pciSwitch->links[l].remNode; // Only fuse sub switches with the same device ID. - if (sub->type != PCI || sub->pci.device != device) continue; + if (sub->type != PCI || getBcmGen(sub->pci.device, 1) != gen) continue; // Save sub switch for later subSwIds[subs++] = sub->id; // Remove link to that sub switch @@ -225,8 +230,8 @@ ncclResult_t ncclTopoFlattenBcmSwitches(struct ncclTopoSystem* system) { } NCCLCHECK(ncclTopoRemoveNode(system, PCI, index)); } - // Set subdevice to 0x0000 to make sure we don't merge this switch again. - pciSwitch->pci.device = 0x1000c01010000000; + // Set subdevice to 0xffff to make sure we don't merge this switch again. + pciSwitch->pci.device |= 0xffff; free(subSwIds); // Restart, as system->nodes[PCI].nodes has changed. s = 0; @@ -732,6 +737,30 @@ ncclResult_t ncclTopoGetLocal(struct ncclTopoSystem* system, int type, int index return ncclSuccess; } +ncclResult_t getLocalNetCountByBw(struct ncclTopoSystem* system, int gpu, int *count) { + int localNetCount = 0, netCountByBw = 0; + int* localNets; + float totalNetBw = 0, gpuBw = 0; + + for (int l=0; lnodes[GPU].nodes[gpu].nlinks; l++) { + //assuming BW to CPU reflects the GPU bandwidth via P2P or C2C + //caveat, this could be wrong if there is a PCIe switch, + //and a narrower link to the CPU + if (system->nodes[GPU].nodes[gpu].links[l].remNode->type == CPU) { + gpuBw = system->nodes[GPU].nodes[gpu].links[l].bw; + } + } + + NCCLCHECK(ncclTopoGetLocal(system, GPU, gpu, NET, &localNets, &localNetCount, NULL)); + for (int l=0; (l < localNetCount) && (totalNetBw < gpuBw); l++, netCountByBw++) { + totalNetBw += system->nodes[GPU].nodes[gpu].paths[NET][localNets[l]].bw; + } + *count = netCountByBw; + + free(localNets); + return ncclSuccess; +} + ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int channelId, int* id) { int gpu; NCCLCHECK(ncclTopoRankToIndex(system, rank, &gpu)); diff --git a/projects/rccl/src/graph/topo.h b/projects/rccl/src/graph/topo.h index b067f2f975..db1eb6e244 100644 --- a/projects/rccl/src/graph/topo.h +++ b/projects/rccl/src/graph/topo.h @@ -14,7 +14,7 @@ #define SM60_NVLINK_BW 18.0 #define SM70_NVLINK_BW 20.0 #define SM80_NVLINK_BW 20.0 -#define SM90_NVLINK_BW 20.0 +#define SM90_NVLINK_BW 20.6 #define SM86_NVLINK_BW 12.0 #define PCI_BW 12.0 // PCI Gen3 x16 #define QPI_BW 6.0 diff --git a/projects/rccl/src/graph/tuning.cc b/projects/rccl/src/graph/tuning.cc index a97ed9a1ad..7ca5922935 100644 --- a/projects/rccl/src/graph/tuning.cc +++ b/projects/rccl/src/graph/tuning.cc @@ -132,7 +132,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL128] = getNthreads("NCCL_LL128_NTHREADS", ncclParamLl128Nthreads(), NCCL_LL128_MAX_NTHREADS/4, NCCL_LL128_MAX_NTHREADS, NCCL_LL128_MAX_NTHREADS); - int nNodes = comm->nNodes; + // MNNVL support - treat as a single NVLink connected node + int nNodes = comm->MNNVL ? 1 : comm->nNodes; int nRanks = comm->nRanks; if (nRanks <= 1) return ncclSuccess; @@ -165,8 +166,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom for (int a=0; anChannels); // GPU/NIC ratio - factor -= (factor-1)/2; - busBw /= factor; + if (coll == ncclFuncAllGather || coll == ncclFuncReduceScatter) { + busBw = ppn * bw; + // AllGather/ReduceScatter requires 1:1 GPU:NIC + int nicPerNode = comm->collNetHeadsUniqueNum; + if (coll == ncclFuncAllGather && comm->nNodes > 1) { + if (!comm->ncclCollNet || !comm->ncclCollNet->iallgather || ppn > nicPerNode) busBw = 0; + } + if (coll == ncclFuncReduceScatter && comm->nNodes > 1) { + if (!comm->ncclCollNet || !comm->ncclCollNet->ireducescatter || ppn > nicPerNode) busBw = 0; + } + // Measured corrective ratio needed at 1 ppn and 8ppn. Here we hackishly + // interpolate the two. + float w = (ppn-1)/(8-1); + busBw *= w*0.85 + (1-w)*0.95; + } else { + // Collnet+Direct requires all GPUs to have a local NIC to work at full speed + float factor = ppn / (1.0*graphs[a]->nChannels); // GPU/NIC ratio + factor -= (factor-1)/2; + busBw /= factor; + if (minCompCap >= 90) busBw *= .85; + } } - if (a == NCCL_ALGO_COLLNET_DIRECT && p == NCCL_PROTO_SIMPLE && minCompCap >= 90) busBw *= .85; // Convert bus BW to algorithm BW - float ratio; - if (a == NCCL_ALGO_RING) ratio = (1.0 * nRanks) / nsteps; - else if (a == NCCL_ALGO_NVLS || a == NCCL_ALGO_NVLS_TREE) ratio = 5.0/6.0; - else ratio = .5; - comm->bandwidths[coll][a][p] = busBw * ratio; + if (!(a == NCCL_ALGO_COLLNET_DIRECT && (coll == ncclFuncAllGather || coll == ncclFuncReduceScatter))) { + float ratio = 1.0f; + if (a == NCCL_ALGO_RING) ratio *= (1.0 * nRanks) / nsteps; + else if (a == NCCL_ALGO_NVLS || a == NCCL_ALGO_NVLS_TREE) ratio *= 5.0/6.0; + else ratio *= .5; + busBw *= ratio; + } + comm->bandwidths[coll][a][p] = busBw; /* Ring bandwidth backup */ if (a == NCCL_ALGO_RING) comm->ringbdw[coll][p] = comm->bandwidths[coll][NCCL_ALGO_RING][p]; @@ -262,18 +282,19 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom NCCLCHECK(parseList(algoStr, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoEnable)); } - if (comm->nNodes == 1) algoEnable[NCCL_ALGO_NVLS_TREE] = 0; + // MNNVL: NVLS not yet supported + if (comm->nNodes == 1 || comm->MNNVL) algoEnable[NCCL_ALGO_NVLS_TREE] = 0; // Disable CollNet if it is not supported if (comm->collNetSupport == 0) { algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0; algoEnable[NCCL_ALGO_COLLNET_CHAIN] = 0; - if (comm->nNodes > 1) algoEnable[NCCL_ALGO_NVLS] = 0; + // MNNVL: NVLS not yet supported + if (comm->nNodes > 1 || comm->MNNVL) algoEnable[NCCL_ALGO_NVLS] = 0; // If user has hard set NCCL_ALGO=COLLNET, ignore it if (algoEnable[NCCL_ALGO_RING] == 0 && algoEnable[NCCL_ALGO_TREE] == 0 && algoEnable[NCCL_ALGO_NVLS] == 0 && algoEnable[NCCL_ALGO_NVLS_TREE] == 0) { algoEnable[NCCL_ALGO_RING] = algoEnable[NCCL_ALGO_TREE] = 1; - if (comm->rank == 0) WARN("CollNet is not supported or fails to initialize, ignoring NCCL_ALGO=COLLNET"); } } else { // Disable CollNet+Direct if not on an NVSwitch system @@ -398,9 +419,9 @@ static float treeCorrectionFactor[NCCL_NUM_PROTOCOLS][23] = { }; ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time, bool* backup) { - float bw = info->comm->bandwidths[info->coll][algorithm][protocol]; + float bw = info->comm->bandwidths[info->coll][algorithm][protocol]; float lat = info->comm->latencies[info->coll][algorithm][protocol]; - + if (backup) { *backup = false; if (algorithm == NCCL_ALGO_RING && bw == 0.0f) { @@ -416,7 +437,7 @@ ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int proto int logSize = log2i(info->nBytes>>6); if (algorithm == NCCL_ALGO_TREE && logSize < 23) bw *= treeCorrectionFactor[protocol][logSize]; if (info->nChannels != 0) bw = bw / info->comm->nChannels * info->nChannels; - if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && info->comm->nNodes > 1 + if (algorithm == NCCL_ALGO_RING && protocol == NCCL_PROTO_SIMPLE && (!info->comm->MNNVL && info->comm->nNodes > 1) && info->coll == ncclFuncAllReduce && info->nBytes/(info->comm->nChannels*info->comm->nRanks) >= 64) { lat *= info->comm->minCompCap < 80 ? 1.9 : 1.4; // Plateau effect of ring } diff --git a/projects/rccl/src/graph/xml.cc b/projects/rccl/src/graph/xml.cc index ffce2e64bb..6f4d9ea618 100644 --- a/projects/rccl/src/graph/xml.cc +++ b/projects/rccl/src/graph/xml.cc @@ -592,8 +592,8 @@ ncclResult_t ncclTopoGetXmlFromSys(struct ncclXmlNode* pciNode, struct ncclXml* NCCLCHECK(xmlGetAttrStr(pciNode, "busid", &newBusId)); for (int s=0; snSubs; s++) { const char* busId; - NCCLCHECK(xmlGetAttrStr(parent->subs[s], "busid", &busId)); - if (strcmp(newBusId, busId) < 0) { subIndex = s; break; } + NCCLCHECK(xmlGetAttr(parent->subs[s], "busid", &busId)); + if (busId != NULL && strcmp(newBusId, busId) < 0) { subIndex = s; break; } } for (int s = parent->nSubs; s > subIndex; s--) parent->subs[s] = parent->subs[s-1]; parent->subs[subIndex] = pciNode; diff --git a/projects/rccl/src/group.cc b/projects/rccl/src/group.cc index 29400d6bcb..81fbf13045 100644 --- a/projects/rccl/src/group.cc +++ b/projects/rccl/src/group.cc @@ -235,9 +235,9 @@ static void groupCleanup(struct ncclComm** groupCommHeadPtr, struct ncclComm** g // Reset comm->tasks to empty. comm->tasks.nTasksColl = 0; comm->tasks.nTasksP2p = 0; + comm->tasks.workBytesTotal = 0; comm->tasks.streams = nullptr; ncclIntruQueueConstruct(&comm->tasks.collQueue); - comm->tasks.collBytesTotal = 0; for (int i = 0; i < comm->nRanks; i++) { ncclIntruQueueConstruct(&comm->tasks.peers[i].sendQueue); ncclIntruQueueConstruct(&comm->tasks.peers[i].recvQueue); @@ -321,9 +321,9 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) { assert(state == ncclGroupJobJoined); } - if (*groupAbortFlag == true || errorJobAbortFlag == true) { - *job->abortFlag = 1; - if (job->childAbortFlag) *job->childAbortFlag = 1; + if (__atomic_load_n(groupAbortFlag, __ATOMIC_RELAXED) || errorJobAbortFlag == true) { + __atomic_store_n(job->abortFlag, 1, __ATOMIC_RELAXED); + if (job->childAbortFlag) __atomic_store_n(job->childAbortFlag, 1, __ATOMIC_RELAXED); } job = job->next; @@ -438,7 +438,7 @@ ncclResult_t ncclGroupJobComplete(struct ncclGroupJob* groupJob) { ncclResult_t ncclGroupJobAbort(struct ncclGroupJob* groupJob) { if (groupJob && groupJob->initialized) { - *groupJob->abortFlagPtr = true; + __atomic_store_n(groupJob->abortFlagPtr, true, __ATOMIC_RELAXED); NCCLCHECK(ncclGroupJobComplete(groupJob)); } return ncclSuccess; diff --git a/projects/rccl/src/include/alloc.h b/projects/rccl/src/include/alloc.h index f8d954469e..aa522ea1a0 100644 --- a/projects/rccl/src/include/alloc.h +++ b/projects/rccl/src/include/alloc.h @@ -85,13 +85,14 @@ static inline ncclResult_t ncclCuMemAlloc(void **ptr, CUmemGenericAllocationHand CUmemAllocationProp prop = {}; CUmemAccessDesc accessDesc = {}; CUmemGenericAllocationHandle handle; + CUmemAllocationHandleType type = ncclCuMemHandleType; int cudaDev; int flag = 0; CUDACHECK(cudaGetDevice(&cudaDev)); CUCHECK(cuDeviceGet(¤tDev, cudaDev)); prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.requestedHandleTypes = NCCL_P2P_HANDLE_TYPE; // So it can be exported + prop.requestedHandleTypes = type; prop.location.id = currentDev; // Query device to see if RDMA support is available CUCHECK(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED, currentDev)); diff --git a/projects/rccl/src/include/argcheck.h b/projects/rccl/src/include/argcheck.h index 8d8b74e8e4..e4bdc32e21 100644 --- a/projects/rccl/src/include/argcheck.h +++ b/projects/rccl/src/include/argcheck.h @@ -12,5 +12,6 @@ ncclResult_t PtrCheck(void* ptr, const char* opname, const char* ptrname); ncclResult_t ArgsCheck(struct ncclInfo* info); +ncclResult_t CudaPtrCheck(const void* pointer, struct ncclComm* comm, const char* ptrname, const char* opname); #endif diff --git a/projects/rccl/src/include/coll_net.h b/projects/rccl/src/include/coll_net.h index f4b5408669..affbf0a24a 100644 --- a/projects/rccl/src/include/coll_net.h +++ b/projects/rccl/src/include/coll_net.h @@ -19,9 +19,9 @@ static ncclResult_t collNetGetProperties(struct ncclComm* comm, int dev, ncclNet static ncclResult_t collNetListen(struct ncclComm* comm, int dev, void* handle, void** listenComm) { NCCLCHECK(comm->ncclCollNet->listen(dev, handle, listenComm)); return ncclSuccess; } static ncclResult_t collNetConnect(struct ncclComm* comm, void* handles[], int nranks, int rank, void* listenComm, void** collComm) { NCCLCHECK(comm->ncclCollNet->connect(handles, nranks, rank, listenComm, collComm)); return ncclSuccess; } static ncclResult_t collNetReduceSupport(struct ncclComm* comm, ncclDataType_t dataType, ncclRedOp_t redOp, int* supported) { NCCLCHECK(comm->ncclCollNet->reduceSupport(dataType, redOp, supported)); return ncclSuccess; } -static ncclResult_t collNetRegMr(struct ncclComm* comm, void* collComm, void* data, int size, int type, void** mhandle) { NCCLCHECK(comm->ncclCollNet->regMr(collComm, data, size, type, mhandle)); return ncclSuccess; } +static ncclResult_t collNetRegMr(struct ncclComm* comm, void* collComm, void* data, size_t size, int type, void** mhandle) { NCCLCHECK(comm->ncclCollNet->regMr(collComm, data, size, type, mhandle)); return ncclSuccess; } /* DMA-BUF support */ -static ncclResult_t collNetRegMrDmaBuf(struct ncclComm* comm, void* collComm, void* data, int size, int type, uint64_t offset, int fd, void** mhandle) { NCCLCHECK(comm->ncclCollNet->regMrDmaBuf(collComm, data, size, type, offset, fd, mhandle)); return ncclSuccess; } +static ncclResult_t collNetRegMrDmaBuf(struct ncclComm* comm, void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { NCCLCHECK(comm->ncclCollNet->regMrDmaBuf(collComm, data, size, type, offset, fd, mhandle)); return ncclSuccess; } static ncclResult_t collNetDeregMr(struct ncclComm* comm, void* collComm, void* mhandle) { NCCLCHECK(comm->ncclCollNet->deregMr(collComm, mhandle)); return ncclSuccess; } static ncclResult_t collNetIallreduce(struct ncclComm* comm, void* collComm, void* sendData, void* recvData, int count, ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { NCCLCHECK(comm->ncclCollNet->iallreduce(collComm, sendData, recvData, count, dataType, redOp, sendMhandle, recvMhandle, request)); return ncclSuccess; } diff --git a/projects/rccl/src/include/collectives.h b/projects/rccl/src/include/collectives.h index 0f965276a4..888df728f4 100644 --- a/projects/rccl/src/include/collectives.h +++ b/projects/rccl/src/include/collectives.h @@ -45,4 +45,15 @@ inline int ncclTypeSize(ncclDataType_t type) { } } +#include + +#define NCCL_MODE_NORMAL 0 +#define NCCL_MODE_OFFSET 1 +#define NCCL_MODE_PTR 2 +struct ncclConnFifo { + int mode; + int offset; + ssize_t size; + void* ptr; +}; #endif diff --git a/projects/rccl/src/include/comm.h b/projects/rccl/src/include/comm.h index 328ffef3b5..81da1ad6ae 100644 --- a/projects/rccl/src/include/comm.h +++ b/projects/rccl/src/include/comm.h @@ -14,6 +14,7 @@ #include "proxy.h" #include "strongstream.h" #include "nccl_net.h" +#include "register.h" #if CUDART_VERSION < 9000 struct cudaLaunchParams { @@ -54,8 +55,7 @@ struct ncclRecvMem { struct { uint64_t tail; char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)]; - int sizesFifo[NCCL_STEPS]; - int offsFifo[NCCL_STEPS]; + struct ncclConnFifo connFifo[NCCL_STEPS]; int flush; // For GDRCopy-based flush }; char pad4[MEM_ALIGN]; @@ -169,7 +169,6 @@ struct ncclKernelPlan { // A kernel plan is also a callback that reclaims itself. Hence this must // be the first member. struct ncclCommCallback reclaimer; - struct ncclMemoryPool memPool_ncclProxyOp; // memory to return to comm in cleanup struct ncclComm* comm; struct ncclKernelPlan* next; @@ -200,23 +199,7 @@ struct ncclKernelPlan { struct ncclIntruQueue workQueue; struct ncclIntruQueue proxyOpQueue; } channels[MAXCHANNELS]; -}; - -struct ncclRegRequest { - uintptr_t buff; - size_t size; - struct ncclRegRequest *next; -}; - -struct ncclRegRecord { - uintptr_t buff; - size_t size; - CUdeviceptr regAddr; - size_t regSize; - int dev; - CUmemGenericAllocationHandle mcHandle; - uintptr_t *addrs; /* use to check if NVLS buffers match among intra-node ranks */ - struct ncclRegRecord *next; + size_t maxBytesPerChannel; }; struct ncclComm { @@ -262,6 +245,7 @@ struct ncclComm { int* localRankToRank; // localRanks and localRanktoRank for all nodes struct ncclNodeRanks* nodeRanks; + int MNNVL; // MNNVL: Multi-Node NVLink bool checkPointers; bool dmaBufSupport; @@ -270,8 +254,9 @@ struct ncclComm { uint64_t opCount; // Channels for collectives - int nChannels; - int nvlsChannels; + int nChannels; // connection nChannels + int collChannels; // enqueue nChannels + int nvlsChannels; // enqueue nChannels int collNetChannels; // Channels (per peer) for p2p int p2pnChannels; @@ -334,6 +319,9 @@ struct ncclComm { int intraHighestTransportType; int* collNetHeads; int collNetHeadsNum; + int collNetHeadsUniqueNum; + int* collNetDenseToUserRank; + int* collNetUserToDenseRank; /* sharable collNet proxy progress resource. */ struct ncclCollNetSharedRes* collNetSharedRes; @@ -343,8 +331,6 @@ struct ncclComm { /* sharable NVLS resource. */ struct ncclNvlsSharedRes* nvlsResources; - ssize_t channelSize; // User requested work size (bytes) for channel partitions - // pools backed by comm->memPermanent struct ncclMemoryPool memPool_ncclProxyOp; struct ncclMemoryPool memPool_ncclKernelPlan; @@ -380,13 +366,10 @@ struct ncclComm { // group job to support multi-thread FT struct ncclGroupJob *groupJob; - /* store to buffer register request */ - struct ncclIntruQueue regRequestQueue; - /* store registered buffer */ - struct ncclIntruQueue regRecordQueue; - // Tuning plugin ncclTuner_t* tuner; + // buffer registration cache + struct ncclRegCache regCache; }; enum ncclLaunchMode { diff --git a/projects/rccl/src/include/cudawrap.h b/projects/rccl/src/include/cudawrap.h index cc363c1ac7..9350306d9a 100644 --- a/projects/rccl/src/include/cudawrap.h +++ b/projects/rccl/src/include/cudawrap.h @@ -16,6 +16,10 @@ extern int ncclCuMemEnable(); #if CUDART_VERSION >= 11030 #include + +// Handle type used for cuMemCreate() +extern CUmemAllocationHandleType ncclCuMemHandleType; + #else typedef CUresult (CUDAAPI *PFN_cuInit_v2000)(unsigned int Flags); typedef CUresult (CUDAAPI *PFN_cuDriverGetVersion_v2020)(int *driverVersion); diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index 56f8039f30..02ea883a95 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -96,8 +96,7 @@ struct ncclConnInfo { void **ptrExchange; // Pointer exchange for direct communication uint64_t* redOpArgExchange; // PreOp scaler exchange for direct pull case - int *sizesFifo; // Sizes fifo from GPU to proxy - int *offsFifo; // Buffer fifo from proxy to GPU + struct ncclConnFifo* connFifo; // Used for GPU - Proxy communication uint64_t step; // Keep where we are uint64_t llLastCleaning; @@ -151,6 +150,9 @@ struct ncclDirect { int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC) int shift; // Shuffling of send/recv for scatter/gather operations, basically localRank%nHeads + // The heads[...] are guaranteed to be in rotated order start with self: + // headRank, (headRank+1)%nHeads, (headRank+2)%nHeads, ... + int heads[NCCL_MAX_DIRECT_ARITY+1]; int up[NCCL_MAX_DIRECT_ARITY]; int down[NCCL_MAX_DIRECT_ARITY]; }; @@ -210,21 +212,28 @@ struct ncclWorkElem { union { uint8_t flagBits; struct { - uint8_t isUsed:1, redOpArgIsPtr:1, regUsed:1; + uint8_t isUsed:1, redOpArgIsPtr:1, regUsed:1, oneNode:1; }; }; uint8_t nWarps; uint8_t direct; - - const void * sendbuff; - void * recvbuff; + uint32_t root; + const void *sendbuff; + void *recvbuff; size_t count; - size_t lastChunkSize; - uint32_t root; - uint8_t bid; - uint8_t nChannels; uint64_t redOpArg; + uint64_t chunkCount:25, workCount:39; + union { + struct { + uint64_t lastChunkCount:25; + uint64_t workOffset:39; + }; + struct { + uint64_t bid:32; + uint64_t nChannels:32; + }; + }; }; #define NCCL_MAX_WORK_ELEMENTS ((NCCL_WORK_SIZE - alignUp(sizeof(ncclWorkHeader), alignof(ncclWorkElem)))/sizeof(ncclWorkElem)) @@ -235,7 +244,8 @@ struct ncclWorkElemP2p { int proto : 2; enum ncclWorkP2PType p2pType; - uint8_t nWarps; + uint8_t reg:1; + uint8_t nWarps:5; uint8_t warpStart; uint8_t ngroups; // Important not to use any fields with greater than 4-byte alignment since @@ -296,6 +306,8 @@ struct alignas(16) ncclDevChannel { struct ncclDevComm { int rank; int nRanks; + int node; + int nNodes; int buffSizes[NCCL_NUM_PROTOCOLS]; int p2pChunkSize; @@ -303,6 +315,8 @@ struct ncclDevComm { int workFifoDepth; struct ncclWork* workFifoHeap; // may be cudaHost or GDR memory + int* collNetDenseToUserRank; + // Flag to ask NCCL kernels to abort volatile uint32_t* abortFlag; @@ -415,46 +429,54 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) #else constexpr int NumTypes = ncclNumTypes + 1; #endif + int row; + do { + row = 0; // ncclDevFuncIndex_P2p + if (coll == ncclFuncSendRecv) break; + row += 1; - int row = 0; // ncclDevFuncIndex_P2p - if (coll == ncclFuncSendRecv) goto have_row; - row += 1; + int nAlgos = 3; + if (coll == ncclFuncAllGather) { + int algo1 = algo == NCCL_ALGO_RING ? 0 : + algo == NCCL_ALGO_COLLNET_DIRECT ? 1 : + /*algo == NCCL_ALGO_NVLS*/ 2; + row += algo1*NCCL_NUM_PROTOCOLS + proto; + break; + } + row += nAlgos*NCCL_NUM_PROTOCOLS; - if (coll == ncclFuncAllGather) { - int algo1 = algo == NCCL_ALGO_RING ? 0 : - /*algo == NCCL_ALGO_NVLS*/ 1; - row += algo1*NCCL_NUM_PROTOCOLS + proto; - goto have_row; - } - row += (/*NumAlgos=*/2)*NCCL_NUM_PROTOCOLS; + nAlgos = 1; + if (coll == ncclFuncBroadcast) { + row += proto; + break; + } + row += nAlgos*NCCL_NUM_PROTOCOLS; - if (coll == ncclFuncBroadcast) { - row += proto; - goto have_row; - } - row += (/*NumAlgos=*/1)*NCCL_NUM_PROTOCOLS; + nAlgos = NCCL_NUM_ALGORITHMS; + if (coll == ncclFuncAllReduce) { + row += ((devRedOp*NumTypes + type)*nAlgos + algo)*NCCL_NUM_PROTOCOLS + proto; + break; + } + row += ncclNumDevRedOps*NumTypes*nAlgos*NCCL_NUM_PROTOCOLS; - if (coll == ncclFuncAllReduce) { - row += ((devRedOp*NumTypes + type)*NCCL_NUM_ALGORITHMS + algo)*NCCL_NUM_PROTOCOLS + proto; - goto have_row; - } - row += ncclNumDevRedOps*NumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS; + nAlgos = 1; + if (coll == ncclFuncReduce) { + row += (devRedOp*NumTypes + type)*NCCL_NUM_PROTOCOLS + proto; + break; + } + row += ncclNumDevRedOps*NumTypes*nAlgos*NCCL_NUM_PROTOCOLS; - if (coll == ncclFuncReduce) { - row += (devRedOp*NumTypes + type)*NCCL_NUM_PROTOCOLS + proto; - goto have_row; - } - row += ncclNumDevRedOps*NumTypes*(/*NumAlgos=*/1)*NCCL_NUM_PROTOCOLS; + nAlgos = 3; + if (coll == ncclFuncReduceScatter) { + int algo1 = algo == NCCL_ALGO_RING ? 0 : + algo == NCCL_ALGO_COLLNET_DIRECT ? 1 : + /*algo == NCCL_ALGO_NVLS*/ 2; + row += ((devRedOp*NumTypes + type)*nAlgos + algo1)*NCCL_NUM_PROTOCOLS + proto; + break; + } + row += ncclNumDevRedOps*NumTypes*nAlgos*NCCL_NUM_PROTOCOLS; + } while (false); - if (coll == ncclFuncReduceScatter) { - int algo1 = algo == NCCL_ALGO_RING ? 0 : - /*algo == NCCL_ALGO_NVLS*/ 1; - row += ((devRedOp*NumTypes + type)*2 + algo1)*NCCL_NUM_PROTOCOLS + proto; - goto have_row; - } - row += ncclNumDevRedOps*NumTypes*(/*NumAlgos=*/2)*NCCL_NUM_PROTOCOLS; - -have_row: return ncclDevFuncRowToId[row]; } diff --git a/projects/rccl/src/include/enqueue.h b/projects/rccl/src/include/enqueue.h index 634f037cb3..8ab59607d6 100644 --- a/projects/rccl/src/include/enqueue.h +++ b/projects/rccl/src/include/enqueue.h @@ -12,8 +12,10 @@ #include "collectives.h" #include "utils.h" -#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) -#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ +#define NCCL_LL_ALIGNMENT_PER_THREAD sizeof(uint64_t) +#define NCCL_LL128_ALIGNMENT_PER_WARP 480 +#define NCCL_SIMPLE_ALIGNMENT (WARP_SIZE * 8LL * 16LL) +#define NCCL_BYTES_ALIGNMENT 16 ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize); ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); diff --git a/projects/rccl/src/include/graph.h b/projects/rccl/src/include/graph.h index fdd634894d..2a455e9e2b 100644 --- a/projects/rccl/src/include/graph.h +++ b/projects/rccl/src/include/graph.h @@ -33,6 +33,7 @@ int ncclTopoPathAllNVLink(struct ncclTopoSystem* system); // Query topology ncclResult_t ncclTopoGetNetDev(struct ncclComm* comm, int rank, struct ncclTopoGraph* graph, int channelId, int peerRank, int* net, int* proxyRank); ncclResult_t ncclTopoCheckP2p(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* p2p, int *read, int* intermediateRank); +ncclResult_t ncclTopoCheckMNNVL(struct ncclTopoSystem* system, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* ret); ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int64_t busId, int netDev, int read, int* useGdr); ncclResult_t ncclTopoNeedFlush(struct ncclTopoSystem* system, int64_t busId, int* flush); ncclResult_t ncclTopoCheckNet(struct ncclTopoSystem* system, int64_t id1, int64_t id2, int* net); @@ -53,10 +54,11 @@ ncclResult_t ncclTopoGetCpuAffinity(struct ncclTopoSystem* system, int rank, cpu #define NCCL_TOPO_CPU_TYPE_YONGFENG 1 ncclResult_t ncclTopoCpuType(struct ncclTopoSystem* system, int* arch, int* vendor, int* model); ncclResult_t ncclTopoGetGpuCount(struct ncclTopoSystem* system, int* count); -ncclResult_t ncclTopoGetNvsCount(struct ncclTopoSystem* system, int* count); +ncclResult_t ncclTopoGetNetCount(struct ncclTopoSystem* system, int* count); ncclResult_t ncclTopoGetNvsCount(struct ncclTopoSystem* system, int* count); ncclResult_t ncclTopoGetLocalNet(struct ncclTopoSystem* system, int rank, int channelId, int* id); ncclResult_t ncclTopoGetLocalGpu(struct ncclTopoSystem* system, int net, int* gpuIndex); +ncclResult_t getLocalNetCountByBw(struct ncclTopoSystem* system, int gpu, int *count); #define NCCL_TOPO_MAX_NODES 256 @@ -102,6 +104,7 @@ struct ncclTopoRanks { int treeToChild0[MAXCHANNELS]; int treeToChild1[MAXCHANNELS]; int nvlsHeads[MAXCHANNELS]; + int nvlsHeadNum; }; ncclResult_t ncclTopoPreset(struct ncclComm* comm, struct ncclTopoGraph** graphs, struct ncclTopoRanks* topoRanks); diff --git a/projects/rccl/src/include/info.h b/projects/rccl/src/include/info.h index f65ed2e698..3a3f4e5b15 100644 --- a/projects/rccl/src/include/info.h +++ b/projects/rccl/src/include/info.h @@ -13,6 +13,7 @@ #include "core.h" #include "utils.h" #include "strongstream.h" +#define NCCL_MAX_LOCAL_RANKS 64 typedef enum : uint8_t { ncclPatternRing, @@ -30,6 +31,13 @@ typedef enum : uint8_t { ncclPatternRecv } ncclPattern_t; +enum ncclRegBufferType { + NCCL_REGULAR_BUFFER = 0, + NCCL_IPC_REG_BUFFER = 1, + NCCL_NVLS_REG_BUFFER = 2, + NCCL_REG_BUFFER_NUM = 3 +}; + // Used to pass NCCL call information between functions struct ncclInfo { ncclFunc_t coll; @@ -48,37 +56,46 @@ struct ncclInfo { int sliceSteps; // Computed later ncclDevRedOpFull opFull; - int algorithm; - int protocol; ncclPattern_t pattern; - int nChannels; - int nThreads; size_t nBytes; + size_t aggnBytes; + size_t workBytes; size_t sendbuffSize; size_t recvbuffSize; - int nstepsPerLoop; - int nchunksPerLoop; + int stepSize; + int chunkCount; int chunkSize; int channelId; + int workFuncIndex; + ncclRegBufferType regBufType; + void* regBufSend[NCCL_MAX_LOCAL_RANKS]; + void* regBufRecv[NCCL_MAX_LOCAL_RANKS]; + // Need to initialize + int nThreads; + int nChannels; + int algorithm; + int protocol; + bool userTuned; + struct ncclInfo *next; }; inline ncclResult_t ncclInfoSetDerived(struct ncclInfo* info, int nRanks) { - info->nBytes = info->count * ncclTypeSize(info->datatype); + info->nBytes = info->workBytes = info->count * ncclTypeSize(info->datatype); if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast) { - info->count = info->nBytes; + info->count = info->workBytes; info->datatype = ncclInt8; } if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) info->nBytes *= nRanks; // count is per rank /* compute buffer size for NVLS buffer registration */ if (info->coll == ncclFuncAllGather) { - info->sendbuffSize = info->count * ncclTypeSize(info->datatype); + info->sendbuffSize = info->workBytes; info->recvbuffSize = info->sendbuffSize * nRanks; } else if (info->coll == ncclFuncReduceScatter) { - info->recvbuffSize = info->count * ncclTypeSize(info->datatype); + info->recvbuffSize = info->workBytes; info->sendbuffSize = info->recvbuffSize * nRanks; } else { - info->sendbuffSize = info->recvbuffSize = info->count * ncclTypeSize(info->datatype); + info->sendbuffSize = info->recvbuffSize = info->workBytes; } return ncclSuccess; } @@ -93,6 +110,7 @@ struct ncclTaskColl { ncclDataType_t datatype; ncclDevRedOpFull op; int chunkSteps, sliceSteps; + struct ncclInfo info; }; struct ncclTaskP2p { ncclTaskP2p *next; @@ -113,8 +131,16 @@ struct ncclTasks { struct ncclIntruQueue sendQueue; struct ncclIntruQueue recvQueue; }; - struct ncclIntruQueue collQueue; - size_t collBytesTotal; + struct ncclIntruQueue collQueue; + // Queue for user-tuned executed collectives + struct ncclIntruQueue collTunedQueue; + // Queue for continuous bytes distribution (CBD) collectives + struct ncclIntruQueue collCBDQueue; + // Queue for collnet + struct ncclIntruQueue collnetQueue; + size_t workBytesTotal; + int usableChannels; + bool sorted; struct Peer* peers/*[nRanks]*/; int *p2pSendOrder, *p2pRecvOrder; int p2pOrderSteps; diff --git a/projects/rccl/src/include/ipcsocket.h b/projects/rccl/src/include/ipcsocket.h index ccecde84c7..5a09e90b74 100644 --- a/projects/rccl/src/include/ipcsocket.h +++ b/projects/rccl/src/include/ipcsocket.h @@ -35,4 +35,7 @@ ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket* handle, int* fd); ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank, uint64_t hash); +ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, const int sendFd, int rank, uint64_t hash); +ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, int *recvFd); + #endif /* NCCL_IPCSOCKET_H */ diff --git a/projects/rccl/src/include/nccl_common.h b/projects/rccl/src/include/nccl_common.h index a37ac203ea..ded0dae9fa 100644 --- a/projects/rccl/src/include/nccl_common.h +++ b/projects/rccl/src/include/nccl_common.h @@ -13,7 +13,17 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); #define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now -typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncSend, ncclFuncRecv, ncclNumFuncs} ncclFunc_t; +typedef enum { + ncclFuncBroadcast = 0, + ncclFuncReduce = 1, + ncclFuncAllGather = 2, + ncclFuncReduceScatter = 3, + ncclFuncAllReduce = 4, + ncclFuncSendRecv = 5, + ncclFuncSend = 6, + ncclFuncRecv = 7, + ncclNumFuncs = 8 +} ncclFunc_t; #define NCCL_NUM_ALGORITHMS 6 // Tree/Ring/CollNet* #define NCCL_ALGO_UNDEF -1 diff --git a/projects/rccl/src/include/nccl_net.h b/projects/rccl/src/include/nccl_net.h index 9b3e6719fc..467d9fdb89 100644 --- a/projects/rccl/src/include/nccl_net.h +++ b/projects/rccl/src/include/nccl_net.h @@ -21,6 +21,140 @@ // Maximum number of requests per comm object #define NCCL_NET_MAX_REQUESTS 32 +typedef struct { + char* name; // Used mostly for logging. + char* pciPath; // Path to the PCI device in /sys. + uint64_t guid; // Unique identifier for the NIC chip. Important for + // cards with multiple PCI functions (Physical or virtual). + int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] + int regIsGlobal; // regMr is not tied to a particular comm + int speed; // Port speed in Mbps. + int port; // Port number. + float latency; // Network latency + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives. + ncclNetDeviceType netDeviceType; // Network offload type + int netDeviceVersion; // Version number for network offload +} ncclNetProperties_v8_t; + +typedef ncclNetProperties_v8_t ncclNetProperties_t; + +typedef struct { + // Name of the network (mainly for logs) + const char* name; + // Initialize the network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create a connection. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Connect to a handle and return a sending comm object for that peer. + // This call must not block for the connection to be established, and instead + // should return successfully with sendComm == NULL with the expectation that + // it will be called again until sendComm != NULL. + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm); + // Finalize connection establishment after remote peer has called connect. + // This call must not block for the connection to be established, and instead + // should return successfully with recvComm == NULL with the expectation that + // it will be called again until recvComm != NULL. + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm); + // Register/Deregister memory. Comm can be either a sendComm or a recvComm. + // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + // Asynchronous send to a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request); + // Asynchronous recv from a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* sizes); + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); +} ncclNet_v8_t; + +typedef ncclNet_v8_t ncclNet_t; + +#define NCCL_NET_PLUGIN_SYMBOL ncclNetPlugin_v8 + +typedef struct { + void* mhandle; + void* address; + uint32_t size; +} ncclNetSGE_v8_t; + +typedef struct { + // Name of the collective network (mainly for logs) + const char* name; + // Initialize the collective network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters capable of doing collective operations. + // If ndev returns 0, all other functions might be set to NULL. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create connections. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Create a group for collective operations. handles have been created + // using listen() above. rank indicates caller's rank in the collective network. + ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); + // Returns whether a reduction operation on a data type is supported. + // 1 for supported, 0 otherwise. + ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); + // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* collComm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* collComm, void* mhandle); + // Performs an asynchronous allreduce operation on the collective group. + // May return request == NULL if the call cannot be performed (or would block). + ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, int count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); + ncclResult_t (*iallgather)(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v8_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request); + ncclResult_t (*ireducescatter)(void* collComm, int nSendParts, ncclNetSGE_v8_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* size); + // Close and free collective comm objects + ncclResult_t (*closeColl)(void* collComm); + ncclResult_t (*closeListen)(void* listenComm); +} ncclCollNet_v8_t; + +typedef ncclCollNet_v8_t ncclCollNet_t; + +#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v8 + typedef struct { char* name; // Used mostly for logging. char* pciPath; // Path to the PCI device in /sys. @@ -36,8 +170,6 @@ typedef struct { int netDeviceVersion; // Version number for network offload } ncclNetProperties_v7_t; -typedef ncclNetProperties_v7_t ncclNetProperties_t; - typedef struct { // Name of the network (mainly for logs) const char* name; @@ -93,11 +225,45 @@ typedef struct { ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); } ncclNet_v7_t; -typedef ncclNet_v7_t ncclNet_t; - -#define NCCL_NET_PLUGIN_SYMBOL ncclNetPlugin_v7 - -#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v7 +typedef struct { + // Name of the collective network (mainly for logs) + const char* name; + // Initialize the collective network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters capable of doing collective operations. + // If ndev returns 0, all other functions might be set to NULL. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v7_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create connections. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Create a group for collective operations. handles have been created + // using listen() above. rank indicates caller's rank in the collective network. + ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); + // Returns whether a reduction operation on a data type is supported. + // 1 for supported, 0 otherwise. + ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); + // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* collComm, void* data, int size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* collComm, void* mhandle); + // Performs an asynchronous allreduce operation on the collective group. + // May return request == NULL if the call cannot be performed (or would block). + ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, int count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* size); + // Close and free collective comm objects + ncclResult_t (*closeColl)(void* collComm); + ncclResult_t (*closeListen)(void* listenComm); +} ncclCollNet_v7_t; #define NCCL_NET_MAX_REQUESTS_V6 8 @@ -162,49 +328,6 @@ typedef struct { ncclResult_t (*closeListen)(void* listenComm); } ncclNet_v6_t; -typedef struct { - // Name of the collective network (mainly for logs) - const char* name; - // Initialize the collective network. - ncclResult_t (*init)(ncclDebugLogger_t logFunction); - // Return the number of adapters capable of doing collective operations. - // If ndev returns 0, all other functions might be set to NULL. - ncclResult_t (*devices)(int* ndev); - // Get various device properties. - ncclResult_t (*getProperties)(int dev, ncclNetProperties_v7_t* props); - // Create a receiving object and provide a handle to connect to it. The - // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged - // between ranks to create connections. - ncclResult_t (*listen)(int dev, void* handle, void** listenComm); - // Create a group for collective operations. handles have been created - // using listen() above. rank indicates caller's rank in the collective network. - ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); - // Returns whether a reduction operation on a data type is supported. - // 1 for supported, 0 otherwise. - ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); - // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. - ncclResult_t (*regMr)(void* collComm, void* data, int size, int type, void** mhandle); - /* DMA-BUF support */ - ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); - ncclResult_t (*deregMr)(void* collComm, void* mhandle); - // Performs an asynchronous allreduce operation on the collective group. - // May return request == NULL if the call cannot be performed (or would block). - ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, int count, - ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); - // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is - // visible to the GPU - ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); - // Test whether a request is complete. If size is not NULL, it returns the - // number of bytes sent/received. - ncclResult_t (*test)(void* request, int* done, int* size); - // Close and free collective comm objects - ncclResult_t (*closeColl)(void* collComm); - ncclResult_t (*closeListen)(void* listenComm); -} ncclCollNet_v7_t; - -typedef ncclCollNet_v7_t ncclCollNet_t; - -// v6 struct for backwards compatibility typedef struct { // Name of the collective network (mainly for logs) const char* name; diff --git a/projects/rccl/src/include/net_device.h b/projects/rccl/src/include/net_device.h index 8f7c0d6e1e..7bb2968c05 100644 --- a/projects/rccl/src/include/net_device.h +++ b/projects/rccl/src/include/net_device.h @@ -24,6 +24,7 @@ typedef struct { int needsProxyProgress; } ncclNetDeviceHandle_v7_t; -typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_t; +typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t; +typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_t; #endif diff --git a/projects/rccl/src/include/nvmlwrap.h b/projects/rccl/src/include/nvmlwrap.h index 2ab8e3a2b0..bad0b79371 100644 --- a/projects/rccl/src/include/nvmlwrap.h +++ b/projects/rccl/src/include/nvmlwrap.h @@ -20,6 +20,12 @@ // Dynamically handle dependencies on NVML /* Extracted from nvml.h */ + +#define NVML_API_VERSION 12 + +#define NVML_STRUCT_VERSION(data, ver) (unsigned int)(sizeof(nvml ## data ## _v ## ver ## _t) | \ + (ver << 24U)) + typedef struct nvmlDevice_st* nvmlDevice_t; #define NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16 @@ -181,6 +187,72 @@ typedef struct nvmlFieldValue_st nvmlValue_t value; //!< Value for this field. This is only valid if nvmlReturn == NVML_SUCCESS } nvmlFieldValue_t; + +#define NVML_GPU_FABRIC_UUID_LEN 16 + +#define NVML_GPU_FABRIC_STATE_NOT_SUPPORTED 0 +#define NVML_GPU_FABRIC_STATE_NOT_STARTED 1 +#define NVML_GPU_FABRIC_STATE_IN_PROGRESS 2 +#define NVML_GPU_FABRIC_STATE_COMPLETED 3 + +typedef unsigned char nvmlGpuFabricState_t; + +typedef struct { + unsigned char clusterUuid[NVML_GPU_FABRIC_UUID_LEN]; //!< Uuid of the cluster to which this GPU belongs + nvmlReturn_t status; //!< Error status, if any. Must be checked only if state returns "complete". + unsigned int cliqueId; //!< ID of the fabric clique to which this GPU belongs + nvmlGpuFabricState_t state; //!< Current state of GPU registration process +} nvmlGpuFabricInfo_t; + +#define NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED 0 +#define NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE 1 +#define NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE 2 + +#define NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW 0 +#define NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW 0x11 + +/** + * GPU Fabric Health Status Mask for various fields can be obtained + * using the below macro. + * Ex - NVML_GPU_FABRIC_HEALTH_GET(var, _DEGRADED_BW) + */ +#define NVML_GPU_FABRIC_HEALTH_GET(var, type) \ + (((var) >> NVML_GPU_FABRIC_HEALTH_MASK_SHIFT##type) & \ + (NVML_GPU_FABRIC_HEALTH_MASK_WIDTH##type)) + +/** + * GPU Fabric Health Status Mask for various fields can be tested + * using the below macro. + * Ex - NVML_GPU_FABRIC_HEALTH_TEST(var, _DEGRADED_BW, _TRUE) + */ +#define NVML_GPU_FABRIC_HEALTH_TEST(var, type, val) \ + (NVML_GPU_FABRIC_HEALTH_GET(var, type) == \ + NVML_GPU_FABRIC_HEALTH_MASK##type##val) + +/** +* GPU Fabric information (v2). +* +* Version 2 adds the \ref nvmlGpuFabricInfo_v2_t.version field +* to the start of the structure, and the \ref nvmlGpuFabricInfo_v2_t.healthMask +* field to the end. This structure is not backwards-compatible with +* \ref nvmlGpuFabricInfo_t. +*/ +typedef struct { + unsigned int version; //!< Structure version identifier (set to \ref nvmlGpuFabricInfo_v2) + unsigned char clusterUuid[NVML_GPU_FABRIC_UUID_LEN]; //!< Uuid of the cluster to which this GPU belongs + nvmlReturn_t status; //!< Error status, if any. Must be checked only if state returns "complete". + unsigned int cliqueId; //!< ID of the fabric clique to which this GPU belongs + nvmlGpuFabricState_t state; //!< Current state of GPU registration process + unsigned int healthMask; //!< GPU Fabric health Status Mask +} nvmlGpuFabricInfo_v2_t; + +typedef nvmlGpuFabricInfo_v2_t nvmlGpuFabricInfoV_t; + +/** +* Version identifier value for \ref nvmlGpuFabricInfo_v2_t.version. +*/ +#define nvmlGpuFabricInfo_v2 NVML_STRUCT_VERSION(GpuFabricInfo, 2) + /* End of nvml.h */ #endif // NCCL_NVML_DIRECT @@ -210,5 +282,6 @@ ncclResult_t ncclNvmlDeviceGetNvLinkCapability(nvmlDevice_t device, unsigned int ncclResult_t ncclNvmlDeviceGetCudaComputeCapability(nvmlDevice_t device, int* major, int* minor); ncclResult_t ncclNvmlDeviceGetP2PStatus(nvmlDevice_t device1, nvmlDevice_t device2, nvmlGpuP2PCapsIndex_t p2pIndex, nvmlGpuP2PStatus_t* p2pStatus); ncclResult_t ncclNvmlDeviceGetFieldValues(nvmlDevice_t device, int valuesCount, nvmlFieldValue_t *values); +ncclResult_t ncclNvmlDeviceGetGpuFabricInfoV(nvmlDevice_t device, nvmlGpuFabricInfoV_t *gpuFabricInfo); #endif // End include guard diff --git a/projects/rccl/src/include/p2p.h b/projects/rccl/src/include/p2p.h index 6ffba4b0e1..9a3dbdb3b0 100644 --- a/projects/rccl/src/include/p2p.h +++ b/projects/rccl/src/include/p2p.h @@ -9,10 +9,22 @@ #ifndef NCCL_P2P_H_ #define NCCL_P2P_H_ -#define NCCL_P2P_HANDLE_TYPE CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR +#include -typedef struct { +#if CUDART_VERSION < 12030 +// MNNVL: FABRIC handle support lifted from CUDA 12.3 +#define CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED ((CUdevice_attribute)128) +#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) +#define CU_IPC_HANDLE_SIZE 64 +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; +#endif + +typedef union { uint64_t data; // Needs to hold a CUmemGenericAllocationHandle for UDS fd support + CUmemFabricHandle handle; } ncclCuDesc; typedef union { diff --git a/projects/rccl/src/include/proxy.h b/projects/rccl/src/include/proxy.h index 8093c0ce63..353426c1d3 100644 --- a/projects/rccl/src/include/proxy.h +++ b/projects/rccl/src/include/proxy.h @@ -24,33 +24,42 @@ typedef ncclResult_t (*proxyProgressFunc_t)(struct ncclProxyState*, struct ncclP #define NCCL_PROXY_MAX_SUBS MAXCHANNELS static_assert(NCCL_MAX_WORK_ELEMENTS <= MAXCHANNELS, "Not enough sub space for max work elements"); +union ncclProxyOpSpecifics { + struct { + size_t sizePerRank; + int nNodes, node; + } collnetDirect; +}; + struct ncclProxyOp { struct ncclProxyConnection* connection; - int channelId; - int nsteps; + void* buffer; ssize_t nbytes; + uint64_t opCount; int root; int next; - - uint64_t opCount; - int sliceSteps; - int chunkSteps; + int nsteps; int chunkSize; + uint8_t sliceSteps; + uint8_t chunkSteps; + uint8_t channelId; uint8_t /*ncclDataType_t*/ dtype; uint8_t /*ncclDevRedOp_t*/ redOp; + uint8_t /*ncclFunc_t*/ coll; uint8_t /*ncclPattern_t*/ pattern; uint8_t protocol; + uint8_t reg; - union { - uint64_t unused; - // For use by enqueue.cc - struct ncclProxyOp *enqNext; - }; + union ncclProxyOpSpecifics specifics; + + struct ncclProxyOp *enqNext; }; -static_assert(sizeof(struct ncclProxyOp) == 64, "Keep ProxyOp aligned with cache lines for effective prefetch"); struct ncclProxySubArgs { struct ncclProxyConnection* connection; + int reg; + void* buffer; + void* mhandle; int channelId; int nsteps; ssize_t nbytes; @@ -82,6 +91,7 @@ struct ncclProxyArgs { uint8_t /*ncclDataType_t*/ dtype; uint8_t /*ncclDevRedOp_t*/ redOp; uint8_t /*ncclPattern_t*/ pattern; + uint8_t /*ncclFunc_t*/ coll; uint8_t protocol; int state; char* sharedBuff[NCCL_STEPS]; @@ -93,6 +103,8 @@ struct ncclProxyArgs { struct ncclProxyArgs* next; struct ncclProxyArgs* nextPeer; struct ncclProxyArgs** proxyAppendPtr; + + union ncclProxyOpSpecifics specifics; }; #define NCCL_MAX_NETDEVS 128 @@ -100,7 +112,7 @@ struct ncclProxyArgs { // Make sure we have enough to store two full rounds of operations on all channels. // Otherwise we'd be unable to post half of them to free new elements. #define MAX_OPS_PER_PEER (2*MAXCHANNELS*NCCL_MAX_WORK_ELEMENTS_P2P) -#define NCCL_MAX_LOCAL_RANKS 64 + struct ncclProxyOpsPool { struct ncclProxyOp ops[MAX_OPS_PER_PEER*NCCL_MAX_LOCAL_RANKS]; volatile int nextOps; @@ -193,6 +205,16 @@ struct ncclProxyRpcResponseHeader { int respSize; }; +// UDS support +struct ncclIpcHdr { + int type; + int rank; + int reqSize; + int respSize; + void *opId; + uint64_t data[16]; // 128-bytes +}; + struct ncclProxyState { int refCount; int tpRank; @@ -208,9 +230,11 @@ struct ncclProxyState { ncclNet_t* ncclNet; ncclCollNet_t* ncclCollNet; volatile uint32_t* abortFlag; - // Service thread + // Service threads pthread_t thread; + pthread_t threadUDS; struct ncclSocket* listenSock; + struct ncclIpcSocket ipcSock; int stop; CUcontext cudaCtx; ncclResult_t asyncResult; @@ -221,6 +245,7 @@ struct ncclProxyState { struct ncclProxyOps* proxyOps; void** sharedDevMems; struct ncclIpcSocket peerIpcSock; // cuMEM API support (UDS) + uint64_t *peerAddressesUDS; // cuMem API support (UDS) // Progress thread struct ncclProxyProgressState progressState; @@ -262,9 +287,9 @@ enum proxyMode { }; ncclResult_t ncclProxySaveOp(struct ncclComm* comm, struct ncclProxyOp* proxyOp, bool *justInquire); -ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* proxyOp); +ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* proxyOp, int reg); ncclResult_t ncclProxyStart(struct ncclComm* comm); -ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses); +ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses, uint64_t *peerAddressesUDS); ncclResult_t ncclProxyCreate(struct ncclComm* comm); ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int proxyRank, struct ncclProxyConnector* proxyConn); enum ncclProxyMsgType { @@ -288,7 +313,8 @@ ncclResult_t ncclProxyCallAsync(struct ncclComm* comm, struct ncclProxyConnector ncclResult_t ncclProxyCallBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize); ncclResult_t ncclPollProxyResponse(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* respBuff, void* opId); -ncclResult_t ncclProxyClientGetFdBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void *handle, int* convertedFd); +// UDS support +ncclResult_t ncclProxyClientGetFdBlocking(struct ncclComm* comm, int rank, void *handle, int* convertedFd); ncclResult_t ncclProxyStop(struct ncclComm* comm); ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm); diff --git a/projects/rccl/src/include/register.h b/projects/rccl/src/include/register.h new file mode 100644 index 0000000000..2fb387f7c9 --- /dev/null +++ b/projects/rccl/src/include/register.h @@ -0,0 +1,42 @@ +#ifndef NCCL_REGISTER_H_ +#define NCCL_REGISTER_H_ + +enum { + NET_REG_COMPLETE = 0x01, + NVLS_REG_COMPLETE = 0x02, + NVLS_REG_POSSIBLE = 0x04, + NVLS_REG_NO_SUPPORT = 0x08 +}; + +struct ncclReg { + // common attributes + size_t pages; + int refs; + uintptr_t addr; + uint32_t state; + // net reg + int nDevs; + int devs[MAXCHANNELS]; + void** handles; + // nvls reg + uintptr_t baseAddr; + size_t baseSize; + CUdeviceptr regAddr; + size_t regSize; + int dev; + CUmemGenericAllocationHandle mcHandle; + uintptr_t caddrs[NCCL_MAX_LOCAL_RANKS]; /* use to check if NVLS buffers match among intra-node ranks */ +}; + +struct ncclRegCache { + struct ncclReg **slots; + int capacity, population; + uintptr_t pageSize; + void* sComms[MAXCHANNELS]; + void* rComms[MAXCHANNELS]; +}; + +ncclResult_t ncclRegCleanup(struct ncclComm* comm); +ncclResult_t ncclRegFind(struct ncclComm* comm, const void* data, size_t size, struct ncclReg** reg); + +#endif diff --git a/projects/rccl/src/include/shm.h b/projects/rccl/src/include/shm.h index e75caa6a6e..1db16662d5 100644 --- a/projects/rccl/src/include/shm.h +++ b/projects/rccl/src/include/shm.h @@ -18,6 +18,7 @@ struct ncclShmemCollBuff { volatile size_t *cnt[2]; volatile void *ptr[2]; int round; + size_t maxTypeSize; }; ncclResult_t ncclShmemAllgather(struct ncclComm *comm, struct ncclShmemCollBuff *shmem, void *sendbuff, void *recvbuff, size_t typeSize); diff --git a/projects/rccl/src/include/transport.h b/projects/rccl/src/include/transport.h index 27529df5e6..a21114807e 100644 --- a/projects/rccl/src/include/transport.h +++ b/projects/rccl/src/include/transport.h @@ -43,6 +43,8 @@ struct ncclPeerInfo { int64_t busId; struct ncclComm* comm; int cudaCompCap; + // MNNVL support + nvmlGpuFabricInfoV_t fabricInfo; }; #define CONNECT_SIZE 128 diff --git a/projects/rccl/src/include/utils.h b/projects/rccl/src/include/utils.h index 60f6efb5f8..cfc0098610 100644 --- a/projects/rccl/src/include/utils.h +++ b/projects/rccl/src/include/utils.h @@ -30,6 +30,11 @@ uint64_t getHostHash(); uint64_t getPidHash(); ncclResult_t getRandomData(void* buffer, size_t bytes); +const char* ncclOpToString(ncclRedOp_t op); +const char* ncclDatatypeToString(ncclDataType_t type); +const char* ncclAlgoToString(int algo); +const char* ncclProtoToString(int proto); + struct netIf { char prefix[64]; int port; @@ -394,6 +399,36 @@ void ncclIntruQueueFreeAll(ncclIntruQueue *me, ncclMemoryPool *pool) { } } +/* cmp function determines the sequence of objects in the queue. If cmp returns value >= 0, it means a > b, + * and we should put a before b; otherwise, b should be put ahead of a. */ +template +inline void ncclIntruQueueSortEnqueue(ncclIntruQueue *me, T *x, int (*cmp)(T *a, T *b)) { + T *cur = me->head; + T *prev = NULL; + + if (cur == NULL) { + x->*next = nullptr; + me->tail = me->head = x; + } else { + while (cur) { + if (cmp(cur, x) > 0) { + prev = cur; + cur = cur->next; + } else { + break; + } + } + + x->*next = cur; + if (prev) { + prev->*next = x; + if (cur == NULL) me->tail = x; + } else { + me->head = x; + } + } +} + //////////////////////////////////////////////////////////////////////////////// constexpr ncclThreadSignal ncclThreadSignalStaticInitializer() { diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index e82e64e148..39d0213b36 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -180,6 +180,10 @@ static ncclResult_t commFree(ncclComm_t comm) { * resource cleanup in commFree(). */ if (comm->proxyState && comm->proxyRefCountOld == 0 && comm->proxyState->thread) { pthread_join(comm->proxyState->thread, nullptr); + if (comm->proxyState->threadUDS) { + // UDS support + pthread_join(comm->proxyState->threadUDS, nullptr);; + } } delete[] comm->userRedOps; @@ -238,17 +242,7 @@ static ncclResult_t commFree(ncclComm_t comm) { free(comm->topParentRanks); free(comm->topParentLocalRanks); - while (!ncclIntruQueueEmpty(&comm->regRecordQueue)) { - struct ncclRegRecord* rec = ncclIntruQueueDequeue(&comm->regRecordQueue); - NCCLCHECK(ncclNvlsDeregBuffer(&rec->mcHandle, rec->regAddr, rec->dev, rec->regSize)); - free(rec->addrs); - free(rec); - } - - while (!ncclIntruQueueEmpty(&comm->regRequestQueue)) { - struct ncclRegRequest* req = ncclIntruQueueDequeue(&comm->regRequestQueue); - free(req); - } + NCCLCHECK(ncclRegCleanup(comm)); commPoison(comm); // poison comm before free to avoid comm reuse. free(comm); @@ -256,7 +250,6 @@ static ncclResult_t commFree(ncclComm_t comm) { return ncclSuccess; } -NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2); NCCL_PARAM(DisableGraphHelper, "GRAPH_HELPER_DISABLE", 0); // GDRCOPY support: FIFO_ENABLE when enabled locates a workFifo in CUDA memory NCCL_PARAM(GdrCopyFifoEnable, "GDRCOPY_FIFO_ENABLE", 1); @@ -288,7 +281,7 @@ ncclResult_t ncclCommEnsureReady(ncclComm_t comm) { /* comm must be ready, or error will be reported */ ncclResult_t ret = ncclSuccess; - if (*comm->abortFlag) { + if (__atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED)) { ncclGroupJobAbort(comm->groupJob); } else { NCCLCHECK(ncclCommGetAsyncError(comm, &ret)); @@ -361,7 +354,6 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in comm->groupNext = reinterpret_cast(0x1); comm->preconnectNext = reinterpret_cast(0x1); - comm->channelSize = ncclParamAggChannelSize(); static_assert(MAXCHANNELS <= sizeof(*comm->connectSend)*8, "comm->connectSend must have enough bits for all channels"); static_assert(MAXCHANNELS <= sizeof(*comm->connectRecv)*8, "comm->connectRecv must have enough bits for all channels"); @@ -393,9 +385,9 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in comm->topParentRanks[i] = i; } - ncclIntruQueueConstruct(&comm->regRequestQueue); - ncclIntruQueueConstruct(&comm->regRecordQueue); ncclIntruQueueMpscConstruct(&comm->callbackQueue); + + comm->regCache.pageSize = sysconf(_SC_PAGESIZE); return ncclSuccess; } @@ -411,6 +403,8 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { comm->devComm = &devCommAndChans->comm; tmpCommAndChans.comm.rank = comm->rank; tmpCommAndChans.comm.nRanks = nRanks; + tmpCommAndChans.comm.node = comm->node; + tmpCommAndChans.comm.nNodes = comm->nNodes; tmpCommAndChans.comm.abortFlag = comm->abortFlag; for (int p=0; p < NCCL_NUM_PROTOCOLS; p++) { tmpCommAndChans.comm.buffSizes[p] = comm->buffSizes[p]; @@ -443,6 +437,12 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { comm->workFifoSent = 0; comm->workFifoAckdMin = 0; + if (comm->collNetDenseToUserRank != nullptr) { + NCCLCHECKGOTO(ncclCudaCallocAsync(&tmpCommAndChans.comm.collNetDenseToUserRank, nRanks, comm->sharedRes->deviceStream.cudaStream), ret, fail); + ncclCommPushCudaFree(comm, tmpCommAndChans.comm.collNetDenseToUserRank); + NCCLCHECKGOTO(ncclCudaMemcpyAsync(tmpCommAndChans.comm.collNetDenseToUserRank, comm->collNetDenseToUserRank, nRanks, comm->sharedRes->deviceStream.cudaStream), ret, fail); + } + for (int c=0; c < MAXCHANNELS; c++) { tmpCommAndChans.channels[c].peers = comm->channels[c].devPeers; tmpCommAndChans.channels[c].ring = comm->channels[c].ring; @@ -499,6 +499,24 @@ static ncclResult_t fillInfo(struct ncclComm* comm, struct ncclPeerInfo* info, u NCCLCHECK(ncclGpuGdrSupport(comm, &info->gdrSupport)); info->comm = comm; info->cudaCompCap = comm->minCompCap = comm->maxCompCap = comm->compCap; + + // MNNVL support + { + // MNNVL: Request the fabric UUID and partition info + char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; + nvmlDevice_t nvmlDev; + NCCLCHECK(int64ToBusId(info->busId, busId)); + NCCLCHECK(ncclNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev)); + info->fabricInfo.state = NVML_GPU_FABRIC_STATE_NOT_SUPPORTED; + (void) ncclNvmlDeviceGetGpuFabricInfoV(nvmlDev, &info->fabricInfo); + if (info->fabricInfo.state != NVML_GPU_FABRIC_STATE_NOT_SUPPORTED) { + INFO(NCCL_INIT, "MNNVL busId 0x%lx fabric UUID %lx.%lx cliqueId 0x%x state %d healthMask 0x%x", + info->busId, + ((long *)&info->fabricInfo.clusterUuid)[0], ((long *)&info->fabricInfo.clusterUuid)[1], + info->fabricInfo.cliqueId, info->fabricInfo.state, info->fabricInfo.healthMask); + } + } + return ncclSuccess; } @@ -542,8 +560,9 @@ static ncclResult_t computeBuffSizes(struct ncclComm* comm) { comm->buffSizes[p] = envs[p] != -2 ? envs[p] : defaults[p]; } - if (comm->nNodes > 1) comm->p2pChunkSize = ncclParamP2pNetChunkSize(); - else if (ncclTopoPathAllNVLink(comm->topo)) comm->p2pChunkSize = ncclParamP2pNvlChunkSize(); + // MNNVL support + if (!comm->MNNVL && comm->nNodes > 1) comm->p2pChunkSize = ncclParamP2pNetChunkSize(); + else if (comm->MNNVL || ncclTopoPathAllNVLink(comm->topo)) comm->p2pChunkSize = ncclParamP2pNvlChunkSize(); else comm->p2pChunkSize = ncclParamP2pPciChunkSize(); // Make sure P2P chunksize is not larger than coll chunksize. @@ -573,6 +592,8 @@ static ncclResult_t collNetTrySetup(ncclComm_t comm, ncclComm_t parent, struct n int highestTypes[NCCL_MAX_LOCAL_RANKS] = { TRANSPORT_P2P }; // Find all head ranks int nHeads = collNetGraph->nChannels; + int nHeadsUnique = 0; + int headsUnique[NCCL_MAX_LOCAL_RANKS]; int highestTransportType0, highestTransportType1; char line[1024]; bool share; @@ -584,13 +605,20 @@ static ncclResult_t collNetTrySetup(ncclComm_t comm, ncclComm_t parent, struct n struct collnetShareInfo* infos = NULL; NCCLCHECKGOTO(ncclCalloc(&heads, nHeads), ret, fail); - // Head GPU index is always 0 - for (int c = 0; c < nHeads; c++) { - heads[c] = collNetGraph->intra[c * comm->localRanks + 0]; + { uint64_t mask = 0; + // Head GPU index is always 0 + for (int c = 0; c < nHeads; c++) { + heads[c] = collNetGraph->intra[c * comm->localRanks + 0]; + assert(comm->rankToNode[heads[c]] == comm->node); + uint64_t mask0 = mask; + mask |= 1ull<rankToLocalRank[heads[c]]; + if (mask != mask0) headsUnique[nHeadsUnique++] = heads[c]; + } } comm->collNetHeads = heads; comm->collNetHeadsNum = nHeads; + comm->collNetHeadsUniqueNum = nHeadsUnique; if (parent && parent->collNetSupport && parent->config.splitShare && parent->nNodes == comm->nNodes) { NCCLCHECKGOTO(ncclCalloc(&infos, comm->nRanks), ret, fail); /* check whether child can share collnet resources of parent. Since parent builds each collnet communicator @@ -651,6 +679,26 @@ static ncclResult_t collNetTrySetup(ncclComm_t comm, ncclComm_t parent, struct n NCCLCHECK(ncclCalloc(&comm->collNetSharedRes, 1)); comm->collNetChannels = comm->collNetSharedRes->nChannels = comm->nChannels; comm->collNetSharedRes->buffSize = comm->buffSizes[NCCL_PROTO_SIMPLE]; + + comm->collNetDenseToUserRank = ncclMemoryStackAlloc(&comm->memPermanent, comm->nRanks); + comm->collNetUserToDenseRank = ncclMemoryStackAlloc(&comm->memPermanent, comm->nRanks); + { // initialize collNetUserToDenseRank[rank] + uint64_t nonHeadMask = (1ull<localRanks)-1; + comm->collNetUserToDenseRank[rank] = -1; + for (int h=0; h < nHeadsUnique; h++) { + nonHeadMask ^= 1ull<rankToLocalRank[headsUnique[h]]; + if (headsUnique[h] == rank) { comm->collNetUserToDenseRank[rank] = h; break; } + } + if (comm->collNetUserToDenseRank[rank] == -1) { + comm->collNetUserToDenseRank[rank] = __builtin_popcountll(nonHeadMask & ((1ull<localRank)-1)); + } + comm->collNetUserToDenseRank[rank] += comm->node*comm->localRanks; + } + NCCLCHECK(bootstrapAllGather(comm->bootstrap, comm->collNetUserToDenseRank, sizeof(int))); + for (int r=0; r < comm->nRanks; r++) { + comm->collNetDenseToUserRank[comm->collNetUserToDenseRank[r]] = r; + } + for (int c = 0; c < comm->collNetChannels; c++) { struct ncclChannel* channel = comm->channels + c; NCCLCHECKGOTO(initCollnetChannel(comm, c, parent, false), ret, fail); @@ -768,6 +816,9 @@ fail: goto exit; } +// MNNVL: Flag to indicate whether to enable Multi-Node NVLink +NCCL_PARAM(MNNVL, "MNNVL", -2); + static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* parent = NULL) { // We use 2 AllGathers // 1. { peerInfo, comm, compCap} @@ -822,6 +873,56 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p } // AllGather1 - end +#if CUDART_VERSION >= 11030 + +#include +#include "cudawrap.h" + + // MNNVL support + { + int cliqueSize = 0; + comm->MNNVL = 0; + // Determine the size of the MNNVL domain/clique + for (int i = 0; i < nranks; i++) { + nvmlGpuFabricInfoV_t *fabricInfo1 = &comm->peerInfo[rank].fabricInfo; + nvmlGpuFabricInfoV_t *fabricInfo2 = &comm->peerInfo[i].fabricInfo; + // Check that the Fabric state is fully initialized + if (fabricInfo2->state != NVML_GPU_FABRIC_STATE_COMPLETED) continue; + // Check that the cluster UUID and cliqueId match in each rank + // A zero UUID means we don't have MNNVL fabric info - disable MNNVL + if ((((long *)&fabricInfo2->clusterUuid)[0]|((long *)fabricInfo2->clusterUuid)[1]) == 0) continue; + if ((memcmp(fabricInfo1->clusterUuid, fabricInfo2->clusterUuid, NVML_GPU_FABRIC_UUID_LEN) == 0) && + (fabricInfo1->cliqueId == fabricInfo2->cliqueId)) { + cliqueSize++; + } + } + // Determine whether this is a MNNVL system + comm->MNNVL = ncclParamMNNVL() < 0 ? cliqueSize == comm->nRanks : ncclParamMNNVL(); + // MNNVL requires cuMem to be enabled + if (!ncclCuMemEnable()) comm->MNNVL = 0; + if (comm->MNNVL) { + // MNNVL also requires FABRIC handle support + int cudaDev; + int flag = 0; + CUdevice currentDev; + CUDACHECK(cudaGetDevice(&cudaDev)); + CUCHECK(cuDeviceGet(¤tDev, cudaDev)); + // Ignore error if CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED is not supported + (void) CUPFN(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, currentDev));; + if (!flag) + comm->MNNVL = 0; + else + // Force the handle type to be FABRIC for MNNVL + ncclCuMemHandleType = CU_MEM_HANDLE_TYPE_FABRIC; + } + if (ncclParamMNNVL() == 1 && !comm->MNNVL) { + WARN("MNNVL is not supported on this system"); + ret = ncclSystemError; + goto fail; + } + } +#endif + do { // Compute intra-process ranks int intraProcRank0 = -1, intraProcRank = -1, intraProcRanks = 0; @@ -1019,6 +1120,9 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p goto fail; } + INFO(NCCL_INIT, "comm %p rank %d nRanks %d nNodes %d localRanks %d localRank %d MNNVL %d", + comm, rank, comm->nRanks, comm->nNodes, comm->localRanks, comm->localRank, comm->MNNVL); + nChannelsOrig = comm->nChannels; NCCLCHECKGOTO(ncclCalloc(&allTopoRanks, comm->nRanks), ret, fail); for (int i=0; itopParentLocalRanks = topParentLocalRanks; // Launch proxy service thread, after this, the proxy calls can be used. - NCCLCHECKGOTO(ncclProxyCreate(comm), ret, fail); + if (parent && parent->config.splitShare) { + comm->proxyState = parent->sharedRes->proxyState; + ncclAtomicRefCountIncrement(&parent->sharedRes->proxyState->refCount); + } else { + NCCLCHECKGOTO(ncclProxyCreate(comm), ret, fail); + } // Connect with prev/next for each ring for (int c=0; cnChannels; c++) { @@ -1124,8 +1233,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p // Setup NVLS NCCLCHECKGOTO(ncclNvlsSetup(comm, parent), ret, fail); // And NVLS trees if needed - if (comm->nvlsSupport && comm->localRanks > 1) { - for (int c=0; cnvlsChannels; c++) { + if (comm->nvlsSupport && comm->nNodes > 1) { + for (int c=0; cnChannels; c++) { struct ncclChannel* channel = comm->channels+c; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_NVLS_TREE_ARITY, channel->nvls.treeDown, 1, &channel->nvls.treeUp, 0), ret, fail); NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->nvls.treeUp, NCCL_MAX_NVLS_TREE_ARITY, channel->nvls.treeDown, 0), ret, fail); @@ -1142,7 +1251,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p // Compute time models for algorithm and protocol combinations NCCLCHECKGOTO(ncclTopoTuneModel(comm, comm->minCompCap, comm->maxCompCap, graphs), ret, fail); - INFO(NCCL_INIT, "%d coll channels, %d nvls channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->nvlsChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); + INFO(NCCL_INIT, "%d coll channels, %d collnet channels, %d nvls channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->collNetChannels, comm->nvlsChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); do { // Setup p2p structures in comm->tasks struct ncclTasks* tasks = &comm->tasks; @@ -1376,14 +1485,15 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { if (job->color == NCCL_SPLIT_NOCOLOR) goto exit; snprintf((char*)&job->commId, sizeof(job->commId), "%016lx-%d", job->parent->commHash, job->color); NCCLCHECKGOTO(commAlloc(comm, job->parent, job->nranks, job->myrank), res, fail); + comm->commHash = getHash(job->commId.internal, NCCL_UNIQUE_ID_BYTES); // Needed for UDS support NCCLCHECKGOTO(bootstrapSplit((struct ncclBootstrapHandle*)&job->commId, comm, job->parent, job->color, job->key, parentRanks), res, fail); } else { NCCLCHECKGOTO(commAlloc(comm, NULL, job->nranks, job->myrank), res, fail); + comm->commHash = getHash(job->commId.internal, NCCL_UNIQUE_ID_BYTES); // Needed for UDS support NCCLCHECKGOTO(bootstrapInit((struct ncclBootstrapHandle*)&job->commId, comm), res, fail); } comm->cudaArch = cudaArch; - comm->commHash = getHash(job->commId.internal, NCCL_UNIQUE_ID_BYTES); INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d nvmlDev %d busId %lx commId 0x%llx - Init START", comm, comm->rank, comm->nRanks, comm->cudaDev, comm->nvmlDev, comm->busId, (unsigned long long)hashUniqueId(job->commId)); @@ -1886,7 +1996,7 @@ static ncclResult_t commReclaim(ncclComm_t comm) { NCCLCHECKGOTO(ncclCommGetAsyncError(comm, &state), ret, fail); TRACE(NCCL_INIT, "commReclaim: reclaim comm %p rank %d state %d", comm, comm->rank, state); - if (state == ncclSuccess && *comm->abortFlag == 0 && comm->finalizeCalled == false) { + if (state == ncclSuccess && __atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED) == 0 && comm->finalizeCalled == false) { /* user does not call ncclCommFinalize and this is a normal comm destroy. ncclCommDestroy * should be nonblocking until last call of ncclCommDestroy. */ NCCLCHECKGOTO(commFinalize(comm, false), ret, fail); @@ -2011,9 +2121,9 @@ ncclResult_t ncclCommAbort(ncclComm_t comm) { // Ask anything that might still be running on the device to quit childAbortFlag = __atomic_load_n(&comm->childAbortFlag, __ATOMIC_ACQUIRE); if (childAbortFlag != NULL) { - *childAbortFlag = 1; + __atomic_store_n(childAbortFlag, 1, __ATOMIC_RELAXED); } - *comm->abortFlag = 1; + __atomic_store_n(comm->abortFlag, 1, __ATOMIC_RELAXED); /* init thread must be joined before we destroy the comm, * and we should ignore the init error here. */ ncclCommEnsureReady(comm); @@ -2161,98 +2271,6 @@ ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) { return ncclSuccess; } -NCCL_PARAM(LocalRegister, "LOCAL_REGISTER", 1); - -NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size_t size, void** handle); -ncclResult_t ncclCommRegister(const ncclComm_t comm, void* buff, size_t size, void** handle) { - NVTX3_FUNC_RANGE_IN(nccl_domain); - ncclResult_t ret = ncclSuccess; - -#if CUDART_VERSION >= 12010 - size_t granularity; - if (ncclParamLocalRegister()) { - if (comm == NCCL_COMM_NULL || buff == NULL || handle == NULL || size == 0) { - WARN("Invalid arguments comm %p, buff %p, size %ld, handle %p", comm, buff, size, handle); - ret = ncclInvalidArgument; - } else if (comm->nvlsSupport) { - CUmulticastObjectProp prop = comm->nvlsResources->properties; - - prop.size = size; - CUCHECK(cuMulticastGetGranularity(&granularity, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); - - if ((uintptr_t)buff % comm->nvlsResources->ucGran == 0 && size % granularity == 0) { - /* we can direct register what user provide */ - struct ncclRegRequest* req; - NCCLCHECK(ncclCalloc(&req, 1)); - req->buff = (uintptr_t)buff; - req->size = size; - ncclIntruQueueEnqueue(&comm->regRequestQueue, req); - *handle = (void*)req; - } else { - void* base; - size_t baseSize; - /* Since we don't provide actually allocated buffer size for users by ncclMemAlloc, - * therefore, we need to get the full range of the buffer by cuMemGetAddressRange to - * register buffers. */ - CUCHECK(cuMemGetAddressRange((CUdeviceptr*)&base, &baseSize, (CUdeviceptr)buff)); - if ((uintptr_t)base % comm->nvlsResources->ucGran == 0 && baseSize % granularity == 0) { - struct ncclRegRequest* req; - NCCLCHECK(ncclCalloc(&req, 1)); - req->buff = (uintptr_t)base; - req->size = baseSize; - ncclIntruQueueEnqueue(&comm->regRequestQueue, req); - *handle = (void*)req; - } else { - WARN("register fails, buffer %p (aligned %s, granularity %ld) and size %ld (aligned %s, granularity %ld) for registration", buff, (uintptr_t)buff % comm->nvlsResources->ucGran == 0 ? "TRUE" : "FALSE", comm->nvlsResources->ucGran, size, size % granularity == 0 ? "TRUE" : "FALSE", granularity); - ret = ncclInvalidArgument; - } - } - } - } -#endif - - return ret; -} - -NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); -ncclResult_t ncclCommDeregister(const ncclComm_t comm, void* handle) { - ncclResult_t ret = ncclSuccess; - -#if CUDART_VERSION >= 12010 - struct ncclRegRequest* dreq = (struct ncclRegRequest*)handle; - if (ncclParamLocalRegister()) { - if (comm == NCCL_COMM_NULL || handle == NULL) { - WARN("Invalid arguments comm %p, handle %p", comm, handle); - ret = ncclInvalidArgument; - } else { - struct ncclRegRecord* rec; - - /* first release register record */ - rec = ncclIntruQueueHead(&comm->regRecordQueue); - - while (rec) { - if (rec->buff == dreq->buff && rec->size == dreq->size) { - NCCLCHECK(ncclNvlsDeregBuffer(&rec->mcHandle, rec->regAddr, rec->dev, rec->regSize)); - ncclIntruQueueDelete(&comm->regRecordQueue, rec); - free(rec->addrs); - free(rec); - break; - } - rec = rec->next; - } - - /* then free register request */ - if (ncclIntruQueueDelete(&comm->regRequestQueue, dreq) == false) { - WARN("Invalid handle %p", handle); - ret = ncclInvalidArgument; - } - } - } -#endif - - return ret; -} - NCCL_API(ncclResult_t, ncclMemAlloc, void **ptr, size_t size); ncclResult_t ncclMemAlloc(void **ptr, size_t size) { NVTX3_FUNC_RANGE_IN(nccl_domain); diff --git a/projects/rccl/src/misc/argcheck.cc b/projects/rccl/src/misc/argcheck.cc index 994d1fd9b1..c5909337d2 100644 --- a/projects/rccl/src/misc/argcheck.cc +++ b/projects/rccl/src/misc/argcheck.cc @@ -7,7 +7,7 @@ #include "argcheck.h" #include "comm.h" -static ncclResult_t CudaPtrCheck(const void* pointer, struct ncclComm* comm, const char* ptrname, const char* opname) { +ncclResult_t CudaPtrCheck(const void* pointer, struct ncclComm* comm, const char* ptrname, const char* opname) { cudaPointerAttributes attr; cudaError_t err = cudaPointerGetAttributes(&attr, pointer); if (err != cudaSuccess || attr.devicePointer == NULL) { diff --git a/projects/rccl/src/misc/cudawrap.cc b/projects/rccl/src/misc/cudawrap.cc index f2260a1c0e..8ccc7a876d 100644 --- a/projects/rccl/src/misc/cudawrap.cc +++ b/projects/rccl/src/misc/cudawrap.cc @@ -14,6 +14,9 @@ // This env var (NCCL_CUMEM_ENABLE) toggles cuMem API usage NCCL_PARAM(CuMemEnable, "CUMEM_ENABLE", -2); +// Handle type used for cuMemCreate() +CUmemAllocationHandleType ncclCuMemHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + static int ncclCuMemSupported = 0; // Determine whether CUMEM & VMM RDMA is supported on this platform diff --git a/projects/rccl/src/misc/ipcsocket.cc b/projects/rccl/src/misc/ipcsocket.cc index 9d66ac7197..fc7fd4b66a 100644 --- a/projects/rccl/src/misc/ipcsocket.cc +++ b/projects/rccl/src/misc/ipcsocket.cc @@ -132,7 +132,7 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, WARN("UDS: Receiving data over socket failed : %d", errno); return ncclSystemError; } - if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + if (handle->abortFlag && __atomic_load_n(handle->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; } if (recvFd != NULL) { @@ -221,7 +221,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, strerror(errno), errno); return ncclSystemError; } - if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + if (handle->abortFlag && __atomic_load_n(handle->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; } return ncclSuccess; diff --git a/projects/rccl/src/misc/nvmlwrap.cc b/projects/rccl/src/misc/nvmlwrap.cc index 2de993a6e5..76c989e76a 100644 --- a/projects/rccl/src/misc/nvmlwrap.cc +++ b/projects/rccl/src/misc/nvmlwrap.cc @@ -39,6 +39,8 @@ namespace { NCCL_NVML_FN(nvmlDeviceGetCudaComputeCapability, nvmlReturn_t, (nvmlDevice_t device, int* major, int* minor)) NCCL_NVML_FN(nvmlDeviceGetP2PStatus, nvmlReturn_t, (nvmlDevice_t device1, nvmlDevice_t device2, nvmlGpuP2PCapsIndex_t p2pIndex, nvmlGpuP2PStatus_t* p2pStatus)) NCCL_NVML_FN(nvmlDeviceGetFieldValues, nvmlReturn_t, (nvmlDevice_t device, int valuesCount, nvmlFieldValue_t *values)) + // MNNVL support + NCCL_NVML_FN(nvmlDeviceGetGpuFabricInfoV, nvmlReturn_t, (nvmlDevice_t device, nvmlGpuFabricInfoV_t *gpuFabricInfo)) std::mutex lock; // NVML has had some thread safety bugs bool initialized = false; @@ -82,7 +84,9 @@ ncclResult_t ncclNvmlEnsureInitialized() { {(void**)&pfn_nvmlDeviceGetNvLinkCapability, "nvmlDeviceGetNvLinkCapability"}, {(void**)&pfn_nvmlDeviceGetCudaComputeCapability, "nvmlDeviceGetCudaComputeCapability"}, {(void**)&pfn_nvmlDeviceGetP2PStatus, "nvmlDeviceGetP2PStatus"}, - {(void**)&pfn_nvmlDeviceGetFieldValues, "nvmlDeviceGetFieldValues"} + {(void**)&pfn_nvmlDeviceGetFieldValues, "nvmlDeviceGetFieldValues"}, + // MNNVL support + {(void**)&pfn_nvmlDeviceGetGpuFabricInfoV, "nvmlDeviceGetGpuFabricInfoV"}, }; for(Symbol sym: symbols) { *sym.ppfn = dlsym(libhandle, sym.name); @@ -269,3 +273,12 @@ ncclResult_t ncclNvmlDeviceGetFieldValues(nvmlDevice_t device, int valuesCount, NVMLTRY(nvmlDeviceGetFieldValues, device, valuesCount, values); return ncclSuccess; } + +// MNNVL support +ncclResult_t ncclNvmlDeviceGetGpuFabricInfoV(nvmlDevice_t device, nvmlGpuFabricInfoV_t *gpuFabricInfo) { + NCCLCHECK(ncclNvmlEnsureInitialized()); + std::lock_guard locked(lock); + gpuFabricInfo->version = nvmlGpuFabricInfo_v2; + NVMLTRY(nvmlDeviceGetGpuFabricInfoV, device, gpuFabricInfo); + return ncclSuccess; +} diff --git a/projects/rccl/src/misc/shmutils.cc b/projects/rccl/src/misc/shmutils.cc index 80ece40c1c..04f7c10be7 100644 --- a/projects/rccl/src/misc/shmutils.cc +++ b/projects/rccl/src/misc/shmutils.cc @@ -169,7 +169,7 @@ ncclResult_t ncclShmemAllgather(struct ncclComm *comm, struct ncclShmemCollBuff int curRound = shmem->round; size_t mycnt; - if (comm == NULL || shmem == NULL || sendbuff == NULL || recvbuff == NULL) { + if (comm == NULL || shmem == NULL || sendbuff == NULL || recvbuff == NULL || shmem->maxTypeSize < typeSize) { ret = ncclInvalidArgument; goto exit; } @@ -184,7 +184,7 @@ ncclResult_t ncclShmemAllgather(struct ncclComm *comm, struct ncclShmemCollBuff uint64_t t0 = clockNano(); while(__atomic_load_n(shmem->cnt[curRound], __ATOMIC_ACQUIRE) != comm->localRanks + 1) { if (clockNano() - t0 >= 5 * 1000) sched_yield(); - if (*comm->abortFlag == 1) { + if (__atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED) == 1) { ret = ncclInternalError; goto exit; } diff --git a/projects/rccl/src/misc/socket.cc b/projects/rccl/src/misc/socket.cc index 149bd73aa1..3aeed6c578 100644 --- a/projects/rccl/src/misc/socket.cc +++ b/projects/rccl/src/misc/socket.cc @@ -34,7 +34,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr } } (*offset) += bytes; - if (sock->abortFlag && *sock->abortFlag != 0) { + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) { INFO(NCCL_NET, "socketProgressOpt: abort called"); return ncclInternalError; } @@ -529,6 +529,8 @@ static ncclResult_t socketPollConnect(struct ncclSocket* sock) { sock->state = ncclSocketStateConnecting; } else if (ret != EINPROGRESS) { sock->state = ncclSocketStateError; + char line[SOCKET_NAME_MAXLEN+1]; + WARN("socketPollConnect: Connect to %s returned %d(%s) errno %d(%s)", ncclSocketToString(&sock->addr, line), ret, strerror(ret), errno, strerror(errno)); return ncclSystemError; } return ncclSuccess; @@ -618,12 +620,12 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { do { NCCLCHECK(socketProgressState(sock)); } while (sock->asyncFlag == 0 && - (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) && (sock->state == ncclSocketStateConnecting || sock->state == ncclSocketStateConnectPolling || sock->state == ncclSocketStateConnected)); - if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; switch (sock->state) { case ncclSocketStateConnecting: @@ -665,11 +667,11 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen do { NCCLCHECKGOTO(socketProgressState(sock), ret, exit); } while (sock->asyncFlag == 0 && - (sock->abortFlag == NULL || *sock->abortFlag == 0) && + (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) && (sock->state == ncclSocketStateAccepting || sock->state == ncclSocketStateAccepted)); - if (sock->abortFlag && *sock->abortFlag != 0) return ncclInternalError; + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; switch (sock->state) { case ncclSocketStateAccepting: diff --git a/projects/rccl/src/misc/tuner.cc b/projects/rccl/src/misc/tuner.cc index bfe61e8c1d..8f5b2ce345 100644 --- a/projects/rccl/src/misc/tuner.cc +++ b/projects/rccl/src/misc/tuner.cc @@ -30,25 +30,25 @@ ncclResult_t ncclLoadTunerPlugin(ncclTuner_t** tuner) { if (name) { INFO(NCCL_TUNING, "NCCL_TUNER_PLUGIN set to %s", name); tunerPluginLib = dlopen(name, RTLD_LAZY | RTLD_LOCAL); - } - if (tunerPluginLib == nullptr) { - // dlopen does not guarantee to set errno, but dlerror only gives us a - // string, so checking errno doesn't hurt to try to provide a better - // error message - if (errno == ENOENT) { - INFO(NCCL_TUNING, "Tuner: no plugin found '%s', using default tuner instead.", name); + if (tunerPluginLib == nullptr) { + // dlopen does not guarantee to set errno, but dlerror only gives us a + // string, so checking errno doesn't hurt to try to provide a better + // error message + if (errno == ENOENT) { + INFO(NCCL_TUNING, "Tuner: no plugin found '%s', using default tuner instead.", name); + } else { + INFO(NCCL_TUNING, "Tuner: plugin load '%s' returned error (%d : %s), using default tuner instead.", name, errno, dlerror()); + } } else { - INFO(NCCL_TUNING, "Tuner: plugin load '%s' returned error (%d : %s), using default tuner instead.", name, errno, dlerror()); - } - } else { - tunerSymbol = (ncclTuner_t*)dlsym(tunerPluginLib, NCCL_TUNER_PLUGIN_SYMBOL); - if (tunerSymbol == nullptr) { - INFO(NCCL_TUNING, "Tuner: failed to find " NCCL_TUNER_PLUGIN_SYMBOL " in plugin (%s), using default tuner instead.", name); - dlclose(tunerPluginLib); - tunerPluginLib = nullptr; - } else { - INFO(NCCL_TUNING, "Opened tuner: '%s'", tunerSymbol->name); - tunerPluginRefCount = 0; + tunerSymbol = (ncclTuner_t*)dlsym(tunerPluginLib, NCCL_TUNER_PLUGIN_SYMBOL); + if (tunerSymbol == nullptr) { + INFO(NCCL_TUNING, "Tuner: failed to find " NCCL_TUNER_PLUGIN_SYMBOL " in plugin (%s), using default tuner instead.", name); + dlclose(tunerPluginLib); + tunerPluginLib = nullptr; + } else { + INFO(NCCL_TUNING, "Opened tuner: '%s'", tunerSymbol->name); + tunerPluginRefCount = 0; + } } } } diff --git a/projects/rccl/src/misc/utils.cc b/projects/rccl/src/misc/utils.cc index b775666799..74d5b6d24c 100644 --- a/projects/rccl/src/misc/utils.cc +++ b/projects/rccl/src/misc/utils.cc @@ -291,3 +291,79 @@ void ncclMemoryStackDestruct(struct ncclMemoryStack* me) { h = h1; } } + +const char* ncclOpToString(ncclRedOp_t op) { + switch (op) { + case ncclSum: + return "ncclSum"; + case ncclProd: + return "ncclProd"; + case ncclMax: + return "ncclMax"; + case ncclMin: + return "ncclMin"; + case ncclAvg: + return "ncclAvg"; + default: + return "Unknown"; + } +} + +const char* ncclDatatypeToString(ncclDataType_t type) { + switch (type) { + case ncclInt8: // ncclChar + return "ncclInt8"; + case ncclInt32: // ncclInt + return "ncclInt32"; + case ncclUint32: + return "ncclUint32"; + case ncclInt64: + return "ncclInt64"; + case ncclUint64: + return "ncclUint64"; + case ncclFloat16: // ncclHalf + return "ncclFloat16"; + case ncclFloat32: // ncclFloat + return "ncclFloat32"; + case ncclFloat64: // ncclDouble + return "ncclFloat64"; +#if defined(__CUDA_BF16_TYPES_EXIST__) + case ncclBfloat16: + return "ncclBfloat16"; +#endif + default: + return "Unknown"; + } +} + +const char* ncclAlgoToString(int algo) { + switch (algo) { + case NCCL_ALGO_TREE: + return "TREE"; + case NCCL_ALGO_RING: + return "RING"; + case NCCL_ALGO_COLLNET_DIRECT: + return "COLLNET_DIRECT"; + case NCCL_ALGO_COLLNET_CHAIN: + return "COLLNET_CHAIN"; + case NCCL_ALGO_NVLS: + return "NVLS"; + case NCCL_ALGO_NVLS_TREE: + return "NVLS_TREE"; + default: + return "Unknown"; + } +} + +const char* ncclProtoToString(int proto) { + switch (proto) { + case NCCL_PROTO_LL: + return "LL"; + case NCCL_PROTO_LL128: + return "LL128"; + case NCCL_PROTO_SIMPLE: + return "SIMPLE"; + default: + return "Unknown"; + } +} diff --git a/projects/rccl/src/nccl.h.in b/projects/rccl/src/nccl.h.in index 1585d58acb..901d8c0d28 100644 --- a/projects/rccl/src/nccl.h.in +++ b/projects/rccl/src/nccl.h.in @@ -154,9 +154,7 @@ ncclResult_t pncclCommSplit(ncclComm_t comm, int color, int key, ncclComm_t *new const char* ncclGetErrorString(ncclResult_t result); const char* pncclGetErrorString(ncclResult_t result); -/* Returns a human-readable message of the last error that occurred. - * comm is currently unused and can be set to NULL - */ +/* Returns a human-readable message of the last error that occurred. */ const char* ncclGetLastError(ncclComm_t comm); const char* pncclGetLastError(ncclComm_t comm); @@ -176,6 +174,15 @@ ncclResult_t pncclCommCuDevice(const ncclComm_t comm, int* device); ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank); ncclResult_t pncclCommUserRank(const ncclComm_t comm, int* rank); + +/* Register CUDA buffer for zero-copy operation */ +ncclResult_t ncclCommRegister(const ncclComm_t comm, void* buff, size_t size, void** handle); +ncclResult_t pncclCommRegister(const ncclComm_t comm, void* buff, size_t size, void** handle); + +/* Deregister CUDA buffer */ +ncclResult_t ncclCommDeregister(const ncclComm_t comm, void* handle); +ncclResult_t pncclCommDeregister(const ncclComm_t comm, void* handle); + /* Reduction operation selector */ typedef enum { ncclNumOps_dummy = 5 } ncclRedOp_dummy_t; typedef enum { ncclSum = 0, diff --git a/projects/rccl/src/net.cc b/projects/rccl/src/net.cc index 2bfc9a9277..ba3479282d 100644 --- a/projects/rccl/src/net.cc +++ b/projects/rccl/src/net.cc @@ -15,16 +15,67 @@ //#include //#include -static ncclNet_v7_t ncclNet_v5_as_v7; -static ncclNet_v7_t ncclNet_v6_as_v7; +static ncclNet_v8_t ncclNet_v5_as_v8; +static ncclNet_v8_t ncclNet_v6_as_v8; +static ncclNet_v8_t ncclNet_v7_as_v8; static ncclNet_v5_t *ncclNet_v5; static ncclNet_v6_t *ncclNet_v6; -static ncclCollNet_v7_t ncclCollNet_v5_as_v7; -static ncclCollNet_v7_t ncclCollNet_v6_as_v7; +static ncclNet_v7_t *ncclNet_v7; +static ncclCollNet_v8_t ncclCollNet_v5_as_v8; +static ncclCollNet_v8_t ncclCollNet_v6_as_v8; +static ncclCollNet_v8_t ncclCollNet_v7_as_v8; static ncclCollNet_v5_t *ncclCollNet_v5; static ncclCollNet_v6_t *ncclCollNet_v6; +static ncclCollNet_v7_t *ncclCollNet_v7; -static ncclResult_t ncclNet_v6_as_v7_getProperties(int dev, ncclNetProperties_v7_t* props) { +static ncclResult_t ncclNet_v7_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { + ncclNetProperties_v7_t p7; + ncclResult_t ans = ncclNet_v7->getProperties(dev, &p7); + if (ans != ncclSuccess) return ans; + props->name = p7.name; + props->pciPath = p7.pciPath; + props->guid = p7.guid; + props->ptrSupport = p7.ptrSupport; + props->regIsGlobal = 0; + props->speed = p7.speed; + props->port = p7.port; + props->maxComms = p7.maxComms; + props->maxRecvs = p7.maxRecvs; + props->latency = p7.latency; + props->netDeviceType = p7.netDeviceType; + props->netDeviceVersion = p7.netDeviceVersion; + return ncclSuccess; +} + +static ncclResult_t ncclNet_v7_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { + if (size >= 1<<31) return ncclInternalError; + return ncclNet_v7->regMr(comm, data, (int) size, type, mhandle); +} + +static ncclResult_t ncclNet_v7_as_v8_init(ncclDebugLogger_t logfn) { + NCCLCHECK(ncclNet_v7->init(logfn)); + ncclNet_v7_as_v8.name = ncclNet_v7->name; + ncclNet_v7_as_v8.devices = ncclNet_v7->devices; + ncclNet_v7_as_v8.getProperties = ncclNet_v7_as_v8_getProperties; // ncclNet_v5->getProperties; + ncclNet_v7_as_v8.listen = ncclNet_v7->listen; + ncclNet_v7_as_v8.connect = ncclNet_v7->connect; + ncclNet_v7_as_v8.accept = ncclNet_v7->accept; + ncclNet_v7_as_v8.regMr = ncclNet_v7_as_v8_regMr; + ncclNet_v7_as_v8.regMrDmaBuf = ncclNet_v7->regMrDmaBuf; + ncclNet_v7_as_v8.deregMr = ncclNet_v7->deregMr; + ncclNet_v7_as_v8.isend = ncclNet_v7->isend; + ncclNet_v7_as_v8.irecv = ncclNet_v7->irecv; + ncclNet_v7_as_v8.iflush = ncclNet_v7->iflush; + ncclNet_v7_as_v8.test = ncclNet_v7->test; + ncclNet_v7_as_v8.closeSend = ncclNet_v7->closeSend; + ncclNet_v7_as_v8.closeRecv = ncclNet_v7->closeRecv; + ncclNet_v7_as_v8.closeListen = ncclNet_v7->closeListen; + ncclNet_v7_as_v8.getDeviceMr = ncclNet_v7->getDeviceMr; + ncclNet_v7_as_v8.irecvConsumed = ncclNet_v7->irecvConsumed; + return ncclSuccess; +} + +static ncclResult_t ncclNet_v6_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclNet_v6->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -32,6 +83,7 @@ static ncclResult_t ncclNet_v6_as_v7_getProperties(int dev, ncclNetProperties_v7 props->pciPath = p6.pciPath; props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; + props->regIsGlobal = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -42,38 +94,43 @@ static ncclResult_t ncclNet_v6_as_v7_getProperties(int dev, ncclNetProperties_v7 return ncclSuccess; } -static ncclResult_t ncclNet_v6_as_v7_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { +static ncclResult_t ncclNet_v6_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { + if (size >= 1<<31) return ncclInternalError; + return ncclNet_v6->regMr(comm, data, (int) size, type, mhandle); +} + +static ncclResult_t ncclNet_v6_as_v8_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { return ncclNet_v6->connect(dev, handle, sendComm); } -static ncclResult_t ncclNet_v6_as_v7_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { +static ncclResult_t ncclNet_v6_as_v8_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { return ncclNet_v6->accept(listenComm, recvComm); } -static ncclResult_t ncclNet_v6_as_v7_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclNet_v6_as_v8_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclNet_v6->init(logfn)); - ncclNet_v6_as_v7.name = ncclNet_v6->name; - ncclNet_v6_as_v7.devices = ncclNet_v6->devices; - ncclNet_v6_as_v7.getProperties = ncclNet_v6_as_v7_getProperties; // ncclNet_v5->getProperties; - ncclNet_v6_as_v7.listen = ncclNet_v6->listen; - ncclNet_v6_as_v7.connect = ncclNet_v6_as_v7_connect; - ncclNet_v6_as_v7.accept = ncclNet_v6_as_v7_accept; - ncclNet_v6_as_v7.regMr = ncclNet_v6->regMr; - ncclNet_v6_as_v7.regMrDmaBuf = ncclNet_v6->regMrDmaBuf; - ncclNet_v6_as_v7.deregMr = ncclNet_v6->deregMr; - ncclNet_v6_as_v7.isend = ncclNet_v6->isend; - ncclNet_v6_as_v7.irecv = ncclNet_v6->irecv; - ncclNet_v6_as_v7.iflush = ncclNet_v6->iflush; - ncclNet_v6_as_v7.test = ncclNet_v6->test; - ncclNet_v6_as_v7.closeSend = ncclNet_v6->closeSend; - ncclNet_v6_as_v7.closeRecv = ncclNet_v6->closeRecv; - ncclNet_v6_as_v7.closeListen = ncclNet_v6->closeListen; - ncclNet_v6_as_v7.getDeviceMr = NULL; - ncclNet_v6_as_v7.irecvConsumed = NULL; + ncclNet_v6_as_v8.name = ncclNet_v6->name; + ncclNet_v6_as_v8.devices = ncclNet_v6->devices; + ncclNet_v6_as_v8.getProperties = ncclNet_v6_as_v8_getProperties; // ncclNet_v5->getProperties; + ncclNet_v6_as_v8.listen = ncclNet_v6->listen; + ncclNet_v6_as_v8.connect = ncclNet_v6_as_v8_connect; + ncclNet_v6_as_v8.accept = ncclNet_v6_as_v8_accept; + ncclNet_v6_as_v8.regMr = ncclNet_v6_as_v8_regMr; + ncclNet_v6_as_v8.regMrDmaBuf = ncclNet_v6->regMrDmaBuf; + ncclNet_v6_as_v8.deregMr = ncclNet_v6->deregMr; + ncclNet_v6_as_v8.isend = ncclNet_v6->isend; + ncclNet_v6_as_v8.irecv = ncclNet_v6->irecv; + ncclNet_v6_as_v8.iflush = ncclNet_v6->iflush; + ncclNet_v6_as_v8.test = ncclNet_v6->test; + ncclNet_v6_as_v8.closeSend = ncclNet_v6->closeSend; + ncclNet_v6_as_v8.closeRecv = ncclNet_v6->closeRecv; + ncclNet_v6_as_v8.closeListen = ncclNet_v6->closeListen; + ncclNet_v6_as_v8.getDeviceMr = NULL; + ncclNet_v6_as_v8.irecvConsumed = NULL; return ncclSuccess; } -static ncclResult_t ncclNet_v5_as_v7_getProperties(int dev, ncclNetProperties_v7_t* props) { +static ncclResult_t ncclNet_v5_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclNet_v5->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -81,6 +138,7 @@ static ncclResult_t ncclNet_v5_as_v7_getProperties(int dev, ncclNetProperties_v7 props->pciPath = p6.pciPath; props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; + props->regIsGlobal = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -91,40 +149,45 @@ static ncclResult_t ncclNet_v5_as_v7_getProperties(int dev, ncclNetProperties_v7 return ncclSuccess; } -static ncclResult_t ncclNet_v5_as_v7_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { +static ncclResult_t ncclNet_v5_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { + if (size >= 1<<31) return ncclInternalError; + return ncclNet_v5->regMr(comm, data, (int) size, type, mhandle); +} + +static ncclResult_t ncclNet_v5_as_v8_connect(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { return ncclNet_v5->connect(dev, handle, sendComm); } -static ncclResult_t ncclNet_v5_as_v7_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { +static ncclResult_t ncclNet_v5_as_v8_accept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { return ncclNet_v5->accept(listenComm, recvComm); } // We use a wrapper around the v5 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclNet_v5_as_v7_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclNet_v5_as_v8_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclNet_v5->init(logfn)); - ncclNet_v5_as_v7.name = ncclNet_v5->name; - ncclNet_v5_as_v7.devices = ncclNet_v5->devices; - ncclNet_v5_as_v7.getProperties = ncclNet_v5_as_v7_getProperties; - ncclNet_v5_as_v7.listen = ncclNet_v5->listen; - ncclNet_v5_as_v7.connect = ncclNet_v5_as_v7_connect; - ncclNet_v5_as_v7.accept = ncclNet_v5_as_v7_accept; - ncclNet_v5_as_v7.regMr = ncclNet_v5->regMr; - ncclNet_v5_as_v7.regMrDmaBuf = NULL; - ncclNet_v5_as_v7.deregMr = ncclNet_v5->deregMr; - ncclNet_v5_as_v7.isend = ncclNet_v5->isend; - ncclNet_v5_as_v7.irecv = ncclNet_v5->irecv; - ncclNet_v5_as_v7.iflush = ncclNet_v5->iflush; - ncclNet_v5_as_v7.test = ncclNet_v5->test; - ncclNet_v5_as_v7.closeSend = ncclNet_v5->closeSend; - ncclNet_v5_as_v7.closeRecv = ncclNet_v5->closeRecv; - ncclNet_v5_as_v7.closeListen = ncclNet_v5->closeListen; - ncclNet_v5_as_v7.getDeviceMr = NULL; - ncclNet_v5_as_v7.irecvConsumed = NULL; + ncclNet_v5_as_v8.name = ncclNet_v5->name; + ncclNet_v5_as_v8.devices = ncclNet_v5->devices; + ncclNet_v5_as_v8.getProperties = ncclNet_v5_as_v8_getProperties; + ncclNet_v5_as_v8.listen = ncclNet_v5->listen; + ncclNet_v5_as_v8.connect = ncclNet_v5_as_v8_connect; + ncclNet_v5_as_v8.accept = ncclNet_v5_as_v8_accept; + ncclNet_v5_as_v8.regMr = ncclNet_v5_as_v8_regMr; + ncclNet_v5_as_v8.regMrDmaBuf = NULL; + ncclNet_v5_as_v8.deregMr = ncclNet_v5->deregMr; + ncclNet_v5_as_v8.isend = ncclNet_v5->isend; + ncclNet_v5_as_v8.irecv = ncclNet_v5->irecv; + ncclNet_v5_as_v8.iflush = ncclNet_v5->iflush; + ncclNet_v5_as_v8.test = ncclNet_v5->test; + ncclNet_v5_as_v8.closeSend = ncclNet_v5->closeSend; + ncclNet_v5_as_v8.closeRecv = ncclNet_v5->closeRecv; + ncclNet_v5_as_v8.closeListen = ncclNet_v5->closeListen; + ncclNet_v5_as_v8.getDeviceMr = NULL; + ncclNet_v5_as_v8.irecvConsumed = NULL; return ncclSuccess; } -static ncclResult_t ncclCollNet_v5_as_v7_getProperties(int dev, ncclNetProperties_v7_t* props) { +static ncclResult_t ncclCollNet_v5_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclCollNet_v5->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -132,6 +195,7 @@ static ncclResult_t ncclCollNet_v5_as_v7_getProperties(int dev, ncclNetPropertie props->pciPath = p6.pciPath; props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; + props->regIsGlobal = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -142,28 +206,35 @@ static ncclResult_t ncclCollNet_v5_as_v7_getProperties(int dev, ncclNetPropertie return ncclSuccess; } +static ncclResult_t ncclCollNet_v5_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { + if (size >= 1<<31) return ncclInternalError; + return ncclCollNet_v5->regMr(comm, data, (int) size, type, mhandle); +} + // We use a wrapper around the v5 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclCollNet_v5_as_v7_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclCollNet_v5_as_v8_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclCollNet_v5->init(logfn)); - ncclCollNet_v5_as_v7.name = ncclCollNet_v5->name; - ncclCollNet_v5_as_v7.devices = ncclCollNet_v5->devices; - ncclCollNet_v5_as_v7.getProperties = ncclCollNet_v5_as_v7_getProperties; - ncclCollNet_v5_as_v7.listen = ncclCollNet_v5->listen; - ncclCollNet_v5_as_v7.connect = ncclCollNet_v5->connect; - ncclCollNet_v5_as_v7.reduceSupport = ncclCollNet_v5->reduceSupport; - ncclCollNet_v5_as_v7.regMr = ncclCollNet_v5->regMr; - ncclCollNet_v5_as_v7.regMrDmaBuf = NULL; - ncclCollNet_v5_as_v7.deregMr = ncclCollNet_v5->deregMr; - ncclCollNet_v5_as_v7.iallreduce = ncclCollNet_v5->iallreduce; - ncclCollNet_v5_as_v7.iflush = ncclCollNet_v5->iflush; - ncclCollNet_v5_as_v7.test = ncclCollNet_v5->test; - ncclCollNet_v5_as_v7.closeColl = ncclCollNet_v5->closeColl; - ncclCollNet_v5_as_v7.closeListen = ncclCollNet_v5->closeListen; + ncclCollNet_v5_as_v8.name = ncclCollNet_v5->name; + ncclCollNet_v5_as_v8.devices = ncclCollNet_v5->devices; + ncclCollNet_v5_as_v8.getProperties = ncclCollNet_v5_as_v8_getProperties; + ncclCollNet_v5_as_v8.listen = ncclCollNet_v5->listen; + ncclCollNet_v5_as_v8.connect = ncclCollNet_v5->connect; + ncclCollNet_v5_as_v8.reduceSupport = ncclCollNet_v5->reduceSupport; + ncclCollNet_v5_as_v8.regMr = ncclCollNet_v5_as_v8_regMr; + ncclCollNet_v5_as_v8.regMrDmaBuf = NULL; + ncclCollNet_v5_as_v8.deregMr = ncclCollNet_v5->deregMr; + ncclCollNet_v5_as_v8.iallreduce = ncclCollNet_v5->iallreduce; + ncclCollNet_v5_as_v8.iallgather = nullptr; + ncclCollNet_v5_as_v8.ireducescatter = nullptr; + ncclCollNet_v5_as_v8.iflush = ncclCollNet_v5->iflush; + ncclCollNet_v5_as_v8.test = ncclCollNet_v5->test; + ncclCollNet_v5_as_v8.closeColl = ncclCollNet_v5->closeColl; + ncclCollNet_v5_as_v8.closeListen = ncclCollNet_v5->closeListen; return ncclSuccess; } -static ncclResult_t ncclCollNet_v6_as_v7_getProperties(int dev, ncclNetProperties_v7_t* props) { +static ncclResult_t ncclCollNet_v6_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { ncclNetProperties_v6_t p6; ncclResult_t ans = ncclCollNet_v6->getProperties(dev, &p6); if (ans != ncclSuccess) return ans; @@ -171,6 +242,7 @@ static ncclResult_t ncclCollNet_v6_as_v7_getProperties(int dev, ncclNetPropertie props->pciPath = p6.pciPath; props->guid = p6.guid; props->ptrSupport = p6.ptrSupport; + props->regIsGlobal = 0; props->speed = p6.speed; props->port = p6.port; props->maxComms = p6.maxComms; @@ -181,24 +253,78 @@ static ncclResult_t ncclCollNet_v6_as_v7_getProperties(int dev, ncclNetPropertie return ncclSuccess; } -// We use a wrapper around the v5 init to copy over the struct contents +static ncclResult_t ncclCollNet_v6_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { + if (size >= 1<<31) return ncclInternalError; + return ncclCollNet_v6->regMr(comm, data, (int) size, type, mhandle); +} + +// We use a wrapper around the v6 init to copy over the struct contents // post-init since they may not be initialized before hand. -static ncclResult_t ncclCollNet_v6_as_v7_init(ncclDebugLogger_t logfn) { +static ncclResult_t ncclCollNet_v6_as_v8_init(ncclDebugLogger_t logfn) { NCCLCHECK(ncclCollNet_v6->init(logfn)); - ncclCollNet_v6_as_v7.name = ncclCollNet_v6->name; - ncclCollNet_v6_as_v7.devices = ncclCollNet_v6->devices; - ncclCollNet_v6_as_v7.getProperties = ncclCollNet_v6_as_v7_getProperties; - ncclCollNet_v6_as_v7.listen = ncclCollNet_v6->listen; - ncclCollNet_v6_as_v7.connect = ncclCollNet_v6->connect; - ncclCollNet_v6_as_v7.reduceSupport = ncclCollNet_v6->reduceSupport; - ncclCollNet_v6_as_v7.regMr = ncclCollNet_v6->regMr; - ncclCollNet_v6_as_v7.regMrDmaBuf = ncclCollNet_v6->regMrDmaBuf; - ncclCollNet_v6_as_v7.deregMr = ncclCollNet_v6->deregMr; - ncclCollNet_v6_as_v7.iallreduce = ncclCollNet_v6->iallreduce; - ncclCollNet_v6_as_v7.iflush = ncclCollNet_v6->iflush; - ncclCollNet_v6_as_v7.test = ncclCollNet_v6->test; - ncclCollNet_v6_as_v7.closeColl = ncclCollNet_v6->closeColl; - ncclCollNet_v6_as_v7.closeListen = ncclCollNet_v6->closeListen; + ncclCollNet_v6_as_v8.name = ncclCollNet_v6->name; + ncclCollNet_v6_as_v8.devices = ncclCollNet_v6->devices; + ncclCollNet_v6_as_v8.getProperties = ncclCollNet_v6_as_v8_getProperties; + ncclCollNet_v6_as_v8.listen = ncclCollNet_v6->listen; + ncclCollNet_v6_as_v8.connect = ncclCollNet_v6->connect; + ncclCollNet_v6_as_v8.reduceSupport = ncclCollNet_v6->reduceSupport; + ncclCollNet_v6_as_v8.regMr = ncclCollNet_v6_as_v8_regMr; + ncclCollNet_v6_as_v8.regMrDmaBuf = ncclCollNet_v6->regMrDmaBuf; + ncclCollNet_v6_as_v8.deregMr = ncclCollNet_v6->deregMr; + ncclCollNet_v6_as_v8.iallreduce = ncclCollNet_v6->iallreduce; + ncclCollNet_v6_as_v8.iallgather = nullptr; + ncclCollNet_v6_as_v8.ireducescatter = nullptr; + ncclCollNet_v6_as_v8.iflush = ncclCollNet_v6->iflush; + ncclCollNet_v6_as_v8.test = ncclCollNet_v6->test; + ncclCollNet_v6_as_v8.closeColl = ncclCollNet_v6->closeColl; + ncclCollNet_v6_as_v8.closeListen = ncclCollNet_v6->closeListen; + return ncclSuccess; +} + +static ncclResult_t ncclCollNet_v7_as_v8_getProperties(int dev, ncclNetProperties_v8_t* props) { + ncclNetProperties_v7_t p7; + ncclResult_t ans = ncclCollNet_v7->getProperties(dev, &p7); + if (ans != ncclSuccess) return ans; + props->name = p7.name; + props->pciPath = p7.pciPath; + props->guid = p7.guid; + props->ptrSupport = p7.ptrSupport; + props->regIsGlobal = 0; + props->speed = p7.speed; + props->port = p7.port; + props->maxComms = p7.maxComms; + props->maxRecvs = p7.maxRecvs; + props->latency = p7.latency; + props->netDeviceType = NCCL_NET_DEVICE_HOST; + props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + return ncclSuccess; +} + +static ncclResult_t ncclCollNet_v7_as_v8_regMr(void* comm, void* data, size_t size, int type, void** mhandle) { + if (size >= 1<<31) return ncclInternalError; + return ncclCollNet_v7->regMr(comm, data, (int) size, type, mhandle); +} + +// We use a wrapper around the v7 init to copy over the struct contents +// post-init since they may not be initialized before hand. +static ncclResult_t ncclCollNet_v7_as_v8_init(ncclDebugLogger_t logfn) { + NCCLCHECK(ncclCollNet_v7->init(logfn)); + ncclCollNet_v7_as_v8.name = ncclCollNet_v7->name; + ncclCollNet_v7_as_v8.devices = ncclCollNet_v7->devices; + ncclCollNet_v7_as_v8.getProperties = ncclCollNet_v7_as_v8_getProperties; + ncclCollNet_v7_as_v8.listen = ncclCollNet_v7->listen; + ncclCollNet_v7_as_v8.connect = ncclCollNet_v7->connect; + ncclCollNet_v7_as_v8.reduceSupport = ncclCollNet_v7->reduceSupport; + ncclCollNet_v7_as_v8.regMr = ncclCollNet_v7_as_v8_regMr; + ncclCollNet_v7_as_v8.regMrDmaBuf = ncclCollNet_v7->regMrDmaBuf; + ncclCollNet_v7_as_v8.deregMr = ncclCollNet_v7->deregMr; + ncclCollNet_v7_as_v8.iallreduce = ncclCollNet_v7->iallreduce; + ncclCollNet_v7_as_v8.iallgather = nullptr; + ncclCollNet_v7_as_v8.ireducescatter = nullptr; + ncclCollNet_v7_as_v8.iflush = ncclCollNet_v7->iflush; + ncclCollNet_v7_as_v8.test = ncclCollNet_v7->test; + ncclCollNet_v7_as_v8.closeColl = ncclCollNet_v7->closeColl; + ncclCollNet_v7_as_v8.closeListen = ncclCollNet_v7->closeListen; return ncclSuccess; } @@ -236,54 +362,72 @@ ncclResult_t ncclNetPluginInit() { return ncclSuccess; } - ncclNets[0] = (ncclNet_v7_t*)dlsym(netPluginLib, "ncclNetPlugin_v7"); + ncclNets[0] = (ncclNet_v8_t*)dlsym(netPluginLib, "ncclNetPlugin_v8"); if (ncclNets[0] == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin_v7 symbol."); - // Try v6 plugin - ncclNet_v6 = (ncclNet_v6_t*)dlsym(netPluginLib, "ncclNetPlugin_v6"); - if (ncclNet_v6 == nullptr) { - // Try v5 plugin - ncclNet_v5 = (ncclNet_v5_t*)dlsym(netPluginLib, "ncclNetPlugin_v5"); - if (ncclNet_v5 == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin symbol (>= v5). ncclNetPlugin symbols v4 and lower are not supported."); - if (netPluginLib != nullptr) dlclose(netPluginLib); - return ncclSuccess; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin_v8 symbol."); + // Try v7 plugin + ncclNet_v7 = (ncclNet_v7_t*)dlsym(netPluginLib, "ncclNetPlugin_v7"); + if (ncclNet_v7 == nullptr) { + // Try v6 plugin + ncclNet_v6 = (ncclNet_v6_t*)dlsym(netPluginLib, "ncclNetPlugin_v6"); + if (ncclNet_v6 == nullptr) { + // Try v5 plugin + ncclNet_v5 = (ncclNet_v5_t*)dlsym(netPluginLib, "ncclNetPlugin_v5"); + if (ncclNet_v5 == nullptr) { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclNetPlugin symbol (>= v5). ncclNetPlugin symbols v4 and lower are not supported."); + if (netPluginLib != nullptr) dlclose(netPluginLib); + return ncclSuccess; + } else { + ncclNets[0] = &ncclNet_v5_as_v8; + ncclNet_v5_as_v8.init = ncclNet_v5_as_v8_init; + // Set the name right away to allow for NCCL_NET=... to work + ncclNet_v5_as_v8.name = ncclNet_v5->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v5)", ncclNets[0]->name); + } } else { - ncclNets[0] = &ncclNet_v5_as_v7; - ncclNet_v5_as_v7.init = ncclNet_v5_as_v7_init; + ncclNets[0] = &ncclNet_v6_as_v8; + ncclNet_v6_as_v8.init = ncclNet_v6_as_v8_init; // Set the name right away to allow for NCCL_NET=... to work - ncclNet_v5_as_v7.name = ncclNet_v5->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v5)", ncclNets[0]->name); + ncclNet_v6_as_v8.name = ncclNet_v6->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v6)", ncclNets[0]->name); } } else { - ncclNets[0] = &ncclNet_v6_as_v7; - ncclNet_v6_as_v7.init = ncclNet_v6_as_v7_init; + ncclNets[0] = &ncclNet_v7_as_v8; + ncclNet_v7_as_v8.init = ncclNet_v7_as_v8_init; // Set the name right away to allow for NCCL_NET=... to work - ncclNet_v6_as_v7.name = ncclNet_v6->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v6)", ncclNets[0]->name); + ncclNet_v7_as_v8.name = ncclNet_v7->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded net plugin %s (v7)", ncclNets[0]->name); } } // Check for CollNet - ncclCollNets[0] = (ncclCollNet_v7_t*) dlsym(netPluginLib, "ncclCollNetPlugin_v7"); + ncclCollNets[0] = (ncclCollNet_v8_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v8"); if (ncclCollNets[0] == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin_v7 symbol."); - ncclCollNet_v6 = (ncclCollNet_v6_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v6"); - if (ncclCollNet_v6 == nullptr) { - ncclCollNet_v5 = (ncclCollNet_v5_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v5"); - if (ncclCollNet_v5 == nullptr) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin symbol (>= v5). ncclCollNetPlugin symbols v4 and lower are not supported."); + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin_v8 symbol."); + ncclCollNet_v7 = (ncclCollNet_v7_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v7"); + if (ncclCollNet_v7 == nullptr) { + ncclCollNet_v6 = (ncclCollNet_v6_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v6"); + if (ncclCollNet_v6 == nullptr) { + ncclCollNet_v5 = (ncclCollNet_v5_t*)dlsym(netPluginLib, "ncclCollNetPlugin_v5"); + if (ncclCollNet_v5 == nullptr) { + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find ncclCollNetPlugin symbol (>= v5). ncclCollNetPlugin symbols v4 and lower are not supported."); + } else { + ncclCollNets[0] = &ncclCollNet_v5_as_v8; + ncclCollNet_v5_as_v8.init = ncclCollNet_v5_as_v8_init; + ncclCollNet_v5_as_v8.name = ncclCollNet_v5->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded coll plugin %s (v5)", ncclCollNets[0]->name); + } } else { - ncclCollNets[0] = &ncclCollNet_v5_as_v7; - ncclCollNet_v5_as_v7.init = ncclCollNet_v5_as_v7_init; - ncclCollNet_v5_as_v7.name = ncclCollNet_v5->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded coll plugin %s (v5)", ncclCollNets[0]->name); + ncclCollNets[0] = &ncclCollNet_v6_as_v8; + ncclCollNet_v6_as_v8.init = ncclCollNet_v6_as_v8_init; + ncclCollNet_v6_as_v8.name = ncclCollNet_v6->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded coll plugin %s (v6)", ncclCollNets[0]->name); } } else { - ncclCollNets[0] = &ncclCollNet_v6_as_v7; - ncclCollNet_v6_as_v7.init = ncclCollNet_v6_as_v7_init; - ncclCollNet_v6_as_v7.name = ncclCollNet_v6->name; - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded coll plugin %s (v6)", ncclCollNets[0]->name); + ncclCollNets[0] = &ncclCollNet_v7_as_v8; + ncclCollNet_v7_as_v8.init = ncclCollNet_v7_as_v8_init; + ncclCollNet_v7_as_v8.name = ncclCollNet_v7->name; + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Loaded coll plugin %s (v7)", ncclCollNets[0]->name); } } return ncclSuccess; @@ -329,6 +473,7 @@ static ncclResult_t netGetState(int i, enum ncclNetState* state) { } static ncclResult_t collNetGetState(int i, enum ncclNetState* state) { + pthread_mutex_lock(&netLock); if (ncclCollNetStates[i] == ncclNetStateInit) { int ndev; if (ncclCollNets[i]->init(ncclDebugLog) != ncclSuccess) ncclCollNetStates[i] = ncclNetStateDisabled; @@ -336,6 +481,7 @@ static ncclResult_t collNetGetState(int i, enum ncclNetState* state) { else ncclCollNetStates[i] = ncclNetStateEnabled; } *state = ncclCollNetStates[i]; + pthread_mutex_unlock(&netLock); return ncclSuccess; } @@ -416,7 +562,7 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) { while (!connected) { // If we're aborting now, skip to cleanup - if (*comm->abortFlag) { + if (__atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED)) { goto cleanup2; } @@ -453,11 +599,9 @@ cleanup1: } int ncclNetVersion(struct ncclComm* comm) { - if (comm->ncclNet == &ncclNet_v5_as_v7) { - return 5; - } else if (comm->ncclNet == &ncclNet_v6_as_v7) { - return 6; - } else { - return 7; - } + return + (comm->ncclNet == &ncclNet_v5_as_v8) ? 5 : + (comm->ncclNet == &ncclNet_v6_as_v8) ? 6 : + (comm->ncclNet == &ncclNet_v7_as_v8) ? 7 : + 8; } diff --git a/projects/rccl/src/proxy.cc b/projects/rccl/src/proxy.cc index db36a1573e..b2b488264d 100644 --- a/projects/rccl/src/proxy.cc +++ b/projects/rccl/src/proxy.cc @@ -353,20 +353,22 @@ static ncclResult_t ncclProxyOpToArgs(struct ncclProxyOp* op, struct ncclProxyAr WARN("Proxy append out of bounds"); return ncclInternalError; } - //memset(sub, 0, sizeof(struct ncclProxySubArgs)); sub->connection = op->connection; sub->channelId = op->channelId; sub->nsteps = op->nsteps; sub->nbytes = op->nbytes; sub->peer = op->root; + sub->reg = op->reg; + sub->buffer = op->buffer; args->nsubs = subIndex+1; if (subIndex) { if ((args->sliceSteps != op->sliceSteps) || (args->chunkSteps != op->chunkSteps) || (args->protocol != op->protocol) || (args->dtype != op->dtype) || - (args->redOp != op->redOp)) { + (args->redOp != op->redOp) || + (args->coll != op->coll)) { WARN("Proxy append mismatch"); return ncclInternalError; } @@ -386,6 +388,8 @@ static ncclResult_t ncclProxyOpToArgs(struct ncclProxyOp* op, struct ncclProxyAr args->redOp = op->redOp; args->pattern = op->pattern; args->protocol = op->protocol; + args->coll = op->coll; + args->specifics = op->specifics; args->state = ncclProxyOpReady; args->progress = op->connection->tcomm->proxyProgress; args->proxyAppendPtr = op->connection->proxyAppendPtr; @@ -590,7 +594,7 @@ ncclResult_t ncclProxySaveOp(struct ncclComm* comm, struct ncclProxyOp* op, bool NCCL_PARAM(ChunkSize, "CHUNK_SIZE", 0); -ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) { +ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op, int reg) { memset(op, 0, sizeof(struct ncclProxyOp)); int channelId = info->channelId; struct ncclChannel* channel = info->comm->channels+channelId; @@ -611,15 +615,17 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) op->pattern = ncclPatternSend; if (op->root != info->comm->rank && peer->send[1].transportComm == &netTransport.send) { // Tune chunk size for the network - if (info->count < stepSize) info->chunkSize /= 4; + if (info->protocol == NCCL_PROTO_SIMPLE && info->count < stepSize) info->chunkSize /= 4; else if (info->count < 8*stepSize) info->chunkSize /= 2; + if (info->protocol == NCCL_PROTO_SIMPLE && peer->send[1].proxyConn.sameProcess) op->reg = reg; } } else if (info->coll == ncclFuncRecv) { op->pattern = ncclPatternRecv; if (op->root != info->comm->rank && peer->recv[1].transportComm == &netTransport.recv) { // Tune chunk size for the network - if (info->count < stepSize) info->chunkSize /= 4; + if (info->protocol == NCCL_PROTO_SIMPLE && info->count < stepSize) info->chunkSize /= 4; else if (info->count < 8*stepSize) info->chunkSize /= 2; + if (info->protocol == NCCL_PROTO_SIMPLE && peer->recv[1].proxyConn.sameProcess) op->reg = reg; } } else { WARN("P2p operation is neither send or recv"); @@ -628,17 +634,21 @@ ncclResult_t ncclProxyComputeP2p(struct ncclInfo* info, struct ncclProxyOp* op) if (ncclParamChunkSize() != 0) { info->chunkSize = ncclParamChunkSize(); } + op->buffer = op->reg ? info->recvbuff : NULL; op->chunkSize = info->chunkSize; + op->nbytes = info->count; // Compute nSteps for proxies int chunkEffectiveSize = op->chunkSize; if (op->protocol == NCCL_PROTO_LL) { chunkEffectiveSize /= 2; + op->nbytes *= 2; + op->nbytes = DIVUP(op->nbytes, sizeof(union ncclLLFifoLine)) * sizeof(union ncclLLFifoLine); } - op->nbytes = stepSize; + if (!op->reg) op->nbytes = std::min(op->nbytes, (ssize_t)info->chunkSize); op->nsteps = DIVUP(info->count, chunkEffectiveSize); - if (op->nsteps == 0) op->nsteps = 1; + if (op->nsteps == 0 || op->reg) op->nsteps = 1; return ncclSuccess; } @@ -1069,35 +1079,60 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in return ncclSuccess; } -// cuMem API support -// The response is sent out-of-band using ncclIpcSocket for this specific command -ncclResult_t ncclProxyClientGetFdBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void *handle, int* convertedFd) { - ncclResult_t ret = ncclSuccess; - ncclResult_t res = ncclInProgress; +// UDS support +ncclResult_t ncclProxyCallBlockingUDS(struct ncclComm* comm, int tpRank, int type, void* reqBuff, int reqSize, void* respBuff, int respSize, int *respFd) { + ncclResult_t res = ncclSuccess; struct ncclIpcSocket ipcSock = { 0 }; void *opId = (void*)((((uintptr_t)random()) << 32) | random()); - // Create a UDS socket to receive the converted fd - NCCLCHECK(ncclIpcSocketInit(&ipcSock, comm->topParentLocalRanks[comm->localRank], (uint64_t)opId, comm->abortFlag)); + int rank = comm->topParentLocalRanks[comm->localRank]; + struct ncclProxyState* sharedProxyState = comm->proxyState; + uint64_t pidHash = sharedProxyState->peerAddressesUDS[tpRank]; - // Request the allocation of a UDS fd for the handle over sockets - NCCLCHECKGOTO(ncclProxyCallAsync(comm, proxyConn, ncclProxyMsgGetFd, handle, sizeof(CUmemGenericAllocationHandle), 0, opId), ret, error); + INFO(NCCL_PROXY, "ProxyCall UDS comm %p rank %d tpRank %d(%lx) reqSize %d respSize %d respFd %p opId %p", + comm, rank, tpRank, pidHash, reqSize, respSize, respFd, opId); - // Receive the converted fd over UDS - NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, convertedFd), ret, error); - TRACE(NCCL_PROXY, "UDS: ClientGetFd handle 0x%lx rank %d returned fd %d", *(uint64_t*)handle, proxyConn->tpLocalRank, *convertedFd); - NCCLCHECKGOTO(ncclIpcSocketClose(&ipcSock), ret, error); + // cuMem: Create a UDS socket to receive the response + NCCLCHECK(ncclIpcSocketInit(&ipcSock, rank, (uint64_t)opId, comm->abortFlag)); - // Wait for proxy response (sockets) - while (res == ncclInProgress) { - res = ncclPollProxyResponse(comm, proxyConn, NULL, opId); - } + ncclIpcHdr hdr; + hdr.type = type; + hdr.rank = rank; + hdr.reqSize = reqSize; + hdr.respSize = respSize; + hdr.opId = opId; + assert(reqSize <= sizeof(hdr.data)); + memcpy(&hdr.data, reqBuff, reqSize); + NCCLCHECKGOTO(ncclIpcSocketSendMsg(&ipcSock, &hdr, sizeof(hdr), -1, tpRank, pidHash), res, error); + NCCLCHECKGOTO(ncclIpcSocketRecvMsg(&ipcSock, respBuff, respSize, respFd), res, error); + NCCLCHECKGOTO(ncclIpcSocketClose(&ipcSock), res, error); + + INFO(NCCL_PROXY, "ProxyCall UDS comm %p rank %d tpRank %d(%lx) reqSize %d respSize %d respFd %d opId %p - DONE", + comm, rank, tpRank, pidHash, reqSize, respSize, (respFd ? *respFd : -1), opId); + + return res; + +error: + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + WARN("ncclProxyCallBlockingUDS call to tpRank %d(%lx) failed : %d", tpRank, pidHash, res); + return res; +} + +// cuMem API support +// The request/response is sent out-of-band using ncclIpcSocket for this specific command +ncclResult_t ncclProxyClientGetFdBlocking(struct ncclComm* comm, int tpRank, void *handle, int* convertedFd) { + ncclResult_t ret = ncclSuccess; + + // Request the allocation of a UDS fd for the handle + NCCLCHECKGOTO(ncclProxyCallBlockingUDS(comm, tpRank, ncclProxyMsgGetFd, handle, sizeof(CUmemGenericAllocationHandle), NULL, 0, convertedFd), ret, error); + + // We have now received the converted fd over UDS + INFO(NCCL_PROXY, "UDS: ClientGetFd handle 0x%lx tpRank %d returned fd %d", *(uint64_t*)handle, tpRank, *convertedFd); return ret; error: - NCCLCHECK(ncclIpcSocketClose(&ipcSock)); - WARN("ncclProxyClientGetFd call to rank %d handle 0x%lx failed : %d", proxyConn->tpRank, *(uint64_t*)handle, ret); + WARN("ncclProxyClientGetFd call to tpRank %d handle 0x%lx failed : %d", tpRank, *(uint64_t*)handle, ret); return ret; } @@ -1132,7 +1167,7 @@ error: ncclResult_t ncclPollProxyResponse(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* respBuff, void* opId) { struct ncclProxyState* sharedProxyState = comm->proxyState; // Receive the connection pointer from the Proxy - if (*comm->abortFlag) { + if (__atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED)) { WARN("Comm %p is in abort state", comm); return ncclInternalError; } @@ -1287,13 +1322,13 @@ static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclPr } // cuMem API support -static ncclResult_t proxyGetFd(struct ncclProxyLocalPeer* peer, void *opId, struct ncclProxyState* proxyState, uint64_t handle) { +static ncclResult_t proxyGetFd(struct ncclProxyState* proxyState, int rank, void *opId, uint64_t handle) { #if CUDART_VERSION >= 11030 // cuMem API support ncclResult_t ret = ncclSuccess; struct ncclIpcSocket ipcSock = { 0 }; uint64_t hash = (uint64_t) opId; - INFO(NCCL_PROXY, "UDS proxyGetFd received handle 0x%lx peer %d opId %lx", handle, peer->tpLocalRank, hash); + INFO(NCCL_PROXY, "UDS proxyGetFd received handle 0x%lx peer %d opId %lx", handle, rank, hash); CUmemAllocationHandleType type = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; int fd = -1; @@ -1301,7 +1336,7 @@ static ncclResult_t proxyGetFd(struct ncclProxyLocalPeer* peer, void *opId, stru CUCHECK(cuMemExportToShareableHandle(&fd, handle, type, 0)); // Send back the converted fd using UDS NCCLCHECKGOTO(ncclIpcSocketInit(&ipcSock, proxyState->tpRank, hash^1, proxyState->abortFlag), ret, error); - NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, peer->tpLocalRank, hash), ret, error); + NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, rank, hash), ret, error); error: NCCLCHECK(ncclIpcSocketClose(&ipcSock)); // We can now safely close the exported fd @@ -1326,11 +1361,8 @@ static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclP TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgSharedInit opId=%p op.reqBuff=%p nChannels=%d", op->opId, op->reqBuff, nChannels); if (op->connection->tcomm->proxySharedInit) res = op->connection->tcomm->proxySharedInit(op->connection, proxyState, nChannels); __atomic_store_n(&op->connection->state, connSharedInitialized, __ATOMIC_RELEASE); - } else if (op->type == ncclProxyMsgGetFd) { - uint64_t handle = *(uint64_t*)op->reqBuff; - TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgGetFd opId=%p op.reqBuff=%p handle=0x%lx", op->opId, op->reqBuff, handle); - res = proxyGetFd(peer, op->opId, proxyState, handle); // cuMem API support - } else if (op->type == ncclProxyMsgInit) { + } + else if (op->type == ncclProxyMsgInit) { TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgInit opId=%p op.reqBuff=%p", op->opId, op->reqBuff); res = proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection); } else return ncclInternalError; @@ -1360,7 +1392,7 @@ static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclP (*asyncOpCount)--; return ncclSuccess; - } else if (*proxyState->abortFlag != 0) { + } else if (__atomic_load_n(proxyState->abortFlag, __ATOMIC_RELAXED) != 0) { return ncclInternalError; } @@ -1446,7 +1478,7 @@ void* ncclProxyService(void* _args) { /* Even if local comm aborts, we cannot let proxy thread exit if we still have peer * connections. Need to wait until all other related comms call abort and safely exit * together, or we could face segmentation fault. */ - if (*proxyState->abortFlag != 0) stop = 1; + if (__atomic_load_n(proxyState->abortFlag, __ATOMIC_RELAXED) != 0) stop = 1; /* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */ int ret; do { @@ -1563,13 +1595,71 @@ void* ncclProxyService(void* _args) { return NULL; } -ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses) { + +// Process a request on the UDS socket +static ncclResult_t proxyUDSRecvReq(struct ncclProxyState* proxyState, int reqFd) { + ncclIpcHdr hdr; + NCCLCHECK(ncclIpcSocketRecvMsg(&proxyState->ipcSock, &hdr, sizeof(hdr), NULL)); + if (hdr.type == ncclProxyMsgGetFd) { + // cuMem API support + uint64_t handle = *(uint64_t*)hdr.data; + INFO(NCCL_PROXY, "proxyUDSRecvReq::ncclProxyMsgGetFd rank %d opId %p handle=0x%lx", hdr.rank, hdr.opId, handle); + return proxyGetFd(proxyState, hdr.rank, hdr.opId, handle); + } + + return ncclInternalError; +} + +// UDS fd handle support +void* ncclProxyServiceUDS(void* _args) { + struct ncclProxyState* proxyState = (struct ncclProxyState*) _args; + struct pollfd pollfds[1]; + + if (setProxyThreadContext(proxyState)) { + INFO(NCCL_INIT, "[Proxy Service UDS] Created CUDA context on device %d", proxyState->cudaDev); + } else if (cudaSetDevice(proxyState->cudaDev) != cudaSuccess) { + WARN("[Proxy Service UDS] Failed to set CUDA device %d", proxyState->cudaDev); + } + + if (ncclIpcSocketGetFd(&proxyState->ipcSock, &pollfds[0].fd) != ncclSuccess) { + WARN("[Proxy Service UDS] Get listenSock fd fails"); + return NULL; + }; + pollfds[0].events = POLLIN|POLLHUP; + + while (1) { + /* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */ + int ret; + do { + ret = poll(pollfds, 1, 500); + } while (ret < 0 && errno == EINTR); + if (ret < 0) { + WARN("[Proxy Service UDS] Poll failed: %s", strerror(errno)); + return NULL; + } + + // Check for stop/abort + if (proxyState->stop || *proxyState->abortFlag) break; + + if (pollfds[0].revents) { + // A request was seen on the UDS fd + proxyUDSRecvReq(proxyState, pollfds[0].fd); + } + } + + ncclIpcSocketClose(&proxyState->ipcSock); + INFO(NCCL_PROXY, "[Proxy Service UDS] exit: stop %d abortFlag %d", proxyState->stop, *proxyState->abortFlag); + return NULL; +} + +ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union ncclSocketAddress* peerAddresses, uint64_t *peerAddressesUDS) { assert(comm->sharedRes->proxyState == NULL); NCCLCHECK(ncclCalloc(&comm->sharedRes->proxyState, 1)); comm->proxyState = comm->sharedRes->proxyState; comm->proxyState->refCount = 1; comm->proxyState->listenSock = sock; comm->proxyState->peerAddresses = peerAddresses; + comm->proxyState->peerAddressesUDS = peerAddressesUDS; // Seed the random number generator for UDS filename generation struct timeval time; gettimeofday(&time,NULL); @@ -1601,6 +1691,12 @@ ncclResult_t ncclProxyCreate(struct ncclComm* comm) { pthread_create(&comm->proxyState->thread, NULL, ncclProxyService, comm->proxyState); ncclSetThreadName(comm->proxyState->thread, "NCCL Service %2d", comm->cudaDev); + + // UDS support + INFO(NCCL_PROXY, "UDS: Creating service thread comm %p rank %d pidHash %lx", comm, comm->rank, comm->peerInfo[comm->rank].pidHash); + NCCLCHECK(ncclIpcSocketInit(&comm->proxyState->ipcSock, comm->rank, comm->peerInfo[comm->rank].pidHash, comm->abortFlag)); + pthread_create(&comm->proxyState->threadUDS, NULL, ncclProxyServiceUDS, comm->proxyState); + ncclSetThreadName(comm->proxyState->threadUDS, "NCCL UDS Service %2d", comm->cudaDev); } return ncclSuccess; } @@ -1610,8 +1706,13 @@ ncclResult_t ncclProxyStop(struct ncclComm* comm) { struct ncclProxyState* sharedProxyState = comm->proxyState; if ((comm->proxyRefCountOld = ncclAtomicRefCountDecrement(&sharedProxyState->refCount)) == 0) { + if (comm->proxyState->threadUDS) { + // UDS support + comm->proxyState->stop = 1; + } + if (sharedProxyState->peerAddresses) { - if (*comm->abortFlag == 0) { + if (__atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED) == 0) { struct ncclSocket sock; int type = ncclProxyMsgStop; NCCLCHECK(ncclSocketInit(&sock, sharedProxyState->peerAddresses + comm->topParentRanks[comm->rank], comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag)); @@ -1636,7 +1737,7 @@ ncclResult_t ncclProxyStop(struct ncclComm* comm) { } } int type = ncclProxyMsgClose; - if (*comm->abortFlag == 0) NCCLCHECK(ncclSocketSend(sharedProxyState->peerSocks + i, &type, sizeof(int))); + if (__atomic_load_n(comm->abortFlag, __ATOMIC_RELAXED) == 0) NCCLCHECK(ncclSocketSend(sharedProxyState->peerSocks + i, &type, sizeof(int))); NCCLCHECK(ncclSocketClose(sharedProxyState->peerSocks + i)); } } @@ -1652,6 +1753,7 @@ ncclResult_t ncclProxyDestroy(struct ncclComm* comm) { assert(sharedProxyState->refCount == 0); free(sharedProxyState->peerAddresses); + free(sharedProxyState->peerAddressesUDS); free(sharedProxyState->peerSocks); free(sharedProxyState->proxyOps); free(sharedProxyState->sharedDevMems); diff --git a/projects/rccl/src/register.cc b/projects/rccl/src/register.cc new file mode 100644 index 0000000000..0e252a2f20 --- /dev/null +++ b/projects/rccl/src/register.cc @@ -0,0 +1,182 @@ +/************************************************************************* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "argcheck.h" // Need some checks here since we access comm +#include "nccl.h" +#include "comm.h" +#include "net.h" +#include "register.h" + +ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) { + struct ncclRegCache* cache = &comm->regCache; + ncclDebugNoWarn = NCCL_NET; + for (int d=0; dnDevs; d++) { + if (reg->handles[d] != NULL) NCCLCHECK(comm->ncclNet->deregMr(cache->sComms[reg->devs[d]], reg->handles[d])); + } + reg->nDevs = 0; + free(reg->handles); + reg->handles = NULL; + ncclDebugNoWarn = 0; + return ncclSuccess; +} + +ncclResult_t ncclNetRegister(struct ncclComm* comm, void* addr, size_t size, struct ncclReg* reg) { + struct ncclRegCache* cache = &comm->regCache; + int netCount; + NCCLCHECK(ncclTopoGetNetCount(comm->topo, &netCount)); + if (netCount == 0) return ncclSuccess; + + ncclResult_t ret = ncclSuccess; + + // Find local devices for p2p operations + for (int c=0; cp2pnChannels; c++) { + int dev; + if (ncclTopoGetLocalNet(comm->topo, comm->rank, c, &dev) != ncclSuccess) goto end; // No local net + ncclNetProperties_t props; + NCCLCHECKGOTO(comm->ncclNet->getProperties(dev, &props), ret, end); + if (props.regIsGlobal == 0) { // We need to be sure all NICs support global registration. + reg->nDevs = 0; + break; + } + int found = 0; + for (int d=0; dnDevs; d++) if (reg->devs[d] == dev) found = 1; + if (!found) reg->devs[reg->nDevs++] = dev; + } + + NCCLCHECKGOTO(ncclCalloc(®->handles, reg->nDevs), ret, end); + + ncclDebugNoWarn = NCCL_NET; + for (int d=0; dnDevs; d++) { + int dev = reg->devs[d]; + reg->handles[d] = NULL; + + if (cache->sComms[dev] == NULL) { + // Create a loopback network comm object for that device to register the buffers. + void *lComm = NULL; + ncclNetHandle_t netHandle; + bool connected = false; + NCCLCHECKGOTO(comm->ncclNet->listen(dev, &netHandle, &lComm), ret, end); + while (!connected) { + if (*comm->abortFlag) { + goto end; + } + if (cache->sComms[dev] == NULL) + NCCLCHECKGOTO(comm->ncclNet->connect(dev, &netHandle, cache->sComms+dev, NULL), ret, end); + if (cache->rComms[dev] == NULL) + NCCLCHECKGOTO(comm->ncclNet->accept(lComm, cache->rComms+dev, NULL), ret, end); + connected = (cache->rComms[dev] != NULL) && (cache->sComms[dev] != NULL); + } + NCCLCHECK(comm->ncclNet->closeListen(lComm)); + } + if (comm->ncclNet->regMr(cache->sComms[dev], addr, size, NCCL_PTR_CUDA, reg->handles+d) != ncclSuccess) { + reg->handles[d] = NULL; + NCCLCHECK(ncclNetDeregister(comm, reg)); + reg->nDevs = 0; + goto end; + } + } +end: + ncclDebugNoWarn = 0; + if (ret != ncclSuccess) NCCLCHECK(ncclNetDeregister(comm, reg)); + return ret; +} + +ncclResult_t ncclRegFind(struct ncclComm* comm, const void* data, size_t size, struct ncclReg** reg) { + struct ncclRegCache* cache = &comm->regCache; + uintptr_t pageSize = cache->pageSize; + uintptr_t addr = (uintptr_t)data & -pageSize; + size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; + + *reg = NULL; + for (int slot=0; /*true*/; slot++) { + if (slot == cache->population || addr < cache->slots[slot]->addr) return ncclSuccess; + if ((addr >= cache->slots[slot]->addr) && + ((addr-cache->slots[slot]->addr)/pageSize+pages) <= cache->slots[slot]->pages) { + *reg = cache->slots[slot]; + return ncclSuccess; + } + } +} +NCCL_PARAM(LocalRegister, "LOCAL_REGISTER", 1); + +ncclResult_t ncclRegister(struct ncclComm* comm, void* data, size_t size, void** handle) { + if (!ncclParamLocalRegister()) return ncclSuccess; + struct ncclRegCache* cache = &comm->regCache; + uintptr_t pageSize = cache->pageSize; + uintptr_t addr = (uintptr_t)data & -pageSize; + size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; + for (int slot=0; /*true*/; slot++) { + if ((slot == cache->population) || (addr < cache->slots[slot]->addr)) { + if (cache->population == cache->capacity) { // must grow cache + cache->capacity = cache->capacity < 32 ? 32 : 2*cache->capacity; + NCCLCHECK(ncclRealloc(&cache->slots, cache->population, cache->capacity)); + } + memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclReg*)); + NCCLCHECK(ncclCalloc(cache->slots+slot, 1)); + struct ncclReg* regSlot = cache->slots[slot]; + regSlot->addr = addr; + regSlot->pages = pages; + regSlot->refs = 1; + NCCLCHECK(ncclNetRegister(comm, (void*)addr, pages*pageSize, regSlot)); + regSlot->state |= NET_REG_COMPLETE; + cache->population += 1; + *handle = regSlot; + return ncclSuccess; + } else if ((addr >= cache->slots[slot]->addr) && + ((addr-cache->slots[slot]->addr)/pageSize+pages) <= cache->slots[slot]->pages) { + cache->slots[slot]->refs++; + *handle = cache->slots[slot]; + return ncclSuccess; + } + } +} + +ncclResult_t ncclRegCleanup(struct ncclComm* comm) { + struct ncclRegCache* cache = &comm->regCache; + for (int i=0; ipopulation; i++) { + INFO(NCCL_INIT, "Cleanup buffer %p pages %lx", (void*)cache->slots[i]->addr, cache->slots[i]->pages); + NCCLCHECK(ncclNetDeregister(comm, cache->slots[i])); + if (cache->slots[i]->state & NVLS_REG_COMPLETE) NCCLCHECK(ncclNvlsDeregBuffer(&cache->slots[i]->mcHandle, cache->slots[i]->regAddr, cache->slots[i]->dev, cache->slots[i]->regSize)); + free(cache->slots[i]); + } + free(cache->slots); + for (int d=0; dsComms[d]) NCCLCHECK(comm->ncclNet->closeSend(cache->sComms[d])); + if (cache->rComms[d]) NCCLCHECK(comm->ncclNet->closeRecv(cache->rComms[d])); + } + return ncclSuccess; +} + +NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size_t size, void** handle); +ncclResult_t ncclCommRegister(const ncclComm_t comm, void* buff, size_t size, void** handle) { + NCCLCHECK(PtrCheck(comm, "ncclCommRegister", "comm")); + if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister")); + NCCLCHECK(ncclRegister(comm, buff, size, handle)); + return ncclSuccess; +} + +NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); +ncclResult_t ncclCommDeregister(const ncclComm_t comm, void* handle) { + NCCLCHECK(PtrCheck(comm, "ncclCommRegister", "comm")); + struct ncclReg* reg = (struct ncclReg*)handle; + struct ncclRegCache* cache = &comm->regCache; + int slot; + for (slot=0; slotpopulation && cache->slots[slot] != reg; slot++); + if (slot == cache->population) { + WARN("Deregister: Could not find handle"); + return ncclInvalidUsage; + } + if (--reg->refs) return ncclSuccess; + NCCLCHECK(ncclNetDeregister(comm, reg)); + if (reg->state & NVLS_REG_COMPLETE) { + NCCLCHECK(ncclNvlsDeregBuffer(®->mcHandle, reg->regAddr, reg->dev, reg->regSize)); + reg->regAddr = (CUdeviceptr)NULL; + } + free(reg); + memmove(cache->slots+slot, cache->slots+slot+1, (cache->population-slot-1)*sizeof(struct ncclReg*)); + cache->population -= 1; + return ncclSuccess; +} diff --git a/projects/rccl/src/transport.cc b/projects/rccl/src/transport.cc index a465d6b5cf..e229d382ce 100644 --- a/projects/rccl/src/transport.cc +++ b/projects/rccl/src/transport.cc @@ -324,10 +324,10 @@ int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collN for (int r = 0; r < nranks; r++) { if (allConnects[r].isMaster) { memcpy(masterConnects+c, &(allConnects[r].connect), sizeof(struct ncclConnect)); - if (r == rank) rankInCollNet = c; c++; } } + if (isMaster) rankInCollNet = comm->node; } else { // send side : copy in connect info received from peer recv master if (isMaster) memcpy(masterConnects+rankInCollNet, &(sendrecvExchange.connect), sizeof(struct ncclConnect)); } diff --git a/projects/rccl/src/transport/coll_net.cc b/projects/rccl/src/transport/coll_net.cc index 04bab8b4f2..302e5263b0 100644 --- a/projects/rccl/src/transport/coll_net.cc +++ b/projects/rccl/src/transport/coll_net.cc @@ -86,8 +86,8 @@ struct connectMap { }; struct reqSlot { - volatile void* recvBuff; - volatile int size; + bool turnIsSendNotRecv; + int size; }; struct sendResources { @@ -243,9 +243,11 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne struct ncclRecvMem *recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, gpu, recvMem); send->conn.tail = &recvMem->tail; - send->conn.sizesFifo = recvMem->sizesFifo; - for (int i=0; iconn.sizesFifo[i] = -1; - send->conn.offsFifo = recvMem->offsFifo; + send->conn.connFifo = recvMem->connFifo; + for (int i=0; iconn.connFifo[i].size = -1; + send->conn.connFifo[i].mode = NCCL_MODE_OFFSET; + } for (int p=0; pconn.buffs[p] = NCCL_NET_MAP_GET_POINTER(map, gpu, buffs[p]); @@ -274,7 +276,10 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne struct ncclRecvMem *recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, gpu, recvMem); void* gdcMem = map->mems[NCCL_NET_MAP_GDCMEM].gpuPtr; recv->conn.tail = gdcMem ? (uint64_t*)gdcMem : &recvMem->tail; - recv->conn.offsFifo = recvMem->offsFifo; + recv->conn.connFifo = recvMem->connFifo; + for (int i=0; iconn.connFifo[i].mode = NCCL_MODE_OFFSET; + } for (int p=0; pconn.buffs[p] = NCCL_NET_MAP_GET_POINTER(map, gpu, buffs[p]); @@ -376,6 +381,8 @@ static ncclResult_t sharedBuffersInit(struct ncclCollNetSharedRes* collNet, int if (cuda && collNet->cudaBuff == NULL) { NCCLCHECK(ncclCudaCalloc(&collNet->cudaBuff, *size)); + cudaMemset(collNet->cudaBuff, 0x33, *size/2); + cudaMemset((char*)collNet->cudaBuff + *size/2, 0x66, *size/2); } if (!cuda && collNet->hostBuff == NULL) { NCCLCHECK(ncclCudaHostCalloc(&collNet->hostBuff, *size)); @@ -471,7 +478,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str resources->sendMem = (struct ncclSendMem*) NCCL_NET_MAP_GET_POINTER(map, cpu, sendMem); resources->recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, cpu, recvMem); // Don't give credits yet in shared mode. - resources->sendMem->head = -NCCL_STEPS; + (resources->gdcSync ? *resources->gdcSync : resources->sendMem->head) = -NCCL_STEPS; // Allocate & Register shared buffers for the Simple protocol int bank = resources->useGdr ? NCCL_NET_MAP_SHARED_DEVMEM : NCCL_NET_MAP_SHARED_HOSTMEM; @@ -617,9 +624,49 @@ static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct return ncclSuccess; } +static size_t calcAlgoOffset(struct ncclProxyArgs* args, int isAllNotOne, int sub, uint64_t step) { + int chunkSize = args->chunkSize; + int nNodes = args->specifics.collnetDirect.nNodes; + int node = args->specifics.collnetDirect.node; + size_t sizePerRank = args->specifics.collnetDirect.sizePerRank; + size_t offset = (step*(args->nsubs) + sub)*chunkSize; + if (isAllNotOne) { + offset = std::min(offset, nNodes*sizePerRank); + } else { + offset = std::max(offset, (node+0)*sizePerRank); + offset = std::min(offset, (node+1)*sizePerRank); + } + return offset; +} -#define LAST_OF_GROUP(s) \ - (s % COLLNET_GROUP_NSUBS == COLLNET_GROUP_NSUBS-1 || s == args->nsubs-1) +static int calcRegionOffset( + struct ncclProxyArgs* args, int isRecvNotSend, int sub, uint64_t step, + int side // 0=begin, 1=end + ) { + struct ncclCollNetSharedRes* collNet = args->subs[0].connection->collNet; + int slotSize = collNet->buffSize/NCCL_STEPS; + int chunkSize = args->chunkSize; + int base = isRecvNotSend*NCCL_STEPS + (step%NCCL_STEPS); + base *= collNet->nChannels*slotSize; + if (args->coll == ncclFuncAllReduce) { + return base + (sub+side)*chunkSize; + } else { + int isAllNotOne = isRecvNotSend ^ (args->coll == ncclFuncReduceScatter); + int sub0 = sub - (sub%COLLNET_GROUP_NSUBS); + size_t off = sub0*slotSize; + off += calcAlgoOffset(args, isAllNotOne, sub+side, step) + - calcAlgoOffset(args, isAllNotOne, sub0, step); + return base + off; + } +} + +#define LAST_OF_GROUP(args, s) \ + ((s)%COLLNET_GROUP_NSUBS == COLLNET_GROUP_NSUBS-1 || (s) == (args)->nsubs-1) + +static constexpr int calcStepsPerGroup(int nGroups) { + //return NCCL_STEPS/nGroups; + return NCCL_STEPS; +} static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) { if (args->state == ncclProxyOpReady) { @@ -637,83 +684,117 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct if (args->state == ncclProxyOpProgress) { int p = NCCL_PROTO_SIMPLE; int nGroups = DIVUP(args->nsubs, COLLNET_GROUP_NSUBS); - int perGroupSteps = NCCL_STEPS / nGroups; for (int s=0; snsubs; s++) { struct ncclProxySubArgs* sub = args->subs+s; struct sendResources* resources = (struct sendResources*) (sub->connection->transportResources); void* sendMhandle = resources->sendMhandles[p]; void* recvMhandle = resources->recvMhandles[p]; + char* region = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[p]); auto reqFifo = resources->reqFifo; + int group = s/COLLNET_GROUP_NSUBS; + int groupStart = s - (s%COLLNET_GROUP_NSUBS); + if (sub->posted < sub->nsteps && sub->posted < sub->done + NCCL_STEPS) { int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; - int sharedBuffSlot = sub->posted%NCCL_STEPS; - int offset; - NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 0, sharedBuffSlot, 0, &offset)); - resources->recvMem->offsFifo[buffSlot] = offset + s*args->chunkSize; + resources->recvMem->connFifo[buffSlot].offset = calcRegionOffset(args, 0, s, sub->posted, 0); __sync_synchronize(); volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] posted offset %d @ %p signal %ld->%ld", long(sub->posted), group, buffSlot, resources->recvMem->connFifo[buffSlot].offset, &resources->recvMem->connFifo[buffSlot].offset, long(*sendHead), long(sub->base + sub->posted + args->sliceSteps - NCCL_STEPS)); sub->posted += args->sliceSteps; *sendHead = sub->base + sub->posted - NCCL_STEPS; if (resources->gdcSync) wc_store_fence(); // Flush out WC write } - // Enforce sync between operations of the same group. - bool groupSync = (((s == 0) && ((sub+args->nsubs-1)->received == sub->received)) || (s && (sub-1)->received > sub->received)); - if (groupSync && sub->received < sub->posted && sub->received < sub->done + perGroupSteps) { + if (sub->received < sub->posted && sub->received < sub->done + calcStepsPerGroup(nGroups)) { int buffSlot = (sub->base+sub->received)%NCCL_STEPS; - int sharedBuffSlot = sub->received%NCCL_STEPS; - volatile int* sizesFifo = resources->recvMem->sizesFifo; + volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; volatile uint64_t* recvTail = &resources->recvMem->tail; - char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, gpu, buffs[p]); - if (sizesFifo[buffSlot] != -1 && ((*recvTail > (sub->base+sub->received)))) { - // We have something to receive, let's check whether data is ready. - int ready = 1; - if (s == 0) { - int offset; - NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 0, sharedBuffSlot, 0, &offset)); - args->sharedBuff[sharedBuffSlot] = localBuff + offset; - args->sharedSize[sharedBuffSlot] = args->chunkSize; - } - if (ready) { - sizesFifo[buffSlot] = -1; - sub->received += args->sliceSteps; - args->idle = 0; - //continue; + if (connFifo[buffSlot].size != -1 && ((*recvTail > (sub->base+sub->received)))) { + if (args->coll != ncclFuncAllReduce) { + int sendBeg = calcRegionOffset(args, 0, s, sub->received, 0); + int sendEnd = calcRegionOffset(args, 0, s, sub->received, 1); + if (sendEnd-sendBeg != connFifo[buffSlot].size) { + WARN("CollNet sizes: want=%d got=%ld", sendEnd-sendBeg, connFifo[buffSlot].size); + return ncclInternalError; + } } + connFifo[buffSlot].size = -1; + sub->received += args->sliceSteps; + args->idle = 0; } } - if (LAST_OF_GROUP(s) && (sub->transmitted < sub->received)) { - int group = s / COLLNET_GROUP_NSUBS; - int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; - int sharedBuffSlot = sub->transmitted%NCCL_STEPS; - if (reqFifo[group][buffSlot].recvBuff != NULL) { - int totalSize = (s-group*COLLNET_GROUP_NSUBS+1) * args->sharedSize[sharedBuffSlot]; - int count = totalSize / ncclTypeSize((ncclDataType_t)args->dtype); - reqFifo[group][buffSlot].size = args->sharedSize[sharedBuffSlot]; - char* sendAddress = (char*)args->sharedBuff[sharedBuffSlot] + group*COLLNET_GROUP_NSUBS*args->sharedSize[sharedBuffSlot]; - NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, sendAddress, (void*)(reqFifo[group][buffSlot].recvBuff), count, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sendMhandle, recvMhandle, sub->requests+buffSlot)); - if (sub->requests[buffSlot] == NULL) continue; + // Enforce collective ordering of collnet ops. + bool ordered = s==0 ? args->subs[args->nsubs-1].transmitted == sub->transmitted + : sub->transmitted < (sub-1)->transmitted; + if (ordered && (sub->transmitted < sub->received)) { + if (LAST_OF_GROUP(args, s)) { + int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; + if (!reqFifo[group][buffSlot].turnIsSendNotRecv) continue; - TRACE(NCCL_NET, "sendProxy [%d/%d/%d] Iallreduce posted, size %d req %p", sub->transmitted, group, buffSlot, totalSize, sub->requests[buffSlot]); - // Make sure size is reset to zero before we update the head. - __sync_synchronize(); - sub->transmitted += args->sliceSteps; - args->idle = 0; - continue; + ssize_t sizePerRank = 0; + size_t allBeg = calcAlgoOffset(args, 1, groupStart, sub->transmitted); + size_t allEnd = calcAlgoOffset(args, 1, s+1, sub->transmitted); + int sendBeg = calcRegionOffset(args, 0, groupStart, sub->transmitted, 0); + int sendEnd = calcRegionOffset(args, 0, s, sub->transmitted, 1); + int recvBeg = calcRegionOffset(args, 1, groupStart, sub->transmitted, 0); + int recvEnd = calcRegionOffset(args, 1, s, sub->transmitted, 1); + reqFifo[group][buffSlot].size = recvEnd - recvBeg; + size_t eltSize = ncclTypeSize((ncclDataType_t)args->dtype); + + if (sendBeg==sendEnd && recvBeg==recvEnd) { + sub->requests[buffSlot] = nullptr; // trivally finished request + } else { + if (args->coll == ncclFuncAllReduce) { + int count = (sendEnd-sendBeg)/eltSize; + NCCLCHECK(proxyState->ncclCollNet->iallreduce(resources->collNetComm, region+sendBeg, region+recvBeg, count, (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, sendMhandle, recvMhandle, sub->requests+buffSlot)); + } else { + sizePerRank = args->specifics.collnetDirect.sizePerRank; + if (args->coll == ncclFuncAllGather) { + ncclNetSGE_v8_t recvParts; + recvParts.mhandle = recvMhandle; + recvParts.address = region + recvBeg; + recvParts.size = allEnd - allBeg; + NCCLCHECK(proxyState->ncclCollNet->iallgather( + resources->collNetComm, region+sendBeg, 1, &recvParts, + sizePerRank, allBeg, allEnd-allBeg, + sendMhandle, sub->requests+buffSlot)); + } else { + ncclNetSGE_v8_t sendParts; + sendParts.mhandle = sendMhandle; + sendParts.address = region + sendBeg; + sendParts.size = allEnd - allBeg; + NCCLCHECK(proxyState->ncclCollNet->ireducescatter( + resources->collNetComm, 1, &sendParts, region+recvBeg, + sizePerRank, allBeg, allEnd-allBeg, + (ncclDataType_t)args->dtype, (ncclRedOp_t)args->redOp, + recvMhandle, sub->requests+buffSlot)); + } + } + if (sub->requests[buffSlot] == nullptr) continue; + + if (args->coll == ncclFuncAllReduce) { + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Iallreduce posted, size %d req %p", (long)sub->transmitted, group, buffSlot, int(sendEnd-sendBeg), sub->requests[buffSlot]); + } else if (args->coll == ncclFuncAllGather) { + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Iallgather posted sendSize=%ld recvOffset=%ld recvSize=%ld request=%p", (long)sub->transmitted, group, buffSlot, long(sizePerRank), long(allBeg), long(allEnd-allBeg), sub->requests[buffSlot]); + } else { + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] Ireducescatter posted sendOffset=%ld sendSize=%ld recvSize=%ld request=%p", (long)sub->transmitted, group, buffSlot, long(allBeg), long(allEnd-allBeg), long(sizePerRank), sub->requests[buffSlot]); + } + } } + sub->transmitted += args->sliceSteps; + args->idle = 0; + continue; } // Check whether the network has completed some send operations. - if (LAST_OF_GROUP(s) && sub->done < sub->transmitted) { + if (LAST_OF_GROUP(args, s) && sub->done < sub->transmitted) { int done, size; - int group = s / COLLNET_GROUP_NSUBS; int buffSlot = (sub->base+sub->done)%NCCL_STEPS; - NCCLCHECK(proxyState->ncclCollNet->test((void*)(sub->requests[buffSlot]), &done, &size)); + done = 1; + if (sub->requests[buffSlot]) NCCLCHECK(proxyState->ncclCollNet->test((void*)(sub->requests[buffSlot]), &done, &size)); if (done) { - TRACE(NCCL_NET, "sendProxy [%d/%d/%d] request %p done, size %d", sub->done, group, buffSlot, sub->requests[buffSlot], size); - // Make sure size is updated before we set recvBuff to NULL (from the view of recv proxy, concerning the flush) - // (reordered store after store is possible on POWER, though not on x86) - __sync_synchronize(); - reqFifo[group][buffSlot].recvBuff = NULL; // Notify recvProxy - for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].done += args->sliceSteps; + TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] request %p done, size %d", (long)sub->done, group, buffSlot, sub->requests[buffSlot], size); + sub->requests[buffSlot] = nullptr; + reqFifo[group][buffSlot].turnIsSendNotRecv = false; // Notify recvProxy + for (int i=groupStart; i<=s; i++) args->subs[i].done += args->sliceSteps; args->idle = 0; int allDone = 1; for (int i=0; insubs; i++) { @@ -721,7 +802,7 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct } if (allDone) { args->state = ncclProxyOpNone; - TRACE(NCCL_NET, "sendProxy [%d/%d] stopped", sub->done, s); + TRACE(NCCL_NET, "sendProxy [%ld/%d] stopped", (long)sub->done, s); } } } @@ -739,6 +820,7 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->base = ROUNDUP(resources->step, args->chunkSteps); sub->posted = sub->received = sub->flushed = sub->transmitted = sub->done = 0; resources->step = sub->base + sub->nsteps; + memset(sub->requests, 0, sizeof(sub->requests)); } args->state = ncclProxyOpProgress; } @@ -746,38 +828,32 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (args->state == ncclProxyOpProgress) { int p = NCCL_PROTO_SIMPLE; int nGroups = DIVUP(args->nsubs, COLLNET_GROUP_NSUBS); - int perGroupSteps = NCCL_STEPS / nGroups; for (int s=0; snsubs; s++) { + int group = s/COLLNET_GROUP_NSUBS; + int groupStart = s - (s%COLLNET_GROUP_NSUBS); struct ncclProxySubArgs* sub = args->subs+s; struct recvResources* resources = (struct recvResources*) (sub->connection->transportResources); void* mhandle = resources->mhandles[p]; auto reqFifo = resources->reqFifo; - char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); + char* region = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); // Enforce sync between operations of the same group. - if (LAST_OF_GROUP(s) && (sub->posted < sub->done + perGroupSteps) && (sub->posted < sub->nsteps)) { - int group = s / COLLNET_GROUP_NSUBS; + if (LAST_OF_GROUP(args, s) && (sub->posted < sub->done + calcStepsPerGroup(nGroups)) && (sub->posted < sub->nsteps)) { int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; - int sharedBuffSlot = sub->posted%NCCL_STEPS; - int startChannel = group*COLLNET_GROUP_NSUBS; - int offset; - NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 1, sharedBuffSlot, startChannel, &offset)); - reqFifo[group][buffSlot].recvBuff = localBuff + offset; - TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff); + reqFifo[group][buffSlot].turnIsSendNotRecv = true; + TRACE(NCCL_NET, "recvProxy [%ld/%d/%d] posted buffer", (long)sub->posted, group, buffSlot); sub->posted += args->sliceSteps; args->idle = 0; continue; } - if (LAST_OF_GROUP(s) && (sub->posted > sub->received)) { - int group = s / COLLNET_GROUP_NSUBS; + if (LAST_OF_GROUP(args, s) && (sub->received < sub->posted)) { int buffSlot = (sub->base+sub->received)%NCCL_STEPS; - int sharedBuffSlot = sub->received%NCCL_STEPS; - if (reqFifo[group][buffSlot].recvBuff == NULL) { // Buffer is cleared : coll is complete - args->sharedSize[sharedBuffSlot] = reqFifo[group][buffSlot].size; - int totalSize = args->sharedSize[sharedBuffSlot]*(s-group*COLLNET_GROUP_NSUBS+1); - TRACE(NCCL_NET, "recvProxy [%d/%d/%d] received, size %d", sub->received, group, buffSlot, totalSize); + if (!reqFifo[group][buffSlot].turnIsSendNotRecv) { // Buffer is cleared : coll is complete + int recvBeg = calcRegionOffset(args, 1, groupStart, sub->received, 0); + int recvEnd = calcRegionOffset(args, 1, s, sub->received, 1); + int totalSize = recvEnd - recvBeg; + TRACE(NCCL_NET, "recvProxy [%ld/%d/%d] received, size %d chunkSize=%d", (long)sub->received, group, buffSlot, totalSize, args->chunkSize); sub->received += args->sliceSteps; - sub->requests[buffSlot] = NULL; if (reqFifo[group][buffSlot].size > 0 && resources->useGdr && resources->needFlush) { // GDRCOPY support if (resources->gdcFlush) { @@ -788,42 +864,31 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct WARN("NET: GDR Flush only supported on x86_64"); return ncclInternalError; #endif - sub->requests[buffSlot] = NULL; } else { - int startChannel = group*COLLNET_GROUP_NSUBS; - int offset; - NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 1, sharedBuffSlot, startChannel, &offset)); - NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, localBuff + offset, totalSize, mhandle, sub->requests+buffSlot)); + NCCLCHECK(proxyState->ncclCollNet->iflush(resources->collNetComm, region+recvBeg, totalSize, mhandle, sub->requests+buffSlot)); } - } else { - for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps; } args->idle = 0; continue; } } - if (LAST_OF_GROUP(s) && (sub->received > sub->flushed)) { + if (LAST_OF_GROUP(args, s) && (sub->flushed < sub->received)) { // Progress flush operations - int group = s / COLLNET_GROUP_NSUBS; int buffSlot = (sub->base + sub->flushed)%NCCL_STEPS; int done = 1; if (sub->requests[buffSlot]) NCCLCHECK(proxyState->ncclCollNet->test(sub->requests[buffSlot], &done, NULL)); if (done) { - TRACE(NCCL_NET, "recvProxy [%d/%d/%d] flushed", sub->flushed, group, buffSlot); + sub->requests[buffSlot] = nullptr; + TRACE(NCCL_NET, "recvProxy [%ld/%d/%d] flushed", (long)sub->flushed, group, buffSlot); for (int i=group*COLLNET_GROUP_NSUBS; i<=s; i++) args->subs[i].flushed += args->sliceSteps; args->idle = 0; //continue; } } - if (sub->flushed > sub->transmitted) { - int group = s / COLLNET_GROUP_NSUBS; + if (sub->transmitted < sub->flushed) { int buffSlot = (sub->base + sub->transmitted)%NCCL_STEPS; - int sharedBuffSlot = sub->transmitted%NCCL_STEPS; - int startChannel = group*COLLNET_GROUP_NSUBS; - int offset; - NCCLCHECK(sharedBuffersGet(sub->connection->collNet, 1, sharedBuffSlot, startChannel, &offset)); - volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo; - offsFifo[buffSlot] = offset + (s%COLLNET_GROUP_NSUBS)*args->chunkSize; + volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; + connFifo[buffSlot].offset = calcRegionOffset(args, 1, s, sub->transmitted, 0); __sync_synchronize(); volatile uint64_t* recvTail = resources->gdcSync ? resources->gdcSync : &resources->recvMem->tail; *recvTail = sub->base + sub->flushed; @@ -835,14 +900,15 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct // Enforce sync here to make sure the last sub doesn't increase "done" before all others in the group have // reached the same point, otherwise we would start posting buffers to the send proxy before we're done // processing all the shared buffer. - bool groupSync = (((s == 0) && ((sub+args->nsubs-1)->done == sub->done)) || (s && (sub-1)->done > sub->done)); + bool groupSync = s==0 ? args->subs[args->nsubs-1].done == sub->done + : (sub-1)->done > sub->done; volatile uint64_t* sendHead = &resources->sendMem->head; if (groupSync && sub->done < sub->transmitted && (sub->base+sub->done) < *sendHead) { sub->done += args->sliceSteps; args->idle = 0; if (sub->done == sub->nsteps && s == args->nsubs-1) { args->state = ncclProxyOpNone; - TRACE(NCCL_NET, "recvProxy [%d/%d] stopped", sub->done, s); + TRACE(NCCL_NET, "recvProxy [%ld/%d] stopped", (long)sub->done, s); } } } diff --git a/projects/rccl/src/transport/net.cc b/projects/rccl/src/transport/net.cc index 0998172f59..58cb92144f 100644 --- a/projects/rccl/src/transport/net.cc +++ b/projects/rccl/src/transport/net.cc @@ -347,9 +347,12 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne struct ncclRecvMem *recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, gpu, recvMem); send->conn.tail = &recvMem->tail; - send->conn.sizesFifo = recvMem->sizesFifo; + send->conn.connFifo = recvMem->connFifo; // Only fuse P2P buffers, continue to allocate dedicated buffers for ring/tree - send->conn.offsFifo = map->shared ? recvMem->offsFifo : NULL; + for (int i=0; iconn.connFifo[i].offset = -1; + recvMem->connFifo[i].mode = map->shared ? NCCL_MODE_OFFSET : NCCL_MODE_NORMAL; + } for (int p=0; pconn.buffs[p] = NCCL_NET_MAP_GET_POINTER(map, gpu, buffs[p]); @@ -409,9 +412,11 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne struct ncclRecvMem *recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, gpu, recvMem); void* gdcMem = map->mems[NCCL_NET_MAP_GDCMEM].gpuPtr; recv->conn.tail = gdcMem ? (uint64_t*)gdcMem : &recvMem->tail; - recv->conn.sizesFifo = recvMem->sizesFifo; + recv->conn.connFifo = recvMem->connFifo; // Only fuse P2P buffers, continue to allocate dedicated buffers for ring/tree - recv->conn.offsFifo = map->shared ? recvMem->offsFifo : NULL; + for (int i=0; iconnFifo[i].mode = map->shared ? NCCL_MODE_OFFSET : NCCL_MODE_NORMAL; + } for (int p=0; pconn.buffs[p] = NCCL_NET_MAP_GET_POINTER(map, gpu, buffs[p]); @@ -510,10 +515,11 @@ static ncclResult_t sharedNetBuffersInit(struct ncclProxyState* proxyState, int return ncclSuccess; } -static ncclResult_t sharedBuffersGet(struct ncclProxyState* proxyState, int channel, int slot, int* offset) { +static ncclResult_t sharedBuffersGet(struct ncclProxyState* proxyState, int channel, int slot, int* offset, int* size) { // Use different pools for different channels and also separate send/recv. int globalSlot = (channel*NCCL_SHARED_STEPS)+slot; *offset = proxyState->p2pChunkSize * globalSlot; + if (size) *size = proxyState->p2pChunkSize; return ncclSuccess; } @@ -752,8 +758,9 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str resources->recvMem = (struct ncclRecvMem*) NCCL_NET_MAP_GET_POINTER(map, cpu, recvMem); // Don't give credits yet in shared mode. - resources->sendMem->head = map->shared ? -NCCL_STEPS : 0; - for (int i=0; irecvMem->sizesFifo[i] = -1; + (resources->gdcSync ? *resources->gdcSync : resources->sendMem->head) = + (map->shared ? -NCCL_STEPS : 0); + for (int i=0; irecvMem->connFifo[i].size = -1; for (int p=0; pbuffers[p] = NCCL_NET_MAP_GET_POINTER(map, cpu, buffs[p]); @@ -1014,6 +1021,7 @@ static ncclResult_t recvProxyFree(struct ncclProxyConnection* connection, struct } static_assert(NCCL_STEPS <= NCCL_NET_MAX_REQUESTS, "Not enough net requests to cover for steps"); +#define MAX_NET_SIZE (1024*1024*1024L) // Rather than send INT_MAX which is 2G-1, send a power of two. static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct ncclProxyArgs* args) { if (args->state == ncclProxyOpReady) { @@ -1022,8 +1030,15 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct struct sendNetResources* resources = (struct sendNetResources*) (sub->connection->transportResources); // Round to next multiple of sliceSteps sub->base = ROUNDUP(resources->step, args->chunkSteps); + // Set step base for next op + resources->step = sub->base + sub->nsteps; sub->posted = sub->transmitted = sub->done = 0; for (uint64_t step=0; stepnsteps; step++) ncclProfilingRecord(args, s, step, ncclProxyProfileBegin); + if (sub->reg && sub->nbytes > 0) { + NCCLCHECK(proxyState->ncclNet->regMr(resources->netSendComm, sub->buffer, sub->nbytes, NCCL_PTR_CUDA, &sub->mhandle)); + } else { + sub->mhandle = resources->mhandles[args->protocol]; + } } args->state = ncclProxyOpProgress; } @@ -1035,23 +1050,24 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct struct ncclProxySubArgs* sub = args->subs+s; if (sub->done == sub->nsteps) continue; struct sendNetResources* resources = (struct sendNetResources*) (sub->connection->transportResources); - void* mhandle = resources->mhandles[p]; + volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; int stepSize = resources->buffSizes[p] / NCCL_STEPS; char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); - int buffSize = stepSize*args->sliceSteps; - if (sub->nbytes < buffSize) buffSize = sub->nbytes; // Post buffers to the GPU if (sub->posted < sub->nsteps && sub->posted < sub->done + maxDepth) { int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; if (resources->shared) { - int sharedBuffSlot = sub->posted%maxDepth; - int offset; - NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s, &offset)); - resources->recvMem->offsFifo[buffSlot] = offset; - __sync_synchronize(); + if (!sub->reg) { + int sharedBuffSlot = sub->posted%maxDepth; + int offset; + NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s, &offset, NULL)); + resources->recvMem->connFifo[buffSlot].offset = offset; + __sync_synchronize(); + } volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; sub->posted += args->sliceSteps; - *sendHead = sub->base + sub->posted - NCCL_STEPS; + // Only post one credit for registered buffer + if (sub->reg == 0 || sub->posted == args->sliceSteps) *sendHead = sub->base + sub->posted - NCCL_STEPS; if (resources->gdcSync) wc_store_fence(); // Flush out WC write } else sub->posted += args->sliceSteps; for (uint64_t step=sub->posted-args->sliceSteps; stepposted; step++) { @@ -1063,13 +1079,13 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct // Check whether we received data from the GPU and send it to the network if (sub->transmitted < sub->posted && sub->transmitted < sub->done + NCCL_STEPS) { int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; - volatile int* sizesFifo = resources->recvMem->sizesFifo; volatile uint64_t* recvTail = &resources->recvMem->tail; - if (sizesFifo[buffSlot] != -1 && ((*recvTail > (sub->base+sub->transmitted)) || p == NCCL_PROTO_LL)) { + uint64_t tail = sub->base + (sub->reg ? 0 : sub->transmitted); + if ((sub->reg || connFifo[buffSlot].size != -1) && ((*recvTail > tail) || p == NCCL_PROTO_LL)) { // We have something to receive, let's check if it's completely ready. - int size = sizesFifo[buffSlot]; + int size = sub->reg ? std::min(MAX_NET_SIZE, sub->nbytes) : connFifo[buffSlot].size; bool shared = (p == NCCL_PROTO_SIMPLE) && resources->shared; - char* buff = shared ? localBuff+resources->recvMem->offsFifo[buffSlot] : localBuff+buffSlot*stepSize; + char* buff = shared ? localBuff+connFifo[buffSlot].offset : localBuff+buffSlot*stepSize; int ready = 1; if (p == NCCL_PROTO_LL128) { ready = resources->useGdr; @@ -1077,7 +1093,7 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct // When data is in sysmem, we need to wait until all flags are correct since the GPU only // called threadfence() uint64_t flag = sub->base+sub->transmitted+1; - int nFifoLines = DIVUP(sizesFifo[buffSlot], sizeof(uint64_t)*NCCL_LL128_LINEELEMS); + int nFifoLines = DIVUP(connFifo[buffSlot].size, sizeof(uint64_t)*NCCL_LL128_LINEELEMS); volatile uint64_t* lines = (volatile uint64_t*)buff; ready = 1; for (int i=0; ishared) { + buff = sub->reg ? (char*)sub->buffer : localBuff+resources->recvMem->connFifo[buffSlot].offset; } if (ready) { // Data is ready, try to send. - NCCLCHECK(proxyState->ncclNet->isend(resources->netSendComm, buff, size, resources->tpRank, mhandle, sub->requests+buffSlot)); + NCCLCHECK(proxyState->ncclNet->isend(resources->netSendComm, buff, size, resources->tpRank, sub->mhandle, sub->requests+buffSlot)); if (sub->requests[buffSlot] != NULL) { - TRACE(NCCL_NET, "sendProxy [%ld/%d] Isend posted, req %p", sub->transmitted, buffSlot, sub->requests[buffSlot]); - sizesFifo[buffSlot] = -1; - // Make sure size is reset to zero before we update the head. - __sync_synchronize(); + TRACE(NCCL_NET, "sendProxy [%ld/%d] Isend posted, req %p, size %d, proto %d, myRank %d, channelId %d", sub->transmitted, buffSlot, sub->requests[buffSlot], size, p, proxyState->tpRank, sub->channelId); sub->transmitted += args->sliceSteps; for (uint64_t step=sub->transmitted-args->sliceSteps; steptransmitted; step++) ncclProfilingRecord(args, s, step, ncclProxyProfileSendWait); args->idle = 0; @@ -1113,21 +1128,43 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct // Check whether the network has completed some send operations. if (sub->done < sub->transmitted) { int done; + int size; int buffSlot = (sub->base+sub->done)%NCCL_STEPS; - NCCLCHECK(proxyState->ncclNet->test(sub->requests[buffSlot], &done, NULL)); + NCCLCHECK(proxyState->ncclNet->test(sub->requests[buffSlot], &done, &size)); if (done) { + if (sub->reg) { + if (size < sub->nbytes) { + sub->buffer = ((char*)sub->buffer)+size; + sub->nbytes -= size; + // Do one more step (at least) + sub->nsteps++; + } else { + // Signal the GPU the send is complete and it can return. + connFifo[sub->base%NCCL_STEPS].size = -1; + } + } + // Make sure size is reset to -1 before we update the head. + if (sub->reg == 0) connFifo[buffSlot].size = -1; + __sync_synchronize(); TRACE(NCCL_NET, "sendProxy [%ld/%d] request %p done", sub->done, buffSlot, sub->requests[buffSlot]); sub->done += args->sliceSteps; for (uint64_t step=sub->done-args->sliceSteps; stepdone; step++) ncclProfilingRecord(args, s, step, ncclProxyProfileEnd); if (resources->shared == 0) { volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; - *sendHead = sub->base + sub->done; + if (sub->reg) { + // We may have added more net steps, but reg operations only have a single step w.r.t. the GPU. + if (sub->done == sub->nsteps) *sendHead = sub->base + args->sliceSteps; + } else { + *sendHead = sub->base + sub->done; + } if (resources->gdcSync) wc_store_fence(); // Flush out WC write } args->idle = 0; if (sub->done == sub->nsteps) { - resources->step = sub->base + sub->nsteps; + if (sub->reg && sub->nbytes > 0) { + NCCLCHECK(proxyState->ncclNet->deregMr(resources->netSendComm, sub->mhandle)); + } args->done++; } } @@ -1171,9 +1208,17 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct recvComm = resources->netRecvComm; // Round to next multiple of sliceSteps sub->base = ROUNDUP(resources->step, args->chunkSteps); + // Set step base for next op + resources->step = sub->base + sub->nsteps; sub->posted = sub->received = sub->transmitted = sub->done = 0; for (int i=0; insteps; step++) ncclProfilingRecord(args, s, step, ncclProxyProfileBegin); + if (sub->reg && sub->nbytes > 0) { + // Register buffer + NCCLCHECK(proxyState->ncclNet->regMr(resources->netRecvComm, sub->buffer, sub->nbytes, NCCL_PTR_CUDA, &sub->mhandle)); + } else { + sub->mhandle = resources->mhandles[args->protocol]; + } } args->state = ncclProxyOpProgress; } @@ -1188,29 +1233,37 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct int sizes[NCCL_PROXY_MAX_SUBS]; int tags[NCCL_PROXY_MAX_SUBS]; void* mhandles[NCCL_PROXY_MAX_SUBS]; - for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; if (sub->posted < sub->nsteps) { if (sub->posted >= sub->done + maxDepth) { subCount = 0; break; } struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); + if (sub->reg) maxDepth = 1; int stepSize = resources->buffSizes[p] / NCCL_STEPS; char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; + volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; if (p == NCCL_PROTO_SIMPLE && resources->shared) { - int sharedBuffSlot = sub->posted%maxDepth; - int offset; - NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s+i, &offset)); - volatile int* offsFifo = (volatile int*)resources->recvMem->offsFifo; - offsFifo[buffSlot] = offset; - ptrs[subCount] = localBuff+offset; + if (sub->reg) { + // Wait until CUDA kernel has started before we access the user buffer directly. + if (connFifo[sub->base%NCCL_STEPS].size == -1) continue; + ptrs[subCount] = sub->buffer; + sizes[subCount] = std::min(MAX_NET_SIZE, sub->nbytes); + } else { + int sharedBuffSlot = sub->posted%maxDepth; + int offset; + NCCLCHECK(sharedBuffersGet(proxyState, sub->channelId, sharedBuffSlot*args->nsubs+s+i, &offset, sizes+subCount)); + connFifo[buffSlot].offset = offset; + ptrs[subCount] = localBuff+offset; + } } else { ptrs[subCount] = localBuff+buffSlot*stepSize; + sizes[subCount] = stepSize*args->sliceSteps; } sizes[subCount] = stepSize*args->sliceSteps; if (sub->nbytes < sizes[subCount]) sizes[subCount] = sub->nbytes; tags[subCount] = resources->tpRemoteRank; - mhandles[subCount] = resources->mhandles[p]; + mhandles[subCount] = sub->mhandle; subCount++; } } @@ -1246,9 +1299,27 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (done) { int needFlush = 0; int totalSize = 0; + int subIndex = 0; for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; + if (sub->received < sub->nsteps) { + int size = sizes[subIndex++]; + if (sub->reg) { + if (size < sub->nbytes) { + sub->buffer = ((char*)sub->buffer) + size; + sub->nbytes -= size; + // Do one more step (at least) + sub->nsteps++; + } else { + // Reset connFifo size indicating the GPU was ready to receive. + // There is a __sync_synchronize() later to ensure it is reset before it is set again by the GPU. + struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); + volatile struct ncclConnFifo* connFifo = (volatile struct ncclConnFifo*)resources->recvMem->connFifo; + connFifo[sub->base%NCCL_STEPS].size = -1; + } + } + } sub->received += args->sliceSteps; for (uint64_t step=sub->received-args->sliceSteps; stepreceived; step++) ncclProfilingRecord(args, s+i, step, ncclProxyProfileRecvFlushWait); if (step < sub->nsteps) { @@ -1276,9 +1347,11 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); int stepSize = resources->buffSizes[p] / NCCL_STEPS; char* localBuff = NCCL_NET_MAP_GET_POINTER(&resources->map, cpu, buffs[p]); - int buffSlot = (sub->base+sub->posted)%NCCL_STEPS; - ptrs[subCount] = resources->shared ? localBuff+resources->recvMem->offsFifo[buffSlot] : localBuff+buffSlot*stepSize; - mhandles[subCount] = resources->mhandles[p]; + int buffSlot = (sub->base+sub->received-args->sliceSteps)%NCCL_STEPS; + ptrs[subCount] = resources->shared ? + (sub->reg ? sub->buffer : localBuff+resources->recvMem->connFifo[buffSlot].offset) : + localBuff+buffSlot*stepSize; + mhandles[subCount] = sub->mhandle; subCount++; } } @@ -1302,13 +1375,18 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (done) { for (int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; + sub->transmitted += args->sliceSteps; for (uint64_t step=sub->transmitted-args->sliceSteps; steptransmitted; step++) ncclProfilingRecord(args, s+i, step, ncclProxyProfileRecvGPUWait); if (step < sub->nsteps) { __sync_synchronize(); struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); volatile uint64_t* recvTail = resources->gdcSync ? resources->gdcSync : &resources->recvMem->tail; - *recvTail = sub->base + sub->transmitted; + if (sub->reg) { + // We may have added more net steps, but reg operations only have a single step w.r.t. the GPU. + if (sub->transmitted == sub->nsteps) *recvTail = sub->base + args->sliceSteps; + } else + *recvTail = sub->base + sub->transmitted; if (resources->gdcSync) wc_store_fence(); // Flush out WC write } } @@ -1326,7 +1404,7 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (sub->transmitted > sub->done) { struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); volatile uint64_t* sendHead = &resources->sendMem->head; - uint64_t done = *sendHead; + uint64_t done = sub->reg ? sub->base + sub->nsteps : *sendHead; while (done > sub->base + sub->done && // LL and LL128 can acknowledge 0-bytes send before they even happen. Don't go past what we transmitted. sub->transmitted > sub->done) { @@ -1341,7 +1419,9 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct args->idle = 0; if (sub->done == sub->nsteps) { struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); - resources->step = sub->base + sub->nsteps; + if (sub->reg && sub->nbytes > 0) { + NCCLCHECK(proxyState->ncclNet->deregMr(resources->netRecvComm, sub->mhandle)); + } args->done++; break; } diff --git a/projects/rccl/src/transport/net_ib.cc b/projects/rccl/src/transport/net_ib.cc index 8d4313dddc..43b110956d 100644 --- a/projects/rccl/src/transport/net_ib.cc +++ b/projects/rccl/src/transport/net_ib.cc @@ -31,7 +31,7 @@ static union ncclSocketAddress ncclIbIfAddr; struct ncclIbMr { uintptr_t addr; - int pages; + size_t pages; int refs; ibv_mr *mr; }; @@ -41,12 +41,22 @@ struct ncclIbMrCache { int capacity, population; }; +static int ncclNMergedIbDevs = -1; +#define NCCL_IB_MAX_DEVS_PER_NIC 2 +#define MAX_MERGED_DEV_NAME (MAXNAMESIZE*NCCL_IB_MAX_DEVS_PER_NIC)+NCCL_IB_MAX_DEVS_PER_NIC +struct alignas(64) ncclIbMergedDev { + int ndevs; + int devs[NCCL_IB_MAX_DEVS_PER_NIC]; // Points to an index in ncclIbDevs + int speed; + char devName[MAX_MERGED_DEV_NAME]; // Up to NCCL_IB_MAX_DEVS_PER_NIC * name size, and a character for each '+' +}; + static int ncclNIbDevs = -1; struct alignas(64) ncclIbDev { pthread_mutex_t lock; int device; uint64_t guid; - uint8_t port; + uint8_t portNum; uint8_t link; int speed; ibv_context* context; @@ -58,17 +68,12 @@ struct alignas(64) ncclIbDev { int maxQp; struct ncclIbMrCache mrCache; int ar; // ADAPTIVE_ROUTING -}; - -#define MAX_IB_PORT 15 -struct userIbDev { - char devName[MAXNAMESIZE]; - uint16_t port_en; + struct ibv_port_attr portAttr; }; #define MAX_IB_DEVS 32 +struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_DEVS]; struct ncclIbDev ncclIbDevs[MAX_IB_DEVS]; -struct userIbDev userIbDevs[MAX_IB_DEVS]; pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER; static int ncclIbRelaxedOrderingEnabled = 0; @@ -85,14 +90,14 @@ NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); pthread_t ncclIbAsyncThread; static void* ncclIbAsyncThreadMain(void* args) { - struct ibv_context* context = (struct ibv_context*)args; + struct ncclIbDev* dev = (struct ncclIbDev*)args; while (1) { struct ibv_async_event event; - if (ncclSuccess != wrap_ibv_get_async_event(context, &event)) { break; } + if (ncclSuccess != wrap_ibv_get_async_event(dev->context, &event)) { break; } char *str; if (ncclSuccess != wrap_ibv_event_type_str(&str, event.event_type)) { break; } if (event.event_type != IBV_EVENT_COMM_EST) - WARN("NET/IB : Got async event : %s", str); + WARN("NET/IB : %s:%d Got async event : %s", dev->devName, dev->portNum, str); if (ncclSuccess != wrap_ibv_ack_async_event(&event)) { break; } } return NULL; @@ -100,6 +105,7 @@ static void* ncclIbAsyncThreadMain(void* args) { NCCL_PARAM(IbDisable, "IB_DISABLE", 0); NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1); +NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1); static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) { char devicePath[PATH_MAX]; @@ -156,6 +162,25 @@ static int ncclIbRelaxedOrderingCapable(void) { return r == ncclInternalError ? 0 : 1; } +// Compare ncclIbDev[dev] to all stored mergedIbDevs +int ncclIbFindMatchingDev(int dev) { + for (int i = 0; i < ncclNMergedIbDevs; i++) { + if (ncclIbMergedDevs[i].ndevs < NCCL_IB_MAX_DEVS_PER_NIC) { + int compareDev = ncclIbMergedDevs[i].devs[0]; + if (strcmp(ncclIbDevs[dev].pciPath, ncclIbDevs[compareDev].pciPath) == 0 && + (ncclIbDevs[dev].guid == ncclIbDevs[compareDev].guid) && + (ncclIbDevs[dev].link == ncclIbDevs[compareDev].link)) { + TRACE(NCCL_NET, "NET/IB: Matched name1=%s pciPath1=%s guid1=0x%lx link1=%u name2=%s pciPath2=%s guid2=0x%lx link2=%u", + ncclIbDevs[dev].devName, ncclIbDevs[dev].pciPath, ncclIbDevs[dev].guid, ncclIbDevs[dev].link, + ncclIbDevs[compareDev].devName, ncclIbDevs[compareDev].pciPath, ncclIbDevs[compareDev].guid, ncclIbDevs[compareDev].link); + return i; + } + } + } + + return ncclNMergedIbDevs; +} + ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { if (ncclParamIbDisable()) return ncclInternalError; static int shownIbHcaEnv = 0; @@ -166,6 +191,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { wrap_ibv_fork_init(); if (ncclNIbDevs == -1) { ncclNIbDevs = 0; + ncclNMergedIbDevs = 0; if (ncclFindInterfaces(ncclIbIfName, &ncclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) { WARN("NET/IB : No IP interface found."); return ncclInternalError; @@ -201,10 +227,10 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { if (ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; } continue; } - for (int port = 1; port <= devAttr.phys_port_cnt; port++) { + for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) { struct ibv_port_attr portAttr; - if (ncclSuccess != wrap_ibv_query_port(context, port, &portAttr)) { - WARN("NET/IB : Unable to query port %d", port); + if (ncclSuccess != wrap_ibv_query_port(context, port_num, &portAttr)) { + WARN("NET/IB : Unable to query port_num %d", port_num); continue; } if (portAttr.state != IBV_PORT_ACTIVE) continue; @@ -212,15 +238,13 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) continue; // check against user specified HCAs/ports - if (! (matchIfList(devices[d]->name, port, userIfs, nUserIfs, searchExact) ^ searchNot)) { + if (! (matchIfList(devices[d]->name, port_num, userIfs, nUserIfs, searchExact) ^ searchNot)) { continue; } - TRACE(NCCL_INIT|NCCL_NET,"NET/IB: [%d] %s:%d/%s ", d, devices[d]->name, port, - portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); pthread_mutex_init(&ncclIbDevs[ncclNIbDevs].lock, NULL); ncclIbDevs[ncclNIbDevs].device = d; ncclIbDevs[ncclNIbDevs].guid = devAttr.sys_image_guid; - ncclIbDevs[ncclNIbDevs].port = port; + ncclIbDevs[ncclNIbDevs].portNum = port_num; ncclIbDevs[ncclNIbDevs].link = portAttr.link_layer; ncclIbDevs[ncclNIbDevs].speed = ncclIbSpeed(portAttr.active_speed) * ncclIbWidth(portAttr.active_width); ncclIbDevs[ncclNIbDevs].context = context; @@ -238,9 +262,36 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { ncclIbDevs[ncclNIbDevs].ar = (portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND) ? 1 : 0; if (ncclParamIbAdaptiveRouting() != -2) ncclIbDevs[ncclNIbDevs].ar = ncclParamIbAdaptiveRouting(); - pthread_create(&ncclIbAsyncThread, NULL, ncclIbAsyncThreadMain, context); + TRACE(NCCL_NET,"NET/IB: [%d] %s:%s:%d/%s speed=%d context=%p pciPath=%s ar=%d", d, devices[d]->name, devices[d]->dev_name, ncclIbDevs[ncclNIbDevs].portNum, + portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", ncclIbDevs[ncclNIbDevs].speed, context, ncclIbDevs[ncclNIbDevs].pciPath, ncclIbDevs[ncclNIbDevs].ar); + + pthread_create(&ncclIbAsyncThread, NULL, ncclIbAsyncThreadMain, ncclIbDevs + ncclNIbDevs); ncclSetThreadName(ncclIbAsyncThread, "NCCL IbAsync %2d", ncclNIbDevs); pthread_detach(ncclIbAsyncThread); // will not be pthread_join()'d + + int mergedDev = ncclNMergedIbDevs; + if (ncclParamIbMergeNics()) { + mergedDev = ncclIbFindMatchingDev(ncclNIbDevs); + } + + // No matching dev found, create new mergedDev entry (it's okay if there's only one dev inside) + if (mergedDev == ncclNMergedIbDevs) { + // Set ndevs to 1, assign first ibDevN to the current IB device + ncclIbMergedDevs[mergedDev].ndevs = 1; + ncclIbMergedDevs[mergedDev].devs[0] = ncclNIbDevs; + ncclNMergedIbDevs++; + strncpy(ncclIbMergedDevs[mergedDev].devName, ncclIbDevs[ncclNIbDevs].devName, MAXNAMESIZE); + // Matching dev found, edit name + } else { + // Set next device in this array to the current IB device + int ndevs = ncclIbMergedDevs[mergedDev].ndevs; + ncclIbMergedDevs[mergedDev].devs[ndevs] = ncclNIbDevs; + ncclIbMergedDevs[mergedDev].ndevs++; + snprintf(ncclIbMergedDevs[mergedDev].devName + strlen(ncclIbMergedDevs[mergedDev].devName), MAXNAMESIZE+1, "+%s", ncclIbDevs[ncclNIbDevs].devName); + } + + // Aggregate speed + ncclIbMergedDevs[mergedDev].speed += ncclIbDevs[ncclNIbDevs].speed; ncclNIbDevs++; nPorts++; } @@ -251,15 +302,30 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { if (ncclNIbDevs == 0) { INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found."); } else { - char line[1024]; + char line[2048]; line[0] = '\0'; // Determine whether RELAXED_ORDERING is enabled and possible ncclIbRelaxedOrderingEnabled = ncclIbRelaxedOrderingCapable(); - for (int d=0; dndevs > 1) { + // Print out merged dev info + snprintf(line+strlen(line), 2047-strlen(line), " [%d]={", d); + for (int i = 0; i < mergedDev->ndevs; i++) { + int ibDev = mergedDev->devs[i]; + snprintf(line+strlen(line), 2047-strlen(line), "[%d] %s:%d/%s%s", ibDev, ncclIbDevs[ibDev].devName, + ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", + // Insert comma to delineate + i == (mergedDev->ndevs - 1) ? "" : ", "); + } + snprintf(line+strlen(line), 2047-strlen(line), "}"); + } else { + int ibDev = mergedDev->devs[0]; + snprintf(line+strlen(line), 2047-strlen(line), " [%d]%s:%d/%s", ibDev, ncclIbDevs[ibDev].devName, + ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); + } } - line[1023] = '\0'; + line[2047] = '\0'; char addrline[SOCKET_NAME_MAXLEN+1]; INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline)); @@ -270,7 +336,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { } ncclResult_t ncclIbDevices(int* ndev) { - *ndev = ncclNIbDevs; + *ndev = ncclNMergedIbDevs; return ncclSuccess; } @@ -278,7 +344,7 @@ ncclResult_t ncclIbDevices(int* ndev) { // Returns : // ncclSuccess : GDR works // ncclSystemError : no module or module loaded but not supported by GPU -ncclResult_t ncclIbGdrSupport(int ibDev) { +ncclResult_t ncclIbGdrSupport() { static int moduleLoaded = -1; if (moduleLoaded == -1) { // Check for the nv_peer_mem module being loaded @@ -300,13 +366,19 @@ ncclResult_t ncclIbDmaBufSupport(int dev) { ncclResult_t res; struct ibv_pd* pd; struct ibv_context* ctx; - ctx = ncclIbDevs[dev].context; - NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); - // Test kernel DMA-BUF support with a dummy call (fd=-1) - (void) wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL/*offset*/, 0ULL/*len*/, 0ULL/*iova*/, -1/*fd*/, 0/*flags*/); - // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) - dmaBufSupported = (errno != EOPNOTSUPP && errno != EPROTONOSUPPORT) ? 1 : 0; - NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + dev; + + // Test each dev + for (int i = 0; i < mergedDev->ndevs; i++) { + int ibDev = mergedDev->devs[i]; + ctx = ncclIbDevs[ibDev].context; + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); + // Test kernel DMA-BUF support with a dummy call (fd=-1) + (void) wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL/*offset*/, 0ULL/*len*/, 0ULL/*iova*/, -1/*fd*/, 0/*flags*/); + // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) + dmaBufSupported = (errno != EOPNOTSUPP && errno != EPROTONOSUPPORT) ? 1 : 0; + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + } } if (dmaBufSupported == 0) return ncclSystemError; return ncclSuccess; @@ -318,20 +390,25 @@ failure: #define NCCL_NET_IB_MAX_RECVS 8 ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { - props->name = ncclIbDevs[dev].devName; - props->pciPath = ncclIbDevs[dev].pciPath; - props->guid = ncclIbDevs[dev].guid; + struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs+dev; + props->name = mergedDev->devName; + props->speed = mergedDev->speed; + + // Take the rest of the properties from an arbitrary sub-device (should be the same) + struct ncclIbDev* ibDev = ncclIbDevs + mergedDev->devs[0]; + props->pciPath = ibDev->pciPath; + props->guid = ibDev->guid; props->ptrSupport = NCCL_PTR_HOST; - if (ncclIbGdrSupport(dev) == ncclSuccess) { + if (ncclIbGdrSupport() == ncclSuccess) { props->ptrSupport |= NCCL_PTR_CUDA; // GDR support via nv_peermem } + props->regIsGlobal = 1; if (ncclIbDmaBufSupport(dev) == ncclSuccess) { props->ptrSupport |= NCCL_PTR_DMABUF; // GDR support via DMA-BUF } - props->speed = ncclIbDevs[dev].speed; props->latency = 0; // Not set - props->port = ncclIbDevs[dev].port + ncclIbDevs[dev].realPort; - props->maxComms = ncclIbDevs[dev].maxQp; + props->port = ibDev->portNum + ibDev->realPort; + props->maxComms = ibDev->maxQp; props->maxRecvs = NCCL_NET_IB_MAX_RECVS; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; @@ -344,24 +421,39 @@ static_assert(MAX_REQUESTS <= 256, "request id are encoded in wr_id and we need #define NCCL_IB_MAX_QPS 128 +// Per-QP connection metatdata struct ncclIbQpInfo { - uint32_t lid; - uint8_t ib_port; - uint8_t link_layer; - uint32_t qpn[NCCL_IB_MAX_QPS]; + uint32_t qpn; // Fields needed for ece (enhanced connection establishment) - struct ibv_ece ece[NCCL_IB_MAX_QPS]; - int ece_supported[NCCL_IB_MAX_QPS]; + struct ibv_ece ece; + int ece_supported; + int devIndex; +}; + +// Per-Dev connection metadata +struct ncclIbDevInfo { + uint32_t lid; + uint8_t ib_port; + enum ibv_mtu mtu; + uint8_t link_layer; // For RoCE uint64_t spn; uint64_t iid; - enum ibv_mtu mtu; // FIFO RDMA info uint32_t fifoRkey; + union ibv_gid remoteGid; +}; + +// Struct containing everything needed to establish connections +struct ncclIbConnectionMetadata { + struct ncclIbQpInfo qpInfo[NCCL_IB_MAX_QPS]; + struct ncclIbDevInfo devs[NCCL_IB_MAX_DEVS_PER_NIC]; + char devName[MAX_MERGED_DEV_NAME]; uint64_t fifoAddr; + int ndevs; }; enum ncclIbCommState { @@ -388,11 +480,10 @@ struct ncclIbHandle { struct ncclIbCommStage stage; // Used by the other side when connecting }; -// Retain local and remote RoCE addresses for error logging +// Retain local RoCE address for error logging struct ncclIbGidInfo { uint8_t link_layer; union ibv_gid localGid; - union ibv_gid remoteGid; }; #define NCCL_NET_IB_REQ_UNUSED 0 @@ -402,31 +493,31 @@ struct ncclIbGidInfo { const char* reqTypeStr[] = { "Unused", "Send", "Recv", "Flush" }; struct ncclIbRequest { - struct ncclIbVerbs* verbs; + struct ncclIbNetCommBase* base; int type; - int events; struct ncclSocket* sock; - struct ncclIbGidInfo* gidInfo; + int events[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbNetCommDevBase* devBases[NCCL_IB_MAX_DEVS_PER_NIC]; int nreqs; union { struct { int size; void* data; - uint32_t lkey; + uint32_t lkeys[NCCL_IB_MAX_DEVS_PER_NIC]; int offset; } send; struct { - int sizes[NCCL_NET_IB_MAX_RECVS]; + int* sizes; } recv; }; }; -struct ncclIbVerbs { - int dev; - struct ibv_pd* pd; // duplicate of ncclIbDevs[dev].pd +struct ncclIbNetCommDevBase { + int ibDevN; + struct ibv_pd* pd; struct ibv_cq* cq; uint64_t pad[1]; - struct ncclIbRequest reqs[MAX_REQUESTS]; + struct ncclIbGidInfo gidInfo; }; struct ncclIbListenComm { @@ -438,108 +529,157 @@ struct ncclIbListenComm { struct ncclIbSendFifo { uint64_t addr; int size; - uint32_t rkey; + uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; uint32_t nreqs; uint32_t tag; uint64_t idx; + char padding[24]; +}; + +struct ncclIbQp { + struct ibv_qp* qp; + int devIndex; + int remDevIdx; +}; + +struct ncclIbRemSizesFifo { + int elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + uint64_t fifoTail; + uint64_t addr; + uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; + uint32_t flags; + struct ibv_mr* mrs[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ibv_sge sge; +}; + +// A per-dev struct for netIbSendComm +struct alignas(8) ncclIbSendCommDev { + struct ncclIbNetCommDevBase base; + struct ibv_mr* fifoMr; +}; + + +// Wrapper to track an MR per-device, if needed +struct ncclIbMrHandle { + ibv_mr* mrs[NCCL_IB_MAX_DEVS_PER_NIC]; +}; + +struct alignas(32) ncclIbNetCommBase { + int ndevs; + bool isSend; + struct ncclIbRequest reqs[MAX_REQUESTS]; + struct ncclIbQp qps[NCCL_IB_MAX_QPS]; + int nqps; + int qpIndex; + int devIndex; + struct ncclSocket sock; + int ready; + // Track necessary remDevInfo here + int nRemDevs; + struct ncclIbDevInfo remDevs[NCCL_IB_MAX_DEVS_PER_NIC]; }; struct ncclIbSendComm { - struct ncclIbVerbs verbs; + struct ncclIbNetCommBase base; struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; - uint64_t fifoHead; + // Each dev correlates to a mergedIbDev + struct ncclIbSendCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ncclIbRequest* fifoReqs[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; - struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS+1]; struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS]; - struct ncclSocket sock; - - int ready; - struct ibv_qp* qps[NCCL_IB_MAX_QPS]; - int nqps; - int qpIndex; - struct ibv_mr* fifoMr; - int ar; - struct ncclIbGidInfo gidInfo; + struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS+1]; + struct ncclIbRemSizesFifo remSizesFifo; + uint64_t fifoHead; + int ar; // Use adaptive routing when all merged devices have it enabled }; // The SendFifo needs to be 32-byte aligned and each element needs // to be a 32-byte multiple, so that an entry does not get split and // written out of order when IB Relaxed Ordering is enabled +static_assert((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset"); static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples"); +static_assert((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned"); +static_assert((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned"); struct ncclIbGpuFlush { - int enabled; - int hostMem; struct ibv_mr* hostMr; struct ibv_sge sge; - struct ibv_qp* qp; + struct ncclIbQp qp; }; struct ncclIbRemFifo { struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; uint64_t fifoTail; uint64_t addr; - uint32_t rkey; uint32_t flags; - struct ibv_mr* mr; - struct ibv_sge sge; +}; + +struct alignas(16) ncclIbRecvCommDev { + struct ncclIbNetCommDevBase base; + struct ncclIbGpuFlush gpuFlush; + uint32_t fifoRkey; + struct ibv_mr* fifoMr; + struct ibv_sge fifoSge; + struct ibv_mr* sizesFifoMr; }; struct ncclIbRecvComm { - struct ncclIbVerbs verbs; + struct ncclIbNetCommBase base; + struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ncclIbRemFifo remFifo; - struct ncclSocket sock; - int ready; - struct ibv_qp* qps[NCCL_IB_MAX_QPS]; - int nqps; - int qpIndex; - struct ncclIbGpuFlush gpuFlush; - struct ncclIbGidInfo gidInfo; + int sizesFifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + int gpuFlushHostMem; + int flushEnabled; }; -static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); +static_assert((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbRecvComm fifo must be 32-byte aligned"); NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); -ncclResult_t ncclIbInitVerbs(int dev, struct ibv_context* ctx, struct ncclIbVerbs* verbs) { - verbs->dev = dev; +static void ncclIbAddEvent(struct ncclIbRequest* req, int devIndex, struct ncclIbNetCommDevBase* base) { + req->events[devIndex]++; + req->devBases[devIndex] = base; +} - pthread_mutex_lock(&ncclIbDevs[dev].lock); - if (0 == ncclIbDevs[dev].pdRefs++) { +ncclResult_t ncclIbInitCommDevBase(int ibDevN, struct ncclIbNetCommDevBase* base) { + base->ibDevN = ibDevN; + ncclIbDev* ibDev = ncclIbDevs + ibDevN; + pthread_mutex_lock(&ibDev->lock); + if (0 == ibDev->pdRefs++) { ncclResult_t res; - NCCLCHECKGOTO(wrap_ibv_alloc_pd(&ncclIbDevs[dev].pd, ctx), res, failure); + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&ibDev->pd, ibDev->context), res, failure); if (0) { failure: - pthread_mutex_unlock(&ncclIbDevs[dev].lock); + pthread_mutex_unlock(&ibDev->lock); return res; } } - verbs->pd = ncclIbDevs[dev].pd; - pthread_mutex_unlock(&ncclIbDevs[dev].lock); + base->pd = ibDev->pd; + pthread_mutex_unlock(&ibDev->lock); // Recv requests can generate 2 completions (one for the post FIFO, one for the Recv). - NCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0)); + NCCLCHECK(wrap_ibv_create_cq(&base->cq, ibDev->context, 2*MAX_REQUESTS*ncclParamIbQpsPerConn(), NULL, NULL, 0)); + return ncclSuccess; } -ncclResult_t ncclIbDestroyVerbs(struct ncclIbVerbs* verbs) { +ncclResult_t ncclIbDestroyBase(struct ncclIbNetCommDevBase* base) { ncclResult_t res; - NCCLCHECK(wrap_ibv_destroy_cq(verbs->cq)); + NCCLCHECK(wrap_ibv_destroy_cq(base->cq)); - pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock); - if (0 == --ncclIbDevs[verbs->dev].pdRefs) { - NCCLCHECKGOTO(wrap_ibv_dealloc_pd(ncclIbDevs[verbs->dev].pd), res, returning); + pthread_mutex_lock(&ncclIbDevs[base->ibDevN].lock); + if (0 == --ncclIbDevs[base->ibDevN].pdRefs) { + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(ncclIbDevs[base->ibDevN].pd), res, returning); } res = ncclSuccess; returning: - pthread_mutex_unlock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_unlock(&ncclIbDevs[base->ibDevN].lock); return res; } -ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int access_flags, struct ibv_qp** qp) { +ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, int access_flags, struct ncclIbQp* qp) { struct ibv_qp_init_attr qpInitAttr; memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr)); - qpInitAttr.send_cq = verbs->cq; - qpInitAttr.recv_cq = verbs->cq; + qpInitAttr.send_cq = base->cq; + qpInitAttr.recv_cq = base->cq; qpInitAttr.qp_type = IBV_QPT_RC; // We might send 2 messages per send (RDMA and RDMA_WITH_IMM) qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS; @@ -547,23 +687,23 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbVerbs* verbs, int acce qpInitAttr.cap.max_send_sge = 1; qpInitAttr.cap.max_recv_sge = 1; qpInitAttr.cap.max_inline_data = ncclParamIbUseInline() ? sizeof(struct ncclIbSendFifo) : 0; - NCCLCHECK(wrap_ibv_create_qp(qp, verbs->pd, &qpInitAttr)); + NCCLCHECK(wrap_ibv_create_qp(&qp->qp, base->pd, &qpInitAttr)); struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_INIT; qpAttr.pkey_index = ncclParamIbPkey(); qpAttr.port_num = ib_port; qpAttr.qp_access_flags = access_flags; - NCCLCHECK(wrap_ibv_modify_qp(*qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); return ncclSuccess; } -ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t qpn, struct ncclIbQpInfo* info) { +ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbDevInfo* info) { struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_RTR; qpAttr.path_mtu = info->mtu; - qpAttr.dest_qp_num = qpn; + qpAttr.dest_qp_num = dest_qp_num; qpAttr.rq_psn = 0; qpAttr.max_dest_rd_atomic = 1; qpAttr.min_rnr_timer = 12; @@ -631,110 +771,183 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet } NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); - NCCLCHECK(ncclSocketInit(&comm->sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->base.sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1)); stage->comm = comm; stage->state = ncclIbCommStateConnect; - NCCLCHECK(ncclSocketConnect(&comm->sock)); + NCCLCHECK(ncclSocketConnect(&comm->base.sock)); ib_connect_check: /* since ncclSocketConnect is async, we must check if connection is complete */ - NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + NCCLCHECK(ncclSocketReady(&comm->base.sock, &ready)); if (!ready) return ncclSuccess; // IB Setup - struct ibv_context* ctx; - ctx = ncclIbDevs[dev].context; - NCCLCHECK(ncclIbInitVerbs(dev, ctx, &comm->verbs)); - uint8_t ib_port; - ib_port = ncclIbDevs[dev].port; - comm->nqps = ncclParamIbQpsPerConn(); - for (int q=0; qnqps; q++) { - NCCLCHECK(ncclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps+q)); - } - comm->ar = ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING + struct ncclIbMergedDev* mergedDev; + mergedDev = ncclIbMergedDevs + dev; + comm->base.ndevs = mergedDev->ndevs; + comm->base.nqps = ncclParamIbQpsPerConn() * comm->base.ndevs; // We must have at least 1 qp per-device + comm->base.isSend = true; - // Send my QP Info to receiver through the socket. Hope this won't block. - struct ibv_port_attr portAttr; - NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr)); - struct ncclIbQpInfo qpInfo; - qpInfo.ib_port = ib_port; - for (int q=0; qnqps; q++) { - qpInfo.qpn[q] = comm->qps[q]->qp_num; + // Init PD, Ctx for each IB device + comm->ar = 1; // Set to 1 for logic + for (int i = 0; i < mergedDev->ndevs; i++) { + int ibDevN = mergedDev->devs[i]; + NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &comm->devs[i].base)); + comm->ar = comm->ar && ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled + } + + struct ncclIbConnectionMetadata meta; + meta.ndevs = comm->base.ndevs; + + // Alternate QPs between devices + int devIndex; + devIndex = 0; + for (int q = 0; q < comm->base.nqps; q++) { + ncclIbSendCommDev* commDev = comm->devs + devIndex; + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; + NCCLCHECK(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, comm->base.qps+q)); + comm->base.qps[q].devIndex = devIndex; + meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex; // Query ece capabilities (enhanced connection establishment) - NCCLCHECK(wrap_ibv_query_ece(comm->qps[q], &qpInfo.ece[q], &qpInfo.ece_supported[q])); + NCCLCHECK(wrap_ibv_query_ece(comm->base.qps[q].qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); + devIndex = (devIndex + 1) % comm->base.ndevs; } - qpInfo.mtu = portAttr.active_mtu; + for (int i = 0; i < comm->base.ndevs; i++) { + ncclIbSendCommDev* commDev = comm->devs + i; + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; + // Send my QP Info to receiver through the socket. Hope this won't block. + // TODO - I thought I queried this in init? + NCCLCHECK(wrap_ibv_query_port(ibDev->context, ibDev->portNum, &ibDev->portAttr)); - // Prepare my fifo - NCCLCHECK(wrap_ibv_reg_mr(&comm->fifoMr, comm->verbs.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); - qpInfo.fifoRkey = comm->fifoMr->rkey; - qpInfo.fifoAddr = (uint64_t)comm->fifo; + // Write to the metadata struct via this pointer + ncclIbDevInfo* devInfo = meta.devs + i; + devInfo->ib_port = ibDev->portNum; + devInfo->mtu = ibDev->portAttr.active_mtu; + devInfo->lid = ibDev->portAttr.lid; - // RoCE support - qpInfo.lid = portAttr.lid; - qpInfo.link_layer = comm->gidInfo.link_layer = portAttr.link_layer; - if (qpInfo.link_layer == IBV_LINK_LAYER_ETHERNET) { - NCCLCHECK(wrap_ibv_query_gid(ncclIbDevs[dev].context, ncclIbDevs[dev].port, ncclParamIbGidIndex(), &comm->gidInfo.localGid)); - qpInfo.spn = comm->gidInfo.localGid.global.subnet_prefix; - qpInfo.iid = comm->gidInfo.localGid.global.interface_id; - } - - if (qpInfo.link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB - for (int q=0; qnqps; q++) - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ncclIbDevs[dev].port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid); - } else { // RoCE - for (int q=0; qnqps; q++) - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX)", - dev, ncclIbDevs[dev].port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.ece_supported[q], qpInfo.ece[q].vendor_id, qpInfo.ece[q].options, qpInfo.ece[q].comp_mask, ncclParamIbGidIndex(), - qpInfo.spn, qpInfo.iid); + // Prepare my fifo + NCCLCHECK(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); + devInfo->fifoRkey = commDev->fifoMr->rkey; + + // RoCE support + devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET) { + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &commDev->base.gidInfo.localGid)); + devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id; + } + + if (devInfo->link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB + for (int q = 0; q < comm->base.nqps; q++) { + // Print just the QPs for this dev + if (comm->base.qps[q].devIndex == i) + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d fifoRkey=0x%x fifoLkey=0x%x", + comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", + dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->fifoRkey, commDev->fifoMr->lkey); + } + } else { // RoCE + for (int q = 0; q < comm->base.nqps; q++) { + // Print just the QPs for this dev + if (comm->base.qps[q].devIndex == i) + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", + comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, + commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, ncclParamIbGidIndex(), + devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey); + } + } } + meta.fifoAddr = (uint64_t)comm->fifo; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); stage->state = ncclIbCommStateSend; stage->offset = 0; - NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(qpInfo))); - memcpy(stage->buffer, &qpInfo, sizeof(qpInfo)); + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(meta))); + + memcpy(stage->buffer, &meta, sizeof(meta)); ib_send: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, stage->buffer, sizeof(qpInfo), &stage->offset)); - if (stage->offset != sizeof(qpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(meta), &stage->offset)); + if (stage->offset != sizeof(meta)) return ncclSuccess; stage->state = ncclIbCommStateConnecting; stage->offset = 0; // Clear the staging buffer for re-use - memset(stage->buffer, 0, sizeof(qpInfo)); + memset(stage->buffer, 0, sizeof(meta)); ib_connect: - struct ncclIbQpInfo remQpInfo; - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, stage->buffer, sizeof(ncclIbQpInfo), &stage->offset)); - if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; + struct ncclIbConnectionMetadata remMeta; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclIbConnectionMetadata), &stage->offset)); + if (stage->offset != sizeof(remMeta)) return ncclSuccess; - memcpy(&remQpInfo, stage->buffer, sizeof(ncclIbQpInfo)); + memcpy(&remMeta, stage->buffer, sizeof(ncclIbConnectionMetadata)); - comm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn; - comm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid; - for (int q=0; qnqps; q++) { - struct ibv_qp* qp = comm->qps[q]; - if (remQpInfo.ece_supported[q] && qpInfo.ece_supported[q]) - NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo.ece[q], &qpInfo.ece_supported[q])); + comm->base.nRemDevs = remMeta.ndevs; + if (comm->base.nRemDevs != comm->base.ndevs) { + mergedDev = ncclIbMergedDevs + dev; + WARN("NET/IB : Local mergedDev=%s has a different number of devices=%d as remoteDev=%s nRemDevs=%d", + mergedDev->devName, comm->base.ndevs, remMeta.devName, comm->base.nRemDevs); + } - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); + int link_layer; + link_layer = remMeta.devs[0].link_layer; + for (int i = 1; i < remMeta.ndevs; i++) { + if (remMeta.devs[i].link_layer != link_layer) { + WARN("NET/IB : Can't merge net devices with different link_layer. i=%d remMeta.ndevs=%d link_layer=%d rem_link_layer=%d", + i, remMeta.ndevs, link_layer, remMeta.devs[i].link_layer); + return ncclInternalError; + } + } + + // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. + for (int i = 0; i < remMeta.ndevs; i++) { + comm->base.remDevs[i] = remMeta.devs[i]; + comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].iid; + comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].spn; + + // Retain remote sizes fifo info and prepare RDMA ops + comm->remSizesFifo.rkeys[i] = remMeta.devs[i].fifoRkey; + comm->remSizesFifo.addr = remMeta.fifoAddr; + } + + for (int i=0; i < comm->base.ndevs; i++) { + NCCLCHECK(wrap_ibv_reg_mr(comm->remSizesFifo.mrs+i, comm->devs[i].base.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); + } + comm->base.nRemDevs = remMeta.ndevs; + + for (int q = 0; q < comm->base.nqps; q++) { + struct ncclIbQpInfo* remQpInfo = remMeta.qpInfo + q; + struct ncclIbDevInfo* remDevInfo = remMeta.devs + remQpInfo->devIndex; + + // Assign per-QP remDev + comm->base.qps[q].remDevIdx = remQpInfo->devIndex; + + struct ibv_qp* qp = comm->base.qps[q].qp; + if (remQpInfo->ece_supported && remQpInfo->ece_supported) + NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported)); + + NCCLCHECK(ncclIbRtrQp(qp, remQpInfo->qpn, remDevInfo)); NCCLCHECK(ncclIbRtsQp(qp)); } - if (qpInfo.link_layer == IBV_LINK_LAYER_ETHERNET ) { // RoCE - for (int q=0; qnqps; q++) - INFO(NCCL_NET,"NET/IB: Dev %d Port %d qpn %d set_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x}", - dev, ncclIbDevs[dev].port, qpInfo.qpn[q], remQpInfo.ece_supported[q], remQpInfo.ece[q].vendor_id, remQpInfo.ece[q].options, remQpInfo.ece[q].comp_mask); + if (link_layer == IBV_LINK_LAYER_ETHERNET ) { // RoCE + for (int q = 0; q < comm->base.nqps; q++) { + struct ncclIbQp* qp = comm->base.qps + q; + int ibDevN = comm->devs[qp->devIndex].base.ibDevN; + struct ncclIbDev* ibDev = ncclIbDevs + ibDevN; + INFO(NCCL_NET,"NET/IB: IbDev %d Port %d qpn %d set_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x}", + ibDevN, ibDev->portNum, remMeta.qpInfo[q].qpn, remMeta.qpInfo[q].ece_supported, remMeta.qpInfo[q].ece.vendor_id, remMeta.qpInfo[q].ece.options, remMeta.qpInfo[q].ece.comp_mask); + } } - comm->ready = 1; + comm->base.ready = 1; stage->state = ncclIbCommStateConnected; stage->offset = 0; ib_send_ready: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, &comm->ready, sizeof(int), &stage->offset)); + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, &comm->base.ready, sizeof(int), &stage->offset)); if (stage->offset != sizeof(int)) return ncclSuccess; free(stage->buffer); @@ -765,118 +978,169 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle NCCLCHECK(ncclIbMalloc((void**)&rComm, sizeof(struct ncclIbRecvComm))); stage->comm = rComm; stage->state = ncclIbCommStateAccept; - NCCLCHECK(ncclSocketInit(&rComm->sock)); - NCCLCHECK(ncclSocketAccept(&rComm->sock, &lComm->sock)); + NCCLCHECK(ncclSocketInit(&rComm->base.sock)); + NCCLCHECK(ncclSocketAccept(&rComm->base.sock, &lComm->sock)); ib_accept_check: - NCCLCHECK(ncclSocketReady(&rComm->sock, &ready)); + NCCLCHECK(ncclSocketReady(&rComm->base.sock, &ready)); if (!ready) return ncclSuccess; - struct ncclIbQpInfo remQpInfo; + struct ncclIbConnectionMetadata remMeta; stage->state = ncclIbCommStateRecv; stage->offset = 0; - NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remQpInfo))); + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta))); ib_recv: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, stage->buffer, sizeof(remQpInfo), &stage->offset)); - if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(remMeta), &stage->offset)); + if (stage->offset != sizeof(remMeta)) return ncclSuccess; /* copy back the received info */ - memcpy(&remQpInfo, stage->buffer, sizeof(struct ncclIbQpInfo)); - - rComm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn; - rComm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid; + memcpy(&remMeta, stage->buffer, sizeof(struct ncclIbConnectionMetadata)); // IB setup - struct ibv_context* ctx; - uint8_t ib_port; - ctx = ncclIbDevs[lComm->dev].context; - ib_port = ncclIbDevs[lComm->dev].port; - struct ibv_port_attr portAttr; - NCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr)); - NCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, ncclParamIbGidIndex(), &rComm->gidInfo.localGid)); + // Pre-declare variables because of goto + struct ncclIbMergedDev* mergedDev; + struct ncclIbDev* ibDev; + int ibDevN; + struct ncclIbRecvCommDev* rCommDev; + struct ncclIbDevInfo* remDevInfo; + struct ncclIbQp* qp; - // QP Creation - NCCLCHECK(ncclIbInitVerbs(lComm->dev, ctx, &rComm->verbs)); - rComm->nqps = ncclParamIbQpsPerConn(); - for (int q=0; qnqps; q++) { - NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, rComm->qps+q)); + mergedDev = ncclIbMergedDevs + lComm->dev; + rComm->base.ndevs = mergedDev->ndevs; + rComm->base.nqps = ncclParamIbQpsPerConn() * rComm->base.ndevs; // We must have at least 1 qp per-device + rComm->base.isSend = false; + + rComm->base.nRemDevs = remMeta.ndevs; + if (rComm->base.nRemDevs != rComm->base.ndevs) { + WARN("NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", + mergedDev->devName, rComm->base.ndevs, remMeta.devName, rComm->base.nRemDevs); } - // Adjust the MTU - remQpInfo.mtu = (enum ibv_mtu)std::min(remQpInfo.mtu, portAttr.active_mtu); + // Metadata to send back to requestor (sender) + struct ncclIbConnectionMetadata meta; + for (int i = 0; i < rComm->base.ndevs; i++) { + rCommDev = rComm->devs + i; + ibDevN = mergedDev->devs[i]; + NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base)); + ibDev = ncclIbDevs + ibDevN; + NCCLCHECK(wrap_ibv_query_port(ibDev->context, ibDev->portNum, &ibDev->portAttr)); + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &rCommDev->base.gidInfo.localGid)); + } - // Setup QP - struct ncclIbQpInfo qpInfo; - for (int q=0; qnqps; q++) { - struct ibv_qp* qp = rComm->qps[q]; + // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. + for (int i = 0; i < remMeta.ndevs; i++) { + rComm->base.remDevs[i] = remMeta.devs[i]; + rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].iid; + rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].spn; + } + + // Stripe QP creation across merged devs + // Make sure to get correct remote peer dev and QP info + int remDevIndex; + int devIndex; + devIndex = 0; + for (int q = 0; q < rComm->base.nqps; q++) { + remDevIndex = remMeta.qpInfo[q].devIndex; + remDevInfo = remMeta.devs + remDevIndex; + qp = rComm->base.qps+q; + rCommDev = rComm->devs + devIndex; + qp->remDevIdx = remDevIndex; + + // Local ibDevN + ibDevN = rComm->devs[devIndex].base.ibDevN; + ibDev = ncclIbDevs + ibDevN; + NCCLCHECK(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, qp)); + qp->devIndex = devIndex; + devIndex = (devIndex + 1) % rComm->base.ndevs; // Set the ece (enhanced connection establishment) on this QP before RTR - if (remQpInfo.ece_supported[q]) { - NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo.ece[q], &qpInfo.ece_supported[q])); + if (remMeta.qpInfo[q].ece_supported) { + NCCLCHECK(wrap_ibv_set_ece(qp->qp, &remMeta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); // Query the reduced ece for this QP (matching enhancements between the requestor and the responder) // Store this in our own qpInfo for returning to the requestor - if (qpInfo.ece_supported[q]) { - NCCLCHECK(wrap_ibv_query_ece(qp, &qpInfo.ece[q], &qpInfo.ece_supported[q])); - } + if (meta.qpInfo[q].ece_supported) + NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); } - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); + NCCLCHECK(ncclIbRtrQp(qp->qp, remMeta.qpInfo[q].qpn, remDevInfo)); + NCCLCHECK(ncclIbRtsQp(qp->qp)); } - // Retain remote fifo info and prepare my RDMA ops - rComm->remFifo.rkey = remQpInfo.fifoRkey; - rComm->remFifo.addr = remQpInfo.fifoAddr; - NCCLCHECK(wrap_ibv_reg_mr(&rComm->remFifo.mr, rComm->verbs.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); - rComm->remFifo.sge.lkey = rComm->remFifo.mr->lkey; - if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE; + rComm->flushEnabled = ((ncclIbGdrSupport() == ncclSuccess || ncclIbDmaBufSupport(lComm->dev) == ncclSuccess) + && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; - // Allocate Flush dummy buffer for GPU Direct RDMA - rComm->gpuFlush.enabled = ((ncclIbGdrSupport(lComm->dev) == ncclSuccess || ncclIbDmaBufSupport(lComm->dev) == ncclSuccess) - && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; - if (rComm->gpuFlush.enabled) { - NCCLCHECK(wrap_ibv_reg_mr(&rComm->gpuFlush.hostMr, rComm->verbs.pd, &rComm->gpuFlush.hostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE)); - rComm->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlush.hostMem; - rComm->gpuFlush.sge.length = 1; - rComm->gpuFlush.sge.lkey = rComm->gpuFlush.hostMr->lkey; - NCCLCHECK(ncclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rComm->gpuFlush.qp)); - struct ncclIbQpInfo localQpInfo; - localQpInfo.lid=portAttr.lid; - localQpInfo.link_layer=portAttr.link_layer; - localQpInfo.ib_port=ib_port; - localQpInfo.spn=rComm->gidInfo.localGid.global.subnet_prefix; - localQpInfo.iid=rComm->gidInfo.localGid.global.interface_id; - localQpInfo.mtu=portAttr.active_mtu; - NCCLCHECK(ncclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo)); - NCCLCHECK(ncclIbRtsQp(rComm->gpuFlush.qp)); + for (int i = 0; i < mergedDev->ndevs; i++) { + rCommDev = rComm->devs + i; + ibDevN = rCommDev->base.ibDevN; + ibDev = ncclIbDevs + ibDevN; + + // Retain remote fifo info and prepare my RDMA ops + rCommDev->fifoRkey = remMeta.devs[i].fifoRkey; + rComm->remFifo.addr = remMeta.fifoAddr; + NCCLCHECK(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ)); + rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; + if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE; + + // Allocate Flush dummy buffer for GPU Direct RDMA + if (rComm->flushEnabled) { + NCCLCHECK(wrap_ibv_reg_mr(&rCommDev->gpuFlush.hostMr, rCommDev->base.pd, &rComm->gpuFlushHostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE)); + rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem; + rCommDev->gpuFlush.sge.length = 1; + rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey; + NCCLCHECK(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rCommDev->gpuFlush.qp)); + struct ncclIbDevInfo devInfo; + devInfo.lid = ibDev->portAttr.lid; + devInfo.link_layer = ibDev->portAttr.link_layer; + devInfo.ib_port = ibDev->portNum; + devInfo.spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo.iid = rCommDev->base.gidInfo.localGid.global.interface_id; + devInfo.mtu = ibDev->portAttr.active_mtu; + NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo)); + NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp)); + } + + // Fill Handle + meta.devs[i].lid = ibDev->portAttr.lid; + meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + meta.devs[i].ib_port = ibDev->portNum; + meta.devs[i].spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + meta.devs[i].iid = rCommDev->base.gidInfo.localGid.global.interface_id; + + // Adjust the MTU + remMeta.devs[i].mtu = (enum ibv_mtu) std::min(remMeta.devs[i].mtu, ibDev->portAttr.active_mtu); + meta.devs[i].mtu = remMeta.devs[i].mtu; + + // Prepare sizes fifo + NCCLCHECK(wrap_ibv_reg_mr(&rComm->devs[i].sizesFifoMr, rComm->devs[i].base.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); + meta.devs[i].fifoRkey = rComm->devs[i].sizesFifoMr->rkey; + } + meta.fifoAddr = (uint64_t)rComm->sizesFifo; + + for (int q = 0; q < rComm->base.nqps; q++) { + meta.qpInfo[q].qpn = rComm->base.qps[q].qp->qp_num; + meta.qpInfo[q].devIndex = rComm->base.qps[q].devIndex; } - // Fill Handle - qpInfo.lid=portAttr.lid; - qpInfo.link_layer= rComm->gidInfo.link_layer = portAttr.link_layer; - qpInfo.ib_port=ib_port; - for (int q=0; qnqps; q++) qpInfo.qpn[q]=rComm->qps[q]->qp_num; - qpInfo.spn=rComm->gidInfo.localGid.global.subnet_prefix; - qpInfo.iid=rComm->gidInfo.localGid.global.interface_id; - qpInfo.mtu=remQpInfo.mtu; + meta.ndevs = rComm->base.ndevs; + strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); stage->state = ncclIbCommStateSend; stage->offset = 0; if (stage->buffer) free(stage->buffer); - NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbQpInfo))); - memcpy(stage->buffer, &qpInfo, sizeof(struct ncclIbQpInfo)); + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbConnectionMetadata))); + memcpy(stage->buffer, &meta, sizeof(struct ncclIbConnectionMetadata)); ib_send: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->sock, stage->buffer, sizeof(struct ncclIbQpInfo), &stage->offset)); - if (stage->offset < sizeof(struct ncclIbQpInfo)) return ncclSuccess; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(struct ncclIbConnectionMetadata), &stage->offset)); + if (stage->offset < sizeof(struct ncclIbConnectionMetadata)) return ncclSuccess; stage->offset = 0; stage->state = ncclIbCommStatePendingReady; ib_recv_ready: - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, &rComm->ready, sizeof(int), &stage->offset)); + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, &rComm->base.ready, sizeof(int), &stage->offset)); if (stage->offset != sizeof(int)) return ncclSuccess; free(stage->buffer); @@ -890,14 +1154,15 @@ ib_recv_ready: return ncclSuccess; } -ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** req) { +ncclResult_t ncclIbGetRequest(struct ncclIbNetCommBase* base, struct ncclIbRequest** req) { for (int i=0; ireqs+i; + struct ncclIbRequest* r = base->reqs+i; if (r->type == NCCL_NET_IB_REQ_UNUSED) { - r->verbs = verbs; - r->events = 1; + r->base = base; r->sock = NULL; - r->gidInfo = NULL; + r->devBases[0] = NULL; + r->devBases[1] = NULL; + r->events[0] = r->events[1] = 0; *req = r; return ncclSuccess; } @@ -906,6 +1171,7 @@ ncclResult_t ncclIbGetRequest(struct ncclIbVerbs* verbs, struct ncclIbRequest** *req = NULL; return ncclInternalError; } + ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { r->type = NCCL_NET_IB_REQ_UNUSED; return ncclSuccess; @@ -913,22 +1179,16 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { ncclResult_t ncclIbTest(void* request, int* done, int* size); -/* DMA-BUF support */ -ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { - static_assert(offsetof(struct ncclIbSendComm, verbs) == offsetof(struct ncclIbRecvComm, verbs), "Send and recv comms must have verbs at the same offset"); - assert(size > 0); - +ncclResult_t ncclIbRegMrDmaBufInternal(ncclIbNetCommDevBase* base, void* data, size_t size, int type, uint64_t offset, int fd, ibv_mr** mhandle) { static __thread uintptr_t pageSize = 0; if (pageSize == 0) pageSize = sysconf(_SC_PAGESIZE); - - struct ncclIbVerbs* verbs = (struct ncclIbVerbs*)comm; - struct ncclIbMrCache* cache = &ncclIbDevs[verbs->dev].mrCache; + struct ncclIbMrCache* cache = &ncclIbDevs[base->ibDevN].mrCache; uintptr_t addr = (uintptr_t)data & -pageSize; size_t pages = ((uintptr_t)data + size - addr + pageSize-1)/pageSize; ncclResult_t res; - pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_lock(&ncclIbDevs[base->ibDevN].lock); for (int slot=0; /*true*/; slot++) { - if (slot == cache->population) { // didn't find in cache + if (slot == cache->population || addr < cache->slots[slot].addr) { // didn't find in cache if (cache->population == cache->capacity) { // must grow cache cache->capacity = cache->capacity < 32 ? 32 : 2*cache->capacity; NCCLCHECKGOTO(ncclRealloc(&cache->slots, cache->population, cache->capacity), res, returning); @@ -939,47 +1199,71 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui if (ncclIbRelaxedOrderingEnabled) flags |= IBV_ACCESS_RELAXED_ORDERING; if (fd != -1) { /* DMA-BUF support */ - NCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&mr, verbs->pd, offset, pages*pageSize, addr, fd, flags), res, returning); + NCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&mr, base->pd, offset, pages*pageSize, addr, fd, flags), res, returning); } else { if (ncclIbRelaxedOrderingEnabled) { // Use IBVERBS_1.8 API - needed for IBV_ACCESS_RELAXED_ORDERING support - NCCLCHECKGOTO(wrap_ibv_reg_mr_iova2(&mr, verbs->pd, (void*)addr, pages*pageSize, addr, flags), res, returning); + NCCLCHECKGOTO(wrap_ibv_reg_mr_iova2(&mr, base->pd, (void*)addr, pages*pageSize, addr, flags), res, returning); } else { - NCCLCHECKGOTO(wrap_ibv_reg_mr(&mr, verbs->pd, (void*)addr, pages*pageSize, flags), res, returning); + NCCLCHECKGOTO(wrap_ibv_reg_mr(&mr, base->pd, (void*)addr, pages*pageSize, flags), res, returning); } } - TRACE(NCCL_INIT,"regAddr %llx size %lld rkey %x fd %d", (unsigned long long)addr, (long long)pages*pageSize, mr->rkey, fd); - cache->population += 1; + TRACE(NCCL_INIT|NCCL_NET,"regAddr=0x%lx size=%lld rkey=0x%x lkey=0x%x fd=%d", (unsigned long)addr, (long long)pages*pageSize, mr->rkey, mr->lkey, fd); + if (slot != cache->population) memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclIbMr)); cache->slots[slot].addr = addr; cache->slots[slot].pages = pages; cache->slots[slot].refs = 1; cache->slots[slot].mr = mr; - *mhandle = (void*)mr; + cache->population += 1; + *mhandle = mr; res = ncclSuccess; goto returning; - } - else if (cache->slots[slot].addr == addr && cache->slots[slot].pages == pages) { + } else if ((addr >= cache->slots[slot].addr) && + ((addr-cache->slots[slot].addr)/pageSize+pages) <= cache->slots[slot].pages) { cache->slots[slot].refs += 1; - *mhandle = (void*)cache->slots[slot].mr; + *mhandle = cache->slots[slot].mr; res = ncclSuccess; goto returning; } } returning: - pthread_mutex_unlock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_unlock(&ncclIbDevs[base->ibDevN].lock); return res; } -ncclResult_t ncclIbRegMr(void* comm, void* data, int size, int type, void** mhandle) { - return ncclIbRegMrDmaBuf(comm, data, (size_t)size, type, 0ULL, -1, mhandle); +struct ncclIbNetCommDevBase* ncclIbGetNetCommDevBase(ncclIbNetCommBase* base, int devIndex) { + if (base->isSend) { + struct ncclIbSendComm* sComm = (struct ncclIbSendComm*) base; + return &sComm->devs[devIndex].base; + } else { + struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*) base; + return &rComm->devs[devIndex].base; + } } -ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { - struct ncclIbVerbs* verbs = (struct ncclIbVerbs*)comm; - struct ncclIbMrCache* cache = &ncclIbDevs[verbs->dev].mrCache; +/* DMA-BUF support */ +ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { + assert(size > 0); + struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) malloc(sizeof(struct ncclIbMrHandle)); + for (int i = 0; i < base->ndevs; i++) { + // Each ncclIbNetCommDevBase is at different offset in send and recv netComms + struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); + NCCLCHECK(ncclIbRegMrDmaBufInternal(devComm, data, size, type, offset, fd, mhandleWrapper->mrs + i)); + } + *mhandle = (void*) mhandleWrapper; + return ncclSuccess; +} + +ncclResult_t ncclIbRegMr(void* comm, void* data, size_t size, int type, void** mhandle) { + return ncclIbRegMrDmaBuf(comm, data, size, type, 0ULL, -1, mhandle); +} + +ncclResult_t ncclIbDeregMrInternal(ncclIbNetCommDevBase* base, ibv_mr* mhandle) { + struct ncclIbMrCache* cache = &ncclIbDevs[base->ibDevN].mrCache; ncclResult_t res; - pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_lock(&ncclIbDevs[base->ibDevN].lock); for (int i=0; i < cache->population; i++) { if (mhandle == cache->slots[i].mr) { if (0 == --cache->slots[i].refs) { @@ -989,7 +1273,7 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { cache->slots = NULL; cache->capacity = 0; } - NCCLCHECKGOTO(wrap_ibv_dereg_mr((struct ibv_mr*)mhandle), res, returning); + NCCLCHECKGOTO(wrap_ibv_dereg_mr(mhandle), res, returning); } res = ncclSuccess; goto returning; @@ -998,11 +1282,23 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { WARN("NET/IB: could not find mr %p inside cache of %d entries", mhandle, cache->population); res = ncclInternalError; returning: - pthread_mutex_unlock(&ncclIbDevs[verbs->dev].lock); + pthread_mutex_unlock(&ncclIbDevs[base->ibDevN].lock); return res; } -NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 1); +ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; + struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; + for (int i = 0; i < base->ndevs; i++) { + // Each ncclIbNetCommDevBase is at different offset in send and recv netComms + struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); + NCCLCHECK(ncclIbDeregMrInternal(devComm, mhandleWrapper->mrs[i])); + } + free(mhandleWrapper); + return ncclSuccess; +} + +NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0); ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { struct ncclIbRequest** reqs = comm->fifoReqs[slot]; @@ -1011,21 +1307,17 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; uint64_t wr_id = 0ULL; - for (int r=0; rwrs+r; memset(wr, 0, sizeof(struct ibv_send_wr)); struct ibv_sge* sge = comm->sges+r; sge->addr=(uintptr_t)reqs[r]->send.data; - sge->lkey=reqs[r]->send.lkey; - wr->opcode = IBV_WR_RDMA_WRITE; wr->send_flags = 0; wr->wr.rdma.remote_addr = slots[r].addr; - wr->wr.rdma.rkey = slots[r].rkey; - wr->next = wr+1; - wr_id += (reqs[r] - comm->verbs.reqs) << (r*8); + wr->next = wr + 1; + wr_id += (reqs[r] - comm->base.reqs) << (r*8); } // Write size as immediate data. In the case of multi-send, only write @@ -1034,13 +1326,10 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { if (nreqs == 1) { immData = reqs[0]->send.size; } else { - if (nreqs > 32) { - WARN("Cannot store sizes of %d requests in a 32-bits field", nreqs); - return ncclInternalError; - } - for (int r=0; rsend.size ? 1 : 0) << r; - } + int* sizes = comm->remSizesFifo.elems[slot]; + for (int r=0; rsend.size; + comm->remSizesFifo.sge.addr = (uint64_t)sizes; + comm->remSizesFifo.sge.length = nreqs*sizeof(int); } struct ibv_send_wr* lastWr = comm->wrs+nreqs-1; @@ -1050,6 +1339,12 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { // completion. lastWr++; memset(lastWr, 0, sizeof(struct ibv_send_wr)); + if (nreqs > 1) { + // Write remote sizes Fifo + lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); + lastWr->num_sge = 1; + lastWr->sg_list = &comm->remSizesFifo.sge; + } } lastWr->wr_id = wr_id; lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -1059,23 +1354,40 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { // Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work const int align = 128; - const int nqps = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - for (int q=0; qbase.nqps : comm->base.ndevs; + for (int i = 0; i < nqps; i++) { + int qpIndex = comm->base.qpIndex; + ncclIbQp* qp = comm->base.qps + qpIndex; + int devIndex = qp->devIndex; for (int r=0; rdevs[devIndex].base); + + // Select proper rkey (needed even for 0-size send) + comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; + int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align; int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize); if (length <= 0) { comm->wrs[r].sg_list = NULL; comm->wrs[r].num_sge = 0; } else { + // Select proper lkey + comm->sges[r].lkey = reqs[r]->send.lkeys[devIndex]; comm->sges[r].length = length; comm->wrs[r].sg_list = comm->sges+r; comm->wrs[r].num_sge = 1; } } + + if (nreqs > 1) { + // Also make sure lastWr writes remote sizes using the right lkey + comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey; + lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex]; + } + struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->qps[comm->qpIndex], comm->wrs, &bad_wr)); - comm->qpIndex = (comm->qpIndex+1)%comm->nqps; + NCCLCHECK(wrap_ibv_post_send(qp->qp, comm->wrs, &bad_wr)); for (int r=0; rsend.size, nqps), align) * align; @@ -1083,6 +1395,9 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { comm->sges[r].addr += chunkSize; comm->wrs[r].wr.rdma.remote_addr += chunkSize; } + + // Select the next qpIndex + comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps; } return ncclSuccess; @@ -1090,16 +1405,16 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; - if (comm->ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->ready == 0"); return ncclInternalError; } - if (comm->ready == 0) { *request = NULL; return ncclSuccess; } + if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->base.ready == 0"); return ncclInternalError; } + if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } - struct ibv_mr* mr = (struct ibv_mr*)mhandle; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; // Wait for the receiver to have posted the corresponding receive int nreqs = 0; volatile struct ncclIbSendFifo* slots; - int slot = (comm->fifoHead)%MAX_REQUESTS; + int slot = (comm->fifoHead) % MAX_REQUESTS; struct ncclIbRequest** reqs = comm->fifoReqs[slot]; slots = comm->fifo[slot]; uint64_t idx = comm->fifoHead+1; @@ -1111,35 +1426,47 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh for (int r=0; r slots[r].size) { + if (size > slots[r].size) size = slots[r].size; + // Sanity checks + if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { char line[SOCKET_NAME_MAXLEN + 1]; union ncclSocketAddress addr; - ncclSocketGetAddr(&comm->sock, &addr); - WARN("NET/IB : req %d/%d tag %x peer %s collective mismatch error, local size %d remote size %d", - r, nreqs, tag, ncclSocketToString(&addr, line), size, slots[r].size); - return ncclInvalidUsage; - } // plus any potential programming errors - else if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) { - char line[SOCKET_NAME_MAXLEN + 1]; - union ncclSocketAddress addr; - ncclSocketGetAddr(&comm->sock, &addr); - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x", - r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkey); + ncclSocketGetAddr(&comm->base.sock, &addr); + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkeys[0]=%x", + r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); return ncclInternalError; } + struct ncclIbRequest* req; - NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_SEND; - req->sock = &comm->sock; - req->verbs = &comm->verbs; + req->sock = &comm->base.sock; + req->base = &comm->base; req->nreqs = nreqs; req->send.size = size; req->send.data = data; - req->send.lkey = mr->lkey; req->send.offset = 0; - req->events = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo; + + // Populate events + int nEvents = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + int qpIndex = comm->base.qpIndex; + // Count down + while (nEvents > 0) { + ncclIbQp* qp = comm->base.qps + qpIndex; + int devIndex = qp->devIndex; + ncclIbAddEvent(req, devIndex, &comm->devs[devIndex].base); + // Track the valid lkey for this RDMA_Write + req->send.lkeys[devIndex] = mhandleWrapper->mrs[devIndex]->lkey; + nEvents--; + // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs inside ncclIbMultiSend() + qpIndex = (qpIndex+1)%comm->base.nqps; + } + + // Store all lkeys + for (int i = 0; i < comm->base.ndevs; i++) { + req->send.lkeys[i] = mhandleWrapper->mrs[i]->lkey; + } + *request = reqs[r] = req; // If this is a multi-recv, send only when all requests have matched. @@ -1167,24 +1494,39 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int memset(&wr, 0, sizeof(wr)); int slot = comm->remFifo.fifoTail%MAX_REQUESTS; + req->recv.sizes = comm->sizesFifo[slot]; + for (int i=0; irecv.sizes[i] = 0; struct ncclIbSendFifo* localElem = comm->remFifo.elems[slot]; + // Select the next devIndex (local) and QP to use for posting this CTS message + // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value + ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex; + comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.ndevs; + for (int i=0; irkey; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandles[i]; + + // Send all applicable rkeys + for (int j = 0; j < comm->base.ndevs; j++) + localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; + localElem[i].nreqs = n; localElem[i].size = sizes[i]; // Sanity/Debugging localElem[i].tag = tags[i]; localElem[i].idx = comm->remFifo.fifoTail+1; } - wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo); - wr.wr.rdma.rkey = comm->remFifo.rkey; - comm->remFifo.sge.addr = (uint64_t)localElem; - comm->remFifo.sge.length = n*sizeof(struct ncclIbSendFifo); - wr.sg_list = &comm->remFifo.sge; + + // Lookup the correct fifoRkey + wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey; + + // Set the correct sge properties + comm->devs[ctsQp->devIndex].fifoSge.addr = (uint64_t)localElem; + comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); + wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge; wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = comm->remFifo.flags; // IBV_SEND_INLINE @@ -1209,14 +1551,16 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int // polling it will empty the Send Queue, can be posted) // - The status of all posted Send Request is considered unknown // - if (slot == 0) { + // slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled. + // This works out that each fifo posting QP gets drained + if (slot == ctsQp->devIndex) { wr.send_flags |= IBV_SEND_SIGNALED; - wr.wr_id = req - comm->verbs.reqs; - req->events++; + wr.wr_id = req - comm->base.reqs; + ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); } struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->qps[0], &wr, &bad_wr)); + NCCLCHECK(wrap_ibv_post_send(ctsQp->qp, &wr, &bad_wr)); comm->remFifo.fifoTail++; return ncclSuccess; @@ -1224,42 +1568,47 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; - if (comm->ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->ready == 0"); return ncclInternalError; } - if (comm->ready == 0) { *request = NULL; return ncclSuccess; } + if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0"); return ncclInternalError; } + if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; struct ncclIbRequest* req; - NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_RECV; - req->sock = &comm->sock; + req->sock = &comm->base.sock; req->nreqs = n; - if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo; - for (int i=0; irecv.sizes[i] = 0; + + for (int i = 0; i < comm->base.ndevs; i++) { + req->devBases[i] = &comm->devs[i].base; + } struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = req - comm->verbs.reqs; - + wr.wr_id = req - comm->base.reqs; wr.sg_list = NULL; wr.num_sge = 0; TIME_START(1); - const int nqps = ncclParamIbSplitDataOnQps() ? comm->nqps : 1; - for (int q=0; qqps[comm->qpIndex]; - struct ibv_recv_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr)); - comm->qpIndex = (comm->qpIndex+1)%comm->nqps; - } - TIME_STOP(1); - req->events = nqps; + // Select either all QPs, or one qp per-device + const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; - *request = req; + // Post recvs + struct ibv_recv_wr* bad_wr; + for (int i = 0; i < nqps; i++) { + struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; + ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); + NCCLCHECK(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr)); + comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; + } + + TIME_STOP(1); // Post to FIFO to notify sender TIME_START(2); NCCLCHECK(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req)); TIME_STOP(2); + + *request = req; return ncclSuccess; } @@ -1267,30 +1616,35 @@ ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; int last = -1; for (int i=0; igpuFlush.enabled == 0 || last == -1) return ncclSuccess; + if (comm->flushEnabled == 0 || last == -1) return ncclSuccess; // Only flush once using the last non-zero receive struct ncclIbRequest* req; - NCCLCHECK(ncclIbGetRequest(&comm->verbs, &req)); + NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_FLUSH; - req->sock = &comm->sock; - struct ibv_mr* mr = (struct ibv_mr*)mhandles[last]; + req->sock = &comm->base.sock; + struct ncclIbMrHandle* mhandle = (struct ncclIbMrHandle*) mhandles[last]; - struct ibv_send_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = req - comm->verbs.reqs; + // We don't know which devIndex the recv was on, so we flush on all devices + for (int i = 0; i < comm->base.ndevs; i++) { + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = req - comm->base.reqs; - wr.wr.rdma.remote_addr = (uint64_t)data[last]; - wr.wr.rdma.rkey = mr->rkey; - wr.sg_list = &comm->gpuFlush.sge; - wr.num_sge = 1; - wr.opcode = IBV_WR_RDMA_READ; - wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = (uint64_t)data[last]; + wr.wr.rdma.rkey = mhandle->mrs[i]->rkey; + wr.sg_list = &comm->devs[i].gpuFlush.sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_READ; + wr.send_flags = IBV_SEND_SIGNALED; - TIME_START(4); - struct ibv_send_wr* bad_wr; - NCCLCHECK(wrap_ibv_post_send(comm->gpuFlush.qp, &wr, &bad_wr)); - TIME_STOP(4); + TIME_START(4); + struct ibv_send_wr* bad_wr; + NCCLCHECK(wrap_ibv_post_send(comm->devs[i].gpuFlush.qp.qp, &wr, &bad_wr)); + TIME_STOP(4); + + ncclIbAddEvent(req, i, &comm->devs[i].base); + } *request = req; return ncclSuccess; @@ -1299,76 +1653,105 @@ ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { struct ncclIbRequest *r = (struct ncclIbRequest*)request; *done = 0; - while (1) { - if (r->events == 0) { + if (r->events[0] == 0 && r->events[1] == 0) { + TRACE(NCCL_NET, "r=%p done", r); *done = 1; if (sizes && r->type == NCCL_NET_IB_REQ_RECV) { for (int i=0; inreqs; i++) sizes[i] = r->recv.sizes[i]; } + if (sizes && r->type == NCCL_NET_IB_REQ_SEND) { + sizes[0] = r->send.size; + } NCCLCHECK(ncclIbFreeRequest(r)); return ncclSuccess; } + int totalWrDone = 0; int wrDone = 0; struct ibv_wc wcs[4]; - TIME_START(3); - NCCLCHECK(wrap_ibv_poll_cq(r->verbs->cq, 4, wcs, &wrDone)); - if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } - if (wrDone == 0) return ncclSuccess; - for (int w=0; wstatus != IBV_WC_SUCCESS) { - char line[SOCKET_NAME_MAXLEN+1]; - union ncclSocketAddress addr; - ncclSocketGetAddr(r->sock, &addr); - char localGidString[INET6_ADDRSTRLEN] = ""; - char remoteGidString[INET6_ADDRSTRLEN] = ""; - const char* localGidStr = NULL, *remoteGidStr = NULL; - if (r->gidInfo) { - localGidStr = inet_ntop(AF_INET6, &r->gidInfo->localGid, localGidString, sizeof(localGidString)); - remoteGidStr = inet_ntop(AF_INET6, &r->gidInfo->remoteGid, remoteGidString, sizeof(remoteGidString)); - } - WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d (%s)%s%s%s%s", - ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], - localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGid ":"", remoteGidString); - return ncclRemoteError; - } + for (int i = 0; i < NCCL_IB_MAX_DEVS_PER_NIC; i++) { + TIME_START(3); + // If we expect any completions from this device's CQ + if (r->events[i]) { + NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, 4, wcs, &wrDone)); + totalWrDone += wrDone; + if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } + if (wrDone == 0) continue; + for (int w=0; wstatus != IBV_WC_SUCCESS) { + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); + char localGidString[INET6_ADDRSTRLEN] = ""; + char remoteGidString[INET6_ADDRSTRLEN] = ""; + const char* localGidStr = NULL, *remoteGidStr = NULL; + if (r->devBases[i]->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) { + localGidStr = inet_ntop(AF_INET6, &r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); + remoteGidStr = inet_ntop(AF_INET6, &r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); + } - struct ncclIbRequest* req = r->verbs->reqs+(wc->wr_id & 0xff); - if (req->type == NCCL_NET_IB_REQ_SEND) { - for (int i=0; inreqs; i++) { - struct ncclIbRequest* sendReq = r->verbs->reqs+((wc->wr_id >> (i*8)) & 0xff); - if ((sendReq->events <= 0)) return ncclInternalError; - sendReq->events--; - } - } else { - if (req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - if (req->type != NCCL_NET_IB_REQ_RECV) return ncclInternalError; - if (req->nreqs > 1) { - // In the case of a multi recv, we only set sizes to 0 or 1. - for (int i=0; inreqs; i++) { - req->recv.sizes[i] = (wc->imm_data >> i) & 0x1; + char line[SOCKET_NAME_MAXLEN+1]; + WARN("NET/IB : Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s", + ncclSocketToString(&addr, line), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], + localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString); + return ncclRemoteError; + } + + union ncclSocketAddress addr; + ncclSocketGetAddr(r->sock, &addr); + struct ncclIbRequest* req = r->base->reqs+(wc->wr_id & 0xff); + + #ifdef ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN+1]; + TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%d r=%p type=%d events={%d,%d}, i=%d", + ncclSocketToString(&addr, line), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i); + #endif + if (req->type == NCCL_NET_IB_REQ_SEND) { + for (int j = 0; j < req->nreqs; j++) { + struct ncclIbRequest* sendReq = r->base->reqs+((wc->wr_id >> (j*8)) & 0xff); + if ((sendReq->events[i] <= 0)) { + WARN("NET/IB: sendReq(%p)->events={%d,%d}, i=%d, j=%d <= 0", sendReq, sendReq->events[0], sendReq->events[1], i, j); + return ncclInternalError; + } + sendReq->events[i]--; } } else { - req->recv.sizes[0] += wc->imm_data; + if (req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + if (req->type != NCCL_NET_IB_REQ_RECV) { + WARN("NET/IB: wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM and req->type=%d", req->type); + return ncclInternalError; + } + if (req->nreqs == 1) { + req->recv.sizes[0] += wc->imm_data; + } + } + req->events[i]--; } } - req->events--; } } + + // If no CQEs found on any device, return and come back later + if (totalWrDone == 0) return ncclSuccess; } } ncclResult_t ncclIbCloseSend(void* sendComm) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm) { - NCCLCHECK(ncclSocketClose(&comm->sock)); - for (int q=0; qnqps; q++) - if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); - if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr)); - NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs)); + NCCLCHECK(ncclSocketClose(&comm->base.sock)); + + for (int q = 0; q < comm->base.nqps; q++) + if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); + + for (int i = 0; i < comm->base.ndevs; i++) { + struct ncclIbSendCommDev* commDev = comm->devs + i; + if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); + if (comm->remSizesFifo.mrs[i] != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mrs[i])); + NCCLCHECK(ncclIbDestroyBase(&commDev->base)); + } free(comm); } TIME_PRINT("IB"); @@ -1378,15 +1761,21 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { ncclResult_t ncclIbCloseRecv(void* recvComm) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm) { - NCCLCHECK(ncclSocketClose(&comm->sock)); - for (int q=0; qnqps; q++) - if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q])); - if (comm->gpuFlush.enabled) { - if (comm->gpuFlush.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->gpuFlush.qp)); - if (comm->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr)); + NCCLCHECK(ncclSocketClose(&comm->base.sock)); + + for (int q = 0; q < comm->base.nqps; q++) + if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); + + for (int i = 0; i < comm->base.ndevs; i++) { + struct ncclIbRecvCommDev* commDev = comm->devs + i; + if (comm->flushEnabled) { + if (commDev->gpuFlush.qp.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(commDev->gpuFlush.qp.qp)); + if (commDev->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->gpuFlush.hostMr)); + } + if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); + if (commDev->sizesFifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->sizesFifoMr)); + NCCLCHECK(ncclIbDestroyBase(&commDev->base)); } - if (comm->remFifo.mr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remFifo.mr)); - NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs)); free(comm); } return ncclSuccess; diff --git a/projects/rccl/src/transport/net_socket.cc b/projects/rccl/src/transport/net_socket.cc index 502179a217..e9e0357141 100644 --- a/projects/rccl/src/transport/net_socket.cc +++ b/projects/rccl/src/transport/net_socket.cc @@ -96,6 +96,7 @@ ncclResult_t ncclNetSocketGetProperties(int dev, ncclNetProperties_t* props) { props->pciPath = ncclNetSocketDevs[dev].pciPath; props->guid = dev; props->ptrSupport = NCCL_PTR_HOST; + props->regIsGlobal = 0; NCCLCHECK(ncclNetSocketGetSpeed(props->name, &props->speed)); props->latency = 0; // Not set props->port = 0; @@ -534,7 +535,7 @@ ncclResult_t ncclNetSocketTest(void* request, int* done, int* size) { return ncclSuccess; } -ncclResult_t ncclNetSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) { +ncclResult_t ncclNetSocketRegMr(void* comm, void* data, size_t size, int type, void** mhandle) { return (type != NCCL_PTR_HOST) ? ncclInternalError : ncclSuccess; } ncclResult_t ncclNetSocketDeregMr(void* comm, void* mhandle) { return ncclSuccess; } diff --git a/projects/rccl/src/transport/nvls.cc b/projects/rccl/src/transport/nvls.cc index 4dfae51cfe..a03aab387a 100644 --- a/projects/rccl/src/transport/nvls.cc +++ b/projects/rccl/src/transport/nvls.cc @@ -11,6 +11,7 @@ #include "utils.h" #include "proxy.h" #include "enqueue.h" +#include "register.h" #if CUDART_VERSION >= 12010 @@ -20,19 +21,8 @@ struct graphRegData { }; struct localRegData { - /* Registration record data */ - uintptr_t recSendbuff, recRecvbuff; - intptr_t recSendOffset, recRecvOffset; - /* Registration request data */ - uintptr_t reqSendbuff, reqRecvbuff; - size_t reqSendSize, reqRecvSize; - intptr_t reqSendOffset, reqRecvOffset; -}; - -struct localRequestData { - uintptr_t reqBuff; - size_t reqSize; - intptr_t reqOffset; + struct ncclReg reg; + intptr_t offset; }; ncclResult_t nvlsCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { @@ -116,11 +106,9 @@ ncclResult_t nvlsGroupConnect(struct ncclComm *comm, char *shareableHandle, int // cuMem UDS support int fd = -1; TRACE(NCCL_NVLS, "NVLS rank %d Importing shareable handle %p from rank %d", comm->localRank, shareableHandle, rank); - struct ncclProxyConnector proxyConn; int tpProxyRank = comm->topParentRanks[rank]; - NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpProxyRank, &proxyConn)); TRACE(NCCL_NVLS, "NVLS rank %d request conversion of handle 0x%lx from rank %d", comm->localRank, *(uint64_t*)shareableHandle, rank); - NCCLCHECK(ncclProxyClientGetFdBlocking(comm, &proxyConn, shareableHandle, &fd)); + NCCLCHECK(ncclProxyClientGetFdBlocking(comm, tpProxyRank, shareableHandle, &fd)); TRACE(NCCL_NVLS, "NVLS rank %d received converted fd %d from rank %d", comm->localRank, fd, rank); CUCHECK(cuMemImportFromShareableHandle(mcHandle, (void *)(uintptr_t)fd, type)); (void) close(fd); @@ -248,7 +236,8 @@ ncclResult_t ncclNvlsInit(struct ncclComm* comm) { int gpuCount; NCCLCHECK(ncclTopoGetGpuCount(comm->topo, &gpuCount)); - if (!ncclParamNvlsEnable() || gpuCount <= 2) return ncclSuccess; + // NVLS is not supported on MNNVL yet + if (!ncclParamNvlsEnable() || gpuCount <= 2 || comm->nNodes > 1 || comm->MNNVL) return ncclSuccess; CUdevice dev; int driverVersion; @@ -292,14 +281,14 @@ ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) { if (nvlsShare) { /* reuse NVLS resources */ comm->nvlsChannels = std::min(comm->nvlsChannels, parent->nvlsResources->nChannels); - for (int c = 0; c < comm->nvlsChannels; c++) { + for (int c = 0; c < comm->nChannels; c++) { NCCLCHECKGOTO(initNvlsChannel(comm, c, parent, true), res, cleanup); } comm->nvlsResources = parent->nvlsResources; ncclAtomicRefCountIncrement(&parent->nvlsResources->refCount); } else { - int nChannels; + int nChannels = comm->nChannels; struct ncclNvlsSharedRes* resources; NCCLCHECK(ncclCalloc(&resources, 1)); @@ -312,7 +301,7 @@ ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) { comm->nvlsChannels = std::min(comm->nvlsChannels, parent->nvlsResources->nChannels); } - nChannels = resources->nChannels = comm->nvlsChannels; + resources->nChannels = comm->nvlsChannels; for (int c = 0; c < nChannels; c++) { NCCLCHECK(initNvlsChannel(comm, c, parent, false)); } @@ -390,7 +379,8 @@ ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) { } /* create shared memory for fast NVLS buffer registration */ - typeSize = sizeof(struct localRegData); + typeSize = sizeof(struct localRegData) << 1; + if (comm->localRank == 0) { shmPath[0] = '\0'; NCCLCHECKGOTO(ncclShmOpen(shmPath, (sizeof(size_t) + typeSize * comm->localRanks) * 2, (void**)&nvlsShmem, NULL, comm->localRanks - 1, &comm->nvlsResources->nvlsShmemHandle), res, cleanup); @@ -405,6 +395,7 @@ ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) { comm->nvlsResources->nvlsShmem.cnt[1] = (size_t*)((char*)comm->nvlsResources->nvlsShmem.ptr[0] + typeSize * comm->localRanks); comm->nvlsResources->nvlsShmem.ptr[1] = (void*)((char*)comm->nvlsResources->nvlsShmem.cnt[1] + sizeof(size_t)); comm->nvlsResources->nvlsShmem.round = 0; + comm->nvlsResources->nvlsShmem.maxTypeSize = typeSize; return res; @@ -427,23 +418,59 @@ ncclResult_t ncclNvlsFree(struct ncclComm* comm) { return ncclSuccess; } -ncclResult_t tryRegisterBuffer(struct ncclComm *comm, struct localRequestData *reqData, uintptr_t userBuff, size_t buffSize, CUdeviceptr *regAddr, bool *regUsed) { +ncclResult_t tryRegisterBuffer(struct ncclComm *comm, uintptr_t userBuff, size_t buffSize, CUdeviceptr *regAddr, bool *regUsed) { ncclResult_t ret = ncclSuccess; - struct ncclRegRecord *regRecord = NULL; - struct localRequestData *myReqData = &reqData[comm->localRank]; + struct ncclReg *regRecord = NULL; CUdeviceptr regPtr = 0; CUmulticastObjectProp prop; char shareableHandle[NVLS_HANDLE_SIZE]; CUmemGenericAllocationHandle mcHandle; size_t granularity; - size_t minSize; + size_t minSize = SIZE_MAX; bool localRegBufUsed = false; + struct localRegData* regData = NULL; + cudaPointerAttributes attr; - /* get minimal size of nvls buffers */ - minSize = reqData[0].reqSize; - for (int i = 1; i < comm->localRanks; ++i) { - if (minSize > reqData[i].reqSize) - minSize = reqData[i].reqSize; + NCCLCHECKGOTO(ncclCalloc(®Data, comm->localRanks), ret, fail); + + if (userBuff) { + NCCLCHECKGOTO(ncclRegFind(comm, (void*)userBuff, buffSize, ®Record), ret, fail); + if (regRecord) { + CUDACHECK(cudaPointerGetAttributes(&attr, (void*)regRecord->addr)); + if (attr.type == cudaMemoryTypeDevice) { + size_t regSize = regRecord->pages * comm->regCache.pageSize; + prop = comm->nvlsResources->properties; + prop.size = regSize; + CUCHECK(cuMulticastGetGranularity(&granularity, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + CUCHECK(cuMemGetAddressRange((CUdeviceptr*)®Record->baseAddr, ®Record->baseSize, (CUdeviceptr)regRecord->addr)); + if (regSize % granularity == 0) { + regRecord->regSize = regSize; + } else { + regRecord->regSize = regRecord->baseSize - (regRecord->addr - regRecord->baseAddr); + } + + if (regRecord->addr % comm->nvlsResources->ucGran == 0 && regRecord->regSize % granularity == 0) { + regRecord->state |= NVLS_REG_POSSIBLE; + memcpy(®Data[comm->localRank].reg, regRecord, sizeof(struct ncclReg)); + regData[comm->localRank].offset = userBuff - regRecord->addr; + } + } + + if ((regRecord->state & NVLS_REG_POSSIBLE) == 0) { + regRecord->state |= NVLS_REG_NO_SUPPORT; + } + } + } + + NCCLCHECKGOTO(ncclShmemAllgather(comm, &comm->nvlsResources->nvlsShmem, regData + comm->localRank, regData, sizeof(struct localRegData)), ret, fail); + + for (int i = 0; i < comm->localRanks; ++i) { + if ((regData[i].reg.state & NVLS_REG_POSSIBLE) == 0) { + goto fail; + } + /* get minimal reg size of nvls buffers */ + if (minSize > regData[i].reg.regSize) + minSize = regData[i].reg.regSize; } /* start registration */ @@ -459,7 +486,7 @@ ncclResult_t tryRegisterBuffer(struct ncclComm *comm, struct localRequestData *r } CUCHECKGOTO(cuMulticastAddDevice(mcHandle, comm->nvlsResources->dev), ret, fail); - CUCHECKGOTO(cuMulticastBindAddr(mcHandle, 0, (CUdeviceptr)myReqData->reqBuff, minSize, 0), ret, fail); + CUCHECKGOTO(cuMulticastBindAddr(mcHandle, 0, (CUdeviceptr)regRecord->addr, minSize, 0), ret, fail); // Create a VA for the NVLS CUCHECKGOTO(cuMemAddressReserve(®Ptr, minSize, granularity, 0U, 0), ret, fail); @@ -467,26 +494,28 @@ ncclResult_t tryRegisterBuffer(struct ncclComm *comm, struct localRequestData *r CUCHECKGOTO(cuMemMap(regPtr, minSize, 0, mcHandle, 0), ret, fail); CUCHECKGOTO(cuMemSetAccess(regPtr, minSize, &comm->nvlsResources->accessDesc, 1), ret, fail); - NCCLCHECKGOTO(ncclCalloc(®Record, 1), ret, fail); - regRecord->buff = myReqData->reqBuff; - regRecord->size = myReqData->reqSize; regRecord->regAddr = regPtr; regRecord->regSize = minSize; regRecord->dev = comm->nvlsResources->dev; regRecord->mcHandle = mcHandle; + regRecord->state |= NVLS_REG_COMPLETE; /* get all buffer addresses */ - NCCLCHECKGOTO(ncclCalloc(®Record->addrs, comm->localRanks), ret, fail); - regRecord->addrs[comm->localRank] = regRecord->buff; - NCCLCHECKGOTO(ncclShmemAllgather(comm, &comm->nvlsResources->nvlsShmem, regRecord->addrs + comm->localRank, regRecord->addrs, sizeof(uintptr_t)), ret, fail); - /* enqueue record */ - ncclIntruQueueEnqueue(&comm->regRecordQueue, regRecord); + regRecord->caddrs[comm->localRank] = regRecord->addr; + NCCLCHECKGOTO(ncclShmemAllgather(comm, &comm->nvlsResources->nvlsShmem, regRecord->caddrs + comm->localRank, regRecord->caddrs, sizeof(uintptr_t)), ret, fail); + + /* Although registration is done, we still need to check whether the offsets are same among ranks. */ + for (int i = 0; i < comm->localRanks - 1; ++i) { + if (regData[i].offset != regData[i + 1].offset) { + goto fail; + } + } localRegBufUsed = true; exit: - if (localRegBufUsed) - *regAddr = (uintptr_t)regPtr + userBuff - myReqData->reqBuff; + if (localRegBufUsed) *regAddr = (uintptr_t)regPtr + regData[comm->localRank].offset; *regUsed = localRegBufUsed; + free(regData); return ret; fail: localRegBufUsed = false; @@ -497,77 +526,52 @@ ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *send ncclResult_t ret = ncclSuccess; bool localRegBufUsed = false; struct localRegData *regData = NULL; - struct localRequestData *reqData = NULL; - struct ncclRegRecord *regRecordHead = NULL, *sendRegRecord = NULL, *recvRegRecord = NULL; - struct ncclRegRequest *regRequestHead = NULL, *sendRegRequest = NULL, *recvRegRequest = NULL; bool sendNeedReg = false, recvNeedReg = false; CUdeviceptr regSendPtr = 0; CUdeviceptr regRecvPtr = 0; + struct ncclReg *sendRegRecord = NULL; + struct ncclReg *recvRegRecord = NULL; *outRegBufUsed = false; - NCCLCHECKGOTO(ncclCalloc(®Data, comm->localRanks), ret, fail); + NCCLCHECKGOTO(ncclCalloc(®Data, comm->localRanks * 2), ret, fail); - /* first check whether the buffer has been registered and matches each other globally */ - regRecordHead = ncclIntruQueueHead(&comm->regRecordQueue); - while (regRecordHead && ((sendRegRecord == NULL && sendbuff != NULL) || (recvRegRecord == NULL && recvbuff != NULL))) { - /* check send reg record */ - if (sendRegRecord == NULL && regRecordHead->buff <= (uintptr_t)sendbuff && - regRecordHead->buff + regRecordHead->size >= (uintptr_t)sendbuff + sendbuffSize) { - regData[comm->localRank].recSendbuff = regRecordHead->buff; - regData[comm->localRank].recSendOffset = (uintptr_t)sendbuff - regRecordHead->buff; - sendRegRecord = regRecordHead; + if (sendbuff) { + NCCLCHECKGOTO(ncclRegFind(comm, sendbuff, sendbuffSize, &sendRegRecord), ret, fail); + if (sendRegRecord) { + memcpy(®Data[comm->localRank * 2].reg, sendRegRecord, sizeof(struct ncclReg)); + regData[comm->localRank * 2].offset = (uintptr_t)sendbuff - sendRegRecord->addr; } - - /* check recv reg record */ - if (recvRegRecord == NULL && regRecordHead->buff <= (uintptr_t)recvbuff && - regRecordHead->buff + regRecordHead->size >= (uintptr_t)recvbuff + recvbuffSize) { - regData[comm->localRank].recRecvbuff = regRecordHead->buff; - regData[comm->localRank].recRecvOffset = (uintptr_t)recvbuff - regRecordHead->buff; - recvRegRecord = regRecordHead; - } - regRecordHead = regRecordHead->next; } - /* prepare registration request for later reference */ - regRequestHead = ncclIntruQueueHead(&comm->regRequestQueue); - while (regRequestHead && ((sendRegRequest == NULL && sendbuff != NULL) || (recvRegRequest == NULL && recvbuff != NULL))) { - /* check send reg request */ - if (regRequestHead->buff <= (uintptr_t)sendbuff && - regRequestHead->buff + regRequestHead->size >= (uintptr_t)sendbuff + sendbuffSize) { - regData[comm->localRank].reqSendbuff = regRequestHead->buff; - regData[comm->localRank].reqSendSize = regRequestHead->size; - regData[comm->localRank].reqSendOffset = (uintptr_t)sendbuff - regRequestHead->buff; - sendRegRequest = regRequestHead; + if (recvbuff) { + NCCLCHECKGOTO(ncclRegFind(comm, recvbuff, recvbuffSize, &recvRegRecord), ret, fail); + if (recvRegRecord) { + memcpy(®Data[comm->localRank * 2 + 1].reg, recvRegRecord, sizeof(struct ncclReg)); + regData[comm->localRank * 2 + 1].offset = (uintptr_t)recvbuff - recvRegRecord->addr; } - - /* check recv reg request */ - if (regRequestHead->buff <= (uintptr_t)recvbuff && - regRequestHead->buff + regRequestHead->size >= (uintptr_t)recvbuff + recvbuffSize) { - regData[comm->localRank].reqRecvbuff = regRequestHead->buff; - regData[comm->localRank].reqRecvSize = regRequestHead->size; - regData[comm->localRank].reqRecvOffset = (uintptr_t)recvbuff - regRequestHead->buff; - recvRegRequest = regRequestHead; - } - regRequestHead = regRequestHead->next; } - NCCLCHECKGOTO(ncclShmemAllgather(comm, &comm->nvlsResources->nvlsShmem, regData + comm->localRank, regData, sizeof(struct localRegData)), ret, fail); + NCCLCHECKGOTO(ncclShmemAllgather(comm, &comm->nvlsResources->nvlsShmem, regData + comm->localRank * 2, regData, sizeof(struct localRegData) * 2), ret, fail); /* first check whether all local ranks find their registered buffer */ for (int i = 0; i < comm->localRanks; ++i) { - if (regData[i].recSendbuff == 0 || sendRegRecord->addrs[i] != regData[i].recSendbuff) { + if ((regData[i * 2].reg.state & NVLS_REG_COMPLETE) == 0 || regData[comm->localRank * 2].reg.caddrs[i] != regData[i * 2].reg.addr) { sendNeedReg = true; } - if (regData[i].recRecvbuff == 0 || recvRegRecord->addrs[i] != regData[i].recRecvbuff) { + if ((regData[i * 2 + 1].reg.state & NVLS_REG_COMPLETE) == 0 || regData[comm->localRank * 2 + 1].reg.caddrs[i] != regData[i * 2 + 1].reg.addr) { recvNeedReg = true; } + + if ((regData[i * 2].reg.state & NVLS_REG_NO_SUPPORT) || (regData[i * 2 + 1].reg.state & NVLS_REG_NO_SUPPORT)) { + goto fail; + } } if (sendNeedReg == false) { for (int i = 0; i < comm->localRanks - 1; ++i) { - if (regData[i].recSendOffset != regData[i + 1].recSendOffset) { + if (regData[i * 2].offset != regData[(i + 1) * 2].offset) { /* offset are different, we cannot apply user buffer registration */ goto fail; } @@ -575,18 +579,18 @@ ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *send /* reuse previous registered buffer if possible */ if (!sendNeedReg) - regSendPtr = (CUdeviceptr)((uintptr_t)sendRegRecord->regAddr + regData[comm->localRank].recSendOffset); + regSendPtr = (CUdeviceptr)((uintptr_t)sendRegRecord->regAddr + regData[comm->localRank * 2].offset); } if (recvNeedReg == false) { for (int i = 0; i < comm->localRanks - 1; ++i) { - if (regData[i].recRecvOffset != regData[i + 1].recRecvOffset) { + if (regData[i * 2 + 1].offset != regData[(i + 1) * 2 + 1].offset) { goto fail; } } if (!recvNeedReg) - regRecvPtr = (CUdeviceptr)((uintptr_t)recvRegRecord->regAddr + regData[comm->localRank].recRecvOffset); + regRecvPtr = (CUdeviceptr)((uintptr_t)recvRegRecord->regAddr + regData[comm->localRank * 2 + 1].offset); } if ((!sendNeedReg || sendbuff == NULL) && (!recvNeedReg || recvbuff == NULL)) { @@ -597,29 +601,13 @@ ncclResult_t ncclNvlsLocalRegisterBuffer(struct ncclComm *comm, const void *send /* Start Registration. Not found registered buffers, then check whether both send and recv buffer locate * in register request cache. */ - NCCLCHECKGOTO(ncclCalloc(&reqData, comm->localRanks), ret, fail); - if (sendNeedReg && sendbuff != NULL) { - /* copy request data got from previous shmem AG */ - intptr_t offset = regData[0].reqSendOffset; - for (int i = 0; i < comm->localRanks; ++i) { - if (regData[i].reqSendbuff == 0 || offset != regData[i].reqSendOffset) goto fail; - reqData[i].reqBuff = regData[i].reqSendbuff; - reqData[i].reqSize = regData[i].reqSendSize; - reqData[i].reqOffset = regData[i].reqSendOffset; - } - tryRegisterBuffer(comm, reqData, (uintptr_t)sendbuff, sendbuffSize, ®SendPtr, &localRegBufUsed); + if (sendNeedReg && sendbuff) { + tryRegisterBuffer(comm, (uintptr_t)sendbuff, sendbuffSize, ®SendPtr, &localRegBufUsed); if (localRegBufUsed == false) goto fail; } - if (recvNeedReg && recvbuff != NULL) { - intptr_t offset = regData[0].reqRecvOffset; - for (int i = 0; i < comm->localRanks; ++i) { - if (regData[i].reqRecvbuff == 0 || offset != regData[i].reqRecvOffset) goto fail; - reqData[i].reqBuff = regData[i].reqRecvbuff; - reqData[i].reqSize = regData[i].reqRecvSize; - reqData[i].reqOffset = regData[i].reqRecvOffset; - } - tryRegisterBuffer(comm, reqData, (uintptr_t)recvbuff, recvbuffSize, ®RecvPtr, &localRegBufUsed); + if (recvNeedReg && recvbuff) { + tryRegisterBuffer(comm, (uintptr_t)recvbuff, recvbuffSize, ®RecvPtr, &localRegBufUsed); if (localRegBufUsed == false) goto fail; } @@ -630,7 +618,6 @@ exit: *outRegBufRecv = (void*)regRecvPtr; *outRegBufUsed = localRegBufUsed; free(regData); - free(reqData); return ncclSuccess; fail: localRegBufUsed = false; @@ -647,7 +634,7 @@ ncclResult_t ncclNvlsGraphRegisterBuffer(struct ncclComm *comm, struct ncclKerne CUmulticastObjectProp prop; char shareableHandle[NVLS_HANDLE_SIZE]; CUmemGenericAllocationHandle sendMcHandle, recvMcHandle; - size_t sendGran, recvGran; + size_t sendGran = 0, recvGran = 0; bool *regBufFlags = NULL; struct graphRegData *rdata = NULL; const void *baseSend = NULL; @@ -667,19 +654,17 @@ ncclResult_t ncclNvlsGraphRegisterBuffer(struct ncclComm *comm, struct ncclKerne if (recvbuff != NULL) CUCHECKGOTO(cuMemGetAddressRange((CUdeviceptr *)&baseRecv, &baseRecvSize, (CUdeviceptr)recvbuff), ret, fail); - memcpy(&prop, &comm->nvlsResources->properties, sizeof(CUmulticastObjectProp)); - prop.size = baseSendSize; - CUCHECKGOTO(cuMulticastGetGranularity(&sendGran, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED), ret, fail); - prop.size = baseRecvSize; - CUCHECKGOTO(cuMulticastGetGranularity(&recvGran, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED), ret, fail); - - localRegBufUsed = ((uint64_t)baseSend % sendGran != 0 || (uint64_t)baseRecv % recvGran != 0) ? false : true; + localRegBufUsed = ((uint64_t)baseSend % comm->nvlsResources->ucGran != 0 || (uint64_t)baseRecv % comm->nvlsResources->ucGran != 0) ? false : true; regBufFlags[comm->localRank] = localRegBufUsed; NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, regBufFlags, sizeof(bool)), ret, fail); for (int i = 0; i < comm->localRanks; ++i) if (regBufFlags[i] == false) goto fail; + memcpy(&prop, &comm->nvlsResources->properties, sizeof(CUmulticastObjectProp)); if (sendbuff != NULL) { + prop.size = baseSendSize; + CUCHECKGOTO(cuMulticastGetGranularity(&sendGran, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED), ret, fail); + /* check send buffer offset and size */ rdata[comm->localRank].offset = (uintptr_t)sendbuff - (uintptr_t)baseSend; rdata[comm->localRank].size = baseSendSize; @@ -719,6 +704,9 @@ ncclResult_t ncclNvlsGraphRegisterBuffer(struct ncclComm *comm, struct ncclKerne } if (recvbuff != NULL) { + prop.size = baseRecvSize; + CUCHECKGOTO(cuMulticastGetGranularity(&recvGran, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED), ret, fail); + rdata[comm->localRank].offset = (uintptr_t)recvbuff - (uintptr_t)baseRecv; rdata[comm->localRank].size = baseRecvSize; NCCLCHECKGOTO(bootstrapIntraNodeAllGather(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, rdata, sizeof(struct graphRegData)), ret, fail); diff --git a/projects/rccl/src/transport/p2p.cc b/projects/rccl/src/transport/p2p.cc index 3e4dab7e44..b29224e13e 100644 --- a/projects/rccl/src/transport/p2p.cc +++ b/projects/rccl/src/transport/p2p.cc @@ -103,6 +103,12 @@ static void initCeOperation(); ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { initCeOperation(); + // MNNVL support + if (info1->hostHash != info2->hostHash) { + NCCLCHECK(ncclTopoCheckMNNVL(topo, info1, info2, ret)); + if (*ret) return ncclSuccess; + } + // Rule out different nodes / isolated containers if (info1->hostHash != info2->hostHash || info1->shmDev != info2->shmDev) { *ret = 0; @@ -190,8 +196,9 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop ncclResult_t ncclP2pAllocateShareableBuffer(size_t size, ncclIpcDesc *ipcDesc, void **ptr) { if (ncclCuMemEnable()) { #if CUDART_VERSION >= 11030 + CUmemAllocationHandleType type = ncclCuMemHandleType; + // cuMem API support - CUmemAllocationHandleType type = NCCL_P2P_HANDLE_TYPE; CUmemGenericAllocationHandle handle; NCCLCHECK(ncclCuMemAlloc(ptr, &handle, size)); if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { @@ -227,18 +234,16 @@ ncclResult_t ncclP2pImportShareableBuffer(struct ncclComm *comm, int tpPeer, siz #if CUDART_VERSION >= 11030 // cuMem API support CUdeviceptr dptr = 0; - CUmemAllocationHandleType type = NCCL_P2P_HANDLE_TYPE; + CUmemAllocationHandleType type = ncclCuMemHandleType; CUmemGenericAllocationHandle handle; ncclCuDesc *cuDesc = &ipcDesc->cuDesc; // Import and map the remote memory descriptor to the local GPU if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { // UDS fd support - struct ncclProxyConnector proxyConn; int fd = -1; // Send cuMem handle to remote for conversion to an fd - NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpPeer, &proxyConn)); - NCCLCHECK(ncclProxyClientGetFdBlocking(comm, &proxyConn, &cuDesc->data, &fd)); + NCCLCHECK(ncclProxyClientGetFdBlocking(comm, tpPeer, &cuDesc->data, &fd)); INFO(NCCL_P2P, "UDS converted handle 0x%lx to fd %d on remote peer %d", *(uint64_t*)&cuDesc->data, fd, tpPeer); CUCHECK(cuMemImportFromShareableHandle(&handle, (void *)(uintptr_t)fd, type)); (void) close(fd); @@ -276,6 +281,8 @@ ncclResult_t ncclP2pImportShareableBuffer(struct ncclComm *comm, int tpPeer, siz NCCL_PARAM(P2pReadEnable, "P2P_READ_ENABLE", -2); NCCL_PARAM(P2pDirectDisable, "P2P_DIRECT_DISABLE", 0); +#define P2P_SAME_PID(MYINFO, PEERINFO) ((MYINFO->hostHash == PEERINFO->hostHash) && (MYINFO->pidHash == PEERINFO->pidHash)) + static ncclResult_t p2pGetInfo(struct ncclTopoSystem* topo, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2, int* read, int* intermediateRank) { int p2p; // Queries the topology to see if the GPUs are Ampere and @@ -288,7 +295,7 @@ static ncclResult_t p2pGetInfo(struct ncclTopoSystem* topo, struct ncclPeerInfo* } static ncclResult_t p2pMap(struct ncclComm *comm, struct ncclProxyConnector* proxyConn, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclP2pBuff* p2pBuff, void** devMem, void** ipcPtr) { - if (myInfo->pidHash == peerInfo->pidHash) { + if (P2P_SAME_PID(myInfo, peerInfo)) { if (peerInfo->cudaDev != myInfo->cudaDev) { // Same PID different GPUs, enable P2P access // Legacy CUDA IPC @@ -316,15 +323,9 @@ static ncclResult_t p2pMap(struct ncclComm *comm, struct ncclProxyConnector* pro *devMem = p2pBuff->directPtr; *ipcPtr = NULL; } else { - if ((myInfo->pidHash == peerInfo->pidHash) && (peerInfo->cudaDev == myInfo->cudaDev)) { - // Same PID and GPU - *devMem = p2pBuff->directPtr; - *ipcPtr = NULL; - } else { - // Different PID or different GPU - NCCLCHECK(ncclP2pImportShareableBuffer(comm, comm->topParentRanks[peerInfo->rank], p2pBuff->size, &p2pBuff->ipcDesc, devMem)); - *ipcPtr = *devMem; - } + // Different PID + NCCLCHECK(ncclP2pImportShareableBuffer(comm, comm->topParentRanks[peerInfo->rank], p2pBuff->size, &p2pBuff->ipcDesc, devMem)); + *ipcPtr = *devMem; } return ncclSuccess; } @@ -354,7 +355,7 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st if (intermediateRank == -1) { info->rank = myInfo->rank; - if (myInfo->pidHash == peerInfo->pidHash && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0) { + if (P2P_SAME_PID(myInfo, peerInfo) && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0) { resources->type = P2P_DIRECT; send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/direct pointer%s", @@ -363,8 +364,9 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st // cuMem API support if (ncclCuMemEnable()) { resources->type = P2P_CUMEM; - INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/CUMEM%s%s", - channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, useReadStr, useMemcpy ? "/CE" : "");; + const char *MNNVL = comm->MNNVL ? "MNNVL" : "CUMEM"; + INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%d] -> %d[%d] via P2P/%s%s%s", + channelId, connIndex, myInfo->rank, myInfo->nvmlDev, peerInfo->rank, peerInfo->nvmlDev, MNNVL, useReadStr, useMemcpy ? "/CE" : "");; } else { // Legacy CUDA IPC resources->type = P2P_IPC; @@ -418,7 +420,7 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st if (intermediateRank == -1) { info->rank = myInfo->rank; - if (myInfo->pidHash == peerInfo->pidHash && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0) { + if (P2P_SAME_PID(myInfo, peerInfo) && ncclParamP2pDirectDisable() == 0 && useMemcpy == 0) { resources->type = P2P_DIRECT; recv->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; } else { @@ -468,7 +470,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co if (useMemcpy) { send->conn.tail = &resources->proxyInfo.ceRecvMem->tail; - send->conn.sizesFifo = resources->proxyInfo.ceRecvMem->sizesFifo; + send->conn.connFifo = resources->proxyInfo.ceRecvMem->connFifo; send->conn.head = &resources->proxyInfo.devShm->sendMem.head; // Send SIMPLE buff to proxy, and replace it by local buffer NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0)); @@ -712,11 +714,11 @@ static ncclResult_t p2pSendProxyProgress(struct ncclProxyState* proxyState, stru } if (sub->transmitted < sub->done + NCCL_STEPS && sub->transmitted < sub->nsteps) { int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; - volatile int* sizesFifo = resources->ceRecvMem->sizesFifo; + volatile struct ncclConnFifo* connFifo = resources->ceRecvMem->connFifo; volatile uint64_t* recvTail = &resources->ceRecvMem->tail; // Check GPU has sent everything if ((*recvTail > sub->base+sub->transmitted)) { - int size = sizesFifo[buffSlot]; + int size = connFifo[buffSlot].size; CUDACHECK(cudaMemcpyAsync(resources->recvFifo+buffSlot*stepSize, resources->ceDevBuff+buffSlot*stepSize, size, cudaMemcpyDeviceToDevice, resources->stream)); CUDACHECK(cudaEventRecord(resources->events[buffSlot], resources->stream)); sub->transmitted += args->sliceSteps; diff --git a/projects/rccl/src/transport/shm.cc b/projects/rccl/src/transport/shm.cc index 5b24429199..99d26c5da9 100644 --- a/projects/rccl/src/transport/shm.cc +++ b/projects/rccl/src/transport/shm.cc @@ -152,7 +152,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co send->conn.head = &resources->devHostMem->head; if (useMemcpyRecv) { - send->conn.sizesFifo = resources->devRemHostMem->sizesFifo; + send->conn.connFifo = resources->devRemHostMem->connFifo; } if (useMemcpySend) { int tpProxyRank; @@ -162,7 +162,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co NCCLCHECK(ncclProxyCallBlocking(comm, &send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); send->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; send->conn.tail = &proxyInfo.ceRecvMem->tail; - send->conn.sizesFifo = proxyInfo.ceRecvMem->sizesFifo; + send->conn.connFifo = proxyInfo.ceRecvMem->connFifo; } // We must assign the proxyConn's proxyProgress property for proper checking at enqueue-time @@ -315,15 +315,15 @@ static ncclResult_t shmSendProxyProgress(struct ncclProxyState* proxyState, stru } if (sub->transmitted < sub->done + NCCL_STEPS && sub->transmitted < sub->nsteps) { int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; - volatile int* sizesFifo = resources->ceRecvMem->sizesFifo; + volatile struct ncclConnFifo* connFifo = resources->ceRecvMem->connFifo; volatile uint64_t* recvTail = &resources->ceRecvMem->tail; // Check GPU has sent everything if ((*recvTail > sub->base+sub->transmitted)) { - int size = sizesFifo[buffSlot]; + int size = connFifo[buffSlot].size; CUDACHECK(cudaMemcpyAsync(resources->shmFifo+buffSlot*stepSize, resources->devFifo+buffSlot*stepSize, size, cudaMemcpyDeviceToHost, resources->stream)); CUDACHECK(cudaEventRecord(resources->events[buffSlot], resources->stream)); - resources->recvMem->sizesFifo[buffSlot] = size; - __sync_synchronize(); // make sure sizesFifo is visible + resources->recvMem->connFifo[buffSlot].size = size; + __sync_synchronize(); // make sure connFifo[].size is visible sub->transmitted += args->sliceSteps; } } @@ -374,11 +374,11 @@ static ncclResult_t shmRecvProxyProgress(struct ncclProxyState* proxyState, stru } if (sub->transmitted < sub->done + NCCL_STEPS && sub->transmitted < sub->nsteps) { int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; - volatile int* sizesFifo = resources->recvMem->sizesFifo; + volatile struct ncclConnFifo* connFifo = resources->recvMem->connFifo; volatile uint64_t* recvTail = &resources->recvMem->tail; // Check data is ready in SHM if ((*recvTail > sub->base+sub->transmitted)) { - int size = sizesFifo[buffSlot]; + int size = connFifo[buffSlot].size; CUDACHECK(cudaMemcpyAsync(resources->devFifo+buffSlot*stepSize, resources->shmFifo+buffSlot*stepSize, size, cudaMemcpyHostToDevice, resources->stream)); CUDACHECK(cudaEventRecord(resources->events[buffSlot], resources->stream)); sub->transmitted += args->sliceSteps;