generic net plugin ctxt that is extensible for use in multiple APIs (#1735)

Co-authored-by: Sarat Kamisetty <sakamiset@amd.com>
This commit is contained in:
Sarat Kamisetty
2025-06-16 14:48:08 -07:00
committato da GitHub
parent 39211c6b41
commit fa0422f174
2 ha cambiato i file con 21 aggiunte e 6 eliminazioni
+6
Vedi File
@@ -601,4 +601,10 @@ typedef struct {
ncclResult_t (*closeListen)(void* listenComm);
} ncclCollNet_v5_t;
// context passed from RCCL lib to n/w plugin
typedef struct {
// channel id
uint32_t chId;
} ncclNet_ctxt_t;
#endif // end include guard
+15 -6
Vedi File
@@ -20,6 +20,7 @@
#include <assert.h>
#include "graph.h"
#include "graph/topo.h"
#include "nccl_net.h"
#if defined(ENABLE_NPKIT)
#include "npkit/npkit.h"
#endif
@@ -741,6 +742,7 @@ static ncclResult_t ncclNetGetDeviceHandle(ncclNetDeviceType type, int version,
}
static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
ncclNet_ctxt_t ncclNetCtxt = {};
struct sendNetResources* resources = (struct sendNetResources*)(connection->transportResources);
if (reqSize != sizeof(netSendConnectArgs)) return ncclInternalError;
ncclResult_t ret = ncclSuccess;
@@ -767,7 +769,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank;
if (comms->sendComm[resources->channelId] == NULL) {
if (rccl_anp) {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
ncclNetCtxt.chId = resources->channelId;
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle);
}
@@ -776,7 +779,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++;
} else {
if (rccl_anp) {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
ncclNetCtxt.chId = resources->channelId;
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
}
@@ -784,7 +788,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
} else {
// Connect to remote peer
if (rccl_anp) {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
ncclNetCtxt.chId = resources->channelId;
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
}
@@ -935,6 +940,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
netRecvConnectArgs* req = (netRecvConnectArgs*) reqBuff;
resources->tpRemoteProxyRank = req->proxyRank;
ncclResult_t ret = ncclSuccess;
ncclNet_ctxt_t ncclNetCtxt = {};
bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR));
NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle));
@@ -959,7 +965,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank;
if (comms->recvComm[resources->channelId] == NULL) {
if (rccl_anp) {
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
ncclNetCtxt.chId = resources->channelId;
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle);
}
@@ -968,7 +975,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++;
} else {
if (rccl_anp) {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
ncclNetCtxt.chId = resources->channelId;
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
}
@@ -976,7 +984,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
} else {
// Connect to remote peer
if (rccl_anp) {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
ncclNetCtxt.chId = resources->channelId;
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
}