diff --git a/src/include/net.h b/src/include/net.h index a6ac5ba327..29afc41b78 100644 --- a/src/include/net.h +++ b/src/include/net.h @@ -36,6 +36,14 @@ static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->cl static ncclResult_t ncclGpuGdrSupport(int* gdrSupport) { int netDevs; NCCLCHECK(ncclNetDevices(&netDevs)); + pthread_mutex_t ncclParamMutexGpuGdrSupport = PTHREAD_MUTEX_INITIALIZER; + static int gdrSupportCached[16] = {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; + int cudaDev; + CUDACHECK(hipGetDevice(&cudaDev)); + if (gdrSupportCached[cudaDev] != -1) { + *gdrSupport = gdrSupportCached[cudaDev]; + return ncclSuccess; + } *gdrSupport = 0; for (int dev=0; dev