Removed extra marker when syncing graph streams back to the launch stream (#2823)
This commit is contained in:
committato da
GitHub
parent
7f5e443e44
commit
a66c6ca156
@@ -1417,6 +1417,11 @@ amd::Command* GraphExec::EnqueueSegmentedGraph(hip::Stream* launch_stream,
|
||||
// Map to track the last enqueued command for each segment for dependency tracking
|
||||
// This is critical for handling cross-level dependencies with stream reuse
|
||||
std::unordered_map<int, amd::Command*> segment_last_command;
|
||||
// Set of segment IDs that have already been explicitly synchronized to the
|
||||
// launch_stream via an earlier cross-stream wait marker. These segments can be
|
||||
// safely excluded from the final "sync all streams to launch_stream" step to
|
||||
// avoid inserting redundant markers.
|
||||
std::unordered_set<int> segments_synced_to_launch;
|
||||
|
||||
// Process segments level by level using the pre-calculated max_dependency_level_
|
||||
for (int level = 0; level <= max_dependency_level_; ++level) {
|
||||
@@ -1453,6 +1458,9 @@ amd::Command* GraphExec::EnqueueSegmentedGraph(hip::Stream* launch_stream,
|
||||
// Retain command before adding to wait list for proper lifetime management
|
||||
cmd_it->second->retain();
|
||||
wait_list.push_back(cmd_it->second);
|
||||
if (current_stream == launch_stream) {
|
||||
segments_synced_to_launch.insert(dep_segment_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1503,21 +1511,24 @@ amd::Command* GraphExec::EnqueueSegmentedGraph(hip::Stream* launch_stream,
|
||||
|
||||
for (const auto& pair : segment_last_command) {
|
||||
int seg_id = pair.first;
|
||||
amd::Command* cmd = pair.second;
|
||||
auto stream_it = segment_to_stream.find(seg_id);
|
||||
if (stream_it != segment_to_stream.end()) {
|
||||
hip::Stream* stream = stream_it->second;
|
||||
int seg_dependency_level = segments_[seg_id].dependency_level;
|
||||
|
||||
// Only update if this segment is at a strictly higher level
|
||||
// Using strict > ensures deterministic behavior when multiple segments
|
||||
// are at the same level on the same stream
|
||||
auto level_it = stream_max_level.find(stream);
|
||||
if (level_it == stream_max_level.end() ||
|
||||
seg_dependency_level > level_it->second) {
|
||||
stream_max_level[stream] = seg_dependency_level;
|
||||
stream_last_command_map[stream] = cmd;
|
||||
}
|
||||
auto stream_it = segment_to_stream.find(seg_id);
|
||||
if (segments_synced_to_launch.find(seg_id) != segments_synced_to_launch.end() ||
|
||||
stream_it == segment_to_stream.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
amd::Command* cmd = pair.second;
|
||||
hip::Stream* stream = stream_it->second;
|
||||
int seg_dependency_level = segments_[seg_id].dependency_level;
|
||||
|
||||
// Only update if this segment is at a strictly higher level
|
||||
// Using strict > ensures deterministic behavior when multiple segments
|
||||
// are at the same level on the same stream
|
||||
auto level_it = stream_max_level.find(stream);
|
||||
if (level_it == stream_max_level.end() || seg_dependency_level > level_it->second) {
|
||||
stream_max_level[stream] = seg_dependency_level;
|
||||
stream_last_command_map[stream] = cmd;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user