SWDEV-489084 - Avoid using queue colliding with the graph launch stream

Change-Id: I3ecaf8836c8e0883441275139041c702aba0937e


[ROCm/clr commit: 06e6561eb5]
Этот коммит содержится в:
Anusha GodavarthySurya
2024-11-12 13:14:23 +00:00
коммит произвёл Anusha Godavarthy Surya
родитель f421f02546
Коммит c34f55babb
6 изменённых файлов: 24 добавлений и 14 удалений
+13 -8
Просмотреть файл
@@ -413,7 +413,8 @@ hipError_t GraphExec::Init() {
}
status = CreateStreams(parallelLists_.size() - 1 + min_num_streams);
} else {
status = CreateStreams(clonedGraph_->max_streams_);
// create extra stream to avoid queue collision with the default execution stream
status = CreateStreams(clonedGraph_->max_streams_ + 1);
}
if (status != hipSuccess) {
return status;
@@ -638,20 +639,24 @@ hipError_t EnqueueGraphWithSingleList(std::vector<hip::Node>& topoOrder, hip::St
}
// ================================================================================================
void Graph::UpdateStreams(
hip::Stream* launch_stream,
const std::vector<hip::Stream*>& parallel_streams) {
void Graph::UpdateStreams(hip::Stream* launch_stream,
const std::vector<hip::Stream*>& parallel_streams) {
// Allocate array for parallel streams, based on the graph scheduling + current stream
streams_.resize(parallel_streams.size() + 1);
// We create extra stream to avoid collision
streams_.resize(parallel_streams.size());
// Current stream is the default in the assignment
streams_[0] = launch_stream;
// Assign the streams in the array of all streams
for (uint32_t i = 0; i < parallel_streams.size(); ++i) {
streams_[i + 1] = parallel_streams[i];
// Avoid stream that has collision with launch stream
for (uint32_t i = 1, j = 0; i < streams_.size(); j++) {
assert(j != parallel_streams.size());
if (launch_stream->getQueueID() != parallel_streams[j]->getQueueID()) {
streams_[i++] = parallel_streams[j];
}
}
}
// ================================================================================================
bool Graph::RunOneNode(Node node, bool wait) {
if (node->launch_id_ == -1) {
+4 -5
Просмотреть файл
@@ -31,7 +31,7 @@ namespace hip {
Stream::Stream(hip::Device* dev, Priority p, unsigned int f, bool null_stream,
const std::vector<uint32_t>& cuMask, hipStreamCaptureStatus captureStatus)
: amd::HostQueue(*dev->asContext(), *dev->devices()[0], 0, amd::CommandQueue::RealTimeDisabled,
convertToQueuePriority(p), cuMask),
convertToQueuePriority(p), cuMask),
lock_("Stream Callback lock"),
device_(dev),
priority_(p),
@@ -40,10 +40,9 @@ Stream::Stream(hip::Device* dev, Priority p, unsigned int f, bool null_stream,
cuMask_(cuMask),
captureStatus_(captureStatus),
originStream_(false),
captureID_(0)
{
device_->AddStream(this);
}
captureID_(0) {
device_->AddStream(this);
}
// ================================================================================================
hipError_t Stream::EndCapture() {
+1 -1
Просмотреть файл
@@ -1276,7 +1276,7 @@ class VirtualDevice : public amd::HeapObject {
//! Return the physical device for this virtual device.
const amd::Device& device() const { return device_(); }
virtual uint64_t getQueueID() = 0;
virtual void submitReadMemory(amd::ReadMemoryCommand& cmd) = 0;
virtual void submitWriteMemory(amd::WriteMemoryCommand& cmd) = 0;
virtual void submitCopyMemory(amd::CopyMemoryCommand& cmd) = 0;
+1
Просмотреть файл
@@ -311,6 +311,7 @@ class VirtualGPU : public device::VirtualDevice {
amd::CommandQueue::Priority priority = amd::CommandQueue::Priority::Normal);
~VirtualGPU();
uint64_t getQueueID() { return hwRing_; }
void submitReadMemory(amd::ReadMemoryCommand& vcmd);
void submitWriteMemory(amd::WriteMemoryCommand& vcmd);
void submitCopyMemory(amd::CopyMemoryCommand& vcmd);
+1
Просмотреть файл
@@ -433,6 +433,7 @@ class VirtualGPU : public device::VirtualDevice {
void setLastUsedSdmaEngine(uint32_t mask) { lastUsedSdmaEngineMask_ = mask; }
uint32_t getLastUsedSdmaEngine() const { return lastUsedSdmaEngineMask_.load(); }
uint64_t getQueueID() { return gpu_queue_->id; }
// } roc OpenCL integration
private:
+4
Просмотреть файл
@@ -294,6 +294,10 @@ class HostQueue : public CommandQueue {
//! Get queue status
bool GetQueueStatus() { return isActive_; }
uint64_t getQueueID() {
return thread_.vdev()->getQueueID();
}
private:
Command* head_; //!< Head of the batch list
Command* tail_; //!< Tail of the batch list