From d0c2845031e1cf11f2742c09297448abcd6f007b Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Wed, 22 Oct 2025 16:04:58 -0500 Subject: [PATCH] add support for GPUs using wavefront size of 32 (#285) * add gfx1100 support Add support for Radeon 7900 GPUs (RX and PRO), and 7800 PRO. I was contemplating to add gfx1101 and gfx1102 GPUs as well, but those are the lower end models that are more unlikely to be used for compute intensive jobs. In addition, I do not have access to them to test the support. * update WF_SIZe for different options Radeon systems use a WarpSize of 32, unlike current Instinct systems, which use a warp size of 64. For the device side, a gfx specific ifdef is sufficient. For the host side, we need to query the device properties. * adjust functional tests to wf_size of 32 * update unit tests to handle wf_size of 32 * address reviewer comments --- CMakeLists.txt | 1 + scripts/functional_tests/driver.sh | 19 ++++++-- src/assembly.hpp | 16 +++---- src/constants.hpp | 4 ++ src/gda/queue_pair.cpp | 6 ++- src/reverse_offload/backend_ro.cpp | 7 +-- src/util.cpp | 17 +++++++ src/util.hpp | 20 ++++++++ .../default_ctx_primitive_tester.cpp | 9 ++-- tests/functional_tests/primitive_tester.cpp | 8 ++-- tests/functional_tests/shmem_ptr_tester.cpp | 8 ++-- .../team_ctx_primitive_tester.cpp | 8 ++-- tests/unit_tests/bitwise_gtest.cpp | 14 ++---- tests/unit_tests/bitwise_gtest.hpp | 3 ++ tests/unit_tests/free_list_gtest.cpp | 10 ++-- tests/unit_tests/free_list_gtest.hpp | 3 ++ tests/unit_tests/wavefront_size_gtest.cpp | 38 +++++++++++---- tests/unit_tests/wavefront_size_gtest.hpp | 11 ++++- tests/unit_tests/wf_size.hpp | 46 +++++++++++++++++++ 19 files changed, 192 insertions(+), 56 deletions(-) create mode 100644 tests/unit_tests/wf_size.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 35af099c3a..eb89994ece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,6 +108,7 @@ include(cmake/rocm_local_targets.cmake) set(DEFAULT_GPUS gfx90a:xnack-; gfx90a:xnack+; + gfx1100; gfx942) if(${ROCM_MAJOR_VERSION} GREATER 6) diff --git a/scripts/functional_tests/driver.sh b/scripts/functional_tests/driver.sh index 9e91fb438c..d67694d636 100755 --- a/scripts/functional_tests/driver.sh +++ b/scripts/functional_tests/driver.sh @@ -118,6 +118,15 @@ ExecTest() { NUM_THREADS=$4 MAX_MSG_SIZE=$5 + if command -v amd-smi >/dev/null + then + NUM_GPUS=$(amd-smi list | grep GPU | wc -l) + elif command -v rocm-smi >/dev/null + then + NUM_GPUS=$(rocm-smi --showserial | grep GPU | wc -l) + else + NUM_GPUS=64 + fi TIMEOUT=$((5 * 60)) # Timeout in seconds TEST_NUM=${TEST_NUMBERS[$TEST_NAME]} @@ -159,9 +168,13 @@ ExecTest() { CMD+=" >> $LOG_DIR/$TEST_LOG_NAME.log 2>&1" # Run Test - echo $TEST_LOG_NAME - echo "# $CMD" >"$LOG_DIR/$TEST_LOG_NAME.log" - eval $CMD + if [ $NUM_RANKS -le $NUM_GPUS ] && [[ "" == "$HOSTFILE" ]]; then + echo $TEST_LOG_NAME + echo "# $CMD" >"$LOG_DIR/$TEST_LOG_NAME.log" + eval $CMD + else + echo "Skipping test $TEST_LOG_NAME" + fi # Validate Test if [ $? -ne 0 ] diff --git a/src/assembly.hpp b/src/assembly.hpp index 67827f3e9d..2dd5e8dea0 100644 --- a/src/assembly.hpp +++ b/src/assembly.hpp @@ -46,7 +46,7 @@ __device__ __forceinline__ int uncached_load_ubyte(uint8_t* src) { #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile( "global_load_ubyte %0 %1 off glc slc \n" "s_waitcnt vmcnt(0)" @@ -69,7 +69,7 @@ __device__ __forceinline__ void refresh_volatile_sbyte(volatile int *assigned_va #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile( "global_load_sbyte %0 %1 off glc slc\n " "s_waitcnt vmcnt(0)" @@ -91,7 +91,7 @@ __device__ __forceinline__ void refresh_volatile_dwordx2(volatile uint64_t *assi #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile( "global_load_dwordx2 %0 %1 off glc slc\n " "s_waitcnt vmcnt(0)" @@ -122,7 +122,7 @@ NOWARN(-Wdeprecated-volatile, #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile( "global_load_dword %0 %1 off glc slc \n" "s_waitcnt vmcnt(0)" @@ -142,7 +142,7 @@ NOWARN(-Wdeprecated-volatile, #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile( "global_load_dwordx2 %0 %1 off glc slc \n" "s_waitcnt vmcnt(0)" @@ -191,7 +191,7 @@ __device__ __forceinline__ void store_asm(uint8_t* val, uint8_t* dst, #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile("flat_store_short %0 %1 glc slc" : : "v"(dst), "v"(val16)); #endif #if defined(__gfx942__) || defined(__gfx950__) @@ -205,7 +205,7 @@ __device__ __forceinline__ void store_asm(uint8_t* val, uint8_t* dst, #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile("flat_store_dword %0 %1 glc slc" : : "v"(dst), "v"(val32)); #endif #if defined(__gfx942__) || defined(__gfx950__) @@ -219,7 +219,7 @@ __device__ __forceinline__ void store_asm(uint8_t* val, uint8_t* dst, #endif #if defined(__gfx908__) #endif -#if defined(__gfx90a__) +#if defined(__gfx90a__) || defined (__gfx1100__) asm volatile("flat_store_dwordx2 %0 %1 glc slc" : : "v"(dst), "v"(val64)); #endif #if defined(__gfx942__) || defined(__gfx950__) diff --git a/src/constants.hpp b/src/constants.hpp index e3c796c4a7..0a3b54de98 100644 --- a/src/constants.hpp +++ b/src/constants.hpp @@ -57,7 +57,11 @@ inline const unsigned MAX_WG_SIZE{1024}; * * @note Wavefront size on most systems is either 32 or 64. */ +#if defined(__gfx90a__) || defined(__gfx942__) || defined (__gfx950__) inline const int WF_SIZE{64}; +#else +inline const int WF_SIZE{32}; +#endif } // namespace rocshmem diff --git a/src/gda/queue_pair.cpp b/src/gda/queue_pair.cpp index e8c670b90b..c268ae99e6 100644 --- a/src/gda/queue_pair.cpp +++ b/src/gda/queue_pair.cpp @@ -28,6 +28,7 @@ #include "backend_gda.hpp" #include "constants.hpp" +#include "util.hpp" namespace rocshmem { @@ -59,7 +60,10 @@ QueuePair::QueuePair(struct ibv_pd* pd, int gda_provider) { fetching_atomic_lkey = mr_fetching_atomic->lkey; } - for(int i{0}; i < FETCHING_ATOMIC_CNT; i+=WF_SIZE) { + int deviceId; + CHECK_HIP(hipGetDevice(&deviceId)); + int wf_size = get_wf_size(deviceId); + for(int i{0}; i < FETCHING_ATOMIC_CNT; i+=wf_size) { fetching_atomic_freelist->push_back(fetching_atomic + i); } diff --git a/src/reverse_offload/backend_ro.cpp b/src/reverse_offload/backend_ro.cpp index 96ff7c49d4..769c988db4 100644 --- a/src/reverse_offload/backend_ro.cpp +++ b/src/reverse_offload/backend_ro.cpp @@ -56,14 +56,11 @@ ROBackend::ROBackend(MPI_Comm comm) profiler_proxy_ = ProfilerProxyT(envvar::max_num_contexts); int device_id; - hipDeviceProp_t device_props; - CHECK_HIP(hipGetDevice(&device_id)); - CHECK_HIP(hipGetDeviceProperties(&device_props, device_id)); - max_wg_size_ = device_props.maxThreadsPerBlock; + max_wg_size_ = get_threads_per_block(device_id); - wf_size_ = device_props.warpSize; + wf_size_ = get_wf_size(device_id); setup_default_ctx_buffers(); diff --git a/src/util.cpp b/src/util.cpp index 840e84d841..d3c7de0d1d 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -44,6 +44,21 @@ typedef struct device_agent { std::vector gpu_agents; std::vector cpu_agents; +std::vector device_properties; + +static void device_properties_init(void) { + int numDevices; + CHECK_HIP(hipGetDeviceCount(&numDevices)); + + device_prop_t prop; + hipDeviceProp_t hipprop; + for (int i=0; i #include +#include +#include #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) #include "constants.hpp" @@ -146,6 +148,24 @@ do { \ extern const int gpu_clock_freq_mhz; + +typedef struct device_prop { + int warpSize; + int maxThreadsPerBlock; +} device_prop_t; + +extern std::vector device_properties; + +static int get_threads_per_block(int device_id) { + assert(device_properties.size() > device_id); + return device_properties[device_id].maxThreadsPerBlock; +} + +static int get_wf_size(int device_id) { + assert(device_properties.size() > device_id); + return device_properties[device_id].warpSize; +} + /* Device-side internal functions */ __device__ __forceinline__ uint32_t lowerID() { return __ffsll(__ballot(1)) - 1; diff --git a/tests/functional_tests/default_ctx_primitive_tester.cpp b/tests/functional_tests/default_ctx_primitive_tester.cpp index f8b3ce4db5..ef72b4d8ba 100644 --- a/tests/functional_tests/default_ctx_primitive_tester.cpp +++ b/tests/functional_tests/default_ctx_primitive_tester.cpp @@ -42,10 +42,12 @@ /** * Shared array to capture the start time for each wavefront - * Max threads per block = 1024, wavefront size = 64 (in most GPUs) - * Maximum array size required = 1024/64 = 16 + * Max threads per block = 1024, wavefront size = 64 or 32 depending + * on the GPUs. Using 32 since its safer for the dimensioning of the array, + * the last 16 elements will not be used on GPUs with a wf size of 64. + * Maximum array size required = 1024/32 = 32 */ - __shared__ long long int wf_start_time[16]; + __shared__ long long int wf_start_time[32]; /** * Calculate start index for each thread within the grid @@ -190,4 +192,3 @@ } } } - \ No newline at end of file diff --git a/tests/functional_tests/primitive_tester.cpp b/tests/functional_tests/primitive_tester.cpp index 6cef5e39b4..1f8faa650c 100644 --- a/tests/functional_tests/primitive_tester.cpp +++ b/tests/functional_tests/primitive_tester.cpp @@ -43,10 +43,12 @@ __global__ void PrimitiveTest(int loop, int skip, long long int *start_time, /** * Shared array to capture the start time for each wavefront - * Max threads per block = 1024, wavefront size = 64 (in most GPUs) - * Maximum array size required = 1024/64 = 16 + * Max threads per block = 1024, wavefront size = 64 or 32 depending + * on the GPUs. Using 32 since its safer for the dimensioning of the array, + * the last 16 elements will not be used on GPUs with a wf size of 64. + * Maximum array size required = 1024/32 = 32 */ - __shared__ long long int wf_start_time[16]; + __shared__ long long int wf_start_time[32]; /** * Calculate start index for each thread within the grid diff --git a/tests/functional_tests/shmem_ptr_tester.cpp b/tests/functional_tests/shmem_ptr_tester.cpp index 158b6d0878..86608e37d0 100644 --- a/tests/functional_tests/shmem_ptr_tester.cpp +++ b/tests/functional_tests/shmem_ptr_tester.cpp @@ -43,10 +43,12 @@ __global__ void ShmemPtrTest(int loop, int skip, long long int *start_time, /** * Shared array to capture the start time for each wavefront - * Max threads per block = 1024, wavefront size = 64 (in most GPUs) - * Maximum array size required = 1024/64 = 16 + * Max threads per block = 1024, wavefront size = 64 or 32 depending + * on the GPUs. Using 32 since its safer for the dimensioning of the array, + * the last 16 elements will not be used on GPUs with a wf size of 64. + * Maximum array size required = 1024/32 = 32 */ - __shared__ long long int wf_start_time[16]; + __shared__ long long int wf_start_time[32]; /** diff --git a/tests/functional_tests/team_ctx_primitive_tester.cpp b/tests/functional_tests/team_ctx_primitive_tester.cpp index 673a536d58..16faf58e28 100644 --- a/tests/functional_tests/team_ctx_primitive_tester.cpp +++ b/tests/functional_tests/team_ctx_primitive_tester.cpp @@ -47,10 +47,12 @@ __global__ void TeamCtxPrimitiveTest(int loop, int skip, long long int *start_ti /** * Shared array to capture the start time for each wavefront - * Max threads per block = 1024, wavefront size = 64 (in most GPUs) - * Maximum array size required = 1024/64 = 16 + * Max threads per block = 1024, wavefront size = 64 or 32 depending + * on the GPUs. Using 32 since its safer for the dimensioning of the array, + * the last 16 elements will not be used on GPUs with a wf size of 64. + * Maximum array size required = 1024/32 = 32 */ - __shared__ long long int wf_start_time[16]; + __shared__ long long int wf_start_time[32]; /** * Calculate start index for each thread within the grid diff --git a/tests/unit_tests/bitwise_gtest.cpp b/tests/unit_tests/bitwise_gtest.cpp index 23d30a59d9..822bcef25e 100644 --- a/tests/unit_tests/bitwise_gtest.cpp +++ b/tests/unit_tests/bitwise_gtest.cpp @@ -144,12 +144,6 @@ TEST_F(BitwiseTestFixture, verify_host_warp_matrix_init_1024_8) { verify_zeroed_warp_matrix(); } -TEST_F(BitwiseTestFixture, verify_warp_size_64) { - setup_fixture({1, 1, 1}, {1, 1, 1}); - - ASSERT_EQ(WF_SIZE, 64); -} - /***************************************************************************** ************************** Activate Lane Helper****************************** *****************************************************************************/ @@ -1402,7 +1396,7 @@ TEST_F(BitwiseTestFixture, fetch_incr_kernel_4_1) { for (size_t i = 0; i < _warp_matrix->rows(); i++) { for (size_t j = 0; j < _warp_matrix->columns(); j++) { auto *elem = _warp_matrix->access(i, j); - ASSERT_EQ(*elem % WF_SIZE, 0); + ASSERT_EQ(*elem % this->_wf_size, 0); } } } @@ -1426,7 +1420,7 @@ TEST_F(BitwiseTestFixture, fetch_incr_kernel_64_1) { for (size_t i = 0; i < _warp_matrix->rows(); i++) { for (size_t j = 0; j < _warp_matrix->columns(); j++) { auto *elem = _warp_matrix->access(i, j); - ASSERT_EQ(*elem % WF_SIZE, 0); + ASSERT_EQ(*elem % this->_wf_size, 0); } } } @@ -1498,7 +1492,7 @@ TEST_F(BitwiseTestFixture, fetch_incr_kernel_1024_1024) { for (size_t i = 0; i < _warp_matrix->rows(); i++) { for (size_t j = 0; j < _warp_matrix->columns(); j++) { auto *elem = _warp_matrix->access(i, j); - ASSERT_EQ(*elem % WF_SIZE, 0); + ASSERT_EQ(*elem % this->_wf_size, 0); } } } @@ -1522,7 +1516,7 @@ TEST_F(BitwiseTestFixture, fetch_incr_logical_1_kernel_1024_1024) { for (size_t i = 0; i < _warp_matrix->rows(); i++) { for (size_t j = 0; j < _warp_matrix->columns(); j++) { auto *elem = _warp_matrix->access(i, j); - ASSERT_EQ(*elem % WF_SIZE, 1); + ASSERT_EQ(*elem % this->_wf_size, 1); } } } diff --git a/tests/unit_tests/bitwise_gtest.hpp b/tests/unit_tests/bitwise_gtest.hpp index cb69b579a0..6ba17c894a 100644 --- a/tests/unit_tests/bitwise_gtest.hpp +++ b/tests/unit_tests/bitwise_gtest.hpp @@ -28,6 +28,7 @@ #define HIP_ENABLE_PRINTF #include "gtest/gtest.h" +#include "wf_size.hpp" #include "../src/memory/hip_allocator.hpp" #include "containers/matrix.hpp" @@ -259,6 +260,7 @@ class BitwiseTestFixture : public ::testing::Test { _hip_block_dim = block_dim; _hip_grid_dim = grid_dim; + _wf_size = get_wf_size(); assert(_device_methods == nullptr); _hip_allocator.allocate(reinterpret_cast(&_device_methods), @@ -341,6 +343,7 @@ class BitwiseTestFixture : public ::testing::Test HIPAllocator _hip_allocator {}; WarpMatrix *_warp_matrix = nullptr; BitwiseDeviceMethods *_device_methods = nullptr; + int _wf_size; }; } // namespace rocshmem diff --git a/tests/unit_tests/free_list_gtest.cpp b/tests/unit_tests/free_list_gtest.cpp index 2be4bd01b2..43454fa24e 100644 --- a/tests/unit_tests/free_list_gtest.cpp +++ b/tests/unit_tests/free_list_gtest.cpp @@ -108,7 +108,7 @@ TYPED_TEST(FreeListTestFixture, push_host_pop_device) { CHECK_HIP(hipMemset(results, 0, size_bytes)); is_empty = reinterpret_cast(results + h_input.size()); - const auto block_size = WF_SIZE; + const auto block_size = this->wf_size; rocshmem::pop_all<<<1, block_size>>>(free_list, results, h_input.size()); CHECK_HIP(hipDeviceSynchronize()); @@ -140,7 +140,7 @@ TYPED_TEST(FreeListTestFixture, push_host_concurrent_pop_device) { CHECK_HIP(hipMemset(results, 0, size_bytes)); is_empty = reinterpret_cast(results + h_input.size()); const auto num_blocks = h_input.size(); - const auto block_size = WF_SIZE; + const auto block_size = this->wf_size; rocshmem::pop_all<<>>( free_list, results, h_input.size()); CHECK_HIP(hipDeviceSynchronize()); @@ -184,7 +184,7 @@ TYPED_TEST(FreeListTestFixture, push_host_pop_push_device) { CHECK_HIP(hipMemset(results, 0, size_bytes)); d_input = reinterpret_cast(results + h_input.size()); is_empty = reinterpret_cast(d_input + h_input.size()); - const auto block_size = WF_SIZE; + const auto block_size = this->wf_size; CHECK_HIP(hipMemcpy(d_input, h_input.data(), sizeof(T) * h_input.size(), hipMemcpyHostToDevice)); @@ -223,7 +223,7 @@ TYPED_TEST(FreeListTestFixture, push_host_pop_concurrent_push_device) { CHECK_HIP(hipMemset(results, 0, size_bytes)); d_input = reinterpret_cast(results + h_input.size()); - const auto block_size = WF_SIZE; + const auto block_size = this->wf_size; CHECK_HIP(hipMemcpy(d_input, h_input.data(), sizeof(T) * h_input.size(), hipMemcpyHostToDevice)); @@ -277,7 +277,7 @@ TYPED_TEST(FreeListTestFixture, push_host_concurrent_pop_push_device) { CHECK_HIP(hipMemcpy(d_input, h_input.data(), sizeof(T) * h_input.size(), hipMemcpyHostToDevice)); - const auto block_size = WF_SIZE; + const auto block_size = this->wf_size; rocshmem::pop_all<<<1, block_size>>>( free_list, nullptr, h_input.size()); CHECK_HIP(hipDeviceSynchronize()); diff --git a/tests/unit_tests/free_list_gtest.hpp b/tests/unit_tests/free_list_gtest.hpp index 0541a4c434..b63ee3d1bc 100644 --- a/tests/unit_tests/free_list_gtest.hpp +++ b/tests/unit_tests/free_list_gtest.hpp @@ -31,6 +31,7 @@ #include "../src/containers/free_list_impl.hpp" #include "gtest/gtest.h" #include "../src/memory/hip_allocator.hpp" +#include "wf_size.hpp" namespace rocshmem { @@ -45,6 +46,7 @@ class FreeListTestFixture : public ::testing::Test { protected: void SetUp() override { free_list->push_back_range(h_input.begin(), h_input.end()); + wf_size = get_wf_size(); } using T = ValueType; @@ -52,6 +54,7 @@ class FreeListTestFixture : public ::testing::Test { Allocator hip_allocator_ {}; const std::size_t num_elements{32}; std::vector h_input{}; + int wf_size; FreeListProxy list_proxy{}; FreeList* free_list{}; diff --git a/tests/unit_tests/wavefront_size_gtest.cpp b/tests/unit_tests/wavefront_size_gtest.cpp index 96a6c35373..47c0fa2e8e 100644 --- a/tests/unit_tests/wavefront_size_gtest.cpp +++ b/tests/unit_tests/wavefront_size_gtest.cpp @@ -28,16 +28,34 @@ using namespace rocshmem; -TEST_F(WavefrontSizeTestFixture, constant_matches_runtime) { - int device_count = 0; - hipDeviceProp_t prop; - - CHECK_HIP(hipGetDeviceCount(&device_count)); - ASSERT_GT(device_count, 0); - - for (int i = 0; i < device_count; i++) { - CHECK_HIP(hipGetDeviceProperties(&prop, i)); - ASSERT_EQ(WF_SIZE, prop.warpSize); +__global__ void check_wf_size(int wf_size_prop, int *ret) { + if (wf_size_prop == WF_SIZE) { + *ret = 0; + } else { + *ret = 1; } } + +TEST_F(WavefrontSizeTestFixture, constant_matches_runtime) { + int device_count = 0; + hipDeviceProp_t prop; + int *ret; + + CHECK_HIP(hipGetDeviceCount(&device_count)); + ASSERT_GT(device_count, 0); + CHECK_HIP(hipHostMalloc(&ret, sizeof(int), 0)); + + for (int i = 0; i < device_count; i++) { + *ret = -1; + CHECK_HIP(hipSetDevice(i)); + CHECK_HIP(hipGetDeviceProperties(&prop, i)); + + check_wf_size<<<1, 1>>>(prop.warpSize, ret); + CHECK_HIP(hipDeviceSynchronize()); + + ASSERT_EQ(*ret, 0); + } + CHECK_HIP(hipHostFree(ret)); +} + diff --git a/tests/unit_tests/wavefront_size_gtest.hpp b/tests/unit_tests/wavefront_size_gtest.hpp index d083011a6d..431a542c57 100644 --- a/tests/unit_tests/wavefront_size_gtest.hpp +++ b/tests/unit_tests/wavefront_size_gtest.hpp @@ -26,10 +26,19 @@ #define ROCSHMEM_WAVEFRONT_SIZE_GTEST_HPP #include "gtest/gtest.h" +#include "wf_size.hpp" namespace rocshmem { -class WavefrontSizeTestFixture : public ::testing::Test { }; +class WavefrontSizeTestFixture : public ::testing::Test { +public: + void SetUp() override { + wf_size = get_wf_size(); + } + +protected: + int wf_size; +}; } // namespace rocshmem diff --git a/tests/unit_tests/wf_size.hpp b/tests/unit_tests/wf_size.hpp new file mode 100644 index 0000000000..35e3ce07ac --- /dev/null +++ b/tests/unit_tests/wf_size.hpp @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef ROCSHMEM_WF_SIZE_HPP +#define ROCSHMEM_WF_SIZE_HPP + +#include +#include "mpi.h" + +#define CHECK_HIP_MPI(cond) { \ + if(cond != hipSuccess){ \ + fprintf(stderr,"HIP error: %d line: %d\n", cond, __LINE__); \ + MPI_Abort(MPI_COMM_WORLD, 1); \ + } \ +} + +static int get_wf_size() { + int deviceId; + hipDeviceProp_t prop; + CHECK_HIP_MPI(hipGetDevice(&deviceId)); + CHECK_HIP_MPI(hipGetDeviceProperties(&prop, deviceId)); + return prop.warpSize; +} + +#endif