Adding AINIC Network Plugin check (#1528)
- Adding AINIC network plugin check to pass unused parameter to pass the channelId to the network plugin layer
[ROCm/rccl commit: 3494f52d40]
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
1fad40309e
Коммит
b9051c3eca
@@ -25,6 +25,8 @@
|
||||
|
||||
static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large");
|
||||
|
||||
#define RCCL_ANP_PLUGIN_STR "RCCL-ANP"
|
||||
|
||||
#define NCCL_NET_MAP_HOSTMEM 0
|
||||
#define NCCL_NET_MAP_DEVMEM 1
|
||||
#define NCCL_NET_MAP_SHARED_HOSTMEM 2
|
||||
@@ -715,6 +717,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
netSendConnectArgs* req = (netSendConnectArgs*) reqBuff;
|
||||
NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, false /*isRecv*/, &resources->netDeviceHandle));
|
||||
bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR));
|
||||
if (resources->shared) {
|
||||
// Shared buffers
|
||||
struct ncclProxyProgressState* progressState = &proxyState->progressState;
|
||||
@@ -733,15 +736,29 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str
|
||||
NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks));
|
||||
}
|
||||
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank;
|
||||
if (comms->sendComm[resources->channelId] == NULL) ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle);
|
||||
if (comms->sendComm[resources->channelId] == NULL) {
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle);
|
||||
}
|
||||
}
|
||||
resources->netSendComm = comms->sendComm[resources->channelId];
|
||||
if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++;
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Connect to remote peer
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
} else {
|
||||
ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle);
|
||||
}
|
||||
connection->proxyAppendPtr = &connection->proxyAppend;
|
||||
}
|
||||
|
||||
@@ -886,6 +903,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
|
||||
resources->tpRemoteProxyRank = req->proxyRank;
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
|
||||
bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR));
|
||||
NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle));
|
||||
// Finish connection establishment from remote peer
|
||||
if (resources->shared) {
|
||||
@@ -906,15 +924,29 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str
|
||||
NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks));
|
||||
}
|
||||
struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank;
|
||||
if (comms->recvComm[resources->channelId] == NULL) ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle);
|
||||
if (comms->recvComm[resources->channelId] == NULL) {
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle);
|
||||
}
|
||||
}
|
||||
resources->netRecvComm = comms->recvComm[resources->channelId];
|
||||
if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++;
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Connect to remote peer
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
|
||||
if (rccl_anp) {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId));
|
||||
} else {
|
||||
ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle);
|
||||
}
|
||||
connection->proxyAppendPtr = &connection->proxyAppend;
|
||||
}
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user