diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index f165aa1bf0..9c1b3d61c6 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -601,4 +601,10 @@ typedef struct { ncclResult_t (*closeListen)(void* listenComm); } ncclCollNet_v5_t; +// context passed from RCCL lib to n/w plugin +typedef struct { + // channel id + uint32_t chId; +} ncclNet_ctxt_t; + #endif // end include guard diff --git a/src/transport/net.cc b/src/transport/net.cc index bd62b719fa..c8a40869c2 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -20,6 +20,7 @@ #include #include "graph.h" #include "graph/topo.h" +#include "nccl_net.h" #if defined(ENABLE_NPKIT) #include "npkit/npkit.h" #endif @@ -741,6 +742,7 @@ static ncclResult_t ncclNetGetDeviceHandle(ncclNetDeviceType type, int version, } static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { + ncclNet_ctxt_t ncclNetCtxt = {}; struct sendNetResources* resources = (struct sendNetResources*)(connection->transportResources); if (reqSize != sizeof(netSendConnectArgs)) return ncclInternalError; ncclResult_t ret = ncclSuccess; @@ -767,7 +769,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank; if (comms->sendComm[resources->channelId] == NULL) { if (rccl_anp) { - ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + ncclNetCtxt.chId = resources->channelId; + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle); } @@ -776,7 +779,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; } else { if (rccl_anp) { - ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + ncclNetCtxt.chId = resources->channelId; + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle); } @@ -784,7 +788,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str } else { // Connect to remote peer if (rccl_anp) { - ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + ncclNetCtxt.chId = resources->channelId; + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle); } @@ -935,6 +940,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str netRecvConnectArgs* req = (netRecvConnectArgs*) reqBuff; resources->tpRemoteProxyRank = req->proxyRank; ncclResult_t ret = ncclSuccess; + ncclNet_ctxt_t ncclNetCtxt = {}; bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR)); NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle)); @@ -959,7 +965,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank; if (comms->recvComm[resources->channelId] == NULL) { if (rccl_anp) { - ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + ncclNetCtxt.chId = resources->channelId; + ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle); } @@ -968,7 +975,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; } else { if (rccl_anp) { - ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + ncclNetCtxt.chId = resources->channelId; + ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle); } @@ -976,7 +984,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str } else { // Connect to remote peer if (rccl_anp) { - ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + ncclNetCtxt.chId = resources->channelId; + ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt); } else { ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle); }