diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index da147f1c7c..74acd45ce0 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -29,6 +29,7 @@ set(EXAMPLE_SOURCES rocshmem_broadcast_test.cc rocshmem_getmem_test.cc rocshmem_put_signal_test.cc + rocshmem_init_attr_test.cc ) foreach(SOURCE_FILE IN LISTS EXAMPLE_SOURCES) diff --git a/examples/rocshmem_init_attr_test.cc b/examples/rocshmem_init_attr_test.cc new file mode 100644 index 0000000000..d57d37f75a --- /dev/null +++ b/examples/rocshmem_init_attr_test.cc @@ -0,0 +1,68 @@ +/* +hipcc -c -fgpu-rdc -x hip rocshmem_init_attr_test.cc \ + -I/opt/rocm/include \ + -I$ROCSHMEM_INSTALL_DIR/include \ + -I$OPENMPI_UCX_INSTALL_DIR/include/ + +hipcc -fgpu-rdc --hip-link rocshmem_init_attr_test.o -o rocshmem_init_attr_test \ + $ROCSHMEM_INSTALL_DIR/lib/librocshmem.a \ + $OPENMPI_UCX_INSTALL_DIR/lib/libmpi.so \ + -L/opt/rocm/lib -lamdhip64 -lhsa-runtime64 + +ROCSHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 ./rocshmem_init_attr_test +*/ + +#include + +#include +#include +#include + +#define CHECK_HIP(condition) { \ + hipError_t error = condition; \ + if(error != hipSuccess){ \ + fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ + MPI_Abort(MPI_COMM_WORLD, error); \ + } \ + } + +using namespace rocshmem; + +int main (int argc, char **argv) +{ + int rank, nranks; + int ret; + rocshmem_uniqueid_t uid; + rocshmem_init_attr_t attr; + + MPI_Init(&argc, &argv); + MPI_Comm_rank (MPI_COMM_WORLD, &rank); + MPI_Comm_size (MPI_COMM_WORLD, &nranks); + + if (rank == 0) { + ret = rocshmem_get_uniqueid (&uid); + if (ret != ROCSHMEM_SUCCESS) { + std::cout << rank << ": Error in rocshmem_get_uniqueid. Aborting.\n"; + MPI_Abort (MPI_COMM_WORLD, ret); + } + } + + MPI_Bcast (&uid, sizeof(rocshmem_uniqueid_t), MPI_BYTE, 0, MPI_COMM_WORLD); + ret = rocshmem_set_attr_uniqueid_args(rank, nranks, &uid, &attr); + if (ret != ROCSHMEM_SUCCESS) { + std::cout << rank << ": Error in rocshmem_set_attr_uniqueid_args. Aborting.\n"; + MPI_Abort (MPI_COMM_WORLD, ret); + } + + ret = rocshmem_init_attr(ROCSHMEM_INIT_WITH_UNIQUEID, &attr); + if (ret != ROCSHMEM_SUCCESS) { + std::cout << rank << ": Error in rocshmem_init_attr. Aborting.\n"; + MPI_Abort (MPI_COMM_WORLD, ret); + } + + std::cout << rank << ": rocshmem_init_attr SUCCESS\n"; + + rocshmem_finalize(); + MPI_Finalize(); + return 0; +} diff --git a/include/rocshmem/rocshmem.hpp b/include/rocshmem/rocshmem.hpp index d025c6b0d7..b94ab11613 100644 --- a/include/rocshmem/rocshmem.hpp +++ b/include/rocshmem/rocshmem.hpp @@ -71,10 +71,50 @@ __host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD); * to requested thread mode. * @param[in] comm (Optional) MPI Communicator that rocSHMEM will be using * If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD + * + * @return int returns 0 upon success; otherwise, it returns a nonzero + * value */ -__host__ void rocshmem_init_thread(int requested, int *provided, - MPI_Comm comm = MPI_COMM_WORLD); +__host__ int rocshmem_init_thread(int requested, int *provided, + MPI_Comm comm = MPI_COMM_WORLD); +/** + * @brief Initialize the rocSHMEM runtime and underlying transport layer + * using the provided mode and attributes + * + * @param[in] flags initialization method to be used. + * Valid values are ROCSHMEM_INIT_WITH_UNIQUEID and + * ROCSHMEM_INIT_WITH_MPI_COMM + * @param[in] attr attribute structure specifying input characteristics + * + * @return int returns 0 upon success; otherwise, it returns a nonzero + * value + */ +__host__ int rocshmem_init_attr(unsigned int flags, rocshmem_init_attr_t *attr); + +/** + * @brief Return a uniqueID + * + * @return int returns 0 upon success; otherwise, it returns a nonzero + * value + */ +__host__ int rocshmem_get_uniqueid(rocshmem_uniqueid_t *uid); + +/** + * @brief Query the thread mode used by the runtime. + * + * @param[in] rank rank of the calling process + * @param[in] nranks number of pes + * @param[in] uid unique ID used to identify the group processes. + * All processes that + * @param[out] attr attribute structure to be passed to rocshmem_init_attr + * + * @return int returns 0 upon success; otherwise, it returns a nonzero + * value + */ +__host__ int rocshmem_set_attr_uniqueid_args(int rank, int nranks, + rocshmem_uniqueid_t *uid, + rocshmem_init_attr_t *attr); /** * @brief Query the thread mode used by the runtime. * diff --git a/include/rocshmem/rocshmem_common.hpp b/include/rocshmem/rocshmem_common.hpp index baea438244..e2cd646e4e 100644 --- a/include/rocshmem/rocshmem_common.hpp +++ b/include/rocshmem/rocshmem_common.hpp @@ -125,6 +125,32 @@ extern rocshmem_team_t ROCSHMEM_TEAM_WORLD; const rocshmem_team_t ROCSHMEM_TEAM_INVALID = nullptr; +/** + * @brief Data structure defining the unqiueId + */ +constexpr int ROCSHMEM_HOSTNAME_LEN = 20; +struct rocshmem_uniqueid_t { + uint64_t random; + char hostname[ROCSHMEM_HOSTNAME_LEN]; + uint32_t pid; +}; +typedef struct rocshmem_uniqueid_t rocshmem_unique_id_t; + +/** + * @brief Data structure used for attribute based + * initialization + */ +struct rocshmem_init_attr_t { + int32_t rank; + int32_t nranks; + rocshmem_uniqueid_t uid; + void* mpi_comm; +}; +typedef struct rocshmem_init_attr_t rocshmem_init_attr_t; + +constexpr unsigned int ROCSHMEM_INIT_WITH_MPI_COMM = 0; +constexpr unsigned int ROCSHMEM_INIT_WITH_UNIQUEID = 1; + } // namespace rocshmem #endif // LIBRARY_INCLUDE_ROCSHMEM_COMMON_HPP diff --git a/src/rocshmem.cpp b/src/rocshmem.cpp index 931e4c83a0..9d2844cbc3 100644 --- a/src/rocshmem.cpp +++ b/src/rocshmem.cpp @@ -32,7 +32,9 @@ #include "rocshmem/rocshmem.hpp" #include +#include #include +#include #include "backend_bc.hpp" #include "context_incl.hpp" @@ -48,6 +50,8 @@ #include "templates_host.hpp" #include "util.hpp" +#include + namespace rocshmem { #define VERIFY_BACKEND() \ @@ -96,14 +100,100 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; } } +[[maybe_unused]] __host__ int rocshmem_init_attr(unsigned int flags, + rocshmem_init_attr_t *attr) { + MPI_Comm comm = MPI_COMM_WORLD; + + if ((attr == nullptr) || + ((flags != ROCSHMEM_INIT_WITH_UNIQUEID) && + (flags != ROCSHMEM_INIT_WITH_MPI_COMM)) ) { + fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", + "Call 'rocshmem_init_attr: invalid input argument'", + __FILE__, __LINE__); + return ROCSHMEM_ERROR; + } + + if (flags == ROCSHMEM_INIT_WITH_MPI_COMM) { + comm = *(static_cast(attr->mpi_comm)); + } + + // As of right now, we require initialization through the MPI library. + library_init(comm); + + // The unique Id can be used to verify that the processes participating matches + // (i.e. they all need to have the same unique Id, as well as the number of ranks. + if (flags == ROCSHMEM_INIT_WITH_UNIQUEID) { + int worldsize = backend->getNumPEs(); + if (worldsize != attr->nranks) { + fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", + "Call 'rocshmem_init_attr: mismatch between world-team size and " + "attribute value'", __FILE__, __LINE__); + // This is a fatal error, a fundamental mismatch between what was requested + // and what we have. + abort(); + } + } + + return ROCSHMEM_SUCCESS; +} + +[[maybe_unused]] __host__ int rocshmem_set_attr_uniqueid_args(int rank, int nranks, + rocshmem_uniqueid_t *uid, + rocshmem_init_attr_t *attr) { + if (uid == nullptr || attr == nullptr) { + fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", + "Call 'rocshmem_get_uniqueid: invalid input argument'", + __FILE__, __LINE__); + return ROCSHMEM_ERROR; + } + + attr->rank = rank; + attr->nranks = nranks; + attr->uid = *uid; + attr->mpi_comm = nullptr; + + return ROCSHMEM_SUCCESS; +} + +// Note: this function will be called before rocshmem_init_*, so one +// cannot assume that a backend is already set +[[maybe_unused]] __host__ int rocshmem_get_uniqueid(rocshmem_uniqueid_t *uid) { + if (uid == nullptr) { + fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", + "Call 'rocshmem_get_uniqueid: invalid input argument'", + __FILE__, __LINE__); + return ROCSHMEM_ERROR; + } + + std::random_device dev; + std::mt19937_64 rng(dev()); + std::uniform_int_distribution dist(0, std::numeric_limits::max()); + + char hostname[HOST_NAME_MAX+1]; + if (0 != gethostname(hostname, HOST_NAME_MAX)) { + fprintf(stderr, "ROCSHMEM_ERROR: %s in file '%s' in line %d\n", + "Call 'rocshmem_get_uniqueid: could not get hostname'", + __FILE__, __LINE__); + return ROCSHMEM_ERROR; + } + + uid->random = dist(rng); + std::memcpy(uid->hostname, hostname, ROCSHMEM_HOSTNAME_LEN); + uid->pid = static_cast(getpid()); + + return ROCSHMEM_SUCCESS; +} + [[maybe_unused]] __host__ void rocshmem_init(MPI_Comm comm) { library_init(comm); } -[[maybe_unused]] __host__ void rocshmem_init_thread( +[[maybe_unused]] __host__ int rocshmem_init_thread( [[maybe_unused]] int required, int *provided, MPI_Comm comm) { library_init(comm); rocshmem_query_thread(provided); + + return ROCSHMEM_SUCCESS; } [[maybe_unused]] __host__ int rocshmem_my_pe() {