From b9051c3ecac78fa56b620c40cbefded97cc37e69 Mon Sep 17 00:00:00 2001 From: Vijay Srinivasan <8528272+vijasrin@users.noreply.github.com> Date: Thu, 6 Feb 2025 21:37:53 -0800 Subject: [PATCH] Adding AINIC Network Plugin check (#1528) - Adding AINIC network plugin check to pass unused parameter to pass the channelId to the network plugin layer [ROCm/rccl commit: 3494f52d406e5f3dc3d6bb4f0e41f6b2abfaf453] --- projects/rccl/src/transport/net.cc | 44 ++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/projects/rccl/src/transport/net.cc b/projects/rccl/src/transport/net.cc index 3b8f04fcf6..4261dd0fb0 100644 --- a/projects/rccl/src/transport/net.cc +++ b/projects/rccl/src/transport/net.cc @@ -25,6 +25,8 @@ static_assert(sizeof(ncclNetHandle_t) <= CONNECT_SIZE, "NET Connect info is too large"); +#define RCCL_ANP_PLUGIN_STR "RCCL-ANP" + #define NCCL_NET_MAP_HOSTMEM 0 #define NCCL_NET_MAP_DEVMEM 1 #define NCCL_NET_MAP_SHARED_HOSTMEM 2 @@ -715,6 +717,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str ncclResult_t ret = ncclSuccess; netSendConnectArgs* req = (netSendConnectArgs*) reqBuff; NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, false /*isRecv*/, &resources->netDeviceHandle)); + bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR)); if (resources->shared) { // Shared buffers struct ncclProxyProgressState* progressState = &proxyState->progressState; @@ -733,15 +736,29 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks)); } struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteRank; - if (comms->sendComm[resources->channelId] == NULL) ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle); + if (comms->sendComm[resources->channelId] == NULL) { + if (rccl_anp) { + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + } else { + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, comms->sendComm + resources->channelId, &resources->netDeviceHandle); + } + } resources->netSendComm = comms->sendComm[resources->channelId]; if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; } else { - ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle); + if (rccl_anp) { + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + } else { + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle); + } } } else { // Connect to remote peer - ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle); + if (rccl_anp) { + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + } else { + ret = proxyState->ncclNet->connect(resources->netDev, req->handle, &resources->netSendComm, &resources->netDeviceHandle); + } connection->proxyAppendPtr = &connection->proxyAppend; } @@ -886,6 +903,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str resources->tpRemoteProxyRank = req->proxyRank; ncclResult_t ret = ncclSuccess; + bool rccl_anp = !(strcmp(proxyState->ncclNet->name, RCCL_ANP_PLUGIN_STR)); NCCLCHECK(ncclNetGetDeviceHandle(resources->netDeviceType, resources->netDeviceVersion, true /*isRecv*/, &resources->netDeviceHandle)); // Finish connection establishment from remote peer if (resources->shared) { @@ -906,15 +924,29 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclCalloc(progressState->netComms + resources->netDev, proxyState->tpnRanks)); } struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev] + resources->tpRemoteProxyRank; - if (comms->recvComm[resources->channelId] == NULL) ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle); + if (comms->recvComm[resources->channelId] == NULL) { + if (rccl_anp) { + ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + } else { + ret = proxyState->ncclNet->accept(resources->netListenComm, comms->recvComm+resources->channelId, &resources->netDeviceHandle); + } + } resources->netRecvComm = comms->recvComm[resources->channelId]; if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; } else { - ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle); + if (rccl_anp) { + ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + } else { + ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle); + } } } else { // Connect to remote peer - ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle); + if (rccl_anp) { + ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, (ncclNetDeviceHandle_t **)(uintptr_t)(resources->channelId)); + } else { + ret = proxyState->ncclNet->accept(resources->netListenComm, &resources->netRecvComm, &resources->netDeviceHandle); + } connection->proxyAppendPtr = &connection->proxyAppend; }