rocr/aie: Avoiding XdnaDriver class in queue API

Dieser Commit ist enthalten in:
Yiannis Papadopoulos
2025-03-21 11:46:10 -04:00
committet von Papadopoulos, Yiannis
Ursprung 8dcbbf31c7
Commit f4e1c9b0ba
2 geänderte Dateien mit 25 neuen und 19 gelöschten Zeilen
@@ -3,7 +3,7 @@
// The University of Illinois/NCSA
// Open Source License (NCSA)
//
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
//
// Developed by:
//
@@ -53,15 +53,13 @@
namespace rocr {
namespace AMD {
class XdnaDriver;
/// @brief Encapsulates HW AIE AQL Command Processor functionality. It
/// provides the interface for things such as doorbells, queue read and
/// write pointers, and a buffer.
class AieAqlQueue : public core::Queue,
private core::LocalSignal,
core::DoorbellSignal {
public:
public:
static __forceinline bool IsType(core::Signal *signal) {
return signal->IsType(&rtti_id());
}
@@ -70,7 +68,6 @@ public:
return queue->IsType(&rtti_id());
}
AieAqlQueue() = delete;
AieAqlQueue(AieAgent *agent, size_t req_size_pkts, uint32_t node_id);
~AieAqlQueue();
@@ -101,13 +98,20 @@ public:
void *value) override;
// AIE-specific API
AieAgent &GetAgent() { return agent_; }
/// @brief Returns the agent associated with this queue.
AieAgent& GetAgent() { return agent_; }
/// @brief Sets the hardware context.
void SetHwCtxHandle(uint32_t hw_ctx_handle) {
hw_ctx_handle_ = hw_ctx_handle;
}
/// @brief Returns the hardware context.
uint32_t GetHwCtxHandle() const { return hw_ctx_handle_; }
// GPU-specific queue functions are unsupported.
hsa_status_t GetCUMasking(uint32_t num_cu_mask_count,
uint32_t *cu_mask) override;
hsa_status_t SetCUMasking(uint32_t num_cu_mask_count,
@@ -117,26 +121,26 @@ public:
hsa_fence_scope_t releaseFence = HSA_FENCE_SCOPE_NONE,
hsa_signal_t *signal = NULL) override;
private:
HSA_QUEUEID queue_id_ = INVALID_QUEUEID;
/// @brief ID of AIE device on which this queue has been mapped.
uint32_t node_id_ = std::numeric_limits<uint32_t>::max();
/// @brief Queue size in bytes.
uint32_t queue_size_bytes_ = std::numeric_limits<uint32_t>::max();
protected:
protected:
bool _IsA(Queue::rtti_t id) const override { return id == &rtti_id(); }
private:
private:
AieAgent &agent_;
/// @brief Base of the queue's ring buffer storage.
void *ring_buf_ = nullptr;
/// @brief Called when the doorbell is rung to iterate over
/// all packets and submit them. Submissions is done by
// calling into the XdnaDriver.
hsa_status_t SubmitCmd(XdnaDriver& driver, void* queue_base, uint64_t read_dispatch_id,
uint64_t write_dispatch_id);
/// all packets and submit them. Submission is done by
/// calling into the XdnaDriver.
hsa_status_t SubmitCmd(void* queue_base, uint64_t read_dispatch_id, uint64_t write_dispatch_id);
/// @brief Handle for an application context on the AIE device.
///
@@ -3,7 +3,7 @@
// The University of Illinois/NCSA
// Open Source License (NCSA)
//
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
//
// Developed by:
//
@@ -196,13 +196,13 @@ uint64_t AieAqlQueue::AddWriteIndexAcqRel(uint64_t value) {
}
void AieAqlQueue::StoreRelaxed(hsa_signal_value_t value) {
auto& driver = static_cast<XdnaDriver&>(agent_.driver());
SubmitCmd(driver, amd_queue_.hsa_queue.base_address, amd_queue_.read_dispatch_id,
SubmitCmd(amd_queue_.hsa_queue.base_address, amd_queue_.read_dispatch_id,
amd_queue_.write_dispatch_id);
}
hsa_status_t AieAqlQueue::SubmitCmd(XdnaDriver& driver, void* queue_base, uint64_t read_dispatch_id,
hsa_status_t AieAqlQueue::SubmitCmd(void* queue_base, uint64_t read_dispatch_id,
uint64_t write_dispatch_id) {
auto& driver = static_cast<XdnaDriver&>(agent_.driver());
uint64_t cur_id = read_dispatch_id;
while (cur_id < write_dispatch_id) {
hsa_amd_aie_ert_packet_t* pkt = static_cast<hsa_amd_aie_ert_packet_t*>(queue_base) + cur_id;
@@ -230,9 +230,11 @@ hsa_status_t AieAqlQueue::SubmitCmd(XdnaDriver& driver, void* queue_base, uint64
}
// Call into the driver to submit from cur_id to write_dispatch_id
if (driver.SubmitCmdChain(pkt, num_cont_start_cu_pkts, num_operands, hw_ctx_handle_) !=
HSA_STATUS_SUCCESS)
return HSA_STATUS_ERROR;
hsa_status_t status =
driver.SubmitCmdChain(pkt, num_cont_start_cu_pkts, num_operands, hw_ctx_handle_);
if (status != HSA_STATUS_SUCCESS) {
return status;
}
cur_id += num_cont_start_cu_pkts;
break;