355 líneas
12 KiB
Python
355 líneas
12 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
###############################################################################
|
||
|
|
# MIT License
|
||
|
|
#
|
||
|
|
# Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||
|
|
#
|
||
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||
|
|
# of this software and associated documentation files (the "Software"), to deal
|
||
|
|
# in the Software without restriction, including without limitation the rights
|
||
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||
|
|
# copies of the Software, and to permit persons to whom the Software is
|
||
|
|
# furnished to do so, subject to the following conditions:
|
||
|
|
#
|
||
|
|
# The above copyright notice and this permission notice shall be included in
|
||
|
|
# all copies or substantial portions of the Software.
|
||
|
|
#
|
||
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||
|
|
# THE SOFTWARE.
|
||
|
|
###############################################################################
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import sqlite3
|
||
|
|
from argparse import ArgumentParser
|
||
|
|
from typing import Optional, Tuple, Dict, Any, List
|
||
|
|
|
||
|
|
from .importer import RocpdImportData, execute_statement
|
||
|
|
|
||
|
|
|
||
|
|
def get_marker_timestamp(
|
||
|
|
connection: sqlite3.Connection, marker_name: str, marker_type: str = "start"
|
||
|
|
) -> float:
|
||
|
|
"""Get timestamp for a specific marker."""
|
||
|
|
query = "SELECT start FROM markers WHERE name = ?"
|
||
|
|
result = connection.execute(query, (marker_name,)).fetchall()
|
||
|
|
|
||
|
|
if not result:
|
||
|
|
raise ValueError(
|
||
|
|
f'ERROR: {marker_type.capitalize()} marker "{marker_name}" not found'
|
||
|
|
)
|
||
|
|
if len(result) > 1:
|
||
|
|
raise ValueError(
|
||
|
|
f'ERROR: Ambiguous reference - multiple {marker_type} markers found with name "{marker_name}"'
|
||
|
|
)
|
||
|
|
|
||
|
|
return float(result[0][0])
|
||
|
|
|
||
|
|
|
||
|
|
def markers2timestamp(
|
||
|
|
connection: sqlite3.Connection, start_marker: str, end_marker: str
|
||
|
|
) -> Tuple[float, float]:
|
||
|
|
"""Convert marker names to timestamp values."""
|
||
|
|
start_time = get_marker_timestamp(connection, start_marker, "start")
|
||
|
|
end_time = get_marker_timestamp(connection, end_marker, "end")
|
||
|
|
return (start_time, end_time)
|
||
|
|
|
||
|
|
|
||
|
|
def get_min_max_time(connection):
|
||
|
|
min_max_query = """
|
||
|
|
SELECT
|
||
|
|
MIN(min_time) as min_time,
|
||
|
|
MAX(max_time) as max_time
|
||
|
|
FROM (
|
||
|
|
SELECT start as min_time, end as max_time FROM regions_and_samples
|
||
|
|
UNION ALL
|
||
|
|
SELECT start as min_time, end as max_time FROM rocpd_kernel_dispatch
|
||
|
|
UNION ALL
|
||
|
|
SELECT start as min_time, end as max_time FROM rocpd_memory_allocate
|
||
|
|
UNION ALL
|
||
|
|
SELECT start as min_time, end as max_time FROM rocpd_memory_copy
|
||
|
|
)"""
|
||
|
|
|
||
|
|
min_time, max_time = execute_statement(connection, min_max_query).fetchone()
|
||
|
|
return (min_time, max_time)
|
||
|
|
|
||
|
|
|
||
|
|
def percentages2timestamp(
|
||
|
|
connection: sqlite3.Connection, start_time: Optional[str], end_time: Optional[str]
|
||
|
|
) -> Tuple[float, float]:
|
||
|
|
"""Convert percentage strings or time values to timestamps."""
|
||
|
|
|
||
|
|
min_time, max_time = get_min_max_time(connection)
|
||
|
|
|
||
|
|
if min_time is None:
|
||
|
|
raise ValueError(
|
||
|
|
"ERROR: Cannot create time window - trace file contains no timing data"
|
||
|
|
)
|
||
|
|
|
||
|
|
def convert_time(time_str: Optional[str], is_start: bool = False) -> float:
|
||
|
|
if not time_str:
|
||
|
|
return min_time if is_start else max_time
|
||
|
|
|
||
|
|
if "%" in time_str:
|
||
|
|
percentage = float(time_str.replace("%", "")) / 100.0
|
||
|
|
if not 0 <= percentage <= 1:
|
||
|
|
raise ValueError(
|
||
|
|
f"ERROR: Invalid percentage '{time_str}' - must be between '0%' and '100%'"
|
||
|
|
)
|
||
|
|
return min_time + ((max_time - min_time) * percentage)
|
||
|
|
|
||
|
|
try:
|
||
|
|
return float(time_str)
|
||
|
|
except ValueError:
|
||
|
|
raise ValueError(
|
||
|
|
f"ERROR: Invalid time value '{time_str}' - must be percentage (e.g., '50%') or a number (nanoseconds since epoch) "
|
||
|
|
)
|
||
|
|
|
||
|
|
return (convert_time(start_time, True), convert_time(end_time, False))
|
||
|
|
|
||
|
|
|
||
|
|
def get_time_filter(inclusive: bool, start_time, end_time) -> str:
|
||
|
|
"""Create SQL filter for start/end time ranges."""
|
||
|
|
_beg = int(start_time)
|
||
|
|
_end = int(end_time)
|
||
|
|
if inclusive:
|
||
|
|
return f"start >= {_beg} AND end <= {_end}"
|
||
|
|
else:
|
||
|
|
return f"start <= {_end} AND end >= {_beg}"
|
||
|
|
|
||
|
|
|
||
|
|
def get_timestamp_filter(inclusive: bool, start_time, end_time) -> str:
|
||
|
|
"""Create SQL filter for timestamp columns."""
|
||
|
|
_beg = int(start_time)
|
||
|
|
_end = int(end_time)
|
||
|
|
if inclusive:
|
||
|
|
return f"timestamp >= {_beg} AND timestamp <= {_end}"
|
||
|
|
else:
|
||
|
|
return f"timestamp <= {_end} AND timestamp >= {_beg}"
|
||
|
|
|
||
|
|
|
||
|
|
def create_view(connection: sqlite3.Connection, view_name: str, query: str) -> None:
|
||
|
|
"""Create or replace a database view."""
|
||
|
|
execute_statement(connection, f"DROP VIEW IF EXISTS {view_name}")
|
||
|
|
# print(f"{query}")
|
||
|
|
execute_statement(connection, query)
|
||
|
|
connection.commit()
|
||
|
|
|
||
|
|
|
||
|
|
#
|
||
|
|
# Main processing functions
|
||
|
|
#
|
||
|
|
def is_using_markers(args: Dict[str, Any]) -> bool:
|
||
|
|
"""Check if filtering mode uses markers or time ranges."""
|
||
|
|
# Add improved null checks
|
||
|
|
if args.get("start") is not None or args.get("end") is not None:
|
||
|
|
return False
|
||
|
|
elif args.get("start_marker") is not None or args.get("end_marker") is not None:
|
||
|
|
return True
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def get_column_names(conn: RocpdImportData, table_name: str):
|
||
|
|
"""
|
||
|
|
Use SELECT on zero rows and read cursor.description.
|
||
|
|
"""
|
||
|
|
cursor = conn.execute(f"SELECT * FROM '{table_name}' LIMIT 0")
|
||
|
|
return [desc[0] for desc in cursor.description]
|
||
|
|
|
||
|
|
|
||
|
|
def apply_time_window(connection: RocpdImportData, **kwargs: Any) -> None:
|
||
|
|
"""Apply time window filtering to create filtered views."""
|
||
|
|
|
||
|
|
is_marker_mode = is_using_markers(kwargs)
|
||
|
|
if is_marker_mode is None:
|
||
|
|
return connection
|
||
|
|
|
||
|
|
inclusive = kwargs.get("inclusive", True)
|
||
|
|
|
||
|
|
def dump_min_max(label):
|
||
|
|
bounds_min, bounds_max = get_min_max_time(connection)
|
||
|
|
# bounds_min /= 1.0e9
|
||
|
|
# bounds_max /= 1.0e9
|
||
|
|
delta = bounds_max - bounds_min
|
||
|
|
print(
|
||
|
|
f"# {label:>8} time bounds: {bounds_min} : {bounds_max} nsec (delta={delta} nsec)"
|
||
|
|
)
|
||
|
|
return delta
|
||
|
|
|
||
|
|
orig_delta = dump_min_max("Initial")
|
||
|
|
|
||
|
|
# Get start and end times
|
||
|
|
if not is_marker_mode:
|
||
|
|
start_time = kwargs.get("start", None)
|
||
|
|
end_time = kwargs.get("end", None)
|
||
|
|
start_time, end_time = percentages2timestamp(connection, start_time, end_time)
|
||
|
|
else:
|
||
|
|
start_marker = kwargs.get("start_marker", None)
|
||
|
|
end_marker = kwargs.get("end_marker", None)
|
||
|
|
start_time, end_time = markers2timestamp(connection, start_marker, end_marker)
|
||
|
|
|
||
|
|
if not end_time > start_time:
|
||
|
|
raise ValueError(
|
||
|
|
f"ERROR: Invalid time range - end time ({end_time}) must be greater than start time ({start_time})"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create views for tables with start and end times
|
||
|
|
start_end_timed_tables = []
|
||
|
|
timestamp_timed_tables = []
|
||
|
|
|
||
|
|
for itr in connection.table_info.keys():
|
||
|
|
if itr.find("rocpd_info_") == 0:
|
||
|
|
continue
|
||
|
|
column_names = get_column_names(connection, itr)
|
||
|
|
if "start" in column_names and "end" in column_names:
|
||
|
|
start_end_timed_tables += [itr]
|
||
|
|
elif "timestamp" in column_names:
|
||
|
|
timestamp_timed_tables += [itr]
|
||
|
|
|
||
|
|
# Restrict the scope of the tables with start/end columns
|
||
|
|
for table_name in start_end_timed_tables:
|
||
|
|
dbs = [
|
||
|
|
f"{itr} WHERE {get_time_filter(inclusive, start_time, end_time)}"
|
||
|
|
for itr in connection.table_info[table_name]
|
||
|
|
]
|
||
|
|
table_union = " UNION ALL ".join(dbs)
|
||
|
|
create_view_query = f"""
|
||
|
|
CREATE TEMPORARY VIEW {table_name} AS
|
||
|
|
{table_union}
|
||
|
|
"""
|
||
|
|
create_view(connection, table_name, create_view_query)
|
||
|
|
|
||
|
|
# Restrict the scope of the tables with timestamp columns
|
||
|
|
for table_name in timestamp_timed_tables:
|
||
|
|
dbs = [
|
||
|
|
f"{itr} WHERE {get_timestamp_filter(inclusive, start_time, end_time)}"
|
||
|
|
for itr in connection.table_info[table_name]
|
||
|
|
]
|
||
|
|
table_union = " UNION ALL ".join(dbs)
|
||
|
|
create_view_query = f"""
|
||
|
|
CREATE TEMPORARY VIEW {table_name} AS
|
||
|
|
{table_union}
|
||
|
|
"""
|
||
|
|
create_view(connection, table_name, create_view_query)
|
||
|
|
|
||
|
|
# # Create node view
|
||
|
|
# create_view_query = """CREATE VIEW rocpd_node AS """
|
||
|
|
# selects = [
|
||
|
|
# f"SELECT rocpd_node.* FROM rocpd_node INNER JOIN {t} ON rocpd_node.id = {t}.node_id"
|
||
|
|
# for t in start_end_timed_tables
|
||
|
|
# ]
|
||
|
|
# create_view_query += " UNION ".join(selects)
|
||
|
|
# create_view(connection, "rocpd_node", create_view_query)
|
||
|
|
|
||
|
|
# # Create track view
|
||
|
|
# create_view_query = """
|
||
|
|
# CREATE VIEW rocpd_track AS
|
||
|
|
# SELECT rocpd_track.* FROM rocpd_track
|
||
|
|
# INNER JOIN rocpd_sample ON rocpd_sample.track_id = rocpd_track.id
|
||
|
|
# """
|
||
|
|
# create_view(connection, "rocpd_track", create_view_query)
|
||
|
|
|
||
|
|
upd_delta = dump_min_max("Windowed")
|
||
|
|
|
||
|
|
reduction = (1.0 - (upd_delta / orig_delta)) * 100.0
|
||
|
|
print(f"# Time windowing reduced the duration by {reduction:6.2f}%")
|
||
|
|
|
||
|
|
return connection
|
||
|
|
|
||
|
|
|
||
|
|
#
|
||
|
|
# Command-line interface functions
|
||
|
|
#
|
||
|
|
def add_args(parser: ArgumentParser) -> List[str]:
|
||
|
|
"""Add time slice arguments to an existing parser."""
|
||
|
|
|
||
|
|
tw_options = parser.add_argument_group("Time window options")
|
||
|
|
|
||
|
|
# Start time mutually exclusive group
|
||
|
|
start_group = tw_options.add_mutually_exclusive_group(required=False)
|
||
|
|
start_group.add_argument(
|
||
|
|
"--start",
|
||
|
|
type=str,
|
||
|
|
help="Start time as percentage or in nanoseconds from trace file (e.g., '50%%' or '781470909013049')",
|
||
|
|
default=None,
|
||
|
|
)
|
||
|
|
start_group.add_argument(
|
||
|
|
"--start-marker",
|
||
|
|
type=str,
|
||
|
|
help="Named marker event to use as window start point",
|
||
|
|
default=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
# End time mutually exclusive group
|
||
|
|
end_group = tw_options.add_mutually_exclusive_group(required=False)
|
||
|
|
end_group.add_argument(
|
||
|
|
"--end",
|
||
|
|
type=str,
|
||
|
|
help="End time in as percentage or nanoseconds from trace file (e.g., '75%%' or '3543724246381057')",
|
||
|
|
default=None,
|
||
|
|
)
|
||
|
|
end_group.add_argument(
|
||
|
|
"--end-marker",
|
||
|
|
type=str,
|
||
|
|
help="Named marker event to use as window end point",
|
||
|
|
default=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
tw_options.add_argument(
|
||
|
|
"--inclusive",
|
||
|
|
type=lambda x: x.lower() in ("true", "t", "yes", "1"),
|
||
|
|
help="True: include events if START or END in window; False: only if BOTH in window (default: True)",
|
||
|
|
default=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
return ["start", "end", "inclusive", "start_marker", "end_marker"]
|
||
|
|
|
||
|
|
|
||
|
|
def process_args(args, valid_args):
|
||
|
|
|
||
|
|
ret = {}
|
||
|
|
for itr in valid_args:
|
||
|
|
if hasattr(args, itr):
|
||
|
|
val = getattr(args, itr)
|
||
|
|
if val is not None:
|
||
|
|
ret[itr] = val
|
||
|
|
return ret
|
||
|
|
|
||
|
|
|
||
|
|
def execute(input_rpd: str, **kwargs: Any) -> RocpdImportData:
|
||
|
|
"""Execute time window filtering on database file."""
|
||
|
|
|
||
|
|
importData = RocpdImportData(input_rpd)
|
||
|
|
|
||
|
|
apply_time_window(importData, **kwargs)
|
||
|
|
|
||
|
|
return importData
|
||
|
|
|
||
|
|
|
||
|
|
def main(argv=None) -> int:
|
||
|
|
"""Main entry point for command line execution."""
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Apply time window filtering to ROCpd database views"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-i",
|
||
|
|
"--input",
|
||
|
|
type=str,
|
||
|
|
required=True,
|
||
|
|
help="Path to the input ROCpd database file",
|
||
|
|
)
|
||
|
|
|
||
|
|
arg_names = add_args(parser)
|
||
|
|
args = parser.parse_args(argv)
|
||
|
|
|
||
|
|
execute(args.input, **{arg: getattr(args, arg) for arg in arg_names})
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|