diff --git a/projects/rocshmem/CMakeLists.txt b/projects/rocshmem/CMakeLists.txt index f9c3d83d10..5cf74fb6cb 100644 --- a/projects/rocshmem/CMakeLists.txt +++ b/projects/rocshmem/CMakeLists.txt @@ -68,6 +68,7 @@ option(USE_FUNC_CALL "Force compiler to use function calls on library API" OFF) option(USE_SHARED_CTX "Request support for shared ctx between WG" OFF) option(USE_SINGLE_NODE "Enable single node support only." OFF) option(USE_HOST_SIDE_HDP_FLUSH "Use a polling thread to flush the HDP cache on the host." OFF) +option(USE_COOPERATIVE_GROUPS "Use cooperative groups for internal syncronization" OFF) option(BUILD_FUNCTIONAL_TESTS "Build the functional tests" ON) option(BUILD_SOS_TESTS "Build the host-facing tests" OFF) option(BUILD_UNIT_TESTS "Build the unit tests" ON) diff --git a/projects/rocshmem/cmake/config.h.in b/projects/rocshmem/cmake/config.h.in index 8067651a01..21619a6c3e 100644 --- a/projects/rocshmem/cmake/config.h.in +++ b/projects/rocshmem/cmake/config.h.in @@ -14,3 +14,4 @@ #cmakedefine USE_FUNC_CALL #cmakedefine USE_SINGLE_NODE #cmakedefine USE_HOST_SIDE_HDP_FLUSH +#cmakedefine USE_COOPERATIVE_GROUPS diff --git a/projects/rocshmem/scripts/build_configs/ipc_single_cg b/projects/rocshmem/scripts/build_configs/ipc_single_cg new file mode 100755 index 0000000000..c0904bad93 --- /dev/null +++ b/projects/rocshmem/scripts/build_configs/ipc_single_cg @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +if [ -z $1 ] +then + install_path=~/rocshmem +else + install_path=$1 +fi + +src_path=$(dirname "$(realpath $0)")/../../ + +cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=$install_path \ + -DCMAKE_VERBOSE_MAKEFILE=OFF \ + -DDEBUG=OFF \ + -DPROFILE=OFF \ + -DUSE_GPU_IB=OFF \ + -DUSE_RO=OFF \ + -DUSE_DC=OFF \ + -DUSE_IPC=ON \ + -DUSE_COHERENT_HEAP=ON \ + -DUSE_THREADS=OFF \ + -DUSE_WF_COAL=OFF \ + -DUSE_SINGLE_NODE=ON \ + -DUSE_HOST_SIDE_HDP_FLUSH=OFF \ + -DUSE_COOPERATIVE_GROUPS=ON \ + $src_path +cmake --build . --parallel 8 +cmake --install . diff --git a/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp b/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp index 1465b8ffe0..5e1f950c3b 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp @@ -26,6 +26,11 @@ #include "../util.hpp" #include "ipc_team.hpp" +#ifdef USE_COOPERATIVE_GROUPS +#include +namespace cg = cooperative_groups; +#endif /* USE_COOPERATIVE_GROUPS */ + namespace rocshmem { __device__ void IPCContext::internal_direct_barrier(int pe, int PE_start, @@ -84,8 +89,14 @@ __device__ void IPCContext::internal_atomic_barrier(int pe, int PE_start, // Uses PE values that are relative to world __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride, int PE_size, int64_t *pSync) { +#ifdef USE_COOPERATIVE_GROUPS + cg::grid_group grid = cg::this_grid(); + grid.sync(); + if (0 == grid.thread_rank()) { +#else __syncthreads(); if (is_thread_zero_in_block()) { +#endif /* USE_COOPERATIVE_GROUPS */ if (PE_size < 64) { internal_direct_barrier(pe, PE_start, stride, PE_size, pSync); } else { @@ -93,7 +104,11 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride, } } __threadfence(); +#ifdef USE_COOPERATIVE_GROUPS + grid.sync(); +#else __syncthreads(); +#endif /* USE_COOPERATIVE_GROUPS */ } __device__ void IPCContext::sync(roc_shmem_team_t team) { diff --git a/projects/rocshmem/tests/functional_tests/sync_tester.cpp b/projects/rocshmem/tests/functional_tests/sync_tester.cpp index 54d4a99f99..1afecaa45c 100644 --- a/projects/rocshmem/tests/functional_tests/sync_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/sync_tester.cpp @@ -83,9 +83,21 @@ void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, roc_shmem_team_split_strided(ROC_SHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, &team_sync_world_dup); +#ifdef USE_COOPERATIVE_GROUPS + void* kernelParams[] = {(void*)&loop, + (void*)&args.skip, + (void*)&timer, + (void*)&_type, + (void*)&_shmem_context, + (void*)&team_sync_world_dup}; + + CHECK_HIP(hipLaunchCooperativeKernel(SyncTest, gridSize, blockSize, + kernelParams, shared_bytes, stream)); +#else hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, timer, _type, _shmem_context, team_sync_world_dup); +#endif /* USE_COOPERATIVE_GROUPS */ num_msgs = (loop + args.skip) * gridSize.x; num_timed_msgs = loop;