Files
rocm-systems/source/lib/python/rocpd/importer.py
T

167 خطوط
6.1 KiB
Python

2025-06-06 04:03:14 -05:00
###############################################################################
# MIT License
#
# Copyright (c) 2023 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
###############################################################################
#
# Utility classes to simplify generating rpd files
#
#
import sys
import sqlite3
from .schema import RocpdSchema
from . import libpyrocpd
__all__ = ["RocpdImportData", "execute_statement"]
class RocpdImportData(libpyrocpd.RocpdImportData):
def __init__(self, input):
if isinstance(input, RocpdImportData):
super(RocpdImportData, self).__init__(input)
self.table_info = input.table_info
else:
if isinstance(input, sqlite3.Connection):
raise ValueError(
"RocpdImportData does not accept existing sqlite3 connections"
)
elif isinstance(input, str):
_connection = libpyrocpd.connect(input)
_filenames = [input]
elif isinstance(input, list) and len(input) > 0 and isinstance(input[0], str):
_connection = libpyrocpd.connect(":memory:")
_filenames = input[:]
_connection.execute("PRAGMA foreign_keys = ON")
self.table_info = _create_temp_views(_connection, input)
_create_meta_views(_connection)
else:
raise ValueError(
f"input is unsupported type. Expected sqlite3.Connection, string, or (non-empty) list of strings. type={type(input).__name__}"
)
super(RocpdImportData, self).__init__(_connection, _filenames)
def __getattr__(self, name):
# any attribute or method not found in RocpdImportData will be looked up on self.connection
return getattr(self.connection, name)
def __enter__(self):
# support "with RocpdImportData(...) as db:":
return self
def __exit__(self, exc_type, exc, tb):
return self.connection.__exit__(exc_type, exc, tb)
def execute_statement(conn, statement, is_script=False):
if isinstance(conn, RocpdImportData):
_conn = conn.connection
else:
_conn = conn
assert isinstance(_conn, sqlite3.Connection)
try:
if is_script:
return _conn.executescript(statement)
return _conn.execute(f"{statement}")
except sqlite3.Error as err:
sys.stderr.write(f"SQLite3 error: {err}\nStatement:\n\t{statement}\n")
sys.stderr.flush()
raise err
def _create_temp_views(connection, input):
"""Create temporary unified views from multiple database files."""
assert isinstance(connection, sqlite3.Connection)
assert isinstance(input, list)
# Attach each database and extract the uuid from each database
dbinfo = []
uuids = []
for i, inp in enumerate(input):
execute_statement(connection, f"ATTACH DATABASE '{inp}' AS db{i}")
_uuids = [
itr[0]
for itr in execute_statement(
connection,
f"SELECT value FROM db{i}.rocpd_metadata WHERE tag='uuid'",
).fetchall()
]
dbinfo += [f"db{i}"]
uuids += [itr for itr in _uuids if itr not in uuids]
# unique set of universal process identifiers
uuids = list(set(uuids))
all_tables = {}
for ditr in dbinfo:
# get the tables for the given attached database
tables = [
itr[0]
for itr in execute_statement(
connection,
f"SELECT name FROM {ditr}.sqlite_master WHERE type='table' AND name LIKE 'rocpd_%'",
).fetchall()
]
# loop over the tables
for itr in tables:
# loop over the UUIDs
for uitr in uuids:
# skip the tables without the UUID suffix
if f"{uitr}" not in itr:
continue
# strip the UUID suffix to create a base table name, e.g. 'rocpd_string_03daf93' -> 'rocpd_string'
base = itr.replace(f"{uitr}", "")
# create a list of attached databases which have the base table name
if base not in all_tables.keys():
all_tables[base] = []
# create the SELECT statement from this database
select = f"SELECT * FROM {ditr}.{base}"
# make sure that we don't duplicate SELECT statements of same table from same attached database
if select in all_tables[base]:
continue
# add this to list
all_tables[base] += [select]
# create the temporary view that is a union of all the attached databases
for key, itr in all_tables.items():
stmt = "CREATE TEMPORARY VIEW {} AS {}".format(key, " UNION ALL ".join(itr))
execute_statement(connection, stmt)
return all_tables
def _create_meta_views(connection):
schema = RocpdSchema()
sql_script = schema.views.replace("CREATE VIEW", "CREATE TEMPORARY VIEW")
execute_statement(connection, sql_script, is_script=True)