diff --git a/CMakeLists.txt b/CMakeLists.txt index 4496c361d0..03c5f84115 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,6 @@ set(DEFAULT_GPUS include(CheckIncludeFiles) include(CheckSymbolExists) include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets -include(cmake/Generator.cmake) # Configure functions that goes into RCCL list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") @@ -207,6 +206,16 @@ if(ENABLE_IFC) set(IFC_ENABLED OFF) message(WARNING "Indirect function call disabled - requires HIP version >= 5.5.30201") endif() +else() + set(IFC_ENABLED OFF) +endif() + +## Check for LL128 support +if(${hip_version_string} VERSION_GREATER_EQUAL "6.1.33591") + set(LL128_ENABLED ON) + message(STATUS "RCCL LL128 protocol enabled") +else() + message(STATUS "RCCL LL128 protocol disabled - requires HIP version >= 6.1.33591") endif() ## Check for hsa-runtime64 @@ -574,7 +583,7 @@ foreach(SRC_FILE ${SRC_FILES}) OUTPUT ${HIP_FILE} COMMAND mkdir -p ${HIP_FILE_DIR} && ${hipify-perl_executable} -quiet-warnings ${CMAKE_SOURCE_DIR}/${SRC_FILE} -o ${HIP_FILE} - && ${CMAKE_COMMAND} -DHIP_FILE=${HIP_FILE} -P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/AddUnroll.cmake + && ${CMAKE_COMMAND} -E env bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/add_unroll.sh ${HIP_FILE} MAIN_DEPENDENCY ${SRC_FILE} COMMENT "Hipifying ${SRC_FILE} -> ${HIP_FILE}" ) @@ -582,13 +591,23 @@ endforeach() # Generate device/host tables and all the collective functions that are going to be in librccl.so #================================================================================================== -if(ONLY_FUNCS) - ## Generate only the specified functions - gen_functions(${ONLY_FUNCS}) -else() - # Generate all the functions - gen_functions("AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv") -endif() +find_package(PythonInterp REQUIRED) +set(GEN_DIR "${HIPIFY_DIR}/gensrc") + +# Execute the python script to generate required files +execute_process( + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + RESULT_VARIABLE result +) + +# Find the generated files in the output directory +file(GLOB GENERATED_FILES "${GEN_DIR}/*") + +# Append all found generated files to the list +foreach(file ${GENERATED_FILES}) + list(APPEND HIP_SOURCES ${file}) +endforeach() # Create an initial git_version.cpp file (that will be updated with latest git version) #================================================================================================== @@ -616,6 +635,7 @@ target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src) # target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device) target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device/network/unpack) +target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/gensrc) target_include_directories(rccl PRIVATE ${HSA_INCLUDE_PATH}) target_include_directories(rccl PRIVATE ${ROCM_SMI_INCLUDE_DIR}) if(DEMANGLE_DIR) @@ -692,12 +712,8 @@ if(DEMANGLE_DIR) target_compile_definitions(rccl PRIVATE "HAVE_CPLUS_DEMANGLE=1") target_compile_definitions(rccl PRIVATE "HAVE_DECL_BASENAME=1") endif() -if(${hip_version_string} VERSION_GREATER_EQUAL "6.1.33591") - set(LL128_ENABLED ON) +if(LL128_ENABLED) target_compile_definitions(rccl PRIVATE ENABLE_LL128) - message(STATUS "RCCL LL128 protocol enabled") -else() - message(STATUS "RCCL LL128 protocol disabled - requires HIP version >= 6.1.33591") endif() ## Set RCCL compile options diff --git a/cmake/Generator.cmake b/cmake/Generator.cmake deleted file mode 100644 index daac5835ea..0000000000 --- a/cmake/Generator.cmake +++ /dev/null @@ -1,479 +0,0 @@ -# MIT License -# -# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -set(ALL_PARAMS "ALL_COLLS" "ALL_ALGOS" "ALL_PROTOS" "ALL_REDOPS" "ALL_TYPES") -set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "ReduceScatter" "SendRecv") -set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN") -set(ALL_PROTOS "LL" "LL128" "SIMPLE") -set(ALL_REDOPS "Sum" "Prod" "MinMax" "PreMulSum" "SumPostDiv") -set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8") - -set(FLOATS_LIST "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8") - -################################################################################ -# The command line argument is used as a regex to filter the functions -# which make it into librccl. This is helpful for reducing the binary when -# developing device code. The regex supports non-space containing globs '*', -# and union 'a|b'. The string representing the function has the form: -# -# -# -# The possible values for redop, type, algo, proto can be found in the all_ -# lists at the top of this file. -# -# Example use-cases: -# -# # Only send/recv: -# make ONLY_FUNCS="SendRecv" -# -# # Only AllReduce and Reduce -# make ONLY_FUNCS="AllReduce|Reduce" -# -# # Only non-reductions: -# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" -# -# # Only AllReduce Sum int32_t (but all algos, protos) -# make ONLY_FUNCS="AllReduce * * Sum int32_t" -# -# # Only AllReduce RING Max float (but all protos) -# make ONLY_FUNCS="AllReduce RING * Max float" -# -# # AllReduce TREE LL128 Prod rccl_bfloat16 -# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16" -# -# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter) -# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float" -# --- or --- -# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" -# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t" - -set(AllGather_Params "RING" "*" "Sum" "int8_t") -set(AllReduce_Params "*" "*" "*" "*") -set(AllToAllPivot_Params "RING" "SIMPLE" "Sum" "int8_t") -set(Broadcast_Params "RING" "*" "Sum" "int8_t") -set(Reduce_Params "RING" "*" "*" "*") -set(ReduceScatter_Params "RING" "*" "*" "*") -set(SendRecv_Params "RING" "SIMPLE" "Sum" "int8_t") - -############################################################################################################# -## Helper function to check if the conditions for the collective is being met -############################################################################################################# -function(validate_func ITEM_LIST) - set(paramIdx 1) - ## Extract coll/redop/type - list(GET ITEM_LIST 0 coll) - list(GET ITEM_LIST 3 redop) - list(GET ITEM_LIST 4 type) - - ## First check if redop 'SumPostDiv' has no type float - if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST) - set(is_valid FALSE PARENT_SCOPE) - return() - endif() - foreach(parameter IN LISTS "${coll}_Params") - if(NOT parameter STREQUAL "*") - list(GET ITEM_LIST "${paramIdx}" item) - string(FIND "${parameter}" "${item}" is_found) - if(is_found EQUAL -1) - set(is_valid FALSE PARENT_SCOPE) - return() - endif() - endif() - math(EXPR paramIdx "${paramIdx} + 1") - endforeach() - set(is_valid TRUE PARENT_SCOPE) -endfunction() - -############################################################################################################# -## A recursive helper macro to generate functions and kernels based on the input given -############################################################################################################# -macro(filter_functions FUNCTION_PARAMS current_idx) - ## Check if the current_idx does not exceed the max depth - if(${current_idx} LESS 5) - ## current_element is the config parameter - list(GET FUNCTION_PARAMS ${current_idx} current_element) - - ## If the parameter is equal to '*', include all the possible cases for it - if(${current_element} STREQUAL "*") - if(${current_idx} EQUAL 0) - message(FATAL_ERROR "Error: Parameter 'COLL' can not be type all '*'.") - endif() - ## ALL_PARAMS list must be in the same order as FUNCTION_PARAMS ---> - ## Find the respective parameter list from ALL_PARAMS list - list(GET ALL_PARAMS ${current_idx} current_list) - - ## Iterate over the items int the current_list - foreach(item IN LISTS ${current_list}) - ## Add item to ITEM_LIST which will be used in the inner most loop - list(APPEND ITEM_LIST ${item}) - math(EXPR new_idx "${current_idx} + 1") - filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN}) - - ## For each loop layer remove the last element in ITEM_LIST - list(REMOVE_AT ITEM_LIST -1) - endforeach() - else() - ## Check if the current element is recognized - list(GET ALL_PARAMS ${current_idx} current_param) - string(REPLACE "/" ";" elements ${current_element}) - ## Iterate over the elements int the ELEMENTS_LIST - foreach(item IN LISTS elements) - list(FIND ${current_param} ${item} is_valid) - if(${is_valid} EQUAL -1) - message(FATAL_ERROR "Error: ${item} is unrecognized or does not belong to this category ${current_param}.") - endif() - endforeach() - foreach(item IN LISTS elements) - ## Add item to ITEM_LIST which will be used in the inner most loop - list(APPEND ITEM_LIST ${item}) - math(EXPR new_idx "${current_idx} + 1") - filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN}) - - ## For each loop layer remove the last element in ITEM_LIST - list(REMOVE_AT ITEM_LIST -1) - endforeach() - endif() - else() - ## This is the inner most loop where the file is generated - ## Unzip ITEM_LIST - list(GET ITEM_LIST 0 COLL) - list(GET ITEM_LIST 1 ALGO) - list(GET ITEM_LIST 2 PROTO) - list(GET ITEM_LIST 3 REDOP) - list(GET ITEM_LIST 4 TYPE) - - validate_func("${ITEM_LIST}") - if (NOT is_valid) - continue() - endif() - - list(APPEND COLL_LIST "${COLL}-${ALGO}-${PROTO}-${REDOP}-${TYPE}") - set(COLL_LIST ${COLL_LIST} PARENT_SCOPE) - - ## Append the newly formed function/kernel to list - list(APPEND FUNC_LIST "ncclDevFunc_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}") - list(APPEND KERN_LIST "ncclDevKernel_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}") - set(FUNC_LIST ${FUNC_LIST} PARENT_SCOPE) - set(KERN_LIST ${KERN_LIST} PARENT_SCOPE) - endif() -endmacro() - -##################################################################################################### -## Function to generate device table -##################################################################################################### -function(gen_device_table) - ## Generate device table and list all the functions - set(DEVICE_TABLE_H_FILE "${HIPIFY_DIR}/src/device/device_table.h") - message(STATUS "Generating ${DEVICE_TABLE_H_FILE}") - - if(ENABLE_IFC) - set(func_declaration "__device__ void") - else() - set(func_declaration "__device__ __attribute__((noinline)) void") - endif() - - ## Declaration of device functions - foreach(func IN LISTS FUNC_LIST) - string(FIND "${func}" "LL128" IS_LL128) - if(NOT IS_LL128 EQUAL -1) - file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") - file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n#else\n") - string(REPLACE "LL128" "LL" func "${func}") - file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n#endif\n") - else() - file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n") - endif() - endforeach() - file(APPEND ${DEVICE_TABLE_H_FILE} "\n") - - ## Undirect function call - file(APPEND ${DEVICE_TABLE_H_FILE} "typedef void(*ncclDevFuncPtr_t)();\n\n") - file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n") - foreach(func ${FUNC_LIST}) - string(FIND "${func}" "LL128" IS_LL128) - if(NOT IS_LL128 EQUAL -1) - file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#else\n") - string(REPLACE "LL128" "LL" func "${func}") - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#endif\n") - else() - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n") - endif() - endforeach() - file(APPEND ${DEVICE_TABLE_H_FILE} "nullptr};\n\n") - file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") - foreach(func ${FUNC_LIST}) - string(FIND "${func}" "LL128" IS_LL128) - if(NOT IS_LL128 EQUAL -1) - file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n#else\n") - string(REPLACE "LL128" "LL" func "${func}") - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n#endif\n") - else() - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n") - endif() - endforeach() - file(APPEND ${DEVICE_TABLE_H_FILE} "nullptr};\n\n") - - if(NOT ENABLE_IFC) - ## Direct functions calls - file(APPEND ${DEVICE_TABLE_H_FILE} - "template\n" - "struct Caller {\n" - " static __forceinline__ __device__ __host__\n" - " void call(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller{\n" - " static __forceinline__ __device__ __host__\n" - " void call(unsigned short funcIndex) noexcept { ncclDevFuncTable[f](); }\n" - "};\n" - ) - file(APPEND ${DEVICE_TABLE_H_FILE} "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n") - list(LENGTH FUNC_LIST max_index) - file(APPEND ${DEVICE_TABLE_H_FILE} " Caller<0, ${max_index}>::call(funcIndex);\n}\n\n") - - file(APPEND ${DEVICE_TABLE_H_FILE} - "template\n" - "struct Caller4 {\n" - " static __forceinline__ __device__ __host__\n" - " void call4(unsigned short funcIndex) noexcept\n" - " {\n" - " constexpr unsigned short m = f + (l - f) / 2;\n" - " return (funcIndex < m) ? Caller4::call4(funcIndex) : Caller4::call4(funcIndex);\n" - " }\n" - "};\n" - "\n" - "template\n" - "struct Caller4{\n" - " static __forceinline__ __device__ __host__\n" - " void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n" - "};\n" - ) - file(APPEND ${DEVICE_TABLE_H_FILE} "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n") - list(LENGTH FUNC_LIST max_index) - file(APPEND ${DEVICE_TABLE_H_FILE} " Caller4<0, ${max_index}>::call4(funcIndex);\n}\n\n") - endif() - - ## Function name table for collective trace - if(COLLTRACE) - set(DEVICE_TABLE_FILE "${HIPIFY_DIR}/src/device/device_table.cpp") - message(STATUS "Generating ${DEVICE_TABLE_FILE}") - - file(APPEND ${DEVICE_TABLE_FILE} "#include \"nccl_common.h\"\n#include \"device.h\"\n\n const char* funcNames[FUNC_INDEX_TOTAL] = {\n") - foreach(func ${FUNC_LIST}) - file(APPEND ${DEVICE_TABLE_FILE} " \"${func}\",\n") - endforeach() - foreach(type IN LISTS ALL_TYPES) - file(APPEND ${DEVICE_TABLE_FILE} " \"ncclDevFunc_OneRankReduce_PreMulSum_${type}\",\n") - endforeach() - file(APPEND ${DEVICE_TABLE_FILE} "};\n") - endif() - - ## Add the device_table files to HIP_SOURCES - list(APPEND HIP_SOURCES ${DEVICE_TABLE_H_FILE}) - list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE}) - set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) -endfunction() - -###################################################################################################### -## Function to generate host-side table -###################################################################################################### -function(gen_host_table) - set(HOST_TABLE_FILE "${HIPIFY_DIR}/src/device/host_table.cpp") - message(STATUS "Generating ${HOST_TABLE_FILE}") - - file(WRITE ${HOST_TABLE_FILE} "#include \"device.h\"\n\n") - - ## The mapping from function rows to valid function ids - file(APPEND ${HOST_TABLE_FILE} "extern int const ncclDevFuncRowToId[] = {\n") - set(idx 0) - foreach(coll IN LISTS ALL_COLLS) - foreach(algo IN LISTS ALL_ALGOS) - foreach(proto IN LISTS ALL_PROTOS) - foreach(redop IN LISTS ALL_REDOPS) - foreach(type IN LISTS ALL_TYPES) - ## Create a list from the combination of curr parameters - set(ITEM_LIST ${coll} ${algo} ${proto} ${redop} ${type}) - validate_func("${ITEM_LIST}") - if (NOT is_valid) - continue() - endif() - ## Try to find the combination in the generated func list - list(FIND FUNC_LIST "ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}" fn_id) - if(NOT ${fn_id} EQUAL -1) - set(last_valid_fn_id ${fn_id}) - file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}\n") - else() - file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n") - endif() - math(EXPR idx "${idx} + 1") - endforeach() - endforeach() - endforeach() - endforeach() - endforeach() - math(EXPR last_valid_fn_id "${last_valid_fn_id} + 1") - file(APPEND ${HOST_TABLE_FILE} "${last_valid_fn_id}};\n\n") - - ## Add the host_table file to HIP_SOURCES - list(APPEND HIP_SOURCES ${HOST_TABLE_FILE}) - set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) -endfunction() - -########################################################################################################### -## Function to generate MSCCL Kernels -########################################################################################################### -function(gen_msccl_kernels) - set(MSCCL_REDOP Sum Prod MinMax) - foreach(REDOP_CURRENT IN LISTS MSCCL_REDOP) - foreach(DATA_TYPE ${ALL_TYPES}) - set(FILE_NAME "${HIPIFY_DIR}/src/device/msccl_kernel_${REDOP_CURRENT}_${DATA_TYPE}.cpp") - message(STATUS "Generating ${FILE_NAME}") - file(WRITE ${FILE_NAME} - "#include \"msccl_kernel_impl.h\" - #include \"nccl_common.h\" - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);") - list(APPEND HIP_SOURCES ${FILE_NAME}) - endforeach() - endforeach() - set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) -endfunction() - -########################################################################################################### -## Function to generate collectives -########################################################################################################### -function(gen_collectives) - # Iterate over each item in the original list - foreach(item ${COLL_LIST}) - # Split the string into components - string(REPLACE "-" ";" item_components ${item}) - - # Extract COLL, ALGO, and PROTO components - list(GET item_components 0 coll_prefix) - list(GET item_components 1 algo_prefix) - list(GET item_components 2 proto_prefix) - list(GET item_components 3 redop_prefix) - - # Create a list name using COLL, ALGO, and PROTO - set(list_name "${coll_prefix}_${algo_prefix}_${proto_prefix}_${redop_prefix}") - - # Add the item to the corresponding list - list(APPEND ${list_name} ${item}) - - # Add the list name to the map if it doesn't exist - if(NOT list_name IN_LIST divided_lists) - list(APPEND divided_lists ${list_name}) - endif() - endforeach() - - set(index 0) - foreach(list_name IN LISTS divided_lists) - foreach(item IN LISTS ${list_name}) - # Convert to a list - string(REPLACE "-" ";" components ${item}) - - list(GET components 0 coll) - list(GET components 1 algo) - list(GET components 2 proto) - list(GET components 3 redop) - list(GET components 4 type) - - list(APPEND IMPL_LIST "DEFINE_ncclDevFunc(${coll}_${algo}_${proto}_${redop}_${type}, ncclFunc${coll}, Func${redop}, ${type}, NCCL_ALGO_${algo}, NCCL_PROTO_${proto})\n") - - # Increment the function id - math(EXPR index "${index} + 1") - endforeach() - ## Store lower-case version of COLL - string(TOLOWER ${coll} COLL_LOWER) - string(REPLACE "scatter" "_scatter" COLL_LOWER ${COLL_LOWER}) - if(NOT ${coll} STREQUAL "AllToAllPivot") - string(REPLACE "all" "all_" COLL_LOWER ${COLL_LOWER}) - else() - string(REPLACE "pivot" "_pivot" COLL_LOWER ${COLL_LOWER}) - endif() - - ## Set name/path of the file - set(FILE_PATH "${HIPIFY_DIR}/src/device/${list_name}.cpp") - message(STATUS "Generating ${FILE_PATH}") - - ## Construct the file - file(WRITE ${FILE_PATH} "#include \"${COLL_LOWER}.h\"\n#include \"common.h\"\n\n") - string(FIND "${list_name}" "LL128" IS_LL128) - if(NOT IS_LL128 EQUAL -1) - file(APPEND ${FILE_PATH} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") - endif() - foreach(IMPL IN LISTS IMPL_LIST) - file(APPEND ${FILE_PATH} "${IMPL}") - endforeach() - if(NOT IS_LL128 EQUAL -1) - file(APPEND ${FILE_PATH} "#endif\n") - endif() - - ## Append the file to HIP sources list which will be added to source list - list(APPEND HIP_SOURCES ${FILE_PATH}) - set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) - - # Clear the IMPL list for the next iteration - set(IMPL_LIST) - endforeach() -endfunction() - -################################################################################################################### -## Function to generate all the functions that are going to be in librccl.so -################################################################################################################### -function(gen_functions CONFIG_INPUT) - string(REPLACE "|" ";" INPUT_LIST ${CONFIG_INPUT}) - ## Sort the input so that it matches ALL_COLLS - list(SORT INPUT_LIST) - - foreach(INPUT ${INPUT_LIST}) - # Parse the the config string and make it a list - string(REPLACE " " ";" FUNCTION_PARAMS ${INPUT}) - - # Get the number of parameters in the input - list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH) - - # Assume all if a parameter is missing - while(${PARAMS_LENGTH} LESS 5) - list(APPEND FUNCTION_PARAMS "*") - list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH) - endwhile() - - ## Filter functions/kernels based on input - filter_functions(FUNCTION_PARAMS 0) - endforeach() - - gen_collectives() ## Generate collective files - if(ENABLE_MSCCL_KERNEL) - gen_msccl_kernels() ## Generate msccl files (not configurable) - endif() - gen_device_table() ## Generate device_table.cpp - gen_host_table() ## Generate host_table.cpp - - set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) -endfunction() diff --git a/cmake/scripts/AddUnroll.cmake b/cmake/scripts/AddUnroll.cmake deleted file mode 100644 index a53dc3c64d..0000000000 --- a/cmake/scripts/AddUnroll.cmake +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -## This only applies to collective header files -if(HIP_FILE MATCHES ".*/src/device/.*\\.h$") - execute_process(COMMAND sed -i "s/template/template/g" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/template/template/g" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/\\(runRing]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/runTreeUpDown>/runTreeUpDown, COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/\\(runTreeSplit]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/\\(struct RunWorkElement]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE}) - execute_process(COMMAND sed -i "s/\\(struct RunWork]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE}) - - message(STATUS "Added COLL_UNROLL template argument to ${HIP_FILE}") -endif() diff --git a/cmake/scripts/add_unroll.sh b/cmake/scripts/add_unroll.sh new file mode 100644 index 0000000000..e13c2e6bea --- /dev/null +++ b/cmake/scripts/add_unroll.sh @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +HIP_FILE=$1 + +if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then + sed -i "s/template/template/g" "$HIP_FILE" + sed -i "s/template/template/g" "$HIP_FILE" + sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/\\(runRing]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/runTreeUpDown>/runTreeUpDown, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/\\(runTreeSplit]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/\\(struct RunWorkElement]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + sed -i "s/\\(struct RunWork]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE" + + echo "Added COLL_UNROLL template argument to $HIP_FILE" +fi \ No newline at end of file diff --git a/src/device/generate.py b/src/device/generate.py index 43de85d616..294a449535 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -3,11 +3,13 @@ import os import sys # Order of redops, tys, protos, algos must match src/include/device.h -all_colls = ["Broadcast","Reduce","AllGather","ReduceScatter","AllReduce","SendRecv"] +all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"] all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"] -all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16"] +all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16", "f8", "bf8"] all_protos = ["LL","LL128","SIMPLE"] -all_algos = ["TREE","RING","COLLNET_DIRECT","COLLNET_CHAIN","NVLS","NVLS_TREE"] +all_algos = ["TREE","RING"] + +all_params = [all_colls, all_algos, all_protos, all_redops, all_tys] ################################################################################ # The first command line argument is the path to the directory to generate and @@ -20,71 +22,105 @@ if os.path.exists(gensrc): os.remove(os.path.join(gensrc, name)) #os.truncate(os.path.join(gensrc, name), 0) else: - os.mkdir(gensrc) + os.makedirs(gensrc) ################################################################################ -# The second command line argument is used as a regex to filter the functions -# which make it into libnccl. This is helpful for reducing the binary when +# The command line argument is used as a regex to filter the functions +# which make it into librccl. This is helpful for reducing the binary when # developing device code. The regex supports non-space containing globs '*', -# parentheses '(x)', and union 'a|b'. The string representing the function has -# one of the forms: +# and union 'a|b'. The string representing the function has the form: # -# SendRecv -# (AllGather|Broadcast) -# (AlLReduce|Reduce|ReduceScatter) +# # # The possible values for redop, type, algo, proto can be found in the all_ # lists at the top of this file. # -# Since the Makefile forwards this from the ONLY_FUNCS variable, useful command -# line examples are given: -""" -# Only send/recv: -make ONLY_FUNCS="SendRecv" - -# Only non-reductions: -make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" - -# Only AllReduce sum f32 (but all algos, protos) -make ONLY_FUNCS="AllReduce Sum f32 * *" - -# Only AllReduce minmax i32 NVLS (but all protos) -make ONLY_FUNCS="AllReduce MinMax i32 NVLS *" - -# AllReduce sum RING LL128 -make ONLY_FUNCS="AllReduce Sum f32 RING LL128" -""" +# Example use-cases: +# +# # Only send/recv: +# make ONLY_FUNCS="SendRecv" +# +# # Only AllReduce and Reduce +# make ONLY_FUNCS="AllReduce|Reduce" +# +# # Only non-reductions: +# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv" +# +# # Only AllReduce Sum int32_t (but all algos, protos) +# make ONLY_FUNCS="AllReduce * * Sum int32_t" +# +# # Only AllReduce RING Max float (but all protos) +# make ONLY_FUNCS="AllReduce RING * Max float" +# +# # AllReduce TREE LL128 Prod rccl_bfloat16 +# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16" +# +# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter) +# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float" +# --- or --- +# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" +# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t" # Paste all non-None arguments together with `sep`. def paste(sep, *args): return sep.join(x for x in args if x is not None) -func_pattern = sys.argv[2:3] +is_ifc = 1 if sys.argv[2] == "ON" else 0 +is_colltrace = 1 if sys.argv[3] == "ON" else 0 +is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0 + +func_pattern = sys.argv[5:6] if func_pattern and func_pattern[0]: - import re func_pattern = func_pattern[0] - func_pattern = func_pattern.replace("*", "[^ ]*") - func_pattern += "$" - def func_filter(*fn): - return None is not re.match(func_pattern, paste(" ", *fn), flags=re.IGNORECASE) else: - def func_filter(coll, redop, ty, algo, proto): - return True + func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv" ################################################################################ algos_of_coll = { - "AllGather": ["RING","COLLNET_DIRECT","NVLS"], + "AllGather": ["RING"], "AllReduce": all_algos, + "AllToAllPivot": ["RING"], "Broadcast": ["RING"], "Reduce": ["RING"], - "ReduceScatter": ["RING","COLLNET_DIRECT","NVLS"], - "SendRecv": [None] + "ReduceScatter": ["RING"], + "SendRecv": ["RING"] +} + +protos_of_coll = { + "AllGather": all_protos, + "AllReduce": all_protos, + "AllToAllPivot": ["SIMPLE"], + "Broadcast": all_protos, + "Reduce": all_protos, + "ReduceScatter": all_protos, + "SendRecv": ["SIMPLE"] +} + +redops_of_coll = { + "AllGather": ["Sum"], + "AllReduce": all_redops, + "AllToAllPivot": ["Sum"], + "Broadcast": ["Sum"], + "Reduce": all_redops, + "ReduceScatter": all_redops, + "SendRecv": ["Sum"] +} + +tys_of_coll = { + "AllGather": ["i8"], + "AllReduce": all_tys, + "AllToAllPivot": ["i8"], + "Broadcast": ["i8"], + "Reduce": all_tys, + "ReduceScatter": all_tys, + "SendRecv": ["i8"] } coll_camel_to_lower = { "AllGather": "all_gather", "AllReduce": "all_reduce", + "AllToAllPivot": "alltoall_pivot", "Broadcast": "broadcast", "Reduce": "reduce", "ReduceScatter": "reduce_scatter", @@ -94,141 +130,237 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower} ################################################################################ -# Returns pair of minimum required values for (CUDART_VERSION, __CUDA_ARCH__) -# or None if function is never supported. Note that (0, 0) encodes universal -# support. -def required_cuda(coll, redop, ty, algo, proto): - cudart, arch = 0, 0 - # kernels mapped to by coll="Nop" functions have coll="Generic" - if coll in ("SendRecv", "Generic", "Nop"): return (cudart, arch) +# Helper function to check if the conditions for the collective is being met +def func_validate(coll, algo, proto, redop, ty): + if redop == "SumPostDiv" and ty[0] not in ("i","u"): + return False + if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]: + return False + return True - if proto!="SIMPLE" and algo not in ("RING","TREE"): return None +# A recursive helper to generate collective functions based on the input given +def func_filter(function_params, current_idx, item_list=None): + if item_list is None: + item_list = [] - if coll in ("AllReduce","Reduce","ReduceScatter"): - if redop=="SumPostDiv" and ty[0] not in ("i","u"): return None - if ty=="bf16": cudart = max(cudart, 11000) + # Check if current_idx exceeds the max depth + if current_idx < len(all_params): + # Current element is the config parameter + current_element = function_params[current_idx] - if "NVLS" in algo: - if coll in ("AllReduce","Reduce","ReduceScatter"): - # Must match ncclNvlsSupported() in src/include/device.h - nvls_ok = ((ty in ("i32","u32","i64","u64") and redop in ("Sum","MinMax")) or - (ty in ("f32","f64") and redop=="Sum") or - (ty in ("f16","bf16") and redop in ("Sum","MinMax"))) - if not nvls_ok: return None - cudart = max(cudart, 12010) - arch = max(arch, 900) + # If the paramter is equal to '*', include all possible cases for it + if current_element == "*": + if current_idx == 0: + raise ValueError("Error: Paramter 'COLL' can not be type all '*'.") + + # all_params list must be in the same order as function_params --> + # Get the current list from all_params + current_list = all_params[current_idx] - return (cudart, arch) + # Iterate over the items int the current_list + for item in current_list: + # Add item to item_list which will be used in the inner most loop + item_list.append(item) + yield from func_filter(function_params, current_idx+1, item_list) + + # For each loop layer remove the last element in item_list + item_list.pop() + else: + # Check if the current element is recognized + elements = current_element.split("/") + current_param = all_params[current_idx] + + # Iterate over the elements in the elements list + for item in elements: + if item not in current_param: + raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.") + + for item in elements: + item_list.append(item) + yield from func_filter(function_params, current_idx+1, item_list) + + # For each loop layer remove the last element in item_list + item_list.pop() + else: + coll, algo, proto, redop, ty = item_list + + if func_validate(*item_list): + yield(coll, algo, proto, redop, ty) + +# Parse ONLY_FUNCS input and feed it to func_filter +def parse_input(func_pattern): + input_list = sorted(func_pattern.split("|")) + + for input in input_list: + function_params = input.split() + params_length = len(function_params) + + # If a parameter is missing, append '*' + while params_length < len(all_params): + function_params.append("*") + params_length += 1 + + # Filter functions/kernels based on input + yield from func_filter(function_params, 0) # Maps functions to the chosen representative for the equivalence class it # belongs to. For instance (sum, signed int) maps to (sum, unsigned int). -def equivalent_primary(coll, redop, ty, algo, proto): +def equivalent_primary(coll, algo, proto, redop, ty): if coll in ("AllReduce", "Reduce", "ReduceScatter"): # map signed integer sum/prod to unsigned if redop in ("Sum","Prod","PreMulSum") and ty[0]=="i": - return (coll, redop, "u"+ty[1:], algo, proto) + ty = "u"+ty[1:] # map signed integer min/max to unsigned for non-NVLS - if redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo): - return (coll, redop, "u"+ty[1:], algo, proto) - return (coll, redop, ty, algo, proto) - -# Map to another func representing the best kernel to use. Every distinct value -# returned will instantiate a ncclDevKernel specialized to run this func -# without function call overhead. -def best_kernel(coll, redop, ty, algo, proto): - def best(coll, redop, ty, algo, proto): - # Modify this logic to control how many kernels are specialized. - if coll=="Nop": return ("Generic", None, None, None, None) - if coll=="SendRecv": return ("SendRecv", None, None, None, None) - if coll in ("AllGather","Broadcast"): return (coll, None, None, "RING", "LL") - return (coll, "Sum", ty, ("TREE" if algo=="TREE" else "RING"), "LL") - # Need to ensure kernel is specialize for a primary function - kfn = equivalent_primary(*best(coll, redop, ty, algo, proto)) - # And isn't filtered out. - if not func_filter(*kfn): return ("Generic", None, None, None, None) - return kfn + elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo): + ty = "u"+ty[1:] + return (coll, algo, proto, redop, ty) # Order rows are enumerated must match formula of `ncclDevFuncId()`: def enumerate_func_rows(): - yield ("SendRecv", None, None, None, None) - for coll in ("AllGather", "Broadcast"): - algos = algos_of_coll[coll] - for algo in algos: + for coll in all_colls: + for algo in all_algos: for proto in all_protos: - yield (coll, None, None, algo, proto) - for coll in ("AllReduce", "Reduce", "ReduceScatter"): - algos = algos_of_coll[coll] - for redop in all_redops: - for ty in all_tys: - for algo in algos: - for proto in all_protos: - yield (coll, redop, ty, algo, proto) + for redop in all_redops: + for ty in all_tys: + if func_validate(coll, algo, proto, redop, ty): + yield (coll, algo, proto, redop, ty) + +# Sort the hashmap based on custom key +def custom_sort_key(fn): + coll, algo, proto, redop, ty = fn + + return ( + all_colls.index(coll), + all_algos.index(algo), + all_protos.index(proto), + all_redops.index(redop), + all_tys.index(ty) + ) ################################################################################ -def is_built(coll, redop, ty, algo, proto): - built = required_cuda(coll, redop, ty, algo, proto) - built = built and func_filter(coll, redop, ty, algo, proto) - return built - -# Returns None if required_cuda(...) is None. -# Returns the coll="Nop" function if developer has filtered it out. -# Otherwise just returns func it was given. -def validate(coll, redop, ty, algo, proto): - valid = required_cuda(coll, redop, ty, algo, proto) - built = valid and func_filter(coll, redop, ty, algo, proto) - if built: return (coll, redop, ty, algo, proto) - if valid: return ("Nop", None, None, None, None) - return None - # Corresponds to ncclDevFuncRowToId[] -func_rows = [validate(*fn) for fn in enumerate_func_rows()] +func_rows = [fn for fn in enumerate_func_rows()] # Corresponds to ncclDevFuncTable[] -primary_funcs = sorted(set(equivalent_primary(*fn) for fn in func_rows if fn is not None)) +primary_funcs = sorted(set(equivalent_primary(*fn) for fn in parse_input(func_pattern)), key=custom_sort_key) # primary_to_index[primary_funcs[i]] == i -primary_to_index = {fn: i for (i,fn) in zip(range(len(primary_funcs)), primary_funcs)} - -kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) +primary_to_index = {fn: primary_funcs.index(fn) if fn in primary_funcs else -1 for fn in func_rows} ################################################################################ -# Generate /device_table.cu -with open(os.path.join(gensrc, "device_table.cu"), "w") as f: +# Generate /device_table.h +with open(os.path.join(gensrc, "device_table.h"), "w") as f: + print("-- Generating %s" % os.path.join(gensrc, "device_table.h")) out = f.write - out('#include "common.h"\n') - out("\n") + + if is_ifc: func_declaration = "__device__ void" + else: func_declaration = "__device__ __attribute__((noinline)) void" for fn in primary_funcs: sym = paste("_", "ncclDevFunc", *fn) - cudart, arch = required_cuda(*fn) - if (cudart, arch) != (0, 0): - out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch)) - out("__device__ void %s();\n" % sym) - if (cudart, arch) != (0, 0): - out("#endif\n") + if fn[2] == "LL128": + out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") + out("%s %s();\n%s %s_4();\n#else\n" % (func_declaration, sym, func_declaration, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("%s %s();\n%s %s_4();\n#endif\n" % (func_declaration, sym_ll, func_declaration, sym_ll)) + else: + out("%s %s();\n%s %s_4();\n" % (func_declaration, sym, func_declaration, sym)) out("\n") - out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n"); + out("typedef void(*ncclDevFuncPtr_t)();\n\n") + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n") index = 0 for fn in primary_funcs: sym = paste("_", "ncclDevFunc", *fn) - cudart, arch = required_cuda(*fn) - if (cudart, arch) != (0, 0): - out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart ,arch)) - out("/*%4d*/ %s,\n" % (index, sym)) - if (cudart, arch) != (0, 0): - out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index) + if fn[2] == "LL128": + out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s,\n#else\n" % (index, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s,\n#endif\n" % (index, sym_ll)) + else: + out("/*%4d*/ %s,\n" % (index, sym)) index += 1 out("nullptr};\n") out("\n") - out("// Workaround for https://reviews.llvm.org/D55580\n" - "__device__ void ncclWorkaroundClangD55580() {}\n") + out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n") + index = 0 + for fn in primary_funcs: + sym = paste("_", "ncclDevFunc", *fn) + if fn[2] == "LL128": + out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") + out("/*%4d*/ %s_4,\n#else\n" % (index, sym)) + fn_ll = fn[:2] + ("LL",) + fn[3:] + sym_ll = paste("_", "ncclDevFunc", *fn_ll) + out("/*%4d*/ %s_4,\n#endif\n" % (index, sym_ll)) + else: + out("/*%4d*/ %s_4,\n" % (index, sym)) + index += 1 + out("nullptr};\n") + out("\n") + + if not is_ifc: + out("template\n" + "struct Caller {\n" + " static __forceinline__ __device__ __host__\n" + " void call(unsigned short funcIndex) noexcept\n" + " {\n" + " constexpr unsigned short m = f + (l - f) / 2;\n" + " return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex);\n" + " }\n" + "};\n" + "\n" + "template\n" + "struct Caller{\n" + " static __forceinline__ __device__ __host__\n" + " void call(unsigned short funcIndex) noexcept { ncclDevFuncTable[f](); }\n" + "};\n") + out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n") + out(f" Caller<0, {index}>::call(funcIndex);\n") + out("}\n\n") + out("template\n" + "struct Caller4 {\n" + " static __forceinline__ __device__ __host__\n" + " void call4(unsigned short funcIndex) noexcept\n" + " {\n" + " constexpr unsigned short m = f + (l - f) / 2;\n" + " return (funcIndex < m) ? Caller4::call4(funcIndex) : Caller4::call4(funcIndex);\n" + " }\n" + "};\n" + "\n" + "template\n" + "struct Caller4{\n" + " static __forceinline__ __device__ __host__\n" + " void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n" + "};\n") + out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n") + out(f" Caller4<0, {index}>::call4(funcIndex);\n") + out("}\n\n") + +# Generate /device_table.cpp +if is_colltrace: + with open(os.path.join(gensrc, "device_table.cpp"), "w") as f: + print("-- Generating %s" % os.path.join(gensrc, "device_table.cpp")) + + out = f.write + out('#include "nccl_common.h"\n#include "device.h"\n') + out("\n") + + out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n") + for fn in primary_funcs: + out(' "%s",\n' % paste("_", "ncclDevFunc", *fn)) + for ty in all_tys: + out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n') + out("};\n") + +# Generate /host_table.cpp +with open(os.path.join(gensrc, "host_table.cpp"), "w") as f: + print("-- Generating %s" % os.path.join(gensrc, "host_table.cpp")) -# Generate /host_table.cc -with open(os.path.join(gensrc, "host_table.cc"), "w") as f: out = f.write out('#include "device.h"\n') out("\n") @@ -243,61 +375,14 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: comment = " // " + paste(" ", *fn) out("/*%4d*/ %d,%s\n" % (index, fn_id, comment)) index += 1 - out("-1};\n") - out("\n") - - # Forward declarations of kernels. - for kfn in kernel_funcs: - cudart, _ = required_cuda(*kfn) - sym = paste("_", "ncclDevKernel", *kfn) - if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart) - out("__global__ void %s(struct ncclDevComm*, uint64_t, struct ncclWork*);\n" % sym) - if cudart != 0: out("#endif\n") - out("\n") - - # List of all kernel function pointers. - out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs)) - out("extern void* const ncclDevKernelList[] = {\n") - index = 0 - for kfn in kernel_funcs: - cudart, _ = required_cuda(*kfn) - sym = paste("_", "ncclDevKernel", *kfn) - if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart) - out("/*%4d*/ (void*)%s,\n" % (index, sym)); - if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index) - index += 1 - out("nullptr};\n") - out("\n") - - # Maps primary id to kernel function pointer. - out("extern void* const ncclDevKernelForFunc[] = {\n") - index = 0 - for fn in primary_funcs: - kfn = best_kernel(*fn) - sym = paste("_", "ncclDevKernel", *kfn) - cudart, _ = required_cuda(*kfn) - if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart) - out("/*%4d*/ (void*)%s,\n" % (index, sym)) - if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index) - index += 1 - out("nullptr};\n") - out("\n") - - # Does the prior map use an explicitly specialized kernel. - out("extern bool const ncclDevKernelForFuncIsSpecialized[] = {\n") - index = 0 - for fn in primary_funcs: - kfn = best_kernel(*fn) - specialized = "1" if fn == kfn else "0" - out("/*%4d*/ %s,\n" % (index, specialized)) - index += 1 - out("0};\n") + out(f"{index}") + out("};\n") # Maps to .cu filename which implements this func. The only constraint is that # "coll" is reflected in the name: formally that no two funcs having different # coll's map to the same filename. -def impl_filename(coll, redop, ty, algo, proto): - return "%s.cu" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty) +def impl_filename(coll, algo, proto, redop, ty): + return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty) # Partition the functions and kernels to the .cu filenames. The partition is # a dictionary mapping filename to (coll, func-tuple list) @@ -312,33 +397,6 @@ def partition_by_name(fns): return ans name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn[0]!="Nop") -name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Generic") - -# Generate /rules.mk -with open(os.path.join(gensrc, "rules.mk"), "w") as f: - out = f.write - impl_names = sorted(name_to_funcs.keys()) - names = impl_names + ["host_table.cc", "device_table.cu"] - out("LIB_OBJS_GEN = $(patsubst %, $(OBJDIR)/genobj/%.o, {names})\n" - .format(names=" ".join(names))) - out("\n") - - # For each __.cu compile to a .cu.o file. Notice the dependencies - # come from the suffix-erased file (e.g. 'gensrc/all_reduce.cu') - for name in impl_names: - coll = name_to_funcs[name][0] - out( - "$(OBJDIR)/genobj/{name}.o: $(OBJDIR)/gensrc $(OBJDIR)/genobj/{lower_coll}.cu.d\n" - "\t" "$(call COMPILE,$@,$(OBJDIR)/gensrc/{name})\n" - "\n" - .format(name=name, lower_coll=coll_camel_to_lower[coll]) - ) - -# Add the suffix-erased .cu's which are used only for dependency scraping. -for coll in set(coll for (coll,_,_,_,_) in primary_funcs if coll!="Nop"): - name = impl_filename(coll, None, None, None, None) - if name not in name_to_funcs: - name_to_funcs[name] = (coll, []) redop_to_cxx = { None: "FuncCopy", @@ -360,13 +418,17 @@ ty_to_cxx = { "f16": "half", "f32": "float", "f64": "double", - "bf16": "__nv_bfloat16" + "bf16": "hip_bfloat16", + "f8": "rccl_float8", + "bf8": "rccl_bfloat8", } -# Generate each /.cu: +# Generate each /.cpp: for name in name_to_funcs.keys(): (coll, fns) = name_to_funcs[name] with open(os.path.join(gensrc, name), "w") as f: + print("-- Generating %s" % os.path.join(gensrc, name)) + out = f.write out( '#include "common.h"\n' @@ -374,32 +436,30 @@ for name in name_to_funcs.keys(): .format(lower_coll=coll_camel_to_lower[coll]) ) - (_, kfns) = name_to_kernels.get(name) or (None, []) - for kfn in kfns: - (coll, redop, ty, algo, proto) = kfn - sym = paste("_", coll, redop, ty, algo, proto) - fn_id = primary_to_index[kfn] - cudart, arch = required_cuda(*kfn) - if (cudart, arch) != (0, 0): - out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch)) - out( - "DEFINE_ncclDevKernel({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {fn_id})\n" - .format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty], - algo=(algo or "RING"), proto=(proto or "SIMPLE"), fn_id=fn_id) - ) - if (cudart, arch) != (0, 0): - out("#endif\n") - for fn in fns: - (coll, redop, ty, algo, proto) = fn - sym = paste("_", coll, redop, ty, algo, proto) - cudart, arch = required_cuda(*fn) - if (cudart, arch) != (0, 0): - out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch)) + (coll, algo, proto, redop, ty) = fn + sym = paste("_", coll, algo, proto, redop, ty) + if proto == "LL128": + out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") out( "DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto})\n" .format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty], algo=(algo or "RING"), proto=(proto or "SIMPLE")) ) - if (cudart, arch) != (0, 0): + if proto == "LL128": out("#endif\n") + +# Generate each /.cpp +if is_msccl_kernels: + for redop in all_redops: + if redop in ("Sum", "Prod", "MinMax"): + for ty in all_tys: + with open(os.path.join(gensrc, f"msccl_kernel_{redop}_{ty}.cpp"), "w") as f: + print("-- Generating %s" % os.path.join(gensrc, f"msccl_kernel_{redop}_{ty}.cpp")) + + out = f.write + out('#include "msccl_kernel_impl.h"\n#include "nccl_common.h"\n') + out( + "MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE({redop}, {ty_cxx}, false);\n" + .format(redop=redop, ty_cxx=ty_to_cxx[ty]) + ) \ No newline at end of file diff --git a/src/include/device.h b/src/include/device.h index 28d2794d12..54bb7027a2 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -552,7 +552,7 @@ inline bool ncclNvlsSupported(int devRedOp, int type) { // Map the rowIdx to funcIdx extern int const ncclDevFuncRowToId[]; -// `ncclFuncIndex()` needs to be in sync with 'ALL_COLLS' in Generate.cmake +// `ncclDevFuncId()` needs to be in sync with 'ALL_COLLS' in generate.py inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) { int row = 0; do { @@ -568,7 +568,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto); break; } - row += (NCCL_NUM_ALGORITHMS - 2) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); + row += (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS); // RING / SIMPLE / Sum / int8_t if (coll == ncclFuncAllToAllPivot) break; diff --git a/src/include/nccl_common.h b/src/include/nccl_common.h index d13a60f49d..68e8d31e02 100644 --- a/src/include/nccl_common.h +++ b/src/include/nccl_common.h @@ -13,7 +13,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); #define NCCL_NUM_ONERANK 12 -#define FUNC_INDEX_TOTAL 980 + NCCL_NUM_ONERANK +#define FUNC_INDEX_TOTAL 656 + NCCL_NUM_ONERANK #define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now typedef enum {