Reset state of repo with updated README

Bu işleme şunda yer alıyor:
amd-jmacaran
2025-07-15 12:29:35 -04:00
işleme c3f8f57c80
43 değiştirilmiş dosya ile 4718 ekleme ve 0 silme
+64
Dosyayı Görüntüle
@@ -0,0 +1,64 @@
import os
import sys
import yaml
import requests
def get_existing_labels(repo, token):
headers = {"Authorization": f"token {token}"}
labels = {}
page = 1
while True:
url = f"https://api.github.com/repos/{repo}/labels?page={page}&per_page=100"
resp = requests.get(url, headers=headers)
if resp.status_code != 200:
raise Exception(f"Failed to fetch existing labels: {resp.text}")
data = resp.json()
if not data:
break
for label in data:
labels[label["name"]] = {
"color": label["color"],
"description": label.get("description", "")
}
page += 1
return labels
def create_or_update_label(repo, token, label, existing):
headers = {
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json"
}
if label["name"] not in existing:
# Create label
print(f"Creating label: {label['name']}")
url = f"https://api.github.com/repos/{repo}/labels"
resp = requests.post(url, json=label, headers=headers)
else:
# Update if different
current = existing[label["name"]]
if (label["color"].lower() != current["color"].lower() or
label.get("description", "") != current.get("description", "")):
print(f"Updating label: {label['name']}")
url = f"https://api.github.com/repos/{repo}/labels/{label['name']}"
resp = requests.patch(url, json=label, headers=headers)
else:
print(f"Label '{label['name']}' already up to date. Skipping.")
return
if not resp.ok:
print(f"Failed to apply label {label['name']}: {resp.status_code} {resp.text}")
def main(label_file):
token = os.environ["GH_TOKEN"]
repo = os.environ["GITHUB_REPO"]
existing = get_existing_labels(repo, token)
with open(label_file, "r") as f:
labels = yaml.safe_load(f)
for label in labels:
create_or_update_label(repo, token, label, existing)
if __name__ == "__main__":
main(sys.argv[1])
+117
Dosyayı Görüntüle
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
"""
Azure Pipeline Resolver Script
------------------------------
This script determines which Azure pipelines to run based on changed subtrees.
Using a predefined dependency map, the script resolves which projects need to be processed,
skipping those that will be covered by their dependencies.
Steps:
1. Load a list of changed projects from a file.
2. Consult a dependency map to determine transitive and direct dependencies.
3. Identify projects that should be processed, excluding those handled by dependencies.
4. Output the list of projects to be run, along with their Azure pipeline IDs.
Arguments:
--subtree-file : Path to the file containing a newline-separated list of changed subtrees.
Outputs:
Prints a newline-separated list of "project_name=definition_id" for the projects that need
to be processed, where `definition_id` is the Azure pipeline ID associated with the project.
Example Usage:
To determine which pipelines to run given the changed subtrees listed in a file:
python azure_pipeline_resolver.py --subtree-file changed_subtrees.txt
"""
import argparse
from typing import List, Optional
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Given a list of changed subtrees, determine which Azure pipelines to run.")
parser.add_argument("--subtree-file", required=True,
help="Path to the file containing changed subtrees")
return parser.parse_args(argv)
def read_file_into_set(file_path):
"""Reads the project names from the file into a set."""
with open(file_path, 'r') as file:
return {line.strip() for line in file}
def resolve_dependencies(projects, dependencies):
"""Resolves projects to be run by checking all levels of dependencies."""
def has_dependency(project, projects_set):
"""Recursively checks if a project has any dependencies in the projects_set."""
if project not in dependencies:
return False
for dependency in dependencies[project]:
if dependency in projects_set or has_dependency(dependency, projects_set):
return True
return False
projects_to_run = set(projects)
for project in projects:
if has_dependency(project, projects_to_run):
projects_to_run.discard(project)
return projects_to_run
def main(argv=None) -> None:
"""Main function to process the projects and output those to be run."""
# Mathlib build+test dependency tree as defined in Azure CI and TheRock
math_dependencies = {
"shared/tensile": {},
"projects/rocrand": {},
"projects/hiprand": {"projects/rocrand"},
"projects/rocfft": {"projects/hiprand"},
"projects/hipfft": {"projects/rocfft"},
"projects/rocprim": {},
"projects/hipcub": {"projects/rocprim"},
"projects/rocthrust": {"projects/rocprim"},
"projects/hipblas-common": {},
"projects/hipblaslt": {"projects/hipblas-common"},
"projects/rocblas": {"projects/hipblaslt"},
"projects/rocsolver": {"projects/rocprim", "projects/rocblas"},
"projects/rocsparse": {"projects/rocprim", "projects/rocblas"},
"projects/hipblas": {"projects/rocsolver"},
"projects/hipsolver": {"projects/rocsolver", "projects/rocsparse"},
"projects/hipsparse": {"projects/rocsparse"},
"projects/hipsparselt": {"projects/hipsparse"},
"projects/miopen": {"projects/rocrand", "projects/hipblas"}
}
# Azure pipeline IDs for each project, to be populated as projects are enabled
definition_ids = {
"shared/tensile": 305,
"projects/rocrand": 274,
"projects/hiprand": 275,
"projects/rocfft": 282,
"projects/hipfft": 283,
"projects/rocprim": 273,
"projects/hipcub": 277,
"projects/rocthrust": 276,
"projects/hipblas-common": 300,
"projects/hipblaslt": 301,
"projects/hipsparselt": 309,
"projects/rocblas": 302,
"projects/rocsolver": 303,
}
args = parse_arguments(argv)
projects = read_file_into_set(args.subtree_file)
projects_to_run = resolve_dependencies(projects, math_dependencies)
for project in projects_to_run:
if project in definition_ids:
print(f"{project}={definition_ids[project]}")
if __name__ == "__main__":
main()
+48
Dosyayı Görüntüle
@@ -0,0 +1,48 @@
import json
import os
import sys
import requests
import yaml
def get_labels(repo, token):
headers = {"Authorization": f"token {token}"}
labels = []
page = 1
while True:
url = f"https://api.github.com/repos/{repo}/labels?page={page}&per_page=100"
resp = requests.get(url, headers=headers)
if resp.status_code != 200:
raise Exception(f"Failed to fetch labels from {repo}: {resp.text}")
data = resp.json()
if not data:
break
labels.extend(data)
page += 1
return labels
def main(file_path):
with open(file_path, "r") as f:
repos_data = json.load(f)["repositories"]
token = os.environ["GH_TOKEN"]
all_labels = {}
for repo_entry in repos_data:
repo_url = repo_entry["url"]
print(f"Collecting labels from {repo_url}")
for label in get_labels(repo_url, token):
name = label["name"]
if name not in all_labels:
all_labels[name] = {
"name": name,
"color": label["color"],
"description": label.get("description", "")
}
sorted_labels = sorted(all_labels.values(), key=lambda l: l["name"].lower())
os.makedirs(".github", exist_ok=True) # Ensure the .github directory exists
with open(".github/labels.yml", "w") as out:
yaml.dump(sorted_labels, out, sort_keys=False)
if __name__ == "__main__":
main(sys.argv[1])
+18
Dosyayı Görüntüle
@@ -0,0 +1,18 @@
import json
import sys
import logging
from typing import List
from repo_config_model import RepoConfig, RepoEntry
logger = logging.getLogger(__name__)
def load_repo_config(config_path: str) -> List[RepoEntry]:
"""Load and validate repository config from JSON using Pydantic."""
try:
with open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
config = RepoConfig(**data)
return config.repositories
except Exception as e:
logger.error(f"Failed to load or validate config file '{config_path}': {e}")
sys.exit(1)
+282
Dosyayı Görüntüle
@@ -0,0 +1,282 @@
#!/usr/bin/env python3
"""
GitHub Client Utility
---------------------
This utility provides a GitHubClient class that wraps GitHub REST API operations
used across automation scripts, such as retrieving pull request file changes and labels.
When doing manual testing, you can run the same REST API calls through curl in the terminal.
These REST API URLs, without the authentication header, will be output by the debug logging.
This includes:
- Fetching PR details
- Creating PRs
- Closing PRs
Requirements:
- NOTE: GH_TOKEN environment variable hands authentication token to this script in a runner.
- The token is created by the GitHub App and is passed to the script via the environment variable.
Manual curl testing:
To fetch PR details:
curl -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
https://api.github.com/repos/OWNER/REPO/pulls/NUMBER
To list PRs by head branch:
curl -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
"https://api.github.com/repos/OWNER/REPO/pulls?head=OWNER:branch-name&state=open"
To fetch changed files in a PR:
curl -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
https://api.github.com/repos/OWNER/REPO/pulls/NUMBER/files
To create a PR:
curl -X POST -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
https://api.github.com/repos/OWNER/REPO/pulls \
-d '{"title":"Title","body":"Description","head":"branch-name","base":"main"}'
To apply labels:
curl -X POST -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
https://api.github.com/repos/OWNER/REPO/issues/NUMBER/labels \
-d '{"labels": ["bug", "needs-review"]}'
"""
import os
import requests
import time
import logging
from typing import List, Optional
logger = logging.getLogger(__name__)
class GitHubCLIClient:
def __init__(self) -> None:
"""Initialize the GitHub API client using GitHub App authentication."""
self.api_url = "https://api.github.com"
self.session = requests.Session()
self.session.headers.update({
"Authorization": f"Bearer {self._get_token()}",
"Accept": "application/vnd.github+json",
})
def _get_token(self) -> str:
"""Helper method to retrieve the GitHub token from environment variable."""
token = os.getenv("GH_TOKEN")
if not token:
raise EnvironmentError("GH_TOKEN environment variable is not set")
return token
def _get_with_retries(self, url: str, error_msg: str, retries: int = 3,
backoff: int = 2, timeout: int = 10) -> Optional[requests.Response]:
"""Internal helper to retry a GET request with exponential backoff."""
# no logging the actual request to avoid leaking sensitive information
for attempt in range(retries):
try:
response = self.session.get(url, timeout=timeout)
if response.status_code == 200:
return response
# for api rate limiting, we check the headers for remaining requests and reset time
elif response.status_code == 403 and response.headers.get("X-RateLimit-Remaining") == "0":
reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
sleep_seconds = max(1, reset_time - int(time.time()) + 1)
logger.warning(f"Rate limited. Sleeping for {sleep_seconds} seconds...")
time.sleep(sleep_seconds)
continue
# other errors will use exponential backoff timeout
elif response.status_code in {403, 429, 500, 502, 503, 504}:
logger.warning(f"Retryable error {response.status_code} on attempt {attempt}.")
else:
response.raise_for_status()
except requests.RequestException as e:
logger.warning(f"Request failed on attempt {attempt}: {e}")
logger.error(f"{error_msg} for {url} (Attempt {attempt + 1}/{retries})")
if attempt < retries - 1:
time.sleep(backoff ** attempt) # Exponential backoff
else:
logger.error(f"Max retries reached for GET at {url}. Giving up.")
return None
def _get_json(self, url: str, error_msg: str) -> dict:
"""Helper method to perform a simple GET request and return a single JSON object."""
response = self._get_with_retries(url, error_msg)
return response.json() if response else {}
def _get_paginated_json(self, url: str, error_msg: str) -> List[dict]:
"""Helper method to perform a sequence of GET requests with pagination."""
results = []
while url:
response = self._get_with_retries(url, error_msg)
if not response:
return results
results.extend(response.json())
url = response.links.get("next", {}).get("url")
return results
def _request_json(self, method: str, url: str, json: Optional[dict] = None,
error_msg: str = "", retries: int = 3, backoff: int = 2) -> dict:
"""Helper method to perform a request with retries and return JSON response."""
# no logging the actual request to avoid leaking sensitive information
for attempt in range(retries):
response = self.session.request(method, url, json=json)
if response.ok:
if response.status_code == 204 or not response.text.strip():
return {} # DELETE requests have no json content
else:
return response.json()
else:
# for api rate limiting, we check the headers for remaining requests and reset time
if response.status_code == 403 and response.headers.get("X-RateLimit-Remaining") == "0":
reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
sleep_seconds = max(1, reset_time - int(time.time()) + 1)
logger.warning(f"Rate limited. Sleeping for {sleep_seconds} seconds...")
time.sleep(sleep_seconds)
# other errors will use exponential backoff timeout
else:
logger.error(f"{error_msg} for method {method} at {url} (Attempt {attempt + 1}/{retries})")
if attempt < retries - 1:
time.sleep(backoff ** attempt) # Exponential backoff
else:
logger.error(f"Max retries reached for method {method} at {url}. Giving up.")
return {}
def get_changed_files(self, repo: str, pr: int) -> List[str]:
"""Fetch the changed files in a pull request using GitHub API."""
url = f"{self.api_url}/repos/{repo}/pulls/{pr}/files?per_page=50"
logger.debug(f"Request URL: {url}")
files_data = self._get_paginated_json(url, f"Failed to fetch files for PR #{pr} in {repo}")
files = [file["filename"] for file in files_data]
logger.debug(f"Changed files in PR #{pr}: {files}")
return files
def get_defined_labels(self, repo: str) -> List[str]:
"""Get all labels defined in the given repository."""
url = f"{self.api_url}/repos/{repo}/labels?per_page=100"
logger.debug(f"Request URL: {url}")
labels_data = self._get_paginated_json(url, f"Failed to fetch labels from {repo}")
labels = [label["name"] for label in labels_data]
logger.debug(f"Defined labels in {repo}: {labels}")
return labels
def get_existing_labels_on_pr(self, repo: str, pr: int) -> List[str]:
"""Fetch current labels on a PR."""
url = f"{self.api_url}/repos/{repo}/issues/{pr}/labels?per_page=100"
logger.debug(f"Request URL: {url}")
labels_data = self._get_paginated_json(url, f"Failed to fetch labels for PR #{pr} in {repo}")
labels = [label["name"] for label in labels_data]
logger.debug(f"Existing labels on PR #{pr}: {labels}")
return labels
def pr_view(self, repo: str, head: str) -> Optional[int]:
"""Check if a PR exists for the given repo and branch."""
# This is similar to get_pr_by_head_branch but returns only the PR number directly
url = f"{self.api_url}/repos/{repo}/pulls?head={repo.split('/')[0]}:{head}&per_page=100"
logger.debug(f"Request URL: {url}")
result = self._get_paginated_json(url, f"Failed to retrieve PR for head branch {head} in repo {repo}")
return result[0]["number"] if result else None
def get_pr_by_head_branch(self, repo: str, head: str) -> Optional[dict]:
"""Fetch the PR object for a given head branch in a repository, if it exists."""
# This is similar to pr_view but returns the full PR object
url = f"{self.api_url}/repos/{repo}/pulls?head={repo.split('/')[0]}:{head}&state=open&per_page=100"
logger.debug(f"Request URL: {url}")
data = self._get_paginated_json(url, f"Failed to get PRs for {repo} with head {head}")
return data[0] if data else None
def get_pr_by_number(self, repo: str, pr_number: int) -> Optional[dict]:
"""Fetch the PR object for a given PR number in a repository."""
url = f"{self.api_url}/repos/{repo}/pulls/{pr_number}"
logger.debug(f"Fetching PR #{pr_number} from {repo}")
response = self._get_json(url, f"Failed to get PR #{pr_number} from {repo}")
return response
def pr_create(self, repo: str, base: str, head: str, title: str, body: str, dry_run: bool = False) -> None:
"""Create a new pull request."""
url = f"{self.api_url}/repos/{repo}/pulls"
payload = {
"title": title,
"body": body,
"head": head,
"base": base
}
logger.debug(f"Request URL: {url}")
logger.debug(f"Request Payload: {payload}")
if dry_run:
logger.info(f"Dry run: The pull request would be created from {head} to {base} in {repo}")
return
self._request_json("POST", url, payload, f"Failed to create PR from {head} to {base} in {repo}")
logger.info(f"Created PR from {head} to {base} in {repo}.")
def close_pr_and_delete_branch(self, repo: str, pr_number: int, dry_run: bool = False) -> None:
"""Close a pull request and delete the associated branch using the GitHub API."""
pr_url = f"{self.api_url}/repos/{repo}/pulls/{pr_number}"
logger.debug(f"Request URL: {pr_url}")
pr_data = self._get_json(pr_url, f"Failed to fetch PR #{pr_number} in {repo}")
head_ref = pr_data.get("head", {}).get("ref")
if not head_ref:
logger.error(f"Could not determine head branch for PR #{pr_number} in {repo}")
return
logger.debug(f"PR #{pr_number} head branch: {head_ref}")
close_payload = {"state": "closed"}
logger.debug(f"Request Payload: {close_payload}")
if dry_run:
logger.info(f"Dry run: The pull request #{pr_number} would be closed and the branch '{head_ref}' would be deleted in repo '{repo}'")
return
self._request_json("PATCH", pr_url, close_payload, f"Failed to close PR #{pr_number} in {repo}")
branch_url = f"{self.api_url}/repos/{repo}/git/refs/heads/{head_ref}"
logger.debug(f"Branch DELETE URL: {branch_url}")
self._request_json("DELETE", branch_url, None, f"Failed to delete branch '{head_ref}' for PR #{pr_number}")
logger.info(f"Closed pull request #{pr_number} and deleted the branch '{head_ref}' in {repo}.")
def sync_labels(self, target_repo: str, pr_number: int, labels: List[str], dry_run: bool = False) -> None:
"""Sync labels from the source repo to the target repo (only apply existing labels)."""
url = f"{self.api_url}/repos/{target_repo}/labels?per_page=100"
logger.debug(f"Request URL: {url}")
target_repo_labels = {label["name"] for label in self._get_paginated_json(url, f"Failed to fetch labels for {target_repo}")}
labels_set = set(labels)
labels_to_apply = labels_set & target_repo_labels
labels_for_logging = ",".join(labels_to_apply)
if labels_to_apply:
# note: using issues endpoint for labels as PRs are a subset of issues
url = f"{self.api_url}/repos/{target_repo}/issues/{pr_number}/labels"
payload = {"labels": list(labels_to_apply)}
logger.debug(f"Request URL: {url}")
logger.debug(f"Request Payload: {payload}")
if not dry_run:
self._request_json("POST", url, payload, f"Failed to apply labels to PR #{pr_number} in {target_repo}")
logger.info(f"Applied labels '{labels_for_logging}' to PR #{pr_number} in {target_repo}.")
else:
logger.info(f"Dry run: Labels '{labels_for_logging}' would be applied to PR #{pr_number} in {target_repo}.")
else:
logger.info(f"No valid labels to apply to PR #{pr_number} in {target_repo}.")
def get_squash_merge_commit(self, repo: str, pr_number: int) -> Optional[str]:
"""Get the squash merge commit SHA of a merged pull request."""
url = f"{self.api_url}/repos/{repo}/pulls/{pr_number}"
logger.debug(f"Request URL: {url}")
data = self._get_json(url, f"Failed to fetch PR #{pr_number} from {repo}")
if not data:
logger.error(f"No data returned for PR #{pr_number}")
return None
if data.get("merged") and data.get("merge_commit_sha"):
logger.debug(f"PR #{pr_number} merged commit: {data['merge_commit_sha']}")
return data["merge_commit_sha"]
logger.warning(f"PR #{pr_number} is not merged or missing merge commit SHA.")
return None
def get_user(self, username: str) -> tuple[str, str]:
"""Fetch the name and email of a GitHub user. Falls back to login and no-reply email."""
url = f"{self.api_url}/users/{username}"
logger.debug(f"Fetching user profile for @{username}")
data = self._get_json(url, f"Failed to fetch user profile for @{username}")
name = data.get("name") or username
email = data.get("email")
if not email:
user_id = data.get("id")
if user_id:
email = f"{user_id}+{username}@users.noreply.github.com"
else:
email = f"{username}@users.noreply.github.com"
return name, email
+54
Dosyayı Görüntüle
@@ -0,0 +1,54 @@
import os
from pathlib import Path
# Determine monorepo root and output CODEOWNERS path
monorepo_root = Path(__file__).resolve().parents[2]
output_path = monorepo_root / ".github" / "CODEOWNERS"
merged_entries = []
# Walk top-level directories (excluding .github/.git/etc.)
for subdir in monorepo_root.iterdir():
if subdir.name.startswith(".") or not subdir.is_dir():
continue
# Look for CODEOWNERS in root or .github directory of the submodule
candidates = [subdir / "CODEOWNERS", subdir / ".github" / "CODEOWNERS"]
for codeowners_file in candidates:
if codeowners_file.is_file():
with codeowners_file.open("r") as f:
for line in f:
stripped = line.strip()
# Skip empty lines or comments
if not stripped or stripped.startswith("#"):
continue
parts = stripped.split()
if not parts:
continue
original_path = parts[0]
owners = " ".join(parts[1:])
# Ensure prefixed path starts with a single slash
prefixed_path = (
f"/{subdir.name.rstrip('/')}{original_path}"
if original_path.startswith("/")
else f"/{subdir.name}/{original_path}"
)
merged_entries.append(f"{prefixed_path} {owners}")
# Sort for consistency
merged_entries.sort()
# Write merged CODEOWNERS file
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w") as out:
out.write("# Auto-generated CODEOWNERS file\n\n")
out.write("\n".join(merged_entries))
print(f"✅ Merged CODEOWNERS written to {output_path}")
+36
Dosyayı Görüntüle
@@ -0,0 +1,36 @@
import os
import configparser
from pathlib import Path
ROOT_DIR = Path(__file__).resolve().parents[2] # Assuming script is in .github/scripts/
OUTPUT_FILE = ROOT_DIR / ".gitmodules"
MODULE_FILES = list(ROOT_DIR.glob("*/.gitmodules")) + list(ROOT_DIR.glob("*/.github/.gitmodules"))
combined = configparser.ConfigParser()
combined.optionxform = str # Preserve case sensitivity
for module_file in MODULE_FILES:
subdir = module_file.parent.name
local_config = configparser.ConfigParser()
local_config.optionxform = str
local_config.read(module_file)
for section in local_config.sections():
if section.startswith("submodule "):
name = section.split('"')[1]
new_name = f"{subdir}/{name}"
new_section = f'submodule "{new_name}"'
combined[new_section] = {}
for key, value in local_config[section].items():
if key == "path":
value = f"{subdir}/{value}"
combined[new_section][key] = value
# Write combined .gitmodules
with OUTPUT_FILE.open("w") as f:
for section in combined.sections():
f.write(f"[{section}]\n")
for key, value in combined[section].items():
f.write(f"\t{key} = {value}\n")
f.write("\n")
+115
Dosyayı Görüntüle
@@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
PR Category Label Script
--------------------
This script analyzes the file paths changed in a pull request and determines which
category labels should be added or removed based on the modified files.
It uses GitHub's cli to fetch the changed files and the existing labels on the pull request.
Then, it computes the desired labels based on file paths, compares them to the existing labels,
and applies the necessary additions and removals unless in dry-run mode.
Arguments:
--repo : Full repository name (e.g., org/repo)
--pr : Pull request number
--dry-run : If set, will only log actions without making changes.
--debug : If set, enables detailed debug logging.
Outputs:
Writes 'add' and 'remove' keys to the GitHub Actions $GITHUB_OUTPUT file, which
the workflow reads to apply label changes using the GitHub CLI.
Example Usage:
To run in debug mode and perform a dry-run (no changes made):
python pr_auto_label.py --repo ROCm/rocm-libraries --pr <pr-number> --dry-run --debug
To run in debug mode and apply label changes:
python pr_auto_label.py --repo ROCm/rocm-libraries --pr <pr-number> --debug
"""
import argparse
import sys
import os
import logging
from pathlib import Path
from typing import List, Optional
from github_cli_client import GitHubCLIClient
logger = logging.getLogger(__name__)
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Apply labels based on PR's changed files.")
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
parser.add_argument("--dry-run", action="store_true", help="Print results without writing to GITHUB_OUTPUT.")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
return parser.parse_args(argv)
def compute_desired_labels(file_paths: list) -> set:
"""Determine the desired labels based on the changed files."""
desired_labels = set()
for path in file_paths:
parts = Path(path).parts
if len(parts) >= 2:
if parts[0] == "projects":
desired_labels.add(f"project: {parts[1]}")
elif parts[0] == "shared":
desired_labels.add(f"shared: {parts[1]}")
logger.debug(f"Desired labels based on changes: {desired_labels}")
return desired_labels
def output_labels(existing_labels: List[str], desired_labels: List[str], dry_run: bool) -> None:
"""Output the labels to add/remove to GITHUB_OUTPUT or log them in dry-run mode."""
existing_auto_labels = {
label for label in existing_labels
if label.startswith("project: ") or label.startswith("shared: ")
}
to_add = sorted(desired_labels - set(existing_labels))
to_remove = sorted(existing_auto_labels - desired_labels)
logger.debug(f"Labels to add: {to_add}")
logger.debug(f"Labels to remove: {to_remove}")
if dry_run:
logger.info("Dry run enabled. Labels will not be applied.")
else:
output_file = os.environ.get("GITHUB_OUTPUT")
if output_file:
with open(output_file, 'a') as f:
print(f"label_add={','.join(to_add)}", file=f)
print(f"label_remove={','.join(to_remove)}", file=f)
logger.info(f"Wrote to GITHUB_OUTPUT: add={','.join(to_add)}")
logger.info(f"Wrote to GITHUB_OUTPUT: remove={','.join(to_remove)}")
else:
print("GITHUB_OUTPUT environment variable not set. Outputs cannot be written.")
sys.exit(1)
def main(argv=None) -> None:
"""Main function to execute the PR auto label logic."""
args = parse_arguments(argv)
logging.basicConfig(
level=logging.DEBUG if args.debug else logging.INFO
)
client = GitHubCLIClient()
changed_files = [file for file in client.get_changed_files(args.repo, int(args.pr))]
if not changed_files:
logger.warning("REST API failed or returned no changed files. Falling back to Git CLI...")
try:
# Ensure fetch is safe
os.system("git fetch origin +refs/pull/*/merge:refs/remotes/origin/pr/*")
# Get merge commit ref for this PR
base_ref = f"origin/{os.getenv('GITHUB_BASE_REF', 'main')}"
head_ref = "HEAD" # Assumes checkout to PR merge ref
result = os.popen(f"git diff --name-only {base_ref}...{head_ref}").read()
changed_files = result.strip().splitlines()
logger.info(f"Fallback changed files: {changed_files}")
except Exception as e:
logger.error(f"Git CLI fallback failed: {e}")
sys.exit(1)
existing_labels = client.get_existing_labels_on_pr(args.repo, int(args.pr))
desired_labels = compute_desired_labels(changed_files)
output_labels(existing_labels, desired_labels, args.dry_run)
if __name__ == "__main__":
main()
+137
Dosyayı Görüntüle
@@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""
PR Detect Changed Subtrees Script
---------------------------------
This script analyzes a pull request's changed files and determines which subtrees
(defined in .github/repos-config.json by category/name) were affected.
Steps:
1. Fetch the changed files in the PR using the GitHub API.
2. Load the subtree mapping from repos-config.json.
3. Match changed paths against known category/name prefixes.
4. Emit a new-line separated list of changed subtrees to GITHUB_OUTPUT as 'subtrees'.
Arguments:
--repo : Full repository name (e.g., org/repo)
--pr : Pull request number
--config : OPTIONAL, path to the repos-config.json file.
--require-auto-pull : If set, only include entries with auto_subtree_pull=true.
--require-auto-push : If set, only include entries with auto_subtree_push=true.
--dry-run : If set, will only log actions without making changes.
--debug : If set, enables detailed debug logging.
Outputs:
Writes 'subtrees' key to the GitHub Actions $GITHUB_OUTPUT file, which
the workflow reads to pass paths to the checkout stages.
The output is a new-line separated list of subtrees in `category/name` format.
Example Usage:
To run in auto-push situations in debug mode and perform a dry-run (no changes made):
python pr_detect_changed_subtrees.py --repo ROCm/rocm-libraries --pr 123 --require-auto-push --debug --dry-run
"""
import argparse
import sys
import os
import logging
from typing import List, Optional, Set
from github_cli_client import GitHubCLIClient
from repo_config_model import RepoEntry
from config_loader import load_repo_config
logger = logging.getLogger(__name__)
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Detect changed subtrees in a PR.")
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
parser.add_argument("--config", required=False, default=".github/repos-config.json", help="Path to the repos-config.json file")
parser.add_argument("--require-auto-pull", action="store_true", help="Only include entries with auto_subtree_pull=true")
parser.add_argument("--require-auto-push", action="store_true", help="Only include entries with auto_subtree_push=true")
parser.add_argument("--dry-run", action="store_true", help="Print results without writing to GITHUB_OUTPUT.")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
return parser.parse_args(argv)
def get_valid_prefixes(config: List[RepoEntry]) -> Set[str]:
"""Extract valid subtree prefixes from the configuration."""
valid_prefixes = {
f"{entry.category}/{entry.name}"
for entry in config
}
logger.debug("Valid subtrees:\n" + "\n".join(sorted(valid_prefixes)))
return valid_prefixes
def get_valid_prefixes(config: List[RepoEntry], require_auto_pull: bool = False, require_auto_push: bool = False) -> Set[str]:
"""Extract valid subtree prefixes from the configuration based on filters."""
valid_prefixes = set()
for entry in config:
if require_auto_pull and not getattr(entry, "auto_subtree_pull", False):
continue
if require_auto_push and not getattr(entry, "auto_subtree_push", False):
continue
valid_prefixes.add(f"{entry.category}/{entry.name}")
logger.debug("Valid subtrees:\n" + "\n".join(sorted(valid_prefixes)))
return valid_prefixes
def find_matched_subtrees(changed_files: List[str], valid_prefixes: Set[str]) -> List[str]:
"""Find subtrees that match the changed files."""
changed_subtrees = {
"/".join(path.split("/", 2)[:2])
for path in changed_files
if len(path.split("/")) >= 2
}
matched = sorted(changed_subtrees & valid_prefixes)
skipped = sorted(changed_subtrees - valid_prefixes)
if skipped:
logger.debug(f"Skipped subtrees: {skipped}")
logger.debug(f"Matched subtrees: {matched}")
return matched
def output_subtrees(matched_subtrees: List[str], dry_run: bool) -> None:
"""Output the matched subtrees to GITHUB_OUTPUT or log them in dry-run mode."""
newline_separated = "\n".join(matched_subtrees)
if dry_run:
logger.info(f"[Dry-run] Would output:\n{newline_separated}")
else:
output_file = os.environ.get('GITHUB_OUTPUT')
if output_file:
with open(output_file, 'a') as f:
print(f"subtrees<<EOF\n{newline_separated}\nEOF", file=f)
logger.info("Wrote matched subtrees to GITHUB_OUTPUT.")
else:
logger.error("GITHUB_OUTPUT environment variable not set. Outputs cannot be written.")
sys.exit(1)
def main(argv=None) -> None:
"""Main function to determine changed subtrees in PR."""
args = parse_arguments(argv)
logging.basicConfig(
level = logging.DEBUG if args.debug else logging.INFO
)
client = GitHubCLIClient()
config = load_repo_config(args.config)
changed_files = client.get_changed_files(args.repo, int(args.pr))
if not changed_files:
logger.warning("REST API failed or returned no changed files. Falling back to Git CLI...")
try:
# Ensure fetch is safe
os.system("git fetch origin +refs/pull/*/merge:refs/remotes/origin/pr/*")
# Get merge commit ref for this PR
base_ref = f"origin/{os.getenv('GITHUB_BASE_REF', 'main')}"
head_ref = "HEAD" # Assumes checkout to PR merge ref
result = os.popen(f"git diff --name-only {base_ref}...{head_ref}").read()
changed_files = result.strip().splitlines()
logger.info(f"Fallback changed files: {changed_files}")
except Exception as e:
logger.error(f"Git CLI fallback failed: {e}")
sys.exit(1)
valid_prefixes = get_valid_prefixes(config, args.require_auto_pull, args.require_auto_push)
matched_subtrees = find_matched_subtrees(changed_files, valid_prefixes)
output_subtrees(matched_subtrees, args.dry_run)
if __name__ == "__main__":
main()
+298
Dosyayı Görüntüle
@@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
Sync Patches to Subrepositories
-------------------------------
This script is part of the monorepo synchronization system. It runs after a monorepo pull request
is merged and applies relevant changes to the corresponding sub-repositories using Git patches.
- Uses the merge commit of the monorepo PR to extract subtree changes.
- Detects file-level changes including adds, deletes, and renames.
- Applies changes directly using file copy/move/delete as needed.
- Squashes all commits per subtree into one before pushing.
- Uses the repos-config.json file to map subtrees to sub-repos.
- Assumes this script is run from the root of the monorepo.
Arguments:
--repo : Full repository name (e.g., org/repo)
--pr : Pull request number
--subtrees : A newline-separated list of subtree paths in category/name format (e.g., projects/rocBLAS)
--config : OPTIONAL, path to the repos-config.json file
--dry-run : If set, will only log actions without making changes.
--debug : If set, enables detailed debug logging.
Example Usage:
python pr_merge_sync_patches.py --repo ROCm/rocm-libraries --pr 123 --subtrees "$(printf 'projects/rocBLAS\nprojects/hipBLASLt\nshared/rocSPARSE')" --dry-run --debug
"""
import argparse
import logging
import os
import re
import shutil
import subprocess
import tempfile
from typing import Optional, List, Tuple
from pathlib import Path
from github_cli_client import GitHubCLIClient
from config_loader import load_repo_config
from repo_config_model import RepoEntry
logger = logging.getLogger(__name__)
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Apply subtree patches to sub-repositories.")
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
parser.add_argument("--subtrees", required=True, help="Newline-separated list of changed subtrees (category/name)")
parser.add_argument("--config", required=False, default=".github/repos-config.json", help="Path to the repos-config.json file")
parser.add_argument("--dry-run", action="store_true", help="If set, only logs actions without making changes.")
parser.add_argument("--debug", action="store_true", help="If set, enables detailed debug logging.")
return parser.parse_args(argv)
def get_subtree_info(config: List[RepoEntry], subtrees: List[str]) -> List[RepoEntry]:
"""Return config entries matching the given subtrees in category/name format."""
requested = set(subtrees)
matched = [
entry for entry in config
if f"{entry.category}/{entry.name}" in requested
]
missing = requested - {f"{e.category}/{e.name}" for e in matched}
if missing:
logger.warning(f"Some subtrees not found in config: {', '.join(sorted(missing))}")
return matched
def _run_git(args: List[str], cwd: Optional[Path] = None) -> str:
"""Run a git command and return stdout."""
cmd = ["git"] + args
logger.debug(f"Running git command: {' '.join(cmd)} (cwd={cwd})")
result = subprocess.run(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
logger.error(f"Git command failed: {' '.join(cmd)}\n{result.stderr}")
raise RuntimeError(f"Git command failed: {' '.join(cmd)}\n{result.stderr}")
return result.stdout.strip()
def _clone_subrepo(repo_url: str, branch: str, destination: Path) -> None:
"""Clone a specific branch from the given GitHub repository into the destination path."""
_run_git([
"clone",
"--branch", branch,
"--single-branch",
f"https://github.com/{repo_url}",
str(destination)
])
logger.debug(f"Cloned {repo_url} into {destination}")
def _configure_git_user(repo_path: Path) -> None:
"""Configure git user.name and user.email for the given repository directory."""
_run_git(["config", "user.name", "assistant-librarian[bot]"], cwd=repo_path)
_run_git(["config", "user.email", "assistant-librarian[bot]@users.noreply.github.com"], cwd=repo_path)
def _apply_patch(repo_path: Path, patch_path: Path, rel_file_path: Path, monorepo_path: Path, prefix: str) -> None:
"""Try to apply a patch; if it fails, fallback to full file replacement."""
try:
_run_git(["am", str(patch_path)], cwd=repo_path)
logger.info(f"Applied patch {patch_path.name} successfully")
except RuntimeError as e:
logger.warning(f"Patch {patch_path.name} failed to apply; falling back to full file copy")
# Construct source and destination
monorepo_file = monorepo_path / prefix / rel_file_path
subrepo_file = repo_path / rel_file_path
subrepo_file.parent.mkdir(parents=True, exist_ok=True)
if not monorepo_file.exists():
raise RuntimeError(f"Fallback failed: {monorepo_file} does not exist")
shutil.copyfile(monorepo_file, subrepo_file)
_run_git(["add", str(rel_file_path)], cwd=repo_path)
logger.info(f"Copied {monorepo_file} -> {subrepo_file}")
def _set_authenticated_remote(repo_path: Path, repo_url: str) -> None:
"""Set the push URL to use the GitHub App token from GH_TOKEN env."""
token = os.environ.get("GH_TOKEN")
if not token:
raise RuntimeError("GH_TOKEN environment variable is not set")
remote_url = f"https://x-access-token:{token}@github.com/{repo_url}.git"
_run_git(["remote", "set-url", "origin", remote_url], cwd=repo_path)
def _push_changes(repo_path: Path, branch: str) -> None:
"""Push the commit to origin of branch."""
_run_git(["push", "origin", branch], cwd=repo_path)
logger.debug(f"Pushed changes from {repo_path} to origin")
def generate_file_level_patches(prefix: str, merge_sha: str, output_dir: Path) -> tuple[list[str], list[str], list[tuple[str, str]], list[str], list[Path]]:
"""Generate one patch per modified file, and collect adds, deletes, and renames."""
diff_output = _run_git([
"diff", "--name-status", "-M", f"{merge_sha}^!", "--", prefix
])
added_files = []
deleted_files = []
renamed_files = []
modified_files = []
patch_files = []
for line in diff_output.splitlines():
parts = line.split('\t')
status = parts[0]
if status == 'A':
added_files.append(parts[1])
elif status == 'M':
file_path = parts[1]
patch_path = output_dir / (file_path.replace("/", "_") + ".patch")
_run_git([
"format-patch",
"-1", merge_sha,
f"--relative={prefix}",
"--output", str(patch_path),
"--", file_path
])
patch_files.append(patch_path)
modified_files.append(file_path)
elif status == 'D':
deleted_files.append(parts[1])
elif status.startswith('R'):
renamed_files.append((parts[1], parts[2]))
logger.debug(f"Generated {len(patch_files)} modified file patches, "
f"{len(added_files)} added, {len(deleted_files)} deleted, "
f"{len(renamed_files)} renamed under {prefix}")
return added_files, deleted_files, renamed_files, modified_files, patch_files
def resolve_patch_author(client: GitHubCLIClient, repo: str, pr: int) -> tuple[str, str]:
"""Determine the appropriate author for the patch"""
pr_data = client.get_pr_by_number(repo, pr)
body = pr_data.get("body", "") or ""
match = re.search(r"Originally authored by @([A-Za-z0-9_-]+)", body)
if match:
username = match.group(1)
logger.debug(f"Found originally authored username in PR body: @{username}")
else:
username = pr_data["user"]["login"]
logger.debug(f"No explicit original author, using PR author: @{username}")
name, email = client.get_user(username)
return name or username, email
def apply_patches_and_squash(entry: RepoEntry, monorepo_url: str, monorepo_pr: int,
added_files: list[str], deleted_files: list[str], renamed_files: list[tuple[str, str]],
modified_files: list[str], modified_patch_paths: list[Path],
author_name: str, author_email: str, merge_sha: str, dry_run: bool = False) -> None:
"""
Clone the subrepo, apply file-level patches each as a commit,
delete files with git rm, copy added files, rename files,
then squash all new commits into one before pushing.
"""
with tempfile.TemporaryDirectory() as tmpdir:
prefix = f"{entry.category}/{entry.name}"
if dry_run:
logger.info(f"[Dry-run] Sync for {entry.name}:")
prefix_path = Path(prefix)
if added_files:
logger.info(" Added files:")
for f in added_files:
short_path = Path(f).relative_to(prefix_path)
logger.info(f" {short_path}")
if deleted_files:
logger.info(" Deleted files:")
for f in deleted_files:
short_path = Path(f).relative_to(prefix_path)
logger.info(f" {short_path}")
if renamed_files:
logger.info(" Renamed files:")
for old, new in renamed_files:
old_rel = Path(old).relative_to(prefix_path)
new_rel = Path(new).relative_to(prefix_path)
logger.info(f" {old_rel} -> {new_rel}")
if modified_files:
logger.info(" Modified files (via patch):")
for f in modified_files:
short_path = Path(f).relative_to(prefix_path)
logger.info(f" {short_path}")
if not (added_files or deleted_files or renamed_files or modified_files or modified_patch_paths):
logger.info(" No changes detected.")
return
subrepo_path = Path(tmpdir) / entry.name
_clone_subrepo(entry.url, entry.branch, subrepo_path)
_configure_git_user(subrepo_path)
# Get current HEAD commit (before applying patches)
base_commit = _run_git(["rev-parse", "HEAD"], cwd=subrepo_path)
# Handle deletes
for file_path in deleted_files:
rel_path = file_path[len(prefix)+1:] if file_path.startswith(prefix + "/") else file_path
_run_git(["rm", rel_path], cwd=subrepo_path)
# Handle renames
for old, new in renamed_files:
old_rel = old[len(prefix)+1:] if old.startswith(prefix + "/") else old
new_rel = new[len(prefix)+1:] if new.startswith(prefix + "/") else new
_run_git(["mv", old_rel, new_rel], cwd=subrepo_path)
# Handle adds
for file_path in added_files:
rel_path = file_path[len(prefix)+1:] if file_path.startswith(prefix + "/") else file_path
src = Path(prefix) / rel_path
dst = subrepo_path / rel_path
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src, dst)
# Handle modified files (apply patches one by one)
for patch_path, full_file_path in zip(modified_patch_paths, modified_files):
rel_path = full_file_path[len(prefix)+1:] if full_file_path.startswith(prefix + "/") else full_file_path
logger.debug(f"Applying patch {patch_path.name} to {entry.name} at {rel_path}")
_apply_patch(subrepo_path, patch_path, Path(rel_path), Path.cwd(), prefix)
# Final squash
commit_msg = f"[rocm-libraries] {monorepo_url}#{monorepo_pr} (commit {merge_sha[:7]})\n\n" + \
_run_git(["log", "-1", "--pretty=%B", merge_sha])
_run_git(["reset", "--soft", base_commit], cwd=subrepo_path)
_run_git(["commit", "-m", commit_msg, "--author", f"{author_name} <{author_email}>"], cwd=subrepo_path)
_set_authenticated_remote(subrepo_path, entry.url)
_push_changes(subrepo_path, entry.branch)
def main(argv: Optional[List[str]] = None) -> None:
"""Main function to apply patches to sub-repositories."""
args = parse_arguments(argv)
logging.basicConfig(
level=logging.DEBUG if args.debug else logging.INFO
)
client = GitHubCLIClient()
config = load_repo_config(args.config)
subtrees = [line.strip() for line in args.subtrees.splitlines() if line.strip()]
relevant_subtrees = get_subtree_info(config, subtrees)
merge_sha = client.get_squash_merge_commit(args.repo, args.pr)
logger.debug(f"Merge commit for PR #{args.pr} in {args.repo}: {merge_sha}")
_run_git(["checkout", merge_sha])
logger.info(f"Checked out merge commit {merge_sha} for patch operations")
for entry in relevant_subtrees:
prefix = f"{entry.category}/{entry.name}"
logger.debug(f"Processing subtree {prefix}")
with tempfile.TemporaryDirectory() as tmpdir:
patch_dir = Path(tmpdir)
# Generate patches and lists of adds/deletes/renames
added_files, deleted_files, renamed_files, modified_files, modified_patch_paths, = generate_file_level_patches(prefix, merge_sha, patch_dir)
if not (added_files or deleted_files or renamed_files or modified_files or modified_patch_paths):
logger.info(f"No changes to apply for {prefix}")
continue
author_name, author_email = resolve_patch_author(client, args.repo, args.pr)
apply_patches_and_squash(entry, args.repo, args.pr,
added_files, deleted_files, renamed_files, modified_files, modified_patch_paths,
author_name, author_email, merge_sha,
args.dry_run)
if __name__ == "__main__":
main()
+51
Dosyayı Görüntüle
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
"""
Repository Config Model
------------------------
This module defines Pydantic data models for validating and parsing the repos-config.json file.
Structure of the expected JSON:
{
"repositories": [
{
"name": "rocblas",
"url": "ROCm/rocBLAS",
"branch": "develop",
"category": "projects"
},
...
]
}
"""
from typing import List
from pydantic import BaseModel
class RepoEntry(BaseModel):
"""
Represents a single repository entry in the repos-config.json file.
Fields:
name : Name of the project matching packaging file names. Lower-cased and no underscores. (e.g., "rocblas")
url : Individual GitHub org plus repo names in matching case and punctuation. (e.g., "ROCm/rocBLAS")
branch : The base branch of the sub-repo to target (e.g., "develop").
category : Directory category in the monorepo (e.g., "projects" or "shared").
"""
name: str
url: str
branch: str
category: str
auto_subtree_pull: bool
auto_subtree_push: bool
class RepoConfig(BaseModel):
"""
Represents the full config file structure.
Fields:
repositories : List of RepoEntry items.
"""
repositories: List[RepoEntry]
+73
Dosyayı Görüntüle
@@ -0,0 +1,73 @@
from pathlib import Path
import os
import sys
import unittest
sys.path.insert(0, os.fspath(Path(__file__).parent.parent))
import therock_configure_ci
class ConfigureCITest(unittest.TestCase):
def test_pull_request(self):
args = {
"is_pull_request": True,
"input_subtrees": "projects/rocprim\nprojects/hipcub"
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertEqual(len(project_to_run), 1)
def test_pull_request_empty(self):
args = {
"is_pull_request": True,
"input_subtrees": ""
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertEqual(len(project_to_run), 0)
def test_workflow_dispatch(self):
args = {
"is_workflow_dispatch": True,
"input_projects": "projects/rocprim projects/hipcub"
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertEqual(len(project_to_run), 1)
def test_workflow_dispatch_bad_input(self):
args = {
"is_workflow_dispatch": True,
"input_projects": "projects/rocprim$$projects/hipcub"
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertEqual(len(project_to_run), 0)
def test_workflow_dispatch_all(self):
args = {
"is_workflow_dispatch": True,
"input_projects": "all"
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertGreaterEqual(len(project_to_run), 1)
def test_workflow_dispatch_empty(self):
args = {
"is_workflow_dispatch": True,
"input_projects": ""
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertEqual(len(project_to_run), 0)
def test_is_push(self):
args = {
"is_push": True,
}
project_to_run = therock_configure_ci.retrieve_projects(args)
self.assertGreaterEqual(len(project_to_run), 1)
if __name__ == "__main__":
unittest.main()
+87
Dosyayı Görüntüle
@@ -0,0 +1,87 @@
"""
This script determines which build flag and tests to run based on SUBTREES
Required environment variables:
- SUBTREES
"""
import json
import logging
from therock_matrix import subtree_to_project_map, project_map
from typing import Mapping
import os
logging.basicConfig(level=logging.INFO)
def set_github_output(d: Mapping[str, str]):
"""Sets GITHUB_OUTPUT values.
See https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
"""
logging.info(f"Setting github output:\n{d}")
step_output_file = os.environ.get("GITHUB_OUTPUT", "")
if not step_output_file:
logging.warning("Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
return
with open(step_output_file, "a") as f:
f.writelines(f"{k}={v}" + "\n" for k, v in d.items())
def retrieve_projects(args):
# TODO(geomin12): #590 Enable TheRock CI for forked PRs
if args.get("is_forked_pr"):
logging.info("Warning: not enabling any projects due to is_forked_pr. Builds/tests for forked PRs are disabled pending: https://github.com/ROCm/rocm-libraries/issues/590")
return []
if args.get("is_pull_request"):
subtrees = args.get("input_subtrees").split("\n")
if args.get("is_workflow_dispatch"):
if args.get("input_projects") == "all":
subtrees = list(subtree_to_project_map.keys())
else:
subtrees = args.get("input_projects").split()
# If a push event to develop happens, we run tests on all subtrees
if args.get("is_push"):
subtrees = list(subtree_to_project_map.keys())
projects = set()
# collect the associated subtree to project
for subtree in subtrees:
if subtree in subtree_to_project_map:
projects.add(subtree_to_project_map.get(subtree))
# retrieve the subtrees to checkout, cmake options to build, and projects to test
project_to_run = []
for project in projects:
if project in project_map:
project_to_run.append(project_map.get(project))
return project_to_run
def run(args):
project_to_run = retrieve_projects(args)
set_github_output({"projects": json.dumps(project_to_run)})
if __name__ == "__main__":
args = {}
github_event_name = os.getenv("GITHUB_EVENT_NAME")
args["is_pull_request"] = github_event_name == "pull_request"
args["is_push"] = github_event_name == "push"
args["is_workflow_dispatch"] = github_event_name == "workflow_dispatch"
is_forked_pr = os.getenv("IS_FORKED_PR")
args["is_forked_pr"] = is_forked_pr == "true"
input_subtrees = os.getenv("SUBTREES", "")
args["input_subtrees"] = input_subtrees
input_projects = os.getenv("PROJECTS", "")
args["input_projects"] = input_projects
logging.info(f"Retrieved arguments {args}")
run(args)
+23
Dosyayı Görüntüle
@@ -0,0 +1,23 @@
"""
This dictionary is used to map specific file directory changes to the corresponding build flag and tests
"""
subtree_to_project_map = {
"projects/rocprim": "prim",
"projects/rocthrust": "prim",
"projects/hipcub": "prim",
"projects/rocrand": "rand",
"projects/hiprand": "rand"
}
project_map = {
"prim": {
"cmake_options": "-DTHEROCK_ENABLE_PRIM=ON -DTHEROCK_ENABLE_ALL=OFF",
"project_to_test": "rocprim, rocthrust, hipcub",
"subtree_checkout": "projects/rocprim\nprojects/hipcub\nprojects/rocthrust",
},
"rand": {
"cmake_options": "-DTHEROCK_ENABLE_RAND=ON -DTHEROCK_ENABLE_ALL=OFF",
"project_to_test": "rocrand, hiprand",
"subtree_checkout": "projects/rocrand\nprojects/hiprand",
},
}