Revamp the uniqueId code to support subgroups of processes (#80)
* add code for bootstrapping the bootstrapping code has been extracted from the MSCCLPP library, which in parts is based on the code from NVIDIA. The code has been modified to match the specific requirements of the rocSHMEM library. * add code to use the new uniqueId bootstrapping * adjust init_attr example extend the rocshmem_init_attr example to use two disjoint groups of processe, in order to trigger the new code path. * add env variable for bootstrap timeout * Update examples/rocshmem_init_attr_test.cc Co-authored-by: Aurelien Bouteiller <Aurelien.bouteiller@gmail.com> * Update src/rocshmem.cpp Co-authored-by: Aurelien Bouteiller <Aurelien.bouteiller@gmail.com> --------- Co-authored-by: Aurelien Bouteiller <Aurelien.bouteiller@gmail.com>
This commit is contained in:
@@ -30,24 +30,35 @@ using namespace rocshmem;
|
||||
|
||||
int main (int argc, char **argv)
|
||||
{
|
||||
int rank, nranks;
|
||||
int world_rank, world_nranks;
|
||||
int ret;
|
||||
rocshmem_uniqueid_t uid;
|
||||
rocshmem_init_attr_t attr;
|
||||
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank (MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size (MPI_COMM_WORLD, &nranks);
|
||||
MPI_Comm_rank (MPI_COMM_WORLD, &world_rank);
|
||||
MPI_Comm_size (MPI_COMM_WORLD, &world_nranks);
|
||||
|
||||
// Create two disjoint groups of processes, each
|
||||
// one creating a unique rocshmem environment independent
|
||||
// of the other group
|
||||
MPI_Comm newcomm;
|
||||
int color = world_rank %2;
|
||||
int rank, nranks;
|
||||
|
||||
MPI_Comm_split(MPI_COMM_WORLD, color, world_rank, &newcomm);
|
||||
MPI_Comm_rank (newcomm, &rank);
|
||||
MPI_Comm_size (newcomm, &nranks);
|
||||
|
||||
if (rank == 0) {
|
||||
ret = rocshmem_get_uniqueid (&uid);
|
||||
if (ret != ROCSHMEM_SUCCESS) {
|
||||
std::cout << rank << ": Error in rocshmem_get_uniqueid. Aborting.\n";
|
||||
MPI_Abort (MPI_COMM_WORLD, ret);
|
||||
std::cout << rank << ": Error in rocshmem_get_uniqueid. Aborting.\n";
|
||||
MPI_Abort (MPI_COMM_WORLD, ret);
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Bcast (&uid, sizeof(rocshmem_uniqueid_t), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
MPI_Bcast (&uid, sizeof(rocshmem_uniqueid_t), MPI_BYTE, 0, newcomm);
|
||||
ret = rocshmem_set_attr_uniqueid_args(rank, nranks, &uid, &attr);
|
||||
if (ret != ROCSHMEM_SUCCESS) {
|
||||
std::cout << rank << ": Error in rocshmem_set_attr_uniqueid_args. Aborting.\n";
|
||||
@@ -63,6 +74,7 @@ int main (int argc, char **argv)
|
||||
std::cout << rank << ": rocshmem_init_attr SUCCESS\n";
|
||||
|
||||
rocshmem_finalize();
|
||||
MPI_Comm_free (&newcomm);
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -128,13 +128,11 @@ const rocshmem_team_t ROCSHMEM_TEAM_INVALID = nullptr;
|
||||
/**
|
||||
* @brief Data structure defining the unqiueId
|
||||
*/
|
||||
constexpr int ROCSHMEM_HOSTNAME_LEN = 20;
|
||||
struct rocshmem_uniqueid_t {
|
||||
uint64_t random;
|
||||
char hostname[ROCSHMEM_HOSTNAME_LEN];
|
||||
uint32_t pid;
|
||||
};
|
||||
typedef struct rocshmem_uniqueid_t rocshmem_unique_id_t;
|
||||
|
||||
|
||||
/// Unique ID for a process. This is a ROCSHMEM_UNIQUE_ID_BYTES byte array that uniquely identifies a process.
|
||||
#define ROCSHMEM_UNIQUE_ID_BYTES 128
|
||||
using rocshmem_uniqueid_t = std::array<uint8_t, ROCSHMEM_UNIQUE_ID_BYTES>;
|
||||
|
||||
/**
|
||||
* @brief Data structure used for attribute based
|
||||
|
||||
@@ -67,3 +67,4 @@ add_subdirectory(containers)
|
||||
add_subdirectory(host)
|
||||
add_subdirectory(memory)
|
||||
add_subdirectory(sync)
|
||||
add_subdirectory(bootstrap)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
###############################################################################
|
||||
# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
###############################################################################
|
||||
|
||||
###############################################################################
|
||||
# ADD ROCSHMEM TARGET FOR FILES IN CURRENT DIRECTORY
|
||||
###############################################################################
|
||||
target_sources(
|
||||
${PROJECT_NAME}
|
||||
PRIVATE
|
||||
socket.cpp
|
||||
bootstrap.cpp
|
||||
env.cpp
|
||||
utils.cpp
|
||||
)
|
||||
@@ -0,0 +1,589 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <sys/resource.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "bootstrap.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "socket.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
static void setFilesLimit() {
|
||||
rlimit filesLimit;
|
||||
if (getrlimit(RLIMIT_NOFILE, &filesLimit) != 0) {
|
||||
INFO("getrlimit failed\n");
|
||||
return;
|
||||
}
|
||||
filesLimit.rlim_cur = filesLimit.rlim_max;
|
||||
if (setrlimit(RLIMIT_NOFILE, &filesLimit) != 0) {
|
||||
INFO("setrlimit failed\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/* Socket Interface Selection type */
|
||||
enum bootstrapInterface_t { findSubnetIf = -1, dontCareIf = -2 };
|
||||
|
||||
struct ExtInfo {
|
||||
int rank;
|
||||
int nRanks;
|
||||
SocketAddress extAddressListenRoot;
|
||||
SocketAddress extAddressListen;
|
||||
};
|
||||
|
||||
void Bootstrap::groupBarrier(const std::vector<int>& ranks) {
|
||||
int dummy = 0;
|
||||
for (auto rank : ranks) {
|
||||
if (rank != this->getRank()) {
|
||||
this->send(static_cast<void*>(&dummy), sizeof(dummy), rank, 0);
|
||||
}
|
||||
}
|
||||
for (auto rank : ranks) {
|
||||
if (rank != this->getRank()) {
|
||||
this->recv(static_cast<void*>(&dummy), sizeof(dummy), rank, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Bootstrap::send(const std::vector<char>& data, int peer, int tag) {
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag + 1);
|
||||
}
|
||||
|
||||
void Bootstrap::recv(std::vector<char>& data, int peer, int tag) {
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
recv((void*)data.data(), data.size(), peer, tag + 1);
|
||||
}
|
||||
|
||||
struct UniqueIdInternal {
|
||||
uint64_t magic;
|
||||
union SocketAddress addr;
|
||||
};
|
||||
static_assert(sizeof(UniqueIdInternal) <= sizeof(rocshmem_uniqueid_t), "UniqueIdInternal is too large to fit into rocshmem_uniqueid_t");
|
||||
|
||||
class TcpBootstrap::Impl {
|
||||
public:
|
||||
static rocshmem_uniqueid_t createUniqueId();
|
||||
static rocshmem_uniqueid_t getUniqueId(const UniqueIdInternal& uniqueId);
|
||||
|
||||
Impl(int rank, int nRanks);
|
||||
~Impl();
|
||||
void initialize(const rocshmem_uniqueid_t& uniqueId, int64_t timeoutSec);
|
||||
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec);
|
||||
void establishConnections(int64_t timeoutSec);
|
||||
rocshmem_uniqueid_t getUniqueId() const;
|
||||
int getRank();
|
||||
int getNranks();
|
||||
int getNranksPerNode();
|
||||
void allGather(void* allData, int size);
|
||||
void send(void* data, int size, int peer, int tag);
|
||||
void recv(void* data, int size, int peer, int tag);
|
||||
void barrier();
|
||||
void close();
|
||||
|
||||
private:
|
||||
UniqueIdInternal uniqueId_;
|
||||
int rank_;
|
||||
int nRanks_;
|
||||
int nRanksPerNode_;
|
||||
bool netInitialized;
|
||||
std::unique_ptr<Socket> listenSockRoot_;
|
||||
std::unique_ptr<Socket> listenSock_;
|
||||
std::unique_ptr<Socket> ringRecvSocket_;
|
||||
std::unique_ptr<Socket> ringSendSocket_;
|
||||
std::vector<SocketAddress> peerCommAddresses_;
|
||||
std::vector<int> barrierArr_;
|
||||
std::unique_ptr<uint32_t> abortFlagStorage_;
|
||||
volatile uint32_t* abortFlag_;
|
||||
std::thread rootThread_;
|
||||
SocketAddress netIfAddr_;
|
||||
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerSendSockets_;
|
||||
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerRecvSockets_;
|
||||
|
||||
void netSend(Socket* sock, const void* data, int size);
|
||||
void netRecv(Socket* sock, void* data, int size);
|
||||
|
||||
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
|
||||
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);
|
||||
|
||||
static void assignPortToUniqueId(UniqueIdInternal& uniqueId);
|
||||
static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr);
|
||||
|
||||
void bootstrapCreateRoot();
|
||||
void bootstrapRoot();
|
||||
void getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
|
||||
std::vector<SocketAddress>& rankAddressesRoot, int& rank);
|
||||
void sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
|
||||
const std::vector<SocketAddress>& rankAddressesRoot);
|
||||
};
|
||||
|
||||
rocshmem_uniqueid_t TcpBootstrap::Impl::createUniqueId() {
|
||||
UniqueIdInternal uniqueId;
|
||||
SocketAddress netIfAddr;
|
||||
netInit("", "", netIfAddr);
|
||||
getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic));
|
||||
std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress));
|
||||
assignPortToUniqueId(uniqueId);
|
||||
return getUniqueId(uniqueId);
|
||||
}
|
||||
|
||||
rocshmem_uniqueid_t TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) {
|
||||
rocshmem_uniqueid_t ret;
|
||||
std::memcpy(&ret, &uniqueId, sizeof(uniqueId));
|
||||
return ret;
|
||||
}
|
||||
|
||||
TcpBootstrap::Impl::Impl(int rank, int nRanks)
|
||||
: rank_(rank),
|
||||
nRanks_(nRanks),
|
||||
nRanksPerNode_(0),
|
||||
netInitialized(false),
|
||||
peerCommAddresses_(nRanks, SocketAddress()),
|
||||
barrierArr_(nRanks, 0),
|
||||
abortFlagStorage_(new uint32_t(0)),
|
||||
abortFlag_(abortFlagStorage_.get()) {}
|
||||
|
||||
rocshmem_uniqueid_t TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); }
|
||||
|
||||
int TcpBootstrap::Impl::getRank() { return rank_; }
|
||||
|
||||
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
|
||||
|
||||
void TcpBootstrap::Impl::initialize(const rocshmem_uniqueid_t& uniqueId, int64_t timeoutSec) {
|
||||
if (!netInitialized) {
|
||||
netInit("", "", netIfAddr_);
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
|
||||
if (rank_ == 0) {
|
||||
bootstrapCreateRoot();
|
||||
}
|
||||
|
||||
char line[MAX_IF_NAME_SIZE + 1];
|
||||
SocketToString(&uniqueId_.addr, line);
|
||||
TRACE("rank %d nranks %d - connecting to %s\n", rank_, nRanks_, line);
|
||||
establishConnections(timeoutSec);
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) {
|
||||
// first check if it is a trio
|
||||
int nColons = 0;
|
||||
for (auto c : ifIpPortTrio) {
|
||||
if (c == ':') {
|
||||
nColons++;
|
||||
}
|
||||
}
|
||||
std::string ipPortPair = ifIpPortTrio;
|
||||
std::string interface = "";
|
||||
if (nColons == 2) {
|
||||
// we know the <interface>
|
||||
interface = ifIpPortTrio.substr(0, ipPortPair.find_first_of(':'));
|
||||
ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1);
|
||||
}
|
||||
|
||||
if (!netInitialized) {
|
||||
netInit(ipPortPair, interface, netIfAddr_);
|
||||
netInitialized = true;
|
||||
}
|
||||
|
||||
uniqueId_.magic = 0xdeadbeef;
|
||||
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
|
||||
SocketGetAddrFromString(&uniqueId_.addr, ipPortPair.c_str());
|
||||
|
||||
if (rank_ == 0) {
|
||||
bootstrapCreateRoot();
|
||||
}
|
||||
|
||||
establishConnections(timeoutSec);
|
||||
}
|
||||
|
||||
TcpBootstrap::Impl::~Impl() {
|
||||
if (abortFlag_) {
|
||||
*abortFlag_ = 1;
|
||||
}
|
||||
if (rootThread_.joinable()) {
|
||||
rootThread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
|
||||
std::vector<SocketAddress>& rankAddressesRoot, int& rank) {
|
||||
ExtInfo info;
|
||||
SocketAddress zero;
|
||||
std::memset(&zero, 0, sizeof(SocketAddress));
|
||||
|
||||
{
|
||||
Socket sock(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
|
||||
sock.accept(listenSock);
|
||||
netRecv(&sock, &info, sizeof(info));
|
||||
}
|
||||
|
||||
if (this->nRanks_ != info.nRanks) {
|
||||
ERROR("Bootstrap Root : mismatch in rank count from procs %d : %d\n", this->nRanks_, info.nRanks);
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(SocketAddress)) != 0) {
|
||||
ERROR("Bootstrap Root : rank %d of %d has already checked in\n", info.rank, this->nRanks_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Save the connection handle for that rank
|
||||
rankAddressesRoot[info.rank] = info.extAddressListenRoot;
|
||||
rankAddresses[info.rank] = info.extAddressListen;
|
||||
rank = info.rank;
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
|
||||
const std::vector<SocketAddress>& rankAddressesRoot) {
|
||||
int next = (peer + 1) % nRanks_;
|
||||
Socket sock(&rankAddressesRoot[peer], uniqueId_.magic, SocketTypeBootstrap, abortFlag_);
|
||||
sock.connect();
|
||||
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::assignPortToUniqueId(UniqueIdInternal& uniqueId) {
|
||||
std::unique_ptr<Socket> socket = std::make_unique<Socket>(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap);
|
||||
socket->bind();
|
||||
uniqueId.addr = socket->getAddr();
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::bootstrapCreateRoot() {
|
||||
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
|
||||
listenSockRoot_->bindAndListen();
|
||||
uniqueId_.addr = listenSockRoot_->getAddr();
|
||||
|
||||
rootThread_ = std::thread([this]() {
|
||||
// try {
|
||||
bootstrapRoot();
|
||||
//} catch (const std::exception& e) {
|
||||
//if (abortFlag_ && *abortFlag_) r;
|
||||
//throw e;
|
||||
//}
|
||||
});
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::bootstrapRoot() {
|
||||
int numCollected = 0;
|
||||
std::vector<SocketAddress> rankAddresses(nRanks_, SocketAddress());
|
||||
// for initial rank <-> root information exchange
|
||||
std::vector<SocketAddress> rankAddressesRoot(nRanks_, SocketAddress());
|
||||
|
||||
std::memset(rankAddresses.data(), 0, sizeof(SocketAddress) * nRanks_);
|
||||
std::memset(rankAddressesRoot.data(), 0, sizeof(SocketAddress) * nRanks_);
|
||||
setFilesLimit();
|
||||
|
||||
TRACE("BEGIN bootstrapRoot\n");
|
||||
/* Receive addresses from all ranks */
|
||||
do {
|
||||
int rank;
|
||||
getRemoteAddresses(listenSockRoot_.get(), rankAddresses, rankAddressesRoot, rank);
|
||||
++numCollected;
|
||||
TRACE("Received connect from rank %d total %d/%d\n", rank, numCollected, nRanks_);
|
||||
} while (numCollected < nRanks_ && (!abortFlag_ || *abortFlag_ == 0));
|
||||
|
||||
if (abortFlag_ && *abortFlag_) {
|
||||
TRACE("ABORTED\n");
|
||||
return;
|
||||
}
|
||||
|
||||
TRACE("COLLECTED ALL %d HANDLES\n", nRanks_);
|
||||
|
||||
// Send the connect handle for the next rank in the AllGather ring
|
||||
for (int peer = 0; peer < nRanks_; ++peer) {
|
||||
sendHandleToPeer(peer, rankAddresses, rankAddressesRoot);
|
||||
}
|
||||
|
||||
TRACE("DONE bootstrapRoot\n");
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface,
|
||||
SocketAddress& netIfAddr) {
|
||||
char netIfName[MAX_IF_NAME_SIZE + 1];
|
||||
if (!ipPortPair.empty()) {
|
||||
if (interface != "") {
|
||||
// we know the <interface>
|
||||
int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str());
|
||||
if (ret <= 0) {
|
||||
ERROR("NET/Socket : No interface named %s found\n", interface.c_str());
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// we do not know the <interface> try to match it next
|
||||
SocketAddress remoteAddr;
|
||||
SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
|
||||
if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
|
||||
ERROR("NET/Socket : No usable listening interface found\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1);
|
||||
if (ret <= 0) {
|
||||
ERROR("TcpBootstrap : no socket interface found\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
|
||||
std::sprintf(line, " %s:", netIfName);
|
||||
SocketToString(&netIfAddr, line + strlen(line));
|
||||
TRACE("TcpBootstrap : Using%s", line);
|
||||
}
|
||||
|
||||
#define TIMEOUT(__exp) \
|
||||
do { \
|
||||
try { \
|
||||
__exp; \
|
||||
} catch (const Error& e) { \
|
||||
if (e.getErrorCode() == ErrorCode::Timeout) { \
|
||||
throw Error("TcpBootstrap connection timeout", ErrorCode::Timeout); \
|
||||
} \
|
||||
throw; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
|
||||
const int64_t connectionTimeoutUs = timeoutSec * 1000000;
|
||||
Timer timer;
|
||||
SocketAddress nextAddr;
|
||||
ExtInfo info;
|
||||
|
||||
TRACE("establishConnections: rank %d nranks %d\n", rank_, nRanks_);
|
||||
|
||||
auto getLeftTime = [&]() {
|
||||
if (connectionTimeoutUs < 0) {
|
||||
// no timeout: always return a large number
|
||||
return int64_t(1e9);
|
||||
}
|
||||
int64_t timeout = connectionTimeoutUs - timer.elapsed();
|
||||
if (timeout <= 0) {
|
||||
ERROR("TcpBootstrap connection timeout\n");
|
||||
return (long int)-1;
|
||||
}
|
||||
return timeout;
|
||||
};
|
||||
|
||||
info.rank = rank_;
|
||||
info.nRanks = nRanks_;
|
||||
|
||||
uint64_t magic = uniqueId_.magic;
|
||||
// Create socket for other ranks to contact me
|
||||
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
|
||||
listenSock_->bindAndListen();
|
||||
info.extAddressListen = listenSock_->getAddr();
|
||||
|
||||
{
|
||||
// Create socket for root to contact me
|
||||
Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
|
||||
lsock.bindAndListen();
|
||||
info.extAddressListenRoot = lsock.getAddr();
|
||||
|
||||
// stagger connection times to avoid an overload of the root
|
||||
auto randomSleep = [](int rank) {
|
||||
timespec tv;
|
||||
tv.tv_sec = rank / 1000;
|
||||
tv.tv_nsec = 1000000 * (rank % 1000);
|
||||
TRACE("rank %d delaying connection to root by %ld sec %ld nsec\n", rank,
|
||||
tv.tv_sec, tv.tv_nsec);
|
||||
(void)nanosleep(&tv, NULL);
|
||||
};
|
||||
if (nRanks_ > 128) {
|
||||
randomSleep(rank_);
|
||||
}
|
||||
|
||||
// send info on my listening socket to root
|
||||
{
|
||||
Socket sock(&uniqueId_.addr, magic, SocketTypeBootstrap, abortFlag_);
|
||||
//TIMEOUT(sock.connect(getLeftTime()));
|
||||
sock.connect(getLeftTime());
|
||||
netSend(&sock, &info, sizeof(info));
|
||||
}
|
||||
|
||||
// get info on my "next" rank in the bootstrap ring from root
|
||||
{
|
||||
Socket sock(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_);
|
||||
//TIMEOUT(sock.accept(&lsock, getLeftTime()));
|
||||
sock.accept(&lsock, getLeftTime());
|
||||
netRecv(&sock, &nextAddr, sizeof(SocketAddress));
|
||||
}
|
||||
}
|
||||
|
||||
ringSendSocket_ = std::make_unique<Socket>(&nextAddr, magic, SocketTypeBootstrap, abortFlag_);
|
||||
//TIMEOUT(ringSendSocket_->connect(getLeftTime()));
|
||||
ringSendSocket_->connect(getLeftTime());
|
||||
// Accept the connect request from the previous rank in the AllGather ring
|
||||
ringRecvSocket_ = std::make_unique<Socket>(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown,
|
||||
abortFlag_);
|
||||
//TIMEOUT(ringRecvSocket_->accept(listenSock_.get(), getLeftTime()));
|
||||
ringRecvSocket_->accept(listenSock_.get(), getLeftTime());
|
||||
|
||||
// AllGather all listen handlers
|
||||
peerCommAddresses_[rank_] = listenSock_->getAddr();
|
||||
allGather(peerCommAddresses_.data(), sizeof(SocketAddress));
|
||||
|
||||
TRACE("rank %d nranks %d - DONE\n", rank_, nRanks_);
|
||||
}
|
||||
|
||||
int TcpBootstrap::Impl::getNranksPerNode() {
|
||||
if (nRanksPerNode_ > 0) return nRanksPerNode_;
|
||||
int nRanksPerNode = 0;
|
||||
bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET;
|
||||
for (int i = 0; i < nRanks_; i++) {
|
||||
if (useIpv4) {
|
||||
if (peerCommAddresses_[i].sin.sin_addr.s_addr ==
|
||||
peerCommAddresses_[rank_].sin.sin_addr.s_addr) {
|
||||
nRanksPerNode++;
|
||||
}
|
||||
} else {
|
||||
if (std::memcmp(&(peerCommAddresses_[i].sin6.sin6_addr),
|
||||
&(peerCommAddresses_[rank_].sin6.sin6_addr),
|
||||
sizeof(in6_addr)) == 0) {
|
||||
nRanksPerNode++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nRanksPerNode_ = nRanksPerNode;
|
||||
return nRanksPerNode_;
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::allGather(void* allData, int size) {
|
||||
char* data = static_cast<char*>(allData);
|
||||
int rank = rank_;
|
||||
int nRanks = nRanks_;
|
||||
|
||||
TRACE("allGather: rank %d nranks %d size %d\n", rank, nRanks, size);
|
||||
|
||||
/* Simple ring based AllGather
|
||||
* At each step i receive data from (rank-i-1) from left
|
||||
* and send previous step's data from (rank-i) to right
|
||||
*/
|
||||
for (int i = 0; i < nRanks - 1; i++) {
|
||||
size_t rSlice = (rank - i - 1 + nRanks) % nRanks;
|
||||
size_t sSlice = (rank - i + nRanks) % nRanks;
|
||||
|
||||
// Send slice to the right
|
||||
netSend(ringSendSocket_.get(), data + sSlice * size, size);
|
||||
// Recv slice from the left
|
||||
netRecv(ringRecvSocket_.get(), data + rSlice * size, size);
|
||||
}
|
||||
|
||||
TRACE("allGather: rank %d nranks %d size %d - DONE\n", rank, nRanks, size);
|
||||
}
|
||||
|
||||
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) {
|
||||
auto it = peerSendSockets_.find(std::make_pair(peer, tag));
|
||||
if (it != peerSendSockets_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
auto sock = std::make_shared<Socket>(&peerCommAddresses_[peer], uniqueId_.magic,
|
||||
SocketTypeBootstrap, abortFlag_);
|
||||
sock->connect();
|
||||
netSend(sock.get(), &rank_, sizeof(int));
|
||||
netSend(sock.get(), &tag, sizeof(int));
|
||||
peerSendSockets_[std::make_pair(peer, tag)] = sock;
|
||||
return sock;
|
||||
}
|
||||
|
||||
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag) {
|
||||
auto it = peerRecvSockets_.find(std::make_pair(peer, tag));
|
||||
if (it != peerRecvSockets_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
for (;;) {
|
||||
auto sock = std::make_shared<Socket>(nullptr, ROCSHMEM_SOCKET_MAGIC, SocketTypeUnknown,
|
||||
abortFlag_);
|
||||
sock->accept(listenSock_.get());
|
||||
int recvPeer, recvTag;
|
||||
netRecv(sock.get(), &recvPeer, sizeof(int));
|
||||
netRecv(sock.get(), &recvTag, sizeof(int));
|
||||
peerRecvSockets_[std::make_pair(recvPeer, recvTag)] = sock;
|
||||
if (recvPeer == peer && recvTag == tag) {
|
||||
return sock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::netSend(Socket* sock, const void* data, int size) {
|
||||
sock->send(&size, sizeof(int));
|
||||
sock->send(const_cast<void*>(data), size);
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::netRecv(Socket* sock, void* data, int size) {
|
||||
int recvSize;
|
||||
sock->recv(&recvSize, sizeof(int));
|
||||
if (recvSize > size) {
|
||||
ERROR("Message truncated : received %d bytes instead of %d\n", recvSize, size);
|
||||
return;
|
||||
}
|
||||
sock->recv(data, std::min(recvSize, size));
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::send(void* data, int size, int peer, int tag) {
|
||||
auto sock = getPeerSendSocket(peer, tag);
|
||||
netSend(sock.get(), data, size);
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) {
|
||||
auto sock = getPeerRecvSocket(peer, tag);
|
||||
netRecv(sock.get(), data, size);
|
||||
}
|
||||
|
||||
void TcpBootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
|
||||
|
||||
void TcpBootstrap::Impl::close() {
|
||||
listenSockRoot_.reset(nullptr);
|
||||
listenSock_.reset(nullptr);
|
||||
ringRecvSocket_.reset(nullptr);
|
||||
ringSendSocket_.reset(nullptr);
|
||||
peerSendSockets_.clear();
|
||||
peerRecvSockets_.clear();
|
||||
}
|
||||
|
||||
rocshmem_uniqueid_t TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); }
|
||||
|
||||
TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
|
||||
|
||||
rocshmem_uniqueid_t TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
|
||||
|
||||
int TcpBootstrap::getRank() { return pimpl_->getRank(); }
|
||||
|
||||
int TcpBootstrap::getNranks() { return pimpl_->getNranks(); }
|
||||
|
||||
int TcpBootstrap::getNranksPerNode() { return pimpl_->getNranksPerNode(); }
|
||||
|
||||
void TcpBootstrap::send(void* data, int size, int peer, int tag) {
|
||||
pimpl_->send(data, size, peer, tag);
|
||||
}
|
||||
|
||||
void TcpBootstrap::recv(void* data, int size, int peer, int tag) {
|
||||
pimpl_->recv(data, size, peer, tag);
|
||||
}
|
||||
|
||||
void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }
|
||||
|
||||
void TcpBootstrap::initialize(rocshmem_uniqueid_t uniqueId, int64_t timeoutSec) {
|
||||
pimpl_->initialize(uniqueId, timeoutSec);
|
||||
}
|
||||
|
||||
void TcpBootstrap::initialize(const std::string& ipPortPair, int64_t timeoutSec) {
|
||||
pimpl_->initialize(ipPortPair, timeoutSec);
|
||||
}
|
||||
|
||||
void TcpBootstrap::barrier() { pimpl_->barrier(); }
|
||||
|
||||
TcpBootstrap::~TcpBootstrap() { pimpl_->close(); }
|
||||
|
||||
} // namespace rocshmem
|
||||
@@ -0,0 +1,123 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef ROCSHMEM_BOOTSTRAP_HPP_
|
||||
#define ROCSHMEM_BOOTSTRAP_HPP_
|
||||
|
||||
|
||||
#include <array>
|
||||
#include <bitset>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "rocshmem/rocshmem_common.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
/// Return a version string.
|
||||
std::string version();
|
||||
|
||||
/// Base class for bootstraps.
|
||||
class Bootstrap {
|
||||
public:
|
||||
Bootstrap(){};
|
||||
virtual ~Bootstrap() = default;
|
||||
virtual int getRank() = 0;
|
||||
virtual int getNranks() = 0;
|
||||
virtual int getNranksPerNode() = 0;
|
||||
virtual void send(void* data, int size, int peer, int tag) = 0;
|
||||
virtual void recv(void* data, int size, int peer, int tag) = 0;
|
||||
virtual void allGather(void* allData, int size) = 0;
|
||||
virtual void barrier() = 0;
|
||||
|
||||
void groupBarrier(const std::vector<int>& ranks);
|
||||
void send(const std::vector<char>& data, int peer, int tag);
|
||||
void recv(std::vector<char>& data, int peer, int tag);
|
||||
};
|
||||
|
||||
/// A native implementation of the bootstrap using TCP sockets.
|
||||
class TcpBootstrap : public Bootstrap {
|
||||
public:
|
||||
/// Create a random unique ID.
|
||||
/// @return The created unique ID.
|
||||
static rocshmem_uniqueid_t createUniqueId();
|
||||
|
||||
/// Constructor.
|
||||
/// @param rank The rank of the process.
|
||||
/// @param nRanks The total number of ranks.
|
||||
TcpBootstrap(int rank, int nRanks);
|
||||
|
||||
/// Destructor.
|
||||
~TcpBootstrap();
|
||||
|
||||
/// Return the unique ID stored in the @ref TcpBootstrap.
|
||||
/// @return The unique ID stored in the @ref TcpBootstrap.
|
||||
rocshmem_uniqueid_t getUniqueId() const;
|
||||
|
||||
/// Initialize the @ref TcpBootstrap with a given unique ID.
|
||||
/// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with.
|
||||
/// @param timeoutSec The connection timeout in seconds.
|
||||
void initialize(rocshmem_uniqueid_t uniqueId, int64_t timeoutSec = 30);
|
||||
|
||||
/// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port".
|
||||
/// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port".
|
||||
/// @param timeoutSec The connection timeout in seconds.
|
||||
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30);
|
||||
|
||||
/// Return the rank of the process.
|
||||
int getRank() override;
|
||||
|
||||
/// Return the total number of ranks.
|
||||
int getNranks() override;
|
||||
|
||||
/// Return the total number of ranks per node.
|
||||
int getNranksPerNode() override;
|
||||
|
||||
/// Send data to another process.
|
||||
///
|
||||
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
|
||||
/// senderRank, tag)`.
|
||||
///
|
||||
/// @param data The data to send.
|
||||
/// @param size The size of the data to send.
|
||||
/// @param peer The rank of the process to send the data to.
|
||||
/// @param tag The tag to send the data with.
|
||||
void send(void* data, int size, int peer, int tag) override;
|
||||
|
||||
/// Receive data from another process.
|
||||
///
|
||||
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
|
||||
/// senderRank, tag)`.
|
||||
///
|
||||
/// @param data The buffer to write the received data to.
|
||||
/// @param size The size of the data to receive.
|
||||
/// @param peer The rank of the process to receive the data from.
|
||||
/// @param tag The tag to receive the data with.
|
||||
void recv(void* data, int size, int peer, int tag) override;
|
||||
|
||||
/// Gather data from all processes.
|
||||
///
|
||||
/// When called by rank `r`, this sends data from `allData[r * size]` to `allData[(r + 1) * size - 1]` to all other
|
||||
/// ranks. The data sent by rank `r` is received into `allData[r * size]` of other ranks.
|
||||
///
|
||||
/// @param allData The buffer to write the received data to.
|
||||
/// @param size The size of the data each rank sends.
|
||||
void allGather(void* allData, int size) override;
|
||||
|
||||
/// Synchronize all processes.
|
||||
void barrier() override;
|
||||
|
||||
private:
|
||||
// The interal implementation.
|
||||
class Impl;
|
||||
|
||||
// Pointer to the internal implementation.
|
||||
std::unique_ptr<Impl> pimpl_;
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
#endif // ROCSHMEM_BOOTSTRAP_HPP_
|
||||
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc,
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <type_traits>
|
||||
|
||||
#include "env.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename T>
|
||||
T readEnv(const std::string &envName, const T &defaultValue) {
|
||||
const char *envCstr = getenv(envName.c_str());
|
||||
if (envCstr == nullptr) return defaultValue;
|
||||
if constexpr (std::is_same_v<T, int>) {
|
||||
return atoi(envCstr);
|
||||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
return (std::string(envCstr) != "0");
|
||||
}
|
||||
return T(envCstr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void readAndSetEnv(const std::string &envName, T &env) {
|
||||
const char *envCstr = getenv(envName.c_str());
|
||||
if (envCstr == nullptr) return;
|
||||
if constexpr (std::is_same_v<T, int>) {
|
||||
env = atoi(envCstr);
|
||||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
env = (std::string(envCstr) != "0");
|
||||
} else {
|
||||
env = std::string(envCstr);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void logEnv(const std::string &envName, const T &env) {
|
||||
if (!getenv(envName.c_str())) return;
|
||||
INFO("%s=%d", envName.c_str(), env);
|
||||
}
|
||||
|
||||
template <>
|
||||
void logEnv(const std::string &envName, const std::string &env) {
|
||||
if (!getenv(envName.c_str())) return;
|
||||
INFO("%s=%s", envName.c_str(), env.c_str());
|
||||
}
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
Env::Env()
|
||||
: debug(readEnv<std::string>("ROCSHMEM_DEBUG", "")),
|
||||
debugSubsys(readEnv<std::string>("ROCSHMEM_DEBUG_SUBSYS", "")),
|
||||
debugFile(readEnv<std::string>("ROCSHMEM_DEBUG_FILE", "")),
|
||||
hostid(readEnv<std::string>("ROCSHMEM_HOSTID", "")),
|
||||
socketFamily(readEnv<std::string>("ROCSHMEM_SOCKET_FAMILY", "")),
|
||||
socketIfname(readEnv<std::string>("ROCSHMEM_SOCKET_IFNAME", "")) {}
|
||||
|
||||
std::shared_ptr<Env> env() {
|
||||
static std::shared_ptr<Env> globalEnv = std::shared_ptr<Env>(new Env());
|
||||
static bool logged = false;
|
||||
if (!logged) {
|
||||
logged = true;
|
||||
// cannot log inside the constructor because of circular dependency
|
||||
logEnv("ROCSHMEM_DEBUG", globalEnv->debug);
|
||||
logEnv("ROCSHMEM_DEBUG_SUBSYS", globalEnv->debugSubsys);
|
||||
logEnv("ROCSHMEM_DEBUG_FILE", globalEnv->debugFile);
|
||||
logEnv("ROCSHMEM_HOSTID", globalEnv->hostid);
|
||||
logEnv("ROCSHMEM_SOCKET_FAMILY", globalEnv->socketFamily);
|
||||
logEnv("ROCSHMEM_SOCKET_IFNAME", globalEnv->socketIfname);
|
||||
}
|
||||
return globalEnv;
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc,
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef ROCSHMEM_ENV_HPP_
|
||||
#define ROCSHMEM_ENV_HPP_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
class Env;
|
||||
|
||||
/// Get the environment.
|
||||
/// @return A reference to the global environment object.
|
||||
std::shared_ptr<Env> env();
|
||||
|
||||
/// The constructor reads environment variables and sets the corresponding fields.
|
||||
/// Use the @ref env() function to get the environment object.
|
||||
class Env {
|
||||
public:
|
||||
const std::string debug;
|
||||
const std::string debugSubsys;
|
||||
const std::string debugFile;
|
||||
const std::string hostid;
|
||||
const std::string socketFamily;
|
||||
const std::string socketIfname;
|
||||
|
||||
private:
|
||||
Env();
|
||||
|
||||
friend std::shared_ptr<Env> env();
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
#endif // ROCSHMEM_ENV_HPP_
|
||||
@@ -0,0 +1,774 @@
|
||||
// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
// Modifications Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <errno.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "socket.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "env.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
#define ROCSHMEM_SOCKET_SEND 0
|
||||
#define ROCSHMEM_SOCKET_RECV 1
|
||||
|
||||
/* Format a string representation of a (union SocketAddress *)
|
||||
* socket address using getnameinfo()
|
||||
*
|
||||
* Output: "IPv4/IPv6 address<port>"
|
||||
*/
|
||||
const char* SocketToString(union SocketAddress* addr, char* buf,
|
||||
const int numericHostForm /*= 1*/) {
|
||||
if (buf == NULL || addr == NULL) return NULL;
|
||||
struct sockaddr* saddr = &addr->sa;
|
||||
if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
|
||||
buf[0] = '\0';
|
||||
return buf;
|
||||
}
|
||||
char host[NI_MAXHOST], service[NI_MAXSERV];
|
||||
/* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned.
|
||||
* (When not set, this will still happen in case the node's name cannot be determined.)
|
||||
*/
|
||||
int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0);
|
||||
(void)getnameinfo(saddr, sizeof(union SocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag);
|
||||
sprintf(buf, "%s<%s>", host, service);
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Equivalent with ($ cat /proc/sys/net/ipv4/tcp_fin_timeout)
|
||||
static int getTcpFinTimeout() {
|
||||
std::ifstream ifs("/proc/sys/net/ipv4/tcp_fin_timeout");
|
||||
if (!ifs.is_open()) {
|
||||
ERROR("open /proc/sys/net/ipv4/tcp_fin_timeout failed errno %d\n", errno);
|
||||
return -1;
|
||||
}
|
||||
int timeout;
|
||||
ifs >> timeout;
|
||||
return timeout;
|
||||
}
|
||||
|
||||
static uint16_t socketToPort(union SocketAddress* addr) {
|
||||
struct sockaddr* saddr = &addr->sa;
|
||||
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
|
||||
}
|
||||
|
||||
/* Allow the user to force the IPv4/IPv6 interface selection */
|
||||
static int envSocketFamily(void) {
|
||||
int family = -1; // Family selection is not forced, will use first one found
|
||||
const std::string& socketFamily = env()->socketFamily;
|
||||
if (socketFamily == "") return family;
|
||||
|
||||
if (socketFamily == "AF_INET")
|
||||
family = AF_INET; // IPv4
|
||||
else if (socketFamily == "AF_INET6")
|
||||
family = AF_INET6; // IPv6
|
||||
return family;
|
||||
}
|
||||
|
||||
static int findInterfaces(const char* prefixList, char* names, union SocketAddress* addrs,
|
||||
int sock_family, int maxIfNameSize, int maxIfs) {
|
||||
#ifdef ROCSHMEM_ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
struct netIf userIfs[MAX_IFS];
|
||||
bool searchNot = prefixList && prefixList[0] == '^';
|
||||
if (searchNot) prefixList++;
|
||||
bool searchExact = prefixList && prefixList[0] == '=';
|
||||
if (searchExact) prefixList++;
|
||||
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
|
||||
|
||||
int found = 0;
|
||||
struct ifaddrs *interfaces, *interface;
|
||||
getifaddrs(&interfaces);
|
||||
for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
|
||||
if (interface->ifa_addr == NULL) continue;
|
||||
|
||||
/* We only support IPv4 & IPv6 */
|
||||
int family = interface->ifa_addr->sa_family;
|
||||
if (family != AF_INET && family != AF_INET6) continue;
|
||||
|
||||
TRACE("Found interface %s:%s\n", interface->ifa_name,
|
||||
SocketToString((union SocketAddress*)interface->ifa_addr, line));
|
||||
|
||||
/* Allow the caller to force the socket family type */
|
||||
if (sock_family != -1 && family != sock_family) continue;
|
||||
|
||||
/* We also need to skip IPv6 loopback interfaces */
|
||||
if (family == AF_INET6) {
|
||||
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
|
||||
if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue;
|
||||
}
|
||||
|
||||
// check against user specified interfaces
|
||||
if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check that this interface has not already been saved
|
||||
// getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
|
||||
bool duplicate = false;
|
||||
for (int i = 0; i < found; i++) {
|
||||
if (strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) {
|
||||
duplicate = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!duplicate) {
|
||||
// Store the interface name
|
||||
strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize);
|
||||
// Store the IP address
|
||||
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
|
||||
memcpy(addrs + found, interface->ifa_addr, salen);
|
||||
found++;
|
||||
}
|
||||
}
|
||||
|
||||
freeifaddrs(interfaces);
|
||||
return found;
|
||||
}
|
||||
|
||||
static bool matchSubnet(struct ifaddrs local_if, union SocketAddress* remote) {
|
||||
/* Check family first */
|
||||
int family = local_if.ifa_addr->sa_family;
|
||||
if (family != remote->sa.sa_family) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (family == AF_INET) {
|
||||
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
|
||||
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
|
||||
struct sockaddr_in& remote_addr = remote->sin;
|
||||
struct in_addr local_subnet, remote_subnet;
|
||||
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
|
||||
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
|
||||
return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
|
||||
} else if (family == AF_INET6) {
|
||||
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
|
||||
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
|
||||
struct sockaddr_in6& remote_addr = remote->sin6;
|
||||
struct in6_addr& local_in6 = local_addr->sin6_addr;
|
||||
struct in6_addr& mask_in6 = mask->sin6_addr;
|
||||
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
|
||||
bool same = true;
|
||||
int len = 16; // IPv6 address is 16 unsigned char
|
||||
for (int c = 0; c < len; c++) { // Network byte order is big-endian
|
||||
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
|
||||
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
|
||||
if (c1 ^ c2) {
|
||||
same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// At last, we need to compare scope id
|
||||
// Two Link-type addresses can have the same subnet address even though they are not in the same scope
|
||||
// For Global type, this field is 0, so a comparison wouldn't matter
|
||||
same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
|
||||
return same;
|
||||
} else {
|
||||
ERROR("Net : Unsupported address family type\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int FindInterfaceMatchSubnet(char* ifNames, union SocketAddress* localAddrs, union SocketAddress* remoteAddr,
|
||||
int ifNameMaxSize, int maxIfs) {
|
||||
#ifdef ROCSHMEM_ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
char line_a[SOCKET_NAME_MAXLEN + 1];
|
||||
int found = 0;
|
||||
struct ifaddrs *interfaces, *interface;
|
||||
getifaddrs(&interfaces);
|
||||
for (interface = interfaces; interface && !found; interface = interface->ifa_next) {
|
||||
if (interface->ifa_addr == NULL) continue;
|
||||
|
||||
/* We only support IPv4 & IPv6 */
|
||||
int family = interface->ifa_addr->sa_family;
|
||||
if (family != AF_INET && family != AF_INET6) continue;
|
||||
|
||||
// check against user specified interfaces
|
||||
if (!matchSubnet(*interface, remoteAddr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Store the local IP address
|
||||
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
|
||||
memcpy(localAddrs + found, interface->ifa_addr, salen);
|
||||
|
||||
// Store the interface name
|
||||
strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
|
||||
|
||||
TRACE("NET : Found interface %s:%s in the same subnet as remote address %s\n",
|
||||
interface->ifa_name, SocketToString(localAddrs + found, line), SocketToString(remoteAddr, line_a));
|
||||
found++;
|
||||
if (found == maxIfs) break;
|
||||
}
|
||||
|
||||
if (found == 0) {
|
||||
ERROR("Net : No interface found in the same subnet as remote address %s\n",
|
||||
SocketToString(remoteAddr, line_a));
|
||||
}
|
||||
freeifaddrs(interfaces);
|
||||
return found;
|
||||
}
|
||||
|
||||
void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair) {
|
||||
if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
|
||||
ERROR("Net : string is null\n");
|
||||
return;
|
||||
}
|
||||
|
||||
bool ipv6 = ip_port_pair[0] == '[';
|
||||
/* Construct the sockaddress structure */
|
||||
if (!ipv6) {
|
||||
struct netIf ni;
|
||||
// parse <ip_or_hostname>:<port> string, expect one pair
|
||||
if (parseStringList(ip_port_pair, &ni, 1) != 1) {
|
||||
ERROR("Net : No valid <IPv4_or_hostname>:<port> pair found\n");
|
||||
return;
|
||||
}
|
||||
|
||||
struct addrinfo hints, *p;
|
||||
int rv;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
if ((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
|
||||
ERROR("Net : error encountered when getting address info : %s\n", gai_strerror(rv));
|
||||
return;
|
||||
}
|
||||
|
||||
// use the first
|
||||
if (p->ai_family == AF_INET) {
|
||||
struct sockaddr_in& sin = ua->sin;
|
||||
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
|
||||
sin.sin_family = AF_INET; // IPv4
|
||||
// inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
|
||||
sin.sin_port = htons(ni.port); // port
|
||||
} else if (p->ai_family == AF_INET6) {
|
||||
struct sockaddr_in6& sin6 = ua->sin6;
|
||||
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
|
||||
sin6.sin6_family = AF_INET6; // IPv6
|
||||
sin6.sin6_port = htons(ni.port); // port
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = 0; // should be global scope, set to 0
|
||||
} else {
|
||||
ERROR("Net : unsupported IP family\n");
|
||||
return;
|
||||
}
|
||||
|
||||
freeaddrinfo(p); // all done with this structure
|
||||
|
||||
} else {
|
||||
int i, j = -1, len = strlen(ip_port_pair);
|
||||
for (i = 1; i < len; i++) {
|
||||
if (ip_port_pair[i] == '%') j = i;
|
||||
if (ip_port_pair[i] == ']') break;
|
||||
}
|
||||
if (i == len) {
|
||||
ERROR("Net : No valid [IPv6]:port pair found\n");
|
||||
return;
|
||||
}
|
||||
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
|
||||
|
||||
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
|
||||
memset(ip_str, '\0', sizeof(ip_str));
|
||||
memset(port_str, '\0', sizeof(port_str));
|
||||
memset(if_name, '\0', sizeof(if_name));
|
||||
strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
|
||||
strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
|
||||
int port = atoi(port_str);
|
||||
|
||||
// If not global scope, we need the intf name
|
||||
if (!global_scope)
|
||||
strncpy(if_name, ip_port_pair + j + 1, i - j - 1);
|
||||
|
||||
struct sockaddr_in6& sin6 = ua->sin6;
|
||||
sin6.sin6_family = AF_INET6; // IPv6
|
||||
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
|
||||
sin6.sin6_port = htons(port); // port
|
||||
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
|
||||
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
|
||||
}
|
||||
}
|
||||
|
||||
int FindInterfaces(char* ifNames, union SocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs,
|
||||
const char* inputIfName) {
|
||||
static int shownIfName = 0;
|
||||
int nIfs = 0;
|
||||
|
||||
// Allow user to force the INET socket family selection
|
||||
int sock_family = envSocketFamily();
|
||||
|
||||
// User specified interface
|
||||
const std::string& socketIfname = env()->socketIfname;
|
||||
if (inputIfName) {
|
||||
TRACE("using iterface %s", inputIfName);
|
||||
nIfs = findInterfaces(inputIfName, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
|
||||
} else if (socketIfname != "") {
|
||||
// Specified by user : find or fail
|
||||
if (shownIfName++ == 0) TRACE ("ROCSHMEM_SOCKET_IFNAME set to %s", socketIfname.c_str());
|
||||
nIfs = findInterfaces(socketIfname.c_str(), ifNames, ifAddrs, sock_family,
|
||||
ifNameMaxSize, maxIfs);
|
||||
} else {
|
||||
// Try to automatically pick the right one
|
||||
// Look for anything (but not docker or lo)
|
||||
if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family,
|
||||
ifNameMaxSize, maxIfs);
|
||||
// Finally look for docker, then lo.
|
||||
if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family,
|
||||
ifNameMaxSize, maxIfs);
|
||||
if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family,
|
||||
ifNameMaxSize, maxIfs);
|
||||
}
|
||||
return nIfs;
|
||||
}
|
||||
|
||||
Socket::Socket(const SocketAddress* addr, uint64_t magic, enum SocketType type, volatile uint32_t* abortFlag,
|
||||
int asyncFlag) {
|
||||
fd_ = -1;
|
||||
acceptFd_ = -1;
|
||||
connectRetries_ = 0;
|
||||
acceptRetries_ = 0;
|
||||
abortFlag_ = abortFlag;
|
||||
asyncFlag_ = asyncFlag;
|
||||
state_ = SocketStateInitialized;
|
||||
magic_ = magic;
|
||||
type_ = type;
|
||||
|
||||
if (addr) {
|
||||
/* IPv4/IPv6 support */
|
||||
int family;
|
||||
memcpy(&addr_, addr, sizeof(union SocketAddress));
|
||||
family = addr_.sa.sa_family;
|
||||
if (family != AF_INET && family != AF_INET6) {
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
ERROR("SocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)\n",
|
||||
SocketToString(&addr_, line), family, (int)AF_INET, (int)AF_INET6);
|
||||
return;
|
||||
}
|
||||
salen_ = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
|
||||
|
||||
/* Connect to a hostname / port */
|
||||
fd_ = ::socket(family, SOCK_STREAM, 0);
|
||||
if (fd_ == -1) {
|
||||
ERROR("socket creation failed %d\n", errno);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
memset(&addr_, 0, sizeof(union SocketAddress));
|
||||
}
|
||||
|
||||
/* Set socket as non-blocking if async or if we need to be able to abort */
|
||||
if ((asyncFlag_ || abortFlag_) && fd_ >= 0) {
|
||||
int flags = fcntl(fd_, F_GETFL);
|
||||
if (flags == -1) {
|
||||
ERROR("fcntl(F_GETFL) failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
if (fcntl(fd_, F_SETFL, flags | O_NONBLOCK) == -1) {
|
||||
ERROR("fcntl(F_SETFL) failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Socket::~Socket() { close(); }
|
||||
|
||||
void Socket::bind() {
|
||||
if (fd_ == -1) {
|
||||
ERROR("file descriptor is -1\n");
|
||||
return;
|
||||
}
|
||||
|
||||
if (socketToPort(&addr_)) {
|
||||
// Port is forced by env. Make sure we get the port.
|
||||
int opt = 1;
|
||||
#if defined(SO_REUSEPORT)
|
||||
if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)) != 0) {
|
||||
ERROR("::setsockopt(SO_REUSEADDR | SO_REUSEPORT) failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) != 0) {
|
||||
ERROR("setsockopt(SO_REUSEADDR) failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int finTimeout = getTcpFinTimeout();
|
||||
int retrySecs = finTimeout + 1;
|
||||
int remainSecs = retrySecs;
|
||||
|
||||
// addr port should be 0 (Any port)
|
||||
while (::bind(fd_, &addr_.sa, salen_) != 0) {
|
||||
// upon EADDRINUSE, retry up to for (finTimeout + 1) seconds
|
||||
if (errno != EADDRINUSE) {
|
||||
ERROR("bind failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
if (remainSecs > 0) {
|
||||
TRACE("No available ephemeral ports found, will retry after 1 second");
|
||||
sleep(1);
|
||||
remainSecs--;
|
||||
} else {
|
||||
ERROR("No available ephemeral ports found for %d seconds \n", retrySecs);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/* Get the assigned Port */
|
||||
socklen_t size = salen_;
|
||||
if (::getsockname(fd_, &addr_.sa, &size) != 0) {
|
||||
ERROR("getsockname failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
state_ = SocketStateBound;
|
||||
}
|
||||
|
||||
void Socket::bindAndListen() {
|
||||
#ifdef ROCSHMEM_ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
bind();
|
||||
TRACE("Listening on socket %s\n", SocketToString(&addr_, line));
|
||||
|
||||
/* Put the socket in listen mode
|
||||
* NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
|
||||
*/
|
||||
if (::listen(fd_, 16384) != 0) {
|
||||
ERROR("listen failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
state_ = SocketStateReady;
|
||||
}
|
||||
|
||||
void Socket::connect(int64_t timeout) {
|
||||
#ifdef ROCSHMEM_ENABLE_TRACE
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
#endif
|
||||
Timer timer;
|
||||
const int one = 1;
|
||||
|
||||
if (fd_ == -1) {
|
||||
ERROR("file descriptor is -1\n");
|
||||
return;
|
||||
}
|
||||
|
||||
if (state_ != SocketStateInitialized) {
|
||||
ERROR("wrong socket state %d\n", state_);
|
||||
return;
|
||||
}
|
||||
TRACE("Connecting to socket %s \n", SocketToString(&addr_, line));
|
||||
|
||||
if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)) != 0) {
|
||||
INFO("setsockopt(TCP_NODELAY) failed, errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
state_ = SocketStateConnecting;
|
||||
do {
|
||||
progressState();
|
||||
if (timeout > 0 && timer.elapsed() > timeout) {
|
||||
ERROR("connect timeout\n");
|
||||
return;
|
||||
}
|
||||
} while (asyncFlag_ == 0 && (abortFlag_ == NULL || *abortFlag_ == 0) &&
|
||||
(state_ == SocketStateConnecting || state_ == SocketStateConnectPolling || state_ == SocketStateConnected));
|
||||
|
||||
if (abortFlag_ && *abortFlag_ != 0) {
|
||||
ERROR("aborted\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::accept(const Socket* listenSocket, int64_t timeout) {
|
||||
Timer timer;
|
||||
|
||||
if (listenSocket == NULL) {
|
||||
ERROR("listenSocket is NULL\n");
|
||||
return;
|
||||
}
|
||||
if (listenSocket->getState() != SocketStateReady) {
|
||||
ERROR("listenSocket is in error state %u\n", listenSocket->getState());
|
||||
return;
|
||||
}
|
||||
|
||||
if (acceptFd_ == -1) {
|
||||
fd_ = listenSocket->getFd();
|
||||
connectRetries_ = listenSocket->getConnectRetries();
|
||||
acceptRetries_ = listenSocket->getAcceptRetries();
|
||||
abortFlag_ = listenSocket->getAbortFlag();
|
||||
asyncFlag_ = listenSocket->getAsyncFlag();
|
||||
magic_ = listenSocket->getMagic();
|
||||
type_ = listenSocket->getType();
|
||||
addr_ = listenSocket->getAddr();
|
||||
salen_ = listenSocket->getSalen();
|
||||
|
||||
acceptFd_ = listenSocket->getFd();
|
||||
state_ = SocketStateAccepting;
|
||||
}
|
||||
|
||||
do {
|
||||
progressState();
|
||||
if (timeout > 0 && timer.elapsed() > timeout) {
|
||||
ERROR("accept timeout\n");
|
||||
return;
|
||||
}
|
||||
} while (asyncFlag_ == 0 && (abortFlag_ == NULL || *abortFlag_ == 0) &&
|
||||
(state_ == SocketStateAccepting || state_ == SocketStateAccepted));
|
||||
|
||||
if (abortFlag_ && *abortFlag_ != 0) {
|
||||
ERROR("aborted\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::send(void* ptr, int size) {
|
||||
int offset = 0;
|
||||
if (state_ != SocketStateReady) {
|
||||
ERROR("socket state (%d) is not ready\n", state_);
|
||||
return;
|
||||
}
|
||||
socketWait(ROCSHMEM_SOCKET_SEND, ptr, size, &offset);
|
||||
}
|
||||
|
||||
void Socket::recv(void* ptr, int size) {
|
||||
int offset = 0;
|
||||
if (state_ != SocketStateReady) {
|
||||
ERROR("socket state (%d) is not read\n", state_);
|
||||
return;
|
||||
}
|
||||
socketWait(ROCSHMEM_SOCKET_RECV, ptr, size, &offset);
|
||||
}
|
||||
|
||||
void Socket::recvUntilEnd(void* ptr, int size, int* closed) {
|
||||
int offset = 0;
|
||||
*closed = 0;
|
||||
if (state_ != SocketStateReady) {
|
||||
ERROR("socket state (%d) is not ready in recvUntilEnd\n", state_);
|
||||
return;
|
||||
}
|
||||
|
||||
int bytes = 0;
|
||||
char* data = (char*)ptr;
|
||||
|
||||
do {
|
||||
bytes = ::recv(fd_, data + (offset), size - (offset), 0);
|
||||
if (bytes == 0) {
|
||||
*closed = 1;
|
||||
return;
|
||||
}
|
||||
if (bytes == -1) {
|
||||
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) {
|
||||
ERROR("recv until end failed errno %d\n", errno);
|
||||
return;
|
||||
} else {
|
||||
bytes = 0;
|
||||
}
|
||||
}
|
||||
(offset) += bytes;
|
||||
if (abortFlag_ && *abortFlag_ != 0) {
|
||||
ERROR("aborted\n");
|
||||
return;
|
||||
}
|
||||
} while (bytes > 0 && (offset) < size);
|
||||
}
|
||||
|
||||
void Socket::close() {
|
||||
if (fd_ >= 0) ::close(fd_);
|
||||
state_ = SocketStateClosed;
|
||||
fd_ = -1;
|
||||
}
|
||||
|
||||
void Socket::progressState() {
|
||||
if (state_ == SocketStateAccepting) {
|
||||
tryAccept();
|
||||
}
|
||||
if (state_ == SocketStateAccepted) {
|
||||
finalizeAccept();
|
||||
}
|
||||
if (state_ == SocketStateConnecting) {
|
||||
startConnect();
|
||||
}
|
||||
if (state_ == SocketStateConnectPolling) {
|
||||
pollConnect();
|
||||
}
|
||||
if (state_ == SocketStateConnected) {
|
||||
finalizeConnect();
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::tryAccept() {
|
||||
socklen_t socklen = sizeof(union SocketAddress);
|
||||
fd_ = ::accept(acceptFd_, &addr_.sa, &socklen);
|
||||
if (fd_ != -1) {
|
||||
state_ = SocketStateAccepted;
|
||||
} else if (errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
ERROR("accept failed (fd %d) errno %d\n", acceptFd_, errno);
|
||||
} else {
|
||||
usleep(SLEEP_INT);
|
||||
if (++acceptRetries_ % 1000 == 0)
|
||||
INFO("tryAccept: Call to try accept returned %s, retrying", strerror(errno));
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::finalizeAccept() {
|
||||
uint64_t magic;
|
||||
enum SocketType type;
|
||||
int received = 0;
|
||||
socketProgress(ROCSHMEM_SOCKET_RECV, &magic, sizeof(magic), &received);
|
||||
if (received == 0) return;
|
||||
socketWait(ROCSHMEM_SOCKET_RECV, &magic, sizeof(magic), &received);
|
||||
if (magic != magic_) {
|
||||
ERROR("finalizeAccept: wrong magic %lx != %lx\n", magic, magic_);
|
||||
::close(fd_);
|
||||
fd_ = -1;
|
||||
// Ignore spurious connection and accept again
|
||||
state_ = SocketStateAccepting;
|
||||
return;
|
||||
} else {
|
||||
received = 0;
|
||||
socketWait(ROCSHMEM_SOCKET_RECV, &type, sizeof(type), &received);
|
||||
if (type != type_) {
|
||||
state_ = SocketStateError;
|
||||
::close(fd_);
|
||||
fd_ = -1;
|
||||
ERROR("wrong socket type %d != %d \n", type, type_);
|
||||
return;
|
||||
} else {
|
||||
state_ = SocketStateReady;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::startConnect() {
|
||||
/* blocking/non-blocking connect() is determined by asyncFlag. */
|
||||
int ret = ::connect(fd_, &addr_.sa, salen_);
|
||||
if (ret == 0) {
|
||||
state_ = SocketStateConnected;
|
||||
return;
|
||||
} else if (errno == EINPROGRESS) {
|
||||
state_ = SocketStateConnectPolling;
|
||||
return;
|
||||
} else if (errno == ECONNREFUSED || errno == ETIMEDOUT) {
|
||||
usleep(SLEEP_INT);
|
||||
if (++connectRetries_ % 1000 == 0) INFO("Call to connect returned %s, retrying", strerror(errno));
|
||||
return;
|
||||
} else {
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
state_ = SocketStateError;
|
||||
ERROR("connect to %s failed, errno %d\n", SocketToString(&addr_, line), errno);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::pollConnect() {
|
||||
struct pollfd pfd;
|
||||
int timeout = 1, ret;
|
||||
socklen_t rlen = sizeof(int);
|
||||
|
||||
memset(&pfd, 0, sizeof(struct pollfd));
|
||||
pfd.fd = fd_;
|
||||
pfd.events = POLLOUT;
|
||||
ret = ::poll(&pfd, 1, timeout);
|
||||
if (ret == -1) {
|
||||
ERROR("poll failed errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
if (ret == 0) return;
|
||||
|
||||
/* check socket status */
|
||||
if ((ret == 1 && (pfd.revents & POLLOUT)) == 0) {
|
||||
ERROR("poll failed\n");
|
||||
return;
|
||||
}
|
||||
if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen) == -1) {
|
||||
ERROR("getsockopt failed, errno %d\n", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ret == 0) {
|
||||
state_ = SocketStateConnected;
|
||||
} else if (ret == ECONNREFUSED || ret == ETIMEDOUT) {
|
||||
if (++connectRetries_ % 1000 == 0) {
|
||||
INFO("Call to connect returned %s, retrying", strerror(errno));
|
||||
}
|
||||
usleep(SLEEP_INT);
|
||||
|
||||
::close(fd_);
|
||||
fd_ = ::socket(addr_.sa.sa_family, SOCK_STREAM, 0);
|
||||
state_ = SocketStateConnecting;
|
||||
} else if (ret != EINPROGRESS) {
|
||||
state_ = SocketStateError;
|
||||
ERROR("connect failed \n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::finalizeConnect() {
|
||||
int sent = 0;
|
||||
socketProgress(ROCSHMEM_SOCKET_SEND, &magic_, sizeof(magic_), &sent);
|
||||
if (sent == 0) return;
|
||||
socketWait(ROCSHMEM_SOCKET_SEND, &magic_, sizeof(magic_), &sent);
|
||||
sent = 0;
|
||||
socketWait(ROCSHMEM_SOCKET_SEND, &type_, sizeof(type_), &sent);
|
||||
state_ = SocketStateReady;
|
||||
}
|
||||
|
||||
void Socket::socketProgressOpt(int op, void* ptr, int size, int* offset, int block, int* closed) {
|
||||
int bytes = 0;
|
||||
*closed = 0;
|
||||
char* data = (char*)ptr;
|
||||
|
||||
do {
|
||||
if (op == ROCSHMEM_SOCKET_RECV) bytes = ::recv(fd_, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
|
||||
if (op == ROCSHMEM_SOCKET_SEND)
|
||||
bytes = ::send(fd_, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
|
||||
if (op == ROCSHMEM_SOCKET_RECV && bytes == 0) {
|
||||
*closed = 1;
|
||||
return;
|
||||
}
|
||||
if (bytes == -1) {
|
||||
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
|
||||
ERROR("recv failed, errno %d\n", errno);
|
||||
return;
|
||||
} else {
|
||||
bytes = 0;
|
||||
}
|
||||
}
|
||||
(*offset) += bytes;
|
||||
if (abortFlag_ && *abortFlag_ != 0) {
|
||||
ERROR("aborted\n");
|
||||
return;
|
||||
}
|
||||
} while (bytes > 0 && (*offset) < size);
|
||||
}
|
||||
|
||||
void Socket::socketProgress(int op, void* ptr, int size, int* offset) {
|
||||
int closed;
|
||||
socketProgressOpt(op, ptr, size, offset, 0, &closed);
|
||||
if (closed) {
|
||||
char line[SOCKET_NAME_MAXLEN + 1];
|
||||
ERROR("connection closed by remote peer %s\n", SocketToString(&addr_, line, 0));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::socketWait(int op, void* ptr, int size, int* offset) {
|
||||
while (*offset < size) socketProgress(op, ptr, size, offset);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
@@ -0,0 +1,116 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef ROCSHMEM_SOCKET_H_
|
||||
#define ROCSHMEM_SOCKET_H_
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <poll.h>
|
||||
#include <stddef.h>
|
||||
#include <sys/socket.h>
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
#define MAX_IFS 16
|
||||
#define MAX_IF_NAME_SIZE 16
|
||||
#define SLEEP_INT 1000 // connection retry sleep interval in usec
|
||||
#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV)
|
||||
#define ROCSHMEM_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL
|
||||
|
||||
/* Common socket address storage structure for IPv4/IPv6 */
|
||||
union SocketAddress {
|
||||
struct sockaddr sa;
|
||||
struct sockaddr_in sin;
|
||||
struct sockaddr_in6 sin6;
|
||||
};
|
||||
|
||||
enum SocketState {
|
||||
SocketStateNone = 0,
|
||||
SocketStateInitialized = 1,
|
||||
SocketStateAccepting = 2,
|
||||
SocketStateAccepted = 3,
|
||||
SocketStateConnecting = 4,
|
||||
SocketStateConnectPolling = 5,
|
||||
SocketStateConnected = 6,
|
||||
SocketStateBound = 7,
|
||||
SocketStateReady = 8,
|
||||
SocketStateClosed = 9,
|
||||
SocketStateError = 10,
|
||||
SocketStateNum = 11
|
||||
};
|
||||
|
||||
enum SocketType {
|
||||
SocketTypeUnknown = 0,
|
||||
SocketTypeBootstrap = 1,
|
||||
SocketTypeProxy = 2,
|
||||
SocketTypeNetSocket = 3,
|
||||
SocketTypeNetIb = 4
|
||||
};
|
||||
|
||||
const char* SocketToString(union SocketAddress* addr, char* buf, const int numericHostForm = 1);
|
||||
void SocketGetAddrFromString(union SocketAddress* ua, const char* ip_port_pair);
|
||||
int FindInterfaceMatchSubnet(char* ifNames, union SocketAddress* localAddrs, union SocketAddress* remoteAddr,
|
||||
int ifNameMaxSize, int maxIfs);
|
||||
int FindInterfaces(char* ifNames, union SocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs,
|
||||
const char* inputIfName = nullptr);
|
||||
|
||||
class Socket {
|
||||
public:
|
||||
Socket(const SocketAddress* addr = nullptr, uint64_t magic = ROCSHMEM_SOCKET_MAGIC,
|
||||
enum SocketType type = SocketTypeUnknown, volatile uint32_t* abortFlag = nullptr, int asyncFlag = 0);
|
||||
~Socket();
|
||||
|
||||
void bind();
|
||||
void bindAndListen();
|
||||
void connect(int64_t timeout = -1);
|
||||
void accept(const Socket* listenSocket, int64_t timeout = -1);
|
||||
void send(void* ptr, int size);
|
||||
void recv(void* ptr, int size);
|
||||
void recvUntilEnd(void* ptr, int size, int* closed);
|
||||
void close();
|
||||
|
||||
int getFd() const { return fd_; }
|
||||
int getAcceptFd() const { return acceptFd_; }
|
||||
int getConnectRetries() const { return connectRetries_; }
|
||||
int getAcceptRetries() const { return acceptRetries_; }
|
||||
volatile uint32_t* getAbortFlag() const { return abortFlag_; }
|
||||
int getAsyncFlag() const { return asyncFlag_; }
|
||||
enum SocketState getState() const { return state_; }
|
||||
uint64_t getMagic() const { return magic_; }
|
||||
enum SocketType getType() const { return type_; }
|
||||
SocketAddress getAddr() const { return addr_; }
|
||||
int getSalen() const { return salen_; }
|
||||
|
||||
private:
|
||||
void tryAccept();
|
||||
void finalizeAccept();
|
||||
void startConnect();
|
||||
void pollConnect();
|
||||
void finalizeConnect();
|
||||
void progressState();
|
||||
|
||||
void socketProgressOpt(int op, void* ptr, int size, int* offset, int block, int* closed);
|
||||
void socketProgress(int op, void* ptr, int size, int* offset);
|
||||
void socketWait(int op, void* ptr, int size, int* offset);
|
||||
|
||||
int fd_;
|
||||
int acceptFd_;
|
||||
int connectRetries_;
|
||||
int acceptRetries_;
|
||||
volatile uint32_t* abortFlag_;
|
||||
int asyncFlag_;
|
||||
enum SocketState state_;
|
||||
uint64_t magic_;
|
||||
enum SocketType type_;
|
||||
|
||||
union SocketAddress addr_;
|
||||
int salen_;
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
#endif // ROCSHMEM_SOCKET_H_
|
||||
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <unistd.h>
|
||||
#include <signal.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "utils.hpp"
|
||||
#include "env.hpp"
|
||||
|
||||
constexpr char HOSTID_FILE[32] = "/proc/sys/kernel/random/boot_id";
|
||||
|
||||
static bool matchIf(const char* string, const char* ref, bool matchExact) {
|
||||
// Make sure to include '\0' in the exact case
|
||||
int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
|
||||
return strncmp(string, ref, matchLen) == 0;
|
||||
}
|
||||
|
||||
static bool matchPort(const int port1, const int port2) {
|
||||
if (port1 == -1) return true;
|
||||
if (port2 == -1) return true;
|
||||
if (port1 == port2) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
std::string int64ToBusId(int64_t id) {
|
||||
char busId[20];
|
||||
std::snprintf(busId, sizeof(busId), "%04lx:%02lx:%02lx.%01lx", (id) >> 20, (id & 0xff000) >> 12, (id & 0xff0) >> 4,
|
||||
(id & 0xf));
|
||||
return std::string(busId);
|
||||
}
|
||||
|
||||
int64_t busIdToInt64(const std::string busId) {
|
||||
char hexStr[17]; // Longest possible int64 hex string + null terminator.
|
||||
size_t hexOffset = 0;
|
||||
for (size_t i = 0; hexOffset < sizeof(hexStr) - 1 && i < busId.length(); ++i) {
|
||||
char c = busId[i];
|
||||
if (c == '.' || c == ':') continue;
|
||||
if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f')) {
|
||||
hexStr[hexOffset++] = busId[i];
|
||||
} else
|
||||
break;
|
||||
}
|
||||
hexStr[hexOffset] = '\0';
|
||||
return std::strtol(hexStr, NULL, 16);
|
||||
}
|
||||
|
||||
uint64_t getHash(const char* string, int n) {
|
||||
// Based on DJB2a, result = result * 33 ^ char
|
||||
uint64_t result = 5381;
|
||||
for (int c = 0; c < n; c++) {
|
||||
result = ((result << 5) + result) ^ string[c];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Generate a hash of the unique identifying string for this host
|
||||
* that will be unique for both bare-metal and container instances
|
||||
* Equivalent of a hash of;
|
||||
*
|
||||
* $(hostname)$(cat /proc/sys/kernel/random/boot_id)
|
||||
*
|
||||
* This string can be overridden by using the ROCSHMEM_HOSTID env var.
|
||||
*/
|
||||
uint64_t computeHostHash(void) {
|
||||
const size_t hashLen = 1024;
|
||||
char hostHash[hashLen];
|
||||
|
||||
memset(hostHash, 0, hashLen);
|
||||
|
||||
std::string hostName = getHostName(hashLen, '\0');
|
||||
strncpy(hostHash, hostName.c_str(), hostName.size());
|
||||
|
||||
std::string hostid = env()->hostid;
|
||||
if (hostid != "") {
|
||||
strncpy(hostHash, hostid.c_str(), hashLen);
|
||||
} else if (hostName.size() < hashLen) {
|
||||
std::ifstream file(HOSTID_FILE, std::ios::binary);
|
||||
if (file.is_open()) {
|
||||
file.read(hostHash + hostName.size(), hashLen - hostName.size());
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the string is terminated
|
||||
hostHash[sizeof(hostHash) - 1] = '\0';
|
||||
TRACE("unique hostname '%s'", hostHash);
|
||||
return getHash(hostHash, strlen(hostHash));
|
||||
}
|
||||
|
||||
uint64_t getHostHash(void) {
|
||||
thread_local std::unique_ptr<uint64_t> hostHash = std::make_unique<uint64_t>(computeHostHash());
|
||||
// avoid crash on static destruction
|
||||
if (hostHash == nullptr) {
|
||||
hostHash = std::make_unique<uint64_t>(computeHostHash());
|
||||
}
|
||||
return *hostHash;
|
||||
}
|
||||
|
||||
/* Generate a hash of the unique identifying string for this process
|
||||
* that will be unique for both bare-metal and container instances
|
||||
* Equivalent of a hash of;
|
||||
*
|
||||
* $$ $(readlink /proc/self/ns/pid)
|
||||
*/
|
||||
uint64_t computePidHash(void) {
|
||||
char pname[1024];
|
||||
// Start off with our pid ($$)
|
||||
std::snprintf(pname, sizeof(pname), "%ld", (long)getpid());
|
||||
int plen = strlen(pname);
|
||||
int len = readlink("/proc/self/ns/pid", pname + plen, sizeof(pname) - 1 - plen);
|
||||
if (len < 0) len = 0;
|
||||
|
||||
pname[plen + len] = '\0';
|
||||
TRACE("unique PID '%s'", pname);
|
||||
|
||||
return getHash(pname, strlen(pname));
|
||||
}
|
||||
|
||||
uint64_t getPidHash(void) {
|
||||
thread_local std::unique_ptr<uint64_t> pidHash = std::make_unique<uint64_t>(computePidHash());
|
||||
// avoid crash on static destruction
|
||||
if (pidHash == nullptr) {
|
||||
pidHash = std::make_unique<uint64_t>(computePidHash());
|
||||
}
|
||||
return *pidHash;
|
||||
}
|
||||
|
||||
int parseStringList(const char* string, netIf* ifList, int maxList) {
|
||||
if (!string) return 0;
|
||||
|
||||
const char* ptr = string;
|
||||
|
||||
int ifNum = 0;
|
||||
int ifC = 0;
|
||||
char c;
|
||||
do {
|
||||
c = *ptr;
|
||||
if (c == ':') {
|
||||
if (ifC > 0) {
|
||||
ifList[ifNum].prefix[ifC] = '\0';
|
||||
ifList[ifNum].port = atoi(ptr + 1);
|
||||
ifNum++;
|
||||
ifC = 0;
|
||||
}
|
||||
while (c != ',' && c != '\0') c = *(++ptr);
|
||||
} else if (c == ',' || c == '\0') {
|
||||
if (ifC > 0) {
|
||||
ifList[ifNum].prefix[ifC] = '\0';
|
||||
ifList[ifNum].port = -1;
|
||||
ifNum++;
|
||||
ifC = 0;
|
||||
}
|
||||
} else {
|
||||
ifList[ifNum].prefix[ifC] = c;
|
||||
ifC++;
|
||||
}
|
||||
ptr++;
|
||||
} while (ifNum < maxList && c);
|
||||
return ifNum;
|
||||
}
|
||||
|
||||
bool matchIfList(const char* string, int port, netIf* ifList, int listSize, bool matchExact) {
|
||||
// Make an exception for the case where no user list is defined
|
||||
if (listSize == 0) return true;
|
||||
|
||||
for (int i = 0; i < listSize; i++) {
|
||||
if (matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* get any bytes of random data from /dev/urandom */
|
||||
void getRandomData(void* buffer, size_t bytes) {
|
||||
if (bytes > 0) {
|
||||
const size_t one = 1UL;
|
||||
FILE* fp = fopen("/dev/urandom", "r");
|
||||
if (buffer == NULL || fp == NULL || fread(buffer, bytes, one, fp) != one) {
|
||||
ERROR("Failed to read random data\n");
|
||||
return;
|
||||
}
|
||||
if (fp) fclose(fp);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
// Throw upon SIGALRM.
|
||||
static void sigalrmTimeoutHandler(int) {
|
||||
signal(SIGALRM, SIG_IGN);
|
||||
//throw mscclpp::Error("Timer timed out", ErrorCode::Timeout);
|
||||
ERROR("Timer timed out\n");
|
||||
return;
|
||||
}
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
Timer::Timer(int timeout) { set(timeout); }
|
||||
|
||||
Timer::~Timer() {
|
||||
if (timeout_ > 0) {
|
||||
alarm(0);
|
||||
signal(SIGALRM, SIG_DFL);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t Timer::elapsed() const {
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(end - start_).count();
|
||||
}
|
||||
|
||||
void Timer::set(int timeout) {
|
||||
timeout_ = timeout;
|
||||
if (timeout > 0) {
|
||||
signal(SIGALRM, sigalrmTimeoutHandler);
|
||||
alarm(timeout);
|
||||
}
|
||||
start_ = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void Timer::reset() { set(timeout_); }
|
||||
|
||||
void Timer::print(const std::string& name) {
|
||||
auto us = elapsed();
|
||||
printf("%s : %ld\n", name.c_str(), us);
|
||||
}
|
||||
|
||||
ScopedTimer::ScopedTimer(const std::string& name) : name_(name) {}
|
||||
|
||||
ScopedTimer::~ScopedTimer() { print(name_); }
|
||||
|
||||
std::string getHostName(int maxlen, const char delim) {
|
||||
std::string hostname(maxlen + 1, '\0');
|
||||
if (gethostname(const_cast<char*>(hostname.data()), maxlen) != 0) {
|
||||
ERROR("gethostname failed\n");
|
||||
return nullptr;
|
||||
}
|
||||
int i = 0;
|
||||
while ((hostname[i] != delim) && (hostname[i] != '\0') &&
|
||||
(i < maxlen - 1)) i++;
|
||||
hostname[i] = '\0';
|
||||
return hostname.substr(0, i);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
// Modifications Copyright (c) Microsoft Corporation.
|
||||
// Modifications Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef ROCSHMEM_UTILS_HPP_
|
||||
#define ROCSHMEM_UTILS_HPP_
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#define ERROR(...) { fprintf(stderr, __VA_ARGS__); abort(); }
|
||||
|
||||
#ifdef ROCSHMEM_ENABLE_TRACE
|
||||
#define TRACE(...) printf(__VA_ARGS__)
|
||||
#else
|
||||
#define TRACE(...)
|
||||
#endif
|
||||
|
||||
#if defined ROCSHMEM_ENABLE_INFO
|
||||
#define INFO(FLAGS, ...)printf(__VA_ARGS__)
|
||||
#else
|
||||
#define INFO(...)
|
||||
#endif
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
struct Timer {
|
||||
std::chrono::steady_clock::time_point start_;
|
||||
int timeout_;
|
||||
|
||||
Timer(int timeout = -1);
|
||||
|
||||
~Timer();
|
||||
|
||||
/// Returns the elapsed time in microseconds.
|
||||
int64_t elapsed() const;
|
||||
|
||||
void set(int timeout);
|
||||
|
||||
void reset();
|
||||
|
||||
void print(const std::string& name);
|
||||
};
|
||||
|
||||
struct ScopedTimer : public Timer {
|
||||
const std::string name_;
|
||||
|
||||
ScopedTimer(const std::string& name);
|
||||
|
||||
~ScopedTimer();
|
||||
};
|
||||
|
||||
std::string getHostName(int maxlen, const char delim);
|
||||
|
||||
// PCI Bus ID <-> int64 conversion functions
|
||||
std::string int64ToBusId(int64_t id);
|
||||
int64_t busIdToInt64(const std::string busId);
|
||||
|
||||
uint64_t getHash(const char* string, int n);
|
||||
uint64_t getHostHash();
|
||||
uint64_t getPidHash();
|
||||
void getRandomData(void* buffer, size_t bytes);
|
||||
|
||||
struct netIf {
|
||||
char prefix[64];
|
||||
int port;
|
||||
};
|
||||
|
||||
int parseStringList(const char* string, struct netIf* ifList, int maxList);
|
||||
bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact);
|
||||
|
||||
template <class T>
|
||||
inline void hashCombine(std::size_t& hash, const T& v) {
|
||||
std::hash<T> hasher;
|
||||
hash ^= hasher(v) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
||||
}
|
||||
|
||||
struct PairHash {
|
||||
public:
|
||||
template <typename T, typename U>
|
||||
std::size_t operator()(const std::pair<T, U>& x) const {
|
||||
std::size_t hash = 0;
|
||||
hashCombine(hash, x.first);
|
||||
hashCombine(hash, x.second);
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
#endif // ROCSHMEM_UTILS_HPP
|
||||
+65
-29
@@ -49,6 +49,7 @@
|
||||
#include "team.hpp"
|
||||
#include "templates_host.hpp"
|
||||
#include "util.hpp"
|
||||
#include "bootstrap/bootstrap.hpp"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
@@ -102,7 +103,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
[[maybe_unused]] __host__ int rocshmem_init_attr(unsigned int flags,
|
||||
rocshmem_init_attr_t *attr) {
|
||||
MPI_Comm comm = MPI_COMM_WORLD;
|
||||
MPI_Comm comm = MPI_COMM_NULL;
|
||||
|
||||
if ((attr == nullptr) ||
|
||||
((flags != ROCSHMEM_INIT_WITH_UNIQUEID) &&
|
||||
@@ -115,22 +116,69 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
|
||||
if (flags == ROCSHMEM_INIT_WITH_MPI_COMM) {
|
||||
comm = *(static_cast<MPI_Comm*>(attr->mpi_comm));
|
||||
library_init(comm);
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
// As of right now, we require initialization through the MPI library.
|
||||
library_init(comm);
|
||||
|
||||
// The unique Id can be used to verify that the processes participating matches
|
||||
// (i.e. they all need to have the same unique Id, as well as the number of ranks.
|
||||
if (flags == ROCSHMEM_INIT_WITH_UNIQUEID) {
|
||||
int worldsize = backend->getNumPEs();
|
||||
if (worldsize != attr->nranks) {
|
||||
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
|
||||
"Call 'rocshmem_init_attr: mismatch between world-team size and "
|
||||
"attribute value'", __FILE__, __LINE__);
|
||||
// This is a fatal error, a fundamental mismatch between what was requested
|
||||
// and what we have.
|
||||
abort();
|
||||
int initialized;
|
||||
int world_size = -1;
|
||||
MPI_Initialized(&initialized);
|
||||
|
||||
if (!initialized) {
|
||||
// This is an Open MPI specific solution to retrieve the number of
|
||||
// processes that have been started, value can be checked before MPI_Init
|
||||
char *value = getenv("OMPI_COMM_WORLD_SIZE");
|
||||
if (value != NULL) {
|
||||
world_size = atoi(value);
|
||||
}
|
||||
if (world_size != attr->nranks) {
|
||||
// This solution will require MPI_Sessions. This is planned for the
|
||||
// future, but is not supported in the current version.
|
||||
fprintf (stderr, "Unsupported configuration to initialize rocSHMEM. Please "
|
||||
"initialize the MPI library using MPI_Init first, if you want to "
|
||||
"initialize rocSHMEM with a subset of the processes\n");
|
||||
abort();
|
||||
}
|
||||
} else {
|
||||
MPI_Comm_size (MPI_COMM_WORLD, &world_size);
|
||||
}
|
||||
|
||||
if (world_size == attr->nranks) {
|
||||
library_init(MPI_COMM_WORLD);
|
||||
} else {
|
||||
MPI_Group world_group;
|
||||
int world_rank;
|
||||
|
||||
MPI_Comm_rank (MPI_COMM_WORLD, &world_rank);
|
||||
MPI_Comm_group (MPI_COMM_WORLD, &world_group);
|
||||
|
||||
TcpBootstrap bootstr(attr->rank, attr->nranks);
|
||||
|
||||
int timeout = 5;
|
||||
char *value;
|
||||
value = getenv("ROCSHMEM_BOOTSTRAP_TIMEOUT");
|
||||
if (value != nullptr) {
|
||||
timeout = atoi(value);
|
||||
}
|
||||
|
||||
bootstr.initialize(attr->uid, timeout);
|
||||
int *inc_ranks = new int[attr->nranks];
|
||||
inc_ranks[attr->rank] = world_rank;
|
||||
|
||||
bootstr.allGather (inc_ranks, sizeof(int));
|
||||
|
||||
MPI_Group sub_group;
|
||||
MPI_Comm sub_comm;
|
||||
MPI_Group_incl (world_group, attr->nranks, inc_ranks, &sub_group);
|
||||
MPI_Comm_create_group (MPI_COMM_WORLD, sub_group, 1234, &sub_comm);
|
||||
|
||||
library_init(sub_comm);
|
||||
|
||||
MPI_Group_free (&sub_group);
|
||||
MPI_Group_free (&world_group);
|
||||
MPI_Comm_free (&sub_comm);
|
||||
delete[] inc_ranks;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +206,7 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
// Note: this function will be called before rocshmem_init_*, so one
|
||||
// cannot assume that a backend is already set
|
||||
[[maybe_unused]] __host__ int rocshmem_get_uniqueid(rocshmem_uniqueid_t *uid) {
|
||||
rocshmem_uniqueid_t tuid;
|
||||
if (uid == nullptr) {
|
||||
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
|
||||
"Call 'rocshmem_get_uniqueid: invalid input argument'",
|
||||
@@ -165,21 +214,8 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
return ROCSHMEM_ERROR;
|
||||
}
|
||||
|
||||
std::random_device dev;
|
||||
std::mt19937_64 rng(dev());
|
||||
std::uniform_int_distribution<uint64_t> dist(0, std::numeric_limits<uint64_t>::max());
|
||||
|
||||
char hostname[HOST_NAME_MAX+1];
|
||||
if (0 != gethostname(hostname, HOST_NAME_MAX)) {
|
||||
fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n",
|
||||
"Call 'rocshmem_get_uniqueid: could not get hostname'",
|
||||
__FILE__, __LINE__);
|
||||
return ROCSHMEM_ERROR;
|
||||
}
|
||||
|
||||
uid->random = dist(rng);
|
||||
std::memcpy(uid->hostname, hostname, ROCSHMEM_HOSTNAME_LEN);
|
||||
uid->pid = static_cast<uint32_t>(getpid());
|
||||
tuid = TcpBootstrap::createUniqueId();
|
||||
*uid = tuid;
|
||||
|
||||
return ROCSHMEM_SUCCESS;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user