generic net plugin ctxt that is extensible for use in multiple APIs (#1735)
Co-authored-by: Sarat Kamisetty <sakamiset@amd.com>
This commit is contained in:
@@ -601,4 +601,10 @@ typedef struct {
|
||||
ncclResult_t (*closeListen)(void* listenComm);
|
||||
} ncclCollNet_v5_t;
|
||||
|
||||
// context passed from RCCL lib to n/w plugin
|
||||
typedef struct {
|
||||
// channel id
|
||||
uint32_t chId;
|
||||
} ncclNet_ctxt_t;
|
||||
|
||||
#endif // end include guard
|
||||
|
||||
+15
-6
@@ -20,6 +20,7 @@
|
||||
#include <assert.h>
|
||||
#include "graph.h"
|
||||
#include "graph/topo.h"
|
||||
#include "nccl_net.h"
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
@@ -741,6 +742,7 @@ static ncclResult_t ncclNetGetDeviceHandle(ncclNetDeviceType type, int version,
|
||||
}
|
||||
|
||||
static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclProxyState* proxyState, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) {
|
||||
ncclNet_ctxt_t ncclNetCtxt = {};
|
||||
struct sendNetResources* resources = (struct sendNetResources*)(connection->transportResources);
|
||||
if (reqSize != sizeof(netSendConnectArgs)) return ncclInternalError;
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
@@ -767,7 +769,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
|
||||
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank;
|
||||
if (comms->sendComm[resources->channelId] == NULL) {
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
ncclNetCtxt.chId = resources->channelId;
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle);
|
||||
}
|
||||
@@ -776,7 +779,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
|
||||
if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++;
|
||||
} else {
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
ncclNetCtxt.chId = resources->channelId;
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
|
||||
}
|
||||
@@ -784,7 +788,8 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
|
||||
} else {
|
||||
// Connect to remote peer
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
ncclNetCtxt.chId = resources->channelId;
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
|
||||
}
|
||||
@@ -935,6 +940,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
|
||||
netRecvConnectArgs* req = (netRecvConnectArgs*) reqBuff;
|
||||
resources->tpRemoteProxyRank = req->proxyRank;
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
ncclNet_ctxt_t ncclNetCtxt = {};
|
||||
|
||||
bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR));
|
||||
NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle));
|
||||
@@ -959,7 +965,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
|
||||
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank;
|
||||
if (comms->recvComm[resources->channelId] == NULL) {
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
ncclNetCtxt.chId = resources->channelId;
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle);
|
||||
}
|
||||
@@ -968,7 +975,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
|
||||
if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++;
|
||||
} else {
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
ncclNetCtxt.chId = resources->channelId;
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
|
||||
}
|
||||
@@ -976,7 +984,8 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
|
||||
} else {
|
||||
// Connect to remote peer
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
ncclNetCtxt.chId = resources->channelId;
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)&ncclNetCtxt);
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
|
||||
}
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user