Skip to content

Commit

Permalink
Merge pull request #536 from allenai/shanea/storage-cleaner-wandb-pat…
Browse files Browse the repository at this point in the history
…h-from-checkpoint

[Storage Cleaner] Get wandb path from checkpoint dir instead of run dir
  • Loading branch information
2015aroras authored Apr 8, 2024
2 parents 657a55e + 5121db6 commit 62c7954
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,8 +1066,8 @@ def _get_wandb_config(wandb_run) -> TrainConfig:
return wandb_config


def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List:
config_path = os.path.join(training_run_dir, CONFIG_YAML)
def _get_matching_wandb_runs(wandb_runs, checkpoint_dir: str) -> List:
config_path = os.path.join(checkpoint_dir, CONFIG_YAML)
local_config_path = cached_path(config_path)
train_config = TrainConfig.load(local_config_path)

Expand All @@ -1076,18 +1076,18 @@ def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List:
]


def _get_wandb_path(run_dir: str) -> str:
def _get_wandb_path(checkpoint_dir: str, run_dir: str) -> str:
run_dir_storage = _get_storage_adapter_for_path(run_dir)

config_path = os.path.join(run_dir, CONFIG_YAML)
config_path = os.path.join(checkpoint_dir, CONFIG_YAML)
if not run_dir_storage.is_file(config_path):
raise FileNotFoundError("No config file found in run dir, cannot get wandb path")
raise FileNotFoundError(f"No config file found in checkpoint dir {checkpoint_dir}, cannot get wandb path")

local_config_path = cached_path(config_path)
config = TrainConfig.load(local_config_path, validate_paths=False)

if config.wandb is None or config.wandb.entity is None or config.wandb.project is None:
raise ValueError(f"Run at {run_dir} has missing wandb config, cannot get wandb run path")
raise ValueError(f"Checkpoint at {checkpoint_dir} has missing wandb config, cannot get wandb run path")

wandb_runs = []

Expand All @@ -1096,18 +1096,17 @@ def _get_wandb_path(run_dir: str) -> str:
wandb_runs += _get_wandb_runs_from_wandb_dir(run_dir_storage, wandb_dir, config)

wandb_runs += _get_wandb_runs_from_train_config(config)

# Remove duplicate wandb runs based on run path, and wandb runs that do not match our run.
# Remove duplicate wandb runs based on run path, and wandb runs that do not match our checkpoint.
wandb_runs = list({_get_wandb_path_from_run(wandb_run): wandb_run for wandb_run in wandb_runs}.values())
wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, run_dir)
wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, checkpoint_dir)

if len(wandb_matching_runs) == 0:
raise RuntimeError(f"Failed to find any wandb runs for {run_dir}. Run might no longer exist")
raise RuntimeError(f"Failed to find any wandb runs for {checkpoint_dir}. Run might no longer exist")

if len(wandb_matching_runs) > 1:
wandb_run_urls = [wandb_run.url for wandb_run in wandb_matching_runs]
raise RuntimeError(
f"Found {len(wandb_matching_runs)} runs matching run dir {run_dir}, cannot determine correct run: {wandb_run_urls}"
f"Found {len(wandb_matching_runs)} runs matching checkpoint dir {checkpoint_dir}, cannot determine correct run: {wandb_run_urls}"
)

return _get_wandb_path_from_run(wandb_matching_runs[0])
Expand Down Expand Up @@ -1203,14 +1202,15 @@ def _get_src_dest_pairs_for_copy(

assert config.append_wandb_path and not is_archive_file
checkpoint_to_wandb_path: Dict[str, str]
# TODO: Update _get_wandb_path to get the wandb path for a checkpoint rather than a run directory.
# A run directory could correspond to multiple wandb runs.
# No need to consider other checkpoints if we are filtering for a specific checkpoint
if config.entry is not None and _is_checkpoint_dir(entry_path := os.path.join(run_dir, config.entry)):
# No need to consider other checkpoints if we are filtering for a specific checkpoint
checkpoint_to_wandb_path = {entry_path: _get_wandb_path(run_dir)}
checkpoint_to_wandb_path = {entry_path: _get_wandb_path(entry_path, run_dir)}
else:
checkpoint_dirs = _get_checkpoint_dirs(run_dir, run_dir_storage)
checkpoint_to_wandb_path = {checkpoint_dir: _get_wandb_path(run_dir) for checkpoint_dir in checkpoint_dirs}
checkpoint_to_wandb_path = {
checkpoint_dir: _get_wandb_path(checkpoint_dir, run_dir) for checkpoint_dir in checkpoint_dirs
}

src_dest_pairs: List[Tuple[str, str]] = []
# Mappings of source checkpoint directories to destination checkpoint directories
Expand Down

0 comments on commit 62c7954

Please sign in to comment.