From e4d027191ca527b40b7972e23b2f50d2a1cb4a2a Mon Sep 17 00:00:00 2001 From: zichguan-amd Date: Tue, 4 Mar 2025 11:38:51 -0500 Subject: [PATCH] rocr: Allow 0/NULL/invalid signal handles for wait operations to be no-op Remove hard assertions for signal validation on hsa_amd_signal_wait_* operations, instead ignore 0/NULL/invalid signals in the dependency condition evaluation to align with HSA specs for barrier-AND and barrier-OR packets. Signed-off-by: zichguan-amd --- .../hsa-runtime/core/runtime/hsa_ext_amd.cpp | 66 ++++++++++++++----- runtime/hsa-runtime/inc/hsa_ext_amd.h | 14 ++-- 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp b/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp index acd492a984..a7acbdcbec 100644 --- a/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp +++ b/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp @@ -578,22 +578,42 @@ uint32_t hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* hsa_signal TRY; if (!core::Runtime::runtime_singleton_->IsOpen()) { assert(false && "hsa_amd_signal_wait_all called while not initialized."); - return 0; + return uint32_t(0); } - // Do not check for signal invalidation. Invalidation may occur during async - // signal handler loop and is not an error. - for (int i = 0; i < signal_count; ++i) - assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() && - "Invalid signal."); - std::vector satisfying_values_vec; - satisfying_values_vec.resize(signal_count); + // Treat NULL and invalid signals as already satisfied their condition and skip them + std::vector valid_signals; + std::vector valid_signal_ids; + for (uint32_t i = 0; i < signal_count; i++){ + if (hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid()){ + valid_signals.emplace_back(hsa_signals[i]); + valid_signal_ids.emplace_back(i); + } + } + + // Return if there's no valid signal to wait on + if (valid_signals.empty()){ + if (satisfying_values) { + // Set 0 as satisfying value for NULL and invalid signals + std::fill(satisfying_values, satisfying_values + signal_count, 0); + } + return uint32_t(0); + } + + uint32_t valid_signal_count = valid_signals.size(); + + std::vector satisfying_values_vec(valid_signal_count); uint32_t first_satysifying_signal_idx = - core::Signal::WaitMultiple(signal_count, hsa_signals, conds, values, timeout_hint, wait_hint, + core::Signal::WaitMultiple(valid_signal_count, valid_signals.data(), conds, values, timeout_hint, wait_hint, satisfying_values_vec, true); if (satisfying_values) { - std::copy(satisfying_values_vec.begin(), satisfying_values_vec.end(), satisfying_values); + // Set 0 as satisfying value for NULL and invalid signals + std::vector satisfying_values_vec_result(signal_count, 0); + for (uint32_t i = 0; i < valid_signal_count; i++){ + satisfying_values_vec_result[valid_signal_ids[i]] = satisfying_values_vec[i]; + } + std::copy(satisfying_values_vec_result.begin(), satisfying_values_vec_result.end(), satisfying_values); } return first_satysifying_signal_idx; @@ -609,16 +629,30 @@ uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* hsa_signal assert(false && "hsa_amd_signal_wait_any called while not initialized."); return uint32_t(0); } - // Do not check for signal invalidation. Invalidation may occur during async - // signal handler loop and is not an error. - for (uint i = 0; i < signal_count; i++) - assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() && - "Invalid signal."); + + // Ignore NULL and invalid signals + std::vector valid_signals; + std::vector valid_signal_ids; + for (uint32_t i = 0; i < signal_count; i++){ + if (hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid()){ + valid_signals.emplace_back(hsa_signals[i]); + valid_signal_ids.emplace_back(i); + } + } + + // Return if there's no valid signal to wait on + // satisfying_value is ignored + if (valid_signals.empty()){ + return std::numeric_limits::max(); + } std::vector satisfying_value_vec(1); uint32_t satisfying_signal_idx = - core::Signal::WaitMultiple(signal_count, hsa_signals, conds, values, timeout_hint, wait_hint, + core::Signal::WaitMultiple(valid_signals.size(), valid_signals.data(), conds, values, timeout_hint, wait_hint, satisfying_value_vec, false); + + // Map back the index + satisfying_signal_idx = valid_signal_ids[satisfying_signal_idx]; if (satisfying_value) *satisfying_value = satisfying_value_vec.at(0); diff --git a/runtime/hsa-runtime/inc/hsa_ext_amd.h b/runtime/hsa-runtime/inc/hsa_ext_amd.h index 772896e80b..222287f108 100644 --- a/runtime/hsa-runtime/inc/hsa_ext_amd.h +++ b/runtime/hsa-runtime/inc/hsa_ext_amd.h @@ -1209,8 +1209,9 @@ hsa_status_t HSA_API * @details Allows waiting for all of several signal and condition pairs to be * satisfied. The function returns 0 if all signals met their conditions and -1 * on a timeout. The value of each signal's satisfying value is returned in - * satisfying_value unless satisfying_value is nullptr. This function provides - * only relaxed memory semantics. + * satisfying_value unless satisfying_value is nullptr. NULL and invalid signals + * are considered to have value 0 and their conditions already satisfied. This + * function provides only relaxed memory semantics. */ uint32_t HSA_API hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* signals, hsa_signal_condition_t* conds, hsa_signal_value_t* values, @@ -1222,9 +1223,12 @@ uint32_t HSA_API hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* si * * @details Allows waiting for any of several signal and conditions pairs to be * satisfied. The function returns the index into the list of signals of the - * first satisfying signal-condition pair. The value of the satisfying signal's - * value is returned in satisfying_value unless satisfying_value is NULL. This - * function provides only relaxed memory semantics. + * first satisfying signal-condition pair. The function returns + * std::numeric_limits::max() if no valid signal is provided. The value + * of the satisfying signal's value is returned in satisfying_value, unless + * satisfying_value is nullptr or there's no valid signal in the signal-condition + * pairs. NULL and invalid signals are ignored. This function provides only + * relaxed memory semantics. */ uint32_t HSA_API hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* signals,