[BUILD] Move code generation to python from CMake (#1360)
* Use generate.py for func generation * Convert AddUnroll.cmake to bash
Bu işleme şunda yer alıyor:
işlemeyi yapan:
GitHub
ebeveyn
038517b169
işleme
2dd10c8f17
+30
-14
@@ -55,7 +55,6 @@ set(DEFAULT_GPUS
|
||||
include(CheckIncludeFiles)
|
||||
include(CheckSymbolExists)
|
||||
include(cmake/Dependencies.cmake) # GTest, rocm-cmake, rocm_local_targets
|
||||
include(cmake/Generator.cmake) # Configure functions that goes into RCCL
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
@@ -207,6 +206,16 @@ if(ENABLE_IFC)
|
||||
set(IFC_ENABLED OFF)
|
||||
message(WARNING "Indirect function call disabled - requires HIP version >= 5.5.30201")
|
||||
endif()
|
||||
else()
|
||||
set(IFC_ENABLED OFF)
|
||||
endif()
|
||||
|
||||
## Check for LL128 support
|
||||
if(${hip_version_string} VERSION_GREATER_EQUAL "6.1.33591")
|
||||
set(LL128_ENABLED ON)
|
||||
message(STATUS "RCCL LL128 protocol enabled")
|
||||
else()
|
||||
message(STATUS "RCCL LL128 protocol disabled - requires HIP version >= 6.1.33591")
|
||||
endif()
|
||||
|
||||
## Check for hsa-runtime64
|
||||
@@ -574,7 +583,7 @@ foreach(SRC_FILE ${SRC_FILES})
|
||||
OUTPUT ${HIP_FILE}
|
||||
COMMAND mkdir -p ${HIP_FILE_DIR}
|
||||
&& ${hipify-perl_executable} -quiet-warnings ${CMAKE_SOURCE_DIR}/${SRC_FILE} -o ${HIP_FILE}
|
||||
&& ${CMAKE_COMMAND} -DHIP_FILE=${HIP_FILE} -P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/AddUnroll.cmake
|
||||
&& ${CMAKE_COMMAND} -E env bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/add_unroll.sh ${HIP_FILE}
|
||||
MAIN_DEPENDENCY ${SRC_FILE}
|
||||
COMMENT "Hipifying ${SRC_FILE} -> ${HIP_FILE}"
|
||||
)
|
||||
@@ -582,13 +591,23 @@ endforeach()
|
||||
|
||||
# Generate device/host tables and all the collective functions that are going to be in librccl.so
|
||||
#==================================================================================================
|
||||
if(ONLY_FUNCS)
|
||||
## Generate only the specified functions
|
||||
gen_functions(${ONLY_FUNCS})
|
||||
else()
|
||||
# Generate all the functions
|
||||
gen_functions("AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv")
|
||||
endif()
|
||||
find_package(PythonInterp REQUIRED)
|
||||
set(GEN_DIR "${HIPIFY_DIR}/gensrc")
|
||||
|
||||
# Execute the python script to generate required files
|
||||
execute_process(
|
||||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS}
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
RESULT_VARIABLE result
|
||||
)
|
||||
|
||||
# Find the generated files in the output directory
|
||||
file(GLOB GENERATED_FILES "${GEN_DIR}/*")
|
||||
|
||||
# Append all found generated files to the list
|
||||
foreach(file ${GENERATED_FILES})
|
||||
list(APPEND HIP_SOURCES ${file})
|
||||
endforeach()
|
||||
|
||||
# Create an initial git_version.cpp file (that will be updated with latest git version)
|
||||
#==================================================================================================
|
||||
@@ -616,6 +635,7 @@ target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src) #
|
||||
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/include)
|
||||
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device)
|
||||
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/src/device/network/unpack)
|
||||
target_include_directories(rccl PRIVATE ${HIPIFY_DIR}/gensrc)
|
||||
target_include_directories(rccl PRIVATE ${HSA_INCLUDE_PATH})
|
||||
target_include_directories(rccl PRIVATE ${ROCM_SMI_INCLUDE_DIR})
|
||||
if(DEMANGLE_DIR)
|
||||
@@ -692,12 +712,8 @@ if(DEMANGLE_DIR)
|
||||
target_compile_definitions(rccl PRIVATE "HAVE_CPLUS_DEMANGLE=1")
|
||||
target_compile_definitions(rccl PRIVATE "HAVE_DECL_BASENAME=1")
|
||||
endif()
|
||||
if(${hip_version_string} VERSION_GREATER_EQUAL "6.1.33591")
|
||||
set(LL128_ENABLED ON)
|
||||
if(LL128_ENABLED)
|
||||
target_compile_definitions(rccl PRIVATE ENABLE_LL128)
|
||||
message(STATUS "RCCL LL128 protocol enabled")
|
||||
else()
|
||||
message(STATUS "RCCL LL128 protocol disabled - requires HIP version >= 6.1.33591")
|
||||
endif()
|
||||
|
||||
## Set RCCL compile options
|
||||
|
||||
@@ -1,479 +0,0 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
set(ALL_PARAMS "ALL_COLLS" "ALL_ALGOS" "ALL_PROTOS" "ALL_REDOPS" "ALL_TYPES")
|
||||
set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "ReduceScatter" "SendRecv")
|
||||
set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN")
|
||||
set(ALL_PROTOS "LL" "LL128" "SIMPLE")
|
||||
set(ALL_REDOPS "Sum" "Prod" "MinMax" "PreMulSum" "SumPostDiv")
|
||||
set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8")
|
||||
|
||||
set(FLOATS_LIST "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8")
|
||||
|
||||
################################################################################
|
||||
# The command line argument is used as a regex to filter the functions
|
||||
# which make it into librccl. This is helpful for reducing the binary when
|
||||
# developing device code. The regex supports non-space containing globs '*',
|
||||
# and union 'a|b'. The string representing the function has the form:
|
||||
#
|
||||
# <coll> <algo> <proto> <redop> <type>
|
||||
#
|
||||
# The possible values for redop, type, algo, proto can be found in the all_<foo>
|
||||
# lists at the top of this file.
|
||||
#
|
||||
# Example use-cases:
|
||||
#
|
||||
# # Only send/recv:
|
||||
# make ONLY_FUNCS="SendRecv"
|
||||
#
|
||||
# # Only AllReduce and Reduce
|
||||
# make ONLY_FUNCS="AllReduce|Reduce"
|
||||
#
|
||||
# # Only non-reductions:
|
||||
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
#
|
||||
# # Only AllReduce Sum int32_t (but all algos, protos)
|
||||
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
|
||||
|
||||
set(AllGather_Params "RING" "*" "Sum" "int8_t")
|
||||
set(AllReduce_Params "*" "*" "*" "*")
|
||||
set(AllToAllPivot_Params "RING" "SIMPLE" "Sum" "int8_t")
|
||||
set(Broadcast_Params "RING" "*" "Sum" "int8_t")
|
||||
set(Reduce_Params "RING" "*" "*" "*")
|
||||
set(ReduceScatter_Params "RING" "*" "*" "*")
|
||||
set(SendRecv_Params "RING" "SIMPLE" "Sum" "int8_t")
|
||||
|
||||
#############################################################################################################
|
||||
## Helper function to check if the conditions for the collective is being met
|
||||
#############################################################################################################
|
||||
function(validate_func ITEM_LIST)
|
||||
set(paramIdx 1)
|
||||
## Extract coll/redop/type
|
||||
list(GET ITEM_LIST 0 coll)
|
||||
list(GET ITEM_LIST 3 redop)
|
||||
list(GET ITEM_LIST 4 type)
|
||||
|
||||
## First check if redop 'SumPostDiv' has no type float
|
||||
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
|
||||
set(is_valid FALSE PARENT_SCOPE)
|
||||
return()
|
||||
endif()
|
||||
foreach(parameter IN LISTS "${coll}_Params")
|
||||
if(NOT parameter STREQUAL "*")
|
||||
list(GET ITEM_LIST "${paramIdx}" item)
|
||||
string(FIND "${parameter}" "${item}" is_found)
|
||||
if(is_found EQUAL -1)
|
||||
set(is_valid FALSE PARENT_SCOPE)
|
||||
return()
|
||||
endif()
|
||||
endif()
|
||||
math(EXPR paramIdx "${paramIdx} + 1")
|
||||
endforeach()
|
||||
set(is_valid TRUE PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#############################################################################################################
|
||||
## A recursive helper macro to generate functions and kernels based on the input given
|
||||
#############################################################################################################
|
||||
macro(filter_functions FUNCTION_PARAMS current_idx)
|
||||
## Check if the current_idx does not exceed the max depth
|
||||
if(${current_idx} LESS 5)
|
||||
## current_element is the config parameter
|
||||
list(GET FUNCTION_PARAMS ${current_idx} current_element)
|
||||
|
||||
## If the parameter is equal to '*', include all the possible cases for it
|
||||
if(${current_element} STREQUAL "*")
|
||||
if(${current_idx} EQUAL 0)
|
||||
message(FATAL_ERROR "Error: Parameter 'COLL' can not be type all '*'.")
|
||||
endif()
|
||||
## ALL_PARAMS list must be in the same order as FUNCTION_PARAMS ---> <coll> <algo> <proto> <redop> <type>
|
||||
## Find the respective parameter list from ALL_PARAMS list
|
||||
list(GET ALL_PARAMS ${current_idx} current_list)
|
||||
|
||||
## Iterate over the items int the current_list
|
||||
foreach(item IN LISTS ${current_list})
|
||||
## Add item to ITEM_LIST which will be used in the inner most loop
|
||||
list(APPEND ITEM_LIST ${item})
|
||||
math(EXPR new_idx "${current_idx} + 1")
|
||||
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
|
||||
|
||||
## For each loop layer remove the last element in ITEM_LIST
|
||||
list(REMOVE_AT ITEM_LIST -1)
|
||||
endforeach()
|
||||
else()
|
||||
## Check if the current element is recognized
|
||||
list(GET ALL_PARAMS ${current_idx} current_param)
|
||||
string(REPLACE "/" ";" elements ${current_element})
|
||||
## Iterate over the elements int the ELEMENTS_LIST
|
||||
foreach(item IN LISTS elements)
|
||||
list(FIND ${current_param} ${item} is_valid)
|
||||
if(${is_valid} EQUAL -1)
|
||||
message(FATAL_ERROR "Error: ${item} is unrecognized or does not belong to this category ${current_param}.")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(item IN LISTS elements)
|
||||
## Add item to ITEM_LIST which will be used in the inner most loop
|
||||
list(APPEND ITEM_LIST ${item})
|
||||
math(EXPR new_idx "${current_idx} + 1")
|
||||
filter_functions(${FUNCTION_PARAMS} ${new_idx} ${ARGN})
|
||||
|
||||
## For each loop layer remove the last element in ITEM_LIST
|
||||
list(REMOVE_AT ITEM_LIST -1)
|
||||
endforeach()
|
||||
endif()
|
||||
else()
|
||||
## This is the inner most loop where the file is generated
|
||||
## Unzip ITEM_LIST
|
||||
list(GET ITEM_LIST 0 COLL)
|
||||
list(GET ITEM_LIST 1 ALGO)
|
||||
list(GET ITEM_LIST 2 PROTO)
|
||||
list(GET ITEM_LIST 3 REDOP)
|
||||
list(GET ITEM_LIST 4 TYPE)
|
||||
|
||||
validate_func("${ITEM_LIST}")
|
||||
if (NOT is_valid)
|
||||
continue()
|
||||
endif()
|
||||
|
||||
list(APPEND COLL_LIST "${COLL}-${ALGO}-${PROTO}-${REDOP}-${TYPE}")
|
||||
set(COLL_LIST ${COLL_LIST} PARENT_SCOPE)
|
||||
|
||||
## Append the newly formed function/kernel to list
|
||||
list(APPEND FUNC_LIST "ncclDevFunc_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
|
||||
list(APPEND KERN_LIST "ncclDevKernel_${COLL}_${ALGO}_${PROTO}_${REDOP}_${TYPE}")
|
||||
set(FUNC_LIST ${FUNC_LIST} PARENT_SCOPE)
|
||||
set(KERN_LIST ${KERN_LIST} PARENT_SCOPE)
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
#####################################################################################################
|
||||
## Function to generate device table
|
||||
#####################################################################################################
|
||||
function(gen_device_table)
|
||||
## Generate device table and list all the functions
|
||||
set(DEVICE_TABLE_H_FILE "${HIPIFY_DIR}/src/device/device_table.h")
|
||||
message(STATUS "Generating ${DEVICE_TABLE_H_FILE}")
|
||||
|
||||
if(ENABLE_IFC)
|
||||
set(func_declaration "__device__ void")
|
||||
else()
|
||||
set(func_declaration "__device__ __attribute__((noinline)) void")
|
||||
endif()
|
||||
|
||||
## Declaration of device functions
|
||||
foreach(func IN LISTS FUNC_LIST)
|
||||
string(FIND "${func}" "LL128" IS_LL128)
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n#else\n")
|
||||
string(REPLACE "LL128" "LL" func "${func}")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n#endif\n")
|
||||
else()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n${func_declaration} ${func}_4();\n")
|
||||
endif()
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "\n")
|
||||
|
||||
## Undirect function call
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "typedef void(*ncclDevFuncPtr_t)();\n\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n")
|
||||
foreach(func ${FUNC_LIST})
|
||||
string(FIND "${func}" "LL128" IS_LL128)
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#else\n")
|
||||
string(REPLACE "LL128" "LL" func "${func}")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#endif\n")
|
||||
else()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
|
||||
endif()
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "nullptr};\n\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
foreach(func ${FUNC_LIST})
|
||||
string(FIND "${func}" "LL128" IS_LL128)
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n#else\n")
|
||||
string(REPLACE "LL128" "LL" func "${func}")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n#endif\n")
|
||||
else()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func}_4,\n")
|
||||
endif()
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "nullptr};\n\n")
|
||||
|
||||
if(NOT ENABLE_IFC)
|
||||
## Direct functions calls
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE}
|
||||
"template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller<f, m>::call(funcIndex) : Caller<m, l>::call(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable[f](); }\n"
|
||||
"};\n"
|
||||
)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
|
||||
list(LENGTH FUNC_LIST max_index)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " Caller<0, ${max_index}>::call(funcIndex);\n}\n\n")
|
||||
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE}
|
||||
"template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller4 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller4<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
|
||||
"};\n"
|
||||
)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
|
||||
list(LENGTH FUNC_LIST max_index)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " Caller4<0, ${max_index}>::call4(funcIndex);\n}\n\n")
|
||||
endif()
|
||||
|
||||
## Function name table for collective trace
|
||||
if(COLLTRACE)
|
||||
set(DEVICE_TABLE_FILE "${HIPIFY_DIR}/src/device/device_table.cpp")
|
||||
message(STATUS "Generating ${DEVICE_TABLE_FILE}")
|
||||
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "#include \"nccl_common.h\"\n#include \"device.h\"\n\n const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
foreach(func ${FUNC_LIST})
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " \"${func}\",\n")
|
||||
endforeach()
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
file(APPEND ${DEVICE_TABLE_FILE} " \"ncclDevFunc_OneRankReduce_PreMulSum_${type}\",\n")
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_FILE} "};\n")
|
||||
endif()
|
||||
|
||||
## Add the device_table files to HIP_SOURCES
|
||||
list(APPEND HIP_SOURCES ${DEVICE_TABLE_H_FILE})
|
||||
list(APPEND HIP_SOURCES ${DEVICE_TABLE_FILE})
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
######################################################################################################
|
||||
## Function to generate host-side table
|
||||
######################################################################################################
|
||||
function(gen_host_table)
|
||||
set(HOST_TABLE_FILE "${HIPIFY_DIR}/src/device/host_table.cpp")
|
||||
message(STATUS "Generating ${HOST_TABLE_FILE}")
|
||||
|
||||
file(WRITE ${HOST_TABLE_FILE} "#include \"device.h\"\n\n")
|
||||
|
||||
## The mapping from function rows to valid function ids
|
||||
file(APPEND ${HOST_TABLE_FILE} "extern int const ncclDevFuncRowToId[] = {\n")
|
||||
set(idx 0)
|
||||
foreach(coll IN LISTS ALL_COLLS)
|
||||
foreach(algo IN LISTS ALL_ALGOS)
|
||||
foreach(proto IN LISTS ALL_PROTOS)
|
||||
foreach(redop IN LISTS ALL_REDOPS)
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
## Create a list from the combination of curr parameters
|
||||
set(ITEM_LIST ${coll} ${algo} ${proto} ${redop} ${type})
|
||||
validate_func("${ITEM_LIST}")
|
||||
if (NOT is_valid)
|
||||
continue()
|
||||
endif()
|
||||
## Try to find the combination in the generated func list
|
||||
list(FIND FUNC_LIST "ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
|
||||
if(NOT ${fn_id} EQUAL -1)
|
||||
set(last_valid_fn_id ${fn_id})
|
||||
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id}, // ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}\n")
|
||||
else()
|
||||
file(APPEND ${HOST_TABLE_FILE} " /*${idx}*/ ${fn_id},\n")
|
||||
endif()
|
||||
math(EXPR idx "${idx} + 1")
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
math(EXPR last_valid_fn_id "${last_valid_fn_id} + 1")
|
||||
file(APPEND ${HOST_TABLE_FILE} "${last_valid_fn_id}};\n\n")
|
||||
|
||||
## Add the host_table file to HIP_SOURCES
|
||||
list(APPEND HIP_SOURCES ${HOST_TABLE_FILE})
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
###########################################################################################################
|
||||
## Function to generate MSCCL Kernels
|
||||
###########################################################################################################
|
||||
function(gen_msccl_kernels)
|
||||
set(MSCCL_REDOP Sum Prod MinMax)
|
||||
foreach(REDOP_CURRENT IN LISTS MSCCL_REDOP)
|
||||
foreach(DATA_TYPE ${ALL_TYPES})
|
||||
set(FILE_NAME "${HIPIFY_DIR}/src/device/msccl_kernel_${REDOP_CURRENT}_${DATA_TYPE}.cpp")
|
||||
message(STATUS "Generating ${FILE_NAME}")
|
||||
file(WRITE ${FILE_NAME}
|
||||
"#include \"msccl_kernel_impl.h\"
|
||||
#include \"nccl_common.h\"
|
||||
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(${REDOP_CURRENT}, ${DATA_TYPE}, false);")
|
||||
list(APPEND HIP_SOURCES ${FILE_NAME})
|
||||
endforeach()
|
||||
endforeach()
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
###########################################################################################################
|
||||
## Function to generate collectives
|
||||
###########################################################################################################
|
||||
function(gen_collectives)
|
||||
# Iterate over each item in the original list
|
||||
foreach(item ${COLL_LIST})
|
||||
# Split the string into components
|
||||
string(REPLACE "-" ";" item_components ${item})
|
||||
|
||||
# Extract COLL, ALGO, and PROTO components
|
||||
list(GET item_components 0 coll_prefix)
|
||||
list(GET item_components 1 algo_prefix)
|
||||
list(GET item_components 2 proto_prefix)
|
||||
list(GET item_components 3 redop_prefix)
|
||||
|
||||
# Create a list name using COLL, ALGO, and PROTO
|
||||
set(list_name "${coll_prefix}_${algo_prefix}_${proto_prefix}_${redop_prefix}")
|
||||
|
||||
# Add the item to the corresponding list
|
||||
list(APPEND ${list_name} ${item})
|
||||
|
||||
# Add the list name to the map if it doesn't exist
|
||||
if(NOT list_name IN_LIST divided_lists)
|
||||
list(APPEND divided_lists ${list_name})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(index 0)
|
||||
foreach(list_name IN LISTS divided_lists)
|
||||
foreach(item IN LISTS ${list_name})
|
||||
# Convert to a list
|
||||
string(REPLACE "-" ";" components ${item})
|
||||
|
||||
list(GET components 0 coll)
|
||||
list(GET components 1 algo)
|
||||
list(GET components 2 proto)
|
||||
list(GET components 3 redop)
|
||||
list(GET components 4 type)
|
||||
|
||||
list(APPEND IMPL_LIST "DEFINE_ncclDevFunc(${coll}_${algo}_${proto}_${redop}_${type}, ncclFunc${coll}, Func${redop}, ${type}, NCCL_ALGO_${algo}, NCCL_PROTO_${proto})\n")
|
||||
|
||||
# Increment the function id
|
||||
math(EXPR index "${index} + 1")
|
||||
endforeach()
|
||||
## Store lower-case version of COLL
|
||||
string(TOLOWER ${coll} COLL_LOWER)
|
||||
string(REPLACE "scatter" "_scatter" COLL_LOWER ${COLL_LOWER})
|
||||
if(NOT ${coll} STREQUAL "AllToAllPivot")
|
||||
string(REPLACE "all" "all_" COLL_LOWER ${COLL_LOWER})
|
||||
else()
|
||||
string(REPLACE "pivot" "_pivot" COLL_LOWER ${COLL_LOWER})
|
||||
endif()
|
||||
|
||||
## Set name/path of the file
|
||||
set(FILE_PATH "${HIPIFY_DIR}/src/device/${list_name}.cpp")
|
||||
message(STATUS "Generating ${FILE_PATH}")
|
||||
|
||||
## Construct the file
|
||||
file(WRITE ${FILE_PATH} "#include \"${COLL_LOWER}.h\"\n#include \"common.h\"\n\n")
|
||||
string(FIND "${list_name}" "LL128" IS_LL128)
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${FILE_PATH} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
endif()
|
||||
foreach(IMPL IN LISTS IMPL_LIST)
|
||||
file(APPEND ${FILE_PATH} "${IMPL}")
|
||||
endforeach()
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${FILE_PATH} "#endif\n")
|
||||
endif()
|
||||
|
||||
## Append the file to HIP sources list which will be added to source list
|
||||
list(APPEND HIP_SOURCES ${FILE_PATH})
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
|
||||
# Clear the IMPL list for the next iteration
|
||||
set(IMPL_LIST)
|
||||
endforeach()
|
||||
endfunction()
|
||||
|
||||
###################################################################################################################
|
||||
## Function to generate all the functions that are going to be in librccl.so
|
||||
###################################################################################################################
|
||||
function(gen_functions CONFIG_INPUT)
|
||||
string(REPLACE "|" ";" INPUT_LIST ${CONFIG_INPUT})
|
||||
## Sort the input so that it matches ALL_COLLS
|
||||
list(SORT INPUT_LIST)
|
||||
|
||||
foreach(INPUT ${INPUT_LIST})
|
||||
# Parse the the config string and make it a list
|
||||
string(REPLACE " " ";" FUNCTION_PARAMS ${INPUT})
|
||||
|
||||
# Get the number of parameters in the input
|
||||
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
|
||||
|
||||
# Assume all if a parameter is missing
|
||||
while(${PARAMS_LENGTH} LESS 5)
|
||||
list(APPEND FUNCTION_PARAMS "*")
|
||||
list(LENGTH FUNCTION_PARAMS PARAMS_LENGTH)
|
||||
endwhile()
|
||||
|
||||
## Filter functions/kernels based on input
|
||||
filter_functions(FUNCTION_PARAMS 0)
|
||||
endforeach()
|
||||
|
||||
gen_collectives() ## Generate collective files
|
||||
if(ENABLE_MSCCL_KERNEL)
|
||||
gen_msccl_kernels() ## Generate msccl files (not configurable)
|
||||
endif()
|
||||
gen_device_table() ## Generate device_table.cpp
|
||||
gen_host_table() ## Generate host_table.cpp
|
||||
|
||||
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
@@ -1,35 +0,0 @@
|
||||
# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
## This only applies to collective header files
|
||||
if(HIP_FILE MATCHES ".*/src/device/.*\\.h$")
|
||||
execute_process(COMMAND sed -i "s/template<typename T, typename RedOp, typename Proto>/template<typename T, typename RedOp, typename Proto, int COLL_UNROLL>/g" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/template<typename T, typename RedOp>/template<typename T, typename RedOp, int COLL_UNROLL>/g" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/\\(runRing<T[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>>/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>, COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/\\(runTreeSplit<T[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/\\(struct RunWorkElement<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
|
||||
execute_process(COMMAND sed -i "s/\\(struct RunWork<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" ${HIP_FILE})
|
||||
|
||||
message(STATUS "Added COLL_UNROLL template argument to ${HIP_FILE}")
|
||||
endif()
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
HIP_FILE=$1
|
||||
|
||||
if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then
|
||||
sed -i "s/template<typename T, typename RedOp, typename Proto>/template<typename T, typename RedOp, typename Proto, int COLL_UNROLL>/g" "$HIP_FILE"
|
||||
sed -i "s/template<typename T, typename RedOp>/template<typename T, typename RedOp, int COLL_UNROLL>/g" "$HIP_FILE"
|
||||
sed -i "s/ProtoSimple<1, 1>/ProtoSimple<1, 1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/ProtoSimple<1,1>/ProtoSimple<1,1,COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(using Proto = ProtoSimple<[^1][^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(runRing<T[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>>/runTreeUpDown<T, RedOp, ProtoSimple<1, 1, COLL_UNROLL>, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(runTreeSplit<T[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkElement<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWork<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
|
||||
echo "Added COLL_UNROLL template argument to $HIP_FILE"
|
||||
fi
|
||||
+297
-237
@@ -3,11 +3,13 @@ import os
|
||||
import sys
|
||||
|
||||
# Order of redops, tys, protos, algos must match src/include/device.h
|
||||
all_colls = ["Broadcast","Reduce","AllGather","ReduceScatter","AllReduce","SendRecv"]
|
||||
all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"]
|
||||
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16", "f8", "bf8"]
|
||||
all_protos = ["LL","LL128","SIMPLE"]
|
||||
all_algos = ["TREE","RING","COLLNET_DIRECT","COLLNET_CHAIN","NVLS","NVLS_TREE"]
|
||||
all_algos = ["TREE","RING"]
|
||||
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys]
|
||||
|
||||
################################################################################
|
||||
# The first command line argument is the path to the directory to generate and
|
||||
@@ -20,71 +22,105 @@ if os.path.exists(gensrc):
|
||||
os.remove(os.path.join(gensrc, name))
|
||||
#os.truncate(os.path.join(gensrc, name), 0)
|
||||
else:
|
||||
os.mkdir(gensrc)
|
||||
os.makedirs(gensrc)
|
||||
|
||||
################################################################################
|
||||
# The second command line argument is used as a regex to filter the functions
|
||||
# which make it into libnccl. This is helpful for reducing the binary when
|
||||
# The command line argument is used as a regex to filter the functions
|
||||
# which make it into librccl. This is helpful for reducing the binary when
|
||||
# developing device code. The regex supports non-space containing globs '*',
|
||||
# parentheses '(x)', and union 'a|b'. The string representing the function has
|
||||
# one of the forms:
|
||||
# and union 'a|b'. The string representing the function has the form:
|
||||
#
|
||||
# SendRecv
|
||||
# (AllGather|Broadcast) <algo> <proto>
|
||||
# (AlLReduce|Reduce|ReduceScatter) <redop> <type> <algo> <proto>
|
||||
# <coll> <algo> <proto> <redop> <type>
|
||||
#
|
||||
# The possible values for redop, type, algo, proto can be found in the all_<foo>
|
||||
# lists at the top of this file.
|
||||
#
|
||||
# Since the Makefile forwards this from the ONLY_FUNCS variable, useful command
|
||||
# line examples are given:
|
||||
"""
|
||||
# Only send/recv:
|
||||
make ONLY_FUNCS="SendRecv"
|
||||
|
||||
# Only non-reductions:
|
||||
make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
|
||||
# Only AllReduce sum f32 (but all algos, protos)
|
||||
make ONLY_FUNCS="AllReduce Sum f32 * *"
|
||||
|
||||
# Only AllReduce minmax i32 NVLS (but all protos)
|
||||
make ONLY_FUNCS="AllReduce MinMax i32 NVLS *"
|
||||
|
||||
# AllReduce sum <all floats> RING LL128
|
||||
make ONLY_FUNCS="AllReduce Sum f32 RING LL128"
|
||||
"""
|
||||
# Example use-cases:
|
||||
#
|
||||
# # Only send/recv:
|
||||
# make ONLY_FUNCS="SendRecv"
|
||||
#
|
||||
# # Only AllReduce and Reduce
|
||||
# make ONLY_FUNCS="AllReduce|Reduce"
|
||||
#
|
||||
# # Only non-reductions:
|
||||
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
#
|
||||
# # Only AllReduce Sum int32_t (but all algos, protos)
|
||||
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
|
||||
|
||||
# Paste all non-None arguments together with `sep`.
|
||||
def paste(sep, *args):
|
||||
return sep.join(x for x in args if x is not None)
|
||||
|
||||
func_pattern = sys.argv[2:3]
|
||||
is_ifc = 1 if sys.argv[2] == "ON" else 0
|
||||
is_colltrace = 1 if sys.argv[3] == "ON" else 0
|
||||
is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0
|
||||
|
||||
func_pattern = sys.argv[5:6]
|
||||
if func_pattern and func_pattern[0]:
|
||||
import re
|
||||
func_pattern = func_pattern[0]
|
||||
func_pattern = func_pattern.replace("*", "[^ ]*")
|
||||
func_pattern += "$"
|
||||
def func_filter(*fn):
|
||||
return None is not re.match(func_pattern, paste(" ", *fn), flags=re.IGNORECASE)
|
||||
else:
|
||||
def func_filter(coll, redop, ty, algo, proto):
|
||||
return True
|
||||
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
|
||||
################################################################################
|
||||
|
||||
algos_of_coll = {
|
||||
"AllGather": ["RING","COLLNET_DIRECT","NVLS"],
|
||||
"AllGather": ["RING"],
|
||||
"AllReduce": all_algos,
|
||||
"AllToAllPivot": ["RING"],
|
||||
"Broadcast": ["RING"],
|
||||
"Reduce": ["RING"],
|
||||
"ReduceScatter": ["RING","COLLNET_DIRECT","NVLS"],
|
||||
"SendRecv": [None]
|
||||
"ReduceScatter": ["RING"],
|
||||
"SendRecv": ["RING"]
|
||||
}
|
||||
|
||||
protos_of_coll = {
|
||||
"AllGather": all_protos,
|
||||
"AllReduce": all_protos,
|
||||
"AllToAllPivot": ["SIMPLE"],
|
||||
"Broadcast": all_protos,
|
||||
"Reduce": all_protos,
|
||||
"ReduceScatter": all_protos,
|
||||
"SendRecv": ["SIMPLE"]
|
||||
}
|
||||
|
||||
redops_of_coll = {
|
||||
"AllGather": ["Sum"],
|
||||
"AllReduce": all_redops,
|
||||
"AllToAllPivot": ["Sum"],
|
||||
"Broadcast": ["Sum"],
|
||||
"Reduce": all_redops,
|
||||
"ReduceScatter": all_redops,
|
||||
"SendRecv": ["Sum"]
|
||||
}
|
||||
|
||||
tys_of_coll = {
|
||||
"AllGather": ["i8"],
|
||||
"AllReduce": all_tys,
|
||||
"AllToAllPivot": ["i8"],
|
||||
"Broadcast": ["i8"],
|
||||
"Reduce": all_tys,
|
||||
"ReduceScatter": all_tys,
|
||||
"SendRecv": ["i8"]
|
||||
}
|
||||
|
||||
coll_camel_to_lower = {
|
||||
"AllGather": "all_gather",
|
||||
"AllReduce": "all_reduce",
|
||||
"AllToAllPivot": "alltoall_pivot",
|
||||
"Broadcast": "broadcast",
|
||||
"Reduce": "reduce",
|
||||
"ReduceScatter": "reduce_scatter",
|
||||
@@ -94,141 +130,237 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
||||
|
||||
################################################################################
|
||||
|
||||
# Returns pair of minimum required values for (CUDART_VERSION, __CUDA_ARCH__)
|
||||
# or None if function is never supported. Note that (0, 0) encodes universal
|
||||
# support.
|
||||
def required_cuda(coll, redop, ty, algo, proto):
|
||||
cudart, arch = 0, 0
|
||||
# kernels mapped to by coll="Nop" functions have coll="Generic"
|
||||
if coll in ("SendRecv", "Generic", "Nop"): return (cudart, arch)
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty):
|
||||
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
|
||||
return False
|
||||
return True
|
||||
|
||||
if proto!="SIMPLE" and algo not in ("RING","TREE"): return None
|
||||
# A recursive helper to generate collective functions based on the input given
|
||||
def func_filter(function_params, current_idx, item_list=None):
|
||||
if item_list is None:
|
||||
item_list = []
|
||||
|
||||
if coll in ("AllReduce","Reduce","ReduceScatter"):
|
||||
if redop=="SumPostDiv" and ty[0] not in ("i","u"): return None
|
||||
if ty=="bf16": cudart = max(cudart, 11000)
|
||||
# Check if current_idx exceeds the max depth
|
||||
if current_idx < len(all_params):
|
||||
# Current element is the config parameter
|
||||
current_element = function_params[current_idx]
|
||||
|
||||
if "NVLS" in algo:
|
||||
if coll in ("AllReduce","Reduce","ReduceScatter"):
|
||||
# Must match ncclNvlsSupported() in src/include/device.h
|
||||
nvls_ok = ((ty in ("i32","u32","i64","u64") and redop in ("Sum","MinMax")) or
|
||||
(ty in ("f32","f64") and redop=="Sum") or
|
||||
(ty in ("f16","bf16") and redop in ("Sum","MinMax")))
|
||||
if not nvls_ok: return None
|
||||
cudart = max(cudart, 12010)
|
||||
arch = max(arch, 900)
|
||||
# If the paramter is equal to '*', include all possible cases for it
|
||||
if current_element == "*":
|
||||
if current_idx == 0:
|
||||
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
|
||||
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
|
||||
# Get the current list from all_params
|
||||
current_list = all_params[current_idx]
|
||||
|
||||
return (cudart, arch)
|
||||
# Iterate over the items int the current_list
|
||||
for item in current_list:
|
||||
# Add item to item_list which will be used in the inner most loop
|
||||
item_list.append(item)
|
||||
yield from func_filter(function_params, current_idx+1, item_list)
|
||||
|
||||
# For each loop layer remove the last element in item_list
|
||||
item_list.pop()
|
||||
else:
|
||||
# Check if the current element is recognized
|
||||
elements = current_element.split("/")
|
||||
current_param = all_params[current_idx]
|
||||
|
||||
# Iterate over the elements in the elements list
|
||||
for item in elements:
|
||||
if item not in current_param:
|
||||
raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.")
|
||||
|
||||
for item in elements:
|
||||
item_list.append(item)
|
||||
yield from func_filter(function_params, current_idx+1, item_list)
|
||||
|
||||
# For each loop layer remove the last element in item_list
|
||||
item_list.pop()
|
||||
else:
|
||||
coll, algo, proto, redop, ty = item_list
|
||||
|
||||
if func_validate(*item_list):
|
||||
yield(coll, algo, proto, redop, ty)
|
||||
|
||||
# Parse ONLY_FUNCS input and feed it to func_filter
|
||||
def parse_input(func_pattern):
|
||||
input_list = sorted(func_pattern.split("|"))
|
||||
|
||||
for input in input_list:
|
||||
function_params = input.split()
|
||||
params_length = len(function_params)
|
||||
|
||||
# If a parameter is missing, append '*'
|
||||
while params_length < len(all_params):
|
||||
function_params.append("*")
|
||||
params_length += 1
|
||||
|
||||
# Filter functions/kernels based on input
|
||||
yield from func_filter(function_params, 0)
|
||||
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, redop, ty, algo, proto):
|
||||
def equivalent_primary(coll, algo, proto, redop, ty):
|
||||
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum") and ty[0]=="i":
|
||||
return (coll, redop, "u"+ty[1:], algo, proto)
|
||||
ty = "u"+ty[1:]
|
||||
# map signed integer min/max to unsigned for non-NVLS
|
||||
if redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
|
||||
return (coll, redop, "u"+ty[1:], algo, proto)
|
||||
return (coll, redop, ty, algo, proto)
|
||||
|
||||
# Map to another func representing the best kernel to use. Every distinct value
|
||||
# returned will instantiate a ncclDevKernel specialized to run this func
|
||||
# without function call overhead.
|
||||
def best_kernel(coll, redop, ty, algo, proto):
|
||||
def best(coll, redop, ty, algo, proto):
|
||||
# Modify this logic to control how many kernels are specialized.
|
||||
if coll=="Nop": return ("Generic", None, None, None, None)
|
||||
if coll=="SendRecv": return ("SendRecv", None, None, None, None)
|
||||
if coll in ("AllGather","Broadcast"): return (coll, None, None, "RING", "LL")
|
||||
return (coll, "Sum", ty, ("TREE" if algo=="TREE" else "RING"), "LL")
|
||||
# Need to ensure kernel is specialize for a primary function
|
||||
kfn = equivalent_primary(*best(coll, redop, ty, algo, proto))
|
||||
# And isn't filtered out.
|
||||
if not func_filter(*kfn): return ("Generic", None, None, None, None)
|
||||
return kfn
|
||||
elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
|
||||
ty = "u"+ty[1:]
|
||||
return (coll, algo, proto, redop, ty)
|
||||
|
||||
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
||||
def enumerate_func_rows():
|
||||
yield ("SendRecv", None, None, None, None)
|
||||
for coll in ("AllGather", "Broadcast"):
|
||||
algos = algos_of_coll[coll]
|
||||
for algo in algos:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
yield (coll, None, None, algo, proto)
|
||||
for coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
algos = algos_of_coll[coll]
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
for algo in algos:
|
||||
for proto in all_protos:
|
||||
yield (coll, redop, ty, algo, proto)
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
yield (coll, algo, proto, redop, ty)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty = fn
|
||||
|
||||
return (
|
||||
all_colls.index(coll),
|
||||
all_algos.index(algo),
|
||||
all_protos.index(proto),
|
||||
all_redops.index(redop),
|
||||
all_tys.index(ty)
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
||||
def is_built(coll, redop, ty, algo, proto):
|
||||
built = required_cuda(coll, redop, ty, algo, proto)
|
||||
built = built and func_filter(coll, redop, ty, algo, proto)
|
||||
return built
|
||||
|
||||
# Returns None if required_cuda(...) is None.
|
||||
# Returns the coll="Nop" function if developer has filtered it out.
|
||||
# Otherwise just returns func it was given.
|
||||
def validate(coll, redop, ty, algo, proto):
|
||||
valid = required_cuda(coll, redop, ty, algo, proto)
|
||||
built = valid and func_filter(coll, redop, ty, algo, proto)
|
||||
if built: return (coll, redop, ty, algo, proto)
|
||||
if valid: return ("Nop", None, None, None, None)
|
||||
return None
|
||||
|
||||
# Corresponds to ncclDevFuncRowToId[]
|
||||
func_rows = [validate(*fn) for fn in enumerate_func_rows()]
|
||||
func_rows = [fn for fn in enumerate_func_rows()]
|
||||
|
||||
# Corresponds to ncclDevFuncTable[]
|
||||
primary_funcs = sorted(set(equivalent_primary(*fn) for fn in func_rows if fn is not None))
|
||||
primary_funcs = sorted(set(equivalent_primary(*fn) for fn in parse_input(func_pattern)), key=custom_sort_key)
|
||||
|
||||
# primary_to_index[primary_funcs[i]] == i
|
||||
primary_to_index = {fn: i for (i,fn) in zip(range(len(primary_funcs)), primary_funcs)}
|
||||
|
||||
kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs))
|
||||
primary_to_index = {fn: primary_funcs.index(fn) if fn in primary_funcs else -1 for fn in func_rows}
|
||||
|
||||
################################################################################
|
||||
|
||||
# Generate <gensrc>/device_table.cu
|
||||
with open(os.path.join(gensrc, "device_table.cu"), "w") as f:
|
||||
# Generate <gensrc>/device_table.h
|
||||
with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, "device_table.h"))
|
||||
out = f.write
|
||||
out('#include "common.h"\n')
|
||||
out("\n")
|
||||
|
||||
if is_ifc: func_declaration = "__device__ void"
|
||||
else: func_declaration = "__device__ __attribute__((noinline)) void"
|
||||
|
||||
for fn in primary_funcs:
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
cudart, arch = required_cuda(*fn)
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
|
||||
out("__device__ void %s();\n" % sym)
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#endif\n")
|
||||
if fn[2] == "LL128":
|
||||
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
out("%s %s();\n%s %s_4();\n#else\n" % (func_declaration, sym, func_declaration, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("%s %s();\n%s %s_4();\n#endif\n" % (func_declaration, sym_ll, func_declaration, sym_ll))
|
||||
else:
|
||||
out("%s %s();\n%s %s_4();\n" % (func_declaration, sym, func_declaration, sym))
|
||||
out("\n")
|
||||
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n");
|
||||
out("typedef void(*ncclDevFuncPtr_t)();\n\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable[] = {\n")
|
||||
index = 0
|
||||
for fn in primary_funcs:
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
cudart, arch = required_cuda(*fn)
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart ,arch))
|
||||
out("/*%4d*/ %s,\n" % (index, sym))
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
|
||||
if fn[2] == "LL128":
|
||||
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index, sym))
|
||||
index += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
out("// Workaround for https://reviews.llvm.org/D55580\n"
|
||||
"__device__ void ncclWorkaroundClangD55580() {}\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index = 0
|
||||
for fn in primary_funcs:
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s_4,\n#else\n" % (index, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s_4,\n#endif\n" % (index, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s_4,\n" % (index, sym))
|
||||
index += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
if not is_ifc:
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller<f, m>::call(funcIndex) : Caller<m, l>::call(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller<0, {index}>::call(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller4 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller4<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller4<0, {index}>::call4(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
|
||||
# Generate <gensrc>/device_table.cpp
|
||||
if is_colltrace:
|
||||
with open(os.path.join(gensrc, "device_table.cpp"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, "device_table.cpp"))
|
||||
|
||||
out = f.write
|
||||
out('#include "nccl_common.h"\n#include "device.h"\n')
|
||||
out("\n")
|
||||
|
||||
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
for fn in primary_funcs:
|
||||
out(' "%s",\n' % paste("_", "ncclDevFunc", *fn))
|
||||
for ty in all_tys:
|
||||
out(f' "ncclDevFunc_OneRankReduce_PreMulSum_{ty}",\n')
|
||||
out("};\n")
|
||||
|
||||
# Generate <gensrc>/host_table.cpp
|
||||
with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, "host_table.cpp"))
|
||||
|
||||
# Generate <gensrc>/host_table.cc
|
||||
with open(os.path.join(gensrc, "host_table.cc"), "w") as f:
|
||||
out = f.write
|
||||
out('#include "device.h"\n')
|
||||
out("\n")
|
||||
@@ -243,61 +375,14 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f:
|
||||
comment = " // " + paste(" ", *fn)
|
||||
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
|
||||
index += 1
|
||||
out("-1};\n")
|
||||
out("\n")
|
||||
|
||||
# Forward declarations of kernels.
|
||||
for kfn in kernel_funcs:
|
||||
cudart, _ = required_cuda(*kfn)
|
||||
sym = paste("_", "ncclDevKernel", *kfn)
|
||||
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
|
||||
out("__global__ void %s(struct ncclDevComm*, uint64_t, struct ncclWork*);\n" % sym)
|
||||
if cudart != 0: out("#endif\n")
|
||||
out("\n")
|
||||
|
||||
# List of all kernel function pointers.
|
||||
out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs))
|
||||
out("extern void* const ncclDevKernelList[] = {\n")
|
||||
index = 0
|
||||
for kfn in kernel_funcs:
|
||||
cudart, _ = required_cuda(*kfn)
|
||||
sym = paste("_", "ncclDevKernel", *kfn)
|
||||
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
|
||||
out("/*%4d*/ (void*)%s,\n" % (index, sym));
|
||||
if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
|
||||
index += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
# Maps primary id to kernel function pointer.
|
||||
out("extern void* const ncclDevKernelForFunc[] = {\n")
|
||||
index = 0
|
||||
for fn in primary_funcs:
|
||||
kfn = best_kernel(*fn)
|
||||
sym = paste("_", "ncclDevKernel", *kfn)
|
||||
cudart, _ = required_cuda(*kfn)
|
||||
if cudart != 0: out("#if CUDART_VERSION >= %d\n" % cudart)
|
||||
out("/*%4d*/ (void*)%s,\n" % (index, sym))
|
||||
if cudart != 0: out("#else\n" "/*%4d*/ nullptr,\n" "#endif\n" % index)
|
||||
index += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
# Does the prior map use an explicitly specialized kernel.
|
||||
out("extern bool const ncclDevKernelForFuncIsSpecialized[] = {\n")
|
||||
index = 0
|
||||
for fn in primary_funcs:
|
||||
kfn = best_kernel(*fn)
|
||||
specialized = "1" if fn == kfn else "0"
|
||||
out("/*%4d*/ %s,\n" % (index, specialized))
|
||||
index += 1
|
||||
out("0};\n")
|
||||
out(f"{index}")
|
||||
out("};\n")
|
||||
|
||||
# Maps to .cu filename which implements this func. The only constraint is that
|
||||
# "coll" is reflected in the name: formally that no two funcs having different
|
||||
# coll's map to the same filename.
|
||||
def impl_filename(coll, redop, ty, algo, proto):
|
||||
return "%s.cu" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
|
||||
def impl_filename(coll, algo, proto, redop, ty):
|
||||
return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
|
||||
|
||||
# Partition the functions and kernels to the .cu filenames. The partition is
|
||||
# a dictionary mapping filename to (coll, func-tuple list)
|
||||
@@ -312,33 +397,6 @@ def partition_by_name(fns):
|
||||
return ans
|
||||
|
||||
name_to_funcs = partition_by_name(fn for fn in primary_funcs if fn[0]!="Nop")
|
||||
name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Generic")
|
||||
|
||||
# Generate <gensrc>/rules.mk
|
||||
with open(os.path.join(gensrc, "rules.mk"), "w") as f:
|
||||
out = f.write
|
||||
impl_names = sorted(name_to_funcs.keys())
|
||||
names = impl_names + ["host_table.cc", "device_table.cu"]
|
||||
out("LIB_OBJS_GEN = $(patsubst %, $(OBJDIR)/genobj/%.o, {names})\n"
|
||||
.format(names=" ".join(names)))
|
||||
out("\n")
|
||||
|
||||
# For each <coll>_<op>_<ty>.cu compile to a .cu.o file. Notice the dependencies
|
||||
# come from the suffix-erased file (e.g. 'gensrc/all_reduce.cu')
|
||||
for name in impl_names:
|
||||
coll = name_to_funcs[name][0]
|
||||
out(
|
||||
"$(OBJDIR)/genobj/{name}.o: $(OBJDIR)/gensrc $(OBJDIR)/genobj/{lower_coll}.cu.d\n"
|
||||
"\t" "$(call COMPILE,$@,$(OBJDIR)/gensrc/{name})\n"
|
||||
"\n"
|
||||
.format(name=name, lower_coll=coll_camel_to_lower[coll])
|
||||
)
|
||||
|
||||
# Add the suffix-erased .cu's which are used only for dependency scraping.
|
||||
for coll in set(coll for (coll,_,_,_,_) in primary_funcs if coll!="Nop"):
|
||||
name = impl_filename(coll, None, None, None, None)
|
||||
if name not in name_to_funcs:
|
||||
name_to_funcs[name] = (coll, [])
|
||||
|
||||
redop_to_cxx = {
|
||||
None: "FuncCopy",
|
||||
@@ -360,13 +418,17 @@ ty_to_cxx = {
|
||||
"f16": "half",
|
||||
"f32": "float",
|
||||
"f64": "double",
|
||||
"bf16": "__nv_bfloat16"
|
||||
"bf16": "hip_bfloat16",
|
||||
"f8": "rccl_float8",
|
||||
"bf8": "rccl_bfloat8",
|
||||
}
|
||||
|
||||
# Generate each <gensrc>/<impl>.cu:
|
||||
# Generate each <gensrc>/<impl>.cpp:
|
||||
for name in name_to_funcs.keys():
|
||||
(coll, fns) = name_to_funcs[name]
|
||||
with open(os.path.join(gensrc, name), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, name))
|
||||
|
||||
out = f.write
|
||||
out(
|
||||
'#include "common.h"\n'
|
||||
@@ -374,32 +436,30 @@ for name in name_to_funcs.keys():
|
||||
.format(lower_coll=coll_camel_to_lower[coll])
|
||||
)
|
||||
|
||||
(_, kfns) = name_to_kernels.get(name) or (None, [])
|
||||
for kfn in kfns:
|
||||
(coll, redop, ty, algo, proto) = kfn
|
||||
sym = paste("_", coll, redop, ty, algo, proto)
|
||||
fn_id = primary_to_index[kfn]
|
||||
cudart, arch = required_cuda(*kfn)
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
|
||||
out(
|
||||
"DEFINE_ncclDevKernel({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {fn_id})\n"
|
||||
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"), fn_id=fn_id)
|
||||
)
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#endif\n")
|
||||
|
||||
for fn in fns:
|
||||
(coll, redop, ty, algo, proto) = fn
|
||||
sym = paste("_", coll, redop, ty, algo, proto)
|
||||
cudart, arch = required_cuda(*fn)
|
||||
if (cudart, arch) != (0, 0):
|
||||
out("#if CUDART_VERSION >= %d && __CUDA_ARCH__ >= %d\n" % (cudart, arch))
|
||||
(coll, algo, proto, redop, ty) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty)
|
||||
if proto == "LL128":
|
||||
out("#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
out(
|
||||
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto})\n"
|
||||
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"))
|
||||
)
|
||||
if (cudart, arch) != (0, 0):
|
||||
if proto == "LL128":
|
||||
out("#endif\n")
|
||||
|
||||
# Generate each <gensrc>/<msccl_impl>.cpp
|
||||
if is_msccl_kernels:
|
||||
for redop in all_redops:
|
||||
if redop in ("Sum", "Prod", "MinMax"):
|
||||
for ty in all_tys:
|
||||
with open(os.path.join(gensrc, f"msccl_kernel_{redop}_{ty}.cpp"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, f"msccl_kernel_{redop}_{ty}.cpp"))
|
||||
|
||||
out = f.write
|
||||
out('#include "msccl_kernel_impl.h"\n#include "nccl_common.h"\n')
|
||||
out(
|
||||
"MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE({redop}, {ty_cxx}, false);\n"
|
||||
.format(redop=redop, ty_cxx=ty_to_cxx[ty])
|
||||
)
|
||||
@@ -552,7 +552,7 @@ inline bool ncclNvlsSupported(int devRedOp, int type) {
|
||||
// Map the rowIdx to funcIdx
|
||||
extern int const ncclDevFuncRowToId[];
|
||||
|
||||
// `ncclFuncIndex()` needs to be in sync with 'ALL_COLLS' in Generate.cmake
|
||||
// `ncclDevFuncId()` needs to be in sync with 'ALL_COLLS' in generate.py
|
||||
inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto) {
|
||||
int row = 0;
|
||||
do {
|
||||
@@ -568,7 +568,7 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto)
|
||||
row += (((algo * NCCL_NUM_PROTOCOLS + proto) * ncclNumDevRedOps + devRedOp) * ncclNumTypes + type) - NCCL_NUM_FLOATS * (algo * NCCL_NUM_PROTOCOLS + proto);
|
||||
break;
|
||||
}
|
||||
row += (NCCL_NUM_ALGORITHMS - 2) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
row += (NCCL_NUM_ALGORITHMS - 4) * NCCL_NUM_PROTOCOLS * (ncclNumDevRedOps * ncclNumTypes - NCCL_NUM_FLOATS);
|
||||
|
||||
// RING / SIMPLE / Sum / int8_t
|
||||
if (coll == ncclFuncAllToAllPivot) break;
|
||||
|
||||
@@ -13,7 +13,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC
|
||||
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
|
||||
|
||||
#define NCCL_NUM_ONERANK 12
|
||||
#define FUNC_INDEX_TOTAL 980 + NCCL_NUM_ONERANK
|
||||
#define FUNC_INDEX_TOTAL 656 + NCCL_NUM_ONERANK
|
||||
|
||||
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
|
||||
typedef enum {
|
||||
|
||||
Yeni konuda referans
Bir kullanıcı engelle