363 wiersze
13 KiB
Python
363 wiersze
13 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
###############################################################################
|
||
|
|
# MIT License
|
||
|
|
#
|
||
|
|
# Copyright (c) 2025 Advanced Micro Devices, Inc.
|
||
|
|
#
|
||
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||
|
|
# of this software and associated documentation files (the "Software"), to deal
|
||
|
|
# in the Software without restriction, including without limitation the rights
|
||
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||
|
|
# copies of the Software, and to permit persons to whom the Software is
|
||
|
|
# furnished to do so, subject to the following conditions:
|
||
|
|
#
|
||
|
|
# The above copyright notice and this permission notice shall be included in
|
||
|
|
# all copies or substantial portions of the Software.
|
||
|
|
#
|
||
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||
|
|
# THE SOFTWARE.
|
||
|
|
###############################################################################
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
import sqlite3
|
||
|
|
import time
|
||
|
|
|
||
|
|
from typing import List, Dict, Iterable, Optional, Callable, Any
|
||
|
|
|
||
|
|
|
||
|
|
def merge_sqlite_dbs(
|
||
|
|
sources: Iterable[str],
|
||
|
|
dest_path: str,
|
||
|
|
on_log: Optional[Callable[[str], None]] = None,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Merge multiple SQLite databases into a single destination database.
|
||
|
|
|
||
|
|
Parameters
|
||
|
|
----------
|
||
|
|
sources : Iterable[str]
|
||
|
|
Paths to source databases.
|
||
|
|
dest_path : str
|
||
|
|
Path to destination database.
|
||
|
|
on_log : Optional[Callable[[str], None]]
|
||
|
|
Logger function; defaults to None. Pass `print` to generate logs.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def log(msg: str) -> None:
|
||
|
|
if on_log:
|
||
|
|
on_log(f" {msg}")
|
||
|
|
|
||
|
|
sources = list(sources)
|
||
|
|
if not sources:
|
||
|
|
raise ValueError("No source databases provided")
|
||
|
|
|
||
|
|
# Prepare output directory
|
||
|
|
output_dir = os.path.dirname(os.path.abspath(dest_path)) or os.getcwd()
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# Remove existing file
|
||
|
|
if os.path.isfile(dest_path):
|
||
|
|
os.remove(dest_path)
|
||
|
|
|
||
|
|
uuids = []
|
||
|
|
views = []
|
||
|
|
data_views = []
|
||
|
|
schema_versions = []
|
||
|
|
|
||
|
|
with sqlite3.connect(str(dest_path)) as conn:
|
||
|
|
conn.execute("PRAGMA journal_mode = WAL;")
|
||
|
|
conn.execute("PRAGMA synchronous = NORMAL;")
|
||
|
|
conn.execute("PRAGMA foreign_keys = OFF;") # defer FK checks until end
|
||
|
|
|
||
|
|
# One big atomic transaction
|
||
|
|
with conn:
|
||
|
|
# Attach sources one by one
|
||
|
|
for i, src in enumerate(sources, 1):
|
||
|
|
alias = f"src{i}"
|
||
|
|
conn.execute(f"ATTACH DATABASE ? AS {alias}", (src,))
|
||
|
|
print(f"Adding {src}")
|
||
|
|
log(f"Attached {src} AS {alias}")
|
||
|
|
|
||
|
|
# UUIDs and schema version
|
||
|
|
_uuids = [
|
||
|
|
itr[0]
|
||
|
|
for itr in conn.execute(
|
||
|
|
f"SELECT value FROM {alias}.rocpd_metadata WHERE tag='uuid'",
|
||
|
|
).fetchall()
|
||
|
|
]
|
||
|
|
uuids += [itr for itr in _uuids if itr not in uuids]
|
||
|
|
|
||
|
|
_schema_versions = [
|
||
|
|
itr[0]
|
||
|
|
for itr in conn.execute(
|
||
|
|
f"SELECT value FROM {alias}.rocpd_metadata WHERE tag='schema_version'",
|
||
|
|
).fetchall()
|
||
|
|
]
|
||
|
|
schema_versions += _schema_versions
|
||
|
|
|
||
|
|
# Helper: fetch rows from attached sqlite_master
|
||
|
|
def fetch_master(_alias: str, kind: str):
|
||
|
|
cur = conn.execute(
|
||
|
|
f"""
|
||
|
|
SELECT name, sql
|
||
|
|
FROM {_alias}.sqlite_master
|
||
|
|
WHERE type = ? AND name NOT LIKE 'sqlite_%'
|
||
|
|
ORDER BY name
|
||
|
|
""",
|
||
|
|
(kind,),
|
||
|
|
)
|
||
|
|
return cur.fetchall()
|
||
|
|
|
||
|
|
# Track dest tables to detect collisions quickly
|
||
|
|
existing_tables = {
|
||
|
|
row[0]
|
||
|
|
for row in conn.execute(
|
||
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
# 1) Create tables
|
||
|
|
for name, create_sql in fetch_master(alias, "table"):
|
||
|
|
if name in existing_tables:
|
||
|
|
raise AssertionError(
|
||
|
|
f"Table name collision for '{name}' from {alias}; "
|
||
|
|
"assumption of globally-unique table names violated."
|
||
|
|
)
|
||
|
|
if not create_sql:
|
||
|
|
continue
|
||
|
|
log(f"Creating table {name}")
|
||
|
|
conn.execute(create_sql)
|
||
|
|
existing_tables.add(name)
|
||
|
|
|
||
|
|
# 2) Copy table data
|
||
|
|
tbls = [name for name, _ in fetch_master(alias, "table")]
|
||
|
|
print(f"Tables found: {len(tbls)}")
|
||
|
|
for name in tbls:
|
||
|
|
log(f"Inserting rows into {name} from {alias}.{name}")
|
||
|
|
rows = conn.execute(f'SELECT * FROM {alias}."{name}"').fetchall()
|
||
|
|
if rows:
|
||
|
|
col_count = len(rows[0])
|
||
|
|
placeholders = ", ".join(["?"] * col_count)
|
||
|
|
conn.executemany(
|
||
|
|
f'INSERT INTO "{name}" VALUES ({placeholders})', rows
|
||
|
|
)
|
||
|
|
|
||
|
|
# 3) Recreate indexes (make idempotent with IF NOT EXISTS)
|
||
|
|
def inject_if_not_exists_in_index_sql(sql: str) -> str:
|
||
|
|
# Naive, but works for standard forms produced by sqlite_master
|
||
|
|
# Handles UNIQUE and non-UNIQUE:
|
||
|
|
# "CREATE INDEX name ON ..." or "CREATE UNIQUE INDEX name ON ..."
|
||
|
|
sql_stripped = sql.strip()
|
||
|
|
if sql_stripped.upper().startswith("CREATE UNIQUE INDEX"):
|
||
|
|
return sql_stripped.replace(
|
||
|
|
"CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1
|
||
|
|
)
|
||
|
|
if sql_stripped.upper().startswith("CREATE INDEX"):
|
||
|
|
return sql_stripped.replace(
|
||
|
|
"CREATE INDEX", "CREATE INDEX IF NOT EXISTS", 1
|
||
|
|
)
|
||
|
|
return sql
|
||
|
|
|
||
|
|
existing_indexes = {
|
||
|
|
row[0]
|
||
|
|
for row in conn.execute(
|
||
|
|
"SELECT name FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%'"
|
||
|
|
)
|
||
|
|
}
|
||
|
|
for name, create_sql in fetch_master(alias, "index"):
|
||
|
|
if not create_sql:
|
||
|
|
continue
|
||
|
|
if name in existing_indexes:
|
||
|
|
log(f"Index {name} exists; skipping or using IF NOT EXISTS")
|
||
|
|
# Try to create with IF NOT EXISTS to avoid collision
|
||
|
|
sql2 = inject_if_not_exists_in_index_sql(create_sql)
|
||
|
|
conn.execute(sql2)
|
||
|
|
existing_indexes.add(name)
|
||
|
|
|
||
|
|
# 4) Recreate triggers (skip on name conflict)
|
||
|
|
existing_triggers = {
|
||
|
|
row[0]
|
||
|
|
for row in conn.execute(
|
||
|
|
"SELECT name FROM sqlite_master WHERE type='trigger'"
|
||
|
|
)
|
||
|
|
}
|
||
|
|
for name, create_sql in fetch_master(alias, "trigger"):
|
||
|
|
if not create_sql:
|
||
|
|
continue
|
||
|
|
if name in existing_triggers:
|
||
|
|
log(f"Trigger {name} exists; skipping")
|
||
|
|
continue
|
||
|
|
log(f"Creating trigger {name}")
|
||
|
|
conn.execute(create_sql)
|
||
|
|
existing_triggers.add(name)
|
||
|
|
|
||
|
|
# 5) Recreate views (skip on name conflict)
|
||
|
|
existing_views = {
|
||
|
|
row[0]
|
||
|
|
for row in conn.execute(
|
||
|
|
"SELECT name FROM sqlite_master WHERE type='view'"
|
||
|
|
)
|
||
|
|
}
|
||
|
|
for name, create_sql in fetch_master(alias, "view"):
|
||
|
|
if not create_sql:
|
||
|
|
continue
|
||
|
|
if name in existing_views:
|
||
|
|
log(f"View {name} exists; skipping")
|
||
|
|
continue
|
||
|
|
# If the view name does not start with "rocpd_", collect it for later recreation
|
||
|
|
if not name.startswith("rocpd_") and not any(
|
||
|
|
name == _name for _name, _ in data_views
|
||
|
|
):
|
||
|
|
data_views.append((name, create_sql))
|
||
|
|
existing_views.add(name)
|
||
|
|
|
||
|
|
views += [itr for itr in list(existing_views) if itr.startswith("rocpd_")]
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.execute(f"DETACH DATABASE {alias}")
|
||
|
|
log(f"Detached {alias}")
|
||
|
|
|
||
|
|
# Check the schema versions. Merge only occurs if all the DBs are the same schema version.
|
||
|
|
unique_versions = list(set(schema_versions))
|
||
|
|
if len(unique_versions) != 1:
|
||
|
|
raise RuntimeError(f"Multiple schema versions found: {unique_versions}")
|
||
|
|
|
||
|
|
# Re-enable FKs and run a quick FK check
|
||
|
|
conn.execute("PRAGMA foreign_keys = ON;")
|
||
|
|
# Optional: enforce integrity
|
||
|
|
# try:
|
||
|
|
# conn.execute("PRAGMA quick_check;")
|
||
|
|
# except sqlite3.DatabaseError as e:
|
||
|
|
# log(f"SQLite3 quick_check reported an issue: {e}")
|
||
|
|
|
||
|
|
uuids = sorted(list(set(uuids))) # unique set of uuids
|
||
|
|
views = sorted(list(set(views))) # unique set of views
|
||
|
|
|
||
|
|
# Create UNION views by listing all tables
|
||
|
|
existing_tables = {
|
||
|
|
row[0]
|
||
|
|
for row in conn.execute(
|
||
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
# Then UNION all the tables starting with the view name
|
||
|
|
for vitr in views:
|
||
|
|
matching_tables = [
|
||
|
|
titr for titr in existing_tables if titr.startswith(f"{vitr}_")
|
||
|
|
]
|
||
|
|
tables_union = " UNION ALL ".join(
|
||
|
|
[f"SELECT * FROM {titr}" for titr in matching_tables]
|
||
|
|
)
|
||
|
|
conn.execute(f"CREATE VIEW {vitr} AS {tables_union}")
|
||
|
|
conn.commit()
|
||
|
|
|
||
|
|
# Now that the rocpd_ views are created, re-create the data-views using all the data
|
||
|
|
for _, sql_view in data_views:
|
||
|
|
conn.execute(sql_view)
|
||
|
|
conn.commit()
|
||
|
|
|
||
|
|
|
||
|
|
#
|
||
|
|
# Command-line interface functions
|
||
|
|
#
|
||
|
|
def add_args(parser):
|
||
|
|
"""Add arguments for merger."""
|
||
|
|
|
||
|
|
io_options = parser.add_argument_group("I/O options")
|
||
|
|
|
||
|
|
io_options.add_argument(
|
||
|
|
"-o",
|
||
|
|
"--output-file",
|
||
|
|
help="Sets the base output file name",
|
||
|
|
default=os.environ.get("ROCPD_OUTPUT_NAME", "merged"),
|
||
|
|
type=str,
|
||
|
|
required=False,
|
||
|
|
)
|
||
|
|
io_options.add_argument(
|
||
|
|
"-d",
|
||
|
|
"--output-path",
|
||
|
|
help="Sets the output path where the output files will be saved (default path: `./rocpd-output-data`)",
|
||
|
|
default=os.environ.get("ROCPD_OUTPUT_PATH", "./rocpd-output-data"),
|
||
|
|
type=str,
|
||
|
|
required=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
def process_args(input, args):
|
||
|
|
valid_args = ["output_file", "output_path"]
|
||
|
|
ret = {}
|
||
|
|
for itr in valid_args:
|
||
|
|
if hasattr(args, itr):
|
||
|
|
val = getattr(args, itr)
|
||
|
|
if val is not None:
|
||
|
|
ret[itr] = val
|
||
|
|
return ret
|
||
|
|
|
||
|
|
return process_args
|
||
|
|
|
||
|
|
|
||
|
|
def execute(inputs: List[str], **kwargs: Dict[str, Any]) -> str:
|
||
|
|
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
input_files = inputs
|
||
|
|
try:
|
||
|
|
from . import package
|
||
|
|
|
||
|
|
input_files = package.flatten_rocpd_yaml_input_file(inputs, skip_auto_merge=True)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Import error trying to use package, fallback to use inputs: {e}")
|
||
|
|
|
||
|
|
output_path = kwargs.get("output_path")
|
||
|
|
output_filename = kwargs.get("output_file")
|
||
|
|
if not output_filename.endswith(".db"):
|
||
|
|
output_filename += ".db"
|
||
|
|
output = os.path.join(output_path, output_filename)
|
||
|
|
|
||
|
|
merge_sqlite_dbs(input_files, output)
|
||
|
|
|
||
|
|
elapsed_time = time.time() - start_time
|
||
|
|
|
||
|
|
print(f"Merge completed successfully! Output saved to: {output}")
|
||
|
|
print(f"Time: {elapsed_time:.2f} sec")
|
||
|
|
return str(output)
|
||
|
|
|
||
|
|
|
||
|
|
def main(argv=None) -> int:
|
||
|
|
"""Main entry point for command line execution."""
|
||
|
|
|
||
|
|
from . import output_config
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Generate merged database from rocPD databases"
|
||
|
|
)
|
||
|
|
|
||
|
|
required_params = parser.add_argument_group("Required options")
|
||
|
|
|
||
|
|
required_params.add_argument(
|
||
|
|
"-i",
|
||
|
|
"--input",
|
||
|
|
required=True,
|
||
|
|
type=output_config.check_file_exists,
|
||
|
|
nargs="+",
|
||
|
|
help="Path to the input ROCpd database files",
|
||
|
|
)
|
||
|
|
|
||
|
|
process_args = add_args(parser)
|
||
|
|
|
||
|
|
args = parser.parse_args(argv)
|
||
|
|
|
||
|
|
merge_args = process_args(args)
|
||
|
|
|
||
|
|
execute(args.input, **merge_args)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|