GDA support for alltoall via rocshmem integration (#2099)
* ROCSHMEM linking/building to match MSCCL++ style * add rocSHMEM as a submodule * Move rocSHMEM submodule to ext-src/rocSHMEM * Adding submodule support proper, as well as a patch for rocshmem * Cleaning up INCLUDE_DIR vs INCLUDE_DIRS mixup * updating patch file * Pointing rocshmem submodule to edgars fixup patch * Adding IBVERBS link to the submodule build * More IBVERBS patching * pin rocshmem submodule tob534423* Adding IPC support in rocSHMEM build * updating rocshmem submodule to resolve CQ errors * Updating submodule to include recent a2a optimizations * invoke rocshmem alltoall from rccl * Updating submodule to CQ error number hang * Updating submodule to include a2a improvements and bug fixes * Updating submodule to point to Yiltan's fork and doorbell ring removal commit * Updating hash to correspond with submodule change * Updating to no-ctx wg call and updating submodule * copy-in/copy-out using multiples CUs * Updating rocSHMEM submodule to include doorbell improvs * updating gitmodule to point to upstream * code cleanup and adjust threashold * guard rocshmem a2a invocation * Only build with rocshmem when specified * code cleanup * address review comments * Removing debugging failure case Signed-off-by: Thomas Huber <thomas.huber@amd.com> * whitespace fix * Adding rocshmem compile guard * Removing unneccesary comment Signed-off-by: Thomas Huber <thomas.huber@amd.com> * remove commented lines * address review comments * cleanup --------- Signed-off-by: Thomas Huber <thomas.huber@amd.com> Co-authored-by: Thomas Huber <thomas.huber@amd.com> Co-authored-by: Nusrat Islam <nusislam@dell300x-ccs-aus-k12-27.cs-aus.dcgpu> Co-authored-by: Nusrat Islam <nusislam@dell300x-ccs-aus-k13-09.cs-aus.dcgpu> Co-authored-by: Islam <nusislam@amd.com> Co-authored-by: Nusrat Islam <nusislam@dell300x-ccs-aus-k13-03.cs-aus.dcgpu> [ROCm/rccl commit:27648b0900]
这个提交包含在:
@@ -8,3 +8,7 @@
|
||||
url = https://github.com/nlohmann/json.git
|
||||
ignore = dirty
|
||||
shallow = true
|
||||
[submodule "ext-src/rocSHMEM"]
|
||||
path = ext-src/rocSHMEM
|
||||
url = https://github.com/ROCm/rocSHMEM.git
|
||||
branch = develop
|
||||
|
||||
@@ -42,6 +42,7 @@ option(TIMETRACE "Enable time-trace during compila
|
||||
option(TRACE "Enable additional tracing" OFF)
|
||||
option(FAULT_INJECTION "Enable fault injection" ON)
|
||||
option(QUIET_WARNINGS "Supress compiler warnings" OFF)
|
||||
option(ENABLE_ROCSHMEM "Enable rocSHMEM support in RCCL" OFF)
|
||||
|
||||
# Default GPU architectures to build
|
||||
#==================================================================================================
|
||||
@@ -65,6 +66,11 @@ include(CheckSymbolExists)
|
||||
include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets
|
||||
include(cmake/CheckSymbolExistsNoWarn.cmake)
|
||||
|
||||
# Include rocSHMEM build module only if enabled
|
||||
if(ENABLE_ROCSHMEM)
|
||||
include(cmake/ROCSHMEM.cmake)
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
# Build only for local GPU architecture
|
||||
@@ -321,6 +327,8 @@ if(BUILD_BFD)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
# Check for --amdgpu-kernarg-preload-count
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-kernarg-preload-count=16" HAVE_KERNARG_PRELOAD)
|
||||
if (HAVE_KERNARG_PRELOAD)
|
||||
@@ -472,6 +480,7 @@ set(SRC_FILES
|
||||
src/device/all_gather.h
|
||||
src/device/all_reduce.h
|
||||
src/device/alltoall_pivot.h
|
||||
src/device/alltoall_gda.h
|
||||
src/device/broadcast.h
|
||||
src/device/common.h
|
||||
src/device/common_kernel.h
|
||||
@@ -884,6 +893,7 @@ if(ROCTX_ENABLE)
|
||||
target_include_directories(rccl PRIVATE ${ROCTRACER_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
|
||||
## Set RCCL compile definitions
|
||||
if(COLLTRACE)
|
||||
target_compile_definitions(rccl PRIVATE ENABLE_COLLTRACE)
|
||||
@@ -903,6 +913,28 @@ endif()
|
||||
if(ENABLE_WARP_SPEED)
|
||||
target_compile_definitions(rccl PRIVATE ENABLE_WARP_SPEED)
|
||||
endif()
|
||||
if(ENABLE_ROCSHMEM)
|
||||
target_compile_definitions(rccl PRIVATE ENABLE_ROCSHMEM)
|
||||
endif()
|
||||
|
||||
# ==== rocSHMEM integration (optional) ====
|
||||
|
||||
if (ENABLE_ROCSHMEM)
|
||||
add_rocshmem_targets()
|
||||
# Ensure rocSHMEM is fully built/installed before compiling rccl
|
||||
if (TARGET rocshmem_ext)
|
||||
add_dependencies(rccl rocshmem_ext)
|
||||
endif()
|
||||
|
||||
if (ROCSHMEM_INCLUDE_DIR)
|
||||
target_include_directories(rccl PRIVATE ${ROCSHMEM_INCLUDE_DIR})
|
||||
endif()
|
||||
|
||||
# Moved to where MSCCL target_links
|
||||
## target_link_libraries(rccl PRIVATE ${ROCSHMEM_LIBRARY})
|
||||
target_link_libraries(rccl PRIVATE ${IBVERBS})
|
||||
|
||||
endif()
|
||||
|
||||
# NPKit flags
|
||||
## May be better to move these to a separate file
|
||||
@@ -1234,6 +1266,10 @@ target_link_libraries(rccl PRIVATE fmt::fmt-header-only)
|
||||
if(ENABLE_MSCCLPP)
|
||||
target_link_libraries(rccl PRIVATE mscclpp_nccl)
|
||||
endif()
|
||||
if(ENABLE_ROCSHMEM)
|
||||
target_link_libraries(rccl PRIVATE ${ROCSHMEM_LIBRARY})
|
||||
target_link_libraries(rccl PRIVATE ${IBVERBS})
|
||||
endif()
|
||||
|
||||
## Set RCCL link options
|
||||
## Find out available memory
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
find_path(ROCSHMEM_INCLUDE_DIR
|
||||
NAMES rocshmem/rocshmem.hpp rocshmem/rocshmem.h
|
||||
HINTS ${ROCSHMEM_INSTALL_DIR}/include/)
|
||||
|
||||
find_library(ROCSHMEM_LIBRARY
|
||||
NAMES rocshmem
|
||||
HINTS ${ROCSHMEM_INSTALL_DIR}/lib)
|
||||
|
||||
## -- todo --- what to do with verbs? add to handle args call below? -- ##
|
||||
find_library(IBVERBS ibverbs)
|
||||
|
||||
find_package_handle_standard_args(rocshmem_static DEFAULT_MSG ROCSHMEM_INCLUDE_DIR ROCSHMEM_LIBRARY)
|
||||
## mark_as_advanced(MSCCLPP_INCLUDE_DIRS MSCCLPP_NCCL_STATIC_LIB) add this for Rocshmem?
|
||||
@@ -0,0 +1,113 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
include(ExternalProject)
|
||||
|
||||
function(add_rocshmem_targets)
|
||||
|
||||
# Check for an existing installation via the user-provided prefix ROCSHMEM_INSTALL DIR
|
||||
if(ROCSHMEM_INSTALL_DIR)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
find_package(rocshmem_static)
|
||||
if(NOT IBVERBS)
|
||||
find_library(IBVERBS ibverbs)
|
||||
if(IBVERBS)
|
||||
set(IBVERBS ${IBVERBS} PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# If no pre-existing installation, build from submodule into ext/rocshmem
|
||||
if(NOT rocshmem_static_FOUND)
|
||||
set(_rccl_root "${CMAKE_SOURCE_DIR}")
|
||||
set(ROCSHMEM_SOURCE "${_rccl_root}/ext-src/rocSHMEM")
|
||||
set(ROCSHMEM_INSTALL_DIR "${_rccl_root}/ext/rocshmem")
|
||||
|
||||
# Make sure submodule exists (same style as MSCCL++: custom rule + target)
|
||||
add_custom_command(
|
||||
OUTPUT "${ROCSHMEM_SOURCE}/CMakeLists.txt"
|
||||
COMMAND git submodule update --init --recursive ext-src/rocSHMEM
|
||||
WORKING_DIRECTORY "${_rccl_root}"
|
||||
COMMENT "Checking out submodule: ext-src/rocSHMEM"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(rocshmem_checkout_submodule
|
||||
DEPENDS "${ROCSHMEM_SOURCE}/CMakeLists.txt")
|
||||
|
||||
# Where our patch files live (like MSCCL++)
|
||||
set(EXT_SOURCE "${_rccl_root}/ext-src")
|
||||
|
||||
# Build and install rocSHMEM. We run `../build_scripts/gdx_bxnt`
|
||||
# from a 'build' dir just like the README shows.
|
||||
ExternalProject_Add(rocshmem_ext
|
||||
SOURCE_DIR "${ROCSHMEM_SOURCE}"
|
||||
INSTALL_DIR "${ROCSHMEM_INSTALL_DIR}"
|
||||
UPDATE_DISCONNECTED TRUE
|
||||
LOG_DOWNLOAD FALSE
|
||||
LOG_CONFIGURE FALSE
|
||||
LOG_BUILD FALSE
|
||||
LOG_INSTALL FALSE
|
||||
BUILD_IN_SOURCE TRUE
|
||||
DOWNLOAD_COMMAND "" # using the submodule checkout above
|
||||
TEST_COMMAND ""
|
||||
DEPENDS rocshmem_checkout_submodule
|
||||
|
||||
# Rocshmem submodule commit hash -> commit b28a56bd54ccc581d05a439ffa466c3dacb3385
|
||||
# The project has its own scripts; we replicate the README sequence:
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND
|
||||
${CMAKE_COMMAND} -E make_directory build
|
||||
&& ${CMAKE_COMMAND} -E chdir build bash -lc "../scripts/build_configs/gda_bnxt -DUSE_EXTERNAL_MPI=OFF -DUSE_IPC=ON -DBUILD_EXAMPLES=OFF "
|
||||
&& ${CMAKE_COMMAND} -E chdir build ${CMAKE_COMMAND}
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
-DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
|
||||
-DBUILD_EXAMPLES=OFF ..
|
||||
&& ${CMAKE_COMMAND} -E chdir build ${CMAKE_MAKE_PROGRAM} -j
|
||||
INSTALL_COMMAND
|
||||
${CMAKE_COMMAND} -E chdir build ${CMAKE_MAKE_PROGRAM} install
|
||||
)
|
||||
|
||||
# After build, define the variables RCCL expects
|
||||
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_INSTALL_DIR}/include" PARENT_SCOPE)
|
||||
set(ROCSHMEM_LIBRARY "${ROCSHMEM_INSTALL_DIR}/lib/librocshmem.a" PARENT_SCOPE)
|
||||
find_library(_IBVERBS ibverbs)
|
||||
if(NOT _IBVERBS)
|
||||
message(FATAL_ERROR "libibverbs not found (install rdma-core/libibverbs-dev)")
|
||||
endif()
|
||||
set(IBVERBS ${_IBVERBS} PARENT_SCOPE)
|
||||
|
||||
# Provide a dummy target other code can depend on
|
||||
add_custom_target(rocshmem_static ALL DEPENDS rocshmem_ext)
|
||||
else()
|
||||
# We found a prebuilt rocSHMEM; export variables upward as-is
|
||||
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_INCLUDE_DIR}" PARENT_SCOPE)
|
||||
set(ROCSHMEM_LIBRARY "${ROCSHMEM_LIBRARY}" PARENT_SCOPE)
|
||||
|
||||
find_library(_IBVERBS ibverbs)
|
||||
if(NOT _IBVERBS)
|
||||
message(FATAL_ERROR "libibverbs not found")
|
||||
endif()
|
||||
set(IBVERBS ${_IBVERBS} PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
子模块 projects/rccl/ext-src/rocSHMEM 已添加到 b28a56bd54
+12
-1
@@ -41,6 +41,7 @@ force_reduce_pipeline=false
|
||||
generate_sym_kernels=false
|
||||
warp_speed_enabled=true # note that this flag will be overridden to false for non MI350/MI300 platforms
|
||||
quiet_warnings=false
|
||||
build_rocshmem_support=false
|
||||
|
||||
# #################################################
|
||||
# helper functions
|
||||
@@ -82,6 +83,7 @@ function display_help()
|
||||
echo " --force-reduce-pipeline Force reduce_copy sw pipeline to be used for every reduce-based collectives and datatypes"
|
||||
echo " --generate-sym-kernels Generate symmetric memory kernels"
|
||||
echo " -q|--quiet-warnings Suppress majority of compiler warnings (not recommended)"
|
||||
echo " --rocshmem Build with rocSHMEM support"
|
||||
}
|
||||
|
||||
# #################################################
|
||||
@@ -91,7 +93,7 @@ function display_help()
|
||||
# check if we have a modern version of getopt that can handle whitespace and long parameters
|
||||
getopt -T
|
||||
if [[ "$?" -eq 4 ]]; then
|
||||
GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose -- "$@")
|
||||
GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose,rocshmem -- "$@")
|
||||
else
|
||||
echo "Need a new version of getopt"
|
||||
exit 1
|
||||
@@ -140,6 +142,7 @@ while true; do
|
||||
--generate-sym-kernels) generate_sym_kernels=true; shift ;;
|
||||
--disable-warp-speed) warp_speed_enabled=false; shift ;;
|
||||
-q | --quiet-warnings) quiet_warnings=true; shift ;;
|
||||
--rocshmem) build_rocshmem_support=true; shift ;;
|
||||
--) shift ; break ;;
|
||||
*) echo "Unexpected command line parameter received; aborting";
|
||||
exit 1
|
||||
@@ -329,6 +332,14 @@ if [[ "${quiet_warnings}" == true ]]; then
|
||||
fi
|
||||
|
||||
|
||||
# Enable rocSHMEM support
|
||||
if [[ "${build_rocshmem_support}" == true ]]; then
|
||||
cmake_common_options="${cmake_common_options} -DENABLE_ROCSHMEM=ON"
|
||||
cmake_common_options="${cmake_common_options} -DROCSHMEM_INSTALL_DIR=${ROCSHMEM_INSTALL_DIR}"
|
||||
else
|
||||
cmake_common_options="${cmake_common_options} -DENABLE_ROCSHMEM=OFF"
|
||||
fi
|
||||
|
||||
check_exit_code "$?"
|
||||
|
||||
# Enable ninja build for time tracing
|
||||
|
||||
@@ -13,6 +13,10 @@
|
||||
#include "nvtx_payload_schemas.h"
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#endif
|
||||
|
||||
using namespace rccl;
|
||||
|
||||
const char* ncclFuncToString(ncclFunc_t fn) {
|
||||
@@ -222,6 +226,8 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
|
||||
|
||||
size_t rankOffset = count * ncclTypeSize(datatype);
|
||||
size_t rankAlign = rankOffset & ((~rankOffset) + 1);
|
||||
size_t msgSize = count * ncclTypeSize(datatype) * comm->nRanks;
|
||||
|
||||
// Determine Pivot A2A support now that we know number of channels
|
||||
if (comm->topo->pivotA2AEnabled && comm->nChannels >= comm->topo->pivotA2ANumBiRings * 2 &&
|
||||
rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAllToAllPivotEnable()) {
|
||||
@@ -230,7 +236,17 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
|
||||
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
|
||||
return ncclEnqueueCheck(&info);
|
||||
} else {
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
if (rcclUseAllToAllGda(comm) && msgSize <= comm->rocshmemThreshold) {
|
||||
struct ncclInfo info = { ncclFuncAllToAllGda, "AllToAllGda",
|
||||
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream,
|
||||
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
|
||||
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
#endif
|
||||
int nRanks;
|
||||
//comm->isA2a = 0;
|
||||
NCCLCHECK(ncclCommCount(comm, &nRanks));
|
||||
if (count == 0) return ncclSuccess;
|
||||
if (!mscclIsCaller()) Recorder::instance().skip(true);
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "device.h"
|
||||
#include "collectives.h"
|
||||
#include "primitives.h"
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkColl<ncclFuncAllToAllGda, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(int tid, int nThreads, struct ncclDevWorkColl* work) {
|
||||
if (blockIdx.x == 0) {
|
||||
int num_pes = rocshmem::rocshmem_n_pes();
|
||||
|
||||
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
|
||||
tid, nThreads, 0, nullptr, false, 1, (void **)&work->sendbuff, 1, (void **)&work->sndbuff,
|
||||
(work->size*num_pes));
|
||||
|
||||
rocshmem::rocshmem_char_alltoall_wg(work->team, ((char*)work->tempbuff), ((char*)work->sndbuff), work->size);
|
||||
|
||||
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
|
||||
tid, nThreads, 0, nullptr, false, 1, (void **)&work->tempbuff, 1, (void **)&work->recvbuff,
|
||||
(work->size*num_pes));
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -5,7 +5,7 @@ import subprocess
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Order of colls, redops, tys, protos, algos must match src/include/device.h
|
||||
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"]
|
||||
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot", "AllToAllGda"]
|
||||
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
|
||||
all_protos = ["LL","LL128","SIMPLE"]
|
||||
@@ -79,7 +79,7 @@ func_pattern = sys.argv[6:7]
|
||||
if func_pattern and func_pattern[0]:
|
||||
func_pattern = func_pattern[0]
|
||||
else:
|
||||
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
func_pattern = "AllGather|AllReduce|AllToAllPivot|AllToAllGda|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -87,6 +87,7 @@ algos_of_coll = {
|
||||
"AllGather": ["RING", "PAT"],
|
||||
"AllReduce": ["RING", "TREE"],
|
||||
"AllToAllPivot": ["RING"],
|
||||
"AllToAllGda": ["RING"],
|
||||
"Broadcast": ["RING"],
|
||||
"Reduce": ["RING"],
|
||||
"ReduceScatter": ["RING", "PAT"],
|
||||
@@ -97,6 +98,7 @@ protos_of_coll = {
|
||||
"AllGather": all_protos,
|
||||
"AllReduce": all_protos,
|
||||
"AllToAllPivot": ["SIMPLE"],
|
||||
"AllToAllGda": ["SIMPLE"],
|
||||
"Broadcast": all_protos,
|
||||
"Reduce": all_protos,
|
||||
"ReduceScatter": all_protos,
|
||||
@@ -107,6 +109,7 @@ redops_of_coll = {
|
||||
"AllGather": ["Sum"],
|
||||
"AllReduce": all_redops,
|
||||
"AllToAllPivot": ["Sum"],
|
||||
"AllToAllGda": ["Sum"],
|
||||
"Broadcast": ["Sum"],
|
||||
"Reduce": all_redops,
|
||||
"ReduceScatter": all_redops,
|
||||
@@ -117,6 +120,7 @@ tys_of_coll = {
|
||||
"AllGather": ["i8"],
|
||||
"AllReduce": all_tys,
|
||||
"AllToAllPivot": ["i8"],
|
||||
"AllToAllGda": ["i8"],
|
||||
"Broadcast": ["i8"],
|
||||
"Reduce": all_tys,
|
||||
"ReduceScatter": all_tys,
|
||||
@@ -127,6 +131,7 @@ acc_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": all_accs,
|
||||
"AllToAllPivot": ["0"],
|
||||
"AllToAllGda": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": ["0"],
|
||||
"ReduceScatter": ["0"],
|
||||
@@ -137,6 +142,7 @@ pipelines_of_coll = {
|
||||
"AllGather": ["0"],
|
||||
"AllReduce": all_pipelines,
|
||||
"AllToAllPivot": ["0"],
|
||||
"AllToAllGda": ["0"],
|
||||
"Broadcast": ["0"],
|
||||
"Reduce": all_pipelines,
|
||||
"ReduceScatter": all_pipelines,
|
||||
@@ -148,6 +154,7 @@ coll_camel_to_lower = {
|
||||
"AllGather": "all_gather",
|
||||
"AllReduce": "all_reduce",
|
||||
"AllToAllPivot": "alltoall_pivot",
|
||||
"AllToAllGda": "alltoall_gda",
|
||||
"Broadcast": "broadcast",
|
||||
"Reduce": "reduce",
|
||||
"ReduceScatter": "reduce_scatter",
|
||||
@@ -503,7 +510,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
)
|
||||
if fn.coll == "Broadcast":
|
||||
key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8))
|
||||
if fn.coll in ["SendRecv", "AllToAllPivot"]:
|
||||
if fn.coll in ["SendRecv", "AllToAllPivot", "AllToAllGda"]:
|
||||
key = ((coll_idx & 0x3F))
|
||||
|
||||
out(f' {{{key}, {fn_id}}}, {comment}\n')
|
||||
|
||||
+35
-8
@@ -11,6 +11,7 @@
|
||||
#include "coll_net.h"
|
||||
#include "graph/topo.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <hip/hip_ext.h>
|
||||
#include "gdrwrap.h"
|
||||
#include "bootstrap.h"
|
||||
@@ -29,6 +30,10 @@
|
||||
#include <cassert>
|
||||
#include "latency_profiler/CollTraceFunc.h"
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#endif
|
||||
|
||||
using namespace rccl;
|
||||
|
||||
struct ncclKernelMatch {
|
||||
@@ -36,6 +41,7 @@ struct ncclKernelMatch {
|
||||
bool specialized;
|
||||
};
|
||||
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0))
|
||||
static ncclKernelMatch const ncclKerns[6] = {
|
||||
@@ -390,6 +396,19 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) {
|
||||
devWork.rcclUseOneSlice = comm->rcclUseOneSlice;
|
||||
//[Added-comment] opCount is missing for collDevWork, adding here
|
||||
devWork.opCount = task->opCount;
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
if (comm->enableRocshmem && task->func == ncclFuncAllToAllGda) {
|
||||
devWork.enableRocshmem = comm->enableRocshmem;
|
||||
devWork.team = comm->team_reduce_world_dup;
|
||||
|
||||
devWork.sndbuff = (void*)comm->sourceRshmem[comm->symId];
|
||||
devWork.tempbuff = (void*)comm->destRshmem[comm->symId];
|
||||
|
||||
comm->symId = (comm->symId + 1) % comm->numSymBuf;
|
||||
|
||||
devWork.size = task->count;
|
||||
}
|
||||
#endif
|
||||
|
||||
devWork.isOneRPN = comm->isOneRPN;
|
||||
devWork.netRegUsed = devWork.regUsed = 0;
|
||||
@@ -730,8 +749,10 @@ static ncclResult_t scheduleCollTasksToPlan(
|
||||
proxyOp.incWorkCounter = true;
|
||||
addWorkBatchToPlan(comm, plan, c, workNode->workType, task->devFuncId, plan->workBytes);
|
||||
// Set pattern to profiler to add a proxy profiler for kernel events
|
||||
NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp));
|
||||
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp));
|
||||
if (task->func != ncclFuncAllToAllGda) {
|
||||
NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp));
|
||||
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp));
|
||||
}
|
||||
}
|
||||
} else { // not task->isCollnet
|
||||
int trafficPerByte = ncclFuncTrafficPerByte(task->func, comm->nRanks);
|
||||
@@ -877,8 +898,10 @@ static ncclResult_t scheduleCollTasksToPlan(
|
||||
// Coverity reports "proxyOp->connection" as being possibly uninitialized. It's hard to
|
||||
// determine if that's actually true but it's also not clear if that would be an issue.
|
||||
// coverity[uninit_use_in_call:FALSE]
|
||||
NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp));
|
||||
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp));
|
||||
if (task->func != ncclFuncAllToAllGda) {
|
||||
NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp));
|
||||
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1508,8 +1531,8 @@ static ncclResult_t hostStreamPlanTask(struct ncclComm* comm, struct ncclKernelP
|
||||
NCCLCHECK(ncclProfilerStartGroupEvent(plan));
|
||||
NCCLCHECK(ncclProfilerStartTaskEvents(plan));
|
||||
if (ncclIntruQueueHead(&plan->proxyOpQueue)) {
|
||||
NCCLCHECK(uploadProxyOps(comm, plan));
|
||||
NCCLCHECK(ncclProxyStart(comm));
|
||||
NCCLCHECK(uploadProxyOps(comm, plan));
|
||||
NCCLCHECK(ncclProxyStart(comm));
|
||||
}
|
||||
NCCLCHECK(ncclProfilerStopTaskEvents(plan));
|
||||
NCCLCHECK(ncclProfilerStopGroupEvent(plan));
|
||||
@@ -1788,6 +1811,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan
|
||||
latency_profiler::collTraceRecordStartEvent(comm, launchStream, event.get());
|
||||
comm->lastStream = planner->streams->stream;
|
||||
CUDACHECKGOTO(hipExtLaunchKernel(plan->kernelFn, grid, block, extra, 0, launchStream, NULL, comm->doneEvent, 0), ret, do_return);
|
||||
|
||||
latency_profiler::collTraceRecordEndEvent(comm, plan, launchStream, std::move(event));
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -2023,7 +2047,7 @@ static ncclResult_t updateCollCostTable(
|
||||
float** collCostTable) {
|
||||
float (*table)[NCCL_NUM_PROTOCOLS] = (float (*)[NCCL_NUM_PROTOCOLS])collCostTable;
|
||||
|
||||
if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot) {
|
||||
if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot || info->func == ncclFuncAllToAllGda) {
|
||||
table[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE] = 0.0;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -2327,6 +2351,9 @@ static ncclResult_t calcCollChunking(
|
||||
case ncclFuncAllToAllPivot:
|
||||
pattern = ncclPatternRing;
|
||||
break;
|
||||
case ncclFuncAllToAllGda:
|
||||
pattern = ncclPatternRing;
|
||||
break;
|
||||
case ncclFuncAllReduce:
|
||||
pattern =
|
||||
info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
|
||||
@@ -2749,7 +2776,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) {
|
||||
t->root = info->root;
|
||||
t->datatype = info->datatype;
|
||||
size_t elementSize = ncclTypeSize(t->datatype);
|
||||
if (t->func == ncclFuncAllGather || t->func == ncclFuncBroadcast || t->func == ncclFuncAllToAllPivot) {
|
||||
if (t->func == ncclFuncAllGather || t->func == ncclFuncBroadcast || t->func == ncclFuncAllToAllPivot || t->func == ncclFuncAllToAllGda) {
|
||||
t->count *= elementSize;
|
||||
t->datatype = ncclInt8;
|
||||
elementSize = 1;
|
||||
|
||||
@@ -24,6 +24,10 @@
|
||||
#include "rccl_common.h"
|
||||
#include "recorder.h"
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#endif
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
|
||||
#define HIPRT_CB
|
||||
#else
|
||||
@@ -725,6 +729,17 @@ struct ncclComm {
|
||||
// multiProcessorCount from hipDeviceProp_t [RCCL]
|
||||
int cuCount;
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
// circular ring buffer in rocshmem symmetric heap
|
||||
void** sourceRshmem;
|
||||
void** destRshmem;
|
||||
rocshmem::rocshmem_team_t team_reduce_world_dup;
|
||||
int enableRocshmem;
|
||||
int rocshmemThreshold;
|
||||
int numSymBuf;
|
||||
int symId;
|
||||
#endif
|
||||
|
||||
uint64_t endMagic;
|
||||
};
|
||||
|
||||
|
||||
@@ -38,7 +38,11 @@
|
||||
#include <string>
|
||||
#include "debug.h"
|
||||
|
||||
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2];
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#endif
|
||||
|
||||
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3];
|
||||
|
||||
extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS];
|
||||
|
||||
@@ -397,6 +401,14 @@ struct alignas(16) ncclDevWorkColl {
|
||||
};
|
||||
uint64_t redOpArg;
|
||||
uint64_t opCount;
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
rocshmem::rocshmem_team_t team;
|
||||
int enableRocshmem;
|
||||
void* tempbuff;
|
||||
void* sndbuff;
|
||||
int size;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -784,7 +796,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto,
|
||||
if (coll == ncclFuncBroadcast) {
|
||||
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) |
|
||||
((uint64_t)(proto & RCCL_FUNC_ID_MASK) << RCCL_PROTO_SHIFT);
|
||||
} else if (coll == ncclFuncSendRecv || coll == ncclFuncAllToAllPivot) {
|
||||
} else if (coll == ncclFuncSendRecv || coll == ncclFuncAllToAllPivot || coll == ncclFuncAllToAllGda) {
|
||||
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT );
|
||||
} else {
|
||||
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) |
|
||||
|
||||
@@ -64,7 +64,8 @@ typedef enum {
|
||||
ncclFuncSend = 6,
|
||||
ncclFuncRecv = 7,
|
||||
ncclFuncAllToAllPivot = 8,
|
||||
ncclNumFuncs = 9
|
||||
ncclFuncAllToAllGda = 9,
|
||||
ncclNumFuncs = 10
|
||||
} ncclFunc_t;
|
||||
|
||||
#define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*/PAT
|
||||
|
||||
@@ -112,6 +112,7 @@ NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll,
|
||||
NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName);
|
||||
NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName);
|
||||
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize);
|
||||
bool rcclUseAllToAllGda(struct ncclComm* comm);
|
||||
void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable);
|
||||
void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize);
|
||||
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
|
||||
|
||||
+88
-1
@@ -56,6 +56,12 @@
|
||||
#include "rccl_common.h"
|
||||
// [/RCCL]
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
#define NUM_SYM_BUF 8
|
||||
#endif
|
||||
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
#include "msccl/msccl_status.h"
|
||||
#include "latency_profiler/CollTrace.h"
|
||||
@@ -78,7 +84,7 @@
|
||||
|
||||
using namespace rccl;
|
||||
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "AllGather", "AllReduce", "AllToAllPivot", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"};
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3] = { "AllGather", "AllReduce", "AllToAllPivot", "AllToAllGda", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"};
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree", "PAT" };
|
||||
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
|
||||
const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "MinMax", "PreMulSum", "SumPostDiv" };
|
||||
@@ -97,6 +103,13 @@ NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", NCCL_CONFIG_UNDEF_INT);
|
||||
struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {};
|
||||
ncclResult_t commReclaim(ncclComm_t comm);
|
||||
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
RCCL_PARAM(RocshmemThreshold, "ROCSHMEM_THRESHOLD", (size_t)(262144));
|
||||
RCCL_PARAM(RocshmemEnabled, "ROCSHMEM_ENABLE", 1);
|
||||
std::unordered_map<ncclComm_t, rocshmem::rocshmem_team_t> ncclCommToRshmemTeam;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {
|
||||
return (size_t)getHash(uniqueId.internal, NCCL_UNIQUE_ID_BYTES);
|
||||
@@ -2107,6 +2120,58 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
if (rcclParamRocshmemEnabled()) { // @TODO - This doesn't seem to disable when I set ROCSHMEM_ENABLE=0 on command line
|
||||
INFO(NCCL_INIT,"Initializing rocSHMEM inside of RCCL");
|
||||
int ret;
|
||||
rocshmem::rocshmem_uniqueid_t rocshmemUniqueId;
|
||||
rocshmem::rocshmem_init_attr_t rocshmemAttr;
|
||||
|
||||
if(comm->rank == 0 ) {
|
||||
ret = rocshmem::rocshmem_get_uniqueid (&rocshmemUniqueId);
|
||||
if (ret != rocshmem::ROCSHMEM_SUCCESS) {
|
||||
ERROR("Error in rocshmem_get_uniqueid, Rocshmem cannot be initialized.");
|
||||
return ncclSystemError;
|
||||
}
|
||||
}
|
||||
|
||||
NCCLCHECKGOTO(bootstrapBroadcast(comm->bootstrap, comm->rank, comm->nRanks, 0, &rocshmemUniqueId,
|
||||
sizeof(rocshmemUniqueId)), res, fail);
|
||||
ret = rocshmem::rocshmem_set_attr_uniqueid_args(job->myrank, job->nranks, &rocshmemUniqueId, &rocshmemAttr);
|
||||
if (ret != rocshmem::ROCSHMEM_SUCCESS) {
|
||||
ERROR("Error in rocshmem_set_attr_uniqueid_args, Rocshmem cannot be initialized.");
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
ret = rocshmem::rocshmem_init_attr(rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID, &rocshmemAttr);
|
||||
if (ret != rocshmem::ROCSHMEM_SUCCESS) {
|
||||
ERROR("Error in rocshmem_init_attr, Rocshmem cannot be initialized.");
|
||||
return ncclSystemError;
|
||||
}
|
||||
|
||||
comm->sourceRshmem = (void**) malloc(NUM_SYM_BUF * sizeof(void *));
|
||||
comm->destRshmem = (void**) malloc(NUM_SYM_BUF * sizeof(void *));
|
||||
|
||||
for (int i = 0; i < NUM_SYM_BUF; i++) {
|
||||
comm->sourceRshmem[i] = (void *)rocshmem::rocshmem_malloc((size_t)(1*1024*1024));
|
||||
comm->destRshmem[i] = (void *)rocshmem::rocshmem_malloc((size_t)(1*1024*1024));
|
||||
}
|
||||
|
||||
comm->enableRocshmem = rcclParamRocshmemEnabled();
|
||||
comm->rocshmemThreshold = rcclParamRocshmemThreshold();
|
||||
comm->numSymBuf = NUM_SYM_BUF;
|
||||
comm->symId = 0;
|
||||
//rocshmem::rocshmem_team_t team_reduce_world_dup;
|
||||
comm->team_reduce_world_dup = rocshmem::ROCSHMEM_TEAM_INVALID;
|
||||
rocshmem::rocshmem_team_split_strided(rocshmem::ROCSHMEM_TEAM_WORLD, 0, 1, job->nranks, nullptr, 0,
|
||||
&(comm->team_reduce_world_dup));
|
||||
|
||||
ncclCommToRshmemTeam[comm] = comm->team_reduce_world_dup;
|
||||
CUDACHECK(hipDeviceSynchronize());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (job->parent) {
|
||||
if (job->parent->mscclppCompatible) {
|
||||
@@ -2935,6 +3000,28 @@ ncclResult_t ncclCommDestroy_impl(ncclComm_t comm) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
if (comm->enableRocshmem) {
|
||||
for (int i = 0; i < NUM_SYM_BUF; i++) {
|
||||
rocshmem::rocshmem_free(comm->sourceRshmem[i]);
|
||||
rocshmem::rocshmem_free(comm->destRshmem[i]);
|
||||
}
|
||||
free(comm->sourceRshmem);
|
||||
free(comm->destRshmem);
|
||||
|
||||
//TODO: subcomm check
|
||||
rocshmem::rocshmem_team_t team;
|
||||
if (!ncclCommToRshmemTeam.empty()) {
|
||||
team = ncclCommToRshmemTeam[comm];
|
||||
rocshmem::rocshmem_team_destroy(team);
|
||||
ncclCommToRshmemTeam.erase(comm);
|
||||
}
|
||||
if (ncclCommToRshmemTeam.empty()) {
|
||||
rocshmem::rocshmem_finalize();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev;
|
||||
struct ncclCommFinalizeAsyncJob *job = NULL;
|
||||
ncclResult_t res = ncclSuccess;
|
||||
|
||||
@@ -386,6 +386,18 @@ ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
bool rcclUseAllToAllGda(struct ncclComm* comm) {
|
||||
|
||||
//TODO: enable on MI350; currently tested on MI300X
|
||||
#ifdef ENABLE_ROCSHMEM
|
||||
if (comm->enableRocshmem && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes > 1 && (comm->nRanks/comm->nNodes == 8) && comm->rocshmemThreshold <= 1048576) {
|
||||
INFO(NCCL_INIT, "Enabling GDA alltoall for RCCL");
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) {
|
||||
// Check if user explicitly disabled direct AllGather
|
||||
static int userDirectAllGatherInput = -2;
|
||||
|
||||
在新工单中引用
屏蔽一个用户