Files
rocm-systems/projects/rocprofiler-sdk/source/lib/python/rocpd/time_window.py
T

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

360 sor
12 KiB
Python

2025-06-06 04:03:14 -05:00
#!/usr/bin/env python3
###############################################################################
# MIT License
#
# Copyright (c) 2025 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
###############################################################################
import argparse
import sqlite3
from argparse import ArgumentParser
from typing import Optional, Tuple, Dict, Any
2025-06-06 04:03:14 -05:00
from .importer import RocpdImportData, execute_statement
__all__ = ["apply_time_window", "execute", "add_args", "main"]
2025-06-06 04:03:14 -05:00
def get_marker_timestamp(
connection: sqlite3.Connection, marker_name: str, marker_type: str = "start"
) -> float:
"""Get timestamp for a specific marker."""
query = "SELECT start FROM markers WHERE name = ?"
result = connection.execute(query, (marker_name,)).fetchall()
if not result:
raise ValueError(
f'ERROR: {marker_type.capitalize()} marker "{marker_name}" not found'
)
if len(result) > 1:
raise ValueError(
f'ERROR: Ambiguous reference - multiple {marker_type} markers found with name "{marker_name}"'
)
return float(result[0][0])
def markers2timestamp(
connection: sqlite3.Connection, start_marker: str, end_marker: str
) -> Tuple[float, float]:
"""Convert marker names to timestamp values."""
start_time = get_marker_timestamp(connection, start_marker, "start")
end_time = get_marker_timestamp(connection, end_marker, "end")
return (start_time, end_time)
def get_min_max_time(connection):
min_max_query = """
SELECT
MIN(min_time) as min_time,
MAX(max_time) as max_time
FROM (
SELECT start as min_time, end as max_time FROM regions_and_samples
UNION ALL
SELECT start as min_time, end as max_time FROM rocpd_kernel_dispatch
UNION ALL
SELECT start as min_time, end as max_time FROM rocpd_memory_allocate
UNION ALL
SELECT start as min_time, end as max_time FROM rocpd_memory_copy
)"""
min_time, max_time = execute_statement(connection, min_max_query).fetchone()
return (min_time, max_time)
def percentages2timestamp(
connection: sqlite3.Connection, start_time: Optional[str], end_time: Optional[str]
) -> Tuple[float, float]:
"""Convert percentage strings or time values to timestamps."""
min_time, max_time = get_min_max_time(connection)
if min_time is None:
raise ValueError(
"ERROR: Cannot create time window - trace file contains no timing data"
)
def convert_time(time_str: Optional[str], is_start: bool = False) -> float:
if not time_str:
return min_time if is_start else max_time
if "%" in time_str:
percentage = float(time_str.replace("%", "")) / 100.0
if not 0 <= percentage <= 1:
raise ValueError(
f"ERROR: Invalid percentage '{time_str}' - must be between '0%' and '100%'"
)
return min_time + ((max_time - min_time) * percentage)
try:
return float(time_str)
except ValueError:
raise ValueError(
f"ERROR: Invalid time value '{time_str}' - must be percentage (e.g., '50%') or a number (nanoseconds since epoch) "
)
return (convert_time(start_time, True), convert_time(end_time, False))
def get_time_filter(inclusive: bool, start_time, end_time) -> str:
"""Create SQL filter for start/end time ranges."""
_beg = int(start_time)
_end = int(end_time)
if inclusive:
return f"start >= {_beg} AND end <= {_end}"
else:
return f"start <= {_end} AND end >= {_beg}"
def get_timestamp_filter(inclusive: bool, start_time, end_time) -> str:
"""Create SQL filter for timestamp columns."""
_beg = int(start_time)
_end = int(end_time)
if inclusive:
return f"timestamp >= {_beg} AND timestamp <= {_end}"
else:
return f"timestamp <= {_end} AND timestamp >= {_beg}"
def create_view(connection: sqlite3.Connection, view_name: str, query: str) -> None:
"""Create or replace a database view."""
execute_statement(connection, f"DROP VIEW IF EXISTS {view_name}")
# print(f"{query}")
execute_statement(connection, query)
connection.commit()
#
# Main processing functions
#
def is_using_markers(args: Dict[str, Any]) -> bool:
"""Check if filtering mode uses markers or time ranges."""
# Add improved null checks
if args.get("start") is not None or args.get("end") is not None:
return False
elif args.get("start_marker") is not None or args.get("end_marker") is not None:
return True
return None
def get_column_names(conn: RocpdImportData, table_name: str):
"""
Use SELECT on zero rows and read cursor.description.
"""
cursor = conn.execute(f"SELECT * FROM '{table_name}' LIMIT 0")
return [desc[0] for desc in cursor.description]
def apply_time_window(connection: RocpdImportData, **kwargs: Any) -> None:
"""Apply time window filtering to create filtered views."""
is_marker_mode = is_using_markers(kwargs)
if is_marker_mode is None:
return connection
inclusive = kwargs.get("inclusive", True)
def dump_min_max(label):
bounds_min, bounds_max = get_min_max_time(connection)
# bounds_min /= 1.0e9
# bounds_max /= 1.0e9
delta = bounds_max - bounds_min
print(
f"# {label:>8} time bounds: {bounds_min} : {bounds_max} nsec (delta={delta} nsec)"
)
return delta
orig_delta = dump_min_max("Initial")
# Get start and end times
if not is_marker_mode:
start_time = kwargs.get("start", None)
end_time = kwargs.get("end", None)
start_time, end_time = percentages2timestamp(connection, start_time, end_time)
else:
start_marker = kwargs.get("start_marker", None)
end_marker = kwargs.get("end_marker", None)
start_time, end_time = markers2timestamp(connection, start_marker, end_marker)
if not end_time > start_time:
raise ValueError(
f"ERROR: Invalid time range - end time ({end_time}) must be greater than start time ({start_time})"
)
# Create views for tables with start and end times
start_end_timed_tables = []
timestamp_timed_tables = []
for itr in connection.table_info.keys():
if itr.find("rocpd_info_") == 0:
continue
column_names = get_column_names(connection, itr)
if "start" in column_names and "end" in column_names:
start_end_timed_tables += [itr]
elif "timestamp" in column_names:
timestamp_timed_tables += [itr]
# Restrict the scope of the tables with start/end columns
for table_name in start_end_timed_tables:
dbs = [
f"{itr} WHERE {get_time_filter(inclusive, start_time, end_time)}"
for itr in connection.table_info[table_name]
]
table_union = " UNION ALL ".join(dbs)
create_view_query = f"""
CREATE TEMPORARY VIEW {table_name} AS
{table_union}
"""
create_view(connection, table_name, create_view_query)
# Restrict the scope of the tables with timestamp columns
for table_name in timestamp_timed_tables:
dbs = [
f"{itr} WHERE {get_timestamp_filter(inclusive, start_time, end_time)}"
for itr in connection.table_info[table_name]
]
table_union = " UNION ALL ".join(dbs)
create_view_query = f"""
CREATE TEMPORARY VIEW {table_name} AS
{table_union}
"""
create_view(connection, table_name, create_view_query)
# # Create node view
# create_view_query = """CREATE VIEW rocpd_node AS """
# selects = [
# f"SELECT rocpd_node.* FROM rocpd_node INNER JOIN {t} ON rocpd_node.id = {t}.node_id"
# for t in start_end_timed_tables
# ]
# create_view_query += " UNION ".join(selects)
# create_view(connection, "rocpd_node", create_view_query)
# # Create track view
# create_view_query = """
# CREATE VIEW rocpd_track AS
# SELECT rocpd_track.* FROM rocpd_track
# INNER JOIN rocpd_sample ON rocpd_sample.track_id = rocpd_track.id
# """
# create_view(connection, "rocpd_track", create_view_query)
upd_delta = dump_min_max("Windowed")
reduction = (1.0 - (upd_delta / orig_delta)) * 100.0
print(f"# Time windowing reduced the duration by {reduction:6.2f}%")
return connection
#
# Command-line interface functions
#
def add_args(parser: ArgumentParser):
2025-06-06 04:03:14 -05:00
"""Add time slice arguments to an existing parser."""
tw_options = parser.add_argument_group("Time window options")
# Start time mutually exclusive group
start_group = tw_options.add_mutually_exclusive_group(required=False)
start_group.add_argument(
"--start",
type=str,
help="Start time as percentage or in nanoseconds from trace file (e.g., '50%%' or '781470909013049')",
default=None,
)
start_group.add_argument(
"--start-marker",
type=str,
help="Named marker event to use as window start point",
default=None,
)
# End time mutually exclusive group
end_group = tw_options.add_mutually_exclusive_group(required=False)
end_group.add_argument(
"--end",
type=str,
help="End time in as percentage or nanoseconds from trace file (e.g., '75%%' or '3543724246381057')",
default=None,
)
end_group.add_argument(
"--end-marker",
type=str,
help="Named marker event to use as window end point",
default=None,
)
tw_options.add_argument(
"--inclusive",
type=lambda x: x.lower() in ("true", "t", "yes", "1"),
help="True: include events if START or END in window; False: only if BOTH in window (default: True)",
default=True,
)
def process_args(input, args):
valid_args = ["start", "end", "inclusive", "start_marker", "end_marker"]
ret = {}
for itr in valid_args:
if hasattr(args, itr):
val = getattr(args, itr)
if val is not None:
ret[itr] = val
2025-06-06 04:03:14 -05:00
if ret and input is not None:
apply_time_window(input, **ret)
2025-06-06 04:03:14 -05:00
return ret
2025-06-06 04:03:14 -05:00
return process_args
2025-06-06 04:03:14 -05:00
def execute(input_rpd: str, **kwargs: Any) -> RocpdImportData:
"""Execute time window filtering on database file."""
importData = RocpdImportData(input_rpd)
apply_time_window(importData, **kwargs)
return importData
def main(argv=None) -> int:
"""Main entry point for command line execution."""
parser = argparse.ArgumentParser(
description="Apply time window filtering to ROCpd database views"
)
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
help="Path to the input ROCpd database file",
)
arg_names = add_args(parser)
args = parser.parse_args(argv)
execute(args.input, **{arg: getattr(args, arg) for arg in arg_names})
if __name__ == "__main__":
main()