Skip to content

Commit

Permalink
Fix script
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla committed Oct 9, 2024
1 parent 5a02337 commit 63dba8f
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions scripts/update_ragbits_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,15 @@ def _version_to_list(version_string):
return [int(part) for part in version_string.split(".")]


def _check_update_type(version: str, new_version: str) -> Optional[UpdateType]:
def _check_update_type(version: str, new_version: str) -> UpdateType:
version_list = _version_to_list(version)
new_version_list = _version_to_list(new_version)

if version_list[0] != new_version_list[0]:
return UpdateType.MAJOR
if version_list[1] != new_version_list[1]:
return UpdateType.MINOR
if version_list[2] != new_version_list[2]:
return UpdateType.PATCH
return None
return UpdateType.PATCH


def _get_updated_version(version: str, update_type: UpdateType) -> str:
Expand All @@ -78,6 +76,7 @@ def _update_pkg_version(
pkg_pyproject: Optional[tomlkit.TOMLDocument] = None,
new_version: Optional[str] = None,
update_type: Optional[UpdateType] = None,
sync_ragbits_version: bool = False,
) -> tuple[str, str]:
if not pkg_pyproject:
pkg_pyproject = tomlkit.parse((PACKAGES_DIR / pkg_name / "pyproject.toml").read_text())
Expand All @@ -97,27 +96,34 @@ def _update_pkg_version(
assert isinstance(new_version, str)
pprint(f"[green]The {pkg_name} package was successfully updated from {version} to {new_version}.[/green]")

if pkg_name != "ragbits":
_sync_ragbits_deps(pkg_name, version, new_version, sync_ragbits_version)

return version, new_version


def _sync_ragbits_deps(pkg_name: str, pkg_version: str, pkg_new_version: str, update_type: UpdateType):
def _sync_ragbits_deps(pkg_name: str, pkg_version: str, pkg_new_version: str, update_version: bool = True):
ragbits_pkg_project = tomlkit.parse((PACKAGES_DIR / "ragbits/pyproject.toml").read_text())
ragbits_deps: list[str] = [dep.split("==")[0] for dep in ragbits_pkg_project["project"]["dependencies"]]

update_type = _check_update_type(pkg_version, pkg_new_version)

if pkg_name in ragbits_deps:
idx = ragbits_pkg_project["project"]["dependencies"].index(f"{pkg_name}=={pkg_version}")
del ragbits_pkg_project["project"]["dependencies"][idx]
ragbits_pkg_project["project"]["dependencies"].insert(idx, f"{pkg_name}=={pkg_new_version}")

ragbits_old_version = ragbits_pkg_project["project"]["version"]
ragbits_new_version = _get_updated_version(ragbits_old_version, update_type=update_type)
ragbits_pkg_project["project"]["version"] = ragbits_new_version
if update_version:
ragbits_old_version = ragbits_pkg_project["project"]["version"]
ragbits_new_version = _get_updated_version(ragbits_old_version, update_type=update_type)
ragbits_pkg_project["project"]["version"] = ragbits_new_version

pprint(
"[green]The ragbits package was successfully updated "
f"from {ragbits_old_version} to {ragbits_new_version}.[/green]"
)

(PACKAGES_DIR / "ragbits" / "pyproject.toml").write_text(tomlkit.dumps(ragbits_pkg_project))
pprint(
"[green]The ragbits package was successfully updated "
f"from {ragbits_old_version} to {ragbits_new_version}.[/green]"
)


def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[str] = typer.Argument(None)) -> None:
Expand Down Expand Up @@ -150,7 +156,6 @@ def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[st
pkg_name = list_input("Enter the package name", choices=packages)

casted_update_type = _update_type_to_enum(update_type)

user_prompt_required = pkg_name is None or casted_update_type is None

if pkg_name == "ragbits":
Expand Down Expand Up @@ -179,8 +184,7 @@ def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[st
pprint("[red]The ragbits-core package was not successfully updated.[/red]")

else:
version, new_version = _update_pkg_version(pkg_name, update_type=casted_update_type)
_sync_ragbits_deps(pkg_name, version, new_version, update_type)
_update_pkg_version(pkg_name, update_type=casted_update_type, sync_ragbits_version=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 63dba8f

Please sign in to comment.