From 816af44b05a00d3063d8a7745865b359a5fba238 Mon Sep 17 00:00:00 2001 From: German Andryeyev Date: Wed, 20 Nov 2024 10:45:43 -0500 Subject: [PATCH] rocr: Add logic to track the age of events Some KFD versions can return from hsaKmtWaitOnMultipleEvents_Ext without any wait and require the second call without age array init. Change-Id: I8358c33080084d47c273c2a2827085d0570c8201 --- runtime/hsa-runtime/core/runtime/runtime.cpp | 48 ++++++++++++-------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/runtime/hsa-runtime/core/runtime/runtime.cpp b/runtime/hsa-runtime/core/runtime/runtime.cpp index 7f5e0b5ed4..83d303471a 100644 --- a/runtime/hsa-runtime/core/runtime/runtime.cpp +++ b/runtime/hsa-runtime/core/runtime/runtime.cpp @@ -1565,12 +1565,15 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) { uint32_t unique_evts = 0; auto hsa_signals = reinterpret_cast(&async_events_.signal_[0]); - auto processEvent = [&](size_t index, hsa_signal_value_t value) { + auto processEvent = [&](size_t index, hsa_signal_value_t value, bool wait_any) { // No error or timeout occured, process the handlers // Call handler for the known satisfied signal. assert(async_events_.handler_[index] != nullptr); bool keep = async_events_.handler_[index](value, async_events_.arg_[index]); if (!keep) { + if (!wait_any) { + hsa_signals[index]->WaitingDec(); + } hsa_signal_handle(async_events_.signal_[index])->Release(); async_events_.CopyIndex(index, async_events_.Size() - 1); async_events_.PopBack(); @@ -1590,24 +1593,21 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) { }; // Prepares a list of events for a wait inside KFD - auto PrepareInterrupt = [&](size_t idx) { + auto PrepareInterrupt = [&](size_t idx, bool init_age) { HsaEvent* hsa_event = hsa_signals[idx]->EopEvent(); // If any signal doesn't have an interrupt, then switch to polling if (hsa_event == nullptr) { - // Remove decrement from all previous events - for (int e = 0; e < idx; ++e) { - hsa_signals[e]->WaitingDec(); - } unique_evts = 0; return false; } else { - hsa_signals[idx]->WaitingInc(); if (hsa_events.size() <= unique_evts) { hsa_events.resize(unique_evts + 10); event_age.resize(unique_evts + 10); } hsa_events[unique_evts] = hsa_event; - event_age[unique_evts] = runtime_singleton_->KfdVersion().supports_event_age ? 1 : 0; + if (init_age) { + event_age[unique_evts] = runtime_singleton_->KfdVersion().supports_event_age ? 1 : 0; + } unique_evts++; return true; } @@ -1620,9 +1620,6 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) { HsaEvent** end = std::unique(&hsa_events[0], &hsa_events[0] + unique_evts); unique_evts = uint32_t(end - &hsa_events[0]); hsaKmtWaitOnMultipleEvents_Ext(&hsa_events[0], unique_evts, false, wait_ms, &event_age[0]); - for (size_t i = 0; i < async_events_.Size(); i++) { - hsa_signals[i]->WaitingDec(); - } }; while (!async_events_control_.exit) { @@ -1661,13 +1658,23 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) { hsa_signal_handle(async_events_control_.wake)->StoreRelaxed(0); } else if (index != -1) { if (wait_any) { - processEvent(index, value); + processEvent(index, value, wait_any); } else { index = 0; } // Process all signals on the CPU first bool finish = false; bool polling = false; + bool init_age = true; + + // Mark all signals with a waiting tag + // @note: Waiting tag must be marked before the signal state check on CPU to + // avoid a possible race condition between KFD sleep and rocr's awake call + if (!wait_any) { + for (size_t e = 0; e < async_events_.Size(); e++) { + hsa_signals[e]->WaitingInc(); + } + } while (!finish) { // If exception or WaitAny(), then finish with just one iterration if (wait_any) { @@ -1684,28 +1691,25 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) { if (i == 0) { hsa_signal_handle(async_events_control_.wake)->StoreRelaxed(0); } else { - if (!processEvent(i, value)) { + if (!processEvent(i, value, wait_any)) { i--; } } if (!wait_any) { finish = true; + init_age = true; } } // If the current signal isn't complete and polling is disabled, then prepare KFD wait for an interrupt if (!finish && !polling) { - interrupt_wait = PrepareInterrupt(i); + interrupt_wait = PrepareInterrupt(i, init_age); // If the interrupt was disabled, then force polling if (!interrupt_wait) { polling = true; finish = false; } } else if (unique_evts > 0) { - // Remove the waiting tag from events if we found a complete event - for (int e = 0; e < i; ++e) { - hsa_signals[e]->WaitingDec(); - } unique_evts = 0; interrupt_wait = false; } @@ -1713,10 +1717,18 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) { // If nothing was complete and an interrupt wait was requested, then call KFD if (interrupt_wait) { WaitForInterrupt(); + init_age = false; } } } + if (!wait_any) { + // Remove the waiting tags from events + for (size_t e = 0; e < async_events_.Size(); e++) { + hsa_signals[e]->WaitingDec(); + } + } + // Insert new signals and find plain functions typedef std::pair func_arg_t; std::vector functions;