GDA support for alltoall via rocshmem integration (#2099)

* ROCSHMEM linking/building to match MSCCL++ style

* add rocSHMEM as a submodule

* Move rocSHMEM submodule to ext-src/rocSHMEM

* Adding submodule support proper, as well as a patch for rocshmem

* Cleaning up INCLUDE_DIR vs INCLUDE_DIRS mixup

* updating patch file

* Pointing rocshmem submodule to edgars fixup patch

* Adding IBVERBS link to the submodule build

* More IBVERBS patching

* pin rocshmem submodule to b534423

* Adding IPC support in rocSHMEM build

* updating rocshmem submodule to resolve CQ errors

* Updating submodule to include recent a2a optimizations

* invoke rocshmem alltoall from rccl

* Updating submodule to CQ error number hang

* Updating submodule to include a2a improvements and bug fixes

* Updating submodule to point to Yiltan's fork and doorbell ring removal commit

* Updating hash to correspond with submodule change

* Updating to no-ctx wg call and updating submodule

* copy-in/copy-out using multiples CUs

* Updating rocSHMEM submodule to include doorbell improvs

* updating gitmodule to point to upstream

* code cleanup and adjust threashold

* guard rocshmem a2a invocation

* Only build with rocshmem when specified

* code cleanup

* address review comments

* Removing debugging failure case

Signed-off-by: Thomas Huber <thomas.huber@amd.com>

* whitespace fix

* Adding rocshmem compile guard

* Removing unneccesary comment

Signed-off-by: Thomas Huber <thomas.huber@amd.com>

* remove commented lines

* address review comments

* cleanup

---------

Signed-off-by: Thomas Huber <thomas.huber@amd.com>
Co-authored-by: Thomas Huber <thomas.huber@amd.com>
Co-authored-by: Nusrat Islam <nusislam@dell300x-ccs-aus-k12-27.cs-aus.dcgpu>
Co-authored-by: Nusrat Islam <nusislam@dell300x-ccs-aus-k13-09.cs-aus.dcgpu>
Co-authored-by: Islam <nusislam@amd.com>
Co-authored-by: Nusrat Islam <nusislam@dell300x-ccs-aus-k13-03.cs-aus.dcgpu>

[ROCm/rccl commit: 27648b0900]
This commit is contained in:
Nusrat Islam
2026-01-09 14:04:54 -06:00
کامیت شده توسط GitHub
والد 87eec6427e
کامیت eb347a0dd3
16فایلهای تغییر یافته به همراه427 افزوده شده و 16 حذف شده
@@ -8,3 +8,7 @@
url = https://github.com/nlohmann/json.git
ignore = dirty
shallow = true
[submodule "ext-src/rocSHMEM"]
path = ext-src/rocSHMEM
url = https://github.com/ROCm/rocSHMEM.git
branch = develop
@@ -42,6 +42,7 @@ option(TIMETRACE "Enable time-trace during compila
option(TRACE "Enable additional tracing" OFF)
option(FAULT_INJECTION "Enable fault injection" ON)
option(QUIET_WARNINGS "Supress compiler warnings" OFF)
option(ENABLE_ROCSHMEM "Enable rocSHMEM support in RCCL" OFF)
# Default GPU architectures to build
#==================================================================================================
@@ -65,6 +66,11 @@ include(CheckSymbolExists)
include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets
include(cmake/CheckSymbolExistsNoWarn.cmake)
# Include rocSHMEM build module only if enabled
if(ENABLE_ROCSHMEM)
include(cmake/ROCSHMEM.cmake)
endif()
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
# Build only for local GPU architecture
@@ -321,6 +327,8 @@ if(BUILD_BFD)
endif()
endif()
# Check for --amdgpu-kernarg-preload-count
check_cxx_compiler_flag("-mllvm --amdgpu-kernarg-preload-count=16" HAVE_KERNARG_PRELOAD)
if (HAVE_KERNARG_PRELOAD)
@@ -472,6 +480,7 @@ set(SRC_FILES
src/device/all_gather.h
src/device/all_reduce.h
src/device/alltoall_pivot.h
src/device/alltoall_gda.h
src/device/broadcast.h
src/device/common.h
src/device/common_kernel.h
@@ -884,6 +893,7 @@ if(ROCTX_ENABLE)
target_include_directories(rccl PRIVATE ${ROCTRACER_INCLUDE_DIR})
endif()
## Set RCCL compile definitions
if(COLLTRACE)
target_compile_definitions(rccl PRIVATE ENABLE_COLLTRACE)
@@ -903,6 +913,28 @@ endif()
if(ENABLE_WARP_SPEED)
target_compile_definitions(rccl PRIVATE ENABLE_WARP_SPEED)
endif()
if(ENABLE_ROCSHMEM)
target_compile_definitions(rccl PRIVATE ENABLE_ROCSHMEM)
endif()
# ==== rocSHMEM integration (optional) ====
if (ENABLE_ROCSHMEM)
add_rocshmem_targets()
# Ensure rocSHMEM is fully built/installed before compiling rccl
if (TARGET rocshmem_ext)
add_dependencies(rccl rocshmem_ext)
endif()
if (ROCSHMEM_INCLUDE_DIR)
target_include_directories(rccl PRIVATE ${ROCSHMEM_INCLUDE_DIR})
endif()
# Moved to where MSCCL target_links
## target_link_libraries(rccl PRIVATE ${ROCSHMEM_LIBRARY})
target_link_libraries(rccl PRIVATE ${IBVERBS})
endif()
# NPKit flags
## May be better to move these to a separate file
@@ -1234,6 +1266,10 @@ target_link_libraries(rccl PRIVATE fmt::fmt-header-only)
if(ENABLE_MSCCLPP)
target_link_libraries(rccl PRIVATE mscclpp_nccl)
endif()
if(ENABLE_ROCSHMEM)
target_link_libraries(rccl PRIVATE ${ROCSHMEM_LIBRARY})
target_link_libraries(rccl PRIVATE ${IBVERBS})
endif()
## Set RCCL link options
## Find out available memory
@@ -0,0 +1,35 @@
# MIT License
#
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
find_path(ROCSHMEM_INCLUDE_DIR
NAMES rocshmem/rocshmem.hpp rocshmem/rocshmem.h
HINTS ${ROCSHMEM_INSTALL_DIR}/include/)
find_library(ROCSHMEM_LIBRARY
NAMES rocshmem
HINTS ${ROCSHMEM_INSTALL_DIR}/lib)
## -- todo --- what to do with verbs? add to handle args call below? -- ##
find_library(IBVERBS ibverbs)
find_package_handle_standard_args(rocshmem_static DEFAULT_MSG ROCSHMEM_INCLUDE_DIR ROCSHMEM_LIBRARY)
## mark_as_advanced(MSCCLPP_INCLUDE_DIRS MSCCLPP_NCCL_STATIC_LIB) add this for Rocshmem?
@@ -0,0 +1,113 @@
# MIT License
#
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
include(ExternalProject)
function(add_rocshmem_targets)
# Check for an existing installation via the user-provided prefix ROCSHMEM_INSTALL DIR
if(ROCSHMEM_INSTALL_DIR)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
find_package(rocshmem_static)
if(NOT IBVERBS)
find_library(IBVERBS ibverbs)
if(IBVERBS)
set(IBVERBS ${IBVERBS} PARENT_SCOPE)
endif()
endif()
endif()
# If no pre-existing installation, build from submodule into ext/rocshmem
if(NOT rocshmem_static_FOUND)
set(_rccl_root "${CMAKE_SOURCE_DIR}")
set(ROCSHMEM_SOURCE "${_rccl_root}/ext-src/rocSHMEM")
set(ROCSHMEM_INSTALL_DIR "${_rccl_root}/ext/rocshmem")
# Make sure submodule exists (same style as MSCCL++: custom rule + target)
add_custom_command(
OUTPUT "${ROCSHMEM_SOURCE}/CMakeLists.txt"
COMMAND git submodule update --init --recursive ext-src/rocSHMEM
WORKING_DIRECTORY "${_rccl_root}"
COMMENT "Checking out submodule: ext-src/rocSHMEM"
VERBATIM
)
add_custom_target(rocshmem_checkout_submodule
DEPENDS "${ROCSHMEM_SOURCE}/CMakeLists.txt")
# Where our patch files live (like MSCCL++)
set(EXT_SOURCE "${_rccl_root}/ext-src")
# Build and install rocSHMEM. We run `../build_scripts/gdx_bxnt`
# from a 'build' dir just like the README shows.
ExternalProject_Add(rocshmem_ext
SOURCE_DIR "${ROCSHMEM_SOURCE}"
INSTALL_DIR "${ROCSHMEM_INSTALL_DIR}"
UPDATE_DISCONNECTED TRUE
LOG_DOWNLOAD FALSE
LOG_CONFIGURE FALSE
LOG_BUILD FALSE
LOG_INSTALL FALSE
BUILD_IN_SOURCE TRUE
DOWNLOAD_COMMAND "" # using the submodule checkout above
TEST_COMMAND ""
DEPENDS rocshmem_checkout_submodule
# Rocshmem submodule commit hash -> commit b28a56bd54ccc581d05a439ffa466c3dacb3385
# The project has its own scripts; we replicate the README sequence:
CONFIGURE_COMMAND ""
BUILD_COMMAND
${CMAKE_COMMAND} -E make_directory build
&& ${CMAKE_COMMAND} -E chdir build bash -lc "../scripts/build_configs/gda_bnxt -DUSE_EXTERNAL_MPI=OFF -DUSE_IPC=ON -DBUILD_EXAMPLES=OFF "
&& ${CMAKE_COMMAND} -E chdir build ${CMAKE_COMMAND}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DBUILD_EXAMPLES=OFF ..
&& ${CMAKE_COMMAND} -E chdir build ${CMAKE_MAKE_PROGRAM} -j
INSTALL_COMMAND
${CMAKE_COMMAND} -E chdir build ${CMAKE_MAKE_PROGRAM} install
)
# After build, define the variables RCCL expects
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_INSTALL_DIR}/include" PARENT_SCOPE)
set(ROCSHMEM_LIBRARY "${ROCSHMEM_INSTALL_DIR}/lib/librocshmem.a" PARENT_SCOPE)
find_library(_IBVERBS ibverbs)
if(NOT _IBVERBS)
message(FATAL_ERROR "libibverbs not found (install rdma-core/libibverbs-dev)")
endif()
set(IBVERBS ${_IBVERBS} PARENT_SCOPE)
# Provide a dummy target other code can depend on
add_custom_target(rocshmem_static ALL DEPENDS rocshmem_ext)
else()
# We found a prebuilt rocSHMEM; export variables upward as-is
set(ROCSHMEM_INCLUDE_DIR "${ROCSHMEM_INCLUDE_DIR}" PARENT_SCOPE)
set(ROCSHMEM_LIBRARY "${ROCSHMEM_LIBRARY}" PARENT_SCOPE)
find_library(_IBVERBS ibverbs)
if(NOT _IBVERBS)
message(FATAL_ERROR "libibverbs not found")
endif()
set(IBVERBS ${_IBVERBS} PARENT_SCOPE)
endif()
endfunction()
Submodule projects/rccl/ext-src/rocSHMEM added at b28a56bd54
+12 -1
مشاهده پرونده
@@ -41,6 +41,7 @@ force_reduce_pipeline=false
generate_sym_kernels=false
warp_speed_enabled=true # note that this flag will be overridden to false for non MI350/MI300 platforms
quiet_warnings=false
build_rocshmem_support=false
# #################################################
# helper functions
@@ -82,6 +83,7 @@ function display_help()
echo " --force-reduce-pipeline Force reduce_copy sw pipeline to be used for every reduce-based collectives and datatypes"
echo " --generate-sym-kernels Generate symmetric memory kernels"
echo " -q|--quiet-warnings Suppress majority of compiler warnings (not recommended)"
echo " --rocshmem Build with rocSHMEM support"
}
# #################################################
@@ -91,7 +93,7 @@ function display_help()
# check if we have a modern version of getopt that can handle whitespace and long parameters
getopt -T
if [[ "$?" -eq 4 ]]; then
GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose -- "$@")
GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose,rocshmem -- "$@")
else
echo "Need a new version of getopt"
exit 1
@@ -140,6 +142,7 @@ while true; do
--generate-sym-kernels) generate_sym_kernels=true; shift ;;
--disable-warp-speed) warp_speed_enabled=false; shift ;;
-q | --quiet-warnings) quiet_warnings=true; shift ;;
--rocshmem) build_rocshmem_support=true; shift ;;
--) shift ; break ;;
*) echo "Unexpected command line parameter received; aborting";
exit 1
@@ -329,6 +332,14 @@ if [[ "${quiet_warnings}" == true ]]; then
fi
# Enable rocSHMEM support
if [[ "${build_rocshmem_support}" == true ]]; then
cmake_common_options="${cmake_common_options} -DENABLE_ROCSHMEM=ON"
cmake_common_options="${cmake_common_options} -DROCSHMEM_INSTALL_DIR=${ROCSHMEM_INSTALL_DIR}"
else
cmake_common_options="${cmake_common_options} -DENABLE_ROCSHMEM=OFF"
fi
check_exit_code "$?"
# Enable ninja build for time tracing
@@ -13,6 +13,10 @@
#include "nvtx_payload_schemas.h"
#include "msccl/msccl_lifecycle.h"
#ifdef ENABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
using namespace rccl;
const char* ncclFuncToString(ncclFunc_t fn) {
@@ -222,6 +226,8 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
size_t rankOffset = count * ncclTypeSize(datatype);
size_t rankAlign = rankOffset & ((~rankOffset) + 1);
size_t msgSize = count * ncclTypeSize(datatype) * comm->nRanks;
// Determine Pivot A2A support now that we know number of channels
if (comm->topo->pivotA2AEnabled && comm->nChannels >= comm->topo->pivotA2ANumBiRings * 2 &&
rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAllToAllPivotEnable()) {
@@ -230,7 +236,17 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
return ncclEnqueueCheck(&info);
} else {
#ifdef ENABLE_ROCSHMEM
if (rcclUseAllToAllGda(comm) && msgSize <= comm->rocshmemThreshold) {
struct ncclInfo info = { ncclFuncAllToAllGda, "AllToAllGda",
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream,
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
return ncclEnqueueCheck(&info);
}
#endif
int nRanks;
//comm->isA2a = 0;
NCCLCHECK(ncclCommCount(comm, &nRanks));
if (count == 0) return ncclSuccess;
if (!mscclIsCaller()) Recorder::instance().skip(true);
@@ -0,0 +1,33 @@
/*************************************************************************
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "device.h"
#include "collectives.h"
#include "primitives.h"
#ifdef ENABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncAllToAllGda, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(int tid, int nThreads, struct ncclDevWorkColl* work) {
if (blockIdx.x == 0) {
int num_pes = rocshmem::rocshmem_n_pes();
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
tid, nThreads, 0, nullptr, false, 1, (void **)&work->sendbuff, 1, (void **)&work->sndbuff,
(work->size*num_pes));
rocshmem::rocshmem_char_alltoall_wg(work->team, ((char*)work->tempbuff), ((char*)work->sndbuff), work->size);
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
tid, nThreads, 0, nullptr, false, 1, (void **)&work->tempbuff, 1, (void **)&work->recvbuff,
(work->size*num_pes));
}
}
};
#endif
@@ -5,7 +5,7 @@ import subprocess
from dataclasses import dataclass
# Order of colls, redops, tys, protos, algos must match src/include/device.h
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot"]
all_colls = ["Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "", "", "AllToAllPivot", "AllToAllGda"]
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
all_protos = ["LL","LL128","SIMPLE"]
@@ -79,7 +79,7 @@ func_pattern = sys.argv[6:7]
if func_pattern and func_pattern[0]:
func_pattern = func_pattern[0]
else:
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
func_pattern = "AllGather|AllReduce|AllToAllPivot|AllToAllGda|Broadcast|Reduce|ReduceScatter|SendRecv"
################################################################################
@@ -87,6 +87,7 @@ algos_of_coll = {
"AllGather": ["RING", "PAT"],
"AllReduce": ["RING", "TREE"],
"AllToAllPivot": ["RING"],
"AllToAllGda": ["RING"],
"Broadcast": ["RING"],
"Reduce": ["RING"],
"ReduceScatter": ["RING", "PAT"],
@@ -97,6 +98,7 @@ protos_of_coll = {
"AllGather": all_protos,
"AllReduce": all_protos,
"AllToAllPivot": ["SIMPLE"],
"AllToAllGda": ["SIMPLE"],
"Broadcast": all_protos,
"Reduce": all_protos,
"ReduceScatter": all_protos,
@@ -107,6 +109,7 @@ redops_of_coll = {
"AllGather": ["Sum"],
"AllReduce": all_redops,
"AllToAllPivot": ["Sum"],
"AllToAllGda": ["Sum"],
"Broadcast": ["Sum"],
"Reduce": all_redops,
"ReduceScatter": all_redops,
@@ -117,6 +120,7 @@ tys_of_coll = {
"AllGather": ["i8"],
"AllReduce": all_tys,
"AllToAllPivot": ["i8"],
"AllToAllGda": ["i8"],
"Broadcast": ["i8"],
"Reduce": all_tys,
"ReduceScatter": all_tys,
@@ -127,6 +131,7 @@ acc_of_coll = {
"AllGather": ["0"],
"AllReduce": all_accs,
"AllToAllPivot": ["0"],
"AllToAllGda": ["0"],
"Broadcast": ["0"],
"Reduce": ["0"],
"ReduceScatter": ["0"],
@@ -137,6 +142,7 @@ pipelines_of_coll = {
"AllGather": ["0"],
"AllReduce": all_pipelines,
"AllToAllPivot": ["0"],
"AllToAllGda": ["0"],
"Broadcast": ["0"],
"Reduce": all_pipelines,
"ReduceScatter": all_pipelines,
@@ -148,6 +154,7 @@ coll_camel_to_lower = {
"AllGather": "all_gather",
"AllReduce": "all_reduce",
"AllToAllPivot": "alltoall_pivot",
"AllToAllGda": "alltoall_gda",
"Broadcast": "broadcast",
"Reduce": "reduce",
"ReduceScatter": "reduce_scatter",
@@ -503,7 +510,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
)
if fn.coll == "Broadcast":
key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8))
if fn.coll in ["SendRecv", "AllToAllPivot"]:
if fn.coll in ["SendRecv", "AllToAllPivot", "AllToAllGda"]:
key = ((coll_idx & 0x3F))
out(f' {{{key}, {fn_id}}}, {comment}\n')
+35 -8
مشاهده پرونده
@@ -11,6 +11,7 @@
#include "coll_net.h"
#include "graph/topo.h"
#include <hip/hip_runtime.h>
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_ext.h>
#include "gdrwrap.h"
#include "bootstrap.h"
@@ -29,6 +30,10 @@
#include <cassert>
#include "latency_profiler/CollTraceFunc.h"
#ifdef ENABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
using namespace rccl;
struct ncclKernelMatch {
@@ -36,6 +41,7 @@ struct ncclKernelMatch {
bool specialized;
};
#ifdef ENABLE_COLLTRACE
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0))
static ncclKernelMatch const ncclKerns[6] = {
@@ -390,6 +396,19 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) {
devWork.rcclUseOneSlice = comm->rcclUseOneSlice;
//[Added-comment] opCount is missing for collDevWork, adding here
devWork.opCount = task->opCount;
#ifdef ENABLE_ROCSHMEM
if (comm->enableRocshmem && task->func == ncclFuncAllToAllGda) {
devWork.enableRocshmem = comm->enableRocshmem;
devWork.team = comm->team_reduce_world_dup;
devWork.sndbuff = (void*)comm->sourceRshmem[comm->symId];
devWork.tempbuff = (void*)comm->destRshmem[comm->symId];
comm->symId = (comm->symId + 1) % comm->numSymBuf;
devWork.size = task->count;
}
#endif
devWork.isOneRPN = comm->isOneRPN;
devWork.netRegUsed = devWork.regUsed = 0;
@@ -730,8 +749,10 @@ static ncclResult_t scheduleCollTasksToPlan(
proxyOp.incWorkCounter = true;
addWorkBatchToPlan(comm, plan, c, workNode->workType, task->devFuncId, plan->workBytes);
// Set pattern to profiler to add a proxy profiler for kernel events
NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp));
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp));
if (task->func != ncclFuncAllToAllGda) {
NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp));
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, &proxyOp));
}
}
} else { // not task->isCollnet
int trafficPerByte = ncclFuncTrafficPerByte(task->func, comm->nRanks);
@@ -877,8 +898,10 @@ static ncclResult_t scheduleCollTasksToPlan(
// Coverity reports "proxyOp->connection" as being possibly uninitialized. It's hard to
// determine if that's actually true but it's also not clear if that would be an issue.
// coverity[uninit_use_in_call:FALSE]
NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp));
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp));
if (task->func != ncclFuncAllToAllGda) {
NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp));
NCCLCHECK(addProfilerProxyOpIfNeeded(comm, plan, proxyOp));
}
}
}
@@ -1508,8 +1531,8 @@ static ncclResult_t hostStreamPlanTask(struct ncclComm* comm, struct ncclKernelP
NCCLCHECK(ncclProfilerStartGroupEvent(plan));
NCCLCHECK(ncclProfilerStartTaskEvents(plan));
if (ncclIntruQueueHead(&plan->proxyOpQueue)) {
NCCLCHECK(uploadProxyOps(comm, plan));
NCCLCHECK(ncclProxyStart(comm));
NCCLCHECK(uploadProxyOps(comm, plan));
NCCLCHECK(ncclProxyStart(comm));
}
NCCLCHECK(ncclProfilerStopTaskEvents(plan));
NCCLCHECK(ncclProfilerStopGroupEvent(plan));
@@ -1788,6 +1811,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan
latency_profiler::collTraceRecordStartEvent(comm, launchStream, event.get());
comm->lastStream = planner->streams->stream;
CUDACHECKGOTO(hipExtLaunchKernel(plan->kernelFn, grid, block, extra, 0, launchStream, NULL, comm->doneEvent, 0), ret, do_return);
latency_profiler::collTraceRecordEndEvent(comm, plan, launchStream, std::move(event));
return ncclSuccess;
}
@@ -2023,7 +2047,7 @@ static ncclResult_t updateCollCostTable(
float** collCostTable) {
float (*table)[NCCL_NUM_PROTOCOLS] = (float (*)[NCCL_NUM_PROTOCOLS])collCostTable;
if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot) {
if (comm->nRanks == 1 || info->func == ncclFuncAllToAllPivot || info->func == ncclFuncAllToAllGda) {
table[NCCL_ALGO_RING][NCCL_PROTO_SIMPLE] = 0.0;
return ncclSuccess;
}
@@ -2327,6 +2351,9 @@ static ncclResult_t calcCollChunking(
case ncclFuncAllToAllPivot:
pattern = ncclPatternRing;
break;
case ncclFuncAllToAllGda:
pattern = ncclPatternRing;
break;
case ncclFuncAllReduce:
pattern =
info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls :
@@ -2749,7 +2776,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) {
t->root = info->root;
t->datatype = info->datatype;
size_t elementSize = ncclTypeSize(t->datatype);
if (t->func == ncclFuncAllGather || t->func == ncclFuncBroadcast || t->func == ncclFuncAllToAllPivot) {
if (t->func == ncclFuncAllGather || t->func == ncclFuncBroadcast || t->func == ncclFuncAllToAllPivot || t->func == ncclFuncAllToAllGda) {
t->count *= elementSize;
t->datatype = ncclInt8;
elementSize = 1;
@@ -24,6 +24,10 @@
#include "rccl_common.h"
#include "recorder.h"
#ifdef ENABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#define HIPRT_CB
#else
@@ -725,6 +729,17 @@ struct ncclComm {
// multiProcessorCount from hipDeviceProp_t [RCCL]
int cuCount;
#ifdef ENABLE_ROCSHMEM
// circular ring buffer in rocshmem symmetric heap
void** sourceRshmem;
void** destRshmem;
rocshmem::rocshmem_team_t team_reduce_world_dup;
int enableRocshmem;
int rocshmemThreshold;
int numSymBuf;
int symId;
#endif
uint64_t endMagic;
};
@@ -38,7 +38,11 @@
#include <string>
#include "debug.h"
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2];
#ifdef ENABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3];
extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS];
@@ -397,6 +401,14 @@ struct alignas(16) ncclDevWorkColl {
};
uint64_t redOpArg;
uint64_t opCount;
#ifdef ENABLE_ROCSHMEM
rocshmem::rocshmem_team_t team;
int enableRocshmem;
void* tempbuff;
void* sndbuff;
int size;
#endif
};
@@ -784,7 +796,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto,
if (coll == ncclFuncBroadcast) {
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) |
((uint64_t)(proto & RCCL_FUNC_ID_MASK) << RCCL_PROTO_SHIFT);
} else if (coll == ncclFuncSendRecv || coll == ncclFuncAllToAllPivot) {
} else if (coll == ncclFuncSendRecv || coll == ncclFuncAllToAllPivot || coll == ncclFuncAllToAllGda) {
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT );
} else {
key = ((uint64_t)(coll & RCCL_FUNC_ID_MASK) << RCCL_COLL_SHIFT ) |
@@ -64,7 +64,8 @@ typedef enum {
ncclFuncSend = 6,
ncclFuncRecv = 7,
ncclFuncAllToAllPivot = 8,
ncclNumFuncs = 9
ncclFuncAllToAllGda = 9,
ncclNumFuncs = 10
} ncclFunc_t;
#define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*/PAT
@@ -112,6 +112,7 @@ NCCL_API(ncclResult_t, rcclGetAlgoInfo, struct ncclComm* comm, ncclFunc_t coll,
NCCL_API(ncclResult_t, rcclGetAlgoName, int algo, const char** algoName);
NCCL_API(ncclResult_t, rcclGetProtocolName, int protocol, const char** algoName);
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize);
bool rcclUseAllToAllGda(struct ncclComm* comm);
void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable);
void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize);
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
+88 -1
مشاهده پرونده
@@ -56,6 +56,12 @@
#include "rccl_common.h"
// [/RCCL]
#ifdef ENABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#define NUM_SYM_BUF 8
#endif
#include "msccl/msccl_lifecycle.h"
#include "msccl/msccl_status.h"
#include "latency_profiler/CollTrace.h"
@@ -78,7 +84,7 @@
using namespace rccl;
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "AllGather", "AllReduce", "AllToAllPivot", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"};
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+3] = { "AllGather", "AllReduce", "AllToAllPivot", "AllToAllGda", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"};
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree", "PAT" };
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "MinMax", "PreMulSum", "SumPostDiv" };
@@ -97,6 +103,13 @@ NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", NCCL_CONFIG_UNDEF_INT);
struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {};
ncclResult_t commReclaim(ncclComm_t comm);
#ifdef ENABLE_ROCSHMEM
RCCL_PARAM(RocshmemThreshold, "ROCSHMEM_THRESHOLD", (size_t)(262144));
RCCL_PARAM(RocshmemEnabled, "ROCSHMEM_ENABLE", 1);
std::unordered_map<ncclComm_t, rocshmem::rocshmem_team_t> ncclCommToRshmemTeam;
#endif
#ifdef ENABLE_MSCCLPP
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {
return (size_t)getHash(uniqueId.internal, NCCL_UNIQUE_ID_BYTES);
@@ -2107,6 +2120,58 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
// RCCL: determine and set unroll factor for comm
NCCLCHECK(commSetUnrollFactor(comm));
#ifdef ENABLE_ROCSHMEM
if (rcclParamRocshmemEnabled()) { // @TODO - This doesn't seem to disable when I set ROCSHMEM_ENABLE=0 on command line
INFO(NCCL_INIT,"Initializing rocSHMEM inside of RCCL");
int ret;
rocshmem::rocshmem_uniqueid_t rocshmemUniqueId;
rocshmem::rocshmem_init_attr_t rocshmemAttr;
if(comm->rank == 0 ) {
ret = rocshmem::rocshmem_get_uniqueid (&rocshmemUniqueId);
if (ret != rocshmem::ROCSHMEM_SUCCESS) {
ERROR("Error in rocshmem_get_uniqueid, Rocshmem cannot be initialized.");
return ncclSystemError;
}
}
NCCLCHECKGOTO(bootstrapBroadcast(comm->bootstrap, comm->rank, comm->nRanks, 0, &rocshmemUniqueId,
sizeof(rocshmemUniqueId)), res, fail);
ret = rocshmem::rocshmem_set_attr_uniqueid_args(job->myrank, job->nranks, &rocshmemUniqueId, &rocshmemAttr);
if (ret != rocshmem::ROCSHMEM_SUCCESS) {
ERROR("Error in rocshmem_set_attr_uniqueid_args, Rocshmem cannot be initialized.");
return ncclSystemError;
}
ret = rocshmem::rocshmem_init_attr(rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID, &rocshmemAttr);
if (ret != rocshmem::ROCSHMEM_SUCCESS) {
ERROR("Error in rocshmem_init_attr, Rocshmem cannot be initialized.");
return ncclSystemError;
}
comm->sourceRshmem = (void**) malloc(NUM_SYM_BUF * sizeof(void *));
comm->destRshmem = (void**) malloc(NUM_SYM_BUF * sizeof(void *));
for (int i = 0; i < NUM_SYM_BUF; i++) {
comm->sourceRshmem[i] = (void *)rocshmem::rocshmem_malloc((size_t)(1*1024*1024));
comm->destRshmem[i] = (void *)rocshmem::rocshmem_malloc((size_t)(1*1024*1024));
}
comm->enableRocshmem = rcclParamRocshmemEnabled();
comm->rocshmemThreshold = rcclParamRocshmemThreshold();
comm->numSymBuf = NUM_SYM_BUF;
comm->symId = 0;
//rocshmem::rocshmem_team_t team_reduce_world_dup;
comm->team_reduce_world_dup = rocshmem::ROCSHMEM_TEAM_INVALID;
rocshmem::rocshmem_team_split_strided(rocshmem::ROCSHMEM_TEAM_WORLD, 0, 1, job->nranks, nullptr, 0,
&(comm->team_reduce_world_dup));
ncclCommToRshmemTeam[comm] = comm->team_reduce_world_dup;
CUDACHECK(hipDeviceSynchronize());
}
#endif
#ifdef ENABLE_MSCCLPP
if (job->parent) {
if (job->parent->mscclppCompatible) {
@@ -2935,6 +3000,28 @@ ncclResult_t ncclCommDestroy_impl(ncclComm_t comm) {
}
#endif
#ifdef ENABLE_ROCSHMEM
if (comm->enableRocshmem) {
for (int i = 0; i < NUM_SYM_BUF; i++) {
rocshmem::rocshmem_free(comm->sourceRshmem[i]);
rocshmem::rocshmem_free(comm->destRshmem[i]);
}
free(comm->sourceRshmem);
free(comm->destRshmem);
//TODO: subcomm check
rocshmem::rocshmem_team_t team;
if (!ncclCommToRshmemTeam.empty()) {
team = ncclCommToRshmemTeam[comm];
rocshmem::rocshmem_team_destroy(team);
ncclCommToRshmemTeam.erase(comm);
}
if (ncclCommToRshmemTeam.empty()) {
rocshmem::rocshmem_finalize();
}
}
#endif
int rank = comm->rank, nranks = comm->nRanks, cudaDev = comm->cudaDev;
struct ncclCommFinalizeAsyncJob *job = NULL;
ncclResult_t res = ncclSuccess;
@@ -386,6 +386,18 @@ ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) {
return ncclSuccess;
}
bool rcclUseAllToAllGda(struct ncclComm* comm) {
//TODO: enable on MI350; currently tested on MI300X
#ifdef ENABLE_ROCSHMEM
if (comm->enableRocshmem && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") && comm->nNodes > 1 && (comm->nRanks/comm->nNodes == 8) && comm->rocshmemThreshold <= 1048576) {
INFO(NCCL_INIT, "Enabling GDA alltoall for RCCL");
return true;
}
#endif
return false;
}
bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) {
// Check if user explicitly disabled direct AllGather
static int userDirectAllGatherInput = -2;