From 81111a023c561c5f8eac1a83003e6bc3eab38187 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 8 Apr 2024 11:19:57 -0700 Subject: [PATCH 1/2] Get wandb path from checkpoint dir instead of run dir --- scripts/storage_cleaner.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 2d9e2abc3..83a9d8a2d 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -1066,8 +1066,8 @@ def _get_wandb_config(wandb_run) -> TrainConfig: return wandb_config -def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List: - config_path = os.path.join(training_run_dir, CONFIG_YAML) +def _get_matching_wandb_runs(wandb_runs, checkpoint_dir: str) -> List: + config_path = os.path.join(checkpoint_dir, CONFIG_YAML) local_config_path = cached_path(config_path) train_config = TrainConfig.load(local_config_path) @@ -1076,18 +1076,18 @@ def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List: ] -def _get_wandb_path(run_dir: str) -> str: +def _get_wandb_path(checkpoint_dir: str, run_dir: str) -> str: run_dir_storage = _get_storage_adapter_for_path(run_dir) - config_path = os.path.join(run_dir, CONFIG_YAML) + config_path = os.path.join(checkpoint_dir, CONFIG_YAML) if not run_dir_storage.is_file(config_path): - raise FileNotFoundError("No config file found in run dir, cannot get wandb path") + raise FileNotFoundError(f"No config file found in checkpoint dir {checkpoint_dir}, cannot get wandb path") local_config_path = cached_path(config_path) config = TrainConfig.load(local_config_path, validate_paths=False) if config.wandb is None or config.wandb.entity is None or config.wandb.project is None: - raise ValueError(f"Run at {run_dir} has missing wandb config, cannot get wandb run path") + raise ValueError(f"Checkpoint at {checkpoint_dir} has missing wandb config, cannot get wandb run path") wandb_runs = [] @@ -1096,18 +1096,17 @@ def _get_wandb_path(run_dir: str) -> str: wandb_runs += _get_wandb_runs_from_wandb_dir(run_dir_storage, wandb_dir, config) wandb_runs += _get_wandb_runs_from_train_config(config) - - # Remove duplicate wandb runs based on run path, and wandb runs that do not match our run. + # Remove duplicate wandb runs based on run path, and wandb runs that do not match our checkpoint. wandb_runs = list({_get_wandb_path_from_run(wandb_run): wandb_run for wandb_run in wandb_runs}.values()) - wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, run_dir) + wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, checkpoint_dir) if len(wandb_matching_runs) == 0: - raise RuntimeError(f"Failed to find any wandb runs for {run_dir}. Run might no longer exist") + raise RuntimeError(f"Failed to find any wandb runs for {checkpoint_dir}. Run might no longer exist") if len(wandb_matching_runs) > 1: wandb_run_urls = [wandb_run.url for wandb_run in wandb_matching_runs] raise RuntimeError( - f"Found {len(wandb_matching_runs)} runs matching run dir {run_dir}, cannot determine correct run: {wandb_run_urls}" + f"Found {len(wandb_matching_runs)} runs matching checkpoint dir {checkpoint_dir}, cannot determine correct run: {wandb_run_urls}" ) return _get_wandb_path_from_run(wandb_matching_runs[0]) @@ -1203,14 +1202,13 @@ def _get_src_dest_pairs_for_copy( assert config.append_wandb_path and not is_archive_file checkpoint_to_wandb_path: Dict[str, str] - # TODO: Update _get_wandb_path to get the wandb path for a checkpoint rather than a run directory. - # A run directory could correspond to multiple wandb runs. + # No need to consider other checkpoints if we are filtering for a specific checkpoint if config.entry is not None and _is_checkpoint_dir(entry_path := os.path.join(run_dir, config.entry)): # No need to consider other checkpoints if we are filtering for a specific checkpoint - checkpoint_to_wandb_path = {entry_path: _get_wandb_path(run_dir)} + checkpoint_to_wandb_path = {entry_path: _get_wandb_path(entry_path, run_dir)} else: checkpoint_dirs = _get_checkpoint_dirs(run_dir, run_dir_storage) - checkpoint_to_wandb_path = {checkpoint_dir: _get_wandb_path(run_dir) for checkpoint_dir in checkpoint_dirs} + checkpoint_to_wandb_path = {checkpoint_dir: _get_wandb_path(checkpoint_dir, run_dir) for checkpoint_dir in checkpoint_dirs} src_dest_pairs: List[Tuple[str, str]] = [] # Mappings of source checkpoint directories to destination checkpoint directories From 5121db6cad3ed1c6bcfdb96d59a3e62d969d12b2 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 8 Apr 2024 13:25:04 -0700 Subject: [PATCH 2/2] Run ruff --- scripts/storage_cleaner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 83a9d8a2d..025ae4b2c 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -1208,7 +1208,9 @@ def _get_src_dest_pairs_for_copy( checkpoint_to_wandb_path = {entry_path: _get_wandb_path(entry_path, run_dir)} else: checkpoint_dirs = _get_checkpoint_dirs(run_dir, run_dir_storage) - checkpoint_to_wandb_path = {checkpoint_dir: _get_wandb_path(checkpoint_dir, run_dir) for checkpoint_dir in checkpoint_dirs} + checkpoint_to_wandb_path = { + checkpoint_dir: _get_wandb_path(checkpoint_dir, run_dir) for checkpoint_dir in checkpoint_dirs + } src_dest_pairs: List[Tuple[str, str]] = [] # Mappings of source checkpoint directories to destination checkpoint directories