rocr/aie: Increment write pointer upon packet submission
[ROCm/ROCR-Runtime commit: 2d2c47bdef]
This commit is contained in:
committed by
Papadopoulos, Yiannis
parent
13cdca7fb3
commit
96b7e42776
@@ -137,10 +137,8 @@ class AieAqlQueue : public core::Queue,
|
||||
/// @brief Base of the queue's ring buffer storage.
|
||||
void *ring_buf_ = nullptr;
|
||||
|
||||
/// @brief Called when the doorbell is rung to iterate over
|
||||
/// all packets and submit them. Submission is done by
|
||||
/// calling into the XdnaDriver.
|
||||
hsa_status_t SubmitCmd(void* queue_base, uint64_t read_dispatch_id, uint64_t write_dispatch_id);
|
||||
/// @brief Called when the doorbell is rung to submit all queued packets.
|
||||
void SubmitPackets();
|
||||
|
||||
/// @brief Handle for an application context on the AIE device.
|
||||
///
|
||||
|
||||
@@ -53,6 +53,7 @@
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
|
||||
#include "core/inc/amd_xdna_driver.h"
|
||||
@@ -195,32 +196,35 @@ uint64_t AieAqlQueue::AddWriteIndexAcqRel(uint64_t value) {
|
||||
std::memory_order_acq_rel);
|
||||
}
|
||||
|
||||
void AieAqlQueue::StoreRelaxed(hsa_signal_value_t value) {
|
||||
SubmitCmd(amd_queue_.hsa_queue.base_address, amd_queue_.read_dispatch_id,
|
||||
amd_queue_.write_dispatch_id);
|
||||
}
|
||||
void AieAqlQueue::StoreRelaxed(hsa_signal_value_t value) { SubmitPackets(); }
|
||||
|
||||
void AieAqlQueue::SubmitPackets() {
|
||||
if (!active_.load(std::memory_order_relaxed)) {
|
||||
return;
|
||||
}
|
||||
|
||||
hsa_status_t AieAqlQueue::SubmitCmd(void* queue_base, uint64_t read_dispatch_id,
|
||||
uint64_t write_dispatch_id) {
|
||||
auto& driver = static_cast<XdnaDriver&>(agent_.driver());
|
||||
uint64_t cur_id = read_dispatch_id;
|
||||
while (cur_id < write_dispatch_id) {
|
||||
hsa_amd_aie_ert_packet_t* pkt = static_cast<hsa_amd_aie_ert_packet_t*>(queue_base) + cur_id;
|
||||
void* queue_base = amd_queue_.hsa_queue.base_address;
|
||||
|
||||
uint64_t cur_id = LoadReadIndexRelaxed();
|
||||
const uint64_t end = LoadWriteIndexAcquire();
|
||||
while (cur_id < end) {
|
||||
auto* pkt = static_cast<hsa_amd_aie_ert_packet_t*>(queue_base) + cur_id;
|
||||
|
||||
// Get the packet header information
|
||||
if (pkt->header.header != HSA_PACKET_TYPE_VENDOR_SPECIFIC ||
|
||||
pkt->header.AmdFormat != HSA_AMD_PACKET_TYPE_AIE_ERT)
|
||||
return HSA_STATUS_ERROR;
|
||||
pkt->header.AmdFormat != HSA_AMD_PACKET_TYPE_AIE_ERT) {
|
||||
assert(false && "Invalid packet header");
|
||||
}
|
||||
|
||||
// Get the payload information
|
||||
switch (pkt->opcode) {
|
||||
case HSA_AMD_AIE_ERT_START_CU: {
|
||||
// Iterating over future packets and seeing how many contiguous HSA_AMD_AIE_ERT_START_CU
|
||||
// packets there are. All can be combined into a single chain.
|
||||
int num_cont_start_cu_pkts = 1;
|
||||
for (int peak_pkt_id = cur_id + 1; peak_pkt_id < write_dispatch_id; peak_pkt_id++) {
|
||||
hsa_amd_aie_ert_packet_t* peak_pkt =
|
||||
static_cast<hsa_amd_aie_ert_packet_t*>(queue_base) + peak_pkt_id;
|
||||
uint64_t num_cont_start_cu_pkts = 1;
|
||||
for (uint64_t peak_pkt_id = cur_id + 1; peak_pkt_id < end; peak_pkt_id++) {
|
||||
auto* peak_pkt = static_cast<hsa_amd_aie_ert_packet_t*>(queue_base) + peak_pkt_id;
|
||||
if (peak_pkt->opcode != HSA_AMD_AIE_ERT_START_CU) {
|
||||
break;
|
||||
}
|
||||
@@ -231,18 +235,18 @@ hsa_status_t AieAqlQueue::SubmitCmd(void* queue_base, uint64_t read_dispatch_id,
|
||||
// Submitting the command chain might create a new hardware context.
|
||||
hsa_status_t status = driver.SubmitCmdChain(pkt, num_cont_start_cu_pkts, *this);
|
||||
if (status != HSA_STATUS_SUCCESS) {
|
||||
return status;
|
||||
assert(false && "Could not submit packets");
|
||||
}
|
||||
|
||||
cur_id += num_cont_start_cu_pkts;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return HSA_STATUS_ERROR;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return HSA_STATUS_ERROR;
|
||||
atomic::Store(&amd_queue_.read_dispatch_id, cur_id, std::memory_order_release);
|
||||
}
|
||||
|
||||
void AieAqlQueue::StoreRelease(hsa_signal_value_t value) {
|
||||
|
||||
Reference in New Issue
Block a user