коммит произвёл
GitHub
родитель
05850e89f2
Коммит
28d9b170c9
@@ -17,7 +17,7 @@ def runCI =
|
||||
def prj = new rocProject('rccl', 'Extended')
|
||||
|
||||
prj.timeout.test = 600
|
||||
prj.paths.build_command = './install.sh -t'
|
||||
prj.paths.build_command = './install.sh -tj 16'
|
||||
|
||||
// Define test architectures, optional rocm version argument is available
|
||||
def nodes = new dockerNodes(nodeDetails, jobName, prj)
|
||||
|
||||
@@ -18,7 +18,7 @@ def runCI =
|
||||
def prj = new rocProject('rccl', 'PreCheckin')
|
||||
|
||||
prj.timeout.test = 300
|
||||
prj.paths.build_command = './install.sh -t --fast'
|
||||
prj.paths.build_command = './install.sh -tj 16 --fast'
|
||||
|
||||
// Define test architectures, optional rocm version argument is available
|
||||
def nodes = new dockerNodes(nodeDetails, jobName, prj)
|
||||
|
||||
@@ -12,7 +12,7 @@ def runCI =
|
||||
def prj = new rocProject('rccl', 'Static Library PreCheckin')
|
||||
|
||||
prj.timeout.test = 1440
|
||||
prj.paths.build_command = './install.sh -t --static'
|
||||
prj.paths.build_command = './install.sh -tj 16 --static'
|
||||
|
||||
def nodes = new dockerNodes(nodeDetails, jobName, prj)
|
||||
|
||||
|
||||
+9
-32
@@ -10,7 +10,6 @@ project(rccl CXX)
|
||||
# Build options
|
||||
#==================================================================================================
|
||||
option(BUILD_ADDRESS_SANITIZER "Enable address sanitizer" OFF)
|
||||
option(BUILD_ALLREDUCE_ONLY "AllReduce(sum,float) kernel only" OFF)
|
||||
option(BUILD_BFD "Enable custom backtrace (if bfd.h exists)" OFF)
|
||||
option(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY "File/folder reorg with backward compatibility" OFF)
|
||||
option(BUILD_LOCAL_GPU_TARGET_ONLY "Build only for GPUs detected on this machine" OFF)
|
||||
@@ -46,6 +45,7 @@ set(DEFAULT_GPUS
|
||||
include(CheckIncludeFiles)
|
||||
include(CheckSymbolExists)
|
||||
include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets
|
||||
include(cmake/Generator.cmake) # Configure functions that goes into RCCL
|
||||
|
||||
# Build only for local GPU architecture
|
||||
if (BUILD_LOCAL_GPU_TARGET_ONLY)
|
||||
@@ -309,6 +309,7 @@ set(SRC_FILES
|
||||
src/collectives/device/reduce_kernel.h
|
||||
src/collectives/device/reduce_scatter.h
|
||||
src/collectives/device/sendrecv.h
|
||||
src/collectives/device/onerank_reduce.cu
|
||||
src/collectives/gather.cc
|
||||
src/collectives/reduce.cc
|
||||
src/collectives/reduce_scatter.cc
|
||||
@@ -447,31 +448,6 @@ set(SRC_FILES
|
||||
src/transport/shm.cc
|
||||
)
|
||||
|
||||
## Add kernel files
|
||||
## E.g: find src -type f \( -name "*.u" \) | sort
|
||||
if (BUILD_ALLREDUCE_ONLY)
|
||||
add_definitions(-DBUILD_ALLREDUCE_ONLY)
|
||||
set(CU_SOURCES
|
||||
# src/collectives/device/all_reduce.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/functions.cu
|
||||
# src/collectives/device/msccl_kernel.cu
|
||||
)
|
||||
else()
|
||||
set(CU_SOURCES
|
||||
src/collectives/device/all_gather.cu
|
||||
# src/collectives/device/all_reduce.cu
|
||||
src/collectives/device/alltoall_pivot.cu
|
||||
src/collectives/device/broadcast.cu
|
||||
src/collectives/device/functions.cu
|
||||
# src/collectives/device/msccl_kernel.cu
|
||||
src/collectives/device/onerank_reduce.cu
|
||||
# src/collectives/device/reduce.cu
|
||||
# src/collectives/device/reduce_scatter.cu
|
||||
src/collectives/device/sendrecv.cu)
|
||||
endif()
|
||||
list(APPEND SRC_FILES ${CU_SOURCES})
|
||||
|
||||
if (ENABLE_MSCCL_KERNEL)
|
||||
set(MSCCL_KERNEL_SOURCES
|
||||
src/collectives/device/msccl_kernel_impl.h
|
||||
@@ -509,11 +485,12 @@ foreach(SRC_FILE ${SRC_FILES})
|
||||
)
|
||||
endforeach()
|
||||
|
||||
expand_collectives("all_reduce" "AllReduce")
|
||||
expand_collectives("reduce" "Reduce")
|
||||
expand_collectives("reduce_scatter" "ReduceScatter")
|
||||
if(ENABLE_MSCCL_KERNEL)
|
||||
expand_collectives("msccl_kernel" "MscclKernel")
|
||||
if(ONLY_FUNCS)
|
||||
## Generate only the specified functions
|
||||
gen_functions(${ONLY_FUNCS})
|
||||
else()
|
||||
# Generate all the functions
|
||||
gen_functions("AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv")
|
||||
endif()
|
||||
|
||||
# Create an initial git_version.cpp file (that will be updated with latest git version)
|
||||
@@ -637,7 +614,7 @@ if (HAS_BFD)
|
||||
target_link_libraries(rccl PRIVATE iberty z)
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code
|
||||
target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code
|
||||
|
||||
## Set RCCL link options
|
||||
target_link_options(rccl PRIVATE -parallel-jobs=16) # Use multiple threads to link
|
||||
|
||||
+1
-2
@@ -25,7 +25,6 @@ The root of this repository has a helper script 'install.sh' to build and instal
|
||||
|
||||
Options:
|
||||
--address-sanitizer Build with address sanitizer enabled
|
||||
--build_allreduce_only Build only AllReduce + sum + float kernel
|
||||
-d|--dependencies Install RCCL depdencencies
|
||||
--debug Build debug library
|
||||
--enable_backtrace Build with custom backtrace support
|
||||
@@ -34,7 +33,7 @@ The root of this repository has a helper script 'install.sh' to build and instal
|
||||
-f|--fast Quick-build RCCL (local gpu arch only, no backtrace, and collective trace support)
|
||||
-h|--help Prints this help message
|
||||
-i|--install Install RCCL library (see --prefix argument below)
|
||||
-j|--jobs Specify how many parallel compilation jobs to run (16 by default)
|
||||
-j|--jobs Specify how many parallel compilation jobs to run (nproc by default)
|
||||
-l|--local_gpu_only Only compile for local GPU architecture
|
||||
--no_clean Don't delete files if they already exist
|
||||
--npkit-enable Compile with npkit enabled
|
||||
|
||||
@@ -70,55 +70,6 @@ if(NOT GTest_FOUND AND BUILD_TESTS OR INSTALL_DEPENDENCIES)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(DATATYPES_INT
|
||||
"int8_t"
|
||||
"uint8_t"
|
||||
"int32_t"
|
||||
"uint32_t"
|
||||
"int64_t"
|
||||
"uint64_t"
|
||||
)
|
||||
set(DATATYPES_FLOAT
|
||||
"half"
|
||||
"float"
|
||||
"double"
|
||||
"rccl_bfloat16"
|
||||
)
|
||||
|
||||
function(expand_collectives FILE FUNC)
|
||||
set(REDOP Sum Prod Min Max PreMulSum SumPostDiv)
|
||||
if (FUNC STREQUAL "MscclKernel")
|
||||
set(REDOP_FILTERED Sum Prod Min Max PreMulSum SumPostDiv)
|
||||
else()
|
||||
set(REDOP_FILTERED ${REDOP})
|
||||
endif()
|
||||
foreach(REDOP_CURRENT IN LISTS REDOP_FILTERED)
|
||||
foreach(DATA_TYPE ${DATATYPES_INT} ${DATATYPES_FLOAT})
|
||||
if (REDOP_CURRENT STREQUAL "SumPostDiv" AND DATA_TYPE IN_LIST DATATYPES_FLOAT)
|
||||
continue() # Skip the iteration for DATATYPES_FLOAT when REDOP_CURRENT is SumPostDiv
|
||||
endif()
|
||||
set(FILE_NAME "${HIPIFY_DIR}/src/collectives/device/${FILE}_${REDOP_CURRENT}_${DATA_TYPE}.cpp")
|
||||
message(STATUS "Generating ${FILE_NAME}")
|
||||
if (FUNC STREQUAL "MscclKernel")
|
||||
file(WRITE ${FILE_NAME}
|
||||
"#include \"${FILE}_impl.h\"
|
||||
#include \"primitives.h\"
|
||||
#include \"collectives.h\"
|
||||
#include \"devcomm.h\"
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);")
|
||||
else()
|
||||
file(WRITE ${FILE_NAME}
|
||||
"#include \"${FILE}.h\"
|
||||
#include \"common.h\"
|
||||
#include \"collectives.h\"
|
||||
IMPL_COLL3(${FUNC}, ${REDOP_CURRENT}, ${DATA_TYPE});")
|
||||
endif()
|
||||
list(APPEND HIP_SOURCES ${FILE_NAME})
|
||||
endforeach()
|
||||
endforeach()
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Find or download/install rocm-cmake project
|
||||
set( PROJECT_EXTERN_DIR ${CMAKE_CURRENT_BINARY_DIR}/extern )
|
||||
find_package(ROCM 0.7.3 QUIET CONFIG PATHS /opt/rocm)
|
||||
|
||||
@@ -0,0 +1,383 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
set(ALL_PARAMS "ALL_COLLS" "ALL_ALGOS" "ALL_PROTOS" "ALL_REDOPS" "ALL_TYPES")
|
||||
|
||||
set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "ReduceScatter" "SendRecv")
|
||||
set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN")
|
||||
set(ALL_PROTOS "LL" "LL128" "SIMPLE")
|
||||
set(ALL_REDOPS "Sum" "Prod" "Max" "Min" "PreMulSum" "SumPostDiv")
|
||||
set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "rccl_bfloat16")
|
||||
|
||||
set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16")
|
||||
|
||||
################################################################################
|
||||
# The command line argument is used as a regex to filter the functions
|
||||
# which make it into librccl. This is helpful for reducing the binary when
|
||||
# developing device code. The regex supports non-space containing globs '*',
|
||||
# and union 'a|b'. The string representing the function has the form:
|
||||
#
|
||||
# <coll> <algo> <proto> <redop> <type>
|
||||
#
|
||||
# The possible values for redop, type, algo, proto can be found in the all_<foo>
|
||||
# lists at the top of this file.
|
||||
#
|
||||
# Example use-cases:
|
||||
#
|
||||
# # Only send/recv:
|
||||
# make ONLY_FUNCS="SendRecv"
|
||||
#
|
||||
# # Only AllReduce and Reduce
|
||||
# make ONLY_FUNCS="AllReduce|Reduce"
|
||||
#
|
||||
# # Only non-reductions:
|
||||
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
#
|
||||
# # Only AllReduce Sum int32_t (but all algos, protos)
|
||||
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
|
||||
#############################################################################################################
|
||||
## A recursive helper macro to generate functions and kernels based on the input given
|
||||
#############################################################################################################
|
||||
macro(filter_functions FUNCTION_PARAMS current_idx)
|
||||
## Check if the current_idx does not exceed the max depth
|
||||
if(${current_idx} LESS 5)
|
||||
## current_element is the config parameter
|
||||
list(GET FUNCTION_PARAMS ${current_idx} current_element)
|
||||
|
||||
## If the parameter is equal to '*', include all the possible cases for it
|
||||
if(${current_element} STREQUAL "*")
|
||||
if(${current_idx} EQUAL 0)
|
||||
message(FATAL_ERROR "Error: Parameter 'COLL' can not be type all '*'.")
|
||||
endif()
|
||||
## ALL_PARAMS list must be in the same order as FUNCTION_PARAMS ---> <coll> <algo> <proto> <redop> <type>
|
||||
## Find the respective parameter list from ALL_PARAMS list
|
||||
list(GET ALL_PARAMS ${current_idx} current_list)
|
||||
|
||||
## Iterate over the items int the current_list
|
||||
foreach(item IN LISTS ${current_list})
|
||||
## Add item to ITEM_LIST which will be used in the inner most loop
|
||||
list(APPEND ITEM_LIST ${item})
|
||||
math(EXPR new_idx "${current_idx} + 1")
|
||||
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
|
||||
|
||||
## For each loop layer remove the last element in ITEM_LIST
|
||||
list(REMOVE_AT ITEM_LIST -1)
|
||||
endforeach()
|
||||
else()
|
||||
## Check if the current element is recognized
|
||||
list(GET ALL_PARAMS ${current_idx} current_param)
|
||||
list(FIND ${current_param} ${current_element} is_valid)
|
||||
if(${is_valid} EQUAL -1)
|
||||
message(FATAL_ERROR "Error: ${current_element} is unrecognized or does not belong to this category.")
|
||||
endif()
|
||||
|
||||
## If not '*', no need to iterate. Add the current_element to ITEM_LIST
|
||||
list(APPEND ITEM_LIST ${current_element})
|
||||
math(EXPR new_idx "${current_idx} + 1")
|
||||
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
|
||||
|
||||
list(REMOVE_AT ITEM_LIST -1)
|
||||
endif()
|
||||
else()
|
||||
## This is the inner most loop where the file is generated
|
||||
## Unzip ITEM_LIST
|
||||
list(GET ITEM_LIST 0 COLL)
|
||||
list(GET ITEM_LIST 1 ALGO)
|
||||
list(GET ITEM_LIST 2 PROTO)
|
||||
list(GET ITEM_LIST 3 REDOP)
|
||||
list(GET ITEM_LIST 4 TYPE)
|
||||
|
||||
## Need to check if these conditions are met prior to file generation
|
||||
if(NOT ${COLL} STREQUAL "AllReduce" AND NOT ${ALGO} STREQUAL "RING")
|
||||
continue()
|
||||
elseif((${COLL} STREQUAL "AllGather" OR ${COLL} STREQUAL "Broadcast" OR ${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND (NOT ${REDOP} STREQUAL "Sum" OR NOT ${TYPE} STREQUAL "int8_t"))
|
||||
continue()
|
||||
elseif((${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND NOT ${PROTO} STREQUAL "SIMPLE")
|
||||
continue()
|
||||
endif()
|
||||
|
||||
if(${REDOP} STREQUAL "SumPostDiv" AND TYPE IN_LIST FLOATS_LIST)
|
||||
continue()
|
||||
endif()
|
||||
|
||||
list(APPEND COLL_LIST "${COLL}-${ALGO}-${PROTO}-${REDOP}-${TYPE}")
|
||||
set(COLL_LIST ${COLL_LIST} PARENT_SCOPE)
|
||||
|
||||
## Append the newly formed function/kernel to list
|
||||
list(APPEND FUNC_LIST "ncclFunction_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
|
||||
list(APPEND KERN_LIST "ncclKernel_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
|
||||
set(FUNC_LIST ${FUNC_LIST} PARENT_SCOPE)
|
||||
set(KERN_LIST ${KERN_LIST} PARENT_SCOPE)
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
#####################################################################################################
|
||||
## Function to generate device table
|
||||
#####################################################################################################
|
||||
function(gen_device_table)
|
||||
set(DEVICE_TABLE_FILE "${HIPIFY_DIR}/src/collectives/device/device_table.cpp")
|
||||
message(STATUS "Generating ${DEVICE_TABLE_FILE}")
|
||||
|
||||
## Generate device table and list all the functions
|
||||
file(WRITE ${DEVICE_TABLE_FILE} "#include \"common.h\"\n#include \"collectives.h\"\n\n")
|
||||
|
||||
## Declaration of device functions
|
||||
foreach(func IN LISTS FUNC_LIST)
|
||||
if(ENABLE_IFC)
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "__device__ void ${func}();\n")
|
||||
else()
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "__device__ __attribute__((noinline)) void ${func}();\n")
|
||||
endif()
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "\n")
|
||||
|
||||
if(ENABLE_IFC)
|
||||
## Undirect function call
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "__device__ ncclKernelFunc_t const ncclFuncs[] = {\n")
|
||||
foreach(func ${FUNC_LIST})
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " ${func},\n")
|
||||
endforeach()
|
||||
## Add OneRankReduce functions at the end
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " ncclFunction_OneRankReduce_PreMulSum_${type},\n")
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "nullptr};\n\n")
|
||||
else()
|
||||
## Direct functions calls
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "__device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n switch(funcIndex) {\n")
|
||||
set(index 0)
|
||||
foreach(func IN LISTS FUNC_LIST)
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " case ${index}:\n ${func}();\n break;\n")
|
||||
math(EXPR index "${index} + 1")
|
||||
endforeach()
|
||||
## Add OneRankReduce functions at the end
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " case ${index}:\n ncclFunction_OneRankReduce_PreMulSum_${type}();\n break;\n")
|
||||
math(EXPR index "${index} + 1")
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " }\n}\n")
|
||||
endif()
|
||||
|
||||
## Add the device_table file to HIP_SOURCES
|
||||
list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE})
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
######################################################################################################
|
||||
## Function to generate host-side table
|
||||
######################################################################################################
|
||||
function(gen_host_table)
|
||||
set(HOST_TABLE_FILE "${HIPIFY_DIR}/src/collectives/device/host_table.cpp")
|
||||
message(STATUS "Generating ${HOST_TABLE_FILE}")
|
||||
|
||||
file(WRITE ${HOST_TABLE_FILE} "#include \"devcomm.h\"\n\n")
|
||||
|
||||
## The mapping from function rows to valid function ids
|
||||
file(APPEND ${HOST_TABLE_FILE} "extern int const ncclFuncRowToId[] = {\n")
|
||||
set(idx 0)
|
||||
foreach(coll IN LISTS ALL_COLLS)
|
||||
foreach(algo IN LISTS ALL_ALGOS)
|
||||
foreach(proto IN LISTS ALL_PROTOS)
|
||||
foreach(redop IN LISTS ALL_REDOPS)
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
if(NOT ${coll} STREQUAL "AllReduce" AND NOT ${algo} STREQUAL "RING")
|
||||
continue()
|
||||
elseif((${coll} STREQUAL "AllGather" OR ${coll} STREQUAL "Broadcast" OR ${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND (NOT ${redop} STREQUAL "Sum" OR NOT ${type} STREQUAL "int8_t"))
|
||||
continue()
|
||||
elseif((${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND NOT ${proto} STREQUAL "SIMPLE")
|
||||
continue()
|
||||
endif()
|
||||
|
||||
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
|
||||
continue()
|
||||
endif()
|
||||
|
||||
list(FIND FUNC_LIST "ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
|
||||
if(NOT ${fn_id} EQUAL -1)
|
||||
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}\n")
|
||||
else()
|
||||
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n")
|
||||
endif()
|
||||
math(EXPR idx "${idx} + 1")
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
math(EXPR fn_id "${fn_id} + 1")
|
||||
## Add OneRankReduce function ids at the end
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_OneRankReduce_PreMulSum_${type}\n")
|
||||
|
||||
## Increment the index and func id for each OneRankReduce
|
||||
math(EXPR idx "${idx} + 1")
|
||||
math(EXPR fn_id "${fn_id} + 1")
|
||||
endforeach()
|
||||
file(APPEND ${HOST_TABLE_FILE} "-1};\n\n")
|
||||
|
||||
## Add the host_table file to HIP_SOURCES
|
||||
list(APPEND HIP_SOURCES ${HOST_TABLE_FILE})
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
###########################################################################################################
|
||||
## Function to generate MSCCL Kernels
|
||||
###########################################################################################################
|
||||
function(gen_msccl_kernels)
|
||||
foreach(REDOP_CURRENT IN LISTS ALL_REDOPS)
|
||||
foreach(DATA_TYPE ${ALL_TYPES})
|
||||
if (REDOP_CURRENT STREQUAL "SumPostDiv" AND DATA_TYPE IN_LIST FLOATS_LIST)
|
||||
continue() # Skip the iteration for FLOATS_LIST when REDOP_CURRENT is SumPostDiv
|
||||
endif()
|
||||
set(FILE_NAME "${HIPIFY_DIR}/src/collectives/device/msccl_kernel_${REDOP_CURRENT}_${DATA_TYPE}.cpp")
|
||||
message(STATUS "Generating ${FILE_NAME}")
|
||||
file(WRITE ${FILE_NAME}
|
||||
"#include \"msccl_kernel_impl.h\"
|
||||
#include \"primitives.h\"
|
||||
#include \"collectives.h\"
|
||||
#include \"devcomm.h\"
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);")
|
||||
list(APPEND HIP_SOURCES ${FILE_NAME})
|
||||
endforeach()
|
||||
endforeach()
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
###########################################################################################################
|
||||
## Function to generate collectives
|
||||
###########################################################################################################
|
||||
function(gen_collectives)
|
||||
# Iterate over each item in the original list
|
||||
foreach(item ${COLL_LIST})
|
||||
# Split the string into components
|
||||
string(REPLACE "-" ";" item_components ${item})
|
||||
|
||||
# Extract COLL, ALGO, and PROTO components
|
||||
list(GET item_components 0 coll_prefix)
|
||||
list(GET item_components 1 algo_prefix)
|
||||
list(GET item_components 2 proto_prefix)
|
||||
list(GET item_components 3 redop_prefix)
|
||||
|
||||
# Create a list name using COLL, ALGO, and PROTO
|
||||
set(list_name "${coll_prefix}_${algo_prefix}_${proto_prefix}_${redop_prefix}")
|
||||
|
||||
# Add the item to the corresponding list
|
||||
list(APPEND ${list_name} ${item})
|
||||
|
||||
# Add the list name to the map if it doesn't exist
|
||||
if(NOT list_name IN_LIST divided_lists)
|
||||
list(APPEND divided_lists ${list_name})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(index 0)
|
||||
foreach(list_name IN LISTS divided_lists)
|
||||
foreach(item IN LISTS ${list_name})
|
||||
# Convert to a list
|
||||
string(REPLACE "-" ";" components ${item})
|
||||
|
||||
list(GET components 0 coll)
|
||||
list(GET components 1 algo)
|
||||
list(GET components 2 proto)
|
||||
list(GET components 3 redop)
|
||||
list(GET components 4 type)
|
||||
|
||||
list(APPEND IMPL_LIST "IMPL_COLL_FUNC(${coll}, ${algo}, ${proto}, ${redop}, ${type})\n")
|
||||
|
||||
# Increment the function id
|
||||
math(EXPR index "${index} + 1")
|
||||
endforeach()
|
||||
## Store lower-case version of COLL
|
||||
string(TOLOWER ${coll} COLL_LOWER)
|
||||
string(REPLACE "scatter" "_scatter" COLL_LOWER ${COLL_LOWER})
|
||||
if(NOT ${coll} STREQUAL "AllToAllPivot")
|
||||
string(REPLACE "all" "all_" COLL_LOWER ${COLL_LOWER})
|
||||
else()
|
||||
string(REPLACE "pivot" "_pivot" COLL_LOWER ${COLL_LOWER})
|
||||
endif()
|
||||
|
||||
## Set name/path of the file
|
||||
set(FILE_PATH "${HIPIFY_DIR}/src/collectives/device/${list_name}.cpp")
|
||||
message(STATUS "Generating ${FILE_PATH}")
|
||||
|
||||
## Construct the file
|
||||
file(WRITE ${FILE_PATH} "#include \"${COLL_LOWER}.h\"\n#include \"common.h\"\n#include \"collectives.h\"\n")
|
||||
foreach(IMPL IN LISTS IMPL_LIST)
|
||||
file(APPEND ${FILE_PATH} "${IMPL}")
|
||||
endforeach()
|
||||
|
||||
## Append the file to HIP sources list which will be added to source list
|
||||
list(APPEND HIP_SOURCES ${FILE_PATH})
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
|
||||
# Clear the IMPL list for the next iteration
|
||||
set(IMPL_LIST)
|
||||
endforeach()
|
||||
endfunction()
|
||||
|
||||
###################################################################################################################
|
||||
## Function to generate all the functions that are going to be in librccl.so
|
||||
###################################################################################################################
|
||||
function(gen_functions CONFIG_INPUT)
|
||||
string(REPLACE "|" ";" INPUT_LIST ${CONFIG_INPUT})
|
||||
## Sort the input so that it matches ALL_COLLS
|
||||
list(SORT INPUT_LIST)
|
||||
|
||||
foreach(INPUT ${INPUT_LIST})
|
||||
# Parse the the config string and make it a list
|
||||
string(REPLACE " " ";" FUNCTION_PARAMS ${INPUT})
|
||||
|
||||
# Get the number of parameters in the input
|
||||
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
|
||||
|
||||
# Assume all if a parameter is missing
|
||||
while(${PARAMS_LENGTH} LESS 5)
|
||||
list(APPEND FUNCTION_PARAMS "*")
|
||||
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
|
||||
endwhile()
|
||||
|
||||
## Filter functions/kernels based on input
|
||||
filter_functions(FUNCTION_PARAMS 0)
|
||||
endforeach()
|
||||
|
||||
gen_collectives() ## Generate collective files
|
||||
if(ENABLE_MSCCL_KERNEL)
|
||||
gen_msccl_kernels() ## Generate msccl files (not configurable)
|
||||
endif()
|
||||
gen_device_table() ## Generate device_table.cpp
|
||||
gen_host_table() ## Generate host_table.cpp
|
||||
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
+4
-12
@@ -8,7 +8,6 @@ ROCM_PATH=${ROCM_PATH:="/opt/rocm"}
|
||||
|
||||
# Default values
|
||||
build_address_sanitizer=false
|
||||
build_allreduce_only=false
|
||||
build_bfd=false
|
||||
build_freorg_bkwdcomp=false
|
||||
build_local_gpu_only=false
|
||||
@@ -24,7 +23,7 @@ enable_ninja=""
|
||||
install_dependencies=false
|
||||
install_library=false
|
||||
msccl_kernel_enabled=true
|
||||
num_parallel_jobs=16
|
||||
num_parallel_jobs=$(nproc)
|
||||
npkit_enabled=false
|
||||
run_tests=false
|
||||
run_tests_all=false
|
||||
@@ -38,7 +37,6 @@ function display_help()
|
||||
echo "RCCL build & installation helper script"
|
||||
echo " Options:"
|
||||
echo " --address-sanitizer Build with address sanitizer enabled"
|
||||
echo " --build_allreduce_only Build only AllReduce + sum + float kernel"
|
||||
echo " -d|--dependencies Install RCCL depdencencies"
|
||||
echo " --debug Build debug library"
|
||||
echo " --enable_backtrace Build with custom backtrace support"
|
||||
@@ -70,7 +68,7 @@ function display_help()
|
||||
# check if we have a modern version of getopt that can handle whitespace and long parameters
|
||||
getopt -T
|
||||
if [[ $? -eq 4 ]]; then
|
||||
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,build_allreduce_only,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
|
||||
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
|
||||
else
|
||||
echo "Need a new version of getopt"
|
||||
exit 1
|
||||
@@ -86,7 +84,6 @@ eval set -- "${GETOPT_PARSE}"
|
||||
while true; do
|
||||
case "${1}" in
|
||||
--address-sanitizer) build_address_sanitizer=true; shift ;;
|
||||
--build_allreduce_only) build_allreduce_only=true; shift ;;
|
||||
-d | --dependencies) install_dependencies=true; shift ;;
|
||||
--debug) build_release=false; shift ;;
|
||||
--enable_backtrace) build_bfd=true; shift ;;
|
||||
@@ -182,11 +179,6 @@ if [[ "${build_address_sanitizer}" == true ]]; then
|
||||
cmake_common_options="${cmake_common_options} -DBUILD_ADDRESS_SANITIZER=ON"
|
||||
fi
|
||||
|
||||
# AllReduce only
|
||||
if [[ "${build_allreduce_only}" == true ]]; then
|
||||
cmake_common_options="${cmake_common_options} -DBUILD_ALLREDUCE_ONLY=ON"
|
||||
fi
|
||||
|
||||
# Backtrace support
|
||||
if [[ "${build_bfd}" == true ]]; then
|
||||
cmake_common_options="${cmake_common_options} -DBUILD_BFD=ON"
|
||||
@@ -353,9 +345,9 @@ else
|
||||
fi
|
||||
|
||||
if ($build_tests) || (($run_tests) && [[ ! -f ./test/rccl-UnitTests ]]); then
|
||||
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=ON -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH $enable_ninja ../../.
|
||||
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=ON -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH -DONLY_FUNCS="$ONLY_FUNCS" $enable_ninja ../../.
|
||||
else
|
||||
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=OFF -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH $enable_ninja ../../.
|
||||
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=OFF -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH -DONLY_FUNCS="$ONLY_FUNCS" $enable_ninja ../../.
|
||||
fi
|
||||
check_exit_code "$?"
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "all_gather.h"
|
||||
#include "common.h"
|
||||
#include "collectives.h"
|
||||
|
||||
IMPL_COLL_C(AllGather);
|
||||
@@ -1,12 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
/*This file is now generated in CMake*/
|
||||
|
||||
// #include "all_reduce.h"
|
||||
// #include "common.h"
|
||||
// #include "collectives.h"
|
||||
|
||||
// IMPL_COLL_R(AllReduce);
|
||||
@@ -1,11 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "alltoall_pivot.h"
|
||||
#include "common.h"
|
||||
#include "collectives.h"
|
||||
|
||||
IMPL_COLL_F(AllToAllPivot);
|
||||
@@ -1,11 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "broadcast.h"
|
||||
#include "common.h"
|
||||
#include "collectives.h"
|
||||
|
||||
IMPL_COLL_C(Broadcast);
|
||||
@@ -32,243 +32,14 @@
|
||||
{ __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); }
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_LL128) && defined(__gfx90a__)
|
||||
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
|
||||
#else
|
||||
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
|
||||
#endif
|
||||
|
||||
#define NCCL_FUNC4(func, devredop, type, nullify) \
|
||||
NCCL_FUNC5(func, TREE, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, RING, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, NVLS, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify)
|
||||
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, double, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0)
|
||||
|
||||
// Must be consistent with ncclRedOp_t
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1)
|
||||
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
using ncclKernelFunc_t = void (*)();
|
||||
|
||||
static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
|
||||
// Don't try to initialize the host shadow copy of this device-side global
|
||||
// variable. There is no host pointer to a device-side function, which
|
||||
// confuses clang. This will be fixed in the next clang release.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(BUILD_ALLREDUCE_ONLY)
|
||||
NCCL_FUNC4(AllReduce, Sum, float, 0),
|
||||
#else
|
||||
NCCL_FUNCS2B(Broadcast),
|
||||
NCCL_FUNCS2A(Reduce),
|
||||
NCCL_FUNCS2B(AllGather),
|
||||
NCCL_FUNCS2A(ReduceScatter),
|
||||
NCCL_FUNCS2A(AllReduce),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16),
|
||||
#endif
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
|
||||
NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t),
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
// Defined in device_table.cpp
|
||||
extern __device__ ncclKernelFunc_t const ncclFuncs[];
|
||||
|
||||
static_assert(FUNC_INDEX_P2P == 5410, "Wrong P2P function index");
|
||||
static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 5411, "Wrong AllToAllPivot function index");
|
||||
|
||||
#if !defined(USE_INDIRECT_FUNCTION_CALL) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
template<unsigned short f, unsigned short l, bool u>
|
||||
struct Caller {
|
||||
static __forceinline__ __device__ __host__
|
||||
void call(unsigned short funcIndex) noexcept
|
||||
{
|
||||
constexpr unsigned short m = f + (l - f) / 2;
|
||||
|
||||
return (funcIndex < m) ? Caller<f, m, u>::call(funcIndex) : Caller<m, l, u>::call(funcIndex);
|
||||
}
|
||||
};
|
||||
|
||||
template<unsigned short f, bool u>
|
||||
struct Caller<f, f + 1, u>{
|
||||
static __forceinline__ __device__ __host__
|
||||
void call(unsigned short funcIndex) noexcept { ncclFuncs[f](); }
|
||||
};
|
||||
|
||||
template<bool USING_LL128>
|
||||
__forceinline__
|
||||
__device__
|
||||
void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {
|
||||
#if defined(BUILD_ALLREDUCE_ONLY)
|
||||
if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_RING_SIMPLE_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_RING_LL_Sum_float();
|
||||
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_RING_LL128_Sum_float();
|
||||
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_RING_LL_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_TREE_SIMPLE_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_TREE_LL_Sum_float();
|
||||
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_TREE_LL128_Sum_float();
|
||||
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_TREE_LL_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_COLLNET_DIRECT_SIMPLE_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_COLLNET_DIRECT_LL_Sum_float();
|
||||
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_COLLNET_DIRECT_LL128_Sum_float();
|
||||
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_COLLNET_DIRECT_LL_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_COLLNET_CHAIN_SIMPLE_Sum_float();
|
||||
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_COLLNET_CHAIN_LL_Sum_float();
|
||||
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_COLLNET_CHAIN_LL128_Sum_float();
|
||||
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_COLLNET_CHAIN_LL_Sum_float();
|
||||
else
|
||||
assert("Unsupported function index");
|
||||
#else
|
||||
if (funcIndex < 1080) {
|
||||
if (funcIndex % 18 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 6) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 8) ncclFunction_Broadcast_COLLNET_DIRECT_SIMPLE_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 9) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t();
|
||||
else ncclFunction_Broadcast_COLLNET_CHAIN_SIMPLE_Sum_int8_t();
|
||||
}
|
||||
else if (funcIndex < 2160) Caller<1080, 2160, USING_LL128>::call(funcIndex);
|
||||
else if (funcIndex < 3240) {
|
||||
if (funcIndex % 18 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 6) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 8) ncclFunction_AllGather_COLLNET_DIRECT_SIMPLE_Sum_int8_t();
|
||||
else if (funcIndex % 18 == 9) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t();
|
||||
else if (USING_LL128 && funcIndex % 18 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL128_Sum_int8_t();
|
||||
else if (!USING_LL128 && funcIndex % 18 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t();
|
||||
else ncclFunction_AllGather_COLLNET_CHAIN_SIMPLE_Sum_int8_t();
|
||||
}
|
||||
else if (funcIndex < 5400) Caller<3240, 5400, USING_LL128>::call(funcIndex);
|
||||
else {
|
||||
switch (funcIndex - 5400) {
|
||||
case 0:
|
||||
ncclFunction_OneRankReduce_PreMulSum_int8_t();
|
||||
break;
|
||||
case 1:
|
||||
ncclFunction_OneRankReduce_PreMulSum_uint8_t();
|
||||
break;
|
||||
case 2:
|
||||
ncclFunction_OneRankReduce_PreMulSum_int32_t();
|
||||
break;
|
||||
case 3:
|
||||
ncclFunction_OneRankReduce_PreMulSum_uint32_t();
|
||||
break;
|
||||
case 4:
|
||||
ncclFunction_OneRankReduce_PreMulSum_int64_t();
|
||||
break;
|
||||
case 5:
|
||||
ncclFunction_OneRankReduce_PreMulSum_uint64_t();
|
||||
break;
|
||||
case 6:
|
||||
ncclFunction_OneRankReduce_PreMulSum_half();
|
||||
break;
|
||||
case 7:
|
||||
ncclFunction_OneRankReduce_PreMulSum_float();
|
||||
break;
|
||||
case 8:
|
||||
ncclFunction_OneRankReduce_PreMulSum_double();
|
||||
break;
|
||||
case 9:
|
||||
ncclFunction_OneRankReduce_PreMulSum_rccl_bfloat16();
|
||||
break;
|
||||
case 10:
|
||||
ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t();
|
||||
break;
|
||||
case 11:
|
||||
ncclFunction_AllToAllPivot_RING_SIMPLE_Sum_int8_t();
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifndef USE_INDIRECT_FUNCTION_CALL
|
||||
__device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept;
|
||||
#endif
|
||||
|
||||
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
|
||||
@@ -464,7 +235,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we
|
||||
}
|
||||
}
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int FnIndex, bool COLLTRACE>
|
||||
template<bool COLLTRACE>
|
||||
__forceinline__ __device__ void ncclKernel(
|
||||
struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead
|
||||
) {
|
||||
@@ -565,19 +336,12 @@ __forceinline__ __device__ void ncclKernel(
|
||||
__synclds();
|
||||
|
||||
if (tid == 0) __insert_timestamp(__LINE__);
|
||||
if (ncclShmem.work.header.funcIndex == FnIndex) {
|
||||
RunWork<Fn, T, RedOp, Algo, Proto>().run(&ncclShmem.work);
|
||||
} else {
|
||||
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
|
||||
ncclFuncs[ncclShmem.work.header.funcIndex]();
|
||||
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
ncclFuncs[ncclShmem.work.header.funcIndex]();
|
||||
#else
|
||||
#if defined(ENABLE_LL128) && defined(__gfx90a__)
|
||||
NCCL_CALL_FUNCTIONS<1>(ncclShmem.work.header.funcIndex);
|
||||
#else
|
||||
NCCL_CALL_FUNCTIONS<0>(ncclShmem.work.header.funcIndex);
|
||||
NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
int workIxNext = ncclShmem.work.header.workNext;
|
||||
__synclds();
|
||||
@@ -606,28 +370,25 @@ __forceinline__ __device__ void ncclKernel(
|
||||
}
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \
|
||||
#define IMPL_MAIN_KERN() \
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
|
||||
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \
|
||||
__global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
|
||||
ncclKernel<false>(comm, channelMask, workHead); \
|
||||
} \
|
||||
\
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
|
||||
__global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
|
||||
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(comm, channelMask, workHead); \
|
||||
__global__ void rccl_main_kernel_debug(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
|
||||
ncclKernel<true>(comm, channelMask, workHead); \
|
||||
}
|
||||
#else
|
||||
#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \
|
||||
#define IMPL_MAIN_KERN() \
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
|
||||
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \
|
||||
__global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
|
||||
ncclKernel<false>(comm, channelMask, workHead); \
|
||||
}
|
||||
#endif
|
||||
|
||||
// Examples : AllReduce, RING, LL, Sum, uint8
|
||||
/* Functions for aggregation case */
|
||||
|
||||
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
#define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \
|
||||
__device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \
|
||||
RunWork<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \
|
||||
@@ -639,67 +400,6 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev
|
||||
}
|
||||
#endif
|
||||
|
||||
// Only generate inline kernels for LL
|
||||
#if defined(ENABLE_LL128) && defined(__gfx90a__)
|
||||
#define IMPL_COLL4(func, algo, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type)
|
||||
#else
|
||||
#define IMPL_COLL4(func, algo, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type)
|
||||
#endif
|
||||
|
||||
#define IMPL_COLL3(func, devredop, type) \
|
||||
IMPL_COLL4(func, TREE, devredop, type) \
|
||||
IMPL_COLL4(func, RING, devredop, type) \
|
||||
IMPL_COLL4(func, COLLNET_DIRECT, devredop, type) \
|
||||
IMPL_COLL4(func, COLLNET_CHAIN, devredop, type) \
|
||||
IMPL_COLL4(func, NVLS, devredop, type) \
|
||||
IMPL_COLL4(func, NVLS_TREE, devredop, type)
|
||||
|
||||
#define IMPL_COLL2(func, devredop) \
|
||||
IMPL_COLL3(func, devredop, int8_t) \
|
||||
IMPL_COLL3(func, devredop, uint8_t) \
|
||||
IMPL_COLL3(func, devredop, int32_t) \
|
||||
IMPL_COLL3(func, devredop, uint32_t) \
|
||||
IMPL_COLL3(func, devredop, int64_t) \
|
||||
IMPL_COLL3(func, devredop, uint64_t) \
|
||||
IMPL_COLL3(func, devredop, half) \
|
||||
IMPL_COLL3(func, devredop, float) \
|
||||
IMPL_COLL3(func, devredop, double) \
|
||||
IMPL_COLL3(func, devredop, rccl_bfloat16)
|
||||
|
||||
#define IMPL_COLL2A(func, devredop) \
|
||||
IMPL_COLL3(func, devredop, int8_t) \
|
||||
IMPL_COLL3(func, devredop, uint8_t) \
|
||||
IMPL_COLL3(func, devredop, int32_t) \
|
||||
IMPL_COLL3(func, devredop, uint32_t) \
|
||||
IMPL_COLL3(func, devredop, int64_t) \
|
||||
IMPL_COLL3(func, devredop, uint64_t)
|
||||
|
||||
// Reduction define all functions
|
||||
#define IMPL_COLL_R(func) \
|
||||
IMPL_COLL2(func, Sum) \
|
||||
IMPL_COLL2(func, Prod) \
|
||||
IMPL_COLL2(func, Min) \
|
||||
IMPL_COLL2(func, Max) \
|
||||
IMPL_COLL2(func, PreMulSum) \
|
||||
IMPL_COLL2A(func, SumPostDiv)
|
||||
|
||||
// Copy primitives only define one function for copy
|
||||
#define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t);
|
||||
|
||||
// Point-to-point primitives only have one function/kernel.
|
||||
#define IMPL_COLL_P(func) \
|
||||
IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \
|
||||
IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, FUNC_INDEX_P2P);
|
||||
|
||||
// AllToAll Pivot primitive only has one function.
|
||||
#define IMPL_COLL_F(func) \
|
||||
IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t);
|
||||
|
||||
#define NCCL_NVLS_ENABLED (__CUDA_ARCH__ >= 900 && NCCL_NVLS_SUPPORTS(NCCL_TYPE, NCCL_OP))
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -1,126 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "collectives.h"
|
||||
#include "common.h"
|
||||
|
||||
__shared__ ncclShmemData ncclShmem;
|
||||
#if __CUDA_ARCH__ < 700
|
||||
__shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)];
|
||||
#endif
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
#else
|
||||
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
|
||||
|
||||
#define NCCL_FUNC4(func, devredop, type, nullify) \
|
||||
NCCL_FUNC5(func, TREE, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, RING, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, NVLS, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, double, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, __nv_bfloat16, nullForFloat)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0)
|
||||
#else
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, double, nullForFloat)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0)
|
||||
#endif
|
||||
|
||||
// Must be consistent with ncclRedOp_t
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1)
|
||||
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
|
||||
// Don't try to initialize the host shadow copy of this device-side global
|
||||
// variable. There is no host pointer to a device-side function, which
|
||||
// confuses clang. This will be fixed in the next clang release.
|
||||
#if __CUDA_ARCH__
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, __nv_bfloat16),
|
||||
#endif
|
||||
NCCL_FUNCS2B(Broadcast),
|
||||
NCCL_FUNCS2A(Reduce),
|
||||
NCCL_FUNCS2B(AllGather),
|
||||
NCCL_FUNCS2A(ReduceScatter),
|
||||
NCCL_FUNCS2A(AllReduce)
|
||||
#endif
|
||||
};
|
||||
#endif
|
||||
|
||||
// Workaround for https://reviews.llvm.org/D55580
|
||||
__device__ void ncclWorkaroundClangD55580() {}
|
||||
@@ -1,13 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/*This file is now generated in CMake*/
|
||||
|
||||
// #include "reduce.h"
|
||||
// #include "common.h"
|
||||
// #include "collectives.h"
|
||||
|
||||
// IMPL_COLL_R(Reduce);
|
||||
@@ -1,13 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/*This file is now generated in CMake*/
|
||||
|
||||
// #include "reduce_scatter.h"
|
||||
// #include "common.h"
|
||||
// #include "collectives.h"
|
||||
|
||||
// IMPL_COLL_R(ReduceScatter);
|
||||
@@ -1,11 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "sendrecv.h"
|
||||
#include "common.h"
|
||||
#include "collectives.h"
|
||||
|
||||
IMPL_COLL_P(SendRecv);
|
||||
+11
-10
@@ -20,8 +20,6 @@
|
||||
#include <cstring> // std::memcpy
|
||||
#include <cinttypes> // PRIx64
|
||||
|
||||
static void* const ncclKernelGeneric = (void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t);
|
||||
|
||||
struct ncclKernelMatch {
|
||||
void* kernelFn;
|
||||
bool specialized;
|
||||
@@ -29,15 +27,18 @@ struct ncclKernelMatch {
|
||||
|
||||
typedef void(*ncclKern_t)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
|
||||
|
||||
// Definition of rccl_main_kernel which is only used in here
|
||||
IMPL_MAIN_KERN();
|
||||
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
static ncclKernelMatch const ncclKerns[2] = {
|
||||
{(void *)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true},
|
||||
{(void *)NCCL_KERN_NAME_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), true},
|
||||
{(void *)rccl_main_kernel, true},
|
||||
{(void *)rccl_main_kernel_debug, true},
|
||||
};
|
||||
#else
|
||||
static ncclKernelMatch const ncclKerns[1] = {
|
||||
{(void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true}
|
||||
{(void*)rccl_main_kernel, true}
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -169,7 +170,7 @@ static void appendWorkElemP2p(
|
||||
struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId,
|
||||
struct ncclWorkElemP2p const *elem, bool fuseOk
|
||||
) {
|
||||
constexpr int funcIndex = FUNC_INDEX_P2P;
|
||||
int funcIndex = ncclFuncId_P2p();
|
||||
struct ncclKernelPlan::Channel* chan = &plan->channels[channelId];
|
||||
struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue);
|
||||
if (q && funcIndex == q->work.header.funcIndex) {
|
||||
@@ -191,7 +192,7 @@ static void appendWorkElemP2p(
|
||||
}
|
||||
q = ncclMemoryStackAlloc<struct ncclWorkList>(&comm->memScoped);
|
||||
q->work.header.type = ncclWorkTypeP2p;
|
||||
q->work.header.funcIndex = FUNC_INDEX_P2P;
|
||||
q->work.header.funcIndex = ncclFuncId_P2p();
|
||||
chan->p2pTailElem[ncclWorkP2pTypeRecv-1] = 0;
|
||||
chan->p2pTailElem[ncclWorkP2pTypeSend-1] = 1;
|
||||
q->work.p2pElems[chan->p2pTailElem[elem->p2pType-1]] = *elem; // C++ struct assignment
|
||||
@@ -1313,12 +1314,12 @@ comp_next:
|
||||
|
||||
if (info->comm->nRanks == 1) {
|
||||
// one-rank reduce index
|
||||
*workFuncIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype);
|
||||
*workFuncIndex = ncclFuncId_P2p() + int(info->datatype);
|
||||
return ncclSuccess;
|
||||
} else if (info->coll == ncclFuncAllToAllPivot) {
|
||||
*workFuncIndex = FUNC_INDEX_ALLTOALL_PIVOT;
|
||||
*workFuncIndex = ncclFuncId_AllToAllPivot();
|
||||
} else {
|
||||
*workFuncIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
|
||||
*workFuncIndex = ncclFuncId(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
|
||||
}
|
||||
|
||||
work->connIndex = 0;
|
||||
|
||||
@@ -19,9 +19,8 @@ struct ncclDevRedOpFull {
|
||||
uint64_t scalarArg;
|
||||
};
|
||||
|
||||
#define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps)
|
||||
#define FUNC_INDEX_ALLTOALL_PIVOT (FUNC_INDEX_P2P+1)
|
||||
#define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
|
||||
#define FUNC_INDEX_P2P 1015
|
||||
#define FUNC_INDEX_ALLTOALL_PIVOT 675
|
||||
|
||||
#define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \
|
||||
ncclFunction_##func##_##algo##_##proto##_##devredop##_##type
|
||||
@@ -38,79 +37,13 @@ struct ncclDevRedOpFull {
|
||||
#define NCCL_IMPL_NAME(func, algo, proto) \
|
||||
nccl##func##algo##proto
|
||||
|
||||
/* Declare all collective operations */
|
||||
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
|
||||
#define DECL5(func, algo, proto, devredop, type) \
|
||||
extern __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
|
||||
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
|
||||
#else
|
||||
#define DECL5(func, algo, proto, devredop, type) \
|
||||
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
|
||||
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
|
||||
// Declare rccl main/general kernel
|
||||
extern __global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
extern __global__ void rccl_main_kernel_debug(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
|
||||
#endif
|
||||
|
||||
#define SINGLE_ARG(...) __VA_ARGS__
|
||||
#define CONCAT(a,b) a##b
|
||||
#define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(SINGLE_ARG(t), SINGLE_ARG(f))
|
||||
#define MACRO_IF_0(t, f) f
|
||||
#define MACRO_IF_1(t, f) t
|
||||
|
||||
#define DECL4(func, algo, devredop, type, undef) \
|
||||
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, SIMPLE, devredop, type)) \
|
||||
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL, devredop, type)) \
|
||||
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type))
|
||||
|
||||
#define DECL3(func, devredop, type, undef) \
|
||||
DECL4(func, RING, devredop, type, undef) \
|
||||
DECL4(func, TREE, devredop, type, undef) \
|
||||
DECL4(func, COLLNET_DIRECT, devredop, type, undef) \
|
||||
DECL4(func, COLLNET_CHAIN, devredop, type, undef) \
|
||||
DECL4(func, NVLS, devredop, type, undef) \
|
||||
DECL4(func, NVLS_TREE, devredop, type, undef)
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
#define DECL2(func, devredop, undefForFloat) \
|
||||
DECL3(func, devredop, int8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, half, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, float, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, double, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, rccl_bfloat16, /*undef=*/undefForFloat)
|
||||
#else
|
||||
#define DECL2(func, devredop, undefForFloat) \
|
||||
DECL3(func, devredop, int8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, half, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, float, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, double, /*undef=*/undefForFloat)
|
||||
#endif
|
||||
|
||||
#define DECL(func) \
|
||||
DECL2(func, Sum, /*undefForFloat=*/0) \
|
||||
DECL2(func, Prod, /*undefForFloat=*/0) \
|
||||
DECL2(func, Min, /*undefForFloat=*/0) \
|
||||
DECL2(func, Max, /*undefForFloat=*/0) \
|
||||
DECL2(func, PreMulSum, /*undefForFloat=*/0) \
|
||||
DECL2(func, SumPostDiv, /*undefForFloat=*/1)
|
||||
|
||||
DECL2(Broadcast, Sum, /*undefForFloat=*/0)
|
||||
DECL(Reduce)
|
||||
DECL2(AllGather, Sum, /*undefForFloat=*/0)
|
||||
DECL(ReduceScatter)
|
||||
DECL(AllReduce)
|
||||
DECL5(SendRecv, RING, SIMPLE, Sum, int8_t)
|
||||
DECL5(AllToAllPivot, RING, SIMPLE, Sum, int8_t)
|
||||
|
||||
// Declare OneRankReduce
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t)();
|
||||
@@ -118,9 +51,7 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, half)();
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16)();
|
||||
#endif
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)();
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)();
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "nccl.h"
|
||||
#include "rccl_bfloat16.h"
|
||||
#include "align.h"
|
||||
#include "collectives.h"
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include "npkit/npkit_struct.h"
|
||||
#endif
|
||||
@@ -483,4 +484,62 @@ __host__ __device__ constexpr int ncclShmemDynamicSize(int cudaArch = NCCL_CUDA_
|
||||
return cudaArch < 700 ? 0 : ncclShmemScratchWarpSize(cudaArch)*(NCCL_MAX_NTHREADS/WARP_SIZE);
|
||||
}
|
||||
|
||||
// Map the rowIdx to funcIdx
|
||||
extern int const ncclFuncRowToId[];
|
||||
|
||||
// `ncclFuncIndex()` needs to be in sync with 'ALL_COLLS' in Generate.cmake
|
||||
inline int ncclFuncId(int coll, int devRedOp, int type, int algo, int proto) {
|
||||
int row = 0;
|
||||
|
||||
// RING / <all_protos> / Sum / int8_t
|
||||
if (coll == ncclFuncAllGather) {
|
||||
row += proto;
|
||||
goto have_row;
|
||||
}
|
||||
row += NCCL_NUM_PROTOCOLS;
|
||||
|
||||
// <all_algos> / <all_protos> / <all_redops> / <all_types>
|
||||
if (coll == ncclFuncAllReduce) {
|
||||
row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * (algo * NCCL_NUM_PROTOCOLS + proto);
|
||||
goto have_row;
|
||||
}
|
||||
row += (NCCL_NUM_ALGORITHMS - 2) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4);
|
||||
|
||||
// RING / SIMPLE / Sum / int8_t
|
||||
if (coll == ncclFuncAllToAllPivot) goto have_row;
|
||||
row += 1;
|
||||
|
||||
// RING / <all_protos> / Sum / int8_t
|
||||
if (coll == ncclFuncBroadcast) {
|
||||
row += proto;
|
||||
goto have_row;
|
||||
}
|
||||
row += NCCL_NUM_PROTOCOLS;
|
||||
|
||||
// RING / <all_protos> / <all_redops> / <all_types>
|
||||
if (coll == ncclFuncReduce) {
|
||||
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * proto;
|
||||
goto have_row;
|
||||
}
|
||||
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4);
|
||||
|
||||
// RING / <all_protos> / <all_redops> / <all_types>
|
||||
if (coll == ncclFuncReduceScatter) {
|
||||
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * proto;
|
||||
goto have_row;
|
||||
}
|
||||
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4);
|
||||
|
||||
// RING / SIMPLE / Sum / int8_t
|
||||
if (coll == ncclFuncSendRecv) goto have_row;
|
||||
row += 1;
|
||||
|
||||
have_row:
|
||||
return ncclFuncRowToId[row];
|
||||
}
|
||||
|
||||
inline int ncclFuncId_P2p() { return ncclFuncRowToId[FUNC_INDEX_P2P]; }
|
||||
|
||||
inline int ncclFuncId_AllToAllPivot() { return ncclFuncRowToId[FUNC_INDEX_ALLTOALL_PIVOT]; }
|
||||
|
||||
#endif
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "comm.h"
|
||||
#include "group.h"
|
||||
#include "collectives.h"
|
||||
#include "common.h"
|
||||
#include "utils.h"
|
||||
|
||||
#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
|
||||
|
||||
+58
-20
@@ -18,6 +18,7 @@
|
||||
#include "enqueue.h"
|
||||
#include "graph.h"
|
||||
#include "argcheck.h"
|
||||
#include "devcomm.h"
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
@@ -31,6 +32,7 @@
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdarg>
|
||||
#include "graph/topo.h"
|
||||
#include "graph/xml.h"
|
||||
#include "archinfo.h"
|
||||
@@ -54,7 +56,7 @@
|
||||
#define NCCL_GROUP_CUDA_STREAM 1 // CGMD: CUDA 9.0,9.1 Need to use an internal CUDA stream
|
||||
#endif
|
||||
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" };
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "AllGather", "AllReduce", "AllToAllPivot", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"};
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree" };
|
||||
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
|
||||
const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" };
|
||||
@@ -177,6 +179,18 @@ void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) {
|
||||
RCCL_PARAM(KernelCollTraceEnable, "KERNEL_COLL_TRACE_ENABLE", 0);
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
#define MAX_NAME_LENGTH 64
|
||||
// Helper function to generate function names and update funcIdx
|
||||
void generateFunctionName(char* func_names, int& funcIdx, const char* format, ...) {
|
||||
char* line = func_names + MAX_NAME_LENGTH * funcIdx;
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
vsnprintf(line, MAX_NAME_LENGTH, format, args);
|
||||
va_end(args);
|
||||
funcIdx++;
|
||||
}
|
||||
|
||||
// Should be in sync with 'ALL_COLLS' in Generator.cmake
|
||||
void *ncclCommThreadMain(void *arg) {
|
||||
ncclComm_t comm = (ncclComm_t)arg;
|
||||
int head[MAXCHANNELS];
|
||||
@@ -184,29 +198,53 @@ void *ncclCommThreadMain(void *arg) {
|
||||
|
||||
memset(head, 0, sizeof(int)*MAXCHANNELS);
|
||||
vega_gpu_rtc_freq = GetDeviceWallClockRateInKhz(comm->cudaDev) * 1.0E3;
|
||||
#define MAX_NAME_LENGTH 64
|
||||
char* func_names = (char *)malloc(MAX_NAME_LENGTH*(FUNC_INDEX_P2P+2));
|
||||
for (int func = 0; func < NCCL_NUM_FUNCTIONS; func++) {
|
||||
for (int al = 0; al < NCCL_NUM_ALGORITHMS; al++) {
|
||||
for (int type = 0; type < ncclNumTypes; type++) {
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
for (int devredop = 0; devredop < ncclNumDevRedOps; devredop++) {
|
||||
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, devredop, type, al, pr);
|
||||
sprintf(line, "%s%s%s%s%s", ncclFuncStr[func], ncclAlgoStr[al], ncclProtoStr[pr],
|
||||
ncclDevRedOpStr[devredop], ncclTypeStr[type]);
|
||||
}
|
||||
char* func_names = (char *)malloc(MAX_NAME_LENGTH*(ncclFuncId_P2p()+/*OneRankReduce*/11));
|
||||
int funcIdx = 0;
|
||||
// AllGather --> RING / <all_protos> / Sum / int8_t
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
generateFunctionName(func_names, funcIdx, "AllGatherRing%sSum_i8", ncclProtoStr[pr]);
|
||||
}
|
||||
// AllReduce --> <all_algos> / <all_protos> / <all_redops> / <all_types>
|
||||
for (int al = 0; al < NCCL_NUM_ALGORITHMS - 2; al++) {
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
|
||||
for (int ty = 0; ty < ncclNumTypes; ty++) {
|
||||
if (redop == 5 && ty > 5) continue;
|
||||
generateFunctionName(func_names, funcIdx, "AllReduce%s%s%s%s", ncclAlgoStr[al], ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int type = 0; type < ncclNumTypes; type++) {
|
||||
char* line = func_names+MAX_NAME_LENGTH*(FUNC_INDEX_P2P-ncclNumTypes+type);
|
||||
sprintf(line, "OneRankReducePreMulSum%s", ncclTypeStr[type]);
|
||||
// AllToAllPivot --> RING / SIMPLE / Sum / int8_t
|
||||
generateFunctionName(func_names, funcIdx, "AllToAllPivotRingSimpleSum_i8");
|
||||
// Broadcast --> RING / <all_protos> / Sum / int8_t
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
generateFunctionName(func_names, funcIdx, "BroadcastRing%sSum_i8", ncclProtoStr[pr]);
|
||||
}
|
||||
// Reduce --> RING / <all_protos> / <all_redops> / <all_types>
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
|
||||
for (int ty = 0; ty < ncclNumTypes; ty++) {
|
||||
if (redop == 5 && ty > 5) continue;
|
||||
generateFunctionName(func_names, funcIdx, "ReduceRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// ReduceScatter --> RING / <all_protos> / <all_redops> / <all_types>
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
|
||||
for (int ty = 0; ty < ncclNumTypes; ty++) {
|
||||
if (redop == 5 && ty > 5) continue;
|
||||
generateFunctionName(func_names, funcIdx, "ReduceScatterRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// SendRecv --> RING / SIMPLE / Sum / int8_t
|
||||
generateFunctionName(func_names, funcIdx, "SendRecvRingSimpleSum_i8");
|
||||
// OneRankReduce --> PreMulSum / <all_types>
|
||||
for (int ty = 0; ty < ncclNumTypes; ty++) {
|
||||
generateFunctionName(func_names, funcIdx, "OneRankReducePreMulSum%s", ncclTypeStr[ty]);
|
||||
}
|
||||
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX_P2P;
|
||||
sprintf(line, "SendRecvRingSimpleSum_i8");
|
||||
line += MAX_NAME_LENGTH;
|
||||
sprintf(line, "AllToAllPivotRingSimpleSum_i8");
|
||||
do {
|
||||
for (int channel = 0; channel < MAXCHANNELS; channel++) {
|
||||
int tail = comm->collTraceTail[channel].tail%COLLTRACE_NUM_ITEMS;
|
||||
@@ -232,7 +270,7 @@ void *ncclCommThreadMain(void *arg) {
|
||||
(double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid,
|
||||
fIdx, td->data_0, td->opCount, td->data_1);
|
||||
} else {
|
||||
if (fIdx == FUNC_INDEX_P2P || type == ncclCollTraceP2pElemType)
|
||||
if (fIdx == ncclFuncId_P2p() || type == ncclCollTraceP2pElemType)
|
||||
sprintf(line, "## [%012.6f] [%02d:%02d] %06x-%06x", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->p2pOpCount[0], td->p2pOpCount[1]);
|
||||
else
|
||||
sprintf(line, "## [%012.6f] [%02d:%02d] %06lx", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->opCount);
|
||||
|
||||
@@ -37,10 +37,18 @@ if(BUILD_TESTS)
|
||||
)
|
||||
|
||||
# Collect source files for tests
|
||||
if(BUILD_ALLREDUCE_ONLY)
|
||||
set(TEST_SOURCE_FILES
|
||||
AllReduce_Tests.cpp
|
||||
)
|
||||
if(ONLY_FUNCS)
|
||||
# Convert input string to a list
|
||||
string(REPLACE "|" ";" CONFIG_LIST ${ONLY_FUNCS})
|
||||
|
||||
# For each config in config list
|
||||
foreach(item ${CONFIG_LIST})
|
||||
string(REPLACE " " ";" CONFIG_PARAMS ${item})
|
||||
list(GET CONFIG_PARAMS 0 COLL)
|
||||
|
||||
set(TEST_FILE "${COLL}Tests.cpp")
|
||||
list(APPEND TEST_SOURCE_FILES ${TEST_FILE})
|
||||
endforeach()
|
||||
else()
|
||||
set(TEST_SOURCE_FILES
|
||||
AllGatherTests.cpp
|
||||
|
||||
@@ -65,12 +65,9 @@ namespace RcclUnitTesting
|
||||
useInteractive = GetEnvVar("UT_INTERACTIVE", 0);
|
||||
timeoutUs = GetEnvVar("UT_TIMEOUT_US" , 5000000);
|
||||
|
||||
// Limit number of supported reduction operators to just ncclSum if only allReduce is built
|
||||
#ifdef BUILD_ALLREDUCE_ONLY
|
||||
int numOps = 1;
|
||||
#else
|
||||
// Total number of reduction ops
|
||||
int numOps = ncclNumOps;
|
||||
#endif
|
||||
|
||||
std::vector<std::string> redOpStrings = GetEnvVarsList("UT_REDOPS");
|
||||
for (auto s : redOpStrings)
|
||||
{
|
||||
@@ -98,12 +95,7 @@ namespace RcclUnitTesting
|
||||
{
|
||||
if (!strcmp(s.c_str(), ncclDataTypeNames[i]))
|
||||
{
|
||||
#ifdef BUILD_ALLREDUCE_ONLY
|
||||
if (i == ncclFloat32)
|
||||
#endif
|
||||
{
|
||||
dataTypes.push_back((ncclDataType_t)i);
|
||||
}
|
||||
dataTypes.push_back((ncclDataType_t)i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -112,8 +104,6 @@ namespace RcclUnitTesting
|
||||
if (dataTypes.empty())
|
||||
{
|
||||
dataTypes.push_back(ncclFloat32);
|
||||
// Skip all but 32-bit floats if only AllReduce is being built
|
||||
#ifndef BUILD_ALLREDUCE_ONLY
|
||||
dataTypes.push_back(ncclInt8);
|
||||
dataTypes.push_back(ncclUint8);
|
||||
dataTypes.push_back(ncclInt32);
|
||||
@@ -124,7 +114,6 @@ namespace RcclUnitTesting
|
||||
dataTypes.push_back(ncclFloat32);
|
||||
dataTypes.push_back(ncclFloat64);
|
||||
dataTypes.push_back(ncclBfloat16);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Build list of possible # GPU ranks based on env vars
|
||||
|
||||
Ссылка в новой задаче
Block a user