From 64cf812da02ebdf93b6015dcb9dc31c8664a05cb Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Wed, 10 Jan 2024 08:01:11 -0800 Subject: [PATCH] Re-enable L128 on gfx90a of compiler supports it (#1036) [ROCm/rccl commit: 5851ae597425bc75482ec1656ce1b861958074e9] --- projects/rccl/CMakeLists.txt | 4 ++++ projects/rccl/src/collectives/device/common.h | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/projects/rccl/CMakeLists.txt b/projects/rccl/CMakeLists.txt index 787111f7ab..ff16682ce4 100644 --- a/projects/rccl/CMakeLists.txt +++ b/projects/rccl/CMakeLists.txt @@ -601,6 +601,10 @@ if(DEMANGLE_DIR) target_compile_definitions(rccl PRIVATE "HAVE_CPLUS_DEMANGLE=1") target_compile_definitions(rccl PRIVATE "HAVE_DECL_BASENAME=1") endif() +if(${hipcc_version_string} VERSION_GREATER_EQUAL "6.1.33591") + target_compile_definitions(rccl PRIVATE ENABLE_LL128) + message(STATUS "RCCL LL128 protocol enabled") +endif() ## Set RCCL compile options target_compile_options(rccl PRIVATE -parallel-jobs=12) diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index 4d22026d3d..afb9b85ba0 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -32,7 +32,7 @@ { __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); } #endif -#ifdef ENABLE_LL128 +#if defined(ENABLE_LL128) && defined(__gfx90a__) #define NCCL_FUNC5(func, algo, devredop, type, nullify) \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ @@ -571,7 +571,7 @@ __forceinline__ __device__ void ncclKernel( #if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) ncclFuncs[ncclShmem.work.header.funcIndex](); #else -#ifdef ENABLE_LL128 +#if defined(ENABLE_LL128) && defined(__gfx90a__) NCCL_CALL_FUNCTIONS<1>(ncclShmem.work.header.funcIndex); #else NCCL_CALL_FUNCTIONS<0>(ncclShmem.work.header.funcIndex); @@ -640,7 +640,7 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev #endif // Only generate inline kernels for LL -#ifdef ENABLE_LL128 +#if defined(ENABLE_LL128) && defined(__gfx90a__) #define IMPL_COLL4(func, algo, devredop, type) \ IMPL_COLL_FUNC(func, algo, LL, devredop, type) \ IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \