Files
rocm-systems/.github/scripts/pr_category_label.py
T
2025-08-02 22:08:24 -04:00

110 γραμμές
4.6 KiB
Python

#!/usr/bin/env python3
"""
PR Category Label Script
--------------------
This script analyzes the file paths changed in a pull request and determines which
category labels should be added or removed based on the modified files.
It uses GitHub's cli to fetch the changed files and the existing labels on the pull request.
Then, it computes the desired labels based on file paths, compares them to the existing labels,
and applies the necessary additions and removals unless in dry-run mode.
Arguments:
--repo : Full repository name (e.g., org/repo)
--pr : Pull request number
--dry-run : If set, will only log actions without making changes.
--debug : If set, enables detailed debug logging.
Outputs:
Writes 'add' and 'remove' keys to the GitHub Actions $GITHUB_OUTPUT file, which
the workflow reads to apply label changes using the GitHub CLI.
Example Usage:
To run in debug mode and perform a dry-run (no changes made):
python pr_auto_label.py --repo ROCm/rocm-systems --pr <pr-number> --dry-run --debug
To run in debug mode and apply label changes:
python pr_auto_label.py --repo ROCm/rocm-systems --pr <pr-number> --debug
"""
import argparse
import json
import logging
import os
import sys
from pathlib import Path
from typing import List, Optional
from github_cli_client import GitHubCLIClient
logger = logging.getLogger(__name__)
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Apply labels based on PR's changed files.")
parser.add_argument("--repo", required=True, help="Full repository name (e.g., org/repo)")
parser.add_argument("--pr", required=True, type=int, help="Pull request number")
parser.add_argument("--dry-run", action="store_true", help="Print results without writing to GITHUB_OUTPUT.")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
return parser.parse_args(argv)
def compute_desired_labels(file_paths: list) -> set:
"""Determine the desired labels based on the changed files."""
desired_labels = set()
for path in file_paths:
parts = Path(path).parts
if len(parts) >= 2:
if parts[0] == "projects":
desired_labels.add(f"project: {parts[1]}")
elif parts[0] == "shared":
desired_labels.add(f"shared: {parts[1]}")
logger.debug(f"Desired labels based on changes: {desired_labels}")
return desired_labels
def output_labels(existing_labels: List[str], desired_labels: List[str], dry_run: bool) -> None:
"""Output the labels to add/remove to GITHUB_OUTPUT or log them in dry-run mode."""
to_add = sorted(desired_labels - set(existing_labels))
logger.debug(f"Labels to add: {to_add}")
if dry_run:
logger.info("Dry run enabled. Labels will not be applied.")
else:
output_file = os.environ.get("GITHUB_OUTPUT")
if output_file:
with open(output_file, 'a') as f:
print(f"label_add={','.join(to_add)}", file=f)
logger.info(f"Wrote to GITHUB_OUTPUT: add={','.join(to_add)}")
else:
print("GITHUB_OUTPUT environment variable not set. Outputs cannot be written.")
sys.exit(1)
def main(argv=None) -> None:
"""Main function to execute the PR auto label logic."""
args = parse_arguments(argv)
logging.basicConfig(
level=logging.DEBUG if args.debug else logging.INFO
)
client = GitHubCLIClient()
changed_files = [file for file in client.get_changed_files(args.repo, int(args.pr))]
if not changed_files:
logger.warning("REST API failed or returned no changed files. Falling back to SHA-based Git diff...")
try:
pr_data = os.popen(f"gh api repos/{args.repo}/pulls/{args.pr}").read()
pr = json.loads(pr_data)
base_sha = pr["base"]["sha"]
head_sha = pr["head"]["sha"]
logger.debug(f"Base SHA: {base_sha}, Head SHA: {head_sha}")
os.system(f"git fetch origin {base_sha} {head_sha}")
result = os.popen(f"git diff --name-only {base_sha} {head_sha}").read()
changed_files = result.strip().splitlines()
logger.info(f"Fallback changed files (SHA-based): {changed_files}")
except Exception as e:
logger.error(f"SHA-based Git CLI fallback failed: {e}")
sys.exit(1)
existing_labels = client.get_existing_labels_on_pr(args.repo, int(args.pr))
desired_labels = compute_desired_labels(changed_files)
output_labels(existing_labels, desired_labels, args.dry_run)
if __name__ == "__main__":
main()