diff --git a/cmake/Generator.cmake b/cmake/Generator.cmake index fecf166b62..147d92f0aa 100644 --- a/cmake/Generator.cmake +++ b/cmake/Generator.cmake @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -65,6 +65,43 @@ set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16") # --- or --- # make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float" +set(AllGather_Params "RING" "*" "Sum" "int8_t") +set(AllReduce_Params "*" "*" "*" "*") +set(AllToAllPivot_Params "RING" "SIMPLE" "Sum" "int8_t") +set(Broadcast_Params "RING" "*" "Sum" "int8_t") +set(Reduce_Params "RING" "*" "*" "*") +set(ReduceScatter_Params "RING" "*" "*" "*") +set(SendRecv_Params "RING" "SIMPLE" "Sum" "int8_t") + +############################################################################################################# +## Helper function to check if the conditions for the collective is being met +############################################################################################################# +function(validate_func ITEM_LIST) + set(paramIdx 1) + ## Extract coll/redop/type + list(GET ITEM_LIST 0 coll) + list(GET ITEM_LIST 3 redop) + list(GET ITEM_LIST 4 type) + + ## First check if redop 'SumPostDiv' has no type float + if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST) + set(is_valid FALSE PARENT_SCOPE) + return() + endif() + foreach(parameter IN LISTS "${coll}_Params") + if(NOT parameter STREQUAL "*") + list(GET ITEM_LIST "${paramIdx}" item) + string(FIND "${parameter}" "${item}" is_found) + if(is_found EQUAL -1) + set(is_valid FALSE PARENT_SCOPE) + return() + endif() + endif() + math(EXPR paramIdx "${paramIdx} + 1") + endforeach() + set(is_valid TRUE PARENT_SCOPE) +endfunction() + ############################################################################################################# ## A recursive helper macro to generate functions and kernels based on the input given ############################################################################################################# @@ -117,16 +154,8 @@ macro(filter_functions FUNCTION_PARAMS current_idx) list(GET ITEM_LIST 3 REDOP) list(GET ITEM_LIST 4 TYPE) - ## Need to check if these conditions are met prior to file generation - if(NOT ${COLL} STREQUAL "AllReduce" AND NOT ${ALGO} STREQUAL "RING") - continue() - elseif((${COLL} STREQUAL "AllGather" OR ${COLL} STREQUAL "Broadcast" OR ${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND (NOT ${REDOP} STREQUAL "Sum" OR NOT ${TYPE} STREQUAL "int8_t")) - continue() - elseif((${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND NOT ${PROTO} STREQUAL "SIMPLE") - continue() - endif() - - if(${REDOP} STREQUAL "SumPostDiv" AND TYPE IN_LIST FLOATS_LIST) + validate_func("${ITEM_LIST}") + if (NOT is_valid) continue() endif() @@ -149,19 +178,22 @@ function(gen_device_table) set(DEVICE_TABLE_H_FILE "${HIPIFY_DIR}/src/device/device_table.h") message(STATUS "Generating ${DEVICE_TABLE_H_FILE}") + if(ENABLE_IFC) + set(func_declaration "__device__ void") + else() + set(func_declaration "__device__ __attribute__((noinline)) void") + endif() + ## Declaration of device functions foreach(func IN LISTS FUNC_LIST) string(FIND "${func}" "LL128" IS_LL128) if(NOT IS_LL128 EQUAL -1) file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") - endif() - if(ENABLE_IFC) - file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ void ${func}();\n") + file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n#else\n") + string(REPLACE "LL128" "LL" func "${func}") + file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n#endif\n") else() - file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ __attribute__((noinline)) void ${func}();\n") - endif() - if(NOT IS_LL128 EQUAL -1) - file(APPEND ${DEVICE_TABLE_H_FILE} "#endif\n") + file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n") endif() endforeach() file(APPEND ${DEVICE_TABLE_H_FILE} "\n") @@ -173,11 +205,9 @@ function(gen_device_table) string(FIND "${func}" "LL128" IS_LL128) if(NOT IS_LL128 EQUAL -1) file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n") - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n") - file(APPEND ${DEVICE_TABLE_H_FILE} "#else\n") + file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#else\n") string(REPLACE "LL128" "LL" func "${func}") - file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n") - file(APPEND ${DEVICE_TABLE_H_FILE} "#endif\n") + file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#endif\n") else() file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n") endif() @@ -246,18 +276,13 @@ function(gen_host_table) foreach(proto IN LISTS ALL_PROTOS) foreach(redop IN LISTS ALL_REDOPS) foreach(type IN LISTS ALL_TYPES) - if(NOT ${coll} STREQUAL "AllReduce" AND NOT ${algo} STREQUAL "RING") - continue() - elseif((${coll} STREQUAL "AllGather" OR ${coll} STREQUAL "Broadcast" OR ${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND (NOT ${redop} STREQUAL "Sum" OR NOT ${type} STREQUAL "int8_t")) - continue() - elseif((${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND NOT ${proto} STREQUAL "SIMPLE") + ## Create a list from the combination of curr parameters + set(ITEM_LIST ${coll} ${algo} ${proto} ${redop} ${type}) + validate_func("${ITEM_LIST}") + if (NOT is_valid) continue() endif() - - if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST) - continue() - endif() - + ## Try to find the combination in the generated func list list(FIND FUNC_LIST "ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}" fn_id) if(NOT ${fn_id} EQUAL -1) set(last_valid_fn_id ${fn_id})