diff --git a/.jenkins/extended.groovy b/.jenkins/extended.groovy index 7a2ff008c2..51516d2754 100644 --- a/.jenkins/extended.groovy +++ b/.jenkins/extended.groovy @@ -17,7 +17,7 @@ def runCI = def prj = new rocProject('rccl', 'Extended') prj.timeout.test = 600 - prj.paths.build_command = './install.sh -t' + prj.paths.build_command = './install.sh -tj 16' // Define test architectures, optional rocm version argument is available def nodes = new dockerNodes(nodeDetails, jobName, prj) diff --git a/.jenkins/precheckin.groovy b/.jenkins/precheckin.groovy index 92bef95467..f2c30c1cb7 100644 --- a/.jenkins/precheckin.groovy +++ b/.jenkins/precheckin.groovy @@ -18,7 +18,7 @@ def runCI = def prj = new rocProject('rccl', 'PreCheckin') prj.timeout.test = 300 - prj.paths.build_command = './install.sh -t --fast' + prj.paths.build_command = './install.sh -tj 16 --fast' // Define test architectures, optional rocm version argument is available def nodes = new dockerNodes(nodeDetails, jobName, prj) diff --git a/.jenkins/staticlibrary.groovy b/.jenkins/staticlibrary.groovy index e75ff7ec97..4aac31934f 100644 --- a/.jenkins/staticlibrary.groovy +++ b/.jenkins/staticlibrary.groovy @@ -12,7 +12,7 @@ def runCI = def prj = new rocProject('rccl', 'Static Library PreCheckin') prj.timeout.test = 1440 - prj.paths.build_command = './install.sh -t --static' + prj.paths.build_command = './install.sh -tj 16 --static' def nodes = new dockerNodes(nodeDetails, jobName, prj) diff --git a/CMakeLists.txt b/CMakeLists.txt index ff16682ce4..609c582583 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,6 @@ project(rccl CXX) # Build options #================================================================================================== option(BUILD_ADDRESS_SANITIZER "Enable address sanitizer" OFF) -option(BUILD_ALLREDUCE_ONLY "AllReduce(sum,float) kernel only" OFF) option(BUILD_BFD "Enable custom backtrace (if bfd.h exists)" OFF) option(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY "File/folder reorg with backward compatibility" OFF) option(BUILD_LOCAL_GPU_TARGET_ONLY "Build only for GPUs detected on this machine" OFF) @@ -46,6 +45,7 @@ set(DEFAULT_GPUS include(CheckIncludeFiles) include(CheckSymbolExists) include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets +include(cmake/Generator.cmake) # Configure functions that goes into RCCL # Build only for local GPU architecture if (BUILD_LOCAL_GPU_TARGET_ONLY) @@ -309,6 +309,7 @@ set(SRC_FILES src/collectives/device/reduce_kernel.h src/collectives/device/reduce_scatter.h src/collectives/device/sendrecv.h + src/collectives/device/onerank_reduce.cu src/collectives/gather.cc src/collectives/reduce.cc src/collectives/reduce_scatter.cc @@ -447,31 +448,6 @@ set(SRC_FILES src/transport/shm.cc ) -## Add kernel files -## E.g: find src -type f \( -name "*.u" \) | sort -if (BUILD_ALLREDUCE_ONLY) - add_definitions(-DBUILD_ALLREDUCE_ONLY) - set(CU_SOURCES - # src/collectives/device/all_reduce.cu - src/collectives/device/sendrecv.cu - src/collectives/device/functions.cu - # src/collectives/device/msccl_kernel.cu - ) -else() - set(CU_SOURCES - src/collectives/device/all_gather.cu - # src/collectives/device/all_reduce.cu - src/collectives/device/alltoall_pivot.cu - src/collectives/device/broadcast.cu - src/collectives/device/functions.cu - # src/collectives/device/msccl_kernel.cu - src/collectives/device/onerank_reduce.cu - # src/collectives/device/reduce.cu - # src/collectives/device/reduce_scatter.cu - src/collectives/device/sendrecv.cu) -endif() -list(APPEND SRC_FILES ${CU_SOURCES}) - if (ENABLE_MSCCL_KERNEL) set(MSCCL_KERNEL_SOURCES src/collectives/device/msccl_kernel_impl.h @@ -509,11 +485,12 @@ foreach(SRC_FILE ${SRC_FILES}) ) endforeach() -expand_collectives("all_reduce" "AllReduce") -expand_collectives("reduce" "Reduce") -expand_collectives("reduce_scatter" "ReduceScatter") -if(ENABLE_MSCCL_KERNEL) - expand_collectives("msccl_kernel" "MscclKernel") +if(ONLY_FUNCS) + ## Generate only the specified functions + gen_functions(${ONLY_FUNCS}) +else() + # Generate all the functions + gen_functions("AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv") endif() # Create an initial git_version.cpp file (that will be updated with latest git version) @@ -637,7 +614,7 @@ if (HAS_BFD) target_link_libraries(rccl PRIVATE iberty z) endif() endif() -target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code +target_link_libraries(rccl PRIVATE -fgpu-rdc) # Required when linking relocatable device code ## Set RCCL link options target_link_options(rccl PRIVATE -parallel-jobs=16) # Use multiple threads to link diff --git a/README.md b/README.md index bd7e9f7964..d6feae34fe 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,6 @@ The root of this repository has a helper script 'install.sh' to build and instal Options: --address-sanitizer Build with address sanitizer enabled - --build_allreduce_only Build only AllReduce + sum + float kernel -d|--dependencies Install RCCL depdencencies --debug Build debug library --enable_backtrace Build with custom backtrace support @@ -34,7 +33,7 @@ The root of this repository has a helper script 'install.sh' to build and instal -f|--fast Quick-build RCCL (local gpu arch only, no backtrace, and collective trace support) -h|--help Prints this help message -i|--install Install RCCL library (see --prefix argument below) - -j|--jobs Specify how many parallel compilation jobs to run (16 by default) + -j|--jobs Specify how many parallel compilation jobs to run (nproc by default) -l|--local_gpu_only Only compile for local GPU architecture --no_clean Don't delete files if they already exist --npkit-enable Compile with npkit enabled diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3d50c63615..f822f4f55a 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -70,55 +70,6 @@ if(NOT GTest_FOUND AND BUILD_TESTS OR INSTALL_DEPENDENCIES) endif() endif() -set(DATATYPES_INT -"int8_t" -"uint8_t" -"int32_t" -"uint32_t" -"int64_t" -"uint64_t" - ) -set(DATATYPES_FLOAT - "half" - "float" - "double" - "rccl_bfloat16" - ) - -function(expand_collectives FILE FUNC) - set(REDOP Sum Prod Min Max PreMulSum SumPostDiv) - if (FUNC STREQUAL "MscclKernel") - set(REDOP_FILTERED Sum Prod Min Max PreMulSum SumPostDiv) - else() - set(REDOP_FILTERED ${REDOP}) - endif() - foreach(REDOP_CURRENT IN LISTS REDOP_FILTERED) - foreach(DATA_TYPE ${DATATYPES_INT} ${DATATYPES_FLOAT}) - if (REDOP_CURRENT STREQUAL "SumPostDiv" AND DATA_TYPE IN_LIST DATATYPES_FLOAT) - continue() # Skip the iteration for DATATYPES_FLOAT when REDOP_CURRENT is SumPostDiv - endif() - set(FILE_NAME "${HIPIFY_DIR}/src/collectives/device/${FILE}_${REDOP_CURRENT}_${DATA_TYPE}.cpp") - message(STATUS "Generating ${FILE_NAME}") - if (FUNC STREQUAL "MscclKernel") - file(WRITE ${FILE_NAME} - "#include \"${FILE}_impl.h\" - #include \"primitives.h\" - #include \"collectives.h\" - #include \"devcomm.h\" - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);") - else() - file(WRITE ${FILE_NAME} - "#include \"${FILE}.h\" - #include \"common.h\" - #include \"collectives.h\" - IMPL_COLL3(${FUNC}, ${REDOP_CURRENT}, ${DATA_TYPE});") - endif() - list(APPEND HIP_SOURCES ${FILE_NAME}) - endforeach() - endforeach() - set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) -endfunction() - # Find or download/install rocm-cmake project set( PROJECT_EXTERN_DIR ${CMAKE_CURRENT_BINARY_DIR}/extern ) find_package(ROCM 0.7.3 QUIET CONFIG PATHS /opt/rocm) diff --git a/cmake/Generator.cmake b/cmake/Generator.cmake new file mode 100644 index 0000000000..7871d01337 --- /dev/null +++ b/cmake/Generator.cmake @@ -0,0 +1,383 @@ +# MIT License +# +# Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +set(ALL_PARAMS "ALL_COLLS" "ALL_ALGOS" "ALL_PROTOS" "ALL_REDOPS" "ALL_TYPES") + +set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "ReduceScatter" "SendRecv") +set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN") +set(ALL_PROTOS "LL" "LL128" "SIMPLE") +set(ALL_REDOPS "Sum" "Prod" "Max" "Min" "PreMulSum" "SumPostDiv") +set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "rccl_bfloat16") + +set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16") + +################################################################################ +# The command line argument is used as a regex to filter the functions +# which make it into librccl. This is helpful for reducing the binary when +# developing device code. The regex supports non-space containing globs '*', +# and union 'a|b'. The string representing the function has the form: +# +# +# +# The possible values for redop, type, algo, proto can be found in the all_ +# lists at the top of this file. +# +# Example use-cases: +# +# # Only send/recv: +# make ONLY_FUNCS="SendRecv" +# +# # Only AllReduce and Reduce +# make ONLY_FUNCS="AllReduce|Reduce" +# +# # Only non-reductions: +# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" +# +# # Only AllReduce Sum int32_t (but all algos, protos) +# make ONLY_FUNCS="AllReduce * * Sum int32_t" +# +# # Only AllReduce RING Max float (but all protos) +# make ONLY_FUNCS="AllReduce RING * Max float" +# +# # AllReduce TREE LL128 Prod rccl_bfloat16 +# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16" +# +# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter) +# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float" +# --- or --- +# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" + +############################################################################################################# +## A recursive helper macro to generate functions and kernels based on the input given +############################################################################################################# +macro(filter_functions FUNCTION_PARAMS current_idx) + ## Check if the current_idx does not exceed the max depth + if(${current_idx} LESS 5) + ## current_element is the config parameter + list(GET FUNCTION_PARAMS ${current_idx} current_element) + + ## If the parameter is equal to '*', include all the possible cases for it + if(${current_element} STREQUAL "*") + if(${current_idx} EQUAL 0) + message(FATAL_ERROR "Error: Parameter 'COLL' can not be type all '*'.") + endif() + ## ALL_PARAMS list must be in the same order as FUNCTION_PARAMS ---> + ## Find the respective parameter list from ALL_PARAMS list + list(GET ALL_PARAMS ${current_idx} current_list) + + ## Iterate over the items int the current_list + foreach(item IN LISTS ${current_list}) + ## Add item to ITEM_LIST which will be used in the inner most loop + list(APPEND ITEM_LIST ${item}) + math(EXPR new_idx "${current_idx} + 1") + filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN}) + + ## For each loop layer remove the last element in ITEM_LIST + list(REMOVE_AT ITEM_LIST -1) + endforeach() + else() + ## Check if the current element is recognized + list(GET ALL_PARAMS ${current_idx} current_param) + list(FIND ${current_param} ${current_element} is_valid) + if(${is_valid} EQUAL -1) + message(FATAL_ERROR "Error: ${current_element} is unrecognized or does not belong to this category.") + endif() + + ## If not '*', no need to iterate. Add the current_element to ITEM_LIST + list(APPEND ITEM_LIST ${current_element}) + math(EXPR new_idx "${current_idx} + 1") + filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN}) + + list(REMOVE_AT ITEM_LIST -1) + endif() + else() + ## This is the inner most loop where the file is generated + ## Unzip ITEM_LIST + list(GET ITEM_LIST 0 COLL) + list(GET ITEM_LIST 1 ALGO) + list(GET ITEM_LIST 2 PROTO) + list(GET ITEM_LIST 3 REDOP) + list(GET ITEM_LIST 4 TYPE) + + ## Need to check if these conditions are met prior to file generation + if(NOT ${COLL} STREQUAL "AllReduce" AND NOT ${ALGO} STREQUAL "RING") + continue() + elseif((${COLL} STREQUAL "AllGather" OR ${COLL} STREQUAL "Broadcast" OR ${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND (NOT ${REDOP} STREQUAL "Sum" OR NOT ${TYPE} STREQUAL "int8_t")) + continue() + elseif((${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND NOT ${PROTO} STREQUAL "SIMPLE") + continue() + endif() + + if(${REDOP} STREQUAL "SumPostDiv" AND TYPE IN_LIST FLOATS_LIST) + continue() + endif() + + list(APPEND COLL_LIST "${COLL}-${ALGO}-${PROTO}-${REDOP}-${TYPE}") + set(COLL_LIST ${COLL_LIST} PARENT_SCOPE) + + ## Append the newly formed function/kernel to list + list(APPEND FUNC_LIST "ncclFunction_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}") + list(APPEND KERN_LIST "ncclKernel_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}") + set(FUNC_LIST ${FUNC_LIST} PARENT_SCOPE) + set(KERN_LIST ${KERN_LIST} PARENT_SCOPE) + endif() +endmacro() + +##################################################################################################### +## Function to generate device table +##################################################################################################### +function(gen_device_table) + set(DEVICE_TABLE_FILE "${HIPIFY_DIR}/src/collectives/device/device_table.cpp") + message(STATUS "Generating ${DEVICE_TABLE_FILE}") + + ## Generate device table and list all the functions + file(WRITE ${DEVICE_TABLE_FILE} "#include \"common.h\"\n#include \"collectives.h\"\n\n") + + ## Declaration of device functions + foreach(func IN LISTS FUNC_LIST) + if(ENABLE_IFC) + file(APPEND ${DEVICE_TABLE_FILE} "__device__ void ${func}();\n") + else() + file(APPEND ${DEVICE_TABLE_FILE} "__device__ __attribute__((noinline)) void ${func}();\n") + endif() + endforeach() + file(APPEND ${DEVICE_TABLE_FILE} "\n") + + if(ENABLE_IFC) + ## Undirect function call + file(APPEND ${DEVICE_TABLE_FILE} "__device__ ncclKernelFunc_t const ncclFuncs[] = {\n") + foreach(func ${FUNC_LIST}) + file(APPEND ${DEVICE_TABLE_FILE} " ${func},\n") + endforeach() + ## Add OneRankReduce functions at the end + foreach(type IN LISTS ALL_TYPES) + file(APPEND ${DEVICE_TABLE_FILE} " ncclFunction_OneRankReduce_PreMulSum_${type},\n") + endforeach() + file(APPEND ${DEVICE_TABLE_FILE} "nullptr};\n\n") + else() + ## Direct functions calls + file(APPEND ${DEVICE_TABLE_FILE} "__device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n switch(funcIndex) {\n") + set(index 0) + foreach(func IN LISTS FUNC_LIST) + file(APPEND ${DEVICE_TABLE_FILE} " case ${index}:\n ${func}();\n break;\n") + math(EXPR index "${index} + 1") + endforeach() + ## Add OneRankReduce functions at the end + foreach(type IN LISTS ALL_TYPES) + file(APPEND ${DEVICE_TABLE_FILE} " case ${index}:\n ncclFunction_OneRankReduce_PreMulSum_${type}();\n break;\n") + math(EXPR index "${index} + 1") + endforeach() + file(APPEND ${DEVICE_TABLE_FILE} " }\n}\n") + endif() + + ## Add the device_table file to HIP_SOURCES + list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE}) + set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) +endfunction() + +###################################################################################################### +## Function to generate host-side table +###################################################################################################### +function(gen_host_table) + set(HOST_TABLE_FILE "${HIPIFY_DIR}/src/collectives/device/host_table.cpp") + message(STATUS "Generating ${HOST_TABLE_FILE}") + + file(WRITE ${HOST_TABLE_FILE} "#include \"devcomm.h\"\n\n") + + ## The mapping from function rows to valid function ids + file(APPEND ${HOST_TABLE_FILE} "extern int const ncclFuncRowToId[] = {\n") + set(idx 0) + foreach(coll IN LISTS ALL_COLLS) + foreach(algo IN LISTS ALL_ALGOS) + foreach(proto IN LISTS ALL_PROTOS) + foreach(redop IN LISTS ALL_REDOPS) + foreach(type IN LISTS ALL_TYPES) + if(NOT ${coll} STREQUAL "AllReduce" AND NOT ${algo} STREQUAL "RING") + continue() + elseif((${coll} STREQUAL "AllGather" OR ${coll} STREQUAL "Broadcast" OR ${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND (NOT ${redop} STREQUAL "Sum" OR NOT ${type} STREQUAL "int8_t")) + continue() + elseif((${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND NOT ${proto} STREQUAL "SIMPLE") + continue() + endif() + + if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST) + continue() + endif() + + list(FIND FUNC_LIST "ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}" fn_id) + if(NOT ${fn_id} EQUAL -1) + file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_${coll}_${algo}_${proto}_${redop}_${type}\n") + else() + file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n") + endif() + math(EXPR idx "${idx} + 1") + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + math(EXPR fn_id "${fn_id} + 1") + ## Add OneRankReduce function ids at the end + foreach(type IN LISTS ALL_TYPES) + file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclFunction_OneRankReduce_PreMulSum_${type}\n") + + ## Increment the index and func id for each OneRankReduce + math(EXPR idx "${idx} + 1") + math(EXPR fn_id "${fn_id} + 1") + endforeach() + file(APPEND ${HOST_TABLE_FILE} "-1};\n\n") + + ## Add the host_table file to HIP_SOURCES + list(APPEND HIP_SOURCES ${HOST_TABLE_FILE}) + set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) +endfunction() + +########################################################################################################### +## Function to generate MSCCL Kernels +########################################################################################################### +function(gen_msccl_kernels) + foreach(REDOP_CURRENT IN LISTS ALL_REDOPS) + foreach(DATA_TYPE ${ALL_TYPES}) + if (REDOP_CURRENT STREQUAL "SumPostDiv" AND DATA_TYPE IN_LIST FLOATS_LIST) + continue() # Skip the iteration for FLOATS_LIST when REDOP_CURRENT is SumPostDiv + endif() + set(FILE_NAME "${HIPIFY_DIR}/src/collectives/device/msccl_kernel_${REDOP_CURRENT}_${DATA_TYPE}.cpp") + message(STATUS "Generating ${FILE_NAME}") + file(WRITE ${FILE_NAME} + "#include \"msccl_kernel_impl.h\" + #include \"primitives.h\" + #include \"collectives.h\" + #include \"devcomm.h\" + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false); + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, true);") + list(APPEND HIP_SOURCES ${FILE_NAME}) + endforeach() + endforeach() + set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) +endfunction() + +########################################################################################################### +## Function to generate collectives +########################################################################################################### +function(gen_collectives) + # Iterate over each item in the original list + foreach(item ${COLL_LIST}) + # Split the string into components + string(REPLACE "-" ";" item_components ${item}) + + # Extract COLL, ALGO, and PROTO components + list(GET item_components 0 coll_prefix) + list(GET item_components 1 algo_prefix) + list(GET item_components 2 proto_prefix) + list(GET item_components 3 redop_prefix) + + # Create a list name using COLL, ALGO, and PROTO + set(list_name "${coll_prefix}_${algo_prefix}_${proto_prefix}_${redop_prefix}") + + # Add the item to the corresponding list + list(APPEND ${list_name} ${item}) + + # Add the list name to the map if it doesn't exist + if(NOT list_name IN_LIST divided_lists) + list(APPEND divided_lists ${list_name}) + endif() + endforeach() + + set(index 0) + foreach(list_name IN LISTS divided_lists) + foreach(item IN LISTS ${list_name}) + # Convert to a list + string(REPLACE "-" ";" components ${item}) + + list(GET components 0 coll) + list(GET components 1 algo) + list(GET components 2 proto) + list(GET components 3 redop) + list(GET components 4 type) + + list(APPEND IMPL_LIST "IMPL_COLL_FUNC(${coll}, ${algo}, ${proto}, ${redop}, ${type})\n") + + # Increment the function id + math(EXPR index "${index} + 1") + endforeach() + ## Store lower-case version of COLL + string(TOLOWER ${coll} COLL_LOWER) + string(REPLACE "scatter" "_scatter" COLL_LOWER ${COLL_LOWER}) + if(NOT ${coll} STREQUAL "AllToAllPivot") + string(REPLACE "all" "all_" COLL_LOWER ${COLL_LOWER}) + else() + string(REPLACE "pivot" "_pivot" COLL_LOWER ${COLL_LOWER}) + endif() + + ## Set name/path of the file + set(FILE_PATH "${HIPIFY_DIR}/src/collectives/device/${list_name}.cpp") + message(STATUS "Generating ${FILE_PATH}") + + ## Construct the file + file(WRITE ${FILE_PATH} "#include \"${COLL_LOWER}.h\"\n#include \"common.h\"\n#include \"collectives.h\"\n") + foreach(IMPL IN LISTS IMPL_LIST) + file(APPEND ${FILE_PATH} "${IMPL}") + endforeach() + + ## Append the file to HIP sources list which will be added to source list + list(APPEND HIP_SOURCES ${FILE_PATH}) + set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) + + # Clear the IMPL list for the next iteration + set(IMPL_LIST) + endforeach() +endfunction() + +################################################################################################################### +## Function to generate all the functions that are going to be in librccl.so +################################################################################################################### +function(gen_functions CONFIG_INPUT) + string(REPLACE "|" ";" INPUT_LIST ${CONFIG_INPUT}) + ## Sort the input so that it matches ALL_COLLS + list(SORT INPUT_LIST) + + foreach(INPUT ${INPUT_LIST}) + # Parse the the config string and make it a list + string(REPLACE " " ";" FUNCTION_PARAMS ${INPUT}) + + # Get the number of parameters in the input + list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH) + + # Assume all if a parameter is missing + while(${PARAMS_LENGTH} LESS 5) + list(APPEND FUNCTION_PARAMS "*") + list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH) + endwhile() + + ## Filter functions/kernels based on input + filter_functions(FUNCTION_PARAMS 0) + endforeach() + + gen_collectives() ## Generate collective files + if(ENABLE_MSCCL_KERNEL) + gen_msccl_kernels() ## Generate msccl files (not configurable) + endif() + gen_device_table() ## Generate device_table.cpp + gen_host_table() ## Generate host_table.cpp + + set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/install.sh b/install.sh index 2421714281..3cc42a3649 100755 --- a/install.sh +++ b/install.sh @@ -8,7 +8,6 @@ ROCM_PATH=${ROCM_PATH:="/opt/rocm"} # Default values build_address_sanitizer=false -build_allreduce_only=false build_bfd=false build_freorg_bkwdcomp=false build_local_gpu_only=false @@ -24,7 +23,7 @@ enable_ninja="" install_dependencies=false install_library=false msccl_kernel_enabled=true -num_parallel_jobs=16 +num_parallel_jobs=$(nproc) npkit_enabled=false run_tests=false run_tests_all=false @@ -38,7 +37,6 @@ function display_help() echo "RCCL build & installation helper script" echo " Options:" echo " --address-sanitizer Build with address sanitizer enabled" - echo " --build_allreduce_only Build only AllReduce + sum + float kernel" echo " -d|--dependencies Install RCCL depdencencies" echo " --debug Build debug library" echo " --enable_backtrace Build with custom backtrace support" @@ -70,7 +68,7 @@ function display_help() # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ $? -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,build_allreduce_only,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@") else echo "Need a new version of getopt" exit 1 @@ -86,7 +84,6 @@ eval set -- "${GETOPT_PARSE}" while true; do case "${1}" in --address-sanitizer) build_address_sanitizer=true; shift ;; - --build_allreduce_only) build_allreduce_only=true; shift ;; -d | --dependencies) install_dependencies=true; shift ;; --debug) build_release=false; shift ;; --enable_backtrace) build_bfd=true; shift ;; @@ -182,11 +179,6 @@ if [[ "${build_address_sanitizer}" == true ]]; then cmake_common_options="${cmake_common_options} -DBUILD_ADDRESS_SANITIZER=ON" fi -# AllReduce only -if [[ "${build_allreduce_only}" == true ]]; then - cmake_common_options="${cmake_common_options} -DBUILD_ALLREDUCE_ONLY=ON" -fi - # Backtrace support if [[ "${build_bfd}" == true ]]; then cmake_common_options="${cmake_common_options} -DBUILD_BFD=ON" @@ -353,9 +345,9 @@ else fi if ($build_tests) || (($run_tests) && [[ ! -f ./test/rccl-UnitTests ]]); then - CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=ON -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH $enable_ninja ../../. + CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=ON -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH -DONLY_FUNCS="$ONLY_FUNCS" $enable_ninja ../../. else - CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=OFF -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH $enable_ninja ../../. + CXX=$ROCM_BIN_PATH/hipcc $cmake_executable $cmake_common_options -DBUILD_TESTS=OFF -DNPKIT_FLAGS="${npkit_options}" -DCMAKE_INSTALL_PREFIX=$ROCM_PATH -DROCM_PATH=$ROCM_PATH -DONLY_FUNCS="$ONLY_FUNCS" $enable_ninja ../../. fi check_exit_code "$?" diff --git a/src/collectives/device/all_gather.cu b/src/collectives/device/all_gather.cu deleted file mode 100644 index 4022e2e9f5..0000000000 --- a/src/collectives/device/all_gather.cu +++ /dev/null @@ -1,11 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "all_gather.h" -#include "common.h" -#include "collectives.h" - -IMPL_COLL_C(AllGather); diff --git a/src/collectives/device/all_reduce.cu b/src/collectives/device/all_reduce.cu deleted file mode 100644 index 99de10a714..0000000000 --- a/src/collectives/device/all_reduce.cu +++ /dev/null @@ -1,12 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ -/*This file is now generated in CMake*/ - -// #include "all_reduce.h" -// #include "common.h" -// #include "collectives.h" - -// IMPL_COLL_R(AllReduce); diff --git a/src/collectives/device/alltoall_pivot.cu b/src/collectives/device/alltoall_pivot.cu deleted file mode 100644 index 403d979acc..0000000000 --- a/src/collectives/device/alltoall_pivot.cu +++ /dev/null @@ -1,11 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "alltoall_pivot.h" -#include "common.h" -#include "collectives.h" - -IMPL_COLL_F(AllToAllPivot); diff --git a/src/collectives/device/broadcast.cu b/src/collectives/device/broadcast.cu deleted file mode 100644 index 77595858bf..0000000000 --- a/src/collectives/device/broadcast.cu +++ /dev/null @@ -1,11 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "broadcast.h" -#include "common.h" -#include "collectives.h" - -IMPL_COLL_C(Broadcast); diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index afb9b85ba0..de1e6d84ec 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -32,243 +32,14 @@ { __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); } #endif -#if defined(ENABLE_LL128) && defined(__gfx90a__) -#define NCCL_FUNC5(func, algo, devredop, type, nullify) \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) -#else -#define NCCL_FUNC5(func, algo, devredop, type, nullify) \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) -#endif - -#define NCCL_FUNC4(func, devredop, type, nullify) \ - NCCL_FUNC5(func, TREE, devredop, type, nullify), \ - NCCL_FUNC5(func, RING, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ - NCCL_FUNC5(func, NVLS, devredop, type, nullify), \ - NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify) - -// Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, devredop, nullForFloat) \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, uint8_t, 0), \ - NCCL_FUNC4(func, devredop, int32_t, 0), \ - NCCL_FUNC4(func, devredop, uint32_t, 0), \ - NCCL_FUNC4(func, devredop, int64_t, 0), \ - NCCL_FUNC4(func, devredop, uint64_t, 0), \ - NCCL_FUNC4(func, devredop, half, nullForFloat), \ - NCCL_FUNC4(func, devredop, float, nullForFloat), \ - NCCL_FUNC4(func, devredop, double, nullForFloat), \ - NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat) -#define NCCL_FUNCS3B(func, devredop) \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0) - -// Must be consistent with ncclRedOp_t -#define NCCL_FUNCS2A(func) \ - NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1) - -#define NCCL_FUNCS2B(func) \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum) - // Must be consistent with the ncclFuncSet enum using ncclKernelFunc_t = void (*)(); -static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{ -// Don't try to initialize the host shadow copy of this device-side global -// variable. There is no host pointer to a device-side function, which -// confuses clang. This will be fixed in the next clang release. -#if defined(__HIP_DEVICE_COMPILE__) -#if defined(BUILD_ALLREDUCE_ONLY) - NCCL_FUNC4(AllReduce, Sum, float, 0), -#else - NCCL_FUNCS2B(Broadcast), - NCCL_FUNCS2A(Reduce), - NCCL_FUNCS2B(AllGather), - NCCL_FUNCS2A(ReduceScatter), - NCCL_FUNCS2A(AllReduce), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), -#if defined(RCCL_BFLOAT16) - NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16), -#endif - NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), - NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t), -#endif -#endif -}; +// Defined in device_table.cpp +extern __device__ ncclKernelFunc_t const ncclFuncs[]; -static_assert(FUNC_INDEX_P2P == 5410, "Wrong P2P function index"); -static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 5411, "Wrong AllToAllPivot function index"); - -#if !defined(USE_INDIRECT_FUNCTION_CALL) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) -template -struct Caller { - static __forceinline__ __device__ __host__ - void call(unsigned short funcIndex) noexcept - { - constexpr unsigned short m = f + (l - f) / 2; - - return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex); - } -}; - -template -struct Caller{ - static __forceinline__ __device__ __host__ - void call(unsigned short funcIndex) noexcept { ncclFuncs[f](); } -}; - -template -__forceinline__ -__device__ -void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept { -#if defined(BUILD_ALLREDUCE_ONLY) - if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE)) - ncclFunction_AllReduce_RING_SIMPLE_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL)) - ncclFunction_AllReduce_RING_LL_Sum_float(); - else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_RING_LL128_Sum_float(); - else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_RING_LL_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE)) - ncclFunction_AllReduce_TREE_SIMPLE_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL)) - ncclFunction_AllReduce_TREE_LL_Sum_float(); - else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_TREE_LL128_Sum_float(); - else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_TREE_LL_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_SIMPLE)) - ncclFunction_AllReduce_COLLNET_DIRECT_SIMPLE_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL)) - ncclFunction_AllReduce_COLLNET_DIRECT_LL_Sum_float(); - else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_COLLNET_DIRECT_LL128_Sum_float(); - else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_COLLNET_DIRECT_LL_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_SIMPLE)) - ncclFunction_AllReduce_COLLNET_CHAIN_SIMPLE_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL)) - ncclFunction_AllReduce_COLLNET_CHAIN_LL_Sum_float(); - else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_COLLNET_CHAIN_LL128_Sum_float(); - else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET_CHAIN, NCCL_PROTO_LL128)) - ncclFunction_AllReduce_COLLNET_CHAIN_LL_Sum_float(); - else - assert("Unsupported function index"); -#else - if (funcIndex < 1080) { - if (funcIndex % 18 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); - else if (funcIndex % 18 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(); - else if (funcIndex % 18 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); - else if (funcIndex % 18 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t(); - else if (funcIndex % 18 == 6) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (funcIndex % 18 == 8) ncclFunction_Broadcast_COLLNET_DIRECT_SIMPLE_Sum_int8_t(); - else if (funcIndex % 18 == 9) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t(); - else ncclFunction_Broadcast_COLLNET_CHAIN_SIMPLE_Sum_int8_t(); - } - else if (funcIndex < 2160) Caller<1080, 2160, USING_LL128>::call(funcIndex); - else if (funcIndex < 3240) { - if (funcIndex % 18 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); - else if (funcIndex % 18 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(); - else if (funcIndex % 18 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t(); - else if (funcIndex % 18 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t(); - else if (funcIndex % 18 == 6) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (funcIndex % 18 == 8) ncclFunction_AllGather_COLLNET_DIRECT_SIMPLE_Sum_int8_t(); - else if (funcIndex % 18 == 9) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 18 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 18 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t(); - else ncclFunction_AllGather_COLLNET_CHAIN_SIMPLE_Sum_int8_t(); - } - else if (funcIndex < 5400) Caller<3240, 5400, USING_LL128>::call(funcIndex); - else { - switch (funcIndex - 5400) { - case 0: - ncclFunction_OneRankReduce_PreMulSum_int8_t(); - break; - case 1: - ncclFunction_OneRankReduce_PreMulSum_uint8_t(); - break; - case 2: - ncclFunction_OneRankReduce_PreMulSum_int32_t(); - break; - case 3: - ncclFunction_OneRankReduce_PreMulSum_uint32_t(); - break; - case 4: - ncclFunction_OneRankReduce_PreMulSum_int64_t(); - break; - case 5: - ncclFunction_OneRankReduce_PreMulSum_uint64_t(); - break; - case 6: - ncclFunction_OneRankReduce_PreMulSum_half(); - break; - case 7: - ncclFunction_OneRankReduce_PreMulSum_float(); - break; - case 8: - ncclFunction_OneRankReduce_PreMulSum_double(); - break; - case 9: - ncclFunction_OneRankReduce_PreMulSum_rccl_bfloat16(); - break; - case 10: - ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(); - break; - case 11: - ncclFunction_AllToAllPivot_RING_SIMPLE_Sum_int8_t(); - default: - break; - } - } -#endif -} +#ifndef USE_INDIRECT_FUNCTION_CALL +__device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept; #endif template @@ -464,7 +235,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we } } -template +template __forceinline__ __device__ void ncclKernel( struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead ) { @@ -565,19 +336,12 @@ __forceinline__ __device__ void ncclKernel( __synclds(); if (tid == 0) __insert_timestamp(__LINE__); - if (ncclShmem.work.header.funcIndex == FnIndex) { - RunWork().run(&ncclShmem.work); - } else { -#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) - ncclFuncs[ncclShmem.work.header.funcIndex](); + +#ifdef USE_INDIRECT_FUNCTION_CALL + ncclFuncs[ncclShmem.work.header.funcIndex](); #else -#if defined(ENABLE_LL128) && defined(__gfx90a__) - NCCL_CALL_FUNCTIONS<1>(ncclShmem.work.header.funcIndex); -#else - NCCL_CALL_FUNCTIONS<0>(ncclShmem.work.header.funcIndex); + NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex); #endif -#endif - } int workIxNext = ncclShmem.work.header.workNext; __synclds(); @@ -606,28 +370,25 @@ __forceinline__ __device__ void ncclKernel( } #ifdef ENABLE_COLLTRACE -#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \ +#define IMPL_MAIN_KERN() \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \ +__global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ + ncclKernel(comm, channelMask, workHead); \ } \ \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(comm, channelMask, workHead); \ +__global__ void rccl_main_kernel_debug(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ + ncclKernel(comm, channelMask, workHead); \ } #else -#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \ +#define IMPL_MAIN_KERN() \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \ +__global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ + ncclKernel(comm, channelMask, workHead); \ } #endif -// Examples : AllReduce, RING, LL, Sum, uint8 -/* Functions for aggregation case */ - -#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) +#ifdef USE_INDIRECT_FUNCTION_CALL #define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \ __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \ RunWork, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \ @@ -639,67 +400,6 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev } #endif -// Only generate inline kernels for LL -#if defined(ENABLE_LL128) && defined(__gfx90a__) -#define IMPL_COLL4(func, algo, devredop, type) \ - IMPL_COLL_FUNC(func, algo, LL, devredop, type) \ - IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \ - IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) -#else -#define IMPL_COLL4(func, algo, devredop, type) \ - IMPL_COLL_FUNC(func, algo, LL, devredop, type) \ - IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) -#endif - -#define IMPL_COLL3(func, devredop, type) \ - IMPL_COLL4(func, TREE, devredop, type) \ - IMPL_COLL4(func, RING, devredop, type) \ - IMPL_COLL4(func, COLLNET_DIRECT, devredop, type) \ - IMPL_COLL4(func, COLLNET_CHAIN, devredop, type) \ - IMPL_COLL4(func, NVLS, devredop, type) \ - IMPL_COLL4(func, NVLS_TREE, devredop, type) - -#define IMPL_COLL2(func, devredop) \ - IMPL_COLL3(func, devredop, int8_t) \ - IMPL_COLL3(func, devredop, uint8_t) \ - IMPL_COLL3(func, devredop, int32_t) \ - IMPL_COLL3(func, devredop, uint32_t) \ - IMPL_COLL3(func, devredop, int64_t) \ - IMPL_COLL3(func, devredop, uint64_t) \ - IMPL_COLL3(func, devredop, half) \ - IMPL_COLL3(func, devredop, float) \ - IMPL_COLL3(func, devredop, double) \ - IMPL_COLL3(func, devredop, rccl_bfloat16) - -#define IMPL_COLL2A(func, devredop) \ - IMPL_COLL3(func, devredop, int8_t) \ - IMPL_COLL3(func, devredop, uint8_t) \ - IMPL_COLL3(func, devredop, int32_t) \ - IMPL_COLL3(func, devredop, uint32_t) \ - IMPL_COLL3(func, devredop, int64_t) \ - IMPL_COLL3(func, devredop, uint64_t) - -// Reduction define all functions -#define IMPL_COLL_R(func) \ - IMPL_COLL2(func, Sum) \ - IMPL_COLL2(func, Prod) \ - IMPL_COLL2(func, Min) \ - IMPL_COLL2(func, Max) \ - IMPL_COLL2(func, PreMulSum) \ - IMPL_COLL2A(func, SumPostDiv) - -// Copy primitives only define one function for copy -#define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t); - -// Point-to-point primitives only have one function/kernel. -#define IMPL_COLL_P(func) \ - IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); \ - IMPL_COLL_KERN(func, RING, SIMPLE, Sum, int8_t, FUNC_INDEX_P2P); - -// AllToAll Pivot primitive only has one function. -#define IMPL_COLL_F(func) \ - IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); - #define NCCL_NVLS_ENABLED (__CUDA_ARCH__ >= 900 && NCCL_NVLS_SUPPORTS(NCCL_TYPE, NCCL_OP)) -#endif +#endif \ No newline at end of file diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu deleted file mode 100644 index 85695f059e..0000000000 --- a/src/collectives/device/functions.cu +++ /dev/null @@ -1,126 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "devcomm.h" -#include "collectives.h" -#include "common.h" - -__shared__ ncclShmemData ncclShmem; -#if __CUDA_ARCH__ < 700 - __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; -#endif - -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) -#else -#define NCCL_FUNC5(func, algo, devredop, type, nullify) \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) - -#define NCCL_FUNC4(func, devredop, type, nullify) \ - NCCL_FUNC5(func, TREE, devredop, type, nullify), \ - NCCL_FUNC5(func, RING, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ - NCCL_FUNC5(func, NVLS, devredop, type, nullify), \ - NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify) - -#if defined(__CUDA_BF16_TYPES_EXIST__) -// Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, devredop, nullForFloat) \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, uint8_t, 0), \ - NCCL_FUNC4(func, devredop, int32_t, 0), \ - NCCL_FUNC4(func, devredop, uint32_t, 0), \ - NCCL_FUNC4(func, devredop, int64_t, 0), \ - NCCL_FUNC4(func, devredop, uint64_t, 0), \ - NCCL_FUNC4(func, devredop, half, nullForFloat), \ - NCCL_FUNC4(func, devredop, float, nullForFloat), \ - NCCL_FUNC4(func, devredop, double, nullForFloat), \ - NCCL_FUNC4(func, devredop, __nv_bfloat16, nullForFloat) -#define NCCL_FUNCS3B(func, devredop) \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0) -#else -// Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, devredop, nullForFloat) \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, uint8_t, 0), \ - NCCL_FUNC4(func, devredop, int32_t, 0), \ - NCCL_FUNC4(func, devredop, uint32_t, 0), \ - NCCL_FUNC4(func, devredop, int64_t, 0), \ - NCCL_FUNC4(func, devredop, uint64_t, 0), \ - NCCL_FUNC4(func, devredop, half, nullForFloat), \ - NCCL_FUNC4(func, devredop, float, nullForFloat), \ - NCCL_FUNC4(func, devredop, double, nullForFloat) -#define NCCL_FUNCS3B(func, devredop) \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0), \ - NCCL_FUNC4(func, devredop, int8_t, 0) -#endif - -// Must be consistent with ncclRedOp_t -#define NCCL_FUNCS2A(func) \ - NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \ - NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1) - -#define NCCL_FUNCS2B(func) \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum) - -// Must be consistent with the ncclFuncSet enum -__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { -// Don't try to initialize the host shadow copy of this device-side global -// variable. There is no host pointer to a device-side function, which -// confuses clang. This will be fixed in the next clang release. -#if __CUDA_ARCH__ - NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), - #if defined(__CUDA_BF16_TYPES_EXIST__) - NCCL_ONERANK_REDUCE_NAME(PreMulSum, __nv_bfloat16), - #endif - NCCL_FUNCS2B(Broadcast), - NCCL_FUNCS2A(Reduce), - NCCL_FUNCS2B(AllGather), - NCCL_FUNCS2A(ReduceScatter), - NCCL_FUNCS2A(AllReduce) -#endif -}; -#endif - -// Workaround for https://reviews.llvm.org/D55580 -__device__ void ncclWorkaroundClangD55580() {} diff --git a/src/collectives/device/reduce.cu b/src/collectives/device/reduce.cu deleted file mode 100644 index 235aade88c..0000000000 --- a/src/collectives/device/reduce.cu +++ /dev/null @@ -1,13 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -/*This file is now generated in CMake*/ - -// #include "reduce.h" -// #include "common.h" -// #include "collectives.h" - -// IMPL_COLL_R(Reduce); diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu deleted file mode 100644 index fa8202b4a6..0000000000 --- a/src/collectives/device/reduce_scatter.cu +++ /dev/null @@ -1,13 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -/*This file is now generated in CMake*/ - -// #include "reduce_scatter.h" -// #include "common.h" -// #include "collectives.h" - -// IMPL_COLL_R(ReduceScatter); diff --git a/src/collectives/device/sendrecv.cu b/src/collectives/device/sendrecv.cu deleted file mode 100644 index 59e38b528e..0000000000 --- a/src/collectives/device/sendrecv.cu +++ /dev/null @@ -1,11 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#include "sendrecv.h" -#include "common.h" -#include "collectives.h" - -IMPL_COLL_P(SendRecv); diff --git a/src/enqueue.cc b/src/enqueue.cc index a100226f8d..04fc6bb220 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -20,8 +20,6 @@ #include // std::memcpy #include // PRIx64 -static void* const ncclKernelGeneric = (void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t); - struct ncclKernelMatch { void* kernelFn; bool specialized; @@ -29,15 +27,18 @@ struct ncclKernelMatch { typedef void(*ncclKern_t)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); +// Definition of rccl_main_kernel which is only used in here +IMPL_MAIN_KERN(); + // Must be consistent with the ncclFuncSet enum #ifdef ENABLE_COLLTRACE static ncclKernelMatch const ncclKerns[2] = { - {(void *)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true}, - {(void *)NCCL_KERN_NAME_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), true}, + {(void *)rccl_main_kernel, true}, + {(void *)rccl_main_kernel_debug, true}, }; #else static ncclKernelMatch const ncclKerns[1] = { - {(void*)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true} + {(void*)rccl_main_kernel, true} }; #endif @@ -169,7 +170,7 @@ static void appendWorkElemP2p( struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId, struct ncclWorkElemP2p const *elem, bool fuseOk ) { - constexpr int funcIndex = FUNC_INDEX_P2P; + int funcIndex = ncclFuncId_P2p(); struct ncclKernelPlan::Channel* chan = &plan->channels[channelId]; struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue); if (q && funcIndex == q->work.header.funcIndex) { @@ -191,7 +192,7 @@ static void appendWorkElemP2p( } q = ncclMemoryStackAlloc(&comm->memScoped); q->work.header.type = ncclWorkTypeP2p; - q->work.header.funcIndex = FUNC_INDEX_P2P; + q->work.header.funcIndex = ncclFuncId_P2p(); chan->p2pTailElem[ncclWorkP2pTypeRecv-1] = 0; chan->p2pTailElem[ncclWorkP2pTypeSend-1] = 1; q->work.p2pElems[chan->p2pTailElem[elem->p2pType-1]] = *elem; // C++ struct assignment @@ -1313,12 +1314,12 @@ comp_next: if (info->comm->nRanks == 1) { // one-rank reduce index - *workFuncIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype); + *workFuncIndex = ncclFuncId_P2p() + int(info->datatype); return ncclSuccess; } else if (info->coll == ncclFuncAllToAllPivot) { - *workFuncIndex = FUNC_INDEX_ALLTOALL_PIVOT; + *workFuncIndex = ncclFuncId_AllToAllPivot(); } else { - *workFuncIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); + *workFuncIndex = ncclFuncId(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); } work->connIndex = 0; diff --git a/src/include/collectives.h b/src/include/collectives.h index cd0684181c..9263b4d50e 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -19,9 +19,8 @@ struct ncclDevRedOpFull { uint64_t scalarArg; }; -#define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps) -#define FUNC_INDEX_ALLTOALL_PIVOT (FUNC_INDEX_P2P+1) -#define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr)) +#define FUNC_INDEX_P2P 1015 +#define FUNC_INDEX_ALLTOALL_PIVOT 675 #define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \ ncclFunction_##func##_##algo##_##proto##_##devredop##_##type @@ -38,79 +37,13 @@ struct ncclDevRedOpFull { #define NCCL_IMPL_NAME(func, algo, proto) \ nccl##func##algo##proto -/* Declare all collective operations */ -#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) -#define DECL5(func, algo, proto, devredop, type) \ - extern __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \ - extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); -#else -#define DECL5(func, algo, proto, devredop, type) \ - extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \ - extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); +// Declare rccl main/general kernel +extern __global__ void rccl_main_kernel(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); +#ifdef ENABLE_COLLTRACE +extern __global__ void rccl_main_kernel_debug(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); #endif -#define SINGLE_ARG(...) __VA_ARGS__ -#define CONCAT(a,b) a##b -#define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(SINGLE_ARG(t), SINGLE_ARG(f)) -#define MACRO_IF_0(t, f) f -#define MACRO_IF_1(t, f) t - -#define DECL4(func, algo, devredop, type, undef) \ - MACRO_IF(undef, /*undefined*/, DECL5(func, algo, SIMPLE, devredop, type)) \ - MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL, devredop, type)) \ - MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type)) - -#define DECL3(func, devredop, type, undef) \ - DECL4(func, RING, devredop, type, undef) \ - DECL4(func, TREE, devredop, type, undef) \ - DECL4(func, COLLNET_DIRECT, devredop, type, undef) \ - DECL4(func, COLLNET_CHAIN, devredop, type, undef) \ - DECL4(func, NVLS, devredop, type, undef) \ - DECL4(func, NVLS_TREE, devredop, type, undef) - -#if defined(RCCL_BFLOAT16) -#define DECL2(func, devredop, undefForFloat) \ - DECL3(func, devredop, int8_t, /*undef=*/0) \ - DECL3(func, devredop, uint8_t, /*undef=*/0) \ - DECL3(func, devredop, int32_t, /*undef=*/0) \ - DECL3(func, devredop, uint32_t, /*undef=*/0) \ - DECL3(func, devredop, int64_t, /*undef=*/0) \ - DECL3(func, devredop, uint64_t, /*undef=*/0) \ - DECL3(func, devredop, half, /*undef=*/undefForFloat) \ - DECL3(func, devredop, float, /*undef=*/undefForFloat) \ - DECL3(func, devredop, double, /*undef=*/undefForFloat) \ - DECL3(func, devredop, rccl_bfloat16, /*undef=*/undefForFloat) -#else -#define DECL2(func, devredop, undefForFloat) \ - DECL3(func, devredop, int8_t, /*undef=*/0) \ - DECL3(func, devredop, uint8_t, /*undef=*/0) \ - DECL3(func, devredop, int32_t, /*undef=*/0) \ - DECL3(func, devredop, uint32_t, /*undef=*/0) \ - DECL3(func, devredop, int64_t, /*undef=*/0) \ - DECL3(func, devredop, uint64_t, /*undef=*/0) \ - DECL3(func, devredop, half, /*undef=*/undefForFloat) \ - DECL3(func, devredop, float, /*undef=*/undefForFloat) \ - DECL3(func, devredop, double, /*undef=*/undefForFloat) -#endif - -#define DECL(func) \ - DECL2(func, Sum, /*undefForFloat=*/0) \ - DECL2(func, Prod, /*undefForFloat=*/0) \ - DECL2(func, Min, /*undefForFloat=*/0) \ - DECL2(func, Max, /*undefForFloat=*/0) \ - DECL2(func, PreMulSum, /*undefForFloat=*/0) \ - DECL2(func, SumPostDiv, /*undefForFloat=*/1) - -DECL2(Broadcast, Sum, /*undefForFloat=*/0) -DECL(Reduce) -DECL2(AllGather, Sum, /*undefForFloat=*/0) -DECL(ReduceScatter) -DECL(AllReduce) -DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) -DECL5(AllToAllPivot, RING, SIMPLE, Sum, int8_t) - +// Declare OneRankReduce extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t)(); @@ -118,9 +51,7 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, half)(); -#if defined(RCCL_BFLOAT16) extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16)(); -#endif extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)(); extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(); diff --git a/src/include/devcomm.h b/src/include/devcomm.h index dabd33f638..55d38f99f3 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -11,6 +11,7 @@ #include "nccl.h" #include "rccl_bfloat16.h" #include "align.h" +#include "collectives.h" #if defined(ENABLE_NPKIT) #include "npkit/npkit_struct.h" #endif @@ -483,4 +484,62 @@ __host__ __device__ constexpr int ncclShmemDynamicSize(int cudaArch = NCCL_CUDA_ return cudaArch < 700 ? 0 : ncclShmemScratchWarpSize(cudaArch)*(NCCL_MAX_NTHREADS/WARP_SIZE); } +// Map the rowIdx to funcIdx +extern int const ncclFuncRowToId[]; + +// `ncclFuncIndex()` needs to be in sync with 'ALL_COLLS' in Generate.cmake +inline int ncclFuncId(int coll, int devRedOp, int type, int algo, int proto) { + int row = 0; + + // RING / / Sum / int8_t + if (coll == ncclFuncAllGather) { + row += proto; + goto have_row; + } + row += NCCL_NUM_PROTOCOLS; + + // / / / + if (coll == ncclFuncAllReduce) { + row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * (algo * NCCL_NUM_PROTOCOLS + proto); + goto have_row; + } + row += (NCCL_NUM_ALGORITHMS - 2) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4); + + // RING / SIMPLE / Sum / int8_t + if (coll == ncclFuncAllToAllPivot) goto have_row; + row += 1; + + // RING / / Sum / int8_t + if (coll == ncclFuncBroadcast) { + row += proto; + goto have_row; + } + row += NCCL_NUM_PROTOCOLS; + + // RING / / / + if (coll == ncclFuncReduce) { + row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * proto; + goto have_row; + } + row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4); + + // RING / / / + if (coll == ncclFuncReduceScatter) { + row += ((proto * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - /*floats for each SumPostDiv*/ 4 * proto; + goto have_row; + } + row += NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - /*floats for each SumPostDiv*/ 4); + + // RING / SIMPLE / Sum / int8_t + if (coll == ncclFuncSendRecv) goto have_row; + row += 1; + +have_row: + return ncclFuncRowToId[row]; +} + +inline int ncclFuncId_P2p() { return ncclFuncRowToId[FUNC_INDEX_P2P]; } + +inline int ncclFuncId_AllToAllPivot() { return ncclFuncRowToId[FUNC_INDEX_ALLTOALL_PIVOT]; } + #endif diff --git a/src/include/enqueue.h b/src/include/enqueue.h index 634f037cb3..94e331e194 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -10,6 +10,7 @@ #include "comm.h" #include "group.h" #include "collectives.h" +#include "common.h" #include "utils.h" #define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) diff --git a/src/init.cc b/src/init.cc index 24a99a2d4f..d9b4c1945d 100644 --- a/src/init.cc +++ b/src/init.cc @@ -18,6 +18,7 @@ #include "enqueue.h" #include "graph.h" #include "argcheck.h" +#include "devcomm.h" #if defined(ENABLE_NPKIT) #include "npkit/npkit.h" #endif @@ -31,6 +32,7 @@ #include #include #include +#include #include "graph/topo.h" #include "graph/xml.h" #include "archinfo.h" @@ -54,7 +56,7 @@ #define NCCL_GROUP_CUDA_STREAM 1 // CGMD: CUDA 9.0,9.1 Need to use an internal CUDA stream #endif -const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" }; +const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "AllGather", "AllReduce", "AllToAllPivot", "Broadcast", "Reduce", "ReduceScatter", "SendRecv"}; const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS", "NVLSTree" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" }; @@ -177,6 +179,18 @@ void NCCL_NO_OPTIMIZE commPoison(ncclComm_t comm) { RCCL_PARAM(KernelCollTraceEnable, "KERNEL_COLL_TRACE_ENABLE", 0); #ifdef ENABLE_COLLTRACE +#define MAX_NAME_LENGTH 64 +// Helper function to generate function names and update funcIdx +void generateFunctionName(char* func_names, int& funcIdx, const char* format, ...) { + char* line = func_names + MAX_NAME_LENGTH * funcIdx; + va_list args; + va_start(args, format); + vsnprintf(line, MAX_NAME_LENGTH, format, args); + va_end(args); + funcIdx++; +} + +// Should be in sync with 'ALL_COLLS' in Generator.cmake void *ncclCommThreadMain(void *arg) { ncclComm_t comm = (ncclComm_t)arg; int head[MAXCHANNELS]; @@ -184,29 +198,53 @@ void *ncclCommThreadMain(void *arg) { memset(head, 0, sizeof(int)*MAXCHANNELS); vega_gpu_rtc_freq = GetDeviceWallClockRateInKhz(comm->cudaDev) * 1.0E3; - #define MAX_NAME_LENGTH 64 - char* func_names = (char *)malloc(MAX_NAME_LENGTH*(FUNC_INDEX_P2P+2)); - for (int func = 0; func < NCCL_NUM_FUNCTIONS; func++) { - for (int al = 0; al < NCCL_NUM_ALGORITHMS; al++) { - for (int type = 0; type < ncclNumTypes; type++) { - for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { - for (int devredop = 0; devredop < ncclNumDevRedOps; devredop++) { - char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, devredop, type, al, pr); - sprintf(line, "%s%s%s%s%s", ncclFuncStr[func], ncclAlgoStr[al], ncclProtoStr[pr], - ncclDevRedOpStr[devredop], ncclTypeStr[type]); - } + char* func_names = (char *)malloc(MAX_NAME_LENGTH*(ncclFuncId_P2p()+/*OneRankReduce*/11)); + int funcIdx = 0; + // AllGather --> RING / / Sum / int8_t + for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { + generateFunctionName(func_names, funcIdx, "AllGatherRing%sSum_i8", ncclProtoStr[pr]); + } + // AllReduce --> / / / + for (int al = 0; al < NCCL_NUM_ALGORITHMS - 2; al++) { + for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { + for (int redop = 0; redop < ncclNumDevRedOps; redop++) { + for (int ty = 0; ty < ncclNumTypes; ty++) { + if (redop == 5 && ty > 5) continue; + generateFunctionName(func_names, funcIdx, "AllReduce%s%s%s%s", ncclAlgoStr[al], ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]); } } } } - for (int type = 0; type < ncclNumTypes; type++) { - char* line = func_names+MAX_NAME_LENGTH*(FUNC_INDEX_P2P-ncclNumTypes+type); - sprintf(line, "OneRankReducePreMulSum%s", ncclTypeStr[type]); + // AllToAllPivot --> RING / SIMPLE / Sum / int8_t + generateFunctionName(func_names, funcIdx, "AllToAllPivotRingSimpleSum_i8"); + // Broadcast --> RING / / Sum / int8_t + for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { + generateFunctionName(func_names, funcIdx, "BroadcastRing%sSum_i8", ncclProtoStr[pr]); + } + // Reduce --> RING / / / + for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { + for (int redop = 0; redop < ncclNumDevRedOps; redop++) { + for (int ty = 0; ty < ncclNumTypes; ty++) { + if (redop == 5 && ty > 5) continue; + generateFunctionName(func_names, funcIdx, "ReduceRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]); + } + } + } + // ReduceScatter --> RING / / / + for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { + for (int redop = 0; redop < ncclNumDevRedOps; redop++) { + for (int ty = 0; ty < ncclNumTypes; ty++) { + if (redop == 5 && ty > 5) continue; + generateFunctionName(func_names, funcIdx, "ReduceScatterRing%s%s%s", ncclProtoStr[pr], ncclDevRedOpStr[redop], ncclTypeStr[ty]); + } + } + } + // SendRecv --> RING / SIMPLE / Sum / int8_t + generateFunctionName(func_names, funcIdx, "SendRecvRingSimpleSum_i8"); + // OneRankReduce --> PreMulSum / + for (int ty = 0; ty < ncclNumTypes; ty++) { + generateFunctionName(func_names, funcIdx, "OneRankReducePreMulSum%s", ncclTypeStr[ty]); } - char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX_P2P; - sprintf(line, "SendRecvRingSimpleSum_i8"); - line += MAX_NAME_LENGTH; - sprintf(line, "AllToAllPivotRingSimpleSum_i8"); do { for (int channel = 0; channel < MAXCHANNELS; channel++) { int tail = comm->collTraceTail[channel].tail%COLLTRACE_NUM_ITEMS; @@ -232,7 +270,7 @@ void *ncclCommThreadMain(void *arg) { (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, fIdx, td->data_0, td->opCount, td->data_1); } else { - if (fIdx == FUNC_INDEX_P2P || type == ncclCollTraceP2pElemType) + if (fIdx == ncclFuncId_P2p() || type == ncclCollTraceP2pElemType) sprintf(line, "## [%012.6f] [%02d:%02d] %06x-%06x", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->p2pOpCount[0], td->p2pOpCount[1]); else sprintf(line, "## [%012.6f] [%02d:%02d] %06lx", (double)(td->timeStamp)/vega_gpu_rtc_freq, comm->rank, td->bid, td->opCount); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3ad45217b8..5108e4dce4 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -37,10 +37,18 @@ if(BUILD_TESTS) ) # Collect source files for tests - if(BUILD_ALLREDUCE_ONLY) - set(TEST_SOURCE_FILES - AllReduce_Tests.cpp - ) + if(ONLY_FUNCS) + # Convert input string to a list + string(REPLACE "|" ";" CONFIG_LIST ${ONLY_FUNCS}) + + # For each config in config list + foreach(item ${CONFIG_LIST}) + string(REPLACE " " ";" CONFIG_PARAMS ${item}) + list(GET CONFIG_PARAMS 0 COLL) + + set(TEST_FILE "${COLL}Tests.cpp") + list(APPEND TEST_SOURCE_FILES ${TEST_FILE}) + endforeach() else() set(TEST_SOURCE_FILES AllGatherTests.cpp diff --git a/test/common/EnvVars.cpp b/test/common/EnvVars.cpp index b6cc407e5d..07a9561d7b 100644 --- a/test/common/EnvVars.cpp +++ b/test/common/EnvVars.cpp @@ -65,12 +65,9 @@ namespace RcclUnitTesting useInteractive = GetEnvVar("UT_INTERACTIVE", 0); timeoutUs = GetEnvVar("UT_TIMEOUT_US" , 5000000); - // Limit number of supported reduction operators to just ncclSum if only allReduce is built -#ifdef BUILD_ALLREDUCE_ONLY - int numOps = 1; -#else + // Total number of reduction ops int numOps = ncclNumOps; -#endif + std::vector redOpStrings = GetEnvVarsList("UT_REDOPS"); for (auto s : redOpStrings) { @@ -98,12 +95,7 @@ namespace RcclUnitTesting { if (!strcmp(s.c_str(), ncclDataTypeNames[i])) { -#ifdef BUILD_ALLREDUCE_ONLY - if (i == ncclFloat32) -#endif - { - dataTypes.push_back((ncclDataType_t)i); - } + dataTypes.push_back((ncclDataType_t)i); } } } @@ -112,8 +104,6 @@ namespace RcclUnitTesting if (dataTypes.empty()) { dataTypes.push_back(ncclFloat32); - // Skip all but 32-bit floats if only AllReduce is being built -#ifndef BUILD_ALLREDUCE_ONLY dataTypes.push_back(ncclInt8); dataTypes.push_back(ncclUint8); dataTypes.push_back(ncclInt32); @@ -124,7 +114,6 @@ namespace RcclUnitTesting dataTypes.push_back(ncclFloat32); dataTypes.push_back(ncclFloat64); dataTypes.push_back(ncclBfloat16); -#endif } // Build list of possible # GPU ranks based on env vars