diff --git a/notebooks/Experimental/1-checkpoint.ipynb b/notebooks/Experimental/1-checkpoint.ipynb new file mode 100644 index 00000000000..3dcc8e729da --- /dev/null +++ b/notebooks/Experimental/1-checkpoint.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "034a332d-517f-486f-8713-eeed53d35067", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "import syft as sy\n", + "from syft.util.test_helpers.checkpoint import create_checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53cebce4-5e50-48f0-bd38-cf82cba9e71c", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client = sy.login(email=\"info@openmined.org\", password=\"changethis\", port=8080)" + ] + }, + { + "cell_type": "markdown", + "id": "957c101e-f048-4ba0-b5e8-22af97e6c9b1", + "metadata": {}, + "source": [ + "### Scaling Default Worker Pool" + ] + }, + { + "cell_type": "markdown", + "id": "e7217d39-fd2e-4a6c-b7c7-8013703b83ce", + "metadata": {}, + "source": [ + "We should see a default worker pool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7018e9bf-1639-4204-9bf3-19d296cc8572", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.worker_pools" + ] + }, + { + "cell_type": "markdown", + "id": "87cbd371-a23a-4229-b305-ed834c2eb3d1", + "metadata": {}, + "source": [ + "Scale up to 3 workers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acc5c6ba-cd25-420e-9491-2db0f67fb763", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.api.services.worker_pool.scale(number=3, pool_name=\"default-pool\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4696855-8848-4182-9513-0753893eaf5f", + "metadata": {}, + "outputs": [], + "source": [ + "result = datasite_client.api.services.worker_pool.get_by_name(pool_name=\"default-pool\")\n", + "assert len(result.workers) == 3, str(result.to_dict())\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d8b5827-464b-4954-acc1-52773829ca65", + "metadata": {}, + "outputs": [], + "source": [ + "create_checkpoint(name=\"1-checkpoint-v1\", client=datasite_client)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27f441fc-f3c0-445c-aa68-85ab92feaa57", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "839af9a0-a76d-42bc-b8f9-7e67d9da7e7a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Experimental/2-checkpoint.ipynb b/notebooks/Experimental/2-checkpoint.ipynb new file mode 100644 index 00000000000..4cf14b6a327 --- /dev/null +++ b/notebooks/Experimental/2-checkpoint.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bdea5a57-7ec9-45e0-8684-662ae52145cd", + "metadata": {}, + "source": [ + "# 2 - Checkpoint" + ] + }, + { + "cell_type": "markdown", + "id": "0b978f6c-6778-4880-934a-e751d55bec97", + "metadata": {}, + "source": [ + "- Create a fresh cluster and load the checkpoint from notebook 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b53b0926-8ee0-4bfa-876c-f66bce750dc7", + "metadata": {}, + "outputs": [], + "source": [ + "# syft absolute\n", + "import syft as sy\n", + "from syft.util.test_helpers.checkpoint import load_from_checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bbee80c-e1fa-4b47-a8d2-cd5099838a8c", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client = sy.login(email=\"info@openmined.org\", password=\"changethis\", port=8080)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "949f759d-e9ce-493a-ab2e-a5a6e817ed85", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1d9e708-43f9-4ba4-9e43-6eeb94b89800", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.worker_pools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7234d846-9c38-437a-a901-4ff67dfd92b7", + "metadata": {}, + "outputs": [], + "source": [ + "result = datasite_client.api.services.worker_pool.get_by_name(pool_name=\"default-pool\")\n", + "assert len(result.workers) == 3, str(result.to_dict())\n", + "result" + ] + }, + { + "cell_type": "markdown", + "id": "c964b66b-9992-42c4-8ce6-432ee3a50dc3", + "metadata": {}, + "source": [ + "Scale down to 1 worker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaf07b5b-da13-4c22-a0a1-8cb8ffcf9469", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.api.services.worker_pool.scale(number=1, pool_name=\"default-pool\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd269fbb-87af-44de-8bf9-2752ac8ef703", + "metadata": {}, + "outputs": [], + "source": [ + "result = datasite_client.api.services.worker_pool.get_by_name(pool_name=\"default-pool\")\n", + "assert len(result.workers) == 1, str(result.to_dict())\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9764f5b8-1449-4c8d-8458-b9add6fc938f", + "metadata": {}, + "outputs": [], + "source": [ + "default_worker_pool = datasite_client.api.services.worker_pool.get_by_name(\n", + " pool_name=\"default-pool\"\n", + ")\n", + "default_worker_pool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac2cfc1f-d817-436a-9304-ed5e15c3fad2", + "metadata": {}, + "outputs": [], + "source": [ + "load_from_checkpoint(name=\"1-checkpoint-v1\", client=datasite_client)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3442b39a-0606-44f2-83ca-4fc91a584881", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.worker_pools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f586f08-9d2b-42c3-8a6d-80af7786ba55", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client.worker_pools[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00f1af3c-718e-45df-b548-bec02ee69645", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/syft/src/syft/util/test_helpers/checkpoint.py b/packages/syft/src/syft/util/test_helpers/checkpoint.py index 99b0a3b1e52..599e1d62c43 100644 --- a/packages/syft/src/syft/util/test_helpers/checkpoint.py +++ b/packages/syft/src/syft/util/test_helpers/checkpoint.py @@ -1,7 +1,8 @@ # stdlib -import datetime import os from pathlib import Path +import tempfile +import zipfile # syft absolute from syft import SyftError @@ -17,44 +18,46 @@ CHECKPOINT_ROOT = "checkpoints" CHECKPOINT_DIR_PREFIX = "chkpt" +DEFAULT_CHECKPOINT_DIR = get_root_data_path() / CHECKPOINT_ROOT +try: + # Ensure the default checkpoint path exists always + DEFAULT_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) +except Exception as e: + print(f"Error creating default checkpoint directory: {e}") -def root_checkpoint_path() -> Path: - return get_root_data_path() / CHECKPOINT_ROOT +def is_admin(client: SyftClient) -> bool: + return client._SyftClient__user_role == ServiceRole.ADMIN -def get_checkpoint_parent_dir(server_uid: str, chkpt_name: str) -> Path: - return root_checkpoint_path() / chkpt_name / server_uid - - -def create_checkpoint_dir(server_uid: str, chkpt_name: str) -> Path: - """Create a checkpoint directory by chkpt_name and server_uid.""" - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = f"{CHECKPOINT_DIR_PREFIX}_{timestamp}" - checkpoint_parent_dir = get_checkpoint_parent_dir( - server_uid=server_uid, chkpt_name=chkpt_name - ) - checkpoint_full_path = checkpoint_parent_dir / checkpoint_dir - - # Format of Checkpoint Directory: - # <root_syft_dir>/checkpoints/chkpt_name/<server_uid>/chkpt_<timestamp> - checkpoint_full_path.mkdir(parents=True, exist_ok=True) - return checkpoint_full_path +def is_valid_dir(path: Path | str) -> Path: + if isinstance(path, str): + path = Path(path) + if not path.is_dir(): + raise SyftException(f"Path {path} is not a directory.") + return path -def is_admin(client: SyftClient) -> bool: - return client._SyftClient__user_role == ServiceRole.ADMIN +def is_valid_file(path: Path | str) -> Path: + if isinstance(path, str): + path = Path(path) + if not path.is_file(): + raise SyftException(f"Path {path} is not a file.") + return path def create_checkpoint( name: str, # Name of the checkpoint client: SyftClient, + chkpt_dir: Path | str = DEFAULT_CHECKPOINT_DIR, root_email: str | None = None, root_pwd: str | None = None, ) -> None: """Save a checkpoint for the database.""" + is_valid_dir(chkpt_dir) + if root_email is None: root_email = get_default_root_email() @@ -71,37 +74,34 @@ def create_checkpoint( if isinstance(migration_data, SyftError): raise SyftException(message=migration_data.message) - checkpoint_dir = create_checkpoint_dir( - server_uid=client.id.to_string(), chkpt_name=name - ) + checkpoint_path = chkpt_dir / f"{name}.zip" + + # get a temporary directory to save the checkpoint + temp_dir = Path(tempfile.mkdtemp()) + checkpoint_blob = temp_dir / "checkpoint.blob" + checkpoint_yaml = temp_dir / "checkpoint.yaml" migration_data.save( - path=checkpoint_dir / "migration.blob", - yaml_path=checkpoint_dir / "migration.yaml", + path=checkpoint_blob, + yaml_path=checkpoint_yaml, ) - print(f"Checkpoint saved at: \n {checkpoint_dir}") + # Combine the files into a single zip file to checkpoint_path + with zipfile.ZipFile(checkpoint_path, "w") as zipf: + zipf.write(checkpoint_blob, "checkpoint.blob") + zipf.write(checkpoint_yaml, "checkpoint.yaml") -def last_checkpoint_path_for(server_uid: str, chkpt_name: str) -> Path | None: - """Return the directory of the latest checkpoint for the given name.""" - - checkpoint_parent_dir = get_checkpoint_parent_dir( - server_uid=server_uid, chkpt_name=chkpt_name - ) + print(f"Checkpoint saved at: \n {checkpoint_path}") - checkpoint_dirs = [ - d - for d in checkpoint_parent_dir.glob(f"{CHECKPOINT_DIR_PREFIX}_*") - if d.is_dir() - ] - checkpoints_dirs_with_blob_entry = [ - d for d in checkpoint_dirs if any(d.glob("*.blob")) - ] - if checkpoints_dirs_with_blob_entry: - print(f"Loading from the last checkpoint for: {chkpt_name}") - return max(checkpoints_dirs_with_blob_entry, key=lambda d: d.stat().st_mtime) +def get_checkpoint_for( + path: Path | str | None = None, chkpt_name: str | None = None +) -> Path | None: + # Path takes precedence over name + if path: + return is_valid_file(path) - return None + if chkpt_name: + return is_valid_file(DEFAULT_CHECKPOINT_DIR / f"{chkpt_name}.zip") def get_registry_credentials() -> tuple[str, str]: @@ -112,7 +112,8 @@ def get_registry_credentials() -> tuple[str, str]: def load_from_checkpoint( client: SyftClient, - name: str, + name: str | None = None, + path: Path | str | None = None, root_email: str | None = None, root_password: str | None = None, registry_username: str | None = None, @@ -128,17 +129,26 @@ def load_from_checkpoint( if is_admin(client) else client.login(email=root_email, password=root_password) ) - latest_checkpoint_dir = last_checkpoint_path_for( - server_uid=client.id.to_string(), chkpt_name=name - ) + if name is None and path is None: + raise SyftException("Please provide either a checkpoint name or a path.") + + checkpoint_zip_path = get_checkpoint_for(path=path, chkpt_name=name) - if latest_checkpoint_dir is None: - print(f"No last checkpoint found for : {name}") + if checkpoint_zip_path is None: + print(f"No last checkpoint found for : {name} or {path}") return - print(f"Loading from checkpoint: {latest_checkpoint_dir}") + # Unzip the checkpoint zip file + with zipfile.ZipFile(checkpoint_zip_path, "r") as zipf: + checkpoint_temp_dir = Path(tempfile.mkdtemp()) + zipf.extract("checkpoint.blob", checkpoint_temp_dir) + zipf.extract("checkpoint.yaml", checkpoint_temp_dir) + + checkpoint_blob = checkpoint_temp_dir / "checkpoint.blob" + + print(f"Loading from checkpoint: {checkpoint_zip_path}") result = root_client.load_migration_data( - path_or_data=latest_checkpoint_dir / "migration.blob", + path_or_data=checkpoint_blob, include_worker_pools=True, with_reset_db=True, )