Files
rocm-systems/src/plugin/net.cc
T
2025-06-20 07:54:49 -05:00

324 строки
10 KiB
C++

/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "net.h"
#include "bootstrap.h"
#include "checks.h"
#include "plugin.h"
#include <string.h>
#include <errno.h>
//#include <sys/types.h>
//#include <sys/stat.h>
//#include <unistd.h>
extern ncclNet_t* getNcclNet_v6(void* netPluginLib);
extern ncclNet_t* getNcclNet_v7(void* netPluginLib);
extern ncclNet_t* getNcclNet_v8(void* netPluginLib);
extern ncclNet_t* getNcclNet_v9(void* netPluginLib);
extern ncclNet_t* getNcclNet_v10(void* netPluginLib);
extern ncclCollNet_t* getNcclCollNet_v6(void* netPluginLib);
extern ncclCollNet_t* getNcclCollNet_v7(void* netPluginLib);
extern ncclCollNet_t* getNcclCollNet_v8(void* netPluginLib);
extern ncclCollNet_t* getNcclCollNet_v9(void* netPluginLib);
extern ncclCollNet_t* getNcclCollNet_v10(void* netPluginLib);
static pthread_mutex_t netLock = PTHREAD_MUTEX_INITIALIZER;
ncclNet_t* ncclNets[NCCL_NET_MAX_PLUGINS] = { nullptr, &ncclNetIb, &ncclNetSocket };
static int ncclNetsVer[NCCL_NET_MAX_PLUGINS] = { -1, 10, 10 };
ncclCollNet_t* ncclCollNets[NCCL_NET_MAX_PLUGINS] = { nullptr, nullptr, nullptr };
enum ncclNetState {
ncclNetStateInit = 0,
ncclNetStateEnabled = 1,
ncclNetStateDisabled = 2
};
enum ncclNetState ncclNetStates[NCCL_NET_MAX_PLUGINS] = { ncclNetStateInit, ncclNetStateInit, ncclNetStateInit };
enum ncclNetState ncclCollNetStates[NCCL_NET_MAX_PLUGINS] = { ncclNetStateInit, ncclNetStateInit, ncclNetStateInit };
NCCL_PARAM(NetPluginRefCount, "NET_PLUGIN_REF_COUNT", 1);
static pthread_mutex_t netPluginLock = PTHREAD_MUTEX_INITIALIZER;
static void* netPluginLib;
static int netPluginRefCount;
static void initNetPluginRefCountOnce(void) { netPluginRefCount = ncclParamNetPluginRefCount();}
enum {
netPluginLoadFailed = -1,
netPluginLoadReady = 0,
netPluginLoadSuccess = 1,
};
static int netPluginStatus = netPluginLoadReady;
ncclResult_t ncclNetPluginLoad(struct ncclComm* comm) {
static pthread_once_t netPluginRefCountOnce = PTHREAD_ONCE_INIT;
pthread_once(&netPluginRefCountOnce, initNetPluginRefCountOnce);
pthread_mutex_lock(&netPluginLock);
if (netPluginLoadFailed == netPluginStatus) {
goto exit;
}
if (netPluginLoadSuccess == netPluginStatus) {
++netPluginRefCount;
goto exit;
}
netPluginLib = ncclOpenNetPluginLib(ncclGetEnv("NCCL_NET_PLUGIN"));
if (netPluginLib == nullptr) {
goto fail;
}
ncclNets[0] = getNcclNet_v10(netPluginLib);
if (ncclNets[0]) ncclNetsVer[0] = 10;
if (ncclNets[0] == nullptr) {
// Try v9 plugin
ncclNets[0] = getNcclNet_v9(netPluginLib);
if (ncclNets[0]) ncclNetsVer[0] = 9;
}
if (ncclNets[0] == nullptr) {
// Try v8 plugin
ncclNets[0] = getNcclNet_v8(netPluginLib);
if (ncclNets[0]) ncclNetsVer[0] = 8;
}
if (ncclNets[0] == nullptr) {
// Try v7 plugin
ncclNets[0] = getNcclNet_v7(netPluginLib);
if (ncclNets[0]) ncclNetsVer[0] = 7;
}
if (ncclNets[0] == nullptr) {
// Try v6 plugin
ncclNets[0] = getNcclNet_v6(netPluginLib);
if (ncclNets[0]) ncclNetsVer[0] = 6;
}
if (ncclNets[0] == nullptr) {
goto fail;
}
// Check for CollNet
ncclCollNets[0] = getNcclCollNet_v10(netPluginLib);
if (ncclCollNets[0] == nullptr) {
ncclCollNets[0] = getNcclCollNet_v9(netPluginLib);
}
if (ncclCollNets[0] == nullptr) {
ncclCollNets[0] = getNcclCollNet_v8(netPluginLib);
}
if (ncclCollNets[0] == nullptr) {
ncclCollNets[0] = getNcclCollNet_v7(netPluginLib);
}
if (ncclCollNets[0] == nullptr) {
ncclCollNets[0] = getNcclCollNet_v6(netPluginLib);
}
++netPluginRefCount;
netPluginStatus = netPluginLoadSuccess;
comm->netPluginLoaded = 1;
exit:
pthread_mutex_unlock(&netPluginLock);
return ncclSuccess;
fail:
if (netPluginLib) NCCLCHECK(ncclClosePluginLib(netPluginLib));
netPluginStatus = netPluginLoadFailed;
goto exit;
}
ncclResult_t ncclNetPluginUnload(struct ncclComm* comm) {
pthread_mutex_lock(&netPluginLock);
if (comm->netPluginLoaded && 0 == (--netPluginRefCount)) {
if (ncclNets[0]) {
INFO(NCCL_NET, "NET/Plugin: Closing net plugin '%s'", ncclNets[0]->name);
}
if (ncclCollNets[0]) {
INFO(NCCL_NET, "NET/Plugin: Closing collnet plugin '%s'", ncclCollNets[0]->name);
}
NCCLCHECK(ncclClosePluginLib(netPluginLib));
netPluginLib = nullptr;
ncclNets[0] = nullptr;
ncclCollNets[0] = nullptr;
netPluginStatus = netPluginLoadReady;
comm->netPluginLoaded = 0;
for (int i = 0; i < NCCL_NET_MAX_PLUGINS; ++i)
ncclCollNetStates[i] = ncclNetStates[i] = ncclNetStateInit;
}
pthread_mutex_unlock(&netPluginLock);
return ncclSuccess;
}
ncclResult_t ncclNetCheckDeviceVersion(struct ncclComm* comm, ncclNet_t* net, int dev) {
ncclNetProperties_t props;
NCCLCHECK(net->getProperties(dev, &props));
ncclNetDeviceType type = props.netDeviceType;
if (type) switch (type) {
case NCCL_NET_DEVICE_UNPACK:
if (props.netDeviceVersion == NCCL_NET_DEVICE_UNPACK_VERSION) {
INFO(NCCL_INIT, "Using NCCL_NET_DEVICE_UNPACK net plugin version %d",
props.netDeviceVersion);
return ncclSuccess;
} else {
WARN("NCCL_DEVICE_UNPACK plugin has incompatible version %d, this NCCL build is compatible with %d, not using it",
props.netDeviceVersion, NCCL_NET_DEVICE_UNPACK_VERSION);
return ncclInternalError;
}
default:
WARN("Unknown device code index %d \n", type);
return ncclInternalError;
}
return ncclSuccess;
}
static ncclResult_t netGetState(int i, enum ncclNetState* state) {
pthread_mutex_lock(&netLock);
if (ncclNetStates[i] == ncclNetStateInit) {
int ndev;
if (ncclNets[i]->init(ncclDebugLog, ncclProfilerCallback) != ncclSuccess) ncclNetStates[i] = ncclNetStateDisabled;
else if (ncclNets[i]->devices(&ndev) != ncclSuccess || ndev <= 0) ncclNetStates[i] = ncclNetStateDisabled;
else ncclNetStates[i] = ncclNetStateEnabled;
}
*state = ncclNetStates[i];
pthread_mutex_unlock(&netLock);
return ncclSuccess;
}
static ncclResult_t collNetGetState(int i, enum ncclNetState* state) {
pthread_mutex_lock(&netLock);
if (ncclCollNetStates[i] == ncclNetStateInit) {
int ndev;
if (ncclCollNets[i]->init(ncclDebugLog) != ncclSuccess) ncclCollNetStates[i] = ncclNetStateDisabled;
else if (ncclCollNets[i]->devices(&ndev) != ncclSuccess || ndev <= 0) ncclCollNetStates[i] = ncclNetStateDisabled;
else ncclCollNetStates[i] = ncclNetStateEnabled;
}
*state = ncclCollNetStates[i];
pthread_mutex_unlock(&netLock);
return ncclSuccess;
}
ncclResult_t ncclNetInit(struct ncclComm* comm) {
// Initialize main communication network
const char* netName;
bool ok = false;
netName = comm->config.netName;
for (int i=0; i<3; i++) {
if (ncclNets[i] == nullptr) continue;
enum ncclNetState state;
NCCLCHECK(netGetState(i, &state));
if (state != ncclNetStateEnabled) continue;
if (netName && strcasecmp(netName, ncclNets[i]->name) != 0) continue;
if (ncclSuccess != ncclNetCheckDeviceVersion(comm, ncclNets[i], 0)) {
// Mismatched device plugin version
continue;
}
comm->ncclNet = ncclNets[i];
comm->ncclNetVer = ncclNetsVer[i];
ok = true;
if (ncclCollNets[i]) {
NCCLCHECK(collNetGetState(i, &state));
if (state == ncclNetStateEnabled) {
comm->ncclCollNet = ncclCollNets[i];
}
}
break;
}
if (!ok) {
WARN("Error: network %s not found.", netName ? netName : "");
return ncclInvalidUsage;
}
return ncclSuccess;
}
ncclResult_t ncclNetFinalize(struct ncclComm* comm) {
comm->ncclNet = nullptr;
comm->ncclCollNet = nullptr;
return ncclSuccess;
}
ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) {
constexpr int GPU_BUF_SIZE = 2*1024*1024;
#if CUDART_VERSION >= 11030
// In CUDA 11.3 and later we can now query the cudaDevAttrGPUDirectRDMASupported attribute
int driverVersion;
CUDACHECK(cudaDriverGetVersion(&driverVersion));
if (driverVersion >= 11030) {
int cudaDev, attr = 0;
CUDACHECK(cudaGetDevice(&cudaDev));
CUDACHECK(cudaDeviceGetAttribute(&attr, cudaDevAttrGPUDirectRDMASupported, cudaDev));
*gdrSupport = attr;
return ncclSuccess;
}
#endif
static int gdrSupportMatrix[32] = {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 };
if (gdrSupportMatrix[comm->cudaDev] == -1) {
int netDevs;
NCCLCHECK(comm->ncclNet->devices(&netDevs));
gdrSupportMatrix[comm->cudaDev] = 0;
for (int dev=0; dev<netDevs; dev++) {
// Find a net device which is GDR-capable
ncclNetProperties_t props;
NCCLCHECK(comm->ncclNet->getProperties(dev, &props));
if ((props.ptrSupport & NCCL_PTR_CUDA) == 0) continue;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
gdrSupportMatrix[comm->cudaDev] = 1;
break;
#endif
// Allocate memory on the GPU and try to register it on the NIC.
void *lComm = NULL, *sComm = NULL, *rComm = NULL;
ncclNetHandle_t handle;
char* gpuPtr = NULL;
void* mHandle = NULL;
ncclResult_t ret;
ncclDebugNoWarn = NCCL_NET;
NCCLCHECKGOTO(comm->ncclNet->listen(dev, &handle, &lComm), ret, cleanup1);
bool connected;
connected = false;
while (!connected) {
// If we're aborting now, skip to cleanup
if (__atomic_load_n(comm->abortFlag, __ATOMIC_ACQUIRE)) {
goto cleanup2;
}
if (sComm == NULL)
NCCLCHECKGOTO(comm->ncclNet->connect(dev, NULL, &handle, &sComm, NULL), ret, cleanup2);
if (rComm == NULL)
NCCLCHECKGOTO(comm->ncclNet->accept(lComm, &rComm, NULL), ret, cleanup2);
connected = (rComm != NULL) && (sComm != NULL);
}
NCCLCHECKGOTO(ncclCudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup2);
if (comm->ncclNet->regMr(sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) {
NCCLCHECK(comm->ncclNet->deregMr(sComm, mHandle));
NCCLCHECK(comm->ncclNet->regMr(rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle));
NCCLCHECK(comm->ncclNet->deregMr(rComm, mHandle));
gdrSupportMatrix[comm->cudaDev] = 1;
}
ncclDebugNoWarn = 0;
NCCLCHECK(ncclCudaFree(gpuPtr));
cleanup2:
if (rComm != NULL)
NCCLCHECK(comm->ncclNet->closeRecv(rComm));
if (sComm != NULL)
NCCLCHECK(comm->ncclNet->closeSend(sComm));
NCCLCHECK(comm->ncclNet->closeListen(lComm));
cleanup1:
break;
}
}
*gdrSupport = gdrSupportMatrix[comm->cudaDev];
return ncclSuccess;
}