[BUILD] Move code generation to python from CMake (#1360)

* Use generate.py for func generation

* Convert AddUnroll.cmake to bash
This commit is contained in:
Bertan Dogancay
2024-10-03 10:21:19 -04:00
committato da GitHub
parent 038517b169
commit 2dd10c8f17
7 ha cambiato i file con 366 aggiunte e 768 eliminazioni
+30 -14
Vedi File
@@ -55,7 +55,6 @@ set(DEFAULT_GPUS
include(CheckIncludeFiles)
include(CheckSymbolExists)
include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets
include(cmake/Generator.cmake) # Configure functions that goes into RCCL
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
@@ -207,6 +206,16 @@ if(ENABLE_IFC)
set(IFC_ENABLED OFF)
message(WARNING "Indirect function call disabled - requires HIP version >= 5.5.30201")
endif()
else()
set(IFC_ENABLED OFF)
endif()
## Check for LL128 support
if(${hip_version_string} VERSION_GREATER_EQUAL "6.1.33591")
set(LL128_ENABLED ON)
message(STATUS "RCCL LL128 protocol enabled")
else()
message(STATUS "RCCL LL128 protocol disabled - requires HIP version >= 6.1.33591")
endif()
## Check for hsa-runtime64
@@ -574,7 +583,7 @@ foreach(SRC_FILE ${SRC_FILES})
OUTPUT ${HIP_FILE}
COMMAND mkdir -p ${HIP_FILE_DIR}
&& ${hipify-perl_executable} -quiet-warnings ${CMAKE_SOURCE_DIR}/${SRC_FILE} -o ${HIP_FILE}
&& ${CMAKE_COMMAND} -DHIP_FILE=${HIP_FILE} -P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/AddUnroll.cmake
&& ${CMAKE_COMMAND} -E env bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/add_unroll.sh ${HIP_FILE}
MAIN_DEPENDENCY ${SRC_FILE}
COMMENT "Hipifying ${SRC_FILE} -> ${HIP_FILE}"
)
@@ -582,13 +591,23 @@ endforeach()
# Generate device/host tables and all the collective functions that are going to be in librccl.so
#==================================================================================================
if(ONLY_FUNCS)
## Generate only the specified functions
gen_functions(${ONLY_FUNCS})
else()
# Generate all the functions
gen_functions("AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv")
endif()
find_package(PythonInterp REQUIRED)
set(GEN_DIR "${HIPIFY_DIR}/gensrc")
# Execute the python script to generate required files
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
RESULT_VARIABLE result
)
# Find the generated files in the output directory
file(GLOB GENERATED_FILES "${GEN_DIR}/*")
# Append all found generated files to the list
foreach(file ${GENERATED_FILES})
list(APPEND HIP_SOURCES ${file})
endforeach()
# Create an initial git_version.cpp file (that will be updated with latest git version)
#==================================================================================================
@@ -616,6 +635,7 @@ target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src) #
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include)
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device)
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device/network/unpack)
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/gensrc)
target_include_directories(rccl PRIVATE ${HSA_INCLUDE_PATH})
target_include_directories(rccl PRIVATE ${ROCM_SMI_INCLUDE_DIR})
if(DEMANGLE_DIR)
@@ -692,12 +712,8 @@ if(DEMANGLE_DIR)
target_compile_definitions(rccl PRIVATE "HAVE_CPLUS_DEMANGLE=1")
target_compile_definitions(rccl PRIVATE "HAVE_DECL_BASENAME=1")
endif()
if(${hip_version_string} VERSION_GREATER_EQUAL "6.1.33591")
set(LL128_ENABLED ON)
if(LL128_ENABLED)
target_compile_definitions(rccl PRIVATE ENABLE_LL128)
message(STATUS "RCCL LL128 protocol enabled")
else()
message(STATUS "RCCL LL128 protocol disabled - requires HIP version >= 6.1.33591")
endif()
## Set RCCL compile options
-479
Vedi File
@@ -1,479 +0,0 @@
# MIT License
#
# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
set(ALL_PARAMS "ALL_COLLS" "ALL_ALGOS" "ALL_PROTOS" "ALL_REDOPS" "ALL_TYPES")
set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "ReduceScatter" "SendRecv")
set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN")
set(ALL_PROTOS "LL" "LL128" "SIMPLE")
set(ALL_REDOPS "Sum" "Prod" "MinMax" "PreMulSum" "SumPostDiv")
set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8")
set(FLOATS_LIST "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8")
################################################################################
# The command line argument is used as a regex to filter the functions
# which make it into librccl. This is helpful for reducing the binary when
# developing device code. The regex supports non-space containing globs '*',
# and union 'a|b'. The string representing the function has the form:
#
# <coll> <algo> <proto> <redop> <type>
#
# The possible values for redop, type, algo, proto can be found in the all_<foo>
# lists at the top of this file.
#
# Example use-cases:
#
# # Only send/recv:
# make ONLY_FUNCS="SendRecv"
#
# # Only AllReduce and Reduce
# make ONLY_FUNCS="AllReduce|Reduce"
#
# # Only non-reductions:
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
#
# # Only AllReduce Sum int32_t (but all algos, protos)
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
#
# # Only AllReduce RING Max float (but all protos)
# make ONLY_FUNCS="AllReduce RING * Max float"
#
# # AllReduce TREE LL128 Prod rccl_bfloat16
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
#
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
set(AllGather_Params "RING" "*" "Sum" "int8_t")
set(AllReduce_Params "*" "*" "*" "*")
set(AllToAllPivot_Params "RING" "SIMPLE" "Sum" "int8_t")
set(Broadcast_Params "RING" "*" "Sum" "int8_t")
set(Reduce_Params "RING" "*" "*" "*")
set(ReduceScatter_Params "RING" "*" "*" "*")
set(SendRecv_Params "RING" "SIMPLE" "Sum" "int8_t")
#############################################################################################################
## Helper function to check if the conditions for the collective is being met
#############################################################################################################
function(validate_func ITEM_LIST)
set(paramIdx 1)
## Extract coll/redop/type
list(GET ITEM_LIST 0 coll)
list(GET ITEM_LIST 3 redop)
list(GET ITEM_LIST 4 type)
## First check if redop 'SumPostDiv' has no type float
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
set(is_valid FALSE PARENT_SCOPE)
return()
endif()
foreach(parameter IN LISTS "${coll}_Params")
if(NOT parameter STREQUAL "*")
list(GET ITEM_LIST "${paramIdx}" item)
string(FIND "${parameter}" "${item}" is_found)
if(is_found EQUAL -1)
set(is_valid FALSE PARENT_SCOPE)
return()
endif()
endif()
math(EXPR paramIdx "${paramIdx} + 1")
endforeach()
set(is_valid TRUE PARENT_SCOPE)
endfunction()
#############################################################################################################
## A recursive helper macro to generate functions and kernels based on the input given
#############################################################################################################
macro(filter_functions FUNCTION_PARAMS current_idx)
## Check if the current_idx does not exceed the max depth
if(${current_idx} LESS 5)
## current_element is the config parameter
list(GET FUNCTION_PARAMS ${current_idx} current_element)
## If the parameter is equal to '*', include all the possible cases for it
if(${current_element} STREQUAL "*")
if(${current_idx} EQUAL 0)
message(FATAL_ERROR "Error: Parameter 'COLL' can not be type all '*'.")
endif()
## ALL_PARAMS list must be in the same order as FUNCTION_PARAMS ---> <coll> <algo> <proto> <redop> <type>
## Find the respective parameter list from ALL_PARAMS list
list(GET ALL_PARAMS ${current_idx} current_list)
## Iterate over the items int the current_list
foreach(item IN LISTS ${current_list})
## Add item to ITEM_LIST which will be used in the inner most loop
list(APPEND ITEM_LIST ${item})
math(EXPR new_idx "${current_idx} + 1")
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
## For each loop layer remove the last element in ITEM_LIST
list(REMOVE_AT ITEM_LIST -1)
endforeach()
else()
## Check if the current element is recognized
list(GET ALL_PARAMS ${current_idx} current_param)
string(REPLACE "/" ";" elements ${current_element})
## Iterate over the elements int the ELEMENTS_LIST
foreach(item IN LISTS elements)
list(FIND ${current_param} ${item} is_valid)
if(${is_valid} EQUAL -1)
message(FATAL_ERROR "Error: ${item} is unrecognized or does not belong to this category ${current_param}.")
endif()
endforeach()
foreach(item IN LISTS elements)
## Add item to ITEM_LIST which will be used in the inner most loop
list(APPEND ITEM_LIST ${item})
math(EXPR new_idx "${current_idx} + 1")
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
## For each loop layer remove the last element in ITEM_LIST
list(REMOVE_AT ITEM_LIST -1)
endforeach()
endif()
else()
## This is the inner most loop where the file is generated
## Unzip ITEM_LIST
list(GET ITEM_LIST 0 COLL)
list(GET ITEM_LIST 1 ALGO)
list(GET ITEM_LIST 2 PROTO)
list(GET ITEM_LIST 3 REDOP)
list(GET ITEM_LIST 4 TYPE)
validate_func("${ITEM_LIST}")
if (NOT is_valid)
continue()
endif()
list(APPEND COLL_LIST "${COLL}-${ALGO}-${PROTO}-${REDOP}-${TYPE}")
set(COLL_LIST ${COLL_LIST} PARENT_SCOPE)
## Append the newly formed function/kernel to list
list(APPEND FUNC_LIST "ncclDevFunc_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
list(APPEND KERN_LIST "ncclDevKernel_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
set(FUNC_LIST ${FUNC_LIST} PARENT_SCOPE)
set(KERN_LIST ${KERN_LIST} PARENT_SCOPE)
endif()
endmacro()
#####################################################################################################
## Function to generate device table
#####################################################################################################
function(gen_device_table)
## Generate device table and list all the functions
set(DEVICE_TABLE_H_FILE "${HIPIFY_DIR}/src/device/device_table.h")
message(STATUS "Generating ${DEVICE_TABLE_H_FILE}")
if(ENABLE_IFC)
set(func_declaration "__device__ void")
else()
set(func_declaration "__device__ __attribute__((noinline)) void")
endif()
## Declaration of device functions
foreach(func IN LISTS FUNC_LIST)
string(FIND "${func}" "LL128" IS_LL128)
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n#else\n")
string(REPLACE "LL128" "LL" func "${func}")
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n#endif\n")
else()
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n")
endif()
endforeach()
file(APPEND ${DEVICE_TABLE_H_FILE} "\n")
## Undirect function call
file(APPEND ${DEVICE_TABLE_H_FILE} "typedef void(*ncclDevFuncPtr_t)();\n\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n")
foreach(func ${FUNC_LIST})
string(FIND "${func}" "LL128" IS_LL128)
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#else\n")
string(REPLACE "LL128" "LL" func "${func}")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#endif\n")
else()
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
endif()
endforeach()
file(APPEND ${DEVICE_TABLE_H_FILE} "nullptr};\n\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
foreach(func ${FUNC_LIST})
string(FIND "${func}" "LL128" IS_LL128)
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n#else\n")
string(REPLACE "LL128" "LL" func "${func}")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n#endif\n")
else()
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n")
endif()
endforeach()
file(APPEND ${DEVICE_TABLE_H_FILE} "nullptr};\n\n")
if(NOT ENABLE_IFC)
## Direct functions calls
file(APPEND ${DEVICE_TABLE_H_FILE}
"template<unsigned short f, unsigned short l>\n"
"struct Caller {\n"
" static __forceinline__ __device__ __host__\n"
" void call(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller<f, m>::call(funcIndex) : Caller<m, l>::call(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable[f](); }\n"
"};\n"
)
file(APPEND ${DEVICE_TABLE_H_FILE} "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
list(LENGTH FUNC_LIST max_index)
file(APPEND ${DEVICE_TABLE_H_FILE} " Caller<0, ${max_index}>::call(funcIndex);\n}\n\n")
file(APPEND ${DEVICE_TABLE_H_FILE}
"template<unsigned short f, unsigned short l>\n"
"struct Caller4 {\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller4<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
"};\n"
)
file(APPEND ${DEVICE_TABLE_H_FILE} "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
list(LENGTH FUNC_LIST max_index)
file(APPEND ${DEVICE_TABLE_H_FILE} " Caller4<0, ${max_index}>::call4(funcIndex);\n}\n\n")
endif()
## Function name table for collective trace
if(COLLTRACE)
set(DEVICE_TABLE_FILE "${HIPIFY_DIR}/src/device/device_table.cpp")
message(STATUS "Generating ${DEVICE_TABLE_FILE}")
file(APPEND ${DEVICE_TABLE_FILE} "#include \"nccl_common.h\"\n#include \"device.h\"\n\n const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
foreach(func ${FUNC_LIST})
file(APPEND ${DEVICE_TABLE_FILE} " \"${func}\",\n")
endforeach()
foreach(type IN LISTS ALL_TYPES)
file(APPEND ${DEVICE_TABLE_FILE} " \"ncclDevFunc_OneRankReduce_PreMulSum_${type}\",\n")
endforeach()
file(APPEND ${DEVICE_TABLE_FILE} "};\n")
endif()
## Add the device_table files to HIP_SOURCES
list(APPEND HIP_SOURCES ${DEVICE_TABLE_H_FILE})
list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
######################################################################################################
## Function to generate host-side table
######################################################################################################
function(gen_host_table)
set(HOST_TABLE_FILE "${HIPIFY_DIR}/src/device/host_table.cpp")
message(STATUS "Generating ${HOST_TABLE_FILE}")
file(WRITE ${HOST_TABLE_FILE} "#include \"device.h\"\n\n")
## The mapping from function rows to valid function ids
file(APPEND ${HOST_TABLE_FILE} "extern int const ncclDevFuncRowToId[] = {\n")
set(idx 0)
foreach(coll IN LISTS ALL_COLLS)
foreach(algo IN LISTS ALL_ALGOS)
foreach(proto IN LISTS ALL_PROTOS)
foreach(redop IN LISTS ALL_REDOPS)
foreach(type IN LISTS ALL_TYPES)
## Create a list from the combination of curr parameters
set(ITEM_LIST ${coll} ${algo} ${proto} ${redop} ${type})
validate_func("${ITEM_LIST}")
if (NOT is_valid)
continue()
endif()
## Try to find the combination in the generated func list
list(FIND FUNC_LIST "ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
if(NOT ${fn_id} EQUAL -1)
set(last_valid_fn_id ${fn_id})
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}\n")
else()
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n")
endif()
math(EXPR idx "${idx} + 1")
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
math(EXPR last_valid_fn_id "${last_valid_fn_id} + 1")
file(APPEND ${HOST_TABLE_FILE} "${last_valid_fn_id}};\n\n")
## Add the host_table file to HIP_SOURCES
list(APPEND HIP_SOURCES ${HOST_TABLE_FILE})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
###########################################################################################################
## Function to generate MSCCL Kernels
###########################################################################################################
function(gen_msccl_kernels)
set(MSCCL_REDOP Sum Prod MinMax)
foreach(REDOP_CURRENT IN LISTS MSCCL_REDOP)
foreach(DATA_TYPE ${ALL_TYPES})
set(FILE_NAME "${HIPIFY_DIR}/src/device/msccl_kernel_${REDOP_CURRENT}_${DATA_TYPE}.cpp")
message(STATUS "Generating ${FILE_NAME}")
file(WRITE ${FILE_NAME}
"#include \"msccl_kernel_impl.h\"
#include \"nccl_common.h\"
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);")
list(APPEND HIP_SOURCES ${FILE_NAME})
endforeach()
endforeach()
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
###########################################################################################################
## Function to generate collectives
###########################################################################################################
function(gen_collectives)
# Iterate over each item in the original list
foreach(item ${COLL_LIST})
# Split the string into components
string(REPLACE "-" ";" item_components ${item})
# Extract COLL, ALGO, and PROTO components
list(GET item_components 0 coll_prefix)
list(GET item_components 1 algo_prefix)
list(GET item_components 2 proto_prefix)
list(GET item_components 3 redop_prefix)
# Create a list name using COLL, ALGO, and PROTO
set(list_name "${coll_prefix}_${algo_prefix}_${proto_prefix}_${redop_prefix}")
# Add the item to the corresponding list
list(APPEND ${list_name} ${item})
# Add the list name to the map if it doesn't exist
if(NOT list_name IN_LIST divided_lists)
list(APPEND divided_lists ${list_name})
endif()
endforeach()
set(index 0)
foreach(list_name IN LISTS divided_lists)
foreach(item IN LISTS ${list_name})
# Convert to a list
string(REPLACE "-" ";" components ${item})
list(GET components 0 coll)
list(GET components 1 algo)
list(GET components 2 proto)
list(GET components 3 redop)
list(GET components 4 type)
list(APPEND IMPL_LIST "DEFINE_ncclDevFunc(${coll}_${algo}_${proto}_${redop}_${type}, ncclFunc${coll}, Func${redop}, ${type}, NCCL_ALGO_${algo}, NCCL_PROTO_${proto})\n")
# Increment the function id
math(EXPR index "${index} + 1")
endforeach()
## Store lower-case version of COLL
string(TOLOWER ${coll} COLL_LOWER)
string(REPLACE "scatter" "_scatter" COLL_LOWER ${COLL_LOWER})
if(NOT ${coll} STREQUAL "AllToAllPivot")
string(REPLACE "all" "all_" COLL_LOWER ${COLL_LOWER})
else()
string(REPLACE "pivot" "_pivot" COLL_LOWER ${COLL_LOWER})
endif()
## Set name/path of the file
set(FILE_PATH "${HIPIFY_DIR}/src/device/${list_name}.cpp")
message(STATUS "Generating ${FILE_PATH}")
## Construct the file
file(WRITE ${FILE_PATH} "#include \"${COLL_LOWER}.h\"\n#include \"common.h\"\n\n")
string(FIND "${list_name}" "LL128" IS_LL128)
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${FILE_PATH} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
endif()
foreach(IMPL IN LISTS IMPL_LIST)
file(APPEND ${FILE_PATH} "${IMPL}")
endforeach()
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${FILE_PATH} "#endif\n")
endif()
## Append the file to HIP sources list which will be added to source list
list(APPEND HIP_SOURCES ${FILE_PATH})
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
# Clear the IMPL list for the next iteration
set(IMPL_LIST)
endforeach()
endfunction()
###################################################################################################################
## Function to generate all the functions that are going to be in librccl.so
###################################################################################################################
function(gen_functions CONFIG_INPUT)
string(REPLACE "|" ";" INPUT_LIST ${CONFIG_INPUT})
## Sort the input so that it matches ALL_COLLS
list(SORT INPUT_LIST)
foreach(INPUT ${INPUT_LIST})
# Parse the the config string and make it a list
string(REPLACE " " ";" FUNCTION_PARAMS ${INPUT})
# Get the number of parameters in the input
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
# Assume all if a parameter is missing
while(${PARAMS_LENGTH} LESS 5)
list(APPEND FUNCTION_PARAMS "*")
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
endwhile()
## Filter functions/kernels based on input
filter_functions(FUNCTION_PARAMS 0)
endforeach()
gen_collectives() ## Generate collective files
if(ENABLE_MSCCL_KERNEL)
gen_msccl_kernels() ## Generate msccl files (not configurable)
endif()
gen_device_table() ## Generate device_table.cpp
gen_host_table() ## Generate host_table.cpp
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
-35
Vedi File
@@ -1,35 +0,0 @@
# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
## This only applies to collective header files
if(HIP_FILE MATCHES ".*/src/device/.*\\.h$")
execute_process(COMMAND sed -i "s/template<typename T, typename RedOp, typename Proto>/template<typename T, typename RedOp, typename Proto, int COLL_UNROLL>/g" ${HIP_FILE})
execute_process(COMMAND sed -i "s/template<typename T, typename RedOp>/template<typename T, typename RedOp, int COLL_UNROLL>/g" ${HIP_FILE})
execute_process(COMMAND sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/\\(runRing<T[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>>/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>, COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/\\(runTreeSplit<T[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/\\(struct RunWorkElement<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
execute_process(COMMAND sed -i "s/\\(struct RunWork<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
message(STATUS "Added COLL_UNROLL template argument to ${HIP_FILE}")
endif()
+36
Vedi File
@@ -0,0 +1,36 @@
# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
HIP_FILE=$1
if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then
sed -i "s/template<typename T, typename RedOp, typename Proto>/template<typename T, typename RedOp, typename Proto, int COLL_UNROLL>/g" "$HIP_FILE"
sed -i "s/template<typename T, typename RedOp>/template<typename T, typename RedOp, int COLL_UNROLL>/g" "$HIP_FILE"
sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(runRing<T[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>>/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(runTreeSplit<T[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(struct RunWorkElement<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(struct RunWork<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
echo "Added COLL_UNROLL template argument to $HIP_FILE"
fi
+297 -237
Vedi File
@@ -3,11 +3,13 @@ import os
import sys
# Order of redops, tys, protos, algos must match src/include/device.h
all_colls = ["Broadcast","Reduce","AllGather","ReduceScatter","AllReduce","SendRecv"]
all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"]
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16"]
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16", "f8", "bf8"]
all_protos = ["LL","LL128","SIMPLE"]
all_algos = ["TREE","RING","COLLNET_DIRECT","COLLNET_CHAIN","NVLS","NVLS_TREE"]
all_algos = ["TREE","RING"]
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys]
################################################################################
# The first command line argument is the path to the directory to generate and
@@ -20,71 +22,105 @@ if os.path.exists(gensrc):
os.remove(os.path.join(gensrc, name))
#os.truncate(os.path.join(gensrc, name), 0)
else:
os.mkdir(gensrc)
os.makedirs(gensrc)
################################################################################
# The second command line argument is used as a regex to filter the functions
# which make it into libnccl. This is helpful for reducing the binary when
# The command line argument is used as a regex to filter the functions
# which make it into librccl. This is helpful for reducing the binary when
# developing device code. The regex supports non-space containing globs '*',
# parentheses '(x)', and union 'a|b'. The string representing the function has
# one of the forms:
# and union 'a|b'. The string representing the function has the form:
#
# SendRecv
# (AllGather|Broadcast) <algo> <proto>
# (AlLReduce|Reduce|ReduceScatter) <redop> <type> <algo> <proto>
# <coll> <algo> <proto> <redop> <type>
#
# The possible values for redop, type, algo, proto can be found in the all_<foo>
# lists at the top of this file.
#
# Since the Makefile forwards this from the ONLY_FUNCS variable, useful command
# line examples are given:
"""
# Only send/recv:
make ONLY_FUNCS="SendRecv"
# Only non-reductions:
make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
# Only AllReduce sum f32 (but all algos, protos)
make ONLY_FUNCS="AllReduce Sum f32 * *"
# Only AllReduce minmax i32 NVLS (but all protos)
make ONLY_FUNCS="AllReduce MinMax i32 NVLS *"
# AllReduce sum <all floats> RING LL128
make ONLY_FUNCS="AllReduce Sum f32 RING LL128"
"""
# Example use-cases:
#
# # Only send/recv:
# make ONLY_FUNCS="SendRecv"
#
# # Only AllReduce and Reduce
# make ONLY_FUNCS="AllReduce|Reduce"
#
# # Only non-reductions:
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
#
# # Only AllReduce Sum int32_t (but all algos, protos)
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
#
# # Only AllReduce RING Max float (but all protos)
# make ONLY_FUNCS="AllReduce RING * Max float"
#
# # AllReduce TREE LL128 Prod rccl_bfloat16
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
#
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
# Paste all non-None arguments together with `sep`.
def paste(sep, *args):
return sep.join(x for x in args if x is not None)
func_pattern = sys.argv[2:3]
is_ifc = 1 if sys.argv[2] == "ON" else 0
is_colltrace = 1 if sys.argv[3] == "ON" else 0
is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0
func_pattern = sys.argv[5:6]
if func_pattern and func_pattern[0]:
import re
func_pattern = func_pattern[0]
func_pattern = func_pattern.replace("*", "[^ ]*")
func_pattern += "$"
def func_filter(*fn):
return None is not re.match(func_pattern, paste(" ", *fn), flags=re.IGNORECASE)
else:
def func_filter(coll, redop, ty, algo, proto):
return True
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
################################################################################
algos_of_coll = {
"AllGather": ["RING","COLLNET_DIRECT","NVLS"],
"AllGather": ["RING"],
"AllReduce": all_algos,
"AllToAllPivot": ["RING"],
"Broadcast": ["RING"],
"Reduce": ["RING"],
"ReduceScatter": ["RING","COLLNET_DIRECT","NVLS"],
"SendRecv": [None]
"ReduceScatter": ["RING"],
"SendRecv": ["RING"]
}
protos_of_coll = {
"AllGather": all_protos,
"AllReduce": all_protos,
"AllToAllPivot": ["SIMPLE"],
"Broadcast": all_protos,
"Reduce": all_protos,
"ReduceScatter": all_protos,
"SendRecv": ["SIMPLE"]
}
redops_of_coll = {
"AllGather": ["Sum"],
"AllReduce": all_redops,
"AllToAllPivot": ["Sum"],
"Broadcast": ["Sum"],
"Reduce": all_redops,
"ReduceScatter": all_redops,
"SendRecv": ["Sum"]
}
tys_of_coll = {
"AllGather": ["i8"],
"AllReduce": all_tys,
"AllToAllPivot": ["i8"],
"Broadcast": ["i8"],
"Reduce": all_tys,
"ReduceScatter": all_tys,
"SendRecv": ["i8"]
}
coll_camel_to_lower = {
"AllGather": "all_gather",
"AllReduce": "all_reduce",
"AllToAllPivot": "alltoall_pivot",
"Broadcast": "broadcast",
"Reduce": "reduce",
"ReduceScatter": "reduce_scatter",
@@ -94,141 +130,237 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
################################################################################
# Returns pair of minimum required values for (CUDART_VERSION, __CUDA_ARCH__)
# or None if function is never supported. Note that (0, 0) encodes universal
# support.
def required_cuda(coll, redop, ty, algo, proto):
cudart, arch = 0, 0
# kernels mapped to by coll="Nop" functions have coll="Generic"
if coll in ("SendRecv", "Generic", "Nop"): return (cudart, arch)
# Helper function to check if the conditions for the collective is being met
def func_validate(coll, algo, proto, redop, ty):
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
return False
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
return False
return True
if proto!="SIMPLE" and algo not in ("RING","TREE"): return None
# A recursive helper to generate collective functions based on the input given
def func_filter(function_params, current_idx, item_list=None):
if item_list is None:
item_list = []
if coll in ("AllReduce","Reduce","ReduceScatter"):
if redop=="SumPostDiv" and ty[0] not in ("i","u"): return None
if ty=="bf16": cudart = max(cudart, 11000)
# Check if current_idx exceeds the max depth
if current_idx < len(all_params):
# Current element is the config parameter
current_element = function_params[current_idx]
if "NVLS" in algo:
if coll in ("AllReduce","Reduce","ReduceScatter"):
# Must match ncclNvlsSupported() in src/include/device.h
nvls_ok = ((ty in ("i32","u32","i64","u64") and redop in ("Sum","MinMax")) or
(ty in ("f32","f64") and redop=="Sum") or
(ty in ("f16","bf16") and redop in ("Sum","MinMax")))
if not nvls_ok: return None
cudart = max(cudart, 12010)
arch = max(arch, 900)
# If the paramter is equal to '*', include all possible cases for it
if current_element == "*":
if current_idx == 0:
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
# Get the current list from all_params
current_list = all_params[current_idx]
return (cudart, arch)
# Iterate over the items int the current_list
for item in current_list:
# Add item to item_list which will be used in the inner most loop
item_list.append(item)
yield from func_filter(function_params, current_idx+1, item_list)
# For each loop layer remove the last element in item_list
item_list.pop()
else:
# Check if the current element is recognized
elements = current_element.split("/")
current_param = all_params[current_idx]
# Iterate over the elements in the elements list
for item in elements:
if item not in current_param:
raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.")
for item in elements:
item_list.append(item)
yield from func_filter(function_params, current_idx+1, item_list)
# For each loop layer remove the last element in item_list
item_list.pop()
else:
coll, algo, proto, redop, ty = item_list
if func_validate(*item_list):
yield(coll, algo, proto, redop, ty)
# Parse ONLY_FUNCS input and feed it to func_filter
def parse_input(func_pattern):
input_list = sorted(func_pattern.split("|"))
for input in input_list:
function_params = input.split()
params_length = len(function_params)
# If a parameter is missing, append '*'
while params_length < len(all_params):
function_params.append("*")
params_length += 1
# Filter functions/kernels based on input
yield from func_filter(function_params, 0)
# Maps functions to the chosen representative for the equivalence class it
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
def equivalent_primary(coll, redop, ty, algo, proto):
def equivalent_primary(coll, algo, proto, redop, ty):
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
# map signed integer sum/prod to unsigned
if redop in ("Sum","Prod","PreMulSum") and ty[0]=="i":
return (coll, redop, "u"+ty[1:], algo, proto)
ty = "u"+ty[1:]
# map signed integer min/max to unsigned for non-NVLS
if redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
return (coll, redop, "u"+ty[1:], algo, proto)
return (coll, redop, ty, algo, proto)
# Map to another func representing the best kernel to use. Every distinct value
# returned will instantiate a ncclDevKernel specialized to run this func
# without function call overhead.
def best_kernel(coll, redop, ty, algo, proto):
def best(coll, redop, ty, algo, proto):
# Modify this logic to control how many kernels are specialized.
if coll=="Nop": return ("Generic", None, None, None, None)
if coll=="SendRecv": return ("SendRecv", None, None, None, None)
if coll in ("AllGather","Broadcast"): return (coll, None, None, "RING", "LL")
return (coll, "Sum", ty, ("TREE" if algo=="TREE" else "RING"), "LL")
# Need to ensure kernel is specialize for a primary function
kfn = equivalent_primary(*best(coll, redop, ty, algo, proto))
# And isn't filtered out.
if not func_filter(*kfn): return ("Generic", None, None, None, None)
return kfn
elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
ty = "u"+ty[1:]
return (coll, algo, proto, redop, ty)
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
def enumerate_func_rows():
yield ("SendRecv", None, None, None, None)
for coll in ("AllGather", "Broadcast"):
algos = algos_of_coll[coll]
for algo in algos:
for coll in all_colls:
for algo in all_algos:
for proto in all_protos:
yield (coll, None, None, algo, proto)
for coll in ("AllReduce", "Reduce", "ReduceScatter"):
algos = algos_of_coll[coll]
for redop in all_redops:
for ty in all_tys:
for algo in algos:
for proto in all_protos:
yield (coll, redop, ty, algo, proto)
for redop in all_redops:
for ty in all_tys:
if func_validate(coll, algo, proto, redop, ty):
yield (coll, algo, proto, redop, ty)
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
def custom_sort_key(fn):
coll, algo, proto, redop, ty = fn
return (
all_colls.index(coll),
all_algos.index(algo),
all_protos.index(proto),
all_redops.index(redop),
all_tys.index(ty)
)
################################################################################
def is_built(coll, redop, ty, algo, proto):
built = required_cuda(coll, redop, ty, algo, proto)
built = built and func_filter(coll, redop, ty, algo, proto)
return built
# Returns None if required_cuda(...) is None.
# Returns the coll="Nop" function if developer has filtered it out.
# Otherwise just returns func it was given.
def validate(coll, redop, ty, algo, proto):
valid = required_cuda(coll, redop, ty, algo, proto)
built = valid and func_filter(coll, redop, ty, algo, proto)
if built: return (coll, redop, ty, algo, proto)
if valid: return ("Nop", None, None, None, None)
return None
# Corresponds to ncclDevFuncRowToId[]
func_rows = [validate(*fn) for fn in enumerate_func_rows()]
func_rows = [fn for fn in enumerate_func_rows()]
# Corresponds to ncclDevFuncTable[]
primary_funcs = sorted(set(equivalent_primary(*fn) for fn in func_rows if fn is not None))
primary_funcs = sorted(set(equivalent_primary(*fn) for fn in parse_input(func_pattern)), key=custom_sort_key)
# primary_to_index[primary_funcs[i]] == i
primary_to_index = {fn: i for (i,fn) in zip(range(len(primary_funcs)), primary_funcs)}
kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs))
primary_to_index = {fn: primary_funcs.index(fn) if fn in primary_funcs else -1 for fn in func_rows}
################################################################################
# Generate <gensrc>/device_table.cu
with open(os.path.join(gensrc, "device_table.cu"), "w") as f:
# Generate <gensrc>/device_table.h
with open(os.path.join(gensrc, "device_table.h"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, "device_table.h"))
out = f.write
out('#include "common.h"\n')
out("\n")
if is_ifc: func_declaration = "__device__ void"
else: func_declaration = "__device__ __attribute__((noinline)) void"
for fn in primary_funcs:
sym = paste("_", "ncclDevFunc", *fn)
cudart, arch = required_cuda(*fn)
if (cudart, arch) != (0, 0):
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
out("__device__ void %s();\n" % sym)
if (cudart, arch) != (0, 0):
out("#endif\n")
if fn[2] == "LL128":
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
out("%s %s();\n%s %s_4();\n#else\n" % (func_declaration, sym, func_declaration, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("%s %s();\n%s %s_4();\n#endif\n" % (func_declaration, sym_ll, func_declaration, sym_ll))
else:
out("%s %s();\n%s %s_4();\n" % (func_declaration, sym, func_declaration, sym))
out("\n")
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n");
out("typedef void(*ncclDevFuncPtr_t)();\n\n")
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n")
index = 0
for fn in primary_funcs:
sym = paste("_", "ncclDevFunc", *fn)
cudart, arch = required_cuda(*fn)
if (cudart, arch) != (0, 0):
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart ,arch))
out("/*%4d*/ %s,\n" % (index, sym))
if (cudart, arch) != (0, 0):
out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
if fn[2] == "LL128":
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s,\n#else\n" % (index, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s,\n#endif\n" % (index, sym_ll))
else:
out("/*%4d*/ %s,\n" % (index, sym))
index += 1
out("nullptr};\n")
out("\n")
out("// Workaround for https://reviews.llvm.org/D55580\n"
"__device__ void ncclWorkaroundClangD55580() {}\n")
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
index = 0
for fn in primary_funcs:
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
out("/*%4d*/ %s_4,\n#else\n" % (index, sym))
fn_ll = fn[:2] + ("LL",) + fn[3:]
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
out("/*%4d*/ %s_4,\n#endif\n" % (index, sym_ll))
else:
out("/*%4d*/ %s_4,\n" % (index, sym))
index += 1
out("nullptr};\n")
out("\n")
if not is_ifc:
out("template<unsigned short f, unsigned short l>\n"
"struct Caller {\n"
" static __forceinline__ __device__ __host__\n"
" void call(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller<f, m>::call(funcIndex) : Caller<m, l>::call(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable[f](); }\n"
"};\n")
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
out(f" Caller<0, {index}>::call(funcIndex);\n")
out("}\n\n")
out("template<unsigned short f, unsigned short l>\n"
"struct Caller4 {\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept\n"
" {\n"
" constexpr unsigned short m = f + (l - f) / 2;\n"
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
" }\n"
"};\n"
"\n"
"template<unsigned short f>\n"
"struct Caller4<f, f + 1>{\n"
" static __forceinline__ __device__ __host__\n"
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
"};\n")
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
out(f" Caller4<0, {index}>::call4(funcIndex);\n")
out("}\n\n")
# Generate <gensrc>/device_table.cpp
if is_colltrace:
with open(os.path.join(gensrc, "device_table.cpp"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, "device_table.cpp"))
out = f.write
out('#include "nccl_common.h"\n#include "device.h"\n')
out("\n")
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
for fn in primary_funcs:
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn))
for ty in all_tys:
out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n')
out("};\n")
# Generate <gensrc>/host_table.cpp
with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, "host_table.cpp"))
# Generate <gensrc>/host_table.cc
with open(os.path.join(gensrc, "host_table.cc"), "w") as f:
out = f.write
out('#include "device.h"\n')
out("\n")
@@ -243,61 +375,14 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f:
comment = " // " + paste(" ", *fn)
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
index += 1
out("-1};\n")
out("\n")
# Forward declarations of kernels.
for kfn in kernel_funcs:
cudart, _ = required_cuda(*kfn)
sym = paste("_", "ncclDevKernel", *kfn)
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
out("__global__ void %s(struct ncclDevComm*, uint64_t, struct ncclWork*);\n" % sym)
if cudart != 0: out("#endif\n")
out("\n")
# List of all kernel function pointers.
out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs))
out("extern void* const ncclDevKernelList[] = {\n")
index = 0
for kfn in kernel_funcs:
cudart, _ = required_cuda(*kfn)
sym = paste("_", "ncclDevKernel", *kfn)
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
out("/*%4d*/ (void*)%s,\n" % (index, sym));
if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
index += 1
out("nullptr};\n")
out("\n")
# Maps primary id to kernel function pointer.
out("extern void* const ncclDevKernelForFunc[] = {\n")
index = 0
for fn in primary_funcs:
kfn = best_kernel(*fn)
sym = paste("_", "ncclDevKernel", *kfn)
cudart, _ = required_cuda(*kfn)
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
out("/*%4d*/ (void*)%s,\n" % (index, sym))
if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
index += 1
out("nullptr};\n")
out("\n")
# Does the prior map use an explicitly specialized kernel.
out("extern bool const ncclDevKernelForFuncIsSpecialized[] = {\n")
index = 0
for fn in primary_funcs:
kfn = best_kernel(*fn)
specialized = "1" if fn == kfn else "0"
out("/*%4d*/ %s,\n" % (index, specialized))
index += 1
out("0};\n")
out(f"{index}")
out("};\n")
# Maps to .cu filename which implements this func. The only constraint is that
# "coll" is reflected in the name: formally that no two funcs having different
# coll's map to the same filename.
def impl_filename(coll, redop, ty, algo, proto):
return "%s.cu" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
def impl_filename(coll, algo, proto, redop, ty):
return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
# Partition the functions and kernels to the .cu filenames. The partition is
# a dictionary mapping filename to (coll, func-tuple list)
@@ -312,33 +397,6 @@ def partition_by_name(fns):
return ans
name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn[0]!="Nop")
name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Generic")
# Generate <gensrc>/rules.mk
with open(os.path.join(gensrc, "rules.mk"), "w") as f:
out = f.write
impl_names = sorted(name_to_funcs.keys())
names = impl_names + ["host_table.cc", "device_table.cu"]
out("LIB_OBJS_GEN = $(patsubst %, $(OBJDIR)/genobj/%.o, {names})\n"
.format(names=" ".join(names)))
out("\n")
# For each <coll>_<op>_<ty>.cu compile to a .cu.o file. Notice the dependencies
# come from the suffix-erased file (e.g. 'gensrc/all_reduce.cu')
for name in impl_names:
coll = name_to_funcs[name][0]
out(
"$(OBJDIR)/genobj/{name}.o: $(OBJDIR)/gensrc $(OBJDIR)/genobj/{lower_coll}.cu.d\n"
"\t" "$(call COMPILE,$@,$(OBJDIR)/gensrc/{name})\n"
"\n"
.format(name=name, lower_coll=coll_camel_to_lower[coll])
)
# Add the suffix-erased .cu's which are used only for dependency scraping.
for coll in set(coll for (coll,_,_,_,_) in primary_funcs if coll!="Nop"):
name = impl_filename(coll, None, None, None, None)
if name not in name_to_funcs:
name_to_funcs[name] = (coll, [])
redop_to_cxx = {
None: "FuncCopy",
@@ -360,13 +418,17 @@ ty_to_cxx = {
"f16": "half",
"f32": "float",
"f64": "double",
"bf16": "__nv_bfloat16"
"bf16": "hip_bfloat16",
"f8": "rccl_float8",
"bf8": "rccl_bfloat8",
}
# Generate each <gensrc>/<impl>.cu:
# Generate each <gensrc>/<impl>.cpp:
for name in name_to_funcs.keys():
(coll, fns) = name_to_funcs[name]
with open(os.path.join(gensrc, name), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, name))
out = f.write
out(
'#include "common.h"\n'
@@ -374,32 +436,30 @@ for name in name_to_funcs.keys():
.format(lower_coll=coll_camel_to_lower[coll])
)
(_, kfns) = name_to_kernels.get(name) or (None, [])
for kfn in kfns:
(coll, redop, ty, algo, proto) = kfn
sym = paste("_", coll, redop, ty, algo, proto)
fn_id = primary_to_index[kfn]
cudart, arch = required_cuda(*kfn)
if (cudart, arch) != (0, 0):
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
out(
"DEFINE_ncclDevKernel({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {fn_id})\n"
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
algo=(algo or "RING"), proto=(proto or "SIMPLE"), fn_id=fn_id)
)
if (cudart, arch) != (0, 0):
out("#endif\n")
for fn in fns:
(coll, redop, ty, algo, proto) = fn
sym = paste("_", coll, redop, ty, algo, proto)
cudart, arch = required_cuda(*fn)
if (cudart, arch) != (0, 0):
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
(coll, algo, proto, redop, ty) = fn
sym = paste("_", coll, algo, proto, redop, ty)
if proto == "LL128":
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
out(
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto})\n"
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
algo=(algo or "RING"), proto=(proto or "SIMPLE"))
)
if (cudart, arch) != (0, 0):
if proto == "LL128":
out("#endif\n")
# Generate each <gensrc>/<msccl_impl>.cpp
if is_msccl_kernels:
for redop in all_redops:
if redop in ("Sum", "Prod", "MinMax"):
for ty in all_tys:
with open(os.path.join(gensrc, f"msccl_kernel_{redop}_{ty}.cpp"), "w") as f:
print("-- Generating %s" % os.path.join(gensrc, f"msccl_kernel_{redop}_{ty}.cpp"))
out = f.write
out('#include "msccl_kernel_impl.h"\n#include "nccl_common.h"\n')
out(
"MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE({redop}, {ty_cxx}, false);\n"
.format(redop=redop, ty_cxx=ty_to_cxx[ty])
)
+2 -2
Vedi File
@@ -552,7 +552,7 @@ inline bool ncclNvlsSupported(int devRedOp, int type) {
// Map the rowIdx to funcIdx
extern int const ncclDevFuncRowToId[];
// `ncclFuncIndex()` needs to be in sync with 'ALL_COLLS' in Generate.cmake
// `ncclDevFuncId()` needs to be in sync with 'ALL_COLLS' in generate.py
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) {
int row = 0;
do {
@@ -568,7 +568,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto)
row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto);
break;
}
row += (NCCL_NUM_ALGORITHMS - 2) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
row += (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
// RING / SIMPLE / Sum / int8_t
if (coll == ncclFuncAllToAllPivot) break;
+1 -1
Vedi File
@@ -13,7 +13,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
#define NCCL_NUM_ONERANK 12
#define FUNC_INDEX_TOTAL 980 + NCCL_NUM_ONERANK
#define FUNC_INDEX_TOTAL 656 + NCCL_NUM_ONERANK
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
typedef enum {