Fix bug when configuring for only LL128 (#1097)

Este commit está contenido en:
Bertan Dogancay
2024-03-01 18:09:39 -07:00
cometido por GitHub
padre cbd955627e
commit a279e7f32d
+57 -32
Ver fichero
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -65,6 +65,43 @@ set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16")
# --- or ---
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
set(AllGather_Params "RING" "*" "Sum" "int8_t")
set(AllReduce_Params "*" "*" "*" "*")
set(AllToAllPivot_Params "RING" "SIMPLE" "Sum" "int8_t")
set(Broadcast_Params "RING" "*" "Sum" "int8_t")
set(Reduce_Params "RING" "*" "*" "*")
set(ReduceScatter_Params "RING" "*" "*" "*")
set(SendRecv_Params "RING" "SIMPLE" "Sum" "int8_t")
#############################################################################################################
## Helper function to check if the conditions for the collective is being met
#############################################################################################################
function(validate_func ITEM_LIST)
set(paramIdx 1)
## Extract coll/redop/type
list(GET ITEM_LIST 0 coll)
list(GET ITEM_LIST 3 redop)
list(GET ITEM_LIST 4 type)
## First check if redop 'SumPostDiv' has no type float
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
set(is_valid FALSE PARENT_SCOPE)
return()
endif()
foreach(parameter IN LISTS "${coll}_Params")
if(NOT parameter STREQUAL "*")
list(GET ITEM_LIST "${paramIdx}" item)
string(FIND "${parameter}" "${item}" is_found)
if(is_found EQUAL -1)
set(is_valid FALSE PARENT_SCOPE)
return()
endif()
endif()
math(EXPR paramIdx "${paramIdx} + 1")
endforeach()
set(is_valid TRUE PARENT_SCOPE)
endfunction()
#############################################################################################################
## A recursive helper macro to generate functions and kernels based on the input given
#############################################################################################################
@@ -117,16 +154,8 @@ macro(filter_functions FUNCTION_PARAMS current_idx)
list(GET ITEM_LIST 3 REDOP)
list(GET ITEM_LIST 4 TYPE)
## Need to check if these conditions are met prior to file generation
if(NOT ${COLL} STREQUAL "AllReduce" AND NOT ${ALGO} STREQUAL "RING")
continue()
elseif((${COLL} STREQUAL "AllGather" OR ${COLL} STREQUAL "Broadcast" OR ${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND (NOT ${REDOP} STREQUAL "Sum" OR NOT ${TYPE} STREQUAL "int8_t"))
continue()
elseif((${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND NOT ${PROTO} STREQUAL "SIMPLE")
continue()
endif()
if(${REDOP} STREQUAL "SumPostDiv" AND TYPE IN_LIST FLOATS_LIST)
validate_func("${ITEM_LIST}")
if (NOT is_valid)
continue()
endif()
@@ -149,19 +178,22 @@ function(gen_device_table)
set(DEVICE_TABLE_H_FILE "${HIPIFY_DIR}/src/device/device_table.h")
message(STATUS "Generating ${DEVICE_TABLE_H_FILE}")
if(ENABLE_IFC)
set(func_declaration "__device__ void")
else()
set(func_declaration "__device__ __attribute__((noinline)) void")
endif()
## Declaration of device functions
foreach(func IN LISTS FUNC_LIST)
string(FIND "${func}" "LL128" IS_LL128)
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
endif()
if(ENABLE_IFC)
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ void ${func}();\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n#else\n")
string(REPLACE "LL128" "LL" func "${func}")
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n#endif\n")
else()
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ __attribute__((noinline)) void ${func}();\n")
endif()
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${DEVICE_TABLE_H_FILE} "#endif\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n")
endif()
endforeach()
file(APPEND ${DEVICE_TABLE_H_FILE} "\n")
@@ -173,11 +205,9 @@ function(gen_device_table)
string(FIND "${func}" "LL128" IS_LL128)
if(NOT IS_LL128 EQUAL -1)
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "#else\n")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#else\n")
string(REPLACE "LL128" "LL" func "${func}")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
file(APPEND ${DEVICE_TABLE_H_FILE} "#endif\n")
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#endif\n")
else()
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
endif()
@@ -246,18 +276,13 @@ function(gen_host_table)
foreach(proto IN LISTS ALL_PROTOS)
foreach(redop IN LISTS ALL_REDOPS)
foreach(type IN LISTS ALL_TYPES)
if(NOT ${coll} STREQUAL "AllReduce" AND NOT ${algo} STREQUAL "RING")
continue()
elseif((${coll} STREQUAL "AllGather" OR ${coll} STREQUAL "Broadcast" OR ${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND (NOT ${redop} STREQUAL "Sum" OR NOT ${type} STREQUAL "int8_t"))
continue()
elseif((${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND NOT ${proto} STREQUAL "SIMPLE")
## Create a list from the combination of curr parameters
set(ITEM_LIST ${coll} ${algo} ${proto} ${redop} ${type})
validate_func("${ITEM_LIST}")
if (NOT is_valid)
continue()
endif()
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
continue()
endif()
## Try to find the combination in the generated func list
list(FIND FUNC_LIST "ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
if(NOT ${fn_id} EQUAL -1)
set(last_valid_fn_id ${fn_id})