372 строки
15 KiB
C++
372 строки
15 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
|
|
#include "net.h"
|
|
#include "bootstrap.h"
|
|
#include "checks.h"
|
|
#include "plugin.h"
|
|
#include "nccl_net.h"
|
|
|
|
#include <string.h>
|
|
#include <errno.h>
|
|
//#include <sys/types.h>
|
|
//#include <sys/stat.h>
|
|
//#include <unistd.h>
|
|
|
|
typedef ncclNet_t* getNcclNet_t(void* netPluginLib);
|
|
typedef ncclCollNet_t* getNcclCollNet_t(void* netPluginLib);
|
|
|
|
extern getNcclNet_t getNcclNet_v6;
|
|
extern getNcclNet_t getNcclNet_v7;
|
|
extern getNcclNet_t getNcclNet_v8;
|
|
extern getNcclNet_t getNcclNet_v9;
|
|
extern getNcclNet_t getNcclNet_v10;
|
|
extern getNcclCollNet_t getNcclCollNet_v6;
|
|
extern getNcclCollNet_t getNcclCollNet_v7;
|
|
extern getNcclCollNet_t getNcclCollNet_v8;
|
|
extern getNcclCollNet_t getNcclCollNet_v9;
|
|
extern getNcclCollNet_t getNcclCollNet_v10;
|
|
|
|
NCCL_PARAM(NetPluginRefCount, "NET_PLUGIN_REF_COUNT", 1);
|
|
#define NCCL_NET_VERSION_COUNT 5
|
|
int ncclNetVersion[NCCL_NET_VERSION_COUNT] = {10, 9, 8, 7, 6};
|
|
getNcclNet_t* getNcclNet[NCCL_NET_VERSION_COUNT] = {getNcclNet_v10, getNcclNet_v9, getNcclNet_v8, getNcclNet_v7, getNcclNet_v6};
|
|
getNcclCollNet_t* getNcclCollNet[NCCL_NET_VERSION_COUNT] = {getNcclCollNet_v10, getNcclCollNet_v9, getNcclCollNet_v8, getNcclCollNet_v7, getNcclCollNet_v6};
|
|
|
|
#define NCCL_NET_NUM_INTERNAL_PLUGINS 2
|
|
|
|
typedef enum ncclNetPluginState {
|
|
ncclNetPluginStateDisabled = -2, // Plugin library failed to initialize
|
|
ncclNetPluginStateLoadFailed = -1, // Plugin library failed to load
|
|
ncclNetPluginStateLoadReady = 0, // Plugin library is ready to be loaded
|
|
ncclNetPluginStateInitReady = 1, // Plugin library is loaded and ready to be initialized
|
|
ncclNetPluginStateEnabled = 2, // Plugin library is loaded and initialized
|
|
} ncclNetPluginState_t;
|
|
|
|
#define MAX_STR_LEN 255
|
|
typedef struct netPluginLib {
|
|
char name[MAX_STR_LEN]; // Name of the plugin library
|
|
void* dlHandle; // Handle to the plugin library
|
|
ncclNet_t* ncclNet; // Pointer to the ncclNet_t structure
|
|
int ncclNetVer; // Version of the nccl net plugin
|
|
ncclCollNet_t* ncclCollNet; // Pointer to the ncclCollNet_t structure
|
|
ncclNetPluginState_t ncclNetPluginState; // State of the nccl net plugin
|
|
ncclNetPluginState_t ncclCollNetPluginState; // State of the nccl coll net plugin
|
|
int ncclNetPluginRefCount; // Reference count for the nccl net plugin
|
|
} netPluginLib_t;
|
|
|
|
int pluginCount = 0;
|
|
bool netPluginLibsInitialized = false;
|
|
netPluginLib_t netPluginLibs[NCCL_NET_MAX_PLUGINS] = { 0 };
|
|
static pthread_mutex_t netPluginLock = PTHREAD_MUTEX_INITIALIZER;
|
|
static pthread_once_t initPluginLibsOnceControl = PTHREAD_ONCE_INIT;
|
|
|
|
static ncclResult_t ncclNetPluginUnload(netPluginLib_t* pluginLib) {
|
|
if ((pluginLib->dlHandle) && ((pluginLib->ncclNetPluginRefCount) == 0)) {
|
|
INFO(NCCL_INIT|NCCL_NET, "Unloading plugin %s", pluginLib->name);
|
|
NCCLCHECK(ncclClosePluginLib(pluginLib->dlHandle));
|
|
memset(pluginLib, 0, sizeof(netPluginLib_t));
|
|
}
|
|
return ncclSuccess;
|
|
}
|
|
|
|
static ncclResult_t ncclNetPluginLoad(netPluginLib_t* pluginLib) {
|
|
pluginLib->dlHandle = ncclOpenNetPluginLib(pluginLib->name);
|
|
|
|
if (pluginLib->dlHandle == nullptr) goto fail;
|
|
// load ncclNet
|
|
for (int i = 0; i < NCCL_NET_VERSION_COUNT; i++) {
|
|
pluginLib->ncclNetVer = ncclNetVersion[i];
|
|
pluginLib->ncclNet = getNcclNet[i](pluginLib->dlHandle);
|
|
if (pluginLib->ncclNet) break;
|
|
}
|
|
|
|
// if we fail to find a net, exit
|
|
if (pluginLib->ncclNet == nullptr) goto fail;
|
|
|
|
pluginLib->ncclNetPluginState = ncclNetPluginStateInitReady;
|
|
|
|
// load ncclColNet
|
|
for (int i = 0; i < NCCL_NET_VERSION_COUNT; i++) {
|
|
pluginLib->ncclCollNet = getNcclCollNet[i](pluginLib->dlHandle);
|
|
if (pluginLib->ncclCollNet) break;
|
|
}
|
|
|
|
if (pluginLib->ncclCollNet == nullptr)
|
|
pluginLib->ncclCollNetPluginState = ncclNetPluginStateLoadFailed;
|
|
else
|
|
pluginLib->ncclCollNetPluginState = ncclNetPluginStateInitReady;
|
|
|
|
INFO(NCCL_INIT|NCCL_NET, "Successfully loaded external plugin %s", pluginLib->name);
|
|
exit:
|
|
return ncclSuccess;
|
|
fail:
|
|
if (pluginLib->dlHandle) {
|
|
NCCLCHECK(ncclClosePluginLib(pluginLib->dlHandle));
|
|
}
|
|
pluginLib->ncclNetPluginState = ncclNetPluginStateLoadFailed;
|
|
pluginLib->ncclCollNetPluginState = ncclNetPluginStateLoadFailed;
|
|
goto exit;
|
|
}
|
|
|
|
ncclResult_t ncclNetCheckDeviceVersion(struct ncclComm* comm, ncclNet_t* net, int dev) {
|
|
ncclNetProperties_t props;
|
|
|
|
NCCLCHECK(net->getProperties(dev, &props));
|
|
ncclNetDeviceType type = props.netDeviceType;
|
|
if (type) switch (type) {
|
|
case NCCL_NET_DEVICE_UNPACK:
|
|
if (props.netDeviceVersion == NCCL_NET_DEVICE_UNPACK_VERSION) {
|
|
INFO(NCCL_INIT, "Using NCCL_NET_DEVICE_UNPACK net plugin version %d",
|
|
props.netDeviceVersion);
|
|
return ncclSuccess;
|
|
} else {
|
|
WARN("NCCL_DEVICE_UNPACK plugin has incompatible version %d, this NCCL build is compatible with %d, not using it",
|
|
props.netDeviceVersion, NCCL_NET_DEVICE_UNPACK_VERSION);
|
|
return ncclInternalError;
|
|
}
|
|
default:
|
|
WARN("Unknown device code index %d \n", type);
|
|
return ncclInternalError;
|
|
}
|
|
|
|
return ncclSuccess;
|
|
}
|
|
|
|
static ncclResult_t ncclNetPluginInit(netPluginLib_t* pluginLib) {
|
|
int ndev;
|
|
if (pluginLib->ncclNetPluginState == ncclNetPluginStateInitReady && pluginLib->ncclNet) {
|
|
if (pluginLib->ncclNet->init(ncclDebugLog, ncclProfilerCallback) != ncclSuccess) goto fail;
|
|
if (pluginLib->ncclNet->devices(&ndev) != ncclSuccess || ndev <= 0) goto fail;
|
|
}
|
|
pluginLib->ncclNetPluginState = ncclNetPluginStateEnabled;
|
|
INFO(NCCL_INIT|NCCL_NET, "Initialized NET plugin %s", pluginLib->ncclNet->name);
|
|
|
|
if (pluginLib->ncclCollNetPluginState == ncclNetPluginStateInitReady && pluginLib->ncclCollNet) {
|
|
if (pluginLib->ncclCollNet->init(ncclDebugLog) != ncclSuccess) pluginLib->ncclCollNetPluginState = ncclNetPluginStateDisabled;
|
|
else if (pluginLib->ncclCollNet->devices(&ndev) != ncclSuccess || ndev <= 0) pluginLib->ncclCollNetPluginState = ncclNetPluginStateDisabled;
|
|
else {
|
|
pluginLib->ncclCollNetPluginState = ncclNetPluginStateEnabled;
|
|
}
|
|
}
|
|
exit:
|
|
return ncclSuccess;
|
|
fail:
|
|
pluginLib->ncclNetPluginState = ncclNetPluginStateDisabled;
|
|
pluginLib->ncclCollNetPluginState = ncclNetPluginStateDisabled;
|
|
goto exit;
|
|
}
|
|
|
|
static ncclResult_t ncclNetPluginAssignToComm(struct ncclComm* comm, int pluginIndex, bool* isAssigned) {
|
|
const char* netName = comm->config.netName;
|
|
if (netName && strcasecmp(netName, netPluginLibs[pluginIndex].ncclNet->name) != 0) goto fail;
|
|
if (ncclSuccess != ncclNetCheckDeviceVersion(comm, netPluginLibs[pluginIndex].ncclNet, 0)) goto fail;
|
|
|
|
if (netPluginLibs[pluginIndex].ncclNetPluginState >= ncclNetPluginStateEnabled) {
|
|
comm->ncclNet = netPluginLibs[pluginIndex].ncclNet;
|
|
comm->ncclNetVer = netPluginLibs[pluginIndex].ncclNetVer;
|
|
comm->netPluginIndex = pluginIndex;
|
|
netPluginLibs[pluginIndex].ncclNetPluginRefCount++;
|
|
*isAssigned = true;
|
|
INFO(NCCL_INIT|NCCL_NET, "Assigned NET plugin %s to comm", netPluginLibs[pluginIndex].ncclNet->name);
|
|
if (netPluginLibs[pluginIndex].ncclCollNetPluginState >= ncclNetPluginStateEnabled) {
|
|
comm->ncclCollNet = netPluginLibs[pluginIndex].ncclCollNet;
|
|
}
|
|
}
|
|
exit:
|
|
return ncclSuccess;
|
|
fail:
|
|
*isAssigned = false;
|
|
netPluginLibs[pluginIndex].ncclNetPluginState = ncclNetPluginStateEnabled;
|
|
netPluginLibs[pluginIndex].ncclCollNetPluginState = ncclNetPluginStateEnabled;
|
|
goto exit;
|
|
}
|
|
|
|
static ncclResult_t ncclNetPluginDisableOtherExternal(int pluginIndex) {
|
|
// Only if an external plugin is enabled, disable other external plugins
|
|
if (pluginIndex >= (pluginCount - NCCL_NET_NUM_INTERNAL_PLUGINS)) return ncclSuccess;
|
|
char names[MAX_STR_LEN*(NCCL_NET_MAX_PLUGINS - NCCL_NET_NUM_INTERNAL_PLUGINS)] = { 0 };
|
|
for (int i = 0; i < (pluginCount - NCCL_NET_NUM_INTERNAL_PLUGINS); i++) {
|
|
if (i != pluginIndex) {
|
|
// Append all disabled plugin names to a string
|
|
snprintf(names+strlen(names), sizeof(names)-strlen(names), (strlen(names) == 0) ? "%s" : ", %s", netPluginLibs[i].name);
|
|
netPluginLibs[i].ncclNetPluginState = ncclNetPluginStateDisabled;
|
|
}
|
|
}
|
|
if(strlen(names) > 0) {
|
|
INFO(NCCL_INIT|NCCL_NET, "Disabling external plugins: %s", names);
|
|
}
|
|
return ncclSuccess;
|
|
}
|
|
|
|
static void initPluginLibsOnceFunc() {
|
|
char* netPluginName = nullptr;
|
|
const char* defaultNetPlugin = "libnccl-net.so";
|
|
const char* envNetPlugin = nullptr;
|
|
char* envNetPluginList = nullptr;
|
|
char* savePtr = nullptr;
|
|
int pluginCounter = 0;
|
|
|
|
memset(netPluginLibs, 0, NCCL_NET_MAX_PLUGINS * sizeof(netPluginLib_t));
|
|
envNetPlugin = ncclGetEnv("NCCL_NET_PLUGIN");
|
|
if (envNetPlugin) {
|
|
envNetPluginList = strdup(envNetPlugin);
|
|
// Iterate over list until the list is empty
|
|
netPluginName = strtok_r(envNetPluginList, ",", &savePtr);
|
|
while(netPluginName) {
|
|
// We have 2 internal plugins (ib and socket)
|
|
// So, we can have at most( NCCL_NET_MAX_PLUGINS - (NCCL_NET_NUM_INTERNAL_PLUGINS)) in the NCCL_NET_PLUGIN list
|
|
if (pluginCounter >= (NCCL_NET_MAX_PLUGINS - (NCCL_NET_NUM_INTERNAL_PLUGINS))) {
|
|
INFO(NCCL_NET|NCCL_INIT,"NCCL_NET_PLUGIN list contains more than %d plugins, ignoring the rest", (NCCL_NET_MAX_PLUGINS - (NCCL_NET_NUM_INTERNAL_PLUGINS + 1)));
|
|
break;
|
|
}
|
|
// need to leave space for the name + "\n"
|
|
if((strlen(netPluginName)+1) <= MAX_STR_LEN) {
|
|
netPluginLibs[pluginCounter].ncclNetPluginState = ncclNetPluginStateLoadReady;
|
|
netPluginLibs[pluginCounter].ncclNetPluginRefCount = ncclParamNetPluginRefCount();
|
|
strcpy(netPluginLibs[pluginCounter].name, netPluginName);
|
|
pluginCounter++;
|
|
} else {
|
|
INFO(NCCL_NET|NCCL_INIT,"NCCL_NET_PLUGIN list contains a plugin name %s longer than %d characters, ignoring it.", netPluginName, MAX_STR_LEN);
|
|
}
|
|
netPluginName = strtok_r(nullptr, ",", &savePtr);
|
|
}
|
|
if (envNetPluginList) free(envNetPluginList);
|
|
} else {
|
|
// Add default net plugin
|
|
netPluginLibs[pluginCounter].ncclNetPluginState = ncclNetPluginStateLoadReady;
|
|
netPluginLibs[pluginCounter].ncclNetPluginRefCount = ncclParamNetPluginRefCount();
|
|
strcpy(netPluginLibs[pluginCounter++].name, defaultNetPlugin);
|
|
}
|
|
|
|
// Add 2 internal ib and socket plugins
|
|
netPluginLibs[pluginCounter].ncclNet = &ncclNetIb;
|
|
netPluginLibs[pluginCounter++].ncclNetPluginState = ncclNetPluginStateInitReady;
|
|
netPluginLibs[pluginCounter].ncclNet = &ncclNetSocket;
|
|
netPluginLibs[pluginCounter++].ncclNetPluginState = ncclNetPluginStateInitReady;
|
|
pluginCount = pluginCounter;
|
|
}
|
|
|
|
ncclResult_t ncclNetInit(struct ncclComm* comm) {
|
|
bool ncclNetPluginInitialized = false;
|
|
pthread_once(&initPluginLibsOnceControl, initPluginLibsOnceFunc);
|
|
pthread_mutex_lock(&netPluginLock);
|
|
for (int pluginIndex = 0; pluginIndex < pluginCount; pluginIndex++) {
|
|
if ((pluginIndex < (pluginCount - NCCL_NET_NUM_INTERNAL_PLUGINS)) && (netPluginLibs[pluginIndex].ncclNetPluginState == ncclNetPluginStateLoadReady)) {
|
|
NCCLCHECK(ncclNetPluginLoad(&netPluginLibs[pluginIndex]));
|
|
}
|
|
if (netPluginLibs[pluginIndex].ncclNetPluginState == ncclNetPluginStateInitReady) {
|
|
NCCLCHECK(ncclNetPluginInit(&netPluginLibs[pluginIndex]));
|
|
}
|
|
if (netPluginLibs[pluginIndex].ncclNetPluginState == ncclNetPluginStateEnabled) {
|
|
bool isAssigned = false;
|
|
NCCLCHECK(ncclNetPluginAssignToComm(comm, pluginIndex, &isAssigned));
|
|
if (isAssigned) {
|
|
// If one external plugin is assigned to a comm, then disable all other external plugins
|
|
ncclNetPluginDisableOtherExternal(pluginIndex);
|
|
ncclNetPluginInitialized = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
pthread_mutex_unlock(&netPluginLock);
|
|
if (ncclNetPluginInitialized) return ncclSuccess;
|
|
WARN("Failed to initialize any NET plugin");
|
|
return ncclInvalidUsage;
|
|
}
|
|
|
|
ncclResult_t ncclNetFinalize(struct ncclComm* comm) {
|
|
int pluginIndex = comm->netPluginIndex;
|
|
pthread_mutex_lock(&netPluginLock);
|
|
netPluginLibs[pluginIndex].ncclNetPluginRefCount--;
|
|
for (int i = 0; i < (pluginCount - NCCL_NET_NUM_INTERNAL_PLUGINS); i++) {
|
|
NCCLCHECK(ncclNetPluginUnload(&netPluginLibs[i]));
|
|
}
|
|
pthread_mutex_unlock(&netPluginLock);
|
|
return ncclSuccess;
|
|
}
|
|
|
|
ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) {
|
|
constexpr int GPU_BUF_SIZE = 2*1024*1024;
|
|
#if CUDART_VERSION >= 11030
|
|
// In CUDA 11.3 and later we can now query the cudaDevAttrGPUDirectRDMASupported attribute
|
|
int driverVersion;
|
|
CUDACHECK(cudaDriverGetVersion(&driverVersion));
|
|
if (driverVersion >= 11030) {
|
|
int cudaDev, attr = 0;
|
|
CUDACHECK(cudaGetDevice(&cudaDev));
|
|
CUDACHECK(cudaDeviceGetAttribute(&attr, cudaDevAttrGPUDirectRDMASupported, cudaDev));
|
|
*gdrSupport = attr;
|
|
return ncclSuccess;
|
|
}
|
|
#endif
|
|
static int gdrSupportMatrix[32] = {
|
|
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
|
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 };
|
|
if (gdrSupportMatrix[comm->cudaDev] == -1) {
|
|
int netDevs;
|
|
NCCLCHECK(comm->ncclNet->devices(&netDevs));
|
|
gdrSupportMatrix[comm->cudaDev] = 0;
|
|
for (int dev=0; dev<netDevs; dev++) {
|
|
// Find a net device which is GDR-capable
|
|
ncclNetProperties_t props;
|
|
NCCLCHECK(comm->ncclNet->getProperties(dev, &props));
|
|
if ((props.ptrSupport & NCCL_PTR_CUDA) == 0) continue;
|
|
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
|
|
gdrSupportMatrix[comm->cudaDev] = 1;
|
|
break;
|
|
#endif
|
|
|
|
// Allocate memory on the GPU and try to register it on the NIC.
|
|
void *lComm = NULL, *sComm = NULL, *rComm = NULL;
|
|
ncclNetHandle_t handle;
|
|
char* gpuPtr = NULL;
|
|
void* mHandle = NULL;
|
|
ncclResult_t ret;
|
|
ncclDebugNoWarn = NCCL_NET;
|
|
NCCLCHECKGOTO(comm->ncclNet->listen(dev, &handle, &lComm), ret, cleanup1);
|
|
|
|
bool connected;
|
|
connected = false;
|
|
while (!connected) {
|
|
|
|
// If we're aborting now, skip to cleanup
|
|
if (__atomic_load_n(comm->abortFlag, __ATOMIC_ACQUIRE)) {
|
|
goto cleanup2;
|
|
}
|
|
|
|
if (sComm == NULL)
|
|
NCCLCHECKGOTO(comm->ncclNet->connect(dev, NULL, &handle, &sComm, NULL), ret, cleanup2);
|
|
|
|
if (rComm == NULL)
|
|
NCCLCHECKGOTO(comm->ncclNet->accept(lComm, &rComm, NULL), ret, cleanup2);
|
|
|
|
connected = (rComm != NULL) && (sComm != NULL);
|
|
}
|
|
|
|
NCCLCHECKGOTO(ncclCudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup2);
|
|
if (comm->ncclNet->regMr(sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) {
|
|
NCCLCHECK(comm->ncclNet->deregMr(sComm, mHandle));
|
|
NCCLCHECK(comm->ncclNet->regMr(rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle));
|
|
NCCLCHECK(comm->ncclNet->deregMr(rComm, mHandle));
|
|
gdrSupportMatrix[comm->cudaDev] = 1;
|
|
}
|
|
ncclDebugNoWarn = 0;
|
|
NCCLCHECK(ncclCudaFree(gpuPtr));
|
|
cleanup2:
|
|
if (rComm != NULL)
|
|
NCCLCHECK(comm->ncclNet->closeRecv(rComm));
|
|
if (sComm != NULL)
|
|
NCCLCHECK(comm->ncclNet->closeSend(sComm));
|
|
NCCLCHECK(comm->ncclNet->closeListen(lComm));
|
|
cleanup1:
|
|
break;
|
|
}
|
|
}
|
|
*gdrSupport = gdrSupportMatrix[comm->cudaDev];
|
|
return ncclSuccess;
|
|
}
|