Reset state of repo with updated README
Bu işleme şunda yer alıyor:
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import requests
|
||||
|
||||
def get_existing_labels(repo, token):
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
labels = {}
|
||||
page = 1
|
||||
while True:
|
||||
url = f"https://api.github.com/repos/{repo}/labels?page={page}&per_page=100"
|
||||
resp = requests.get(url, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Failed to fetch existing labels: {resp.text}")
|
||||
data = resp.json()
|
||||
if not data:
|
||||
break
|
||||
for label in data:
|
||||
labels[label["name"]] = {
|
||||
"color": label["color"],
|
||||
"description": label.get("description", "")
|
||||
}
|
||||
page += 1
|
||||
return labels
|
||||
|
||||
def create_or_update_label(repo, token, label, existing):
|
||||
headers = {
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json"
|
||||
}
|
||||
|
||||
if label["name"] not in existing:
|
||||
# Create label
|
||||
print(f"Creating label: {label['name']}")
|
||||
url = f"https://api.github.com/repos/{repo}/labels"
|
||||
resp = requests.post(url, json=label, headers=headers)
|
||||
else:
|
||||
# Update if different
|
||||
current = existing[label["name"]]
|
||||
if (label["color"].lower() != current["color"].lower() or
|
||||
label.get("description", "") != current.get("description", "")):
|
||||
print(f"Updating label: {label['name']}")
|
||||
url = f"https://api.github.com/repos/{repo}/labels/{label['name']}"
|
||||
resp = requests.patch(url, json=label, headers=headers)
|
||||
else:
|
||||
print(f"Label '{label['name']}' already up to date. Skipping.")
|
||||
return
|
||||
|
||||
if not resp.ok:
|
||||
print(f"Failed to apply label {label['name']}: {resp.status_code} {resp.text}")
|
||||
|
||||
def main(label_file):
|
||||
token = os.environ["GH_TOKEN"]
|
||||
repo = os.environ["GITHUB_REPO"]
|
||||
existing = get_existing_labels(repo, token)
|
||||
|
||||
with open(label_file, "r") as f:
|
||||
labels = yaml.safe_load(f)
|
||||
|
||||
for label in labels:
|
||||
create_or_update_label(repo, token, label, existing)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1])
|
||||
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Azure Pipeline Resolver Script
|
||||
------------------------------
|
||||
This script determines which Azure pipelines to run based on changed subtrees.
|
||||
Using a predefined dependency map, the script resolves which projects need to be processed,
|
||||
skipping those that will be covered by their dependencies.
|
||||
|
||||
Steps:
|
||||
1. Load a list of changed projects from a file.
|
||||
2. Consult a dependency map to determine transitive and direct dependencies.
|
||||
3. Identify projects that should be processed, excluding those handled by dependencies.
|
||||
4. Output the list of projects to be run, along with their Azure pipeline IDs.
|
||||
|
||||
Arguments:
|
||||
--subtree-file : Path to the file containing a newline-separated list of changed subtrees.
|
||||
|
||||
Outputs:
|
||||
Prints a newline-separated list of "project_name=definition_id" for the projects that need
|
||||
to be processed, where `definition_id` is the Azure pipeline ID associated with the project.
|
||||
|
||||
Example Usage:
|
||||
To determine which pipelines to run given the changed subtrees listed in a file:
|
||||
python azure_pipeline_resolver.py --subtree-file changed_subtrees.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Given a list of changed subtrees, determine which Azure pipelines to run.")
|
||||
parser.add_argument("--subtree-file", required=True,
|
||||
help="Path to the file containing changed subtrees")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def read_file_into_set(file_path):
|
||||
"""Reads the project names from the file into a set."""
|
||||
with open(file_path, 'r') as file:
|
||||
return {line.strip() for line in file}
|
||||
|
||||
|
||||
def resolve_dependencies(projects, dependencies):
|
||||
"""Resolves projects to be run by checking all levels of dependencies."""
|
||||
def has_dependency(project, projects_set):
|
||||
"""Recursively checks if a project has any dependencies in the projects_set."""
|
||||
if project not in dependencies:
|
||||
return False
|
||||
for dependency in dependencies[project]:
|
||||
if dependency in projects_set or has_dependency(dependency, projects_set):
|
||||
return True
|
||||
return False
|
||||
|
||||
projects_to_run = set(projects)
|
||||
|
||||
for project in projects:
|
||||
if has_dependency(project, projects_to_run):
|
||||
projects_to_run.discard(project)
|
||||
|
||||
return projects_to_run
|
||||
|
||||
|
||||
def main(argv=None) -> None:
|
||||
"""Main function to process the projects and output those to be run."""
|
||||
# Mathlib build+test dependency tree as defined in Azure CI and TheRock
|
||||
math_dependencies = {
|
||||
"shared/tensile": {},
|
||||
"projects/rocrand": {},
|
||||
"projects/hiprand": {"projects/rocrand"},
|
||||
"projects/rocfft": {"projects/hiprand"},
|
||||
"projects/hipfft": {"projects/rocfft"},
|
||||
"projects/rocprim": {},
|
||||
"projects/hipcub": {"projects/rocprim"},
|
||||
"projects/rocthrust": {"projects/rocprim"},
|
||||
"projects/hipblas-common": {},
|
||||
"projects/hipblaslt": {"projects/hipblas-common"},
|
||||
"projects/rocblas": {"projects/hipblaslt"},
|
||||
"projects/rocsolver": {"projects/rocprim", "projects/rocblas"},
|
||||
"projects/rocsparse": {"projects/rocprim", "projects/rocblas"},
|
||||
"projects/hipblas": {"projects/rocsolver"},
|
||||
"projects/hipsolver": {"projects/rocsolver", "projects/rocsparse"},
|
||||
"projects/hipsparse": {"projects/rocsparse"},
|
||||
"projects/hipsparselt": {"projects/hipsparse"},
|
||||
"projects/miopen": {"projects/rocrand", "projects/hipblas"}
|
||||
}
|
||||
# Azure pipeline IDs for each project, to be populated as projects are enabled
|
||||
definition_ids = {
|
||||
"shared/tensile": 305,
|
||||
"projects/rocrand": 274,
|
||||
"projects/hiprand": 275,
|
||||
"projects/rocfft": 282,
|
||||
"projects/hipfft": 283,
|
||||
"projects/rocprim": 273,
|
||||
"projects/hipcub": 277,
|
||||
"projects/rocthrust": 276,
|
||||
"projects/hipblas-common": 300,
|
||||
"projects/hipblaslt": 301,
|
||||
"projects/hipsparselt": 309,
|
||||
"projects/rocblas": 302,
|
||||
"projects/rocsolver": 303,
|
||||
}
|
||||
|
||||
args = parse_arguments(argv)
|
||||
projects = read_file_into_set(args.subtree_file)
|
||||
projects_to_run = resolve_dependencies(projects, math_dependencies)
|
||||
|
||||
for project in projects_to_run:
|
||||
if project in definition_ids:
|
||||
print(f"{project}={definition_ids[project]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,48 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
def get_labels(repo, token):
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
labels = []
|
||||
page = 1
|
||||
while True:
|
||||
url = f"https://api.github.com/repos/{repo}/labels?page={page}&per_page=100"
|
||||
resp = requests.get(url, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Failed to fetch labels from {repo}: {resp.text}")
|
||||
data = resp.json()
|
||||
if not data:
|
||||
break
|
||||
labels.extend(data)
|
||||
page += 1
|
||||
return labels
|
||||
|
||||
def main(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
repos_data = json.load(f)["repositories"]
|
||||
|
||||
token = os.environ["GH_TOKEN"]
|
||||
all_labels = {}
|
||||
|
||||
for repo_entry in repos_data:
|
||||
repo_url = repo_entry["url"]
|
||||
print(f"Collecting labels from {repo_url}")
|
||||
for label in get_labels(repo_url, token):
|
||||
name = label["name"]
|
||||
if name not in all_labels:
|
||||
all_labels[name] = {
|
||||
"name": name,
|
||||
"color": label["color"],
|
||||
"description": label.get("description", "")
|
||||
}
|
||||
|
||||
sorted_labels = sorted(all_labels.values(), key=lambda l: l["name"].lower())
|
||||
os.makedirs(".github", exist_ok=True) # Ensure the .github directory exists
|
||||
with open(".github/labels.yml", "w") as out:
|
||||
yaml.dump(sorted_labels, out, sort_keys=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1])
|
||||
@@ -0,0 +1,18 @@
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
from typing import List
|
||||
from repo_config_model import RepoConfig, RepoEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_repo_config(config_path: str) -> List[RepoEntry]:
|
||||
"""Load and validate repository config from JSON using Pydantic."""
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
config = RepoConfig(**data)
|
||||
return config.repositories
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load or validate config file '{config_path}': {e}")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
GitHub Client Utility
|
||||
---------------------
|
||||
This utility provides a GitHubClient class that wraps GitHub REST API operations
|
||||
used across automation scripts, such as retrieving pull request file changes and labels.
|
||||
|
||||
When doing manual testing, you can run the same REST API calls through curl in the terminal.
|
||||
These REST API URLs, without the authentication header, will be output by the debug logging.
|
||||
|
||||
This includes:
|
||||
- Fetching PR details
|
||||
- Creating PRs
|
||||
- Closing PRs
|
||||
|
||||
Requirements:
|
||||
- NOTE: GH_TOKEN environment variable hands authentication token to this script in a runner.
|
||||
- The token is created by the GitHub App and is passed to the script via the environment variable.
|
||||
|
||||
Manual curl testing:
|
||||
|
||||
To fetch PR details:
|
||||
curl -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
|
||||
https://api.github.com/repos/OWNER/REPO/pulls/NUMBER
|
||||
|
||||
To list PRs by head branch:
|
||||
curl -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
|
||||
"https://api.github.com/repos/OWNER/REPO/pulls?head=OWNER:branch-name&state=open"
|
||||
|
||||
To fetch changed files in a PR:
|
||||
curl -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
|
||||
https://api.github.com/repos/OWNER/REPO/pulls/NUMBER/files
|
||||
|
||||
To create a PR:
|
||||
curl -X POST -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
|
||||
https://api.github.com/repos/OWNER/REPO/pulls \
|
||||
-d '{"title":"Title","body":"Description","head":"branch-name","base":"main"}'
|
||||
|
||||
To apply labels:
|
||||
curl -X POST -H "Authorization: Bearer $GH_TOKEN" -H "Accept: application/vnd.github+json" \
|
||||
https://api.github.com/repos/OWNER/REPO/issues/NUMBER/labels \
|
||||
-d '{"labels": ["bug", "needs-review"]}'
|
||||
"""
|
||||
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GitHubCLIClient:
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the GitHub API client using GitHub App authentication."""
|
||||
self.api_url = "https://api.github.com"
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
"Authorization": f"Bearer {self._get_token()}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
})
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Helper method to retrieve the GitHub token from environment variable."""
|
||||
token = os.getenv("GH_TOKEN")
|
||||
if not token:
|
||||
raise EnvironmentError("GH_TOKEN environment variable is not set")
|
||||
return token
|
||||
|
||||
def _get_with_retries(self, url: str, error_msg: str, retries: int = 3,
|
||||
backoff: int = 2, timeout: int = 10) -> Optional[requests.Response]:
|
||||
"""Internal helper to retry a GET request with exponential backoff."""
|
||||
# no logging the actual request to avoid leaking sensitive information
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = self.session.get(url, timeout=timeout)
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
# for api rate limiting, we check the headers for remaining requests and reset time
|
||||
elif response.status_code == 403 and response.headers.get("X-RateLimit-Remaining") == "0":
|
||||
reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
|
||||
sleep_seconds = max(1, reset_time - int(time.time()) + 1)
|
||||
logger.warning(f"Rate limited. Sleeping for {sleep_seconds} seconds...")
|
||||
time.sleep(sleep_seconds)
|
||||
continue
|
||||
# other errors will use exponential backoff timeout
|
||||
elif response.status_code in {403, 429, 500, 502, 503, 504}:
|
||||
logger.warning(f"Retryable error {response.status_code} on attempt {attempt}.")
|
||||
else:
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Request failed on attempt {attempt}: {e}")
|
||||
logger.error(f"{error_msg} for {url} (Attempt {attempt + 1}/{retries})")
|
||||
if attempt < retries - 1:
|
||||
time.sleep(backoff ** attempt) # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Max retries reached for GET at {url}. Giving up.")
|
||||
return None
|
||||
|
||||
def _get_json(self, url: str, error_msg: str) -> dict:
|
||||
"""Helper method to perform a simple GET request and return a single JSON object."""
|
||||
response = self._get_with_retries(url, error_msg)
|
||||
return response.json() if response else {}
|
||||
|
||||
def _get_paginated_json(self, url: str, error_msg: str) -> List[dict]:
|
||||
"""Helper method to perform a sequence of GET requests with pagination."""
|
||||
results = []
|
||||
while url:
|
||||
response = self._get_with_retries(url, error_msg)
|
||||
if not response:
|
||||
return results
|
||||
results.extend(response.json())
|
||||
url = response.links.get("next", {}).get("url")
|
||||
return results
|
||||
|
||||
def _request_json(self, method: str, url: str, json: Optional[dict] = None,
|
||||
error_msg: str = "", retries: int = 3, backoff: int = 2) -> dict:
|
||||
"""Helper method to perform a request with retries and return JSON response."""
|
||||
# no logging the actual request to avoid leaking sensitive information
|
||||
for attempt in range(retries):
|
||||
response = self.session.request(method, url, json=json)
|
||||
if response.ok:
|
||||
if response.status_code == 204 or not response.text.strip():
|
||||
return {} # DELETE requests have no json content
|
||||
else:
|
||||
return response.json()
|
||||
else:
|
||||
# for api rate limiting, we check the headers for remaining requests and reset time
|
||||
if response.status_code == 403 and response.headers.get("X-RateLimit-Remaining") == "0":
|
||||
reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
|
||||
sleep_seconds = max(1, reset_time - int(time.time()) + 1)
|
||||
logger.warning(f"Rate limited. Sleeping for {sleep_seconds} seconds...")
|
||||
time.sleep(sleep_seconds)
|
||||
# other errors will use exponential backoff timeout
|
||||
else:
|
||||
logger.error(f"{error_msg} for method {method} at {url} (Attempt {attempt + 1}/{retries})")
|
||||
if attempt < retries - 1:
|
||||
time.sleep(backoff ** attempt) # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Max retries reached for method {method} at {url}. Giving up.")
|
||||
return {}
|
||||
|
||||
def get_changed_files(self, repo: str, pr: int) -> List[str]:
|
||||
"""Fetch the changed files in a pull request using GitHub API."""
|
||||
url = f"{self.api_url}/repos/{repo}/pulls/{pr}/files?per_page=50"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
files_data = self._get_paginated_json(url, f"Failed to fetch files for PR #{pr} in {repo}")
|
||||
files = [file["filename"] for file in files_data]
|
||||
logger.debug(f"Changed files in PR #{pr}: {files}")
|
||||
return files
|
||||
|
||||
def get_defined_labels(self, repo: str) -> List[str]:
|
||||
"""Get all labels defined in the given repository."""
|
||||
url = f"{self.api_url}/repos/{repo}/labels?per_page=100"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
labels_data = self._get_paginated_json(url, f"Failed to fetch labels from {repo}")
|
||||
labels = [label["name"] for label in labels_data]
|
||||
logger.debug(f"Defined labels in {repo}: {labels}")
|
||||
return labels
|
||||
|
||||
def get_existing_labels_on_pr(self, repo: str, pr: int) -> List[str]:
|
||||
"""Fetch current labels on a PR."""
|
||||
url = f"{self.api_url}/repos/{repo}/issues/{pr}/labels?per_page=100"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
labels_data = self._get_paginated_json(url, f"Failed to fetch labels for PR #{pr} in {repo}")
|
||||
labels = [label["name"] for label in labels_data]
|
||||
logger.debug(f"Existing labels on PR #{pr}: {labels}")
|
||||
return labels
|
||||
|
||||
def pr_view(self, repo: str, head: str) -> Optional[int]:
|
||||
"""Check if a PR exists for the given repo and branch."""
|
||||
# This is similar to get_pr_by_head_branch but returns only the PR number directly
|
||||
url = f"{self.api_url}/repos/{repo}/pulls?head={repo.split('/')[0]}:{head}&per_page=100"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
result = self._get_paginated_json(url, f"Failed to retrieve PR for head branch {head} in repo {repo}")
|
||||
return result[0]["number"] if result else None
|
||||
|
||||
def get_pr_by_head_branch(self, repo: str, head: str) -> Optional[dict]:
|
||||
"""Fetch the PR object for a given head branch in a repository, if it exists."""
|
||||
# This is similar to pr_view but returns the full PR object
|
||||
url = f"{self.api_url}/repos/{repo}/pulls?head={repo.split('/')[0]}:{head}&state=open&per_page=100"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
data = self._get_paginated_json(url, f"Failed to get PRs for {repo} with head {head}")
|
||||
return data[0] if data else None
|
||||
|
||||
def get_pr_by_number(self, repo: str, pr_number: int) -> Optional[dict]:
|
||||
"""Fetch the PR object for a given PR number in a repository."""
|
||||
url = f"{self.api_url}/repos/{repo}/pulls/{pr_number}"
|
||||
logger.debug(f"Fetching PR #{pr_number} from {repo}")
|
||||
response = self._get_json(url, f"Failed to get PR #{pr_number} from {repo}")
|
||||
return response
|
||||
|
||||
def pr_create(self, repo: str, base: str, head: str, title: str, body: str, dry_run: bool = False) -> None:
|
||||
"""Create a new pull request."""
|
||||
url = f"{self.api_url}/repos/{repo}/pulls"
|
||||
payload = {
|
||||
"title": title,
|
||||
"body": body,
|
||||
"head": head,
|
||||
"base": base
|
||||
}
|
||||
logger.debug(f"Request URL: {url}")
|
||||
logger.debug(f"Request Payload: {payload}")
|
||||
if dry_run:
|
||||
logger.info(f"Dry run: The pull request would be created from {head} to {base} in {repo}")
|
||||
return
|
||||
self._request_json("POST", url, payload, f"Failed to create PR from {head} to {base} in {repo}")
|
||||
logger.info(f"Created PR from {head} to {base} in {repo}.")
|
||||
|
||||
def close_pr_and_delete_branch(self, repo: str, pr_number: int, dry_run: bool = False) -> None:
|
||||
"""Close a pull request and delete the associated branch using the GitHub API."""
|
||||
pr_url = f"{self.api_url}/repos/{repo}/pulls/{pr_number}"
|
||||
logger.debug(f"Request URL: {pr_url}")
|
||||
pr_data = self._get_json(pr_url, f"Failed to fetch PR #{pr_number} in {repo}")
|
||||
head_ref = pr_data.get("head", {}).get("ref")
|
||||
if not head_ref:
|
||||
logger.error(f"Could not determine head branch for PR #{pr_number} in {repo}")
|
||||
return
|
||||
logger.debug(f"PR #{pr_number} head branch: {head_ref}")
|
||||
close_payload = {"state": "closed"}
|
||||
logger.debug(f"Request Payload: {close_payload}")
|
||||
if dry_run:
|
||||
logger.info(f"Dry run: The pull request #{pr_number} would be closed and the branch '{head_ref}' would be deleted in repo '{repo}'")
|
||||
return
|
||||
self._request_json("PATCH", pr_url, close_payload, f"Failed to close PR #{pr_number} in {repo}")
|
||||
branch_url = f"{self.api_url}/repos/{repo}/git/refs/heads/{head_ref}"
|
||||
logger.debug(f"Branch DELETE URL: {branch_url}")
|
||||
self._request_json("DELETE", branch_url, None, f"Failed to delete branch '{head_ref}' for PR #{pr_number}")
|
||||
logger.info(f"Closed pull request #{pr_number} and deleted the branch '{head_ref}' in {repo}.")
|
||||
|
||||
def sync_labels(self, target_repo: str, pr_number: int, labels: List[str], dry_run: bool = False) -> None:
|
||||
"""Sync labels from the source repo to the target repo (only apply existing labels)."""
|
||||
url = f"{self.api_url}/repos/{target_repo}/labels?per_page=100"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
target_repo_labels = {label["name"] for label in self._get_paginated_json(url, f"Failed to fetch labels for {target_repo}")}
|
||||
labels_set = set(labels)
|
||||
labels_to_apply = labels_set & target_repo_labels
|
||||
labels_for_logging = ",".join(labels_to_apply)
|
||||
if labels_to_apply:
|
||||
# note: using issues endpoint for labels as PRs are a subset of issues
|
||||
url = f"{self.api_url}/repos/{target_repo}/issues/{pr_number}/labels"
|
||||
payload = {"labels": list(labels_to_apply)}
|
||||
logger.debug(f"Request URL: {url}")
|
||||
logger.debug(f"Request Payload: {payload}")
|
||||
if not dry_run:
|
||||
self._request_json("POST", url, payload, f"Failed to apply labels to PR #{pr_number} in {target_repo}")
|
||||
logger.info(f"Applied labels '{labels_for_logging}' to PR #{pr_number} in {target_repo}.")
|
||||
else:
|
||||
logger.info(f"Dry run: Labels '{labels_for_logging}' would be applied to PR #{pr_number} in {target_repo}.")
|
||||
else:
|
||||
logger.info(f"No valid labels to apply to PR #{pr_number} in {target_repo}.")
|
||||
|
||||
def get_squash_merge_commit(self, repo: str, pr_number: int) -> Optional[str]:
|
||||
"""Get the squash merge commit SHA of a merged pull request."""
|
||||
url = f"{self.api_url}/repos/{repo}/pulls/{pr_number}"
|
||||
logger.debug(f"Request URL: {url}")
|
||||
data = self._get_json(url, f"Failed to fetch PR #{pr_number} from {repo}")
|
||||
if not data:
|
||||
logger.error(f"No data returned for PR #{pr_number}")
|
||||
return None
|
||||
if data.get("merged") and data.get("merge_commit_sha"):
|
||||
logger.debug(f"PR #{pr_number} merged commit: {data['merge_commit_sha']}")
|
||||
return data["merge_commit_sha"]
|
||||
logger.warning(f"PR #{pr_number} is not merged or missing merge commit SHA.")
|
||||
return None
|
||||
|
||||
def get_user(self, username: str) -> tuple[str, str]:
|
||||
"""Fetch the name and email of a GitHub user. Falls back to login and no-reply email."""
|
||||
url = f"{self.api_url}/users/{username}"
|
||||
logger.debug(f"Fetching user profile for @{username}")
|
||||
data = self._get_json(url, f"Failed to fetch user profile for @{username}")
|
||||
name = data.get("name") or username
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
user_id = data.get("id")
|
||||
if user_id:
|
||||
email = f"{user_id}+{username}@users.noreply.github.com"
|
||||
else:
|
||||
email = f"{username}@users.noreply.github.com"
|
||||
return name, email
|
||||
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Determine monorepo root and output CODEOWNERS path
|
||||
monorepo_root = Path(__file__).resolve().parents[2]
|
||||
output_path = monorepo_root / ".github" / "CODEOWNERS"
|
||||
|
||||
merged_entries = []
|
||||
|
||||
# Walk top-level directories (excluding .github/.git/etc.)
|
||||
for subdir in monorepo_root.iterdir():
|
||||
if subdir.name.startswith(".") or not subdir.is_dir():
|
||||
continue
|
||||
|
||||
# Look for CODEOWNERS in root or .github directory of the submodule
|
||||
candidates = [subdir / "CODEOWNERS", subdir / ".github" / "CODEOWNERS"]
|
||||
|
||||
for codeowners_file in candidates:
|
||||
if codeowners_file.is_file():
|
||||
with codeowners_file.open("r") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip empty lines or comments
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = stripped.split()
|
||||
if not parts:
|
||||
continue
|
||||
|
||||
original_path = parts[0]
|
||||
owners = " ".join(parts[1:])
|
||||
|
||||
# Ensure prefixed path starts with a single slash
|
||||
prefixed_path = (
|
||||
f"/{subdir.name.rstrip('/')}{original_path}"
|
||||
if original_path.startswith("/")
|
||||
else f"/{subdir.name}/{original_path}"
|
||||
)
|
||||
|
||||
merged_entries.append(f"{prefixed_path} {owners}")
|
||||
|
||||
# Sort for consistency
|
||||
merged_entries.sort()
|
||||
|
||||
# Write merged CODEOWNERS file
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with output_path.open("w") as out:
|
||||
out.write("# Auto-generated CODEOWNERS file\n\n")
|
||||
out.write("\n".join(merged_entries))
|
||||
|
||||
print(f"✅ Merged CODEOWNERS written to {output_path}")
|
||||
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[2] # Assuming script is in .github/scripts/
|
||||
OUTPUT_FILE = ROOT_DIR / ".gitmodules"
|
||||
MODULE_FILES = list(ROOT_DIR.glob("*/.gitmodules")) + list(ROOT_DIR.glob("*/.github/.gitmodules"))
|
||||
|
||||
combined = configparser.ConfigParser()
|
||||
combined.optionxform = str # Preserve case sensitivity
|
||||
|
||||
for module_file in MODULE_FILES:
|
||||
subdir = module_file.parent.name
|
||||
local_config = configparser.ConfigParser()
|
||||
local_config.optionxform = str
|
||||
local_config.read(module_file)
|
||||
|
||||
for section in local_config.sections():
|
||||
if section.startswith("submodule "):
|
||||
name = section.split('"')[1]
|
||||
new_name = f"{subdir}/{name}"
|
||||
new_section = f'submodule "{new_name}"'
|
||||
|
||||
combined[new_section] = {}
|
||||
for key, value in local_config[section].items():
|
||||
if key == "path":
|
||||
value = f"{subdir}/{value}"
|
||||
combined[new_section][key] = value
|
||||
|
||||
# Write combined .gitmodules
|
||||
with OUTPUT_FILE.open("w") as f:
|
||||
for section in combined.sections():
|
||||
f.write(f"[{section}]\n")
|
||||
for key, value in combined[section].items():
|
||||
f.write(f"\t{key} = {value}\n")
|
||||
f.write("\n")
|
||||
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
PR Category Label Script
|
||||
--------------------
|
||||
This script analyzes the file paths changed in a pull request and determines which
|
||||
category labels should be added or removed based on the modified files.
|
||||
|
||||
It uses GitHub's cli to fetch the changed files and the existing labels on the pull request.
|
||||
Then, it computes the desired labels based on file paths, compares them to the existing labels,
|
||||
and applies the necessary additions and removals unless in dry-run mode.
|
||||
|
||||
Arguments:
|
||||
--repo : Full repository name (e.g., org/repo)
|
||||
--pr : Pull request number
|
||||
--dry-run : If set, will only log actions without making changes.
|
||||
--debug : If set, enables detailed debug logging.
|
||||
|
||||
Outputs:
|
||||
Writes 'add' and 'remove' keys to the GitHub Actions $GITHUB_OUTPUT file, which
|
||||
the workflow reads to apply label changes using the GitHub CLI.
|
||||
|
||||
Example Usage:
|
||||
To run in debug mode and perform a dry-run (no changes made):
|
||||
python pr_auto_label.py --repo ROCm/rocm-libraries --pr <pr-number> --dry-run --debug
|
||||
To run in debug mode and apply label changes:
|
||||
python pr_auto_label.py --repo ROCm/rocm-libraries --pr <pr-number> --debug
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from github_cli_client import GitHubCLIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Apply labels based on PR's changed files.")
|
||||
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
|
||||
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print results without writing to GITHUB_OUTPUT.")
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
def compute_desired_labels(file_paths: list) -> set:
|
||||
"""Determine the desired labels based on the changed files."""
|
||||
desired_labels = set()
|
||||
for path in file_paths:
|
||||
parts = Path(path).parts
|
||||
if len(parts) >= 2:
|
||||
if parts[0] == "projects":
|
||||
desired_labels.add(f"project: {parts[1]}")
|
||||
elif parts[0] == "shared":
|
||||
desired_labels.add(f"shared: {parts[1]}")
|
||||
logger.debug(f"Desired labels based on changes: {desired_labels}")
|
||||
return desired_labels
|
||||
|
||||
def output_labels(existing_labels: List[str], desired_labels: List[str], dry_run: bool) -> None:
|
||||
"""Output the labels to add/remove to GITHUB_OUTPUT or log them in dry-run mode."""
|
||||
existing_auto_labels = {
|
||||
label for label in existing_labels
|
||||
if label.startswith("project: ") or label.startswith("shared: ")
|
||||
}
|
||||
to_add = sorted(desired_labels - set(existing_labels))
|
||||
to_remove = sorted(existing_auto_labels - desired_labels)
|
||||
logger.debug(f"Labels to add: {to_add}")
|
||||
logger.debug(f"Labels to remove: {to_remove}")
|
||||
if dry_run:
|
||||
logger.info("Dry run enabled. Labels will not be applied.")
|
||||
else:
|
||||
output_file = os.environ.get("GITHUB_OUTPUT")
|
||||
if output_file:
|
||||
with open(output_file, 'a') as f:
|
||||
print(f"label_add={','.join(to_add)}", file=f)
|
||||
print(f"label_remove={','.join(to_remove)}", file=f)
|
||||
logger.info(f"Wrote to GITHUB_OUTPUT: add={','.join(to_add)}")
|
||||
logger.info(f"Wrote to GITHUB_OUTPUT: remove={','.join(to_remove)}")
|
||||
else:
|
||||
print("GITHUB_OUTPUT environment variable not set. Outputs cannot be written.")
|
||||
sys.exit(1)
|
||||
|
||||
def main(argv=None) -> None:
|
||||
"""Main function to execute the PR auto label logic."""
|
||||
args = parse_arguments(argv)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.debug else logging.INFO
|
||||
)
|
||||
client = GitHubCLIClient()
|
||||
changed_files = [file for file in client.get_changed_files(args.repo, int(args.pr))]
|
||||
|
||||
if not changed_files:
|
||||
logger.warning("REST API failed or returned no changed files. Falling back to Git CLI...")
|
||||
try:
|
||||
# Ensure fetch is safe
|
||||
os.system("git fetch origin +refs/pull/*/merge:refs/remotes/origin/pr/*")
|
||||
# Get merge commit ref for this PR
|
||||
base_ref = f"origin/{os.getenv('GITHUB_BASE_REF', 'main')}"
|
||||
head_ref = "HEAD" # Assumes checkout to PR merge ref
|
||||
result = os.popen(f"git diff --name-only {base_ref}...{head_ref}").read()
|
||||
changed_files = result.strip().splitlines()
|
||||
logger.info(f"Fallback changed files: {changed_files}")
|
||||
except Exception as e:
|
||||
logger.error(f"Git CLI fallback failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
existing_labels = client.get_existing_labels_on_pr(args.repo, int(args.pr))
|
||||
desired_labels = compute_desired_labels(changed_files)
|
||||
output_labels(existing_labels, desired_labels, args.dry_run)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
PR Detect Changed Subtrees Script
|
||||
---------------------------------
|
||||
This script analyzes a pull request's changed files and determines which subtrees
|
||||
(defined in .github/repos-config.json by category/name) were affected.
|
||||
|
||||
Steps:
|
||||
1. Fetch the changed files in the PR using the GitHub API.
|
||||
2. Load the subtree mapping from repos-config.json.
|
||||
3. Match changed paths against known category/name prefixes.
|
||||
4. Emit a new-line separated list of changed subtrees to GITHUB_OUTPUT as 'subtrees'.
|
||||
|
||||
Arguments:
|
||||
--repo : Full repository name (e.g., org/repo)
|
||||
--pr : Pull request number
|
||||
--config : OPTIONAL, path to the repos-config.json file.
|
||||
--require-auto-pull : If set, only include entries with auto_subtree_pull=true.
|
||||
--require-auto-push : If set, only include entries with auto_subtree_push=true.
|
||||
--dry-run : If set, will only log actions without making changes.
|
||||
--debug : If set, enables detailed debug logging.
|
||||
|
||||
Outputs:
|
||||
Writes 'subtrees' key to the GitHub Actions $GITHUB_OUTPUT file, which
|
||||
the workflow reads to pass paths to the checkout stages.
|
||||
The output is a new-line separated list of subtrees in `category/name` format.
|
||||
|
||||
Example Usage:
|
||||
To run in auto-push situations in debug mode and perform a dry-run (no changes made):
|
||||
python pr_detect_changed_subtrees.py --repo ROCm/rocm-libraries --pr 123 --require-auto-push --debug --dry-run
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Optional, Set
|
||||
from github_cli_client import GitHubCLIClient
|
||||
from repo_config_model import RepoEntry
|
||||
from config_loader import load_repo_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Detect changed subtrees in a PR.")
|
||||
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
|
||||
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
|
||||
parser.add_argument("--config", required=False, default=".github/repos-config.json", help="Path to the repos-config.json file")
|
||||
parser.add_argument("--require-auto-pull", action="store_true", help="Only include entries with auto_subtree_pull=true")
|
||||
parser.add_argument("--require-auto-push", action="store_true", help="Only include entries with auto_subtree_push=true")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print results without writing to GITHUB_OUTPUT.")
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
def get_valid_prefixes(config: List[RepoEntry]) -> Set[str]:
|
||||
"""Extract valid subtree prefixes from the configuration."""
|
||||
valid_prefixes = {
|
||||
f"{entry.category}/{entry.name}"
|
||||
for entry in config
|
||||
}
|
||||
logger.debug("Valid subtrees:\n" + "\n".join(sorted(valid_prefixes)))
|
||||
return valid_prefixes
|
||||
|
||||
def get_valid_prefixes(config: List[RepoEntry], require_auto_pull: bool = False, require_auto_push: bool = False) -> Set[str]:
|
||||
"""Extract valid subtree prefixes from the configuration based on filters."""
|
||||
valid_prefixes = set()
|
||||
for entry in config:
|
||||
if require_auto_pull and not getattr(entry, "auto_subtree_pull", False):
|
||||
continue
|
||||
if require_auto_push and not getattr(entry, "auto_subtree_push", False):
|
||||
continue
|
||||
valid_prefixes.add(f"{entry.category}/{entry.name}")
|
||||
logger.debug("Valid subtrees:\n" + "\n".join(sorted(valid_prefixes)))
|
||||
return valid_prefixes
|
||||
|
||||
def find_matched_subtrees(changed_files: List[str], valid_prefixes: Set[str]) -> List[str]:
|
||||
"""Find subtrees that match the changed files."""
|
||||
changed_subtrees = {
|
||||
"/".join(path.split("/", 2)[:2])
|
||||
for path in changed_files
|
||||
if len(path.split("/")) >= 2
|
||||
}
|
||||
matched = sorted(changed_subtrees & valid_prefixes)
|
||||
skipped = sorted(changed_subtrees - valid_prefixes)
|
||||
if skipped:
|
||||
logger.debug(f"Skipped subtrees: {skipped}")
|
||||
logger.debug(f"Matched subtrees: {matched}")
|
||||
return matched
|
||||
|
||||
def output_subtrees(matched_subtrees: List[str], dry_run: bool) -> None:
|
||||
"""Output the matched subtrees to GITHUB_OUTPUT or log them in dry-run mode."""
|
||||
newline_separated = "\n".join(matched_subtrees)
|
||||
if dry_run:
|
||||
logger.info(f"[Dry-run] Would output:\n{newline_separated}")
|
||||
else:
|
||||
output_file = os.environ.get('GITHUB_OUTPUT')
|
||||
if output_file:
|
||||
with open(output_file, 'a') as f:
|
||||
print(f"subtrees<<EOF\n{newline_separated}\nEOF", file=f)
|
||||
logger.info("Wrote matched subtrees to GITHUB_OUTPUT.")
|
||||
else:
|
||||
logger.error("GITHUB_OUTPUT environment variable not set. Outputs cannot be written.")
|
||||
sys.exit(1)
|
||||
|
||||
def main(argv=None) -> None:
|
||||
"""Main function to determine changed subtrees in PR."""
|
||||
args = parse_arguments(argv)
|
||||
logging.basicConfig(
|
||||
level = logging.DEBUG if args.debug else logging.INFO
|
||||
)
|
||||
client = GitHubCLIClient()
|
||||
config = load_repo_config(args.config)
|
||||
changed_files = client.get_changed_files(args.repo, int(args.pr))
|
||||
|
||||
if not changed_files:
|
||||
logger.warning("REST API failed or returned no changed files. Falling back to Git CLI...")
|
||||
try:
|
||||
# Ensure fetch is safe
|
||||
os.system("git fetch origin +refs/pull/*/merge:refs/remotes/origin/pr/*")
|
||||
# Get merge commit ref for this PR
|
||||
base_ref = f"origin/{os.getenv('GITHUB_BASE_REF', 'main')}"
|
||||
head_ref = "HEAD" # Assumes checkout to PR merge ref
|
||||
result = os.popen(f"git diff --name-only {base_ref}...{head_ref}").read()
|
||||
changed_files = result.strip().splitlines()
|
||||
logger.info(f"Fallback changed files: {changed_files}")
|
||||
except Exception as e:
|
||||
logger.error(f"Git CLI fallback failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
valid_prefixes = get_valid_prefixes(config, args.require_auto_pull, args.require_auto_push)
|
||||
matched_subtrees = find_matched_subtrees(changed_files, valid_prefixes)
|
||||
output_subtrees(matched_subtrees, args.dry_run)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Sync Patches to Subrepositories
|
||||
-------------------------------
|
||||
|
||||
This script is part of the monorepo synchronization system. It runs after a monorepo pull request
|
||||
is merged and applies relevant changes to the corresponding sub-repositories using Git patches.
|
||||
|
||||
- Uses the merge commit of the monorepo PR to extract subtree changes.
|
||||
- Detects file-level changes including adds, deletes, and renames.
|
||||
- Applies changes directly using file copy/move/delete as needed.
|
||||
- Squashes all commits per subtree into one before pushing.
|
||||
- Uses the repos-config.json file to map subtrees to sub-repos.
|
||||
- Assumes this script is run from the root of the monorepo.
|
||||
|
||||
Arguments:
|
||||
--repo : Full repository name (e.g., org/repo)
|
||||
--pr : Pull request number
|
||||
--subtrees : A newline-separated list of subtree paths in category/name format (e.g., projects/rocBLAS)
|
||||
--config : OPTIONAL, path to the repos-config.json file
|
||||
--dry-run : If set, will only log actions without making changes.
|
||||
--debug : If set, enables detailed debug logging.
|
||||
|
||||
Example Usage:
|
||||
python pr_merge_sync_patches.py --repo ROCm/rocm-libraries --pr 123 --subtrees "$(printf 'projects/rocBLAS\nprojects/hipBLASLt\nshared/rocSPARSE')" --dry-run --debug
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Optional, List, Tuple
|
||||
from pathlib import Path
|
||||
from github_cli_client import GitHubCLIClient
|
||||
from config_loader import load_repo_config
|
||||
from repo_config_model import RepoEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Apply subtree patches to sub-repositories.")
|
||||
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
|
||||
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
|
||||
parser.add_argument("--subtrees", required=True, help="Newline-separated list of changed subtrees (category/name)")
|
||||
parser.add_argument("--config", required=False, default=".github/repos-config.json", help="Path to the repos-config.json file")
|
||||
parser.add_argument("--dry-run", action="store_true", help="If set, only logs actions without making changes.")
|
||||
parser.add_argument("--debug", action="store_true", help="If set, enables detailed debug logging.")
|
||||
return parser.parse_args(argv)
|
||||
|
||||
def get_subtree_info(config: List[RepoEntry], subtrees: List[str]) -> List[RepoEntry]:
|
||||
"""Return config entries matching the given subtrees in category/name format."""
|
||||
requested = set(subtrees)
|
||||
matched = [
|
||||
entry for entry in config
|
||||
if f"{entry.category}/{entry.name}" in requested
|
||||
]
|
||||
missing = requested - {f"{e.category}/{e.name}" for e in matched}
|
||||
if missing:
|
||||
logger.warning(f"Some subtrees not found in config: {', '.join(sorted(missing))}")
|
||||
return matched
|
||||
|
||||
def _run_git(args: List[str], cwd: Optional[Path] = None) -> str:
|
||||
"""Run a git command and return stdout."""
|
||||
cmd = ["git"] + args
|
||||
logger.debug(f"Running git command: {' '.join(cmd)} (cwd={cwd})")
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Git command failed: {' '.join(cmd)}\n{result.stderr}")
|
||||
raise RuntimeError(f"Git command failed: {' '.join(cmd)}\n{result.stderr}")
|
||||
return result.stdout.strip()
|
||||
|
||||
def _clone_subrepo(repo_url: str, branch: str, destination: Path) -> None:
|
||||
"""Clone a specific branch from the given GitHub repository into the destination path."""
|
||||
_run_git([
|
||||
"clone",
|
||||
"--branch", branch,
|
||||
"--single-branch",
|
||||
f"https://github.com/{repo_url}",
|
||||
str(destination)
|
||||
])
|
||||
logger.debug(f"Cloned {repo_url} into {destination}")
|
||||
|
||||
def _configure_git_user(repo_path: Path) -> None:
|
||||
"""Configure git user.name and user.email for the given repository directory."""
|
||||
_run_git(["config", "user.name", "assistant-librarian[bot]"], cwd=repo_path)
|
||||
_run_git(["config", "user.email", "assistant-librarian[bot]@users.noreply.github.com"], cwd=repo_path)
|
||||
|
||||
def _apply_patch(repo_path: Path, patch_path: Path, rel_file_path: Path, monorepo_path: Path, prefix: str) -> None:
|
||||
"""Try to apply a patch; if it fails, fallback to full file replacement."""
|
||||
try:
|
||||
_run_git(["am", str(patch_path)], cwd=repo_path)
|
||||
logger.info(f"Applied patch {patch_path.name} successfully")
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Patch {patch_path.name} failed to apply; falling back to full file copy")
|
||||
|
||||
# Construct source and destination
|
||||
monorepo_file = monorepo_path / prefix / rel_file_path
|
||||
subrepo_file = repo_path / rel_file_path
|
||||
subrepo_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not monorepo_file.exists():
|
||||
raise RuntimeError(f"Fallback failed: {monorepo_file} does not exist")
|
||||
|
||||
shutil.copyfile(monorepo_file, subrepo_file)
|
||||
_run_git(["add", str(rel_file_path)], cwd=repo_path)
|
||||
logger.info(f"Copied {monorepo_file} -> {subrepo_file}")
|
||||
|
||||
def _set_authenticated_remote(repo_path: Path, repo_url: str) -> None:
|
||||
"""Set the push URL to use the GitHub App token from GH_TOKEN env."""
|
||||
token = os.environ.get("GH_TOKEN")
|
||||
if not token:
|
||||
raise RuntimeError("GH_TOKEN environment variable is not set")
|
||||
remote_url = f"https://x-access-token:{token}@github.com/{repo_url}.git"
|
||||
_run_git(["remote", "set-url", "origin", remote_url], cwd=repo_path)
|
||||
|
||||
def _push_changes(repo_path: Path, branch: str) -> None:
|
||||
"""Push the commit to origin of branch."""
|
||||
_run_git(["push", "origin", branch], cwd=repo_path)
|
||||
logger.debug(f"Pushed changes from {repo_path} to origin")
|
||||
|
||||
def generate_file_level_patches(prefix: str, merge_sha: str, output_dir: Path) -> tuple[list[str], list[str], list[tuple[str, str]], list[str], list[Path]]:
|
||||
"""Generate one patch per modified file, and collect adds, deletes, and renames."""
|
||||
diff_output = _run_git([
|
||||
"diff", "--name-status", "-M", f"{merge_sha}^!", "--", prefix
|
||||
])
|
||||
|
||||
added_files = []
|
||||
deleted_files = []
|
||||
renamed_files = []
|
||||
modified_files = []
|
||||
patch_files = []
|
||||
|
||||
for line in diff_output.splitlines():
|
||||
parts = line.split('\t')
|
||||
status = parts[0]
|
||||
if status == 'A':
|
||||
added_files.append(parts[1])
|
||||
elif status == 'M':
|
||||
file_path = parts[1]
|
||||
patch_path = output_dir / (file_path.replace("/", "_") + ".patch")
|
||||
_run_git([
|
||||
"format-patch",
|
||||
"-1", merge_sha,
|
||||
f"--relative={prefix}",
|
||||
"--output", str(patch_path),
|
||||
"--", file_path
|
||||
])
|
||||
patch_files.append(patch_path)
|
||||
modified_files.append(file_path)
|
||||
elif status == 'D':
|
||||
deleted_files.append(parts[1])
|
||||
elif status.startswith('R'):
|
||||
renamed_files.append((parts[1], parts[2]))
|
||||
|
||||
logger.debug(f"Generated {len(patch_files)} modified file patches, "
|
||||
f"{len(added_files)} added, {len(deleted_files)} deleted, "
|
||||
f"{len(renamed_files)} renamed under {prefix}")
|
||||
return added_files, deleted_files, renamed_files, modified_files, patch_files
|
||||
|
||||
def resolve_patch_author(client: GitHubCLIClient, repo: str, pr: int) -> tuple[str, str]:
|
||||
"""Determine the appropriate author for the patch"""
|
||||
pr_data = client.get_pr_by_number(repo, pr)
|
||||
body = pr_data.get("body", "") or ""
|
||||
match = re.search(r"Originally authored by @([A-Za-z0-9_-]+)", body)
|
||||
if match:
|
||||
username = match.group(1)
|
||||
logger.debug(f"Found originally authored username in PR body: @{username}")
|
||||
else:
|
||||
username = pr_data["user"]["login"]
|
||||
logger.debug(f"No explicit original author, using PR author: @{username}")
|
||||
name, email = client.get_user(username)
|
||||
return name or username, email
|
||||
|
||||
def apply_patches_and_squash(entry: RepoEntry, monorepo_url: str, monorepo_pr: int,
|
||||
added_files: list[str], deleted_files: list[str], renamed_files: list[tuple[str, str]],
|
||||
modified_files: list[str], modified_patch_paths: list[Path],
|
||||
author_name: str, author_email: str, merge_sha: str, dry_run: bool = False) -> None:
|
||||
"""
|
||||
Clone the subrepo, apply file-level patches each as a commit,
|
||||
delete files with git rm, copy added files, rename files,
|
||||
then squash all new commits into one before pushing.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
prefix = f"{entry.category}/{entry.name}"
|
||||
if dry_run:
|
||||
logger.info(f"[Dry-run] Sync for {entry.name}:")
|
||||
prefix_path = Path(prefix)
|
||||
|
||||
if added_files:
|
||||
logger.info(" Added files:")
|
||||
for f in added_files:
|
||||
short_path = Path(f).relative_to(prefix_path)
|
||||
logger.info(f" {short_path}")
|
||||
if deleted_files:
|
||||
logger.info(" Deleted files:")
|
||||
for f in deleted_files:
|
||||
short_path = Path(f).relative_to(prefix_path)
|
||||
logger.info(f" {short_path}")
|
||||
if renamed_files:
|
||||
logger.info(" Renamed files:")
|
||||
for old, new in renamed_files:
|
||||
old_rel = Path(old).relative_to(prefix_path)
|
||||
new_rel = Path(new).relative_to(prefix_path)
|
||||
logger.info(f" {old_rel} -> {new_rel}")
|
||||
if modified_files:
|
||||
logger.info(" Modified files (via patch):")
|
||||
for f in modified_files:
|
||||
short_path = Path(f).relative_to(prefix_path)
|
||||
logger.info(f" {short_path}")
|
||||
if not (added_files or deleted_files or renamed_files or modified_files or modified_patch_paths):
|
||||
logger.info(" No changes detected.")
|
||||
return
|
||||
|
||||
subrepo_path = Path(tmpdir) / entry.name
|
||||
_clone_subrepo(entry.url, entry.branch, subrepo_path)
|
||||
|
||||
_configure_git_user(subrepo_path)
|
||||
|
||||
# Get current HEAD commit (before applying patches)
|
||||
base_commit = _run_git(["rev-parse", "HEAD"], cwd=subrepo_path)
|
||||
|
||||
# Handle deletes
|
||||
for file_path in deleted_files:
|
||||
rel_path = file_path[len(prefix)+1:] if file_path.startswith(prefix + "/") else file_path
|
||||
_run_git(["rm", rel_path], cwd=subrepo_path)
|
||||
|
||||
# Handle renames
|
||||
for old, new in renamed_files:
|
||||
old_rel = old[len(prefix)+1:] if old.startswith(prefix + "/") else old
|
||||
new_rel = new[len(prefix)+1:] if new.startswith(prefix + "/") else new
|
||||
_run_git(["mv", old_rel, new_rel], cwd=subrepo_path)
|
||||
|
||||
# Handle adds
|
||||
for file_path in added_files:
|
||||
rel_path = file_path[len(prefix)+1:] if file_path.startswith(prefix + "/") else file_path
|
||||
src = Path(prefix) / rel_path
|
||||
dst = subrepo_path / rel_path
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(src, dst)
|
||||
|
||||
# Handle modified files (apply patches one by one)
|
||||
for patch_path, full_file_path in zip(modified_patch_paths, modified_files):
|
||||
rel_path = full_file_path[len(prefix)+1:] if full_file_path.startswith(prefix + "/") else full_file_path
|
||||
logger.debug(f"Applying patch {patch_path.name} to {entry.name} at {rel_path}")
|
||||
_apply_patch(subrepo_path, patch_path, Path(rel_path), Path.cwd(), prefix)
|
||||
|
||||
# Final squash
|
||||
commit_msg = f"[rocm-libraries] {monorepo_url}#{monorepo_pr} (commit {merge_sha[:7]})\n\n" + \
|
||||
_run_git(["log", "-1", "--pretty=%B", merge_sha])
|
||||
_run_git(["reset", "--soft", base_commit], cwd=subrepo_path)
|
||||
_run_git(["commit", "-m", commit_msg, "--author", f"{author_name} <{author_email}>"], cwd=subrepo_path)
|
||||
|
||||
_set_authenticated_remote(subrepo_path, entry.url)
|
||||
_push_changes(subrepo_path, entry.branch)
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> None:
|
||||
"""Main function to apply patches to sub-repositories."""
|
||||
args = parse_arguments(argv)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.debug else logging.INFO
|
||||
)
|
||||
client = GitHubCLIClient()
|
||||
config = load_repo_config(args.config)
|
||||
subtrees = [line.strip() for line in args.subtrees.splitlines() if line.strip()]
|
||||
relevant_subtrees = get_subtree_info(config, subtrees)
|
||||
merge_sha = client.get_squash_merge_commit(args.repo, args.pr)
|
||||
logger.debug(f"Merge commit for PR #{args.pr} in {args.repo}: {merge_sha}")
|
||||
_run_git(["checkout", merge_sha])
|
||||
logger.info(f"Checked out merge commit {merge_sha} for patch operations")
|
||||
for entry in relevant_subtrees:
|
||||
prefix = f"{entry.category}/{entry.name}"
|
||||
logger.debug(f"Processing subtree {prefix}")
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
patch_dir = Path(tmpdir)
|
||||
# Generate patches and lists of adds/deletes/renames
|
||||
added_files, deleted_files, renamed_files, modified_files, modified_patch_paths, = generate_file_level_patches(prefix, merge_sha, patch_dir)
|
||||
if not (added_files or deleted_files or renamed_files or modified_files or modified_patch_paths):
|
||||
logger.info(f"No changes to apply for {prefix}")
|
||||
continue
|
||||
author_name, author_email = resolve_patch_author(client, args.repo, args.pr)
|
||||
apply_patches_and_squash(entry, args.repo, args.pr,
|
||||
added_files, deleted_files, renamed_files, modified_files, modified_patch_paths,
|
||||
author_name, author_email, merge_sha,
|
||||
args.dry_run)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Repository Config Model
|
||||
------------------------
|
||||
|
||||
This module defines Pydantic data models for validating and parsing the repos-config.json file.
|
||||
|
||||
Structure of the expected JSON:
|
||||
|
||||
{
|
||||
"repositories": [
|
||||
{
|
||||
"name": "rocblas",
|
||||
"url": "ROCm/rocBLAS",
|
||||
"branch": "develop",
|
||||
"category": "projects"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
class RepoEntry(BaseModel):
|
||||
"""
|
||||
Represents a single repository entry in the repos-config.json file.
|
||||
|
||||
Fields:
|
||||
name : Name of the project matching packaging file names. Lower-cased and no underscores. (e.g., "rocblas")
|
||||
url : Individual GitHub org plus repo names in matching case and punctuation. (e.g., "ROCm/rocBLAS")
|
||||
branch : The base branch of the sub-repo to target (e.g., "develop").
|
||||
category : Directory category in the monorepo (e.g., "projects" or "shared").
|
||||
"""
|
||||
name: str
|
||||
url: str
|
||||
branch: str
|
||||
category: str
|
||||
auto_subtree_pull: bool
|
||||
auto_subtree_push: bool
|
||||
|
||||
class RepoConfig(BaseModel):
|
||||
"""
|
||||
Represents the full config file structure.
|
||||
|
||||
Fields:
|
||||
repositories : List of RepoEntry items.
|
||||
"""
|
||||
repositories: List[RepoEntry]
|
||||
@@ -0,0 +1,73 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.fspath(Path(__file__).parent.parent))
|
||||
import therock_configure_ci
|
||||
|
||||
class ConfigureCITest(unittest.TestCase):
|
||||
def test_pull_request(self):
|
||||
args = {
|
||||
"is_pull_request": True,
|
||||
"input_subtrees": "projects/rocprim\nprojects/hipcub"
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertEqual(len(project_to_run), 1)
|
||||
|
||||
def test_pull_request_empty(self):
|
||||
args = {
|
||||
"is_pull_request": True,
|
||||
"input_subtrees": ""
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertEqual(len(project_to_run), 0)
|
||||
|
||||
def test_workflow_dispatch(self):
|
||||
args = {
|
||||
"is_workflow_dispatch": True,
|
||||
"input_projects": "projects/rocprim projects/hipcub"
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertEqual(len(project_to_run), 1)
|
||||
|
||||
def test_workflow_dispatch_bad_input(self):
|
||||
args = {
|
||||
"is_workflow_dispatch": True,
|
||||
"input_projects": "projects/rocprim$$projects/hipcub"
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertEqual(len(project_to_run), 0)
|
||||
|
||||
def test_workflow_dispatch_all(self):
|
||||
args = {
|
||||
"is_workflow_dispatch": True,
|
||||
"input_projects": "all"
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertGreaterEqual(len(project_to_run), 1)
|
||||
|
||||
def test_workflow_dispatch_empty(self):
|
||||
args = {
|
||||
"is_workflow_dispatch": True,
|
||||
"input_projects": ""
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertEqual(len(project_to_run), 0)
|
||||
|
||||
def test_is_push(self):
|
||||
args = {
|
||||
"is_push": True,
|
||||
}
|
||||
|
||||
project_to_run = therock_configure_ci.retrieve_projects(args)
|
||||
self.assertGreaterEqual(len(project_to_run), 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
This script determines which build flag and tests to run based on SUBTREES
|
||||
|
||||
Required environment variables:
|
||||
- SUBTREES
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from therock_matrix import subtree_to_project_map, project_map
|
||||
from typing import Mapping
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def set_github_output(d: Mapping[str, str]):
|
||||
"""Sets GITHUB_OUTPUT values.
|
||||
See https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
|
||||
"""
|
||||
logging.info(f"Setting github output:\n{d}")
|
||||
step_output_file = os.environ.get("GITHUB_OUTPUT", "")
|
||||
if not step_output_file:
|
||||
logging.warning("Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
|
||||
return
|
||||
with open(step_output_file, "a") as f:
|
||||
f.writelines(f"{k}={v}" + "\n" for k, v in d.items())
|
||||
|
||||
|
||||
def retrieve_projects(args):
|
||||
# TODO(geomin12): #590 Enable TheRock CI for forked PRs
|
||||
if args.get("is_forked_pr"):
|
||||
logging.info("Warning: not enabling any projects due to is_forked_pr. Builds/tests for forked PRs are disabled pending: https://github.com/ROCm/rocm-libraries/issues/590")
|
||||
return []
|
||||
|
||||
if args.get("is_pull_request"):
|
||||
subtrees = args.get("input_subtrees").split("\n")
|
||||
|
||||
if args.get("is_workflow_dispatch"):
|
||||
if args.get("input_projects") == "all":
|
||||
subtrees = list(subtree_to_project_map.keys())
|
||||
else:
|
||||
subtrees = args.get("input_projects").split()
|
||||
|
||||
# If a push event to develop happens, we run tests on all subtrees
|
||||
if args.get("is_push"):
|
||||
subtrees = list(subtree_to_project_map.keys())
|
||||
|
||||
projects = set()
|
||||
# collect the associated subtree to project
|
||||
for subtree in subtrees:
|
||||
if subtree in subtree_to_project_map:
|
||||
projects.add(subtree_to_project_map.get(subtree))
|
||||
|
||||
|
||||
# retrieve the subtrees to checkout, cmake options to build, and projects to test
|
||||
project_to_run = []
|
||||
for project in projects:
|
||||
if project in project_map:
|
||||
project_to_run.append(project_map.get(project))
|
||||
|
||||
return project_to_run
|
||||
|
||||
|
||||
def run(args):
|
||||
project_to_run = retrieve_projects(args)
|
||||
set_github_output({"projects": json.dumps(project_to_run)})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = {}
|
||||
github_event_name = os.getenv("GITHUB_EVENT_NAME")
|
||||
args["is_pull_request"] = github_event_name == "pull_request"
|
||||
args["is_push"] = github_event_name == "push"
|
||||
args["is_workflow_dispatch"] = github_event_name == "workflow_dispatch"
|
||||
|
||||
is_forked_pr = os.getenv("IS_FORKED_PR")
|
||||
args["is_forked_pr"] = is_forked_pr == "true"
|
||||
|
||||
input_subtrees = os.getenv("SUBTREES", "")
|
||||
args["input_subtrees"] = input_subtrees
|
||||
|
||||
input_projects = os.getenv("PROJECTS", "")
|
||||
args["input_projects"] = input_projects
|
||||
|
||||
logging.info(f"Retrieved arguments {args}")
|
||||
|
||||
run(args)
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
This dictionary is used to map specific file directory changes to the corresponding build flag and tests
|
||||
"""
|
||||
subtree_to_project_map = {
|
||||
"projects/rocprim": "prim",
|
||||
"projects/rocthrust": "prim",
|
||||
"projects/hipcub": "prim",
|
||||
"projects/rocrand": "rand",
|
||||
"projects/hiprand": "rand"
|
||||
}
|
||||
|
||||
project_map = {
|
||||
"prim": {
|
||||
"cmake_options": "-DTHEROCK_ENABLE_PRIM=ON -DTHEROCK_ENABLE_ALL=OFF",
|
||||
"project_to_test": "rocprim, rocthrust, hipcub",
|
||||
"subtree_checkout": "projects/rocprim\nprojects/hipcub\nprojects/rocthrust",
|
||||
},
|
||||
"rand": {
|
||||
"cmake_options": "-DTHEROCK_ENABLE_RAND=ON -DTHEROCK_ENABLE_ALL=OFF",
|
||||
"project_to_test": "rocrand, hiprand",
|
||||
"subtree_checkout": "projects/rocrand\nprojects/hiprand",
|
||||
},
|
||||
}
|
||||
Yeni konuda referans
Bir kullanıcı engelle