diff --git a/scripts/update_ragbits_package.py b/scripts/update_ragbits_package.py index 730f42604..0e9d59afa 100644 --- a/scripts/update_ragbits_package.py +++ b/scripts/update_ragbits_package.py @@ -46,7 +46,7 @@ def _version_to_list(version_string): return [int(part) for part in version_string.split(".")] -def _check_update_type(version: str, new_version: str) -> Optional[UpdateType]: +def _check_update_type(version: str, new_version: str) -> UpdateType: version_list = _version_to_list(version) new_version_list = _version_to_list(new_version) @@ -54,9 +54,7 @@ def _check_update_type(version: str, new_version: str) -> Optional[UpdateType]: return UpdateType.MAJOR if version_list[1] != new_version_list[1]: return UpdateType.MINOR - if version_list[2] != new_version_list[2]: - return UpdateType.PATCH - return None + return UpdateType.PATCH def _get_updated_version(version: str, update_type: UpdateType) -> str: @@ -78,6 +76,7 @@ def _update_pkg_version( pkg_pyproject: Optional[tomlkit.TOMLDocument] = None, new_version: Optional[str] = None, update_type: Optional[UpdateType] = None, + sync_ragbits_version: bool = False, ) -> tuple[str, str]: if not pkg_pyproject: pkg_pyproject = tomlkit.parse((PACKAGES_DIR / pkg_name / "pyproject.toml").read_text()) @@ -97,27 +96,34 @@ def _update_pkg_version( assert isinstance(new_version, str) pprint(f"[green]The {pkg_name} package was successfully updated from {version} to {new_version}.[/green]") + if pkg_name != "ragbits": + _sync_ragbits_deps(pkg_name, version, new_version, sync_ragbits_version) + return version, new_version -def _sync_ragbits_deps(pkg_name: str, pkg_version: str, pkg_new_version: str, update_type: UpdateType): +def _sync_ragbits_deps(pkg_name: str, pkg_version: str, pkg_new_version: str, update_version: bool = True): ragbits_pkg_project = tomlkit.parse((PACKAGES_DIR / "ragbits/pyproject.toml").read_text()) ragbits_deps: list[str] = [dep.split("==")[0] for dep in ragbits_pkg_project["project"]["dependencies"]] + update_type = _check_update_type(pkg_version, pkg_new_version) + if pkg_name in ragbits_deps: idx = ragbits_pkg_project["project"]["dependencies"].index(f"{pkg_name}=={pkg_version}") del ragbits_pkg_project["project"]["dependencies"][idx] ragbits_pkg_project["project"]["dependencies"].insert(idx, f"{pkg_name}=={pkg_new_version}") - ragbits_old_version = ragbits_pkg_project["project"]["version"] - ragbits_new_version = _get_updated_version(ragbits_old_version, update_type=update_type) - ragbits_pkg_project["project"]["version"] = ragbits_new_version + if update_version: + ragbits_old_version = ragbits_pkg_project["project"]["version"] + ragbits_new_version = _get_updated_version(ragbits_old_version, update_type=update_type) + ragbits_pkg_project["project"]["version"] = ragbits_new_version + + pprint( + "[green]The ragbits package was successfully updated " + f"from {ragbits_old_version} to {ragbits_new_version}.[/green]" + ) (PACKAGES_DIR / "ragbits" / "pyproject.toml").write_text(tomlkit.dumps(ragbits_pkg_project)) - pprint( - "[green]The ragbits package was successfully updated " - f"from {ragbits_old_version} to {ragbits_new_version}.[/green]" - ) def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[str] = typer.Argument(None)) -> None: @@ -150,7 +156,6 @@ def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[st pkg_name = list_input("Enter the package name", choices=packages) casted_update_type = _update_type_to_enum(update_type) - user_prompt_required = pkg_name is None or casted_update_type is None if pkg_name == "ragbits": @@ -179,8 +184,7 @@ def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[st pprint("[red]The ragbits-core package was not successfully updated.[/red]") else: - version, new_version = _update_pkg_version(pkg_name, update_type=casted_update_type) - _sync_ragbits_deps(pkg_name, version, new_version, update_type) + _update_pkg_version(pkg_name, update_type=casted_update_type, sync_ragbits_version=True) if __name__ == "__main__":