-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9331 from OpenMined/rasswanth/remove-name-input
[WIP] Notebooks State - Misc Improvements
- Loading branch information
Showing
3 changed files
with
30 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,14 @@ | ||
# stdlib | ||
import datetime | ||
import json | ||
import os | ||
from pathlib import Path | ||
|
||
# third party | ||
import ipykernel | ||
|
||
# syft absolute | ||
from syft import SyftError | ||
from syft import SyftException | ||
from syft.client.client import SyftClient | ||
from syft.service.user.user_roles import ServiceRole | ||
from syft.util.util import get_root_data_path | ||
from syft.util.util import is_interpreter_jupyter | ||
|
||
# relative | ||
from ...server.env import get_default_root_email | ||
|
@@ -24,53 +19,25 @@ | |
CHECKPOINT_DIR_PREFIX = "chkpt" | ||
|
||
|
||
def get_notebook_name_from_pytest_env() -> str | None: | ||
""" | ||
Returns the notebook file name from the test environment variable 'PYTEST_CURRENT_TEST'. | ||
If not available, returns None. | ||
""" | ||
pytest_current_test = os.environ.get("PYTEST_CURRENT_TEST", "") | ||
# Split by "::" and return the first part, which is the file path | ||
return os.path.basename(pytest_current_test.split("::")[0]) | ||
|
||
|
||
def current_nbname() -> Path: | ||
"""Retrieve the current Jupyter notebook name.""" | ||
curr_kernel_file = Path(ipykernel.get_connection_file()) | ||
kernel_file = json.loads(curr_kernel_file.read_text()) | ||
nb_name = kernel_file.get("jupyter_session", "") | ||
if not nb_name: | ||
nb_name = get_notebook_name_from_pytest_env() | ||
return Path(nb_name) | ||
|
||
|
||
def root_checkpoint_path() -> Path: | ||
return get_root_data_path() / CHECKPOINT_ROOT | ||
|
||
|
||
def checkpoint_parent_dir(server_uid: str, nb_name: str | None = None) -> Path: | ||
"""Return the checkpoint directory for the current notebook and server.""" | ||
if is_interpreter_jupyter: | ||
nb_name = nb_name if nb_name else current_nbname().stem | ||
return Path(f"{nb_name}/{server_uid}") if nb_name else Path(server_uid) | ||
return Path(server_uid) | ||
|
||
|
||
def get_checkpoints_dir(server_uid: str, nb_name: str) -> Path: | ||
return root_checkpoint_path() / checkpoint_parent_dir(server_uid, nb_name) | ||
|
||
def get_checkpoint_parent_dir(server_uid: str, chkpt_name: str) -> Path: | ||
return root_checkpoint_path() / chkpt_name / server_uid | ||
|
||
def get_checkpoint_dir( | ||
server_uid: str, checkpoint_dir: str, nb_name: str | None = None | ||
) -> Path: | ||
return get_checkpoints_dir(server_uid, nb_name) / checkpoint_dir | ||
|
||
|
||
def create_checkpoint_dir(server_uid: str) -> Path: | ||
"""Create a checkpoint directory for the current notebook and server.""" | ||
def create_checkpoint_dir(server_uid: str, chkpt_name: str) -> Path: | ||
"""Create a checkpoint directory by chkpt_name and server_uid.""" | ||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | ||
checkpoint_dir = f"{CHECKPOINT_DIR_PREFIX}_{timestamp}" | ||
checkpoint_full_path = get_checkpoint_dir(server_uid, checkpoint_dir=checkpoint_dir) | ||
checkpoint_parent_dir = get_checkpoint_parent_dir( | ||
server_uid=server_uid, chkpt_name=chkpt_name | ||
) | ||
checkpoint_full_path = checkpoint_parent_dir / checkpoint_dir | ||
|
||
# Format of Checkpoint Directory: | ||
# <root_syft_dir>/checkpoints/chkpt_name/<server_uid>/chkpt_<timestamp> | ||
|
||
checkpoint_full_path.mkdir(parents=True, exist_ok=True) | ||
return checkpoint_full_path | ||
|
@@ -81,6 +48,7 @@ def is_admin(client: SyftClient) -> bool: | |
|
||
|
||
def create_checkpoint( | ||
name: str, # Name of the checkpoint | ||
client: SyftClient, | ||
root_email: str | None = None, | ||
root_pwd: str | None = None, | ||
|
@@ -103,31 +71,22 @@ def create_checkpoint( | |
if isinstance(migration_data, SyftError): | ||
raise SyftException(message=migration_data.message) | ||
|
||
if not is_interpreter_jupyter(): | ||
raise SyftException( | ||
message="Checkpoint can only be created in Jupyter Notebook." | ||
) | ||
|
||
checkpoint_dir = create_checkpoint_dir(server_uid=client.id.to_string()) | ||
checkpoint_dir = create_checkpoint_dir( | ||
server_uid=client.id.to_string(), chkpt_name=name | ||
) | ||
migration_data.save( | ||
path=checkpoint_dir / "migration.blob", | ||
yaml_path=checkpoint_dir / "migration.yaml", | ||
) | ||
print(f"Checkpoint saved at: \n {checkpoint_dir}") | ||
|
||
|
||
def last_checkpoint_path_for_nb(server_uid: str, nb_name: str = None) -> Path | None: | ||
"""Return the directory of the latest checkpoint for the given notebook.""" | ||
nb_name = nb_name if nb_name else current_nbname().stem | ||
checkpoint_dir = None | ||
if len(nb_name.split("/")) > 1: | ||
nb_name, checkpoint_dir = nb_name.split("/") | ||
def last_checkpoint_path_for(server_uid: str, chkpt_name: str) -> Path | None: | ||
"""Return the directory of the latest checkpoint for the given name.""" | ||
|
||
filename = nb_name.split(".ipynb")[0] | ||
checkpoint_parent_dir = get_checkpoints_dir(server_uid, filename) | ||
|
||
if checkpoint_dir: | ||
return checkpoint_parent_dir / checkpoint_dir | ||
checkpoint_parent_dir = get_checkpoint_parent_dir( | ||
server_uid=server_uid, chkpt_name=chkpt_name | ||
) | ||
|
||
checkpoint_dirs = [ | ||
d | ||
|
@@ -139,7 +98,7 @@ def last_checkpoint_path_for_nb(server_uid: str, nb_name: str = None) -> Path | | |
] | ||
|
||
if checkpoints_dirs_with_blob_entry: | ||
print("Loading from the last checkpoint of the current notebook.") | ||
print(f"Loading from the last checkpoint for: {chkpt_name}") | ||
return max(checkpoints_dirs_with_blob_entry, key=lambda d: d.stat().st_mtime) | ||
|
||
return None | ||
|
@@ -153,17 +112,13 @@ def get_registry_credentials() -> tuple[str, str]: | |
|
||
def load_from_checkpoint( | ||
client: SyftClient, | ||
prev_nb_filename: str | None = None, | ||
name: str, | ||
root_email: str | None = None, | ||
root_password: str | None = None, | ||
registry_username: str | None = None, | ||
registry_password: str | None = None, | ||
checkpoint_name: str | None = None, | ||
) -> None: | ||
"""Load the last saved checkpoint for the given notebook state.""" | ||
if prev_nb_filename is None: | ||
print("Loading from the last checkpoint of the current notebook.") | ||
prev_nb_filename = current_nbname().stem | ||
"""Load the last saved checkpoint for the given checkpoint state.""" | ||
|
||
root_email = "[email protected]" if root_email is None else root_email | ||
root_password = "changethis" if root_password is None else root_password | ||
|
@@ -173,12 +128,12 @@ def load_from_checkpoint( | |
if is_admin(client) | ||
else client.login(email=root_email, password=root_password) | ||
) | ||
latest_checkpoint_dir = last_checkpoint_path_for_nb( | ||
client.id.to_string(), prev_nb_filename | ||
latest_checkpoint_dir = last_checkpoint_path_for( | ||
server_uid=client.id.to_string(), chkpt_name=name | ||
) | ||
|
||
if latest_checkpoint_dir is None: | ||
print(f"No last checkpoint found for notebook: {prev_nb_filename}") | ||
print(f"No last checkpoint found for : {name}") | ||
return | ||
|
||
print(f"Loading from checkpoint: {latest_checkpoint_dir}") | ||
|