Setup collectives threshold for enabling intranet (#387)

* Setup collectives threshold for enabling intranet

* Use separate operation counters for coll and p2p

[ROCm/rccl commit: b815a2800f]
このコミットが含まれているのは:
Wenkai Du
2021-06-09 13:24:26 -07:00
committed by GitHub
コミット 5bebcb0015
18個のファイルの変更79行の追加36行の削除
+1 -1
ファイルの表示
@@ -31,7 +31,7 @@ class ncclFunction<ncclFuncAllGather, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS, T, 1, 1, 1, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+1 -1
ファイルの表示
@@ -37,7 +37,7 @@ class ncclFunction<ncclFuncAllReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS, T, 1, 1, 1, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, 0);
prims(tid, nthreads, &ring->prev, &ring->next, thisOutput, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) {
ssize_t realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nranks*nChannels));
+1 -1
ファイルの表示
@@ -33,7 +33,7 @@ class ncclFunction<ncclFuncBroadcast, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS, T, 1, 1, 0, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+4 -2
ファイルの表示
@@ -247,7 +247,8 @@ class ncclPrimitives {
__device__ __forceinline__ void loadRecvConn(struct ncclChannel* channel, T* directBuff) {
if (role & (ROLE_WAIT_RECV|ROLE_POST_RECV)) {
// For oneshot: groups 0,1 use conn 0, groups 2,3 use conn 1
const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/2 : (((p2pNet && (NSEND+NRECV) == 1)) ? NCCL_CONN_IDX_P2P_NET : 0);
const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/2 :
((p2pNet && (NSEND+NRECV) == 1 ? NCCL_CONN_IDX_P2P_NET : ((NSEND+NRECV) == 1 ? 0 : group)));
conn = &channel->devPeers[peer].recv[connIndex].conn;
step = conn->step;
step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
@@ -272,7 +273,8 @@ class ncclPrimitives {
__device__ __forceinline__ void loadSendConn(struct ncclChannel* channel) {
if (role & (ROLE_WAIT_SEND|ROLE_POST_SEND)) {
// For oneshot: groups 0,1 use conn 0, groups 2,3 use conn 1
const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/2 : (((p2pNet && (NSEND+NRECV) == 1)) ? NCCL_CONN_IDX_P2P_NET : 0);
const int connIndex = (NSEND == NCCL_MAX_DIRECT_ARITY || NRECV == NCCL_MAX_DIRECT_ARITY) ? group/2 :
((p2pNet && (NSEND+NRECV) == 1 ? NCCL_CONN_IDX_P2P_NET : ((NSEND+NRECV) == 1 ? 0 : group)));
conn = &channel->devPeers[peer].send[connIndex].conn;
step = conn->step;
step = ROUNDUP(step, SLICESPERCHUNK*SLICESTEPS);
+1 -1
ファイルの表示
@@ -34,7 +34,7 @@ class ncclFunction<ncclFuncReduce, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUNC, T, U
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS, T, 1, 1, 0, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+1 -1
ファイルの表示
@@ -31,7 +31,7 @@ class ncclFunction<ncclFuncReduceScatter, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE, FUN
T * __restrict__ thisOutput = (T*)args->recvbuff;
ncclPrimitives<UNROLL, REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS, T, 1, 1, 0, FUNC>
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, 0);
prims(tid, nthreads, &ring->prev, &ring->next, NULL, stepSize, channel, comm, ncclShmem->ptrs, args->coll.connIndex);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
int realChunkSize = min(chunkSize, DIVUP(size-gridOffset,nChannels));
+14 -6
ファイルの表示
@@ -453,6 +453,8 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
return ncclSuccess;
}
RCCL_PARAM(IntraNetThreshold, "RCCL_INTRANET_THRESHOLD", 8388608);
static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWorkElem* work, struct ncclProxyArgs* proxyArgs /* output */) {
work->comm = info->comm->devComm;
@@ -471,6 +473,15 @@ static ncclResult_t computeColl(struct ncclInfo* info /* input */, struct ncclWo
work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol);
work->coll.connIndex = 0;
proxyArgs->connIndex = 0;
if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) {
if (info->comm->useIntraNet && info->nBytes > rcclParamIntraNetThreshold()) {
work->coll.connIndex = NCCL_CONN_IDX_P2P_NET;
proxyArgs->connIndex = NCCL_CONN_IDX_P2P_NET;
}
}
{ // [RCCL] Check for clique-based kernel support
if (info->comm->cliqueManager->IsSupported(info->coll,
info->count,
@@ -728,7 +739,7 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct
elem->nThreads = NCCL_MAX_NTHREADS;
elem->sendbuff = info->sendbuff;
elem->recvbuff = info->recvbuff;
elem->op.opCount = info->comm->collOpCount;
elem->op.opCount = info->comm->p2pOpCount;
elem->p2p.sendCount = info->sendbytes;
elem->p2p.recvCount = info->recvbytes;
elem->p2p.sendChunkSize = info->sendChunkSize;
@@ -772,7 +783,7 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
// store work element into FIFO
NCCLCHECK(ncclProxySaveP2p(comm, proxyArgs));
NCCLCHECK(enqueueP2pOp(workElem, w, segment));
comm->collOpCount++;
comm->p2pOpCount++;
return ncclSuccess;
}
@@ -916,13 +927,10 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
NCCLCHECKGOTO(checkSetStream(info), ret, end);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
info->opName, info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count,
info->opName, info->coll == ncclFuncSendRecv ? info->comm->p2pOpCount : info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count,
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream);
if (info->coll == ncclFuncSendRecv) { //p2p stored separately
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
info->opName, info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count,
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream);
NCCLCHECKGOTO(ncclSaveP2p(info), ret, end);
} else {
NCCLCHECKGOTO(ncclSaveAsyncColl(info), ret, end);
+1 -3
ファイルの表示
@@ -1017,11 +1017,9 @@ ncclResult_t ncclTopoGetNetDev(struct ncclTopoSystem* system, int rank, struct n
return ncclSuccess;
}
extern int64_t rcclParamP2pNetDisable();
ncclResult_t ncclTopoGetIntraNetDev(struct ncclTopoSystem* system, int rank, struct ncclTopoGraph* graph, int channelId, int type, int* dev) {
*dev = -1;
if (graph && graph->nIntraChannels && rcclParamP2pNetDisable() == 0) {
if (graph && graph->nIntraChannels) {
int n1 = -1;
int ngpus = system->nodes[GPU].count;
int nnets = system->nodes[NET].count;
+3
ファイルの表示
@@ -98,6 +98,8 @@ struct ncclComm {
uint64_t opCount;
// Collective operation counter
uint64_t collOpCount;
// P2P operation counter
uint64_t p2pOpCount;
// Channels for collectives
int nChannels;
@@ -127,6 +129,7 @@ struct ncclComm {
// Flags for enable P2P NET
uint32_t *p2pNet;
uint32_t useIntraNet;
// Device side of the communicator
struct ncclDevComm *devComm;
+1
ファイルの表示
@@ -200,6 +200,7 @@ struct ncclWorkElem {
uint32_t root;
uint8_t bid;
uint8_t nChannels;
uint8_t connIndex;
} coll;
struct {
size_t sendCount;
+1
ファイルの表示
@@ -60,6 +60,7 @@ struct ncclProxyArgs {
int idle;
uint64_t hdp_flushed;
uint8_t connIndex;
// Element linking
pthread_mutex_t mutex;
+12
ファイルの表示
@@ -433,6 +433,8 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
NCCLCHECK(ncclCudaHostCalloc((uint32_t**)&comm->p2pNet, 1));
comm->hostDevComm.p2pNet = comm->p2pNet;
STORE(comm->p2pNet, 0);
comm->collOpCount = 0;
comm->p2pOpCount = 0x8000;
comm->argsptr = &comm->args;
#ifdef ENABLE_PROFILING
@@ -1005,6 +1007,16 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, affinity_restore);
}
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, 0), ret, affinity_restore);
if (ringGraph.nIntraChannels && rcclParamP2pNetDisable() == 0) {
comm->useIntraNet = 1;
// Connect NET for intranode use
for (int c=0; c<comm->nChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
if (comm->nRanks == 1) continue;
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->ring.prev, 1, &channel->ring.next, NCCL_CONN_IDX_P2P_NET), ret, affinity_restore);
}
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, NCCL_CONN_IDX_P2P_NET), ret, affinity_restore);
}
free(rings);
INFO(NCCL_INIT, "Connected all rings");
+2 -2
ファイルの表示
@@ -209,8 +209,8 @@ ncclResult_t ncclProxySaveColl(struct ncclProxyArgs* args, int nranks) {
int pattern = args->pattern;
if (pattern == ncclPatternRing || pattern == ncclPatternRingTwice || pattern == ncclPatternPipelineFrom || pattern == ncclPatternPipelineTo) {
struct ncclRing* ring = &channel->ring;
if (NeedProxy(proxyRecv, pattern, args->root, ring, nranks)) NCCLCHECK(SaveProxy(proxyRecv, ring->prev, args, 0));
if (NeedProxy(proxySend, pattern, args->root, ring, nranks)) NCCLCHECK(SaveProxy(proxySend, ring->next, args, 0));
if (NeedProxy(proxyRecv, pattern, args->root, ring, nranks)) NCCLCHECK(SaveProxy(proxyRecv, ring->prev, args, args->connIndex));
if (NeedProxy(proxySend, pattern, args->root, ring, nranks)) NCCLCHECK(SaveProxy(proxySend, ring->next, args, args->connIndex));
}
if (pattern == ncclPatternTreeUp || pattern == ncclPatternTreeUpDown) {
// Tree up
+7 -6
ファイルの表示
@@ -38,16 +38,17 @@ static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph*
comm->channels[channelId].peers[peer].recv + connIndex;
// handle intra-node network connections
int n1, n2;
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, comm->rank, graph, channelId, (type == 1) ? 1 : 0, &n1));
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2));
int n1 = -1, n2 = -1;
if (connIndex == NCCL_CONN_IDX_P2P_NET) {
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, comm->rank, graph, channelId, (type == 1) ? 1 : 0, &n1));
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2));
}
int xgmi;
NCCLCHECK(connectedByXGMI(&xgmi, comm->topo, myInfo, peerInfo));
for (int t=0; t<NTRANSPORTS; t++) {
if (connIndex == NCCL_CONN_IDX_P2P_NET && (t == TRANSPORT_SHM || (!xgmi && t == TRANSPORT_P2P)))
continue;
if (n1 >= 0 && n2 >= 0 && t != TRANSPORT_NET) continue;
if (graph == NULL && connIndex == NCCL_CONN_IDX_P2P_NET && (t == TRANSPORT_SHM || (!xgmi && t == TRANSPORT_P2P))) continue;
if (graph && n1 >= 0 && n2 >= 0 && t != TRANSPORT_NET) continue;
struct ncclTransport *transport = ncclTransports+t;
struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv;
int ret = 0;
+4 -2
ファイルの表示
@@ -78,7 +78,8 @@ ncclResult_t netSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
send->conn.shared = resources->shared = ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : graph ? 0 : 1;
send->proxyAppendPtr = send->conn.shared ? comm->proxyState.sharedBuffs.proxyAppend+2*channelId+1 : &send->proxyAppend;
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 1, &resources->netDev));
resources->netDev = -1;
if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 1, &resources->netDev));
if (resources->netDev < 0) {
// Send/Receive: Round-robin NICs based on the receiver's CUDA device
int nicRR = comm->peerInfo[peerInfo->rank].cudaDev;
@@ -146,7 +147,8 @@ ncclResult_t netRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
recv->conn.shared = resources->shared = ncclParamNetSharedBuffers() != -2 ? ncclParamNetSharedBuffers() : graph ? 0 : 1;
recv->proxyAppendPtr = recv->conn.shared ? comm->proxyState.sharedBuffs.proxyAppend+2*channelId : &recv->proxyAppend;
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 0, &resources->netDev));
resources->netDev = -1;
if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 0, &resources->netDev));
if (resources->netDev < 0) {
// Send/Receive: Round-robin NICs based on the receiver's CUDA device
int nicRR = comm->cudaDev;
+4 -2
ファイルの表示
@@ -151,7 +151,8 @@ BEGIN {
s=ary[1]
match($col_p5, /([0-9]+)\[.*\]/, ary)
d=ary[1]
conn[s "," d "," chan]=$col_p7
if(!((s "," d "," chan) in conn) || match($col_p7,"NET"))
conn[s "," d "," chan]=$col_p7
}
if($col_p6=="[receive]" && $col_p7=="via") {
@@ -161,7 +162,8 @@ BEGIN {
s=ary[1]
match($col_p5, /([0-9]+)\[.*\]/, ary)
d=ary[1]
conn[s "," d "," chan]=$col_p8
if(!((s "," d "," chan) in conn) || match($col_p8,"NET"))
conn[s "," d "," chan]=$col_p8
}
}
+4 -2
ファイルの表示
@@ -151,7 +151,8 @@ ncclResult_t netCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
ncclResult_t netSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* send, int channelId, int connIndex) {
int netDev, useGdr = 0;
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 1, &netDev));
netDev = -1;
if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 1, &netDev));
if (netDev < 0) {
// Send/Receive: Round-robin NICs based on the receiver's CUDA device
int nicRR = comm->peerInfo[peerInfo->rank].cudaDev;
@@ -169,7 +170,8 @@ NCCL_PARAM(NetGdrLevel, "NET_GDR_LEVEL", PATH_PHB);
ncclResult_t netRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclPeerInfo* myInfo, struct ncclPeerInfo* peerInfo, struct ncclConnect* connectInfo, struct ncclConnector* recv, int channelId, int connIndex) {
int netDev, useGdr = 0;
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 0, &netDev));
netDev = -1;
if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 0, &netDev));
if (netDev < 0) {
// Send/Receive: Round-robin NICs based on the receiver's CUDA device
int nicRR = comm->cudaDev;
+17 -6
ファイルの表示
@@ -539,16 +539,17 @@ static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph*
struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer].send + connIndex :
comm->channels[channelId].peers[peer].recv + connIndex;
// handle intra-node network connections
int n1, n2;
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, comm->rank, graph, channelId, (type == 1) ? 1 : 0, &n1));
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2));
int n1 = -1, n2 = -1;
if (connIndex == NCCL_CONN_IDX_P2P_NET) {
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, comm->rank, graph, channelId, (type == 1) ? 1 : 0, &n1));
NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, peer, graph, channelId, (type == 1) ? 0 : 1, &n2));
}
int xgmi;
NCCLCHECK(connectedByXGMI(&xgmi, comm->topo, myInfo, peerInfo));
for (int t=0; t<NTRANSPORTS; t++) {
if (connIndex == NCCL_CONN_IDX_P2P_NET && (t == TRANSPORT_SHM || (!xgmi && t == TRANSPORT_P2P)))
continue;
if (n1 >= 0 && n2 >= 0 && t != TRANSPORT_NET) continue;
if (graph == NULL && connIndex == NCCL_CONN_IDX_P2P_NET && (t == TRANSPORT_SHM || (!xgmi && t == TRANSPORT_P2P))) continue;
if (graph && n1 >= 0 && n2 >= 0 && t != TRANSPORT_NET) continue;
struct ncclTransport *transport = ncclTransports+t;
struct ncclTransportComm* transportComm = type == 1 ? &transport->send : &transport->recv;
int ret = 0;
@@ -748,6 +749,16 @@ ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->ring.prev, 1, &channel->ring.next, 0), ret, affinity_restore);
}
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, 0), ret, affinity_restore);
if (ringGraph.nIntraChannels && rcclParamP2pNetDisable() == 0) {
comm->useIntraNet = 1;
// Connect NET for intranode use
for (int c=0; c<comm->nChannels; c++) {
struct ncclChannel* channel = comm->channels+c;
if (comm->nRanks == 1) continue;
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channel, 1, &channel->ring.prev, 1, &channel->ring.next, NCCL_CONN_IDX_P2P_NET), ret, affinity_restore);
}
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &ringGraph, NCCL_CONN_IDX_P2P_NET), ret, affinity_restore);
}
free(rings);
INFO(NCCL_INIT, "Connected all rings");