Fix bug when configuring for only LL128 (#1097)
Este commit está contenido en:
+57
-32
@@ -1,6 +1,6 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
@@ -65,6 +65,43 @@ set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16")
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
|
||||
set(AllGather_Params "RING" "*" "Sum" "int8_t")
|
||||
set(AllReduce_Params "*" "*" "*" "*")
|
||||
set(AllToAllPivot_Params "RING" "SIMPLE" "Sum" "int8_t")
|
||||
set(Broadcast_Params "RING" "*" "Sum" "int8_t")
|
||||
set(Reduce_Params "RING" "*" "*" "*")
|
||||
set(ReduceScatter_Params "RING" "*" "*" "*")
|
||||
set(SendRecv_Params "RING" "SIMPLE" "Sum" "int8_t")
|
||||
|
||||
#############################################################################################################
|
||||
## Helper function to check if the conditions for the collective is being met
|
||||
#############################################################################################################
|
||||
function(validate_func ITEM_LIST)
|
||||
set(paramIdx 1)
|
||||
## Extract coll/redop/type
|
||||
list(GET ITEM_LIST 0 coll)
|
||||
list(GET ITEM_LIST 3 redop)
|
||||
list(GET ITEM_LIST 4 type)
|
||||
|
||||
## First check if redop 'SumPostDiv' has no type float
|
||||
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
|
||||
set(is_valid FALSE PARENT_SCOPE)
|
||||
return()
|
||||
endif()
|
||||
foreach(parameter IN LISTS "${coll}_Params")
|
||||
if(NOT parameter STREQUAL "*")
|
||||
list(GET ITEM_LIST "${paramIdx}" item)
|
||||
string(FIND "${parameter}" "${item}" is_found)
|
||||
if(is_found EQUAL -1)
|
||||
set(is_valid FALSE PARENT_SCOPE)
|
||||
return()
|
||||
endif()
|
||||
endif()
|
||||
math(EXPR paramIdx "${paramIdx} + 1")
|
||||
endforeach()
|
||||
set(is_valid TRUE PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#############################################################################################################
|
||||
## A recursive helper macro to generate functions and kernels based on the input given
|
||||
#############################################################################################################
|
||||
@@ -117,16 +154,8 @@ macro(filter_functions FUNCTION_PARAMS current_idx)
|
||||
list(GET ITEM_LIST 3 REDOP)
|
||||
list(GET ITEM_LIST 4 TYPE)
|
||||
|
||||
## Need to check if these conditions are met prior to file generation
|
||||
if(NOT ${COLL} STREQUAL "AllReduce" AND NOT ${ALGO} STREQUAL "RING")
|
||||
continue()
|
||||
elseif((${COLL} STREQUAL "AllGather" OR ${COLL} STREQUAL "Broadcast" OR ${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND (NOT ${REDOP} STREQUAL "Sum" OR NOT ${TYPE} STREQUAL "int8_t"))
|
||||
continue()
|
||||
elseif((${COLL} STREQUAL "SendRecv" OR ${COLL} STREQUAL "AllToAllPivot") AND NOT ${PROTO} STREQUAL "SIMPLE")
|
||||
continue()
|
||||
endif()
|
||||
|
||||
if(${REDOP} STREQUAL "SumPostDiv" AND TYPE IN_LIST FLOATS_LIST)
|
||||
validate_func("${ITEM_LIST}")
|
||||
if (NOT is_valid)
|
||||
continue()
|
||||
endif()
|
||||
|
||||
@@ -149,19 +178,22 @@ function(gen_device_table)
|
||||
set(DEVICE_TABLE_H_FILE "${HIPIFY_DIR}/src/device/device_table.h")
|
||||
message(STATUS "Generating ${DEVICE_TABLE_H_FILE}")
|
||||
|
||||
if(ENABLE_IFC)
|
||||
set(func_declaration "__device__ void")
|
||||
else()
|
||||
set(func_declaration "__device__ __attribute__((noinline)) void")
|
||||
endif()
|
||||
|
||||
## Declaration of device functions
|
||||
foreach(func IN LISTS FUNC_LIST)
|
||||
string(FIND "${func}" "LL128" IS_LL128)
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
endif()
|
||||
if(ENABLE_IFC)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ void ${func}();\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n#else\n")
|
||||
string(REPLACE "LL128" "LL" func "${func}")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n#endif\n")
|
||||
else()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "__device__ __attribute__((noinline)) void ${func}();\n")
|
||||
endif()
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#endif\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "${func_declaration} ${func}();\n")
|
||||
endif()
|
||||
endforeach()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "\n")
|
||||
@@ -173,11 +205,9 @@ function(gen_device_table)
|
||||
string(FIND "${func}" "LL128" IS_LL128)
|
||||
if(NOT IS_LL128 EQUAL -1)
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#if defined(__gfx90a__) && defined(ENABLE_LL128)\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#else\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#else\n")
|
||||
string(REPLACE "LL128" "LL" func "${func}")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} "#endif\n")
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n#endif\n")
|
||||
else()
|
||||
file(APPEND ${DEVICE_TABLE_H_FILE} " ${func},\n")
|
||||
endif()
|
||||
@@ -246,18 +276,13 @@ function(gen_host_table)
|
||||
foreach(proto IN LISTS ALL_PROTOS)
|
||||
foreach(redop IN LISTS ALL_REDOPS)
|
||||
foreach(type IN LISTS ALL_TYPES)
|
||||
if(NOT ${coll} STREQUAL "AllReduce" AND NOT ${algo} STREQUAL "RING")
|
||||
continue()
|
||||
elseif((${coll} STREQUAL "AllGather" OR ${coll} STREQUAL "Broadcast" OR ${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND (NOT ${redop} STREQUAL "Sum" OR NOT ${type} STREQUAL "int8_t"))
|
||||
continue()
|
||||
elseif((${coll} STREQUAL "SendRecv" OR ${coll} STREQUAL "AllToAllPivot") AND NOT ${proto} STREQUAL "SIMPLE")
|
||||
## Create a list from the combination of curr parameters
|
||||
set(ITEM_LIST ${coll} ${algo} ${proto} ${redop} ${type})
|
||||
validate_func("${ITEM_LIST}")
|
||||
if (NOT is_valid)
|
||||
continue()
|
||||
endif()
|
||||
|
||||
if(${redop} STREQUAL "SumPostDiv" AND type IN_LIST FLOATS_LIST)
|
||||
continue()
|
||||
endif()
|
||||
|
||||
## Try to find the combination in the generated func list
|
||||
list(FIND FUNC_LIST "ncclDevFunc_${coll}_${algo}_${proto}_${redop}_${type}" fn_id)
|
||||
if(NOT ${fn_id} EQUAL -1)
|
||||
set(last_valid_fn_id ${fn_id})
|
||||
|
||||
Referencia en una nueva incidencia
Block a user