diff --git a/src/transport.cc b/src/transport.cc index 8149d94429..860b463a22 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -8,7 +8,6 @@ #include "comm.h" #include "info.h" #include "bootstrap.h" -#include "../graph/topo.h" extern struct ncclTransport p2pTransport; extern struct ncclTransport shmTransport; @@ -20,16 +19,6 @@ struct ncclTransport ncclTransports[NTRANSPORTS] = { netTransport, }; -static ncclResult_t connectedByXGMI(int* ret, struct ncclTopoSystem* system, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { - *ret = 0; - if (info1->hostHash != info2->hostHash) return ncclSuccess; - int g1, g2; - NCCLCHECK(ncclTopoRankToIndex(system, info1->rank, &g1)); - NCCLCHECK(ncclTopoRankToIndex(system, info2->rank, &g2)); - if (system->nodes[GPU].nodes[g1].paths[GPU][g2].type == PATH_NVL) *ret = 1; - return ncclSuccess; -} - template static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex) { struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank; @@ -44,8 +33,8 @@ static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2)); } - int xgmi; - NCCLCHECK(connectedByXGMI(&xgmi, comm->topo, myInfo, peerInfo)); + bool xgmi; + NCCLCHECK(ncclTopoGetLinkType(comm->topo, myInfo->cudaDev, peerInfo->cudaDev, &xgmi)); for (int t=0; t= 0 && n2 >= 0 && t != TRANSPORT_NET) continue; diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index b02c6b8de6..5586649143 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -199,6 +199,8 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { if (ncclNIbDevs == 0) { INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found."); } else { + auto cmpIbDevs = [](const void* n1, const void* n2) { return strcmp(((struct ncclIbDev*)n1)->devName, ((struct ncclIbDev*)n2)->devName); }; + qsort(ncclIbDevs, ncclNIbDevs, sizeof(struct ncclIbDev), cmpIbDevs); char line[1024]; line[0] = '\0'; for (int d=0; d