diff --git a/projects/rccl/.gitmodules b/projects/rccl/.gitmodules index eae2c515ce..e839548619 100644 --- a/projects/rccl/.gitmodules +++ b/projects/rccl/.gitmodules @@ -8,3 +8,7 @@ url = https://github.com/nlohmann/json.git ignore = dirty shallow = true +[submodule "ext-src/rocSHMEM"] + path = ext-src/rocSHMEM + url = https://github.com/ROCm/rocSHMEM.git + branch = develop diff --git a/projects/rccl/CMakeLists.txt b/projects/rccl/CMakeLists.txt index 933ae4c47d..e36d868243 100644 --- a/projects/rccl/CMakeLists.txt +++ b/projects/rccl/CMakeLists.txt @@ -42,6 +42,7 @@ option(TIMETRACE "Enable time-trace during compila option(TRACE "Enable additional tracing" OFF) option(FAULT_INJECTION "Enable fault injection" ON) option(QUIET_WARNINGS "Supress compiler warnings" OFF) +option(ENABLE_ROCSHMEM "Enable rocSHMEM support in RCCL" OFF) # Default GPU architectures to build #================================================================================================== @@ -65,6 +66,11 @@ include(CheckSymbolExists) include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets include(cmake/CheckSymbolExistsNoWarn.cmake) +# Include rocSHMEM build module only if enabled +if(ENABLE_ROCSHMEM) + include(cmake/ROCSHMEM.cmake) +endif() + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") # Build only for local GPU architecture @@ -321,6 +327,8 @@ if(BUILD_BFD) endif() endif() + + # Check for --amdgpu-kernarg-preload-count check_cxx_compiler_flag("-mllvm --amdgpu-kernarg-preload-count=16" HAVE_KERNARG_PRELOAD) if (HAVE_KERNARG_PRELOAD) @@ -472,6 +480,7 @@ set(SRC_FILES src/device/all_gather.h src/device/all_reduce.h src/device/alltoall_pivot.h + src/device/alltoall_gda.h src/device/broadcast.h src/device/common.h src/device/common_kernel.h @@ -884,6 +893,7 @@ if(ROCTX_ENABLE) target_include_directories(rccl PRIVATE ${ROCTRACER_INCLUDE_DIR}) endif() + ## Set RCCL compile definitions if(COLLTRACE) target_compile_definitions(rccl PRIVATE ENABLE_COLLTRACE) @@ -903,6 +913,28 @@ endif() if(ENABLE_WARP_SPEED) target_compile_definitions(rccl PRIVATE ENABLE_WARP_SPEED) endif() +if(ENABLE_ROCSHMEM) + target_compile_definitions(rccl PRIVATE ENABLE_ROCSHMEM) +endif() + +# ==== rocSHMEM integration (optional) ==== + +if (ENABLE_ROCSHMEM) + add_rocshmem_targets() + # Ensure rocSHMEM is fully built/installed before compiling rccl + if (TARGET rocshmem_ext) + add_dependencies(rccl rocshmem_ext) + endif() + + if (ROCSHMEM_INCLUDE_DIR) + target_include_directories(rccl PRIVATE ${ROCSHMEM_INCLUDE_DIR}) + endif() + + # Moved to where MSCCL target_links + ## target_link_libraries(rccl PRIVATE ${ROCSHMEM_LIBRARY}) + target_link_libraries(rccl PRIVATE ${IBVERBS}) + +endif() # NPKit flags ## May be better to move these to a separate file @@ -1234,6 +1266,10 @@ target_link_libraries(rccl PRIVATE fmt::fmt-header-only) if(ENABLE_MSCCLPP) target_link_libraries(rccl PRIVATE mscclpp_nccl) endif() +if(ENABLE_ROCSHMEM) + target_link_libraries(rccl PRIVATE ${ROCSHMEM_LIBRARY}) + target_link_libraries(rccl PRIVATE ${IBVERBS}) +endif() ## Set RCCL link options ## Find out available memory diff --git a/projects/rccl/cmake/Findrocshmem_static.cmake b/projects/rccl/cmake/Findrocshmem_static.cmake new file mode 100644 index 0000000000..27f4c3a0ed --- /dev/null +++ b/projects/rccl/cmake/Findrocshmem_static.cmake @@ -0,0 +1,35 @@ +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +find_path(ROCSHMEM_INCLUDE_DIR + NAMES rocshmem/rocshmem.hpp rocshmem/rocshmem.h + HINTS ${ROCSHMEM_INSTALL_DIR}/include/) + +find_library(ROCSHMEM_LIBRARY + NAMES rocshmem + HINTS ${ROCSHMEM_INSTALL_DIR}/lib) + +## -- todo --- what to do with verbs? add to handle args call below? -- ## +find_library(IBVERBS ibverbs) + +find_package_handle_standard_args(rocshmem_static DEFAULT_MSG ROCSHMEM_INCLUDE_DIR ROCSHMEM_LIBRARY) +## mark_as_advanced(MSCCLPP_INCLUDE_DIRS MSCCLPP_NCCL_STATIC_LIB) add this for Rocshmem? diff --git a/projects/rccl/cmake/ROCSHMEM.cmake b/projects/rccl/cmake/ROCSHMEM.cmake new file mode 100644 index 0000000000..f1ed469dde --- /dev/null +++ b/projects/rccl/cmake/ROCSHMEM.cmake @@ -0,0 +1,113 @@ +# MIT License +# +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +include(ExternalProject) + +function(add_rocshmem_targets) + + # Check for an existing installation via the user-provided prefix ROCSHMEM_INSTALL DIR + if(ROCSHMEM_INSTALL_DIR) + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + find_package(rocshmem_static) + if(NOT IBVERBS) + find_library(IBVERBS ibverbs) + if(IBVERBS) + set(IBVERBS ${IBVERBS} PARENT_SCOPE) + endif() + endif() + endif() + + # If no pre-existing installation, build from submodule into ext/rocshmem + if(NOT rocshmem_static_FOUND) + set(_rccl_root "${CMAKE_SOURCE_DIR}") + set(ROCSHMEM_SOURCE "${_rccl_root}/ext-src/rocSHMEM") + set(ROCSHMEM_INSTALL_DIR "${_rccl_root}/ext/rocshmem") + + # Make sure submodule exists (same style as MSCCL++: custom rule + target) + add_custom_command( + OUTPUT "${ROCSHMEM_SOURCE}/CMakeLists.txt" + COMMAND git submodule update --init --recursive ext-src/rocSHMEM + WORKING_DIRECTORY "${_rccl_root}" + COMMENT "Checking out submodule: ext-src/rocSHMEM" + VERBATIM + ) + + add_custom_target(rocshmem_checkout_submodule + DEPENDS "${ROCSHMEM_SOURCE}/CMakeLists.txt") + + # Where our patch files live (like MSCCL++) + set(EXT_SOURCE "${_rccl_root}/ext-src") + + # Build and install rocSHMEM. We run `../build_scripts/gdx_bxnt` + # from a 'build' dir just like the README shows. + ExternalProject_Add(rocshmem_ext + SOURCE_DIR "${ROCSHMEM_SOURCE}" + INSTALL_DIR "${ROCSHMEM_INSTALL_DIR}" + UPDATE_DISCONNECTED TRUE + LOG_DOWNLOAD FALSE + LOG_CONFIGURE FALSE + LOG_BUILD FALSE + LOG_INSTALL FALSE + BUILD_IN_SOURCE TRUE + DOWNLOAD_COMMAND "" # using the submodule checkout above + TEST_COMMAND "" + DEPENDS rocshmem_checkout_submodule + + # Rocshmem submodule commit hash -> commit b28a56bd54ccc581d05a439ffa466c3dacb3385 + # The project has its own scripts; we replicate the README sequence: + CONFIGURE_COMMAND "" + BUILD_COMMAND + ${CMAKE_COMMAND} -E make_directory build + && ${CMAKE_COMMAND} -E chdir build bash -lc "../scripts/build_configs/gda_bnxt -DUSE_EXTERNAL_MPI=OFF -DUSE_IPC=ON -DBUILD_EXAMPLES=OFF " + && ${CMAKE_COMMAND} -E chdir build ${CMAKE_COMMAND} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DCMAKE_INSTALL_PREFIX= + -DBUILD_EXAMPLES=OFF .. + && ${CMAKE_COMMAND} -E chdir build ${CMAKE_MAKE_PROGRAM} -j + INSTALL_COMMAND + ${CMAKE_COMMAND} -E chdir build ${CMAKE_MAKE_PROGRAM} install + ) + + # After build, define the variables RCCL expects + set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_INSTALL_DIR}/include" PARENT_SCOPE) + set(ROCSHMEM_LIBRARY "${ROCSHMEM_INSTALL_DIR}/lib/librocshmem.a" PARENT_SCOPE) + find_library(_IBVERBS ibverbs) + if(NOT _IBVERBS) + message(FATAL_ERROR "libibverbs not found (install rdma-core/libibverbs-dev)") + endif() + set(IBVERBS ${_IBVERBS} PARENT_SCOPE) + + # Provide a dummy target other code can depend on + add_custom_target(rocshmem_static ALL DEPENDS rocshmem_ext) + else() + # We found a prebuilt rocSHMEM; export variables upward as-is + set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_INCLUDE_DIR}" PARENT_SCOPE) + set(ROCSHMEM_LIBRARY "${ROCSHMEM_LIBRARY}" PARENT_SCOPE) + + find_library(_IBVERBS ibverbs) + if(NOT _IBVERBS) + message(FATAL_ERROR "libibverbs not found") + endif() + set(IBVERBS ${_IBVERBS} PARENT_SCOPE) + endif() + +endfunction() diff --git a/projects/rccl/ext-src/rocSHMEM b/projects/rccl/ext-src/rocSHMEM new file mode 160000 index 0000000000..b28a56bd54 --- /dev/null +++ b/projects/rccl/ext-src/rocSHMEM @@ -0,0 +1 @@ +Subproject commit b28a56bd54ccc581d05a439ffa466c3dacb33853 diff --git a/projects/rccl/install.sh b/projects/rccl/install.sh index db62ad8dcc..cc57480478 100755 --- a/projects/rccl/install.sh +++ b/projects/rccl/install.sh @@ -41,6 +41,7 @@ force_reduce_pipeline=false generate_sym_kernels=false warp_speed_enabled=true # note that this flag will be overridden to false for non MI350/MI300 platforms quiet_warnings=false +build_rocshmem_support=false # ################################################# # helper functions @@ -82,6 +83,7 @@ function display_help() echo " --force-reduce-pipeline Force reduce_copy sw pipeline to be used for every reduce-based collectives and datatypes" echo " --generate-sym-kernels Generate symmetric memory kernels" echo " -q|--quiet-warnings Suppress majority of compiler warnings (not recommended)" + echo " --rocshmem Build with rocSHMEM support" } # ################################################# @@ -91,7 +93,7 @@ function display_help() # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ "$?" -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose,rocshmem -- "$@") else echo "Need a new version of getopt" exit 1 @@ -140,6 +142,7 @@ while true; do --generate-sym-kernels) generate_sym_kernels=true; shift ;; --disable-warp-speed) warp_speed_enabled=false; shift ;; -q | --quiet-warnings) quiet_warnings=true; shift ;; + --rocshmem) build_rocshmem_support=true; shift ;; --) shift ; break ;; *) echo "Unexpected command line parameter received; aborting"; exit 1 @@ -329,6 +332,14 @@ if [[ "${quiet_warnings}" == true ]]; then fi +# Enable rocSHMEM support +if [[ "${build_rocshmem_support}" == true ]]; then + cmake_common_options="${cmake_common_options} -DENABLE_ROCSHMEM=ON" + cmake_common_options="${cmake_common_options} -DROCSHMEM_INSTALL_DIR=${ROCSHMEM_INSTALL_DIR}" +else + cmake_common_options="${cmake_common_options} -DENABLE_ROCSHMEM=OFF" +fi + check_exit_code "$?" # Enable ninja build for time tracing diff --git a/projects/rccl/src/collectives.cc b/projects/rccl/src/collectives.cc index 7b1de9726e..33a3d5f1cd 100644 --- a/projects/rccl/src/collectives.cc +++ b/projects/rccl/src/collectives.cc @@ -13,6 +13,10 @@ #include "nvtx_payload_schemas.h" #include "msccl/msccl_lifecycle.h" +#ifdef ENABLE_ROCSHMEM +#include +#endif + using namespace rccl; const char* ncclFuncToString(ncclFunc_t fn) { @@ -222,6 +226,8 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun size_t rankOffset = count * ncclTypeSize(datatype); size_t rankAlign = rankOffset & ((~rankOffset) + 1); + size_t msgSize = count * ncclTypeSize(datatype) * comm->nRanks; + // Determine Pivot A2A support now that we know number of channels if (comm->topo->pivotA2AEnabled && comm->nChannels >= comm->topo->pivotA2ANumBiRings * 2 && rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAllToAllPivotEnable()) { @@ -230,7 +236,17 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr }; return ncclEnqueueCheck(&info); } else { +#ifdef ENABLE_ROCSHMEM + if (rcclUseAllToAllGda(comm) && msgSize <= comm->rocshmemThreshold) { + struct ncclInfo info = { ncclFuncAllToAllGda, "AllToAllGda", + sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, + ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr }; + + return ncclEnqueueCheck(&info); + } +#endif int nRanks; + //comm->isA2a = 0; NCCLCHECK(ncclCommCount(comm, &nRanks)); if (count == 0) return ncclSuccess; if (!mscclIsCaller()) Recorder::instance().skip(true); diff --git a/projects/rccl/src/device/alltoall_gda.h b/projects/rccl/src/device/alltoall_gda.h new file mode 100644 index 0000000000..7b9f7fb467 --- /dev/null +++ b/projects/rccl/src/device/alltoall_gda.h @@ -0,0 +1,33 @@ +/************************************************************************* + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "device.h" +#include "collectives.h" +#include "primitives.h" + +#ifdef ENABLE_ROCSHMEM +#include + +template +struct RunWorkColl { + __device__ __forceinline__ void run(int tid, int nThreads, struct ncclDevWorkColl* work) { + if (blockIdx.x == 0) { + int num_pes = rocshmem::rocshmem_n_pes(); + + reduceCopy( + tid, nThreads, 0, nullptr, false, 1, (void **)&work->sendbuff, 1, (void **)&work->sndbuff, + (work->size*num_pes)); + + rocshmem::rocshmem_char_alltoall_wg(work->team, ((char*)work->tempbuff), ((char*)work->sndbuff), work->size); + + reduceCopy( + tid, nThreads, 0, nullptr, false, 1, (void **)&work->tempbuff, 1, (void **)&work->recvbuff, + (work->size*num_pes)); + } + } +}; +#endif + diff --git a/projects/rccl/src/device/generate.py b/projects/rccl/src/device/generate.py index ef2ecf91b2..42e9075289 100755 --- a/projects/rccl/src/device/generate.py +++ b/projects/rccl/src/device/generate.py @@ -5,7 +5,7 @@ import subprocess from dataclasses import dataclass # Order of colls, redops, tys, protos, algos must match src/include/device.h -all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"] +all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot", "AllToAllGda"] all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"] all_protos = ["LL","LL128","SIMPLE"] @@ -79,7 +79,7 @@ func_pattern = sys.argv[6:7] if func_pattern and func_pattern[0]: func_pattern = func_pattern[0] else: - func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv" + func_pattern = "AllGather|AllReduce|AllToAllPivot|AllToAllGda|Broadcast|Reduce|ReduceScatter|SendRecv" ################################################################################ @@ -87,6 +87,7 @@ algos_of_coll = { "AllGather": ["RING", "PAT"], "AllReduce": ["RING", "TREE"], "AllToAllPivot": ["RING"], + "AllToAllGda": ["RING"], "Broadcast": ["RING"], "Reduce": ["RING"], "ReduceScatter": ["RING", "PAT"], @@ -97,6 +98,7 @@ protos_of_coll = { "AllGather": all_protos, "AllReduce": all_protos, "AllToAllPivot": ["SIMPLE"], + "AllToAllGda": ["SIMPLE"], "Broadcast": all_protos, "Reduce": all_protos, "ReduceScatter": all_protos, @@ -107,6 +109,7 @@ redops_of_coll = { "AllGather": ["Sum"], "AllReduce": all_redops, "AllToAllPivot": ["Sum"], + "AllToAllGda": ["Sum"], "Broadcast": ["Sum"], "Reduce": all_redops, "ReduceScatter": all_redops, @@ -117,6 +120,7 @@ tys_of_coll = { "AllGather": ["i8"], "AllReduce": all_tys, "AllToAllPivot": ["i8"], + "AllToAllGda": ["i8"], "Broadcast": ["i8"], "Reduce": all_tys, "ReduceScatter": all_tys, @@ -127,6 +131,7 @@ acc_of_coll = { "AllGather": ["0"], "AllReduce": all_accs, "AllToAllPivot": ["0"], + "AllToAllGda": ["0"], "Broadcast": ["0"], "Reduce": ["0"], "ReduceScatter": ["0"], @@ -137,6 +142,7 @@ pipelines_of_coll = { "AllGather": ["0"], "AllReduce": all_pipelines, "AllToAllPivot": ["0"], + "AllToAllGda": ["0"], "Broadcast": ["0"], "Reduce": all_pipelines, "ReduceScatter": all_pipelines, @@ -148,6 +154,7 @@ coll_camel_to_lower = { "AllGather": "all_gather", "AllReduce": "all_reduce", "AllToAllPivot": "alltoall_pivot", + "AllToAllGda": "alltoall_gda", "Broadcast": "broadcast", "Reduce": "reduce", "ReduceScatter": "reduce_scatter", @@ -503,7 +510,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: ) if fn.coll == "Broadcast": key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8)) - if fn.coll in ["SendRecv", "AllToAllPivot"]: + if fn.coll in ["SendRecv", "AllToAllPivot", "AllToAllGda"]: key = ((coll_idx & 0x3F)) out(f' {{{key}, {fn_id}}}, {comment}\n') diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index 578dc44538..dd42b9c1fc 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -11,6 +11,7 @@ #include "coll_net.h" #include "graph/topo.h" #include +#include #include #include "gdrwrap.h" #include "bootstrap.h" @@ -29,6 +30,10 @@ #include #include "latency_profiler/CollTraceFunc.h" +#ifdef ENABLE_ROCSHMEM +#include +#endif + using namespace rccl; struct ncclKernelMatch { @@ -36,6 +41,7 @@ struct ncclKernelMatch { bool specialized; }; + #ifdef ENABLE_COLLTRACE #define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0)) static ncclKernelMatch const ncclKerns[6] = { @@ -390,6 +396,19 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { devWork.rcclUseOneSlice = comm->rcclUseOneSlice; //[Added-comment] opCount is missing for collDevWork, adding here devWork.opCount = task->opCount; +#ifdef ENABLE_ROCSHMEM + if (comm->enableRocshmem && task->func == ncclFuncAllToAllGda) { + devWork.enableRocshmem = comm->enableRocshmem; + devWork.team = comm->team_reduce_world_dup; + + devWork.sndbuff = (void*)comm->sourceRshmem[comm->symId]; + devWork.tempbuff = (void*)comm->destRshmem[comm->symId]; + + comm->symId = (comm->symId + 1) % comm->numSymBuf; + + devWork.size = task->count; + } +#endif devWork.isOneRPN = comm->isOneRPN; devWork.netRegUsed = devWork.regUsed = 0; @@ -730,8 +749,10 @@ static ncclResult_t scheduleCollTasksToPlan( proxyOp.incWorkCounter = true; addWorkBatchToPlan(comm, plan, c, workNode->workType, task->devFuncId, plan->workBytes); // Set pattern to profiler to add a proxy profiler for kernel events - NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp)); - NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp)); + if (task->func != ncclFuncAllToAllGda) { + NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp)); + NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp)); + } } } else { // not task->isCollnet int trafficPerByte = ncclFuncTrafficPerByte(task->func, comm->nRanks); @@ -877,8 +898,10 @@ static ncclResult_t scheduleCollTasksToPlan( // Coverity reports "proxyOp->connection" as being possibly uninitialized. It's hard to // determine if that's actually true but it's also not clear if that would be an issue. // coverity[uninit_use_in_call:FALSE] - NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp)); - NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp)); + if (task->func != ncclFuncAllToAllGda) { + NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp)); + NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp)); + } } } @@ -1508,8 +1531,8 @@ static ncclResult_t hostStreamPlanTask(struct ncclComm* comm, struct ncclKernelP NCCLCHECK(ncclProfilerStartGroupEvent(plan)); NCCLCHECK(ncclProfilerStartTaskEvents(plan)); if (ncclIntruQueueHead(&plan->proxyOpQueue)) { - NCCLCHECK(uploadProxyOps(comm, plan)); - NCCLCHECK(ncclProxyStart(comm)); + NCCLCHECK(uploadProxyOps(comm, plan)); + NCCLCHECK(ncclProxyStart(comm)); } NCCLCHECK(ncclProfilerStopTaskEvents(plan)); NCCLCHECK(ncclProfilerStopGroupEvent(plan)); @@ -1788,6 +1811,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan latency_profiler::collTraceRecordStartEvent(comm, launchStream, event.get()); comm->lastStream = planner->streams->stream; CUDACHECKGOTO(hipExtLaunchKernel(plan->kernelFn, grid, block, extra, 0, launchStream, NULL, comm->doneEvent, 0), ret, do_return); + latency_profiler::collTraceRecordEndEvent(comm, plan, launchStream, std::move(event)); return ncclSuccess; } @@ -2023,7 +2047,7 @@ static ncclResult_t updateCollCostTable( float** collCostTable) { float (*table)[NCCL_NUM_PROTOCOLS] = (float (*)[NCCL_NUM_PROTOCOLS])collCostTable; - if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot) { + if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot || info->func == ncclFuncAllToAllGda) { table[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE] = 0.0; return ncclSuccess; } @@ -2327,6 +2351,9 @@ static ncclResult_t calcCollChunking( case ncclFuncAllToAllPivot: pattern = ncclPatternRing; break; + case ncclFuncAllToAllGda: + pattern = ncclPatternRing; + break; case ncclFuncAllReduce: pattern = info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : @@ -2749,7 +2776,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) { t->root = info->root; t->datatype = info->datatype; size_t elementSize = ncclTypeSize(t->datatype); - if (t->func == ncclFuncAllGather || t->func == ncclFuncBroadcast || t->func == ncclFuncAllToAllPivot) { + if (t->func == ncclFuncAllGather || t->func == ncclFuncBroadcast || t->func == ncclFuncAllToAllPivot || t->func == ncclFuncAllToAllGda) { t->count *= elementSize; t->datatype = ncclInt8; elementSize = 1; diff --git a/projects/rccl/src/include/comm.h b/projects/rccl/src/include/comm.h index 1f8d38b0ba..d9fedeaa89 100644 --- a/projects/rccl/src/include/comm.h +++ b/projects/rccl/src/include/comm.h @@ -24,6 +24,10 @@ #include "rccl_common.h" #include "recorder.h" +#ifdef ENABLE_ROCSHMEM +#include +#endif + #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) #define HIPRT_CB #else @@ -725,6 +729,17 @@ struct ncclComm { // multiProcessorCount from hipDeviceProp_t [RCCL] int cuCount; +#ifdef ENABLE_ROCSHMEM + // circular ring buffer in rocshmem symmetric heap + void** sourceRshmem; + void** destRshmem; + rocshmem::rocshmem_team_t team_reduce_world_dup; + int enableRocshmem; + int rocshmemThreshold; + int numSymBuf; + int symId; +#endif + uint64_t endMagic; }; diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index f722f1fe21..f19038d9c9 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -38,7 +38,11 @@ #include #include "debug.h" -extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2]; +#ifdef ENABLE_ROCSHMEM +#include +#endif + +extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3]; extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS]; @@ -397,6 +401,14 @@ struct alignas(16) ncclDevWorkColl { }; uint64_t redOpArg; uint64_t opCount; + +#ifdef ENABLE_ROCSHMEM + rocshmem::rocshmem_team_t team; + int enableRocshmem; + void* tempbuff; + void* sndbuff; + int size; +#endif }; @@ -784,7 +796,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto, if (coll == ncclFuncBroadcast) { key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) | ((uint64_t)(proto & RCCL_FUNC_ID_MASK) << RCCL_PROTO_SHIFT); - } else if (coll == ncclFuncSendRecv || coll == ncclFuncAllToAllPivot) { + } else if (coll == ncclFuncSendRecv || coll == ncclFuncAllToAllPivot || coll == ncclFuncAllToAllGda) { key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ); } else { key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) | diff --git a/projects/rccl/src/include/nccl_common.h b/projects/rccl/src/include/nccl_common.h index c2140289ed..b1b419c1d7 100644 --- a/projects/rccl/src/include/nccl_common.h +++ b/projects/rccl/src/include/nccl_common.h @@ -64,7 +64,8 @@ typedef enum { ncclFuncSend = 6, ncclFuncRecv = 7, ncclFuncAllToAllPivot = 8, - ncclNumFuncs = 9 + ncclFuncAllToAllGda = 9, + ncclNumFuncs = 10 } ncclFunc_t; #define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*/PAT diff --git a/projects/rccl/src/include/rccl_common.h b/projects/rccl/src/include/rccl_common.h index 9df5a2c9e2..e0409e2368 100644 --- a/projects/rccl/src/include/rccl_common.h +++ b/projects/rccl/src/include/rccl_common.h @@ -112,6 +112,7 @@ NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll, NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName); NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName); bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize); +bool rcclUseAllToAllGda(struct ncclComm* comm); void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable); void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize); ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount); diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index a139da61eb..014afe2d3c 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -56,6 +56,12 @@ #include "rccl_common.h" // [/RCCL] +#ifdef ENABLE_ROCSHMEM +#include +#define NUM_SYM_BUF 8 +#endif + + #include "msccl/msccl_lifecycle.h" #include "msccl/msccl_status.h" #include "latency_profiler/CollTrace.h" @@ -78,7 +84,7 @@ using namespace rccl; -const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "AllGather", "AllReduce", "AllToAllPivot", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"}; +const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3] = { "AllGather", "AllReduce", "AllToAllPivot", "AllToAllGda", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"}; const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree", "PAT" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "MinMax", "PreMulSum", "SumPostDiv" }; @@ -97,6 +103,13 @@ NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", NCCL_CONFIG_UNDEF_INT); struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {}; ncclResult_t commReclaim(ncclComm_t comm); + +#ifdef ENABLE_ROCSHMEM +RCCL_PARAM(RocshmemThreshold, "ROCSHMEM_THRESHOLD", (size_t)(262144)); +RCCL_PARAM(RocshmemEnabled, "ROCSHMEM_ENABLE", 1); +std::unordered_map ncclCommToRshmemTeam; +#endif + #ifdef ENABLE_MSCCLPP size_t std::hash::operator ()(const ncclUniqueId& uniqueId) const noexcept { return (size_t)getHash(uniqueId.internal, NCCL_UNIQUE_ID_BYTES); @@ -2107,6 +2120,58 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { // RCCL: determine and set unroll factor for comm NCCLCHECK(commSetUnrollFactor(comm)); +#ifdef ENABLE_ROCSHMEM + if (rcclParamRocshmemEnabled()) { // @TODO - This doesn't seem to disable when I set ROCSHMEM_ENABLE=0 on command line + INFO(NCCL_INIT,"Initializing rocSHMEM inside of RCCL"); + int ret; + rocshmem::rocshmem_uniqueid_t rocshmemUniqueId; + rocshmem::rocshmem_init_attr_t rocshmemAttr; + + if(comm->rank == 0 ) { + ret = rocshmem::rocshmem_get_uniqueid (&rocshmemUniqueId); + if (ret != rocshmem::ROCSHMEM_SUCCESS) { + ERROR("Error in rocshmem_get_uniqueid, Rocshmem cannot be initialized."); + return ncclSystemError; + } + } + + NCCLCHECKGOTO(bootstrapBroadcast(comm->bootstrap, comm->rank, comm->nRanks, 0, &rocshmemUniqueId, + sizeof(rocshmemUniqueId)), res, fail); + ret = rocshmem::rocshmem_set_attr_uniqueid_args(job->myrank, job->nranks, &rocshmemUniqueId, &rocshmemAttr); + if (ret != rocshmem::ROCSHMEM_SUCCESS) { + ERROR("Error in rocshmem_set_attr_uniqueid_args, Rocshmem cannot be initialized."); + return ncclSystemError; + } + + ret = rocshmem::rocshmem_init_attr(rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID, &rocshmemAttr); + if (ret != rocshmem::ROCSHMEM_SUCCESS) { + ERROR("Error in rocshmem_init_attr, Rocshmem cannot be initialized."); + return ncclSystemError; + } + + comm->sourceRshmem = (void**) malloc(NUM_SYM_BUF * sizeof(void *)); + comm->destRshmem = (void**) malloc(NUM_SYM_BUF * sizeof(void *)); + + for (int i = 0; i < NUM_SYM_BUF; i++) { + comm->sourceRshmem[i] = (void *)rocshmem::rocshmem_malloc((size_t)(1*1024*1024)); + comm->destRshmem[i] = (void *)rocshmem::rocshmem_malloc((size_t)(1*1024*1024)); + } + + comm->enableRocshmem = rcclParamRocshmemEnabled(); + comm->rocshmemThreshold = rcclParamRocshmemThreshold(); + comm->numSymBuf = NUM_SYM_BUF; + comm->symId = 0; + //rocshmem::rocshmem_team_t team_reduce_world_dup; + comm->team_reduce_world_dup = rocshmem::ROCSHMEM_TEAM_INVALID; + rocshmem::rocshmem_team_split_strided(rocshmem::ROCSHMEM_TEAM_WORLD, 0, 1, job->nranks, nullptr, 0, + &(comm->team_reduce_world_dup)); + + ncclCommToRshmemTeam[comm] = comm->team_reduce_world_dup; + CUDACHECK(hipDeviceSynchronize()); + } +#endif + + #ifdef ENABLE_MSCCLPP if (job->parent) { if (job->parent->mscclppCompatible) { @@ -2935,6 +3000,28 @@ ncclResult_t ncclCommDestroy_impl(ncclComm_t comm) { } #endif +#ifdef ENABLE_ROCSHMEM + if (comm->enableRocshmem) { + for (int i = 0; i < NUM_SYM_BUF; i++) { + rocshmem::rocshmem_free(comm->sourceRshmem[i]); + rocshmem::rocshmem_free(comm->destRshmem[i]); + } + free(comm->sourceRshmem); + free(comm->destRshmem); + + //TODO: subcomm check + rocshmem::rocshmem_team_t team; + if (!ncclCommToRshmemTeam.empty()) { + team = ncclCommToRshmemTeam[comm]; + rocshmem::rocshmem_team_destroy(team); + ncclCommToRshmemTeam.erase(comm); + } + if (ncclCommToRshmemTeam.empty()) { + rocshmem::rocshmem_finalize(); + } + } +#endif + int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev; struct ncclCommFinalizeAsyncJob *job = NULL; ncclResult_t res = ncclSuccess; diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index 2bce6327aa..5b3c312bdd 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -386,6 +386,18 @@ ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) { return ncclSuccess; } +bool rcclUseAllToAllGda(struct ncclComm* comm) { + + //TODO: enable on MI350; currently tested on MI300X +#ifdef ENABLE_ROCSHMEM + if (comm->enableRocshmem && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes > 1 && (comm->nRanks/comm->nNodes == 8) && comm->rocshmemThreshold <= 1048576) { + INFO(NCCL_INIT, "Enabling GDA alltoall for RCCL"); + return true; + } +#endif + return false; +} + bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) { // Check if user explicitly disabled direct AllGather static int userDirectAllGatherInput = -2;