diff --git a/examples/rocshmem_init_attr_test.cc b/examples/rocshmem_init_attr_test.cc index d57d37f75a..44c73704c7 100644 --- a/examples/rocshmem_init_attr_test.cc +++ b/examples/rocshmem_init_attr_test.cc @@ -30,24 +30,35 @@ using namespace rocshmem; int main (int argc, char **argv) { - int rank, nranks; + int world_rank, world_nranks; int ret; rocshmem_uniqueid_t uid; rocshmem_init_attr_t attr; MPI_Init(&argc, &argv); - MPI_Comm_rank (MPI_COMM_WORLD, &rank); - MPI_Comm_size (MPI_COMM_WORLD, &nranks); + MPI_Comm_rank (MPI_COMM_WORLD, &world_rank); + MPI_Comm_size (MPI_COMM_WORLD, &world_nranks); + + // Create two disjoint groups of processes, each + // one creating a unique rocshmem environment independent + // of the other group + MPI_Comm newcomm; + int color = world_rank %2; + int rank, nranks; + + MPI_Comm_split(MPI_COMM_WORLD, color, world_rank, &newcomm); + MPI_Comm_rank (newcomm, &rank); + MPI_Comm_size (newcomm, &nranks); if (rank == 0) { ret = rocshmem_get_uniqueid (&uid); if (ret != ROCSHMEM_SUCCESS) { - std::cout << rank << ": Error in rocshmem_get_uniqueid. Aborting.\n"; - MPI_Abort (MPI_COMM_WORLD, ret); + std::cout << rank << ": Error in rocshmem_get_uniqueid. Aborting.\n"; + MPI_Abort (MPI_COMM_WORLD, ret); } } - MPI_Bcast (&uid, sizeof(rocshmem_uniqueid_t), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast (&uid, sizeof(rocshmem_uniqueid_t), MPI_BYTE, 0, newcomm); ret = rocshmem_set_attr_uniqueid_args(rank, nranks, &uid, &attr); if (ret != ROCSHMEM_SUCCESS) { std::cout << rank << ": Error in rocshmem_set_attr_uniqueid_args. Aborting.\n"; @@ -63,6 +74,7 @@ int main (int argc, char **argv) std::cout << rank << ": rocshmem_init_attr SUCCESS\n"; rocshmem_finalize(); + MPI_Comm_free (&newcomm); MPI_Finalize(); return 0; } diff --git a/include/rocshmem/rocshmem_common.hpp b/include/rocshmem/rocshmem_common.hpp index e2cd646e4e..a624c947e0 100644 --- a/include/rocshmem/rocshmem_common.hpp +++ b/include/rocshmem/rocshmem_common.hpp @@ -128,13 +128,11 @@ const rocshmem_team_t ROCSHMEM_TEAM_INVALID = nullptr; /** * @brief Data structure defining the unqiueId */ -constexpr int ROCSHMEM_HOSTNAME_LEN = 20; -struct rocshmem_uniqueid_t { - uint64_t random; - char hostname[ROCSHMEM_HOSTNAME_LEN]; - uint32_t pid; -}; -typedef struct rocshmem_uniqueid_t rocshmem_unique_id_t; + + +/// Unique ID for a process. This is a ROCSHMEM_UNIQUE_ID_BYTES byte array that uniquely identifies a process. +#define ROCSHMEM_UNIQUE_ID_BYTES 128 +using rocshmem_uniqueid_t = std::array; /** * @brief Data structure used for attribute based diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8b928eef68..7c49759cb8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,3 +67,4 @@ add_subdirectory(containers) add_subdirectory(host) add_subdirectory(memory) add_subdirectory(sync) +add_subdirectory(bootstrap) diff --git a/src/bootstrap/CMakeLists.txt b/src/bootstrap/CMakeLists.txt new file mode 100644 index 0000000000..8bde7f6392 --- /dev/null +++ b/src/bootstrap/CMakeLists.txt @@ -0,0 +1,33 @@ +############################################################################### +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +############################################################################### + +############################################################################### +# ADD ROCSHMEM TARGET FOR FILES IN CURRENT DIRECTORY +############################################################################### +target_sources( + ${PROJECT_NAME} + PRIVATE + socket.cpp + bootstrap.cpp + env.cpp + utils.cpp +) diff --git a/src/bootstrap/bootstrap.cpp b/src/bootstrap/bootstrap.cpp new file mode 100644 index 0000000000..f43001c6ec --- /dev/null +++ b/src/bootstrap/bootstrap.cpp @@ -0,0 +1,589 @@ +// Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. +// Licensed under the MIT license. + +#include + +#include +#include +#include +#include + +#include "bootstrap.hpp" +#include "utils.hpp" +#include "socket.hpp" + +namespace rocshmem { + +static void setFilesLimit() { + rlimit filesLimit; + if (getrlimit(RLIMIT_NOFILE, &filesLimit) != 0) { + INFO("getrlimit failed\n"); + return; + } + filesLimit.rlim_cur = filesLimit.rlim_max; + if (setrlimit(RLIMIT_NOFILE, &filesLimit) != 0) { + INFO("setrlimit failed\n"); + return; + } +} + +/* Socket Interface Selection type */ +enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 }; + +struct ExtInfo { + int rank; + int nRanks; + SocketAddress extAddressListenRoot; + SocketAddress extAddressListen; +}; + + void Bootstrap::groupBarrier(const std::vector& ranks) { + int dummy = 0; + for (auto rank : ranks) { + if (rank != this->getRank()) { + this->send(static_cast(&dummy), sizeof(dummy), rank, 0); + } + } + for (auto rank : ranks) { + if (rank != this->getRank()) { + this->recv(static_cast(&dummy), sizeof(dummy), rank, 0); + } + } +} + + void Bootstrap::send(const std::vector& data, int peer, int tag) { + size_t size = data.size(); + send((void*)&size, sizeof(size_t), peer, tag); + send((void*)data.data(), data.size(), peer, tag + 1); +} + + void Bootstrap::recv(std::vector& data, int peer, int tag) { + size_t size; + recv((void*)&size, sizeof(size_t), peer, tag); + data.resize(size); + recv((void*)data.data(), data.size(), peer, tag + 1); +} + +struct UniqueIdInternal { + uint64_t magic; + union SocketAddress addr; +}; +static_assert(sizeof(UniqueIdInternal) <= sizeof(rocshmem_uniqueid_t), "UniqueIdInternal is too large to fit into rocshmem_uniqueid_t"); + +class TcpBootstrap::Impl { + public: + static rocshmem_uniqueid_t createUniqueId(); + static rocshmem_uniqueid_t getUniqueId(const UniqueIdInternal& uniqueId); + + Impl(int rank, int nRanks); + ~Impl(); + void initialize(const rocshmem_uniqueid_t& uniqueId, int64_t timeoutSec); + void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec); + void establishConnections(int64_t timeoutSec); + rocshmem_uniqueid_t getUniqueId() const; + int getRank(); + int getNranks(); + int getNranksPerNode(); + void allGather(void* allData, int size); + void send(void* data, int size, int peer, int tag); + void recv(void* data, int size, int peer, int tag); + void barrier(); + void close(); + + private: + UniqueIdInternal uniqueId_; + int rank_; + int nRanks_; + int nRanksPerNode_; + bool netInitialized; + std::unique_ptr listenSockRoot_; + std::unique_ptr listenSock_; + std::unique_ptr ringRecvSocket_; + std::unique_ptr ringSendSocket_; + std::vector peerCommAddresses_; + std::vector barrierArr_; + std::unique_ptr abortFlagStorage_; + volatile uint32_t* abortFlag_; + std::thread rootThread_; + SocketAddress netIfAddr_; + std::unordered_map, std::shared_ptr, PairHash> peerSendSockets_; + std::unordered_map, std::shared_ptr, PairHash> peerRecvSockets_; + + void netSend(Socket* sock, const void* data, int size); + void netRecv(Socket* sock, void* data, int size); + + std::shared_ptr getPeerSendSocket(int peer, int tag); + std::shared_ptr getPeerRecvSocket(int peer, int tag); + + static void assignPortToUniqueId(UniqueIdInternal& uniqueId); + static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr); + + void bootstrapCreateRoot(); + void bootstrapRoot(); + void getRemoteAddresses(Socket* listenSock, std::vector& rankAddresses, + std::vector& rankAddressesRoot, int& rank); + void sendHandleToPeer(int peer, const std::vector& rankAddresses, + const std::vector& rankAddressesRoot); +}; + +rocshmem_uniqueid_t TcpBootstrap::Impl::createUniqueId() { + UniqueIdInternal uniqueId; + SocketAddress netIfAddr; + netInit("", "", netIfAddr); + getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic)); + std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress)); + assignPortToUniqueId(uniqueId); + return getUniqueId(uniqueId); +} + +rocshmem_uniqueid_t TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) { + rocshmem_uniqueid_t ret; + std::memcpy(&ret, &uniqueId, sizeof(uniqueId)); + return ret; +} + +TcpBootstrap::Impl::Impl(int rank, int nRanks) + : rank_(rank), + nRanks_(nRanks), + nRanksPerNode_(0), + netInitialized(false), + peerCommAddresses_(nRanks, SocketAddress()), + barrierArr_(nRanks, 0), + abortFlagStorage_(new uint32_t(0)), + abortFlag_(abortFlagStorage_.get()) {} + +rocshmem_uniqueid_t TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); } + +int TcpBootstrap::Impl::getRank() { return rank_; } + +int TcpBootstrap::Impl::getNranks() { return nRanks_; } + +void TcpBootstrap::Impl::initialize(const rocshmem_uniqueid_t& uniqueId, int64_t timeoutSec) { + if (!netInitialized) { + netInit("", "", netIfAddr_); + netInitialized = true; + } + + std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); + if (rank_ == 0) { + bootstrapCreateRoot(); + } + + char line[MAX_IF_NAME_SIZE + 1]; + SocketToString(&uniqueId_.addr, line); + TRACE("rank %d nranks %d - connecting to %s\n", rank_, nRanks_, line); + establishConnections(timeoutSec); +} + +void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) { + // first check if it is a trio + int nColons = 0; + for (auto c : ifIpPortTrio) { + if (c == ':') { + nColons++; + } + } + std::string ipPortPair = ifIpPortTrio; + std::string interface = ""; + if (nColons == 2) { + // we know the + interface = ifIpPortTrio.substr(0, ipPortPair.find_first_of(':')); + ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1); + } + + if (!netInitialized) { + netInit(ipPortPair, interface, netIfAddr_); + netInitialized = true; + } + + uniqueId_.magic = 0xdeadbeef; + std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress)); + SocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str()); + + if (rank_ == 0) { + bootstrapCreateRoot(); + } + + establishConnections(timeoutSec); +} + +TcpBootstrap::Impl::~Impl() { + if (abortFlag_) { + *abortFlag_ = 1; + } + if (rootThread_.joinable()) { + rootThread_.join(); + } +} + +void TcpBootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector& rankAddresses, + std::vector& rankAddressesRoot, int& rank) { + ExtInfo info; + SocketAddress zero; + std::memset(&zero, 0, sizeof(SocketAddress)); + + { + Socket sock(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); + sock.accept(listenSock); + netRecv(&sock, &info, sizeof(info)); + } + + if (this->nRanks_ != info.nRanks) { + ERROR("Bootstrap Root : mismatch in rank count from procs %d : %d\n", this->nRanks_, info.nRanks); + return; + } + + if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(SocketAddress)) != 0) { + ERROR("Bootstrap Root : rank %d of %d has already checked in\n", info.rank, this->nRanks_); + return; + } + + // Save the connection handle for that rank + rankAddressesRoot[info.rank] = info.extAddressListenRoot; + rankAddresses[info.rank] = info.extAddressListen; + rank = info.rank; +} + +void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector& rankAddresses, + const std::vector& rankAddressesRoot) { + int next = (peer + 1) % nRanks_; + Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_); + sock.connect(); + netSend(&sock, &rankAddresses[next], sizeof(SocketAddress)); +} + +void TcpBootstrap::Impl::assignPortToUniqueId(UniqueIdInternal& uniqueId) { + std::unique_ptr socket = std::make_unique(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap); + socket->bind(); + uniqueId.addr = socket->getAddr(); +} + +void TcpBootstrap::Impl::bootstrapCreateRoot() { + listenSockRoot_ = std::make_unique(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0); + listenSockRoot_->bindAndListen(); + uniqueId_.addr = listenSockRoot_->getAddr(); + + rootThread_ = std::thread([this]() { + // try { + bootstrapRoot(); + //} catch (const std::exception& e) { + //if (abortFlag_ && *abortFlag_) r; + //throw e; + //} + }); +} + +void TcpBootstrap::Impl::bootstrapRoot() { + int numCollected = 0; + std::vector rankAddresses(nRanks_, SocketAddress()); + // for initial rank <-> root information exchange + std::vector rankAddressesRoot(nRanks_, SocketAddress()); + + std::memset(rankAddresses.data(), 0, sizeof(SocketAddress) * nRanks_); + std::memset(rankAddressesRoot.data(), 0, sizeof(SocketAddress) * nRanks_); + setFilesLimit(); + + TRACE("BEGIN bootstrapRoot\n"); + /* Receive addresses from all ranks */ + do { + int rank; + getRemoteAddresses(listenSockRoot_.get(), rankAddresses, rankAddressesRoot, rank); + ++numCollected; + TRACE("Received connect from rank %d total %d/%d\n", rank, numCollected, nRanks_); + } while (numCollected < nRanks_ && (!abortFlag_ || *abortFlag_ == 0)); + + if (abortFlag_ && *abortFlag_) { + TRACE("ABORTED\n"); + return; + } + + TRACE("COLLECTED ALL %d HANDLES\n", nRanks_); + + // Send the connect handle for the next rank in the AllGather ring + for (int peer = 0; peer < nRanks_; ++peer) { + sendHandleToPeer(peer, rankAddresses, rankAddressesRoot); + } + + TRACE("DONE bootstrapRoot\n"); +} + +void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface, + SocketAddress& netIfAddr) { + char netIfName[MAX_IF_NAME_SIZE + 1]; + if (!ipPortPair.empty()) { + if (interface != "") { + // we know the + int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str()); + if (ret <= 0) { + ERROR("NET/Socket : No interface named %s found\n", interface.c_str()); + return; + } + } else { + // we do not know the try to match it next + SocketAddress remoteAddr; + SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()); + if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { + ERROR("NET/Socket : No usable listening interface found\n"); + return; + } + } + + } else { + int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1); + if (ret <= 0) { + ERROR("TcpBootstrap : no socket interface found\n"); + return; + } + } + + char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; + std::sprintf(line, " %s:", netIfName); + SocketToString(&netIfAddr, line + strlen(line)); + TRACE("TcpBootstrap : Using%s", line); +} + +#define TIMEOUT(__exp) \ + do { \ + try { \ + __exp; \ + } catch (const Error& e) { \ + if (e.getErrorCode() == ErrorCode::Timeout) { \ + throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); \ + } \ + throw; \ + } \ + } while (0); + +void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { + const int64_t connectionTimeoutUs = timeoutSec * 1000000; + Timer timer; + SocketAddress nextAddr; + ExtInfo info; + + TRACE("establishConnections: rank %d nranks %d\n", rank_, nRanks_); + + auto getLeftTime = [&]() { + if (connectionTimeoutUs < 0) { + // no timeout: always return a large number + return int64_t(1e9); + } + int64_t timeout = connectionTimeoutUs - timer.elapsed(); + if (timeout <= 0) { + ERROR("TcpBootstrap connection timeout\n"); + return (long int)-1; + } + return timeout; + }; + + info.rank = rank_; + info.nRanks = nRanks_; + + uint64_t magic = uniqueId_.magic; + // Create socket for other ranks to contact me + listenSock_ = std::make_unique(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); + listenSock_->bindAndListen(); + info.extAddressListen = listenSock_->getAddr(); + + { + // Create socket for root to contact me + Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); + lsock.bindAndListen(); + info.extAddressListenRoot = lsock.getAddr(); + + // stagger connection times to avoid an overload of the root + auto randomSleep = [](int rank) { + timespec tv; + tv.tv_sec = rank / 1000; + tv.tv_nsec = 1000000 * (rank % 1000); + TRACE("rank %d delaying connection to root by %ld sec %ld nsec\n", rank, + tv.tv_sec, tv.tv_nsec); + (void)nanosleep(&tv, NULL); + }; + if (nRanks_ > 128) { + randomSleep(rank_); + } + + // send info on my listening socket to root + { + Socket sock(&uniqueId_.addr, magic, SocketTypeBootstrap, abortFlag_); + //TIMEOUT(sock.connect(getLeftTime())); + sock.connect(getLeftTime()); + netSend(&sock, &info, sizeof(info)); + } + + // get info on my "next" rank in the bootstrap ring from root + { + Socket sock(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); + //TIMEOUT(sock.accept(&lsock, getLeftTime())); + sock.accept(&lsock, getLeftTime()); + netRecv(&sock, &nextAddr, sizeof(SocketAddress)); + } + } + + ringSendSocket_ = std::make_unique(&nextAddr, magic, SocketTypeBootstrap, abortFlag_); + //TIMEOUT(ringSendSocket_->connect(getLeftTime())); + ringSendSocket_->connect(getLeftTime()); + // Accept the connect request from the previous rank in the AllGather ring + ringRecvSocket_ = std::make_unique(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown, + abortFlag_); + //TIMEOUT(ringRecvSocket_->accept(listenSock_.get(), getLeftTime())); + ringRecvSocket_->accept(listenSock_.get(), getLeftTime()); + + // AllGather all listen handlers + peerCommAddresses_[rank_] = listenSock_->getAddr(); + allGather(peerCommAddresses_.data(), sizeof(SocketAddress)); + + TRACE("rank %d nranks %d - DONE\n", rank_, nRanks_); +} + +int TcpBootstrap::Impl::getNranksPerNode() { + if (nRanksPerNode_ > 0) return nRanksPerNode_; + int nRanksPerNode = 0; + bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET; + for (int i = 0; i < nRanks_; i++) { + if (useIpv4) { + if (peerCommAddresses_[i].sin.sin_addr.s_addr == + peerCommAddresses_[rank_].sin.sin_addr.s_addr) { + nRanksPerNode++; + } + } else { + if (std::memcmp(&(peerCommAddresses_[i].sin6.sin6_addr), + &(peerCommAddresses_[rank_].sin6.sin6_addr), + sizeof(in6_addr)) == 0) { + nRanksPerNode++; + } + } + } + nRanksPerNode_ = nRanksPerNode; + return nRanksPerNode_; +} + +void TcpBootstrap::Impl::allGather(void* allData, int size) { + char* data = static_cast(allData); + int rank = rank_; + int nRanks = nRanks_; + + TRACE("allGather: rank %d nranks %d size %d\n", rank, nRanks, size); + + /* Simple ring based AllGather + * At each step i receive data from (rank-i-1) from left + * and send previous step's data from (rank-i) to right + */ + for (int i = 0; i < nRanks - 1; i++) { + size_t rSlice = (rank - i - 1 + nRanks) % nRanks; + size_t sSlice = (rank - i + nRanks) % nRanks; + + // Send slice to the right + netSend(ringSendSocket_.get(), data + sSlice * size, size); + // Recv slice from the left + netRecv(ringRecvSocket_.get(), data + rSlice * size, size); + } + + TRACE("allGather: rank %d nranks %d size %d - DONE\n", rank, nRanks, size); +} + +std::shared_ptr TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) { + auto it = peerSendSockets_.find(std::make_pair(peer, tag)); + if (it != peerSendSockets_.end()) { + return it->second; + } + auto sock = std::make_shared(&peerCommAddresses_[peer], uniqueId_.magic, + SocketTypeBootstrap, abortFlag_); + sock->connect(); + netSend(sock.get(), &rank_, sizeof(int)); + netSend(sock.get(), &tag, sizeof(int)); + peerSendSockets_[std::make_pair(peer, tag)] = sock; + return sock; +} + +std::shared_ptr TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag) { + auto it = peerRecvSockets_.find(std::make_pair(peer, tag)); + if (it != peerRecvSockets_.end()) { + return it->second; + } + for (;;) { + auto sock = std::make_shared(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown, + abortFlag_); + sock->accept(listenSock_.get()); + int recvPeer, recvTag; + netRecv(sock.get(), &recvPeer, sizeof(int)); + netRecv(sock.get(), &recvTag, sizeof(int)); + peerRecvSockets_[std::make_pair(recvPeer, recvTag)] = sock; + if (recvPeer == peer && recvTag == tag) { + return sock; + } + } +} + +void TcpBootstrap::Impl::netSend(Socket* sock, const void* data, int size) { + sock->send(&size, sizeof(int)); + sock->send(const_cast(data), size); +} + +void TcpBootstrap::Impl::netRecv(Socket* sock, void* data, int size) { + int recvSize; + sock->recv(&recvSize, sizeof(int)); + if (recvSize > size) { + ERROR("Message truncated : received %d bytes instead of %d\n", recvSize, size); + return; + } + sock->recv(data, std::min(recvSize, size)); +} + +void TcpBootstrap::Impl::send(void* data, int size, int peer, int tag) { + auto sock = getPeerSendSocket(peer, tag); + netSend(sock.get(), data, size); +} + +void TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) { + auto sock = getPeerRecvSocket(peer, tag); + netRecv(sock.get(), data, size); +} + +void TcpBootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); } + +void TcpBootstrap::Impl::close() { + listenSockRoot_.reset(nullptr); + listenSock_.reset(nullptr); + ringRecvSocket_.reset(nullptr); + ringSendSocket_.reset(nullptr); + peerSendSockets_.clear(); + peerRecvSockets_.clear(); +} + + rocshmem_uniqueid_t TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); } + + TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique(rank, nRanks); } + + rocshmem_uniqueid_t TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); } + + int TcpBootstrap::getRank() { return pimpl_->getRank(); } + + int TcpBootstrap::getNranks() { return pimpl_->getNranks(); } + + int TcpBootstrap::getNranksPerNode() { return pimpl_->getNranksPerNode(); } + + void TcpBootstrap::send(void* data, int size, int peer, int tag) { + pimpl_->send(data, size, peer, tag); +} + + void TcpBootstrap::recv(void* data, int size, int peer, int tag) { + pimpl_->recv(data, size, peer, tag); +} + + void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); } + + void TcpBootstrap::initialize(rocshmem_uniqueid_t uniqueId, int64_t timeoutSec) { + pimpl_->initialize(uniqueId, timeoutSec); +} + + void TcpBootstrap::initialize(const std::string& ipPortPair, int64_t timeoutSec) { + pimpl_->initialize(ipPortPair, timeoutSec); +} + + void TcpBootstrap::barrier() { pimpl_->barrier(); } + + TcpBootstrap::~TcpBootstrap() { pimpl_->close(); } + +} // namespace rocshmem diff --git a/src/bootstrap/bootstrap.hpp b/src/bootstrap/bootstrap.hpp new file mode 100644 index 0000000000..273ae058a5 --- /dev/null +++ b/src/bootstrap/bootstrap.hpp @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. +// Licensed under the MIT license. + +#ifndef ROCSHMEM_BOOTSTRAP_HPP_ +#define ROCSHMEM_BOOTSTRAP_HPP_ + + +#include +#include +#include +#include +#include +#include + +#include "rocshmem/rocshmem_common.hpp" + +namespace rocshmem { + +/// Return a version string. +std::string version(); + +/// Base class for bootstraps. +class Bootstrap { + public: + Bootstrap(){}; + virtual ~Bootstrap() = default; + virtual int getRank() = 0; + virtual int getNranks() = 0; + virtual int getNranksPerNode() = 0; + virtual void send(void* data, int size, int peer, int tag) = 0; + virtual void recv(void* data, int size, int peer, int tag) = 0; + virtual void allGather(void* allData, int size) = 0; + virtual void barrier() = 0; + + void groupBarrier(const std::vector& ranks); + void send(const std::vector& data, int peer, int tag); + void recv(std::vector& data, int peer, int tag); +}; + +/// A native implementation of the bootstrap using TCP sockets. +class TcpBootstrap : public Bootstrap { + public: + /// Create a random unique ID. + /// @return The created unique ID. + static rocshmem_uniqueid_t createUniqueId(); + + /// Constructor. + /// @param rank The rank of the process. + /// @param nRanks The total number of ranks. + TcpBootstrap(int rank, int nRanks); + + /// Destructor. + ~TcpBootstrap(); + + /// Return the unique ID stored in the @ref TcpBootstrap. + /// @return The unique ID stored in the @ref TcpBootstrap. + rocshmem_uniqueid_t getUniqueId() const; + + /// Initialize the @ref TcpBootstrap with a given unique ID. + /// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with. + /// @param timeoutSec The connection timeout in seconds. + void initialize(rocshmem_uniqueid_t uniqueId, int64_t timeoutSec = 30); + + /// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port". + /// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port". + /// @param timeoutSec The connection timeout in seconds. + void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30); + + /// Return the rank of the process. + int getRank() override; + + /// Return the total number of ranks. + int getNranks() override; + + /// Return the total number of ranks per node. + int getNranksPerNode() override; + + /// Send data to another process. + /// + /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, + /// senderRank, tag)`. + /// + /// @param data The data to send. + /// @param size The size of the data to send. + /// @param peer The rank of the process to send the data to. + /// @param tag The tag to send the data with. + void send(void* data, int size, int peer, int tag) override; + + /// Receive data from another process. + /// + /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, + /// senderRank, tag)`. + /// + /// @param data The buffer to write the received data to. + /// @param size The size of the data to receive. + /// @param peer The rank of the process to receive the data from. + /// @param tag The tag to receive the data with. + void recv(void* data, int size, int peer, int tag) override; + + /// Gather data from all processes. + /// + /// When called by rank `r`, this sends data from `allData[r * size]` to `allData[(r + 1) * size - 1]` to all other + /// ranks. The data sent by rank `r` is received into `allData[r * size]` of other ranks. + /// + /// @param allData The buffer to write the received data to. + /// @param size The size of the data each rank sends. + void allGather(void* allData, int size) override; + + /// Synchronize all processes. + void barrier() override; + + private: + // The interal implementation. + class Impl; + + // Pointer to the internal implementation. + std::unique_ptr pimpl_; +}; + +} // namespace rocshmem + +#endif // ROCSHMEM_BOOTSTRAP_HPP_ diff --git a/src/bootstrap/env.cpp b/src/bootstrap/env.cpp new file mode 100644 index 0000000000..58452f633c --- /dev/null +++ b/src/bootstrap/env.cpp @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc, +// Licensed under the MIT license. + +#include +#include + +#include "env.hpp" +#include "utils.hpp" + +template +T readEnv(const std::string &envName, const T &defaultValue) { + const char *envCstr = getenv(envName.c_str()); + if (envCstr == nullptr) return defaultValue; + if constexpr (std::is_same_v) { + return atoi(envCstr); + } else if constexpr (std::is_same_v) { + return (std::string(envCstr) != "0"); + } + return T(envCstr); +} + +template +void readAndSetEnv(const std::string &envName, T &env) { + const char *envCstr = getenv(envName.c_str()); + if (envCstr == nullptr) return; + if constexpr (std::is_same_v) { + env = atoi(envCstr); + } else if constexpr (std::is_same_v) { + env = (std::string(envCstr) != "0"); + } else { + env = std::string(envCstr); + } +} + +template +void logEnv(const std::string &envName, const T &env) { + if (!getenv(envName.c_str())) return; + INFO("%s=%d", envName.c_str(), env); +} + +template <> +void logEnv(const std::string &envName, const std::string &env) { + if (!getenv(envName.c_str())) return; + INFO("%s=%s", envName.c_str(), env.c_str()); +} + +namespace rocshmem { + +Env::Env() + : debug(readEnv("ROCSHMEM_DEBUG", "")), + debugSubsys(readEnv("ROCSHMEM_DEBUG_SUBSYS", "")), + debugFile(readEnv("ROCSHMEM_DEBUG_FILE", "")), + hostid(readEnv("ROCSHMEM_HOSTID", "")), + socketFamily(readEnv("ROCSHMEM_SOCKET_FAMILY", "")), + socketIfname(readEnv("ROCSHMEM_SOCKET_IFNAME", "")) {} + +std::shared_ptr env() { + static std::shared_ptr globalEnv = std::shared_ptr(new Env()); + static bool logged = false; + if (!logged) { + logged = true; + // cannot log inside the constructor because of circular dependency + logEnv("ROCSHMEM_DEBUG", globalEnv->debug); + logEnv("ROCSHMEM_DEBUG_SUBSYS", globalEnv->debugSubsys); + logEnv("ROCSHMEM_DEBUG_FILE", globalEnv->debugFile); + logEnv("ROCSHMEM_HOSTID", globalEnv->hostid); + logEnv("ROCSHMEM_SOCKET_FAMILY", globalEnv->socketFamily); + logEnv("ROCSHMEM_SOCKET_IFNAME", globalEnv->socketIfname); + } + return globalEnv; +} + +} // namespace rocshmem diff --git a/src/bootstrap/env.hpp b/src/bootstrap/env.hpp new file mode 100644 index 0000000000..ca14c9b357 --- /dev/null +++ b/src/bootstrap/env.hpp @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc, +// Licensed under the MIT license. + +#ifndef ROCSHMEM_ENV_HPP_ +#define ROCSHMEM_ENV_HPP_ + +#include +#include + +namespace rocshmem { + +class Env; + +/// Get the environment. +/// @return A reference to the global environment object. +std::shared_ptr env(); + +/// The constructor reads environment variables and sets the corresponding fields. +/// Use the @ref env() function to get the environment object. +class Env { + public: + const std::string debug; + const std::string debugSubsys; + const std::string debugFile; + const std::string hostid; + const std::string socketFamily; + const std::string socketIfname; + + private: + Env(); + + friend std::shared_ptr env(); +}; + +} // namespace rocshmem + +#endif // ROCSHMEM_ENV_HPP_ diff --git a/src/bootstrap/socket.cpp b/src/bootstrap/socket.cpp new file mode 100644 index 0000000000..f333999d90 --- /dev/null +++ b/src/bootstrap/socket.cpp @@ -0,0 +1,774 @@ +// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. +// Modifications Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include + +#include "socket.hpp" +#include "utils.hpp" +#include "env.hpp" + +namespace rocshmem { + +#define ROCSHMEM_SOCKET_SEND 0 +#define ROCSHMEM_SOCKET_RECV 1 + +/* Format a string representation of a (union SocketAddress *) + * socket address using getnameinfo() + * + * Output: "IPv4/IPv6 address" + */ +const char* SocketToString(union SocketAddress* addr, char* buf, + const int numericHostForm /*= 1*/) { + if (buf == NULL || addr == NULL) return NULL; + struct sockaddr* saddr = &addr->sa; + if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { + buf[0] = '\0'; + return buf; + } + char host[NI_MAXHOST], service[NI_MAXSERV]; + /* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned. + * (When not set, this will still happen in case the node's name cannot be determined.) + */ + int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0); + (void)getnameinfo(saddr, sizeof(union SocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag); + sprintf(buf, "%s<%s>", host, service); + return buf; +} + +// Equivalent with ($ cat /proc/sys/net/ipv4/tcp_fin_timeout) +static int getTcpFinTimeout() { + std::ifstream ifs("/proc/sys/net/ipv4/tcp_fin_timeout"); + if (!ifs.is_open()) { + ERROR("open /proc/sys/net/ipv4/tcp_fin_timeout failed errno %d\n", errno); + return -1; + } + int timeout; + ifs >> timeout; + return timeout; +} + +static uint16_t socketToPort(union SocketAddress* addr) { + struct sockaddr* saddr = &addr->sa; + return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port); +} + +/* Allow the user to force the IPv4/IPv6 interface selection */ +static int envSocketFamily(void) { + int family = -1; // Family selection is not forced, will use first one found + const std::string& socketFamily = env()->socketFamily; + if (socketFamily == "") return family; + + if (socketFamily == "AF_INET") + family = AF_INET; // IPv4 + else if (socketFamily == "AF_INET6") + family = AF_INET6; // IPv6 + return family; +} + +static int findInterfaces(const char* prefixList, char* names, union SocketAddress* addrs, + int sock_family, int maxIfNameSize, int maxIfs) { +#ifdef ROCSHMEM_ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN + 1]; +#endif + struct netIf userIfs[MAX_IFS]; + bool searchNot = prefixList && prefixList[0] == '^'; + if (searchNot) prefixList++; + bool searchExact = prefixList && prefixList[0] == '='; + if (searchExact) prefixList++; + int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS); + + int found = 0; + struct ifaddrs *interfaces, *interface; + getifaddrs(&interfaces); + for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) { + if (interface->ifa_addr == NULL) continue; + + /* We only support IPv4 & IPv6 */ + int family = interface->ifa_addr->sa_family; + if (family != AF_INET && family != AF_INET6) continue; + + TRACE("Found interface %s:%s\n", interface->ifa_name, + SocketToString((union SocketAddress*)interface->ifa_addr, line)); + + /* Allow the caller to force the socket family type */ + if (sock_family != -1 && family != sock_family) continue; + + /* We also need to skip IPv6 loopback interfaces */ + if (family == AF_INET6) { + struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr); + if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue; + } + + // check against user specified interfaces + if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) { + continue; + } + + // Check that this interface has not already been saved + // getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link + bool duplicate = false; + for (int i = 0; i < found; i++) { + if (strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) { + duplicate = true; + break; + } + } + + if (!duplicate) { + // Store the interface name + strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize); + // Store the IP address + int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + memcpy(addrs + found, interface->ifa_addr, salen); + found++; + } + } + + freeifaddrs(interfaces); + return found; +} + +static bool matchSubnet(struct ifaddrs local_if, union SocketAddress* remote) { + /* Check family first */ + int family = local_if.ifa_addr->sa_family; + if (family != remote->sa.sa_family) { + return false; + } + + if (family == AF_INET) { + struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr); + struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask); + struct sockaddr_in& remote_addr = remote->sin; + struct in_addr local_subnet, remote_subnet; + local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr; + remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr; + return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true; + } else if (family == AF_INET6) { + struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr); + struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask); + struct sockaddr_in6& remote_addr = remote->sin6; + struct in6_addr& local_in6 = local_addr->sin6_addr; + struct in6_addr& mask_in6 = mask->sin6_addr; + struct in6_addr& remote_in6 = remote_addr.sin6_addr; + bool same = true; + int len = 16; // IPv6 address is 16 unsigned char + for (int c = 0; c < len; c++) { // Network byte order is big-endian + char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c]; + char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c]; + if (c1 ^ c2) { + same = false; + break; + } + } + // At last, we need to compare scope id + // Two Link-type addresses can have the same subnet address even though they are not in the same scope + // For Global type, this field is 0, so a comparison wouldn't matter + same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id); + return same; + } else { + ERROR("Net : Unsupported address family type\n"); + return false; + } +} + +int FindInterfaceMatchSubnet(char* ifNames, union SocketAddress* localAddrs, union SocketAddress* remoteAddr, + int ifNameMaxSize, int maxIfs) { +#ifdef ROCSHMEM_ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN + 1]; +#endif + char line_a[SOCKET_NAME_MAXLEN + 1]; + int found = 0; + struct ifaddrs *interfaces, *interface; + getifaddrs(&interfaces); + for (interface = interfaces; interface && !found; interface = interface->ifa_next) { + if (interface->ifa_addr == NULL) continue; + + /* We only support IPv4 & IPv6 */ + int family = interface->ifa_addr->sa_family; + if (family != AF_INET && family != AF_INET6) continue; + + // check against user specified interfaces + if (!matchSubnet(*interface, remoteAddr)) { + continue; + } + + // Store the local IP address + int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + memcpy(localAddrs + found, interface->ifa_addr, salen); + + // Store the interface name + strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize); + + TRACE("NET : Found interface %s:%s in the same subnet as remote address %s\n", + interface->ifa_name, SocketToString(localAddrs + found, line), SocketToString(remoteAddr, line_a)); + found++; + if (found == maxIfs) break; + } + + if (found == 0) { + ERROR("Net : No interface found in the same subnet as remote address %s\n", + SocketToString(remoteAddr, line_a)); + } + freeifaddrs(interfaces); + return found; +} + +void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair) { + if (!(ip_port_pair && strlen(ip_port_pair) > 1)) { + ERROR("Net : string is null\n"); + return; + } + + bool ipv6 = ip_port_pair[0] == '['; + /* Construct the sockaddress structure */ + if (!ipv6) { + struct netIf ni; + // parse : string, expect one pair + if (parseStringList(ip_port_pair, &ni, 1) != 1) { + ERROR("Net : No valid : pair found\n"); + return; + } + + struct addrinfo hints, *p; + int rv; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + if ((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) { + ERROR("Net : error encountered when getting address info : %s\n", gai_strerror(rv)); + return; + } + + // use the first + if (p->ai_family == AF_INET) { + struct sockaddr_in& sin = ua->sin; + memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in)); + sin.sin_family = AF_INET; // IPv4 + // inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address + sin.sin_port = htons(ni.port); // port + } else if (p->ai_family == AF_INET6) { + struct sockaddr_in6& sin6 = ua->sin6; + memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6)); + sin6.sin6_family = AF_INET6; // IPv6 + sin6.sin6_port = htons(ni.port); // port + sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete + sin6.sin6_scope_id = 0; // should be global scope, set to 0 + } else { + ERROR("Net : unsupported IP family\n"); + return; + } + + freeaddrinfo(p); // all done with this structure + + } else { + int i, j = -1, len = strlen(ip_port_pair); + for (i = 1; i < len; i++) { + if (ip_port_pair[i] == '%') j = i; + if (ip_port_pair[i] == ']') break; + } + if (i == len) { + ERROR("Net : No valid [IPv6]:port pair found\n"); + return; + } + bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope + + char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ]; + memset(ip_str, '\0', sizeof(ip_str)); + memset(port_str, '\0', sizeof(port_str)); + memset(if_name, '\0', sizeof(if_name)); + strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1); + strncpy(port_str, ip_port_pair + i + 2, len - i - 1); + int port = atoi(port_str); + + // If not global scope, we need the intf name + if (!global_scope) + strncpy(if_name, ip_port_pair + j + 1, i - j - 1); + + struct sockaddr_in6& sin6 = ua->sin6; + sin6.sin6_family = AF_INET6; // IPv6 + inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address + sin6.sin6_port = htons(port); // port + sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete + sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope + } +} + +int FindInterfaces(char* ifNames, union SocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs, + const char* inputIfName) { + static int shownIfName = 0; + int nIfs = 0; + + // Allow user to force the INET socket family selection + int sock_family = envSocketFamily(); + + // User specified interface + const std::string& socketIfname = env()->socketIfname; + if (inputIfName) { + TRACE("using iterface %s", inputIfName); + nIfs = findInterfaces(inputIfName, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); + } else if (socketIfname != "") { + // Specified by user : find or fail + if (shownIfName++ == 0) TRACE ("ROCSHMEM_SOCKET_IFNAME set to %s", socketIfname.c_str()); + nIfs = findInterfaces(socketIfname.c_str(), ifNames, ifAddrs, sock_family, + ifNameMaxSize, maxIfs); + } else { + // Try to automatically pick the right one + // Look for anything (but not docker or lo) + if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, + ifNameMaxSize, maxIfs); + // Finally look for docker, then lo. + if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, + ifNameMaxSize, maxIfs); + if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, + ifNameMaxSize, maxIfs); + } + return nIfs; +} + +Socket::Socket(const SocketAddress* addr, uint64_t magic, enum SocketType type, volatile uint32_t* abortFlag, + int asyncFlag) { + fd_ = -1; + acceptFd_ = -1; + connectRetries_ = 0; + acceptRetries_ = 0; + abortFlag_ = abortFlag; + asyncFlag_ = asyncFlag; + state_ = SocketStateInitialized; + magic_ = magic; + type_ = type; + + if (addr) { + /* IPv4/IPv6 support */ + int family; + memcpy(&addr_, addr, sizeof(union SocketAddress)); + family = addr_.sa.sa_family; + if (family != AF_INET && family != AF_INET6) { + char line[SOCKET_NAME_MAXLEN + 1]; + ERROR("SocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)\n", + SocketToString(&addr_, line), family, (int)AF_INET, (int)AF_INET6); + return; + } + salen_ = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + + /* Connect to a hostname / port */ + fd_ = ::socket(family, SOCK_STREAM, 0); + if (fd_ == -1) { + ERROR("socket creation failed %d\n", errno); + return; + } + } else { + memset(&addr_, 0, sizeof(union SocketAddress)); + } + + /* Set socket as non-blocking if async or if we need to be able to abort */ + if ((asyncFlag_ || abortFlag_) && fd_ >= 0) { + int flags = fcntl(fd_, F_GETFL); + if (flags == -1) { + ERROR("fcntl(F_GETFL) failed errno %d\n", errno); + return; + } + if (fcntl(fd_, F_SETFL, flags | O_NONBLOCK) == -1) { + ERROR("fcntl(F_SETFL) failed errno %d\n", errno); + return; + } + } +} + +Socket::~Socket() { close(); } + +void Socket::bind() { + if (fd_ == -1) { + ERROR("file descriptor is -1\n"); + return; + } + + if (socketToPort(&addr_)) { + // Port is forced by env. Make sure we get the port. + int opt = 1; +#if defined(SO_REUSEPORT) + if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)) != 0) { + ERROR("::setsockopt(SO_REUSEADDR | SO_REUSEPORT) failed errno %d\n", errno); + return; + } +#else + if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) != 0) { + ERROR("setsockopt(SO_REUSEADDR) failed errno %d\n", errno); + return; + } +#endif + } + + int finTimeout = getTcpFinTimeout(); + int retrySecs = finTimeout + 1; + int remainSecs = retrySecs; + + // addr port should be 0 (Any port) + while (::bind(fd_, &addr_.sa, salen_) != 0) { + // upon EADDRINUSE, retry up to for (finTimeout + 1) seconds + if (errno != EADDRINUSE) { + ERROR("bind failed errno %d\n", errno); + return; + } + if (remainSecs > 0) { + TRACE("No available ephemeral ports found, will retry after 1 second"); + sleep(1); + remainSecs--; + } else { + ERROR("No available ephemeral ports found for %d seconds \n", retrySecs); + return; + } + } + + /* Get the assigned Port */ + socklen_t size = salen_; + if (::getsockname(fd_, &addr_.sa, &size) != 0) { + ERROR("getsockname failed errno %d\n", errno); + return; + } + state_ = SocketStateBound; +} + +void Socket::bindAndListen() { +#ifdef ROCSHMEM_ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN + 1]; +#endif + bind(); + TRACE("Listening on socket %s\n", SocketToString(&addr_, line)); + + /* Put the socket in listen mode + * NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn + */ + if (::listen(fd_, 16384) != 0) { + ERROR("listen failed errno %d\n", errno); + return; + } + state_ = SocketStateReady; +} + +void Socket::connect(int64_t timeout) { +#ifdef ROCSHMEM_ENABLE_TRACE + char line[SOCKET_NAME_MAXLEN + 1]; +#endif + Timer timer; + const int one = 1; + + if (fd_ == -1) { + ERROR("file descriptor is -1\n"); + return; + } + + if (state_ != SocketStateInitialized) { + ERROR("wrong socket state %d\n", state_); + return; + } + TRACE("Connecting to socket %s \n", SocketToString(&addr_, line)); + + if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)) != 0) { + INFO("setsockopt(TCP_NODELAY) failed, errno %d\n", errno); + return; + } + + state_ = SocketStateConnecting; + do { + progressState(); + if (timeout > 0 && timer.elapsed() > timeout) { + ERROR("connect timeout\n"); + return; + } + } while (asyncFlag_ == 0 && (abortFlag_ == NULL || *abortFlag_ == 0) && + (state_ == SocketStateConnecting || state_ == SocketStateConnectPolling || state_ == SocketStateConnected)); + + if (abortFlag_ && *abortFlag_ != 0) { + ERROR("aborted\n"); + return; + } +} + +void Socket::accept(const Socket* listenSocket, int64_t timeout) { + Timer timer; + + if (listenSocket == NULL) { + ERROR("listenSocket is NULL\n"); + return; + } + if (listenSocket->getState() != SocketStateReady) { + ERROR("listenSocket is in error state %u\n", listenSocket->getState()); + return; + } + + if (acceptFd_ == -1) { + fd_ = listenSocket->getFd(); + connectRetries_ = listenSocket->getConnectRetries(); + acceptRetries_ = listenSocket->getAcceptRetries(); + abortFlag_ = listenSocket->getAbortFlag(); + asyncFlag_ = listenSocket->getAsyncFlag(); + magic_ = listenSocket->getMagic(); + type_ = listenSocket->getType(); + addr_ = listenSocket->getAddr(); + salen_ = listenSocket->getSalen(); + + acceptFd_ = listenSocket->getFd(); + state_ = SocketStateAccepting; + } + + do { + progressState(); + if (timeout > 0 && timer.elapsed() > timeout) { + ERROR("accept timeout\n"); + return; + } + } while (asyncFlag_ == 0 && (abortFlag_ == NULL || *abortFlag_ == 0) && + (state_ == SocketStateAccepting || state_ == SocketStateAccepted)); + + if (abortFlag_ && *abortFlag_ != 0) { + ERROR("aborted\n"); + return; + } +} + +void Socket::send(void* ptr, int size) { + int offset = 0; + if (state_ != SocketStateReady) { + ERROR("socket state (%d) is not ready\n", state_); + return; + } + socketWait(ROCSHMEM_SOCKET_SEND, ptr, size, &offset); +} + +void Socket::recv(void* ptr, int size) { + int offset = 0; + if (state_ != SocketStateReady) { + ERROR("socket state (%d) is not read\n", state_); + return; + } + socketWait(ROCSHMEM_SOCKET_RECV, ptr, size, &offset); +} + +void Socket::recvUntilEnd(void* ptr, int size, int* closed) { + int offset = 0; + *closed = 0; + if (state_ != SocketStateReady) { + ERROR("socket state (%d) is not ready in recvUntilEnd\n", state_); + return; + } + + int bytes = 0; + char* data = (char*)ptr; + + do { + bytes = ::recv(fd_, data + (offset), size - (offset), 0); + if (bytes == 0) { + *closed = 1; + return; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) { + ERROR("recv until end failed errno %d\n", errno); + return; + } else { + bytes = 0; + } + } + (offset) += bytes; + if (abortFlag_ && *abortFlag_ != 0) { + ERROR("aborted\n"); + return; + } + } while (bytes > 0 && (offset) < size); +} + +void Socket::close() { + if (fd_ >= 0) ::close(fd_); + state_ = SocketStateClosed; + fd_ = -1; +} + +void Socket::progressState() { + if (state_ == SocketStateAccepting) { + tryAccept(); + } + if (state_ == SocketStateAccepted) { + finalizeAccept(); + } + if (state_ == SocketStateConnecting) { + startConnect(); + } + if (state_ == SocketStateConnectPolling) { + pollConnect(); + } + if (state_ == SocketStateConnected) { + finalizeConnect(); + } +} + +void Socket::tryAccept() { + socklen_t socklen = sizeof(union SocketAddress); + fd_ = ::accept(acceptFd_, &addr_.sa, &socklen); + if (fd_ != -1) { + state_ = SocketStateAccepted; + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + ERROR("accept failed (fd %d) errno %d\n", acceptFd_, errno); + } else { + usleep(SLEEP_INT); + if (++acceptRetries_ % 1000 == 0) + INFO("tryAccept: Call to try accept returned %s, retrying", strerror(errno)); + } +} + +void Socket::finalizeAccept() { + uint64_t magic; + enum SocketType type; + int received = 0; + socketProgress(ROCSHMEM_SOCKET_RECV, &magic, sizeof(magic), &received); + if (received == 0) return; + socketWait(ROCSHMEM_SOCKET_RECV, &magic, sizeof(magic), &received); + if (magic != magic_) { + ERROR("finalizeAccept: wrong magic %lx != %lx\n", magic, magic_); + ::close(fd_); + fd_ = -1; + // Ignore spurious connection and accept again + state_ = SocketStateAccepting; + return; + } else { + received = 0; + socketWait(ROCSHMEM_SOCKET_RECV, &type, sizeof(type), &received); + if (type != type_) { + state_ = SocketStateError; + ::close(fd_); + fd_ = -1; + ERROR("wrong socket type %d != %d \n", type, type_); + return; + } else { + state_ = SocketStateReady; + } + } +} + +void Socket::startConnect() { + /* blocking/non-blocking connect() is determined by asyncFlag. */ + int ret = ::connect(fd_, &addr_.sa, salen_); + if (ret == 0) { + state_ = SocketStateConnected; + return; + } else if (errno == EINPROGRESS) { + state_ = SocketStateConnectPolling; + return; + } else if (errno == ECONNREFUSED || errno == ETIMEDOUT) { + usleep(SLEEP_INT); + if (++connectRetries_ % 1000 == 0) INFO("Call to connect returned %s, retrying", strerror(errno)); + return; + } else { + char line[SOCKET_NAME_MAXLEN + 1]; + state_ = SocketStateError; + ERROR("connect to %s failed, errno %d\n", SocketToString(&addr_, line), errno); + return; + } +} + +void Socket::pollConnect() { + struct pollfd pfd; + int timeout = 1, ret; + socklen_t rlen = sizeof(int); + + memset(&pfd, 0, sizeof(struct pollfd)); + pfd.fd = fd_; + pfd.events = POLLOUT; + ret = ::poll(&pfd, 1, timeout); + if (ret == -1) { + ERROR("poll failed errno %d\n", errno); + return; + } + if (ret == 0) return; + + /* check socket status */ + if ((ret == 1 && (pfd.revents & POLLOUT)) == 0) { + ERROR("poll failed\n"); + return; + } + if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen) == -1) { + ERROR("getsockopt failed, errno %d\n", errno); + return; + } + + if (ret == 0) { + state_ = SocketStateConnected; + } else if (ret == ECONNREFUSED || ret == ETIMEDOUT) { + if (++connectRetries_ % 1000 == 0) { + INFO("Call to connect returned %s, retrying", strerror(errno)); + } + usleep(SLEEP_INT); + + ::close(fd_); + fd_ = ::socket(addr_.sa.sa_family, SOCK_STREAM, 0); + state_ = SocketStateConnecting; + } else if (ret != EINPROGRESS) { + state_ = SocketStateError; + ERROR("connect failed \n"); + return; + } +} + +void Socket::finalizeConnect() { + int sent = 0; + socketProgress(ROCSHMEM_SOCKET_SEND, &magic_, sizeof(magic_), &sent); + if (sent == 0) return; + socketWait(ROCSHMEM_SOCKET_SEND, &magic_, sizeof(magic_), &sent); + sent = 0; + socketWait(ROCSHMEM_SOCKET_SEND, &type_, sizeof(type_), &sent); + state_ = SocketStateReady; +} + +void Socket::socketProgressOpt(int op, void* ptr, int size, int* offset, int block, int* closed) { + int bytes = 0; + *closed = 0; + char* data = (char*)ptr; + + do { + if (op == ROCSHMEM_SOCKET_RECV) bytes = ::recv(fd_, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT); + if (op == ROCSHMEM_SOCKET_SEND) + bytes = ::send(fd_, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL); + if (op == ROCSHMEM_SOCKET_RECV && bytes == 0) { + *closed = 1; + return; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { + ERROR("recv failed, errno %d\n", errno); + return; + } else { + bytes = 0; + } + } + (*offset) += bytes; + if (abortFlag_ && *abortFlag_ != 0) { + ERROR("aborted\n"); + return; + } + } while (bytes > 0 && (*offset) < size); +} + +void Socket::socketProgress(int op, void* ptr, int size, int* offset) { + int closed; + socketProgressOpt(op, ptr, size, offset, 0, &closed); + if (closed) { + char line[SOCKET_NAME_MAXLEN + 1]; + ERROR("connection closed by remote peer %s\n", SocketToString(&addr_, line, 0)); + return; + } +} + +void Socket::socketWait(int op, void* ptr, int size, int* offset) { + while (*offset < size) socketProgress(op, ptr, size, offset); +} + +} // namespace rocshmem diff --git a/src/bootstrap/socket.hpp b/src/bootstrap/socket.hpp new file mode 100644 index 0000000000..253a1d30cf --- /dev/null +++ b/src/bootstrap/socket.hpp @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. +// Licensed under the MIT license. + +#ifndef ROCSHMEM_SOCKET_H_ +#define ROCSHMEM_SOCKET_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace rocshmem { + +#define MAX_IFS 16 +#define MAX_IF_NAME_SIZE 16 +#define SLEEP_INT 1000 // connection retry sleep interval in usec +#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV) +#define ROCSHMEM_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL + +/* Common socket address storage structure for IPv4/IPv6 */ +union SocketAddress { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; +}; + +enum SocketState { + SocketStateNone = 0, + SocketStateInitialized = 1, + SocketStateAccepting = 2, + SocketStateAccepted = 3, + SocketStateConnecting = 4, + SocketStateConnectPolling = 5, + SocketStateConnected = 6, + SocketStateBound = 7, + SocketStateReady = 8, + SocketStateClosed = 9, + SocketStateError = 10, + SocketStateNum = 11 +}; + +enum SocketType { + SocketTypeUnknown = 0, + SocketTypeBootstrap = 1, + SocketTypeProxy = 2, + SocketTypeNetSocket = 3, + SocketTypeNetIb = 4 +}; + +const char* SocketToString(union SocketAddress* addr, char* buf, const int numericHostForm = 1); +void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair); +int FindInterfaceMatchSubnet(char* ifNames, union SocketAddress* localAddrs, union SocketAddress* remoteAddr, + int ifNameMaxSize, int maxIfs); +int FindInterfaces(char* ifNames, union SocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs, + const char* inputIfName = nullptr); + +class Socket { + public: + Socket(const SocketAddress* addr = nullptr, uint64_t magic = ROCSHMEM_SOCKET_MAGIC, + enum SocketType type = SocketTypeUnknown, volatile uint32_t* abortFlag = nullptr, int asyncFlag = 0); + ~Socket(); + + void bind(); + void bindAndListen(); + void connect(int64_t timeout = -1); + void accept(const Socket* listenSocket, int64_t timeout = -1); + void send(void* ptr, int size); + void recv(void* ptr, int size); + void recvUntilEnd(void* ptr, int size, int* closed); + void close(); + + int getFd() const { return fd_; } + int getAcceptFd() const { return acceptFd_; } + int getConnectRetries() const { return connectRetries_; } + int getAcceptRetries() const { return acceptRetries_; } + volatile uint32_t* getAbortFlag() const { return abortFlag_; } + int getAsyncFlag() const { return asyncFlag_; } + enum SocketState getState() const { return state_; } + uint64_t getMagic() const { return magic_; } + enum SocketType getType() const { return type_; } + SocketAddress getAddr() const { return addr_; } + int getSalen() const { return salen_; } + + private: + void tryAccept(); + void finalizeAccept(); + void startConnect(); + void pollConnect(); + void finalizeConnect(); + void progressState(); + + void socketProgressOpt(int op, void* ptr, int size, int* offset, int block, int* closed); + void socketProgress(int op, void* ptr, int size, int* offset); + void socketWait(int op, void* ptr, int size, int* offset); + + int fd_; + int acceptFd_; + int connectRetries_; + int acceptRetries_; + volatile uint32_t* abortFlag_; + int asyncFlag_; + enum SocketState state_; + uint64_t magic_; + enum SocketType type_; + + union SocketAddress addr_; + int salen_; +}; + +} // namespace rocshmem + +#endif // ROCSHMEM_SOCKET_H_ diff --git a/src/bootstrap/utils.cpp b/src/bootstrap/utils.cpp new file mode 100644 index 0000000000..38c8e4a3b5 --- /dev/null +++ b/src/bootstrap/utils.cpp @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. +// Licensed under the MIT license. + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" +#include "env.hpp" + +constexpr char HOSTID_FILE[32] = "/proc/sys/kernel/random/boot_id"; + +static bool matchIf(const char* string, const char* ref, bool matchExact) { + // Make sure to include '\0' in the exact case + int matchLen = matchExact ? strlen(string) + 1 : strlen(ref); + return strncmp(string, ref, matchLen) == 0; +} + +static bool matchPort(const int port1, const int port2) { + if (port1 == -1) return true; + if (port2 == -1) return true; + if (port1 == port2) return true; + return false; +} + +namespace rocshmem { + +std::string int64ToBusId(int64_t id) { + char busId[20]; + std::snprintf(busId, sizeof(busId), "%04lx:%02lx:%02lx.%01lx", (id) >> 20, (id & 0xff000) >> 12, (id & 0xff0) >> 4, + (id & 0xf)); + return std::string(busId); +} + +int64_t busIdToInt64(const std::string busId) { + char hexStr[17]; // Longest possible int64 hex string + null terminator. + size_t hexOffset = 0; + for (size_t i = 0; hexOffset < sizeof(hexStr) - 1 && i < busId.length(); ++i) { + char c = busId[i]; + if (c == '.' || c == ':') continue; + if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f')) { + hexStr[hexOffset++] = busId[i]; + } else + break; + } + hexStr[hexOffset] = '\0'; + return std::strtol(hexStr, NULL, 16); +} + +uint64_t getHash(const char* string, int n) { + // Based on DJB2a, result = result * 33 ^ char + uint64_t result = 5381; + for (int c = 0; c < n; c++) { + result = ((result << 5) + result) ^ string[c]; + } + return result; +} + +/* Generate a hash of the unique identifying string for this host + * that will be unique for both bare-metal and container instances + * Equivalent of a hash of; + * + * $(hostname)$(cat /proc/sys/kernel/random/boot_id) + * + * This string can be overridden by using the ROCSHMEM_HOSTID env var. + */ +uint64_t computeHostHash(void) { + const size_t hashLen = 1024; + char hostHash[hashLen]; + + memset(hostHash, 0, hashLen); + + std::string hostName = getHostName(hashLen, '\0'); + strncpy(hostHash, hostName.c_str(), hostName.size()); + + std::string hostid = env()->hostid; + if (hostid != "") { + strncpy(hostHash, hostid.c_str(), hashLen); + } else if (hostName.size() < hashLen) { + std::ifstream file(HOSTID_FILE, std::ios::binary); + if (file.is_open()) { + file.read(hostHash + hostName.size(), hashLen - hostName.size()); + } + } + + // Make sure the string is terminated + hostHash[sizeof(hostHash) - 1] = '\0'; + TRACE("unique hostname '%s'", hostHash); + return getHash(hostHash, strlen(hostHash)); +} + +uint64_t getHostHash(void) { + thread_local std::unique_ptr hostHash = std::make_unique(computeHostHash()); + // avoid crash on static destruction + if (hostHash == nullptr) { + hostHash = std::make_unique(computeHostHash()); + } + return *hostHash; +} + +/* Generate a hash of the unique identifying string for this process + * that will be unique for both bare-metal and container instances + * Equivalent of a hash of; + * + * $$ $(readlink /proc/self/ns/pid) + */ +uint64_t computePidHash(void) { + char pname[1024]; + // Start off with our pid ($$) + std::snprintf(pname, sizeof(pname), "%ld", (long)getpid()); + int plen = strlen(pname); + int len = readlink("/proc/self/ns/pid", pname + plen, sizeof(pname) - 1 - plen); + if (len < 0) len = 0; + + pname[plen + len] = '\0'; + TRACE("unique PID '%s'", pname); + + return getHash(pname, strlen(pname)); +} + +uint64_t getPidHash(void) { + thread_local std::unique_ptr pidHash = std::make_unique(computePidHash()); + // avoid crash on static destruction + if (pidHash == nullptr) { + pidHash = std::make_unique(computePidHash()); + } + return *pidHash; +} + +int parseStringList(const char* string, netIf* ifList, int maxList) { + if (!string) return 0; + + const char* ptr = string; + + int ifNum = 0; + int ifC = 0; + char c; + do { + c = *ptr; + if (c == ':') { + if (ifC > 0) { + ifList[ifNum].prefix[ifC] = '\0'; + ifList[ifNum].port = atoi(ptr + 1); + ifNum++; + ifC = 0; + } + while (c != ',' && c != '\0') c = *(++ptr); + } else if (c == ',' || c == '\0') { + if (ifC > 0) { + ifList[ifNum].prefix[ifC] = '\0'; + ifList[ifNum].port = -1; + ifNum++; + ifC = 0; + } + } else { + ifList[ifNum].prefix[ifC] = c; + ifC++; + } + ptr++; + } while (ifNum < maxList && c); + return ifNum; +} + +bool matchIfList(const char* string, int port, netIf* ifList, int listSize, bool matchExact) { + // Make an exception for the case where no user list is defined + if (listSize == 0) return true; + + for (int i = 0; i < listSize; i++) { + if (matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) { + return true; + } + } + return false; +} + +/* get any bytes of random data from /dev/urandom */ +void getRandomData(void* buffer, size_t bytes) { + if (bytes > 0) { + const size_t one = 1UL; + FILE* fp = fopen("/dev/urandom", "r"); + if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) { + ERROR("Failed to read random data\n"); + return; + } + if (fp) fclose(fp); + } +} + +} // namespace rocshmem + +// Throw upon SIGALRM. +static void sigalrmTimeoutHandler(int) { + signal(SIGALRM, SIG_IGN); + //throw mscclpp::Error("Timer timed out", ErrorCode::Timeout); + ERROR("Timer timed out\n"); + return; +} + +namespace rocshmem { + +Timer::Timer(int timeout) { set(timeout); } + +Timer::~Timer() { + if (timeout_ > 0) { + alarm(0); + signal(SIGALRM, SIG_DFL); + } +} + +int64_t Timer::elapsed() const { + auto end = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(end - start_).count(); +} + +void Timer::set(int timeout) { + timeout_ = timeout; + if (timeout > 0) { + signal(SIGALRM, sigalrmTimeoutHandler); + alarm(timeout); + } + start_ = std::chrono::steady_clock::now(); +} + +void Timer::reset() { set(timeout_); } + +void Timer::print(const std::string& name) { + auto us = elapsed(); + printf("%s : %ld\n", name.c_str(), us); +} + +ScopedTimer::ScopedTimer(const std::string& name) : name_(name) {} + +ScopedTimer::~ScopedTimer() { print(name_); } + +std::string getHostName(int maxlen, const char delim) { + std::string hostname(maxlen + 1, '\0'); + if (gethostname(const_cast(hostname.data()), maxlen) != 0) { + ERROR("gethostname failed\n"); + return nullptr; + } + int i = 0; + while ((hostname[i] != delim) && (hostname[i] != '\0') && + (i < maxlen - 1)) i++; + hostname[i] = '\0'; + return hostname.substr(0, i); +} + +} // namespace rocshmem diff --git a/src/bootstrap/utils.hpp b/src/bootstrap/utils.hpp new file mode 100644 index 0000000000..08350763e3 --- /dev/null +++ b/src/bootstrap/utils.hpp @@ -0,0 +1,93 @@ +// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. +// Modifications Copyright (c) Microsoft Corporation. +// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. +// Licensed under the MIT License. + +#ifndef ROCSHMEM_UTILS_HPP_ +#define ROCSHMEM_UTILS_HPP_ + +#include +#include +#include + +#define ERROR(...) { fprintf(stderr, __VA_ARGS__); abort(); } + +#ifdef ROCSHMEM_ENABLE_TRACE +#define TRACE(...) printf(__VA_ARGS__) +#else +#define TRACE(...) +#endif + +#if defined ROCSHMEM_ENABLE_INFO +#define INFO(FLAGS, ...)printf(__VA_ARGS__) +#else +#define INFO(...) +#endif + +namespace rocshmem { + +struct Timer { + std::chrono::steady_clock::time_point start_; + int timeout_; + + Timer(int timeout = -1); + + ~Timer(); + + /// Returns the elapsed time in microseconds. + int64_t elapsed() const; + + void set(int timeout); + + void reset(); + + void print(const std::string& name); +}; + +struct ScopedTimer : public Timer { + const std::string name_; + + ScopedTimer(const std::string& name); + + ~ScopedTimer(); +}; + +std::string getHostName(int maxlen, const char delim); + +// PCI Bus ID <-> int64 conversion functions +std::string int64ToBusId(int64_t id); +int64_t busIdToInt64(const std::string busId); + +uint64_t getHash(const char* string, int n); +uint64_t getHostHash(); +uint64_t getPidHash(); +void getRandomData(void* buffer, size_t bytes); + +struct netIf { + char prefix[64]; + int port; +}; + +int parseStringList(const char* string, struct netIf* ifList, int maxList); +bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact); + +template +inline void hashCombine(std::size_t& hash, const T& v) { + std::hash hasher; + hash ^= hasher(v) + 0x9e3779b9 + (hash << 6) + (hash >> 2); +} + +struct PairHash { + public: + template + std::size_t operator()(const std::pair& x) const { + std::size_t hash = 0; + hashCombine(hash, x.first); + hashCombine(hash, x.second); + return hash; + } +}; + +} // namespace rocshmem + +#endif // ROCSHMEM_UTILS_HPP diff --git a/src/rocshmem.cpp b/src/rocshmem.cpp index 1932e221e7..49be967068 100644 --- a/src/rocshmem.cpp +++ b/src/rocshmem.cpp @@ -49,6 +49,7 @@ #include "team.hpp" #include "templates_host.hpp" #include "util.hpp" +#include "bootstrap/bootstrap.hpp" #include @@ -102,7 +103,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; [[maybe_unused]] __host__ int rocshmem_init_attr(unsigned int flags, rocshmem_init_attr_t *attr) { - MPI_Comm comm = MPI_COMM_WORLD; + MPI_Comm comm = MPI_COMM_NULL; if ((attr == nullptr) || ((flags != ROCSHMEM_INIT_WITH_UNIQUEID) && @@ -115,22 +116,69 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; if (flags == ROCSHMEM_INIT_WITH_MPI_COMM) { comm = *(static_cast(attr->mpi_comm)); + library_init(comm); + return ROCSHMEM_SUCCESS; } - // As of right now, we require initialization through the MPI library. - library_init(comm); - - // The unique Id can be used to verify that the processes participating matches - // (i.e. they all need to have the same unique Id, as well as the number of ranks. if (flags == ROCSHMEM_INIT_WITH_UNIQUEID) { - int worldsize = backend->getNumPEs(); - if (worldsize != attr->nranks) { - fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", - "Call 'rocshmem_init_attr: mismatch between world-team size and " - "attribute value'", __FILE__, __LINE__); - // This is a fatal error, a fundamental mismatch between what was requested - // and what we have. - abort(); + int initialized; + int world_size = -1; + MPI_Initialized(&initialized); + + if (!initialized) { + // This is an Open MPI specific solution to retrieve the number of + // processes that have been started, value can be checked before MPI_Init + char *value = getenv("OMPI_COMM_WORLD_SIZE"); + if (value != NULL) { + world_size = atoi(value); + } + if (world_size != attr->nranks) { + // This solution will require MPI_Sessions. This is planned for the + // future, but is not supported in the current version. + fprintf (stderr, "Unsupported configuration to initialize rocSHMEM. Please " + "initialize the MPI library using MPI_Init first, if you want to " + "initialize rocSHMEM with a subset of the processes\n"); + abort(); + } + } else { + MPI_Comm_size (MPI_COMM_WORLD, &world_size); + } + + if (world_size == attr->nranks) { + library_init(MPI_COMM_WORLD); + } else { + MPI_Group world_group; + int world_rank; + + MPI_Comm_rank (MPI_COMM_WORLD, &world_rank); + MPI_Comm_group (MPI_COMM_WORLD, &world_group); + + TcpBootstrap bootstr(attr->rank, attr->nranks); + + int timeout = 5; + char *value; + value = getenv("ROCSHMEM_BOOTSTRAP_TIMEOUT"); + if (value != nullptr) { + timeout = atoi(value); + } + + bootstr.initialize(attr->uid, timeout); + int *inc_ranks = new int[attr->nranks]; + inc_ranks[attr->rank] = world_rank; + + bootstr.allGather (inc_ranks, sizeof(int)); + + MPI_Group sub_group; + MPI_Comm sub_comm; + MPI_Group_incl (world_group, attr->nranks, inc_ranks, &sub_group); + MPI_Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm); + + library_init(sub_comm); + + MPI_Group_free (&sub_group); + MPI_Group_free (&world_group); + MPI_Comm_free (&sub_comm); + delete[] inc_ranks; } } @@ -158,6 +206,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; // Note: this function will be called before rocshmem_init_*, so one // cannot assume that a backend is already set [[maybe_unused]] __host__ int rocshmem_get_uniqueid(rocshmem_uniqueid_t *uid) { + rocshmem_uniqueid_t tuid; if (uid == nullptr) { fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", "Call 'rocshmem_get_uniqueid: invalid input argument'", @@ -165,21 +214,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; return ROCSHMEM_ERROR; } - std::random_device dev; - std::mt19937_64 rng(dev()); - std::uniform_int_distribution dist(0, std::numeric_limits::max()); - - char hostname[HOST_NAME_MAX+1]; - if (0 != gethostname(hostname, HOST_NAME_MAX)) { - fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", - "Call 'rocshmem_get_uniqueid: could not get hostname'", - __FILE__, __LINE__); - return ROCSHMEM_ERROR; - } - - uid->random = dist(rng); - std::memcpy(uid->hostname, hostname, ROCSHMEM_HOSTNAME_LEN); - uid->pid = static_cast(getpid()); + tuid = TcpBootstrap::createUniqueId(); + *uid = tuid; return ROCSHMEM_SUCCESS; }