Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CLI Calls #5

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,15 @@ runs:
shell: bash
run: |
cd $GITHUB_ACTION_PATH
label_list=$(python -c 'import github_query; print(github_query.get_labels(github_query.parse_args()))' "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
if [[ "$label_list" == '[]' ]]; then
echo "label_list=''" >> $GITHUB_OUTPUT
exit 0
fi

label_list=$(python github_query.py pr-labels "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
echo "label_list=$label_list" >> $GITHUB_OUTPUT

- name: Get version increment
id: bump-increment
shell: bash
run: |
cd $GITHUB_ACTION_PATH
increment=$(python -c 'import github_query; pr_labels = github_query.get_labels(github_query.parse_args()); patch_repo_var = github_query.get_repo_var(repo="${{ inputs.repo }}", var_name="PATCH_BUMP_LABEL"); minor_repo_var = github_query.get_repo_var(repo="${{ inputs.repo }}", var_name="MINOR_BUMP_LABEL"); print(github_query.get_version_increment(patch_bump_list=patch_repo_var, minor_bump_list=minor_repo_var, pr_label_list=pr_labels))' "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
increment=$(python github_query.py version-increment "${{ inputs.repo }}" "${{ inputs.query_parameters }}" "${{ inputs.date }}")
echo "increment=$increment" >> $GITHUB_OUTPUT

# disabled until fixed
Expand Down
183 changes: 39 additions & 144 deletions github_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,159 +4,54 @@
Additonally it's test suite relies mainly on putest and therefore the functions need to be importable to the pytest script.
"""

import argparse
import json
import logging
import re
import subprocess


logger = logging.getLogger(__name__)

def parse_args() -> dict:
"""Parse command-line arguments and store them in a global variable."""

parser = argparse.ArgumentParser(description="A python script to convert GitHub PR information to a more simple format.")
parser.add_argument("repo", type=str, help="Repository name consisting of 'repo-owner/repo-name'")
parser.add_argument("query_parameters", type=str, help="Keys to query for.")
parser.add_argument("date", type=str, default="2024-07-08T09:48:33Z", help="Latest release date.")
parsed_args = parser.parse_args()

repo_name = parsed_args.repo
query_tags = parsed_args.query_parameters.split(',')
latest_release_date = parsed_args.date

command = f"gh pr list --state merged --search 'merged:>={latest_release_date}' --json {','.join(query_tags)} --repo {repo_name}"
pr_json = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

return json.loads(pr_json.stdout)

def get_changelog(pr_data, changelog_start="## Changelog", heading="##"):
"""Get list of changes from a PRs changelog.

Args:
pr_body (list(str)): List of PR body contents.
changelog_start (str, optional): Indicates markdown changelog section. Defaults to "## Changes".
heading (str, optional): Markdown heading. Defaults to "##".

Returns:
list(str): List of changes found.
"""

lines = pr_data.splitlines()
changelog_section = None
changelog_lines = []

for line in lines:
if line.startswith(changelog_start):
changelog_section = True
continue

if changelog_section and line.startswith(heading):
break

if changelog_section and line.startswith("* "):
changelog_lines.append(line.strip("* ").strip())

return changelog_lines

def changelog_per_label(json_dict):
# TODO replace with labels fetched from repo variables
changelog_labels = ["bugfix", "enhancement", "feature"]
labels = []
for item in json_dict:
labels.append(item["labels"])
if any(item in changelog_labels for item in labels):
pass

def prepare_changelog_markdown(pr_query, minor_bump_list, patch_bump_list):
# ? should version bump labels also be filter for changelog ?
label_list = minor_bump_list + patch_bump_list
changelog = ""

for pr in pr_query:
# get all label names in a list
pr_label_list = [label["name"] for label in pr["labels"]]
fitlered_label = list(set(label_list).intersection(pr_label_list))[0]

if fitlered_label:
change_list = get_changelog(pr_data=pr["body"])

changelog += f"## {fitlered_label.capitalize()}\n"
changelog += "".join([f"* {change}\n" for change in change_list])
changelog += "\n"

return changelog


def get_labels(pr_data: dict) -> list:
"""Filter all unique labels from dictionary.

Args:
pr_data (dict): Github PR query result

Returns:
[str]: Liste of unique labels strings found or `None`.
"""

labels = set()

for item in pr_data:
if not item.get("labels"):
return []
for label in item["labels"]:
if not label.get("name"):
return []

labels.add(label["name"])

return list(labels)

def get_repo_var(repo: str, var_name: str) -> list:
"""Query labels from repository variables.

Args:
repo (str): Repository name `owner/repo-name`
var_name (str): Repo variable name

Returns:
str: Comma separated value string.
import click
from src import conversion_logic, queries

@click.group()
def cli():
pass

@cli.command()
@click.argument('repo_name', type=click.STRING)
@click.argument('query_tags', type=click.STRING)
@click.argument('latest_release_date', type=click.STRING)
def pr_labels(latest_release_date, query_tags, repo_name):
"""Get a list of all version relevant PR labels.

latest_release_date (str): datatime string\n
query_tags (str): csv string\n
repo_name (str): repo name as <owner><repo>\n
"""
labels = subprocess.run(
["gh", "variable", "get", var_name, "--repo", repo],
capture_output=True,
text=True,
check=True
)

return csv_string_to_list(labels.stdout)
pr_result = queries.query_merged_prs(latest_release_date, query_tags, repo_name)
pr_labels = conversion_logic.get_labels(pr_data=pr_result)

def csv_string_to_list(input: str) -> list:
if input:
return re.split(r',\s*', input.strip())
if not pr_labels:
click.echo("")

return []
click.echo(pr_labels)

def get_version_increment(patch_bump_list: list, minor_bump_list: list, pr_label_list: list):
"""Figure out version increment based on PR labels.

Args:
patch_bump_list ([str]): Labels for bumping patch version
minor_bump_list ([str]): Labels for bumping minor version
label_list([str]): Labels found in PRs
@cli.command()
@click.argument('repo_name', type=click.STRING)
@click.argument('query_tags', type=click.STRING)
@click.argument('latest_release_date', type=click.STRING)
def version_increment(latest_release_date, query_tags, repo_name):
"""Output a calculated version increment suggestion.

Returns:
str: version increment
latest_release_date (str): datetime string\n
query_tags (str): csv string\n
repo_name (str): repo name as <owner><repo>\n
"""

if not pr_label_list:
return ""
pr_result = queries.query_merged_prs(latest_release_date, query_tags, repo_name)
pr_labels = conversion_logic.get_labels(pr_data=pr_result)
patch_repo_var_list = conversion_logic.csv_string_to_list(queries.get_repo_var(repo=repo_name, var_name="PATCH_BUMP_LABEL"))
minor_repo_var_list = conversion_logic.csv_string_to_list(queries.get_repo_var(repo=repo_name, var_name="MINOR_BUMP_LABEL"))
increment = conversion_logic.get_version_increment(patch_bump_list=patch_repo_var_list, minor_bump_list=minor_repo_var_list, pr_label_list=pr_labels)

# TODO add major bump option
if any(label in pr_label_list for label in minor_bump_list):
return "minor"
click.echo(increment)

if any(label in pr_label_list for label in patch_bump_list):
return "patch"

return ""
if __name__ == '__main__':
cli()
Empty file added src/__init__.py
Empty file.
139 changes: 139 additions & 0 deletions src/conversion_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
import re


logger = logging.getLogger(__name__)

# INFO not in use
def get_changelog(pr_data, changelog_start="## Changelog", heading="##"):
"""Get list of changes from a PRs changelog.

Args:
pr_body (list(str)): List of PR body contents.
changelog_start (str, optional): Indicates markdown changelog section. Defaults to "## Changes".
heading (str, optional): Markdown heading. Defaults to "##".

Returns:
list(str): List of changes found.
"""

lines = pr_data.splitlines()
changelog_section = None
changelog_lines = []

for line in lines:
if line.startswith(changelog_start):
changelog_section = True
continue

if changelog_section and line.startswith(heading):
break

if changelog_section and line.startswith("* "):
changelog_lines.append(line.strip("* ").strip())

return changelog_lines

# INFO not in use
def changelog_per_label(json_dict):
# TODO replace with labels fetched from repo variables
changelog_labels = ["bugfix", "enhancement", "feature"]
labels = []
for item in json_dict:
labels.append(item["labels"])
if any(item in changelog_labels for item in labels):
pass

# INFO not in use
def prepare_changelog_markdown(pr_query, minor_bump_list, patch_bump_list):
# ? should version bump labels also be filter for changelog ?
label_list = minor_bump_list + patch_bump_list
changelog = ""

for pr in pr_query:
# get all label names in a list
pr_label_list = [label["name"] for label in pr["labels"]]
filtered_label = list(set(label_list).intersection(pr_label_list))[0]

if filtered_label:
change_list = get_changelog(pr_data=pr["body"])

changelog += f"## {filtered_label.capitalize()}\n"
changelog += "".join([f"* {change}\n" for change in change_list])
changelog += "\n"

return changelog


def get_labels(pr_data: dict) -> list:
"""Filter all unique labels from dictionary.

Args:
pr_data (dict): Github PR query result

Returns:
list: List of unique labels strings found or `None`.
"""

labels = set()

for item in pr_data:
if not item.get("labels"):
logger.warning("No PR label data found.")
return []
for label in item["labels"]:
if not label.get("name"):
logger.warning("No PR label names found.")
return []

labels.add(label["name"])
logger.debug("PR labels found.")

return list(labels)

def csv_string_to_list(input: str) -> list:
"""Convert string to list.

Args:
input (str): Expected csv string.

Returns:
list: List of strings.
"""

if input:
return re.split(r',\s*', input.strip())

return []

def get_version_increment(pr_label_list: list, patch_bump_list: list=[], minor_bump_list: list=[], major_bump_list: list=[]):
"""Figure out version increment based on PR labels.

Args:
patch_bump_list ([str]): Labels for bumping patch version
minor_bump_list ([str]): Labels for bumping minor version
label_list ([str]): Labels found in PRs

Returns:
str: version increment
"""

if not pr_label_list:
logger.warning("PR label list was empty")
return ""

for name, param in locals().items():
if not isinstance(param, list):
raise ValueError(f"{name} must be a list.")

if any(label in pr_label_list for label in major_bump_list):
return "major"

if any(label in pr_label_list for label in minor_bump_list):
return "minor"

if any(label in pr_label_list for label in patch_bump_list):
return "patch"

logger.warning("No relevant labels found for version increment.")
return ""
Loading