diff --git a/action.yml b/action.yml index d82dbbe..f640e52 100644 --- a/action.yml +++ b/action.yml @@ -58,12 +58,8 @@ runs: shell: bash run: | cd $GITHUB_ACTION_PATH - label_list=$(python -c 'import github_query; print(github_query.get_labels(github_query.parse_args()))' "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}") - if [[ "$label_list" == '[]' ]]; then - echo "label_list=''" >> $GITHUB_OUTPUT - exit 0 - fi - + label_list=$(python github_query.py pr-labels "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}") + echo "label_list=$label_list" >> $GITHUB_OUTPUT - name: Get version increment @@ -72,6 +68,7 @@ runs: run: | cd $GITHUB_ACTION_PATH increment=$(python -c 'import github_query; pr_labels = github_query.get_labels(github_query.parse_args()); patch_repo_var = github_query.get_repo_var(repo="${{ inputs.repo }}", var_name="PATCH_BUMP_LABEL"); minor_repo_var = github_query.get_repo_var(repo="${{ inputs.repo }}", var_name="MINOR_BUMP_LABEL"); print(github_query.get_version_increment(patch_bump_list=patch_repo_var, minor_bump_list=minor_repo_var, pr_label_list=pr_labels))' "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}") + echo "increment=$increment" >> $GITHUB_OUTPUT # INFO multiline strings need to be precessed like this according to diff --git a/github_query.py b/github_query.py index 5586f11..01d06ee 100644 --- a/github_query.py +++ b/github_query.py @@ -3,197 +3,55 @@ Additionally it's test suite relies mainly on pytest and therefore the functions need to be importable to the pytest script. """ -import argparse import click -import json -import logging -import re -import subprocess - -from collections import namedtuple - from src import conversion_logic, queries -logger: logging.Logger = logging.getLogger(__name__) - -Changelog: type[Changelog] = namedtuple("Changelog", "labels title number url id") - - -def parse_args() -> dict: - """Parse command-line arguments and store them in a global variable.""" - - parser = argparse.ArgumentParser(description="A python script to convert GitHub PR information to a more simple format.") - parser.add_argument("repo", type=str, help="Repository name consisting of 'repo-owner/repo-name'") - parser.add_argument("query_parameters", type=str, help="Keys to query for.") - parser.add_argument("date", type=str, default="2024-07-08T09:48:33Z", help="Latest release date.") - parsed_args = parser.parse_args() - - repo_name = parsed_args.repo - query_tags = parsed_args.query_parameters.split(',') - latest_release_date = parsed_args.date - - command: str = f"gh pr list --state merged --search 'merged:>={latest_release_date}' --json {','.join(query_tags)} --repo {repo_name}" - pr_json = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - - return json.loads(pr_json.stdout) - - -def get_changelog(pr_data, changelog_start="## Changelog", heading="##"): - """Get list of changes from a PRs changelog. - - Args: - pr_body (list(str)): List of PR body contents. - changelog_start (str, optional): Indicates markdown changelog section. Defaults to "## Changes". - heading (str, optional): Markdown heading. Defaults to "##". - - Returns: - list(str): List of changes found. - """ - - lines = pr_data.splitlines() - changelog_section = None - changelog_lines = [] - - for line in lines: - if line.startswith(changelog_start): - changelog_section = True - continue - - if changelog_section and line.startswith(heading): - break - - if changelog_section and line.startswith("* "): - changelog_lines.append(line.strip("* ").strip()) - - return changelog_lines - - -def filter_changes_per_label(pr_data: list[dict[str, str]], changelog_label_list: list[str]) -> list[Changelog]: - - changes_list: list[Changelog] = [] - - for pull_request in pr_data: - - # TODO refactor this to become more readable - label_list: list[str] = [label["name"] for label in pull_request["labels"] if label["name"] in changelog_label_list] - if label_list: - changes_list.append(Changelog(label_list, pull_request["title"], pull_request["number"], pull_request["url"], pull_request["id"])) - - return changes_list - - -def sort_changes(changes_list: list[Changelog], changelog_label_list: list[str]) -> list[Changelog]: - - # TODO implement this logic in a more clever way - sorted_changes: list[Changelog] = [] - - for order_label in changelog_label_list: - for change in changes_list: - if any(label == order_label for label in change.labels): - sorted_changes.append(change) - - return sorted_changes - - -def build_changelog_markdown(changes: list[Changelog]) -> str: - changelog = "# Changelog" - previous_labels: list[str] = [] - - # TODO implement this logic in a more clever way, checkout `from itertools import groupby` - for change in changes: - current_labels: list[str] = change.labels - - if not any(label in previous_labels for label in current_labels): - label: str = change.labels[0].removeprefix("type: ") - changelog += f"\n\n## {label.capitalize()}\n\n" - - changelog += f"* {change.title} - [{change.number}]({change.url})\n" - - previous_labels = current_labels - - return changelog - - -def get_labels(pr_data: list[dict[str, str]]) -> list[str]: - """Filter all unique labels from dictionary. - - Args: - pr_data (dict): Github PR query result - - Returns: - [str]: Liste of unique labels strings found or `None`. - """ - - labels = set() - - for item in pr_data: - if not item.get("labels"): - return [] - for label in item["labels"]: - if not label.get("name"): - return [] - - labels.add(label["name"]) - - return list(labels) - +@click.group() +def cli() -> None: + pass -def get_repo_var(repo: str, var_name: str) -> list[str]: - """Query labels from repository variables. - Args: - repo (str): Repository name `owner/repo-name` - var_name (str): Repo variable name +@cli.command() +@click.argument('repo_name', type=click.STRING) +@click.argument('query_tags', type=click.STRING) +@click.argument('latest_release_date', type=click.STRING) +def pr_labels(latest_release_date, query_tags, repo_name): + """Get a list of all version relevant PR labels. - Returns: - str: Comma separated value string. + latest_release_date (str): datatime string\n + query_tags (str): csv string\n + repo_name (str): repo name as \n """ - labels = subprocess.run( - ["gh", "variable", "get", var_name, "--repo", repo], - capture_output=True, - text=True, - check=True - ) - - return csv_string_to_list(labels.stdout) + pr_result = queries.query_merged_prs(latest_release_date, query_tags, repo_name) + pr_labels = conversion_logic.get_labels(pr_data=pr_result) -def csv_string_to_list(input: str) -> list[str]: - if input: - return re.split(r',\s*', input.strip()) + if not pr_labels: + click.echo("") - return [] + click.echo(pr_labels) -def get_version_increment(patch_bump_list: list[str], minor_bump_list: list[str], pr_label_list: list[str]): - """Figure out version increment based on PR labels. - - Args: - patch_bump_list ([str]): Labels for bumping patch version - minor_bump_list ([str]): Labels for bumping minor version - label_list([str]): Labels found in PRs +@cli.command() +@click.argument('repo_name', type=click.STRING) +@click.argument('query_tags', type=click.STRING) +@click.argument('latest_release_date', type=click.STRING) +def version_increment(latest_release_date, query_tags, repo_name): + """Output a calculated version increment suggestion. - Returns: - str: version increment + latest_release_date (str): datetime string\n + query_tags (str): csv string\n + repo_name (str): repo name as \n """ - if not pr_label_list: - return "" - - # TODO add major bump option - if any(label in pr_label_list for label in minor_bump_list): - return "minor" - - if any(label in pr_label_list for label in patch_bump_list): - return "patch" + pr_result = queries.query_merged_prs(latest_release_date, query_tags, repo_name) + pr_labels = conversion_logic.get_labels(pr_data=pr_result) + patch_repo_var_list = conversion_logic.csv_string_to_list(queries.get_repo_var(repo=repo_name, var_name="PATCH_BUMP_LABEL")) + minor_repo_var_list = conversion_logic.csv_string_to_list(queries.get_repo_var(repo=repo_name, var_name="MINOR_BUMP_LABEL")) + increment = conversion_logic.get_version_increment(patch_bump_list=patch_repo_var_list, minor_bump_list=minor_repo_var_list, pr_label_list=pr_labels) - return "" - - -@click.group() -def cli(): - pass + click.echo(increment) @cli.command() @@ -213,9 +71,9 @@ def generate_release_changelog(latest_release_date: str, query_tags: str, repo_n pr_result: list[dict[str, str]] = queries.query_merged_prs(latest_release_date, query_tags_list, repo_name) changelog_labels_result: list[str] = conversion_logic.csv_string_to_list(changelog_labels) - pr_filtered: list[Changelog] = filter_changes_per_label(pr_data=pr_result, changelog_label_list=changelog_labels_result) - sorted_changes: list[Changelog] = sort_changes(changes_list=pr_filtered, changelog_label_list=changelog_labels_result) - markdown_changelog: str = build_changelog_markdown(sorted_changes) + pr_filtered: list[Changelog] = conversion_logic.filter_changes_per_label(pr_data=pr_result, changelog_label_list=changelog_labels_result) + sorted_changes: list[Changelog] = conversion_logic.sort_changes(changes_list=pr_filtered, changelog_label_list=changelog_labels_result) + markdown_changelog: str = conversion_logic.build_changelog_markdown(sorted_changes) click.echo(markdown_changelog) diff --git a/src/conversion_logic.py b/src/conversion_logic.py index 6de199a..ebb56b3 100644 --- a/src/conversion_logic.py +++ b/src/conversion_logic.py @@ -1,5 +1,117 @@ +import logging import re +from collections import namedtuple + + +logger: logging.Logger = logging.getLogger(__name__) +Changelog: type[Changelog] = namedtuple("Changelog", "labels title number url id") + + +def sort_changes(changes_list: list[Changelog], changelog_label_list: list[str]) -> list[Changelog]: + + # TODO implement this logic in a more clever way + sorted_changes: list[Changelog] = [] + + for order_label in changelog_label_list: + for change in changes_list: + if any(label == order_label for label in change.labels): + sorted_changes.append(change) + + return sorted_changes + + +# INFO currently not in use +def get_changelog(pr_data, changelog_start="## Changelog", heading="##"): + """Get list of changes from a PRs changelog. + + Args: + pr_body (list(str)): List of PR body contents. + changelog_start (str, optional): Indicates markdown changelog section. Defaults to "## Changes". + heading (str, optional): Markdown heading. Defaults to "##". + + Returns: + list(str): List of changes found. + """ + + lines: list[str] = pr_data.splitlines() + changelog_section = None + changelog_lines: list[str] = [] + + for line in lines: + if line.startswith(changelog_start): + changelog_section = True + continue + + if changelog_section and line.startswith(heading): + break + + if changelog_section and line.startswith("* "): + changelog_lines.append(line.strip("* ").strip()) + + return changelog_lines + + +def filter_changes_per_label(pr_data: list[dict[str, str]], changelog_label_list: list[str]) -> list[Changelog]: + + changes_list: list[Changelog] = [] + + for pull_request in pr_data: + + # TODO refactor this to become more readable + label_list: list[str] = [label["name"] for label in pull_request["labels"] if label["name"] in changelog_label_list] + if label_list: + changes_list.append(Changelog(label_list, pull_request["title"], pull_request["number"], pull_request["url"], pull_request["id"])) + + return changes_list + + +def build_changelog_markdown(changes: list[Changelog]) -> str: + changelog = "# Changelog" + previous_labels: list[str] = [] + + # TODO implement this logic in a more clever way, checkout `from itertools import groupby` + for change in changes: + current_labels: list[str] = change.labels + + if not any(label in previous_labels for label in current_labels): + label: str = change.labels[0].removeprefix("type: ") + changelog += f"\n\n## {label.capitalize()}\n\n" + + changelog += f"* {change.title} - [{change.number}]({change.url})\n" + + previous_labels = current_labels + + return changelog + + +def filter_unique_labels(pr_data: dict[dict[str, str]]) -> list[str]: + """Filter all unique labels from dictionary. + + Args: + pr_data (dict): Github PR query result + + Returns: + list: List of unique labels strings found or `None`. + """ + + labels = set() + + for item in pr_data: + if not item.get("labels"): + logger.warning("No PR label data found.") + return [] + for label in item["labels"]: + if not label.get("name"): + logger.warning("No PR label names found.") + return [] + + labels.add(label["name"]) + logger.debug("PR labels found.") + + return list(labels) + + def csv_string_to_list(input: str) -> list[str]: """Convert string to list. @@ -13,4 +125,37 @@ def csv_string_to_list(input: str) -> list[str]: if input: return re.split(r',\s*', input.strip()) - return [] \ No newline at end of file + return [] + + +def get_version_increment(pr_label_list: list, patch_bump_list: list=[], minor_bump_list: list=[], major_bump_list: list=[]): + """Figure out version increment based on PR labels. + + Args: + patch_bump_list ([str]): Labels for bumping patch version + minor_bump_list ([str]): Labels for bumping minor version + label_list ([str]): Labels found in PRs + + Returns: + str: version increment + """ + + if not pr_label_list: + logger.warning("PR label list was empty") + return "" + + for name, param in locals().items(): + if not isinstance(param, list): + raise ValueError(f"{name} must be a list.") + + if any(label in pr_label_list for label in major_bump_list): + return "major" + + if any(label in pr_label_list for label in minor_bump_list): + return "minor" + + if any(label in pr_label_list for label in patch_bump_list): + return "patch" + + logger.warning("No relevant labels found for version increment.") + return "" \ No newline at end of file diff --git a/tests/test_github_query.py b/tests/test_github_query.py index 4eb3382..c1cf136 100644 --- a/tests/test_github_query.py +++ b/tests/test_github_query.py @@ -2,7 +2,7 @@ import json import pytest -import github_query +from src import conversion_logic, queries @pytest.fixture def pr_api_output() -> list[dict[str, Any]]: @@ -115,13 +115,13 @@ def csv_string_empty() -> Literal['']: # Get PR Label test-cases def test_get_labels(pr_api_output): - labels = github_query.get_labels(pr_data=pr_api_output) + labels = conversion_logic.filter_unique_labels(pr_data=pr_api_output) assert isinstance(labels, list) assert set(labels) == {"bugfix", "enhancement"} def test_get_labels_missing_input(pr_api_output_missing_label): - labels = github_query.get_labels(pr_data=pr_api_output_missing_label) + labels = conversion_logic.filter_unique_labels(pr_data=pr_api_output_missing_label) assert labels == [] @@ -129,22 +129,22 @@ def test_get_labels_missing_input(pr_api_output_missing_label): # Convert repo label list def test_csv_string_to_list_spaces(csv_string_spaces): - string_list = github_query.csv_string_to_list(csv_string_spaces) + string_list = conversion_logic.csv_string_to_list(csv_string_spaces) assert string_list == ["bugfix", "enhancement", "feature"] def test_csv_string_to_list_no_spaces(csv_string_no_spaces): - string_list = github_query.csv_string_to_list(csv_string_no_spaces) + string_list = conversion_logic.csv_string_to_list(csv_string_no_spaces) assert string_list == ["bugfix", "enhancement", "feature"] def test_csv_string_to_list_no_comma(csv_string_no_comma): - string_list = github_query.csv_string_to_list(csv_string_no_comma) + string_list = conversion_logic.csv_string_to_list(csv_string_no_comma) assert string_list == ["bugfix"] def test_csv_string_to_list_empty(csv_string_empty): - string_list = github_query.csv_string_to_list(csv_string_empty) + string_list = conversion_logic.csv_string_to_list(csv_string_empty) assert string_list == [] @@ -152,27 +152,27 @@ def test_csv_string_to_list_empty(csv_string_empty): # Version Increment test-cases def test_get_version_increment_patch(minor_bump, patch_bump, pr_labels_bug): - increment = github_query.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_bug) + increment = conversion_logic.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_bug) assert increment == "patch" def test_get_version_increment_minor(minor_bump, patch_bump, pr_labels_enhancement): - increment = github_query.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_enhancement) + increment = conversion_logic.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_enhancement) assert increment == "minor" def test_get_version_increment_wrong_labels(minor_bump, patch_bump, pr_labels_wrong_labels): - increment = github_query.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_wrong_labels) + increment = conversion_logic.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_wrong_labels) assert increment == "" def test_get_version_increment_none(minor_bump, patch_bump, pr_labels_none): - increment = github_query.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_none) + increment = conversion_logic.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_none) assert increment == "" def test_get_version_increment_empty_list(minor_bump, patch_bump, pr_labels_empty_list): - increment = github_query.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_empty_list) + increment = conversion_logic.get_version_increment(patch_bump_list=patch_bump, minor_bump_list=minor_bump, pr_label_list=pr_labels_empty_list) assert increment == "" @@ -181,7 +181,7 @@ def test_get_version_increment_empty_list(minor_bump, patch_bump, pr_labels_empt def test_changer_pert_label(merged_pr_samples: dict[str, str]) -> None: changelog_labels: list[str] = ["type: bug", "type: enhancement", "type: maintenance"] - filtered_results: list[Changelog] = github_query.filter_changes_per_label(pr_data=merged_pr_samples, changelog_label_list=changelog_labels) + filtered_results: list[Changelog] = conversion_logic.filter_changes_per_label(pr_data=merged_pr_samples, changelog_label_list=changelog_labels) for result in filtered_results: for label in result.labels: