Skip to content

Commit

Permalink
Update structure to match refactor branch
Browse files Browse the repository at this point in the history
  • Loading branch information
philnewm committed Oct 29, 2024
1 parent 862b5e0 commit 4479a62
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 197 deletions.
9 changes: 3 additions & 6 deletions action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ runs:
shell: bash
run: |
cd $GITHUB_ACTION_PATH
label_list=$(python -c 'import github_query; print(github_query.get_labels(github_query.parse_args()))' "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
if [[ "$label_list" == '[]' ]]; then
echo "label_list=''" >> $GITHUB_OUTPUT
exit 0
fi
label_list=$(python github_query.py pr-labels "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
echo "label_list=$label_list" >> $GITHUB_OUTPUT
- name: Get version increment
Expand All @@ -72,6 +68,7 @@ runs:
run: |
cd $GITHUB_ACTION_PATH
increment=$(python -c 'import github_query; pr_labels = github_query.get_labels(github_query.parse_args()); patch_repo_var = github_query.get_repo_var(repo="${{ inputs.repo }}", var_name="PATCH_BUMP_LABEL"); minor_repo_var = github_query.get_repo_var(repo="${{ inputs.repo }}", var_name="MINOR_BUMP_LABEL"); print(github_query.get_version_increment(patch_bump_list=patch_repo_var, minor_bump_list=minor_repo_var, pr_label_list=pr_labels))' "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
echo "increment=$increment" >> $GITHUB_OUTPUT
# INFO multiline strings need to be precessed like this according to
Expand Down
212 changes: 35 additions & 177 deletions github_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,197 +3,55 @@
Additionally it's test suite relies mainly on pytest and therefore the functions need to be importable to the pytest script.
"""

import argparse
import click
import json
import logging
import re
import subprocess

from collections import namedtuple

from src import conversion_logic, queries


logger: logging.Logger = logging.getLogger(__name__)

Changelog: type[Changelog] = namedtuple("Changelog", "labels title number url id")


def parse_args() -> dict:
"""Parse command-line arguments and store them in a global variable."""

parser = argparse.ArgumentParser(description="A python script to convert GitHub PR information to a more simple format.")
parser.add_argument("repo", type=str, help="Repository name consisting of 'repo-owner/repo-name'")
parser.add_argument("query_parameters", type=str, help="Keys to query for.")
parser.add_argument("date", type=str, default="2024-07-08T09:48:33Z", help="Latest release date.")
parsed_args = parser.parse_args()

repo_name = parsed_args.repo
query_tags = parsed_args.query_parameters.split(',')
latest_release_date = parsed_args.date

command: str = f"gh pr list --state merged --search 'merged:>={latest_release_date}' --json {','.join(query_tags)} --repo {repo_name}"
pr_json = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

return json.loads(pr_json.stdout)


def get_changelog(pr_data, changelog_start="## Changelog", heading="##"):
"""Get list of changes from a PRs changelog.
Args:
pr_body (list(str)): List of PR body contents.
changelog_start (str, optional): Indicates markdown changelog section. Defaults to "## Changes".
heading (str, optional): Markdown heading. Defaults to "##".
Returns:
list(str): List of changes found.
"""

lines = pr_data.splitlines()
changelog_section = None
changelog_lines = []

for line in lines:
if line.startswith(changelog_start):
changelog_section = True
continue

if changelog_section and line.startswith(heading):
break

if changelog_section and line.startswith("* "):
changelog_lines.append(line.strip("* ").strip())

return changelog_lines


def filter_changes_per_label(pr_data: list[dict[str, str]], changelog_label_list: list[str]) -> list[Changelog]:

changes_list: list[Changelog] = []

for pull_request in pr_data:

# TODO refactor this to become more readable
label_list: list[str] = [label["name"] for label in pull_request["labels"] if label["name"] in changelog_label_list]
if label_list:
changes_list.append(Changelog(label_list, pull_request["title"], pull_request["number"], pull_request["url"], pull_request["id"]))

return changes_list


def sort_changes(changes_list: list[Changelog], changelog_label_list: list[str]) -> list[Changelog]:

# TODO implement this logic in a more clever way
sorted_changes: list[Changelog] = []

for order_label in changelog_label_list:
for change in changes_list:
if any(label == order_label for label in change.labels):
sorted_changes.append(change)

return sorted_changes


def build_changelog_markdown(changes: list[Changelog]) -> str:
changelog = "# Changelog"
previous_labels: list[str] = []

# TODO implement this logic in a more clever way, checkout `from itertools import groupby`
for change in changes:
current_labels: list[str] = change.labels

if not any(label in previous_labels for label in current_labels):
label: str = change.labels[0].removeprefix("type: ")
changelog += f"\n\n## {label.capitalize()}\n\n"

changelog += f"* {change.title} - [{change.number}]({change.url})\n"

previous_labels = current_labels

return changelog


def get_labels(pr_data: list[dict[str, str]]) -> list[str]:
"""Filter all unique labels from dictionary.
Args:
pr_data (dict): Github PR query result
Returns:
[str]: Liste of unique labels strings found or `None`.
"""

labels = set()

for item in pr_data:
if not item.get("labels"):
return []
for label in item["labels"]:
if not label.get("name"):
return []

labels.add(label["name"])

return list(labels)

@click.group()
def cli() -> None:
pass

def get_repo_var(repo: str, var_name: str) -> list[str]:
"""Query labels from repository variables.

Args:
repo (str): Repository name `owner/repo-name`
var_name (str): Repo variable name
@cli.command()
@click.argument('repo_name', type=click.STRING)
@click.argument('query_tags', type=click.STRING)
@click.argument('latest_release_date', type=click.STRING)
def pr_labels(latest_release_date, query_tags, repo_name):
"""Get a list of all version relevant PR labels.
Returns:
str: Comma separated value string.
latest_release_date (str): datatime string\n
query_tags (str): csv string\n
repo_name (str): repo name as <owner><repo>\n
"""
labels = subprocess.run(
["gh", "variable", "get", var_name, "--repo", repo],
capture_output=True,
text=True,
check=True
)

return csv_string_to_list(labels.stdout)

pr_result = queries.query_merged_prs(latest_release_date, query_tags, repo_name)
pr_labels = conversion_logic.get_labels(pr_data=pr_result)

def csv_string_to_list(input: str) -> list[str]:
if input:
return re.split(r',\s*', input.strip())
if not pr_labels:
click.echo("")

return []
click.echo(pr_labels)


def get_version_increment(patch_bump_list: list[str], minor_bump_list: list[str], pr_label_list: list[str]):
"""Figure out version increment based on PR labels.
Args:
patch_bump_list ([str]): Labels for bumping patch version
minor_bump_list ([str]): Labels for bumping minor version
label_list([str]): Labels found in PRs
@cli.command()
@click.argument('repo_name', type=click.STRING)
@click.argument('query_tags', type=click.STRING)
@click.argument('latest_release_date', type=click.STRING)
def version_increment(latest_release_date, query_tags, repo_name):
"""Output a calculated version increment suggestion.
Returns:
str: version increment
latest_release_date (str): datetime string\n
query_tags (str): csv string\n
repo_name (str): repo name as <owner><repo>\n
"""

if not pr_label_list:
return ""

# TODO add major bump option
if any(label in pr_label_list for label in minor_bump_list):
return "minor"

if any(label in pr_label_list for label in patch_bump_list):
return "patch"
pr_result = queries.query_merged_prs(latest_release_date, query_tags, repo_name)
pr_labels = conversion_logic.get_labels(pr_data=pr_result)
patch_repo_var_list = conversion_logic.csv_string_to_list(queries.get_repo_var(repo=repo_name, var_name="PATCH_BUMP_LABEL"))
minor_repo_var_list = conversion_logic.csv_string_to_list(queries.get_repo_var(repo=repo_name, var_name="MINOR_BUMP_LABEL"))
increment = conversion_logic.get_version_increment(patch_bump_list=patch_repo_var_list, minor_bump_list=minor_repo_var_list, pr_label_list=pr_labels)

return ""


@click.group()
def cli():
pass
click.echo(increment)


@cli.command()
Expand All @@ -213,9 +71,9 @@ def generate_release_changelog(latest_release_date: str, query_tags: str, repo_n
pr_result: list[dict[str, str]] = queries.query_merged_prs(latest_release_date, query_tags_list, repo_name)
changelog_labels_result: list[str] = conversion_logic.csv_string_to_list(changelog_labels)

pr_filtered: list[Changelog] = filter_changes_per_label(pr_data=pr_result, changelog_label_list=changelog_labels_result)
sorted_changes: list[Changelog] = sort_changes(changes_list=pr_filtered, changelog_label_list=changelog_labels_result)
markdown_changelog: str = build_changelog_markdown(sorted_changes)
pr_filtered: list[Changelog] = conversion_logic.filter_changes_per_label(pr_data=pr_result, changelog_label_list=changelog_labels_result)
sorted_changes: list[Changelog] = conversion_logic.sort_changes(changes_list=pr_filtered, changelog_label_list=changelog_labels_result)
markdown_changelog: str = conversion_logic.build_changelog_markdown(sorted_changes)

click.echo(markdown_changelog)

Expand Down
Loading

0 comments on commit 4479a62

Please sign in to comment.