Adding AINIC Network Plugin check (#1528)

- Adding AINIC network plugin check to pass unused parameter to pass the channelId to the network plugin layer

[ROCm/rccl commit: 3494f52d40]
Этот коммит содержится в:
Vijay Srinivasan
2025-02-06 21:37:53 -08:00
коммит произвёл GitHub
родитель 1fad40309e
Коммит b9051c3eca
+38 -6
Просмотреть файл
@@ -25,6 +25,8 @@
static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large");
#define RCCL_ANP_PLUGIN_STR "RCCL-ANP"
#define NCCL_NET_MAP_HOSTMEM 0
#define NCCL_NET_MAP_DEVMEM 1
#define NCCL_NET_MAP_SHARED_HOSTMEM 2
@@ -715,6 +717,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
ncclResult_t ret = ncclSuccess;
netSendConnectArgs* req = (netSendConnectArgs*) reqBuff;
NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, false /*isRecv*/, &resources->netDeviceHandle));
bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR));
if (resources->shared) {
// Shared buffers
struct ncclProxyProgressState* progressState = &proxyState->progressState;
@@ -733,15 +736,29 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks));
}
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank;
if (comms->sendComm[resources->channelId] == NULL) ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle);
if (comms->sendComm[resources->channelId] == NULL) {
if (rccl_anp) {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle);
}
}
resources->netSendComm = comms->sendComm[resources->channelId];
if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++;
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
if (rccl_anp) {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
}
}
} else {
// Connect to remote peer
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
if (rccl_anp) {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
} else {
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
}
connection->proxyAppendPtr = &connection->proxyAppend;
}
@@ -886,6 +903,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
resources->tpRemoteProxyRank = req->proxyRank;
ncclResult_t ret = ncclSuccess;
bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR));
NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle));
// Finish connection establishment from remote peer
if (resources->shared) {
@@ -906,15 +924,29 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks));
}
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank;
if (comms->recvComm[resources->channelId] == NULL) ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle);
if (comms->recvComm[resources->channelId] == NULL) {
if (rccl_anp) {
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle);
}
}
resources->netRecvComm = comms->recvComm[resources->channelId];
if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++;
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
if (rccl_anp) {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
}
}
} else {
// Connect to remote peer
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
if (rccl_anp) {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
} else {
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
}
connection->proxyAppendPtr = &connection->proxyAppend;
}