rocr: Add logic to track the age of events

Some KFD versions can return from hsaKmtWaitOnMultipleEvents_Ext without
any wait and require the second call without age array init.

Change-Id: I8358c33080084d47c273c2a2827085d0570c8201
此提交包含在:
German Andryeyev
2024-11-20 10:45:43 -05:00
父節點 6f6ee9679c
當前提交 816af44b05
+30 -18
查看文件
@@ -1565,12 +1565,15 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
uint32_t unique_evts = 0;
auto hsa_signals = reinterpret_cast<hsa_signal_handle*>(&async_events_.signal_[0]);
auto processEvent = [&](size_t index, hsa_signal_value_t value) {
auto processEvent = [&](size_t index, hsa_signal_value_t value, bool wait_any) {
// No error or timeout occured, process the handlers
// Call handler for the known satisfied signal.
assert(async_events_.handler_[index] != nullptr);
bool keep = async_events_.handler_[index](value, async_events_.arg_[index]);
if (!keep) {
if (!wait_any) {
hsa_signals[index]->WaitingDec();
}
hsa_signal_handle(async_events_.signal_[index])->Release();
async_events_.CopyIndex(index, async_events_.Size() - 1);
async_events_.PopBack();
@@ -1590,24 +1593,21 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
};
// Prepares a list of events for a wait inside KFD
auto PrepareInterrupt = [&](size_t idx) {
auto PrepareInterrupt = [&](size_t idx, bool init_age) {
HsaEvent* hsa_event = hsa_signals[idx]->EopEvent();
// If any signal doesn't have an interrupt, then switch to polling
if (hsa_event == nullptr) {
// Remove decrement from all previous events
for (int e = 0; e < idx; ++e) {
hsa_signals[e]->WaitingDec();
}
unique_evts = 0;
return false;
} else {
hsa_signals[idx]->WaitingInc();
if (hsa_events.size() <= unique_evts) {
hsa_events.resize(unique_evts + 10);
event_age.resize(unique_evts + 10);
}
hsa_events[unique_evts] = hsa_event;
event_age[unique_evts] = runtime_singleton_->KfdVersion().supports_event_age ? 1 : 0;
if (init_age) {
event_age[unique_evts] = runtime_singleton_->KfdVersion().supports_event_age ? 1 : 0;
}
unique_evts++;
return true;
}
@@ -1620,9 +1620,6 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
HsaEvent** end = std::unique(&hsa_events[0], &hsa_events[0] + unique_evts);
unique_evts = uint32_t(end - &hsa_events[0]);
hsaKmtWaitOnMultipleEvents_Ext(&hsa_events[0], unique_evts, false, wait_ms, &event_age[0]);
for (size_t i = 0; i < async_events_.Size(); i++) {
hsa_signals[i]->WaitingDec();
}
};
while (!async_events_control_.exit) {
@@ -1661,13 +1658,23 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
hsa_signal_handle(async_events_control_.wake)->StoreRelaxed(0);
} else if (index != -1) {
if (wait_any) {
processEvent(index, value);
processEvent(index, value, wait_any);
} else {
index = 0;
}
// Process all signals on the CPU first
bool finish = false;
bool polling = false;
bool init_age = true;
// Mark all signals with a waiting tag
// @note: Waiting tag must be marked before the signal state check on CPU to
// avoid a possible race condition between KFD sleep and rocr's awake call
if (!wait_any) {
for (size_t e = 0; e < async_events_.Size(); e++) {
hsa_signals[e]->WaitingInc();
}
}
while (!finish) {
// If exception or WaitAny(), then finish with just one iterration
if (wait_any) {
@@ -1684,28 +1691,25 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
if (i == 0) {
hsa_signal_handle(async_events_control_.wake)->StoreRelaxed(0);
} else {
if (!processEvent(i, value)) {
if (!processEvent(i, value, wait_any)) {
i--;
}
}
if (!wait_any) {
finish = true;
init_age = true;
}
}
// If the current signal isn't complete and polling is disabled, then prepare KFD wait for an interrupt
if (!finish && !polling) {
interrupt_wait = PrepareInterrupt(i);
interrupt_wait = PrepareInterrupt(i, init_age);
// If the interrupt was disabled, then force polling
if (!interrupt_wait) {
polling = true;
finish = false;
}
} else if (unique_evts > 0) {
// Remove the waiting tag from events if we found a complete event
for (int e = 0; e < i; ++e) {
hsa_signals[e]->WaitingDec();
}
unique_evts = 0;
interrupt_wait = false;
}
@@ -1713,10 +1717,18 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
// If nothing was complete and an interrupt wait was requested, then call KFD
if (interrupt_wait) {
WaitForInterrupt();
init_age = false;
}
}
}
if (!wait_any) {
// Remove the waiting tags from events
for (size_t e = 0; e < async_events_.Size(); e++) {
hsa_signals[e]->WaitingDec();
}
}
// Insert new signals and find plain functions
typedef std::pair<void (*)(void*), void*> func_arg_t;
std::vector<func_arg_t> functions;