[DEV] Configure functions in RCCL (#986)

* configure functions in rccl
Этот коммит содержится в:
Bertan Dogancay
2024-01-18 15:07:16 -07:00
коммит произвёл GitHub
родитель 05850e89f2
Коммит 28d9b170c9
24 изменённых файлов: 570 добавлений и 749 удалений
+1 -1
Просмотреть файл
@@ -17,7 +17,7 @@ def runCI =
def prj = new rocProject('rccl', 'Extended')
prj.timeout.test = 600
prj.paths.build_command = './install.sh -t'
prj.paths.build_command = './install.sh -tj 16'
// Define test architectures, optional rocm version argument is available
def nodes = new dockerNodes(nodeDetails, jobName, prj)
+1 -1
Просмотреть файл
@@ -18,7 +18,7 @@ def runCI =
def prj = new rocProject('rccl', 'PreCheckin')
prj.timeout.test = 300
prj.paths.build_command = './install.sh -t --fast'
prj.paths.build_command = './install.sh -tj 16 --fast'
// Define test architectures, optional rocm version argument is available
def nodes = new dockerNodes(nodeDetails, jobName, prj)
+1 -1
Просмотреть файл
@@ -12,7 +12,7 @@ def runCI =
def prj = new rocProject('rccl', 'Static Library PreCheckin')
prj.timeout.test = 1440
prj.paths.build_command = './install.sh -t --static'
prj.paths.build_command = './install.sh -tj 16 --static'
def nodes = new dockerNodes(nodeDetails, jobName, prj)
+9 -32
Просмотреть файл
@@ -10,7 +10,6 @@ project(rccl CXX)
# Build options
#==================================================================================================
option(BUILD_ADDRESS_SANITIZER "Enable address sanitizer" OFF)
option(BUILD_ALLREDUCE_ONLY "AllReduce(sum,float) kernel only" OFF)
option(BUILD_BFD "Enable custom backtrace (if bfd.h exists)" OFF)
option(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY "File/folder reorg with backward compatibility" OFF)
option(BUILD_LOCAL_GPU_TARGET_ONLY "Build only for GPUs detected on this machine" OFF)
@@ -46,6 +45,7 @@ set(DEFAULT_GPUS
include(CheckIncludeFiles)
include(CheckSymbolExists)
include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets
include(cmake/Generator.cmake) # Configure functions that goes into RCCL
# Build only for local GPU architecture
if (BUILD_LOCAL_GPU_TARGET_ONLY)
@@ -309,6 +309,7 @@ set(SRC_FILES
src/collectives/device/reduce_kernel.h
src/collectives/device/reduce_scatter.h
src/collectives/device/sendrecv.h
src/collectives/device/onerank_reduce.cu
src/collectives/gather.cc
src/collectives/reduce.cc
src/collectives/reduce_scatter.cc
@@ -447,31 +448,6 @@ set(SRC_FILES
src/transport/shm.cc
)
## Add kernel files
## E.g: find src -type f \( -name "*.u" \) | sort
if (BUILD_ALLREDUCE_ONLY)
add_definitions(-DBUILD_ALLREDUCE_ONLY)
set(CU_SOURCES
# src/collectives/device/all_reduce.cu
src/collectives/device/sendrecv.cu
src/collectives/device/functions.cu
# src/collectives/device/msccl_kernel.cu
)
else()
set(CU_SOURCES
src/collectives/device/all_gather.cu
# src/collectives/device/all_reduce.cu
src/collectives/device/alltoall_pivot.cu
src/collectives/device/broadcast.cu
src/collectives/device/functions.cu
# src/collectives/device/msccl_kernel.cu
src/collectives/device/onerank_reduce.cu
# src/collectives/device/reduce.cu
# src/collectives/device/reduce_scatter.cu
src/collectives/device/sendrecv.cu)
endif()
list(APPEND SRC_FILES ${CU_SOURCES})
if (ENABLE_MSCCL_KERNEL)
set(MSCCL_KERNEL_SOURCES
src/collectives/device/msccl_kernel_impl.h
@@ -509,11 +485,12 @@ foreach(SRC_FILE ${SRC_FILES})
)
endforeach()
expand_collectives("all_reduce" "AllReduce")
expand_collectives("reduce" "Reduce")
expand_collectives("reduce_scatter" "ReduceScatter")
if(ENABLE_MSCCL_KERNEL)
expand_collectives("msccl_kernel" "MscclKernel")
if(ONLY_FUNCS)
## Generate only the specified functions
gen_functions(${ONLY_FUNCS})
else()
# Generate all the functions
gen_functions("AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv")
endif()
# Create an initial git_version.cpp file (that will be updated with latest git version)
@@ -637,7 +614,7 @@ if (HAS_BFD)
target_link_libraries(rccl PRIVATE iberty z)
endif()
endif()
target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code
target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code
## Set RCCL link options
target_link_options(rccl PRIVATE -parallel-jobs=16) # Use multiple threads to link
+1 -2
Просмотреть файл
@@ -25,7 +25,6 @@ The root of this repository has a helper script 'install.sh' to build and instal
Options:
--address-sanitizer Build with address sanitizer enabled
--build_allreduce_only Build only AllReduce + sum + float kernel
-d|--dependencies Install RCCL depdencencies
--debug Build debug library
--enable_backtrace Build with custom backtrace support
@@ -34,7 +33,7 @@ The root of this repository has a helper script 'install.sh' to build and instal
-f|--fast Quick-build RCCL (local gpu arch only, no backtrace, and collective trace support)
-h|--help Prints this help message
-i|--install Install RCCL library (see --prefix argument below)
-j|--jobs Specify how many parallel compilation jobs to run (16 by default)
-j|--jobs Specify how many parallel compilation jobs to run (nproc by default)
-l|--local_gpu_only Only compile for local GPU architecture
--no_clean Don't delete files if they already exist
--npkit-enable Compile with npkit enabled
-49
Просмотреть файл
@@ -70,55 +70,6 @@ if(NOT GTest_FOUND AND BUILD_TESTS OR INSTALL_DEPENDENCIES)
endif()
endif()
set(DATATYPES_INT
"int8_t"
"uint8_t"
"int32_t"
"uint32_t"
"int64_t"
"uint64_t"
)
set(DATATYPES_FLOAT
"half"
"float"
"double"
"rccl_bfloat16"
)
function(expand_collectives FILE FUNC)
set(REDOP Sum Prod Min Max PreMulSum SumPostDiv)
if (FUNC STREQUAL "MscclKernel")
set(REDOP_FILTERED Sum Prod Min Max PreMulSum SumPostDiv)
else()
set(REDOP_FILTERED ${REDOP})
endif()
foreach(REDOP_CURRENT IN LISTS REDOP_FILTERED)
foreach(DATA_TYPE ${DATATYPES_INT} ${DATATYPES_FLOAT})
if (REDOP_CURRENT STREQUAL "SumPostDiv" AND DATA_TYPE IN_LIST DATATYPES_FLOAT)
continue() # Skip the iteration for DATATYPES_FLOAT when REDOP_CURRENT is SumPostDiv
endif()
set(FILE_NAME "${HIPIFY_DIR}/src/collectives/device/${FILE}_${REDOP_CURRENT}_${DATA_TYPE}.cpp")
message(STATUS "Generating ${FILE_NAME}")
if (FUNC STREQUAL "MscclKernel")
file(WRITE ${FILE_NAME}
"#include \"${FILE}_impl.h\"
#include \"primitives.h\"
#include \"collectives.h\"
#include \"devcomm.h\"
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);")
else()
file(WRITE ${FILE_NAME}
"#include \"${FILE}.h\"
#include \"common.h\"
#include \"collectives.h\"
IMPL_COLL3(${FUNC}, ${REDOP_CURRENT}, ${DATA_TYPE});")
endif()
list(APPEND HIP_SOURCES ${FILE_NAME})
endforeach()
endforeach()
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
# Find or download/install rocm-cmake project
set( PROJECT_EXTERN_DIR ${CMAKE_CURRENT_BINARY_DIR}/extern )
find_package(ROCM 0.7.3 QUIET CONFIG PATHS /opt/rocm)
+383
Просмотреть файл
@@ -0,0 +1,383 @@
# MIT License
#
# Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
set(ALL_PARAMS "ALL_COLLS" "ALL_ALGOS" "ALL_PROTOS" "ALL_REDOPS" "ALL_TYPES")
set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "ReduceScatter" "SendRecv")
set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN")
set(ALL_PROTOS "LL" "LL128" "SIMPLE")
set(ALL_REDOPS "Sum" "Prod" "Max" "Min" "PreMulSum" "SumPostDiv")
set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "rccl_bfloat16")
set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16")
################################################################################
# The command line argument is used as a regex to filter the functions
# which make it into librccl. This is helpful for reducing the binary when
# developing device code. The regex supports non-space containing globs '*',
# and union 'a|b'. The string representing the function has the form:
#
# <coll> <algo> <proto> <redop> <type>
#
# The possible values for redop, type, algo, proto can be found in the all_<foo>
# lists at the top of this file.
#
# Example use-cases:
#
# # Only send/recv:
# make ONLY_FUNCS="SendRecv"
#
# # Only AllReduce and Reduce
# make ONLY_FUNCS="AllReduce|Reduce"
#
# # Only non-reductions:
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
#
# # Only AllReduce Sum int32_t (but all algos, protos)
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
#
# # Only AllReduce RING Max float (but all protos)
# make ONLY_FUNCS="AllReduce RING * Max float"
#
# # AllReduce TREE LL128 Prod rccl_bfloat16
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
#
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
#############################################################################################################
## A recursive helper macro to generate functions and kernels based on the input given
#############################################################################################################
macro(filter_functions FUNCTION_PARAMS current_idx)
## Check if the current_idx does not exceed the max depth
if(${current_idx} LESS 5)
## current_element is the config parameter
list(GET FUNCTION_PARAMS ${current_idx} current_element)
## If the parameter is equal to '*', include all the possible cases for it
if(${current_element} STREQUAL "*")
if(${current_idx} EQUAL 0)
message(FATAL_ERROR "Error: Parameter 'COLL' can not be type all '*'.")
endif()
## ALL_PARAMS list must be in the same order as FUNCTION_PARAMS ---> <coll> <algo> <proto> <redop> <type>
## Find the respective parameter list from ALL_PARAMS list
list(GET ALL_PARAMS ${current_idx} current_list)
## Iterate over the items int the current_list
foreach(item IN LISTS ${current_list})
## Add item to ITEM_LIST which will be used in the inner most loop
list(APPEND ITEM_LIST ${item})
math(EXPR new_idx "${current_idx} + 1")
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
## For each loop layer remove the last element in ITEM_LIST
list(REMOVE_AT ITEM_LIST -1)
endforeach()
else()
## Check if the current element is recognized
list(GET ALL_PARAMS ${current_idx} current_param)
list(FIND ${current_param} ${current_element} is_valid)
if(${is_valid} EQUAL -1)
message(FATAL_ERROR "Error: ${current_element} is unrecognized or does not belong to this category.")
endif()
## If not '*', no need to iterate. Add the current_element to ITEM_LIST
list(APPEND ITEM_LIST ${current_element})
math(EXPR new_idx "${current_idx} + 1")
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
list(REMOVE_AT ITEM_LIST -1)
endif()
else()
## This is the inner most loop where the file is generated
## Unzip ITEM_LIST
list(GET ITEM_LIST 0 COLL)
list(GET ITEM_LIST 1 ALGO)
list(GET ITEM_LIST 2 PROTO)
list(GET ITEM_LIST 3 REDOP)
list(GET ITEM_LIST 4 TYPE)
## Need to check if these conditions are met prior to file generation
if(NOT ${COLL} STREQUAL "AllReduce" AND NOT ${ALGO} STREQUAL "RING")
continue()
elseif((${COLL} STREQUAL "AllGather" OR ${COLL} STREQUAL "Broadcast" OR ${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND (NOT ${REDOP} STREQUAL "Sum" OR NOT ${TYPE} STREQUAL "int8_t"))
continue()
elseif((${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND NOT ${PROTO} STREQUAL "SIMPLE")
continue()
endif()
if(${REDOP} STREQUAL "SumPostDiv" AND TYPE IN_LIST FLOATS_LIST)
continue()
endif()
list(APPEND COLL_LIST "${COLL}-${ALGO}-${PROTO}-${REDOP}-${TYPE}")
set(COLL_LIST ${COLL_LIST} PARENT_SCOPE)
## Append the newly formed function/kernel to list
list(APPEND FUNC_LIST "ncclFunction_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
list(APPEND KERN_LIST "ncclKernel_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
set(FUNC_LIST ${FUNC_LIST} PARENT_SCOPE)
set(KERN_LIST ${KERN_LIST} PARENT_SCOPE)
endif()
endmacro()
#####################################################################################################
## Function to generate device table
#####################################################################################################
function(gen_device_table)
set(DEVICE_TABLE_FILE "${HIPIFY_DIR}/src/collectives/device/device_table.cpp")
message(STATUS "Generating ${DEVICE_TABLE_FILE}")
## Generate device table and list all the functions
file(WRITE ${DEVICE_TABLE_FILE} "#include \"common.h\"\n#include \"collectives.h\"\n\n")
## Declaration of device functions
foreach(func IN LISTS FUNC_LIST)
if(ENABLE_IFC)
file(APPEND ${DEVICE_TABLE_FILE} "__device__ void ${func}();\n")
else()
file(APPEND ${DEVICE_TABLE_FILE} "__device__ __attribute__((noinline)) void ${func}();\n")
endif()
endforeach()
file(APPEND ${DEVICE_TABLE_FILE} "\n")
if(ENABLE_IFC)
## Undirect function call
file(APPEND ${DEVICE_TABLE_FILE} "__device__ ncclKernelFunc_t const ncclFuncs[] = {\n")
foreach(func ${FUNC_LIST})
file(APPEND ${DEVICE_TABLE_FILE} " ${func},\n")
endforeach()
## Add OneRankReduce functions at the end
foreach(type IN LISTS ALL_TYPES)
file(APPEND ${DEVICE_TABLE_FILE} " ncclFunction_OneRankReduce_PreMulSum_${type},\n")
endforeach()
file(APPEND ${DEVICE_TABLE_FILE} "nullptr};\n\n")
else()
## Direct functions calls
file(APPEND ${DEVICE_TABLE_FILE} "__device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n switch(funcIndex) {\n")
set(index 0)
foreach(func IN LISTS FUNC_LIST)
file(APPEND ${DEVICE_TABLE_FILE} " case ${index}:\n ${func}();\n break;\n")
math(EXPR index "${index} + 1")
endforeach()
## Add OneRankReduce functions at the end
foreach(type IN LISTS ALL_TYPES)
file(APPEND ${DEVICE_TABLE_FILE} " case ${index}:\n ncclFunction_OneRankReduce_PreMulSum_${type}();\n break;\n")
math(EXPR index "${index} + 1")
endforeach()
file(APPEND ${DEVICE_TABLE_FILE} " }\n}\n")
endif()
## Add the device_table file to HIP_SOURCES
list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
######################################################################################################
## Function to generate host-side table
######################################################################################################
function(gen_host_table)
set(HOST_TABLE_FILE "${HIPIFY_DIR}/src/collectives/device/host_table.cpp")
message(STATUS "Generating ${HOST_TABLE_FILE}")
file(WRITE ${HOST_TABLE_FILE} "#include \"devcomm.h\"\n\n")
## The mapping from function rows to valid function ids
file(APPEND ${HOST_TABLE_FILE} "extern int const ncclFuncRowToId[] = {\n")
set(idx 0)
foreach(coll IN LISTS ALL_COLLS)
foreach(algo IN LISTS ALL_ALGOS)
foreach(proto IN LISTS ALL_PROTOS)
foreach(redop IN LISTS ALL_REDOPS)
foreach(type IN LISTS ALL_TYPES)
if(NOT ${coll} STREQUAL "AllReduce" AND NOT ${algo} STREQUAL "RING")
continue()
elseif((${coll} STREQUAL "AllGather" OR ${coll} STREQUAL "Broadcast" OR ${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND (NOT ${redop} STREQUAL "Sum" OR NOT ${type} STREQUAL "int8_t"))
continue()
elseif((${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND NOT ${proto} STREQUAL "SIMPLE")
continue()
endif()
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
continue()
endif()
list(FIND FUNC_LIST "ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
if(NOT ${fn_id} EQUAL -1)
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}\n")
else()
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n")
endif()
math(EXPR idx "${idx} + 1")
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
math(EXPR fn_id "${fn_id} + 1")
## Add OneRankReduce function ids at the end
foreach(type IN LISTS ALL_TYPES)
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_OneRankReduce_PreMulSum_${type}\n")
## Increment the index and func id for each OneRankReduce
math(EXPR idx "${idx} + 1")
math(EXPR fn_id "${fn_id} + 1")
endforeach()
file(APPEND ${HOST_TABLE_FILE} "-1};\n\n")
## Add the host_table file to HIP_SOURCES
list(APPEND HIP_SOURCES ${HOST_TABLE_FILE})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
###########################################################################################################
## Function to generate MSCCL Kernels
###########################################################################################################
function(gen_msccl_kernels)
foreach(REDOP_CURRENT IN LISTS ALL_REDOPS)
foreach(DATA_TYPE ${ALL_TYPES})
if (REDOP_CURRENT STREQUAL "SumPostDiv" AND DATA_TYPE IN_LIST FLOATS_LIST)
continue() # Skip the iteration for FLOATS_LIST when REDOP_CURRENT is SumPostDiv
endif()
set(FILE_NAME "${HIPIFY_DIR}/src/collectives/device/msccl_kernel_${REDOP_CURRENT}_${DATA_TYPE}.cpp")
message(STATUS "Generating ${FILE_NAME}")
file(WRITE ${FILE_NAME}
"#include \"msccl_kernel_impl.h\"
#include \"primitives.h\"
#include \"collectives.h\"
#include \"devcomm.h\"
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);")
list(APPEND HIP_SOURCES ${FILE_NAME})
endforeach()
endforeach()
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
###########################################################################################################
## Function to generate collectives
###########################################################################################################
function(gen_collectives)
# Iterate over each item in the original list
foreach(item ${COLL_LIST})
# Split the string into components
string(REPLACE "-" ";" item_components ${item})
# Extract COLL, ALGO, and PROTO components
list(GET item_components 0 coll_prefix)
list(GET item_components 1 algo_prefix)
list(GET item_components 2 proto_prefix)
list(GET item_components 3 redop_prefix)
# Create a list name using COLL, ALGO, and PROTO
set(list_name "${coll_prefix}_${algo_prefix}_${proto_prefix}_${redop_prefix}")
# Add the item to the corresponding list
list(APPEND ${list_name} ${item})
# Add the list name to the map if it doesn't exist
if(NOT list_name IN_LIST divided_lists)
list(APPEND divided_lists ${list_name})
endif()
endforeach()
set(index 0)
foreach(list_name IN LISTS divided_lists)
foreach(item IN LISTS ${list_name})
# Convert to a list
string(REPLACE "-" ";" components ${item})
list(GET components 0 coll)
list(GET components 1 algo)
list(GET components 2 proto)
list(GET components 3 redop)
list(GET components 4 type)
list(APPEND IMPL_LIST "IMPL_COLL_FUNC(${coll}, ${algo}, ${proto}, ${redop}, ${type})\n")
# Increment the function id
math(EXPR index "${index} + 1")
endforeach()
## Store lower-case version of COLL
string(TOLOWER ${coll} COLL_LOWER)
string(REPLACE "scatter" "_scatter" COLL_LOWER ${COLL_LOWER})
if(NOT ${coll} STREQUAL "AllToAllPivot")
string(REPLACE "all" "all_" COLL_LOWER ${COLL_LOWER})
else()
string(REPLACE "pivot" "_pivot" COLL_LOWER ${COLL_LOWER})
endif()
## Set name/path of the file
set(FILE_PATH "${HIPIFY_DIR}/src/collectives/device/${list_name}.cpp")
message(STATUS "Generating ${FILE_PATH}")
## Construct the file
file(WRITE ${FILE_PATH} "#include \"${COLL_LOWER}.h\"\n#include \"common.h\"\n#include \"collectives.h\"\n")
foreach(IMPL IN LISTS IMPL_LIST)
file(APPEND ${FILE_PATH} "${IMPL}")
endforeach()
## Append the file to HIP sources list which will be added to source list
list(APPEND HIP_SOURCES ${FILE_PATH})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
# Clear the IMPL list for the next iteration
set(IMPL_LIST)
endforeach()
endfunction()
###################################################################################################################
## Function to generate all the functions that are going to be in librccl.so
###################################################################################################################
function(gen_functions CONFIG_INPUT)
string(REPLACE "|" ";" INPUT_LIST ${CONFIG_INPUT})
## Sort the input so that it matches ALL_COLLS
list(SORT INPUT_LIST)
foreach(INPUT ${INPUT_LIST})
# Parse the the config string and make it a list
string(REPLACE " " ";" FUNCTION_PARAMS ${INPUT})
# Get the number of parameters in the input
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
# Assume all if a parameter is missing
while(${PARAMS_LENGTH} LESS 5)
list(APPEND FUNCTION_PARAMS "*")
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
endwhile()
## Filter functions/kernels based on input
filter_functions(FUNCTION_PARAMS 0)
endforeach()
gen_collectives() ## Generate collective files
if(ENABLE_MSCCL_KERNEL)
gen_msccl_kernels() ## Generate msccl files (not configurable)
endif()
gen_device_table() ## Generate device_table.cpp
gen_host_table() ## Generate host_table.cpp
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
+4 -12
Просмотреть файл
@@ -8,7 +8,6 @@ ROCM_PATH=${ROCM_PATH:="/opt/rocm"}
# Default values
build_address_sanitizer=false
build_allreduce_only=false
build_bfd=false
build_freorg_bkwdcomp=false
build_local_gpu_only=false
@@ -24,7 +23,7 @@ enable_ninja=""
install_dependencies=false
install_library=false
msccl_kernel_enabled=true
num_parallel_jobs=16
num_parallel_jobs=$(nproc)
npkit_enabled=false
run_tests=false
run_tests_all=false
@@ -38,7 +37,6 @@ function display_help()
echo "RCCL build & installation helper script"
echo " Options:"
echo " --address-sanitizer Build with address sanitizer enabled"
echo " --build_allreduce_only Build only AllReduce + sum + float kernel"
echo " -d|--dependencies Install RCCL depdencencies"
echo " --debug Build debug library"
echo " --enable_backtrace Build with custom backtrace support"
@@ -70,7 +68,7 @@ function display_help()
# check if we have a modern version of getopt that can handle whitespace and long parameters
getopt -T
if [[ $? -eq 4 ]]; then
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,build_allreduce_only,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
else
echo "Need a new version of getopt"
exit 1
@@ -86,7 +84,6 @@ eval set -- "${GETOPT_PARSE}"
while true; do
case "${1}" in
--address-sanitizer) build_address_sanitizer=true; shift ;;
--build_allreduce_only) build_allreduce_only=true; shift ;;
-d | --dependencies) install_dependencies=true; shift ;;
--debug) build_release=false; shift ;;
--enable_backtrace) build_bfd=true; shift ;;
@@ -182,11 +179,6 @@ if [[ "${build_address_sanitizer}" == true ]]; then
cmake_common_options="${cmake_common_options} -DBUILD_ADDRESS_SANITIZER=ON"
fi
# AllReduce only
if [[ "${build_allreduce_only}" == true ]]; then
cmake_common_options="${cmake_common_options} -DBUILD_ALLREDUCE_ONLY=ON"
fi
# Backtrace support
if [[ "${build_bfd}" == true ]]; then
cmake_common_options="${cmake_common_options} -DBUILD_BFD=ON"
@@ -353,9 +345,9 @@ else
fi
if ($build_tests) || (($run_tests) && [[ ! -f ./test/rccl-UnitTests ]]); then
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=ON -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH $enable_ninja ../../.
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=ON -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH -DONLY_FUNCS="$ONLY_FUNCS" $enable_ninja ../../.
else
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=OFF -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH $enable_ninja ../../.
CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=OFF -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH -DONLY_FUNCS="$ONLY_FUNCS" $enable_ninja ../../.
fi
check_exit_code "$?"
-11
Просмотреть файл
@@ -1,11 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "all_gather.h"
#include "common.h"
#include "collectives.h"
IMPL_COLL_C(AllGather);
-12
Просмотреть файл
@@ -1,12 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
/*This file is now generated in CMake*/
// #include "all_reduce.h"
// #include "common.h"
// #include "collectives.h"
// IMPL_COLL_R(AllReduce);
-11
Просмотреть файл
@@ -1,11 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "alltoall_pivot.h"
#include "common.h"
#include "collectives.h"
IMPL_COLL_F(AllToAllPivot);
-11
Просмотреть файл
@@ -1,11 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "broadcast.h"
#include "common.h"
#include "collectives.h"
IMPL_COLL_C(Broadcast);
+19 -319
Просмотреть файл
@@ -32,243 +32,14 @@
{ __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); }
#endif
#if defined(ENABLE_LL128) && defined(__gfx90a__)
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
#else
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
#endif
#define NCCL_FUNC4(func, devredop, type, nullify) \
NCCL_FUNC5(func, TREE, devredop, type, nullify), \
NCCL_FUNC5(func, RING, devredop, type, nullify), \
NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \
NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \
NCCL_FUNC5(func, NVLS, devredop, type, nullify), \
NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, uint8_t, 0), \
NCCL_FUNC4(func, devredop, int32_t, 0), \
NCCL_FUNC4(func, devredop, uint32_t, 0), \
NCCL_FUNC4(func, devredop, int64_t, 0), \
NCCL_FUNC4(func, devredop, uint64_t, 0), \
NCCL_FUNC4(func, devredop, half, nullForFloat), \
NCCL_FUNC4(func, devredop, float, nullForFloat), \
NCCL_FUNC4(func, devredop, double, nullForFloat), \
NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat)
#define NCCL_FUNCS3B(func, devredop) \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0)
// Must be consistent with ncclRedOp_t
#define NCCL_FUNCS2A(func) \
NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1)
#define NCCL_FUNCS2B(func) \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum)
// Must be consistent with the ncclFuncSet enum
using ncclKernelFunc_t = void (*)();
static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
// Don't try to initialize the host shadow copy of this device-side global
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(BUILD_ALLREDUCE_ONLY)
NCCL_FUNC4(AllReduce, Sum, float, 0),
#else
NCCL_FUNCS2B(Broadcast),
NCCL_FUNCS2A(Reduce),
NCCL_FUNCS2B(AllGather),
NCCL_FUNCS2A(ReduceScatter),
NCCL_FUNCS2A(AllReduce),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
#if defined(RCCL_BFLOAT16)
NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16),
#endif
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t),
#endif
#endif
};
// Defined in device_table.cpp
extern __device__ ncclKernelFunc_t const ncclFuncs[];
static_assert(FUNC_INDEX_P2P == 5410, "Wrong P2P function index");
static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 5411, "Wrong AllToAllPivot function index");
#if !defined(USE_INDIRECT_FUNCTION_CALL) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
template<unsigned short f, unsigned short l, bool u>
struct Caller {
static __forceinline__ __device__ __host__
void call(unsigned short funcIndex) noexcept
{
constexpr unsigned short m = f + (l - f) / 2;
return (funcIndex < m) ? Caller<f, m, u>::call(funcIndex) : Caller<m, l, u>::call(funcIndex);
}
};
template<unsigned short f, bool u>
struct Caller<f, f + 1, u>{
static __forceinline__ __device__ __host__
void call(unsigned short funcIndex) noexcept { ncclFuncs[f](); }
};
template<bool USING_LL128>
__forceinline__
__device__
void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {
#if defined(BUILD_ALLREDUCE_ONLY)
if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_RING_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL))
ncclFunction_AllReduce_RING_LL_Sum_float();
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
ncclFunction_AllReduce_RING_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
ncclFunction_AllReduce_RING_LL_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_TREE_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL))
ncclFunction_AllReduce_TREE_LL_Sum_float();
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128))
ncclFunction_AllReduce_TREE_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128))
ncclFunction_AllReduce_TREE_LL_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_COLLNET_DIRECT_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL))
ncclFunction_AllReduce_COLLNET_DIRECT_LL_Sum_float();
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL128))
ncclFunction_AllReduce_COLLNET_DIRECT_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL128))
ncclFunction_AllReduce_COLLNET_DIRECT_LL_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_COLLNET_CHAIN_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL))
ncclFunction_AllReduce_COLLNET_CHAIN_LL_Sum_float();
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL128))
ncclFunction_AllReduce_COLLNET_CHAIN_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL128))
ncclFunction_AllReduce_COLLNET_CHAIN_LL_Sum_float();
else
assert("Unsupported function index");
#else
if (funcIndex < 1080) {
if (funcIndex % 18 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
else if (funcIndex % 18 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t();
else if (funcIndex % 18 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
else if (funcIndex % 18 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t();
else if (funcIndex % 18 == 6) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t();
else if (funcIndex % 18 == 8) ncclFunction_Broadcast_COLLNET_DIRECT_SIMPLE_Sum_int8_t();
else if (funcIndex % 18 == 9) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t();
else ncclFunction_Broadcast_COLLNET_CHAIN_SIMPLE_Sum_int8_t();
}
else if (funcIndex < 2160) Caller<1080, 2160, USING_LL128>::call(funcIndex);
else if (funcIndex < 3240) {
if (funcIndex % 18 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
else if (funcIndex % 18 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t();
else if (funcIndex % 18 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t();
else if (funcIndex % 18 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t();
else if (funcIndex % 18 == 6) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t();
else if (funcIndex % 18 == 8) ncclFunction_AllGather_COLLNET_DIRECT_SIMPLE_Sum_int8_t();
else if (funcIndex % 18 == 9) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 18 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 18 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t();
else ncclFunction_AllGather_COLLNET_CHAIN_SIMPLE_Sum_int8_t();
}
else if (funcIndex < 5400) Caller<3240, 5400, USING_LL128>::call(funcIndex);
else {
switch (funcIndex - 5400) {
case 0:
ncclFunction_OneRankReduce_PreMulSum_int8_t();
break;
case 1:
ncclFunction_OneRankReduce_PreMulSum_uint8_t();
break;
case 2:
ncclFunction_OneRankReduce_PreMulSum_int32_t();
break;
case 3:
ncclFunction_OneRankReduce_PreMulSum_uint32_t();
break;
case 4:
ncclFunction_OneRankReduce_PreMulSum_int64_t();
break;
case 5:
ncclFunction_OneRankReduce_PreMulSum_uint64_t();
break;
case 6:
ncclFunction_OneRankReduce_PreMulSum_half();
break;
case 7:
ncclFunction_OneRankReduce_PreMulSum_float();
break;
case 8:
ncclFunction_OneRankReduce_PreMulSum_double();
break;
case 9:
ncclFunction_OneRankReduce_PreMulSum_rccl_bfloat16();
break;
case 10:
ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t();
break;
case 11:
ncclFunction_AllToAllPivot_RING_SIMPLE_Sum_int8_t();
default:
break;
}
}
#endif
}
#ifndef USE_INDIRECT_FUNCTION_CALL
__device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept;
#endif
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
@@ -464,7 +235,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we
}
}
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int FnIndex, bool COLLTRACE>
template<bool COLLTRACE>
__forceinline__ __device__ void ncclKernel(
struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead
) {
@@ -565,19 +336,12 @@ __forceinline__ __device__ void ncclKernel(
__synclds();
if (tid == 0) __insert_timestamp(__LINE__);
if (ncclShmem.work.header.funcIndex == FnIndex) {
RunWork<Fn, T, RedOp, Algo, Proto>().run(&ncclShmem.work);
} else {
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
ncclFuncs[ncclShmem.work.header.funcIndex]();
#ifdef USE_INDIRECT_FUNCTION_CALL
ncclFuncs[ncclShmem.work.header.funcIndex]();
#else
#if defined(ENABLE_LL128) && defined(__gfx90a__)
NCCL_CALL_FUNCTIONS<1>(ncclShmem.work.header.funcIndex);
#else
NCCL_CALL_FUNCTIONS<0>(ncclShmem.work.header.funcIndex);
NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex);
#endif
#endif
}
int workIxNext = ncclShmem.work.header.workNext;
__synclds();
@@ -606,28 +370,25 @@ __forceinline__ __device__ void ncclKernel(
}
#ifdef ENABLE_COLLTRACE
#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \
#define IMPL_MAIN_KERN() \
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \
__global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
ncclKernel<false>(comm, channelMask, workHead); \
} \
\
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(comm, channelMask, workHead); \
__global__ void rccl_main_kernel_debug(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
ncclKernel<true>(comm, channelMask, workHead); \
}
#else
#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \
#define IMPL_MAIN_KERN() \
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \
__global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \
ncclKernel<false>(comm, channelMask, workHead); \
}
#endif
// Examples : AllReduce, RING, LL, Sum, uint8
/* Functions for aggregation case */
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#ifdef USE_INDIRECT_FUNCTION_CALL
#define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \
__device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \
RunWork<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \
@@ -639,67 +400,6 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev
}
#endif
// Only generate inline kernels for LL
#if defined(ENABLE_LL128) && defined(__gfx90a__)
#define IMPL_COLL4(func, algo, devredop, type) \
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type)
#else
#define IMPL_COLL4(func, algo, devredop, type) \
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type)
#endif
#define IMPL_COLL3(func, devredop, type) \
IMPL_COLL4(func, TREE, devredop, type) \
IMPL_COLL4(func, RING, devredop, type) \
IMPL_COLL4(func, COLLNET_DIRECT, devredop, type) \
IMPL_COLL4(func, COLLNET_CHAIN, devredop, type) \
IMPL_COLL4(func, NVLS, devredop, type) \
IMPL_COLL4(func, NVLS_TREE, devredop, type)
#define IMPL_COLL2(func, devredop) \
IMPL_COLL3(func, devredop, int8_t) \
IMPL_COLL3(func, devredop, uint8_t) \
IMPL_COLL3(func, devredop, int32_t) \
IMPL_COLL3(func, devredop, uint32_t) \
IMPL_COLL3(func, devredop, int64_t) \
IMPL_COLL3(func, devredop, uint64_t) \
IMPL_COLL3(func, devredop, half) \
IMPL_COLL3(func, devredop, float) \
IMPL_COLL3(func, devredop, double) \
IMPL_COLL3(func, devredop, rccl_bfloat16)
#define IMPL_COLL2A(func, devredop) \
IMPL_COLL3(func, devredop, int8_t) \
IMPL_COLL3(func, devredop, uint8_t) \
IMPL_COLL3(func, devredop, int32_t) \
IMPL_COLL3(func, devredop, uint32_t) \
IMPL_COLL3(func, devredop, int64_t) \
IMPL_COLL3(func, devredop, uint64_t)
// Reduction define all functions
#define IMPL_COLL_R(func) \
IMPL_COLL2(func, Sum) \
IMPL_COLL2(func, Prod) \
IMPL_COLL2(func, Min) \
IMPL_COLL2(func, Max) \
IMPL_COLL2(func, PreMulSum) \
IMPL_COLL2A(func, SumPostDiv)
// Copy primitives only define one function for copy
#define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t);
// Point-to-point primitives only have one function/kernel.
#define IMPL_COLL_P(func) \
IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \
IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, FUNC_INDEX_P2P);
// AllToAll Pivot primitive only has one function.
#define IMPL_COLL_F(func) \
IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t);
#define NCCL_NVLS_ENABLED (__CUDA_ARCH__ >= 900 && NCCL_NVLS_SUPPORTS(NCCL_TYPE, NCCL_OP))
#endif
#endif
-126
Просмотреть файл
@@ -1,126 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "devcomm.h"
#include "collectives.h"
#include "common.h"
__shared__ ncclShmemData ncclShmem;
#if __CUDA_ARCH__ < 700
__shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)];
#endif
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
#else
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
#define NCCL_FUNC4(func, devredop, type, nullify) \
NCCL_FUNC5(func, TREE, devredop, type, nullify), \
NCCL_FUNC5(func, RING, devredop, type, nullify), \
NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \
NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \
NCCL_FUNC5(func, NVLS, devredop, type, nullify), \
NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify)
#if defined(__CUDA_BF16_TYPES_EXIST__)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, uint8_t, 0), \
NCCL_FUNC4(func, devredop, int32_t, 0), \
NCCL_FUNC4(func, devredop, uint32_t, 0), \
NCCL_FUNC4(func, devredop, int64_t, 0), \
NCCL_FUNC4(func, devredop, uint64_t, 0), \
NCCL_FUNC4(func, devredop, half, nullForFloat), \
NCCL_FUNC4(func, devredop, float, nullForFloat), \
NCCL_FUNC4(func, devredop, double, nullForFloat), \
NCCL_FUNC4(func, devredop, __nv_bfloat16, nullForFloat)
#define NCCL_FUNCS3B(func, devredop) \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0)
#else
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, uint8_t, 0), \
NCCL_FUNC4(func, devredop, int32_t, 0), \
NCCL_FUNC4(func, devredop, uint32_t, 0), \
NCCL_FUNC4(func, devredop, int64_t, 0), \
NCCL_FUNC4(func, devredop, uint64_t, 0), \
NCCL_FUNC4(func, devredop, half, nullForFloat), \
NCCL_FUNC4(func, devredop, float, nullForFloat), \
NCCL_FUNC4(func, devredop, double, nullForFloat)
#define NCCL_FUNCS3B(func, devredop) \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0), \
NCCL_FUNC4(func, devredop, int8_t, 0)
#endif
// Must be consistent with ncclRedOp_t
#define NCCL_FUNCS2A(func) \
NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \
NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1)
#define NCCL_FUNCS2B(func) \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum), \
NCCL_FUNCS3B(func, Sum)
// Must be consistent with the ncclFuncSet enum
__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
// Don't try to initialize the host shadow copy of this device-side global
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
#if __CUDA_ARCH__
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
#if defined(__CUDA_BF16_TYPES_EXIST__)
NCCL_ONERANK_REDUCE_NAME(PreMulSum, __nv_bfloat16),
#endif
NCCL_FUNCS2B(Broadcast),
NCCL_FUNCS2A(Reduce),
NCCL_FUNCS2B(AllGather),
NCCL_FUNCS2A(ReduceScatter),
NCCL_FUNCS2A(AllReduce)
#endif
};
#endif
// Workaround for https://reviews.llvm.org/D55580
__device__ void ncclWorkaroundClangD55580() {}
-13
Просмотреть файл
@@ -1,13 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
/*This file is now generated in CMake*/
// #include "reduce.h"
// #include "common.h"
// #include "collectives.h"
// IMPL_COLL_R(Reduce);
-13
Просмотреть файл
@@ -1,13 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
/*This file is now generated in CMake*/
// #include "reduce_scatter.h"
// #include "common.h"
// #include "collectives.h"
// IMPL_COLL_R(ReduceScatter);
-11
Просмотреть файл
@@ -1,11 +0,0 @@
/*************************************************************************
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "sendrecv.h"
#include "common.h"
#include "collectives.h"
IMPL_COLL_P(SendRecv);
+11 -10
Просмотреть файл
@@ -20,8 +20,6 @@
#include <cstring> // std::memcpy
#include <cinttypes> // PRIx64
static void* const ncclKernelGeneric = (void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t);
struct ncclKernelMatch {
void* kernelFn;
bool specialized;
@@ -29,15 +27,18 @@ struct ncclKernelMatch {
typedef void(*ncclKern_t)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
// Definition of rccl_main_kernel which is only used in here
IMPL_MAIN_KERN();
// Must be consistent with the ncclFuncSet enum
#ifdef ENABLE_COLLTRACE
static ncclKernelMatch const ncclKerns[2] = {
{(void *)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true},
{(void *)NCCL_KERN_NAME_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), true},
{(void *)rccl_main_kernel, true},
{(void *)rccl_main_kernel_debug, true},
};
#else
static ncclKernelMatch const ncclKerns[1] = {
{(void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true}
{(void*)rccl_main_kernel, true}
};
#endif
@@ -169,7 +170,7 @@ static void appendWorkElemP2p(
struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId,
struct ncclWorkElemP2p const *elem, bool fuseOk
) {
constexpr int funcIndex = FUNC_INDEX_P2P;
int funcIndex = ncclFuncId_P2p();
struct ncclKernelPlan::Channel* chan = &plan->channels[channelId];
struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue);
if (q && funcIndex == q->work.header.funcIndex) {
@@ -191,7 +192,7 @@ static void appendWorkElemP2p(
}
q = ncclMemoryStackAlloc<struct ncclWorkList>(&comm->memScoped);
q->work.header.type = ncclWorkTypeP2p;
q->work.header.funcIndex = FUNC_INDEX_P2P;
q->work.header.funcIndex = ncclFuncId_P2p();
chan->p2pTailElem[ncclWorkP2pTypeRecv-1] = 0;
chan->p2pTailElem[ncclWorkP2pTypeSend-1] = 1;
q->work.p2pElems[chan->p2pTailElem[elem->p2pType-1]] = *elem; // C++ struct assignment
@@ -1313,12 +1314,12 @@ comp_next:
if (info->comm->nRanks == 1) {
// one-rank reduce index
*workFuncIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype);
*workFuncIndex = ncclFuncId_P2p() + int(info->datatype);
return ncclSuccess;
} else if (info->coll == ncclFuncAllToAllPivot) {
*workFuncIndex = FUNC_INDEX_ALLTOALL_PIVOT;
*workFuncIndex = ncclFuncId_AllToAllPivot();
} else {
*workFuncIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
*workFuncIndex = ncclFuncId(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
}
work->connIndex = 0;
+7 -76
Просмотреть файл
@@ -19,9 +19,8 @@ struct ncclDevRedOpFull {
uint64_t scalarArg;
};
#define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps)
#define FUNC_INDEX_ALLTOALL_PIVOT (FUNC_INDEX_P2P+1)
#define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
#define FUNC_INDEX_P2P 1015
#define FUNC_INDEX_ALLTOALL_PIVOT 675
#define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \
ncclFunction_##func##_##algo##_##proto##_##devredop##_##type
@@ -38,79 +37,13 @@ struct ncclDevRedOpFull {
#define NCCL_IMPL_NAME(func, algo, proto) \
nccl##func##algo##proto
/* Declare all collective operations */
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#define DECL5(func, algo, proto, devredop, type) \
extern __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
#else
#define DECL5(func, algo, proto, devredop, type) \
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
// Declare rccl main/general kernel
extern __global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
#ifdef ENABLE_COLLTRACE
extern __global__ void rccl_main_kernel_debug(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
#endif
#define SINGLE_ARG(...) __VA_ARGS__
#define CONCAT(a,b) a##b
#define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(SINGLE_ARG(t), SINGLE_ARG(f))
#define MACRO_IF_0(t, f) f
#define MACRO_IF_1(t, f) t
#define DECL4(func, algo, devredop, type, undef) \
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, SIMPLE, devredop, type)) \
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL, devredop, type)) \
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type))
#define DECL3(func, devredop, type, undef) \
DECL4(func, RING, devredop, type, undef) \
DECL4(func, TREE, devredop, type, undef) \
DECL4(func, COLLNET_DIRECT, devredop, type, undef) \
DECL4(func, COLLNET_CHAIN, devredop, type, undef) \
DECL4(func, NVLS, devredop, type, undef) \
DECL4(func, NVLS_TREE, devredop, type, undef)
#if defined(RCCL_BFLOAT16)
#define DECL2(func, devredop, undefForFloat) \
DECL3(func, devredop, int8_t, /*undef=*/0) \
DECL3(func, devredop, uint8_t, /*undef=*/0) \
DECL3(func, devredop, int32_t, /*undef=*/0) \
DECL3(func, devredop, uint32_t, /*undef=*/0) \
DECL3(func, devredop, int64_t, /*undef=*/0) \
DECL3(func, devredop, uint64_t, /*undef=*/0) \
DECL3(func, devredop, half, /*undef=*/undefForFloat) \
DECL3(func, devredop, float, /*undef=*/undefForFloat) \
DECL3(func, devredop, double, /*undef=*/undefForFloat) \
DECL3(func, devredop, rccl_bfloat16, /*undef=*/undefForFloat)
#else
#define DECL2(func, devredop, undefForFloat) \
DECL3(func, devredop, int8_t, /*undef=*/0) \
DECL3(func, devredop, uint8_t, /*undef=*/0) \
DECL3(func, devredop, int32_t, /*undef=*/0) \
DECL3(func, devredop, uint32_t, /*undef=*/0) \
DECL3(func, devredop, int64_t, /*undef=*/0) \
DECL3(func, devredop, uint64_t, /*undef=*/0) \
DECL3(func, devredop, half, /*undef=*/undefForFloat) \
DECL3(func, devredop, float, /*undef=*/undefForFloat) \
DECL3(func, devredop, double, /*undef=*/undefForFloat)
#endif
#define DECL(func) \
DECL2(func, Sum, /*undefForFloat=*/0) \
DECL2(func, Prod, /*undefForFloat=*/0) \
DECL2(func, Min, /*undefForFloat=*/0) \
DECL2(func, Max, /*undefForFloat=*/0) \
DECL2(func, PreMulSum, /*undefForFloat=*/0) \
DECL2(func, SumPostDiv, /*undefForFloat=*/1)
DECL2(Broadcast, Sum, /*undefForFloat=*/0)
DECL(Reduce)
DECL2(AllGather, Sum, /*undefForFloat=*/0)
DECL(ReduceScatter)
DECL(AllReduce)
DECL5(SendRecv, RING, SIMPLE, Sum, int8_t)
DECL5(AllToAllPivot, RING, SIMPLE, Sum, int8_t)
// Declare OneRankReduce
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)();
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)();
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t)();
@@ -118,9 +51,7 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t)();
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t)();
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t)();
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, half)();
#if defined(RCCL_BFLOAT16)
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16)();
#endif
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)();
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)();
+59
Просмотреть файл
@@ -11,6 +11,7 @@
#include "nccl.h"
#include "rccl_bfloat16.h"
#include "align.h"
#include "collectives.h"
#if defined(ENABLE_NPKIT)
#include "npkit/npkit_struct.h"
#endif
@@ -483,4 +484,62 @@ __host__ __device__ constexpr int ncclShmemDynamicSize(int cudaArch = NCCL_CUDA_
return cudaArch < 700 ? 0 : ncclShmemScratchWarpSize(cudaArch)*(NCCL_MAX_NTHREADS/WARP_SIZE);
}
// Map the rowIdx to funcIdx
extern int const ncclFuncRowToId[];
// `ncclFuncIndex()` needs to be in sync with 'ALL_COLLS' in Generate.cmake
inline int ncclFuncId(int coll, int devRedOp, int type, int algo, int proto) {
int row = 0;
// RING / <all_protos> / Sum / int8_t
if (coll == ncclFuncAllGather) {
row += proto;
goto have_row;
}
row += NCCL_NUM_PROTOCOLS;
// <all_algos> / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncAllReduce) {
row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * (algo * NCCL_NUM_PROTOCOLS + proto);
goto have_row;
}
row += (NCCL_NUM_ALGORITHMS - 2) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4);
// RING / SIMPLE / Sum / int8_t
if (coll == ncclFuncAllToAllPivot) goto have_row;
row += 1;
// RING / <all_protos> / Sum / int8_t
if (coll == ncclFuncBroadcast) {
row += proto;
goto have_row;
}
row += NCCL_NUM_PROTOCOLS;
// RING / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncReduce) {
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * proto;
goto have_row;
}
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4);
// RING / <all_protos> / <all_redops> / <all_types>
if (coll == ncclFuncReduceScatter) {
row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * proto;
goto have_row;
}
row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4);
// RING / SIMPLE / Sum / int8_t
if (coll == ncclFuncSendRecv) goto have_row;
row += 1;
have_row:
return ncclFuncRowToId[row];
}
inline int ncclFuncId_P2p() { return ncclFuncRowToId[FUNC_INDEX_P2P]; }
inline int ncclFuncId_AllToAllPivot() { return ncclFuncRowToId[FUNC_INDEX_ALLTOALL_PIVOT]; }
#endif
+1
Просмотреть файл
@@ -10,6 +10,7 @@
#include "comm.h"
#include "group.h"
#include "collectives.h"
#include "common.h"
#include "utils.h"
#define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64)
+58 -20
Просмотреть файл
@@ -18,6 +18,7 @@
#include "enqueue.h"
#include "graph.h"
#include "argcheck.h"
#include "devcomm.h"
#if defined(ENABLE_NPKIT)
#include "npkit/npkit.h"
#endif
@@ -31,6 +32,7 @@
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <cstdarg>
#include "graph/topo.h"
#include "graph/xml.h"
#include "archinfo.h"
@@ -54,7 +56,7 @@
#define NCCL_GROUP_CUDA_STREAM 1 // CGMD: CUDA 9.0,9.1 Need to use an internal CUDA stream
#endif
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" };
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "AllGather", "AllReduce", "AllToAllPivot", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"};
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree" };
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" };
@@ -177,6 +179,18 @@ void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) {
RCCL_PARAM(KernelCollTraceEnable, "KERNEL_COLL_TRACE_ENABLE", 0);
#ifdef ENABLE_COLLTRACE
#define MAX_NAME_LENGTH 64
// Helper function to generate function names and update funcIdx
void generateFunctionName(char* func_names, int& funcIdx, const char* format, ...) {
char* line = func_names + MAX_NAME_LENGTH * funcIdx;
va_list args;
va_start(args, format);
vsnprintf(line, MAX_NAME_LENGTH, format, args);
va_end(args);
funcIdx++;
}
// Should be in sync with 'ALL_COLLS' in Generator.cmake
void *ncclCommThreadMain(void *arg) {
ncclComm_t comm = (ncclComm_t)arg;
int head[MAXCHANNELS];
@@ -184,29 +198,53 @@ void *ncclCommThreadMain(void *arg) {
memset(head, 0, sizeof(int)*MAXCHANNELS);
vega_gpu_rtc_freq = GetDeviceWallClockRateInKhz(comm->cudaDev) * 1.0E3;
#define MAX_NAME_LENGTH 64
char* func_names = (char *)malloc(MAX_NAME_LENGTH*(FUNC_INDEX_P2P+2));
for (int func = 0; func < NCCL_NUM_FUNCTIONS; func++) {
for (int al = 0; al < NCCL_NUM_ALGORITHMS; al++) {
for (int type = 0; type < ncclNumTypes; type++) {
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int devredop = 0; devredop < ncclNumDevRedOps; devredop++) {
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, devredop, type, al, pr);
sprintf(line, "%s%s%s%s%s", ncclFuncStr[func], ncclAlgoStr[al], ncclProtoStr[pr],
ncclDevRedOpStr[devredop], ncclTypeStr[type]);
}
char* func_names = (char *)malloc(MAX_NAME_LENGTH*(ncclFuncId_P2p()+/*OneRankReduce*/11));
int funcIdx = 0;
// AllGather --> RING / <all_protos> / Sum / int8_t
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
generateFunctionName(func_names, funcIdx, "AllGatherRing%sSum_i8", ncclProtoStr[pr]);
}
// AllReduce --> <all_algos> / <all_protos> / <all_redops> / <all_types>
for (int al = 0; al < NCCL_NUM_ALGORITHMS - 2; al++) {
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
for (int ty = 0; ty < ncclNumTypes; ty++) {
if (redop == 5 && ty > 5) continue;
generateFunctionName(func_names, funcIdx, "AllReduce%s%s%s%s", ncclAlgoStr[al], ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
}
}
}
}
for (int type = 0; type < ncclNumTypes; type++) {
char* line = func_names+MAX_NAME_LENGTH*(FUNC_INDEX_P2P-ncclNumTypes+type);
sprintf(line, "OneRankReducePreMulSum%s", ncclTypeStr[type]);
// AllToAllPivot --> RING / SIMPLE / Sum / int8_t
generateFunctionName(func_names, funcIdx, "AllToAllPivotRingSimpleSum_i8");
// Broadcast --> RING / <all_protos> / Sum / int8_t
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
generateFunctionName(func_names, funcIdx, "BroadcastRing%sSum_i8", ncclProtoStr[pr]);
}
// Reduce --> RING / <all_protos> / <all_redops> / <all_types>
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
for (int ty = 0; ty < ncclNumTypes; ty++) {
if (redop == 5 && ty > 5) continue;
generateFunctionName(func_names, funcIdx, "ReduceRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
}
}
}
// ReduceScatter --> RING / <all_protos> / <all_redops> / <all_types>
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
for (int redop = 0; redop < ncclNumDevRedOps; redop++) {
for (int ty = 0; ty < ncclNumTypes; ty++) {
if (redop == 5 && ty > 5) continue;
generateFunctionName(func_names, funcIdx, "ReduceScatterRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]);
}
}
}
// SendRecv --> RING / SIMPLE / Sum / int8_t
generateFunctionName(func_names, funcIdx, "SendRecvRingSimpleSum_i8");
// OneRankReduce --> PreMulSum / <all_types>
for (int ty = 0; ty < ncclNumTypes; ty++) {
generateFunctionName(func_names, funcIdx, "OneRankReducePreMulSum%s", ncclTypeStr[ty]);
}
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX_P2P;
sprintf(line, "SendRecvRingSimpleSum_i8");
line += MAX_NAME_LENGTH;
sprintf(line, "AllToAllPivotRingSimpleSum_i8");
do {
for (int channel = 0; channel < MAXCHANNELS; channel++) {
int tail = comm->collTraceTail[channel].tail%COLLTRACE_NUM_ITEMS;
@@ -232,7 +270,7 @@ void *ncclCommThreadMain(void *arg) {
(double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid,
fIdx, td->data_0, td->opCount, td->data_1);
} else {
if (fIdx == FUNC_INDEX_P2P || type == ncclCollTraceP2pElemType)
if (fIdx == ncclFuncId_P2p() || type == ncclCollTraceP2pElemType)
sprintf(line, "## [%012.6f] [%02d:%02d] %06x-%06x", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->p2pOpCount[0], td->p2pOpCount[1]);
else
sprintf(line, "## [%012.6f] [%02d:%02d] %06lx", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->opCount);
+12 -4
Просмотреть файл
@@ -37,10 +37,18 @@ if(BUILD_TESTS)
)
# Collect source files for tests
if(BUILD_ALLREDUCE_ONLY)
set(TEST_SOURCE_FILES
AllReduce_Tests.cpp
)
if(ONLY_FUNCS)
# Convert input string to a list
string(REPLACE "|" ";" CONFIG_LIST ${ONLY_FUNCS})
# For each config in config list
foreach(item ${CONFIG_LIST})
string(REPLACE " " ";" CONFIG_PARAMS ${item})
list(GET CONFIG_PARAMS 0 COLL)
set(TEST_FILE "${COLL}Tests.cpp")
list(APPEND TEST_SOURCE_FILES ${TEST_FILE})
endforeach()
else()
set(TEST_SOURCE_FILES
AllGatherTests.cpp
+3 -14
Просмотреть файл
@@ -65,12 +65,9 @@ namespace RcclUnitTesting
useInteractive = GetEnvVar("UT_INTERACTIVE", 0);
timeoutUs = GetEnvVar("UT_TIMEOUT_US" , 5000000);
// Limit number of supported reduction operators to just ncclSum if only allReduce is built
#ifdef BUILD_ALLREDUCE_ONLY
int numOps = 1;
#else
// Total number of reduction ops
int numOps = ncclNumOps;
#endif
std::vector<std::string> redOpStrings = GetEnvVarsList("UT_REDOPS");
for (auto s : redOpStrings)
{
@@ -98,12 +95,7 @@ namespace RcclUnitTesting
{
if (!strcmp(s.c_str(), ncclDataTypeNames[i]))
{
#ifdef BUILD_ALLREDUCE_ONLY
if (i == ncclFloat32)
#endif
{
dataTypes.push_back((ncclDataType_t)i);
}
dataTypes.push_back((ncclDataType_t)i);
}
}
}
@@ -112,8 +104,6 @@ namespace RcclUnitTesting
if (dataTypes.empty())
{
dataTypes.push_back(ncclFloat32);
// Skip all but 32-bit floats if only AllReduce is being built
#ifndef BUILD_ALLREDUCE_ONLY
dataTypes.push_back(ncclInt8);
dataTypes.push_back(ncclUint8);
dataTypes.push_back(ncclInt32);
@@ -124,7 +114,6 @@ namespace RcclUnitTesting
dataTypes.push_back(ncclFloat32);
dataTypes.push_back(ncclFloat64);
dataTypes.push_back(ncclBfloat16);
#endif
}
// Build list of possible # GPU ranks based on env vars