diff --git a/ml_experiment/Scheduler.py b/ml_experiment/Scheduler.py index 96afcc8..5015766 100644 --- a/ml_experiment/Scheduler.py +++ b/ml_experiment/Scheduler.py @@ -33,7 +33,7 @@ class LocalRunConfig(RunConfig): -Pred = Callable[[str, int, int, int], bool] +RunFilter = Callable[[RunSpec], bool] VersionSpec = int | dict[str, int | None] | None class Scheduler: @@ -71,11 +71,11 @@ def get_all_runs(self) -> Self: return self - def filter(self, already_exists: Pred) -> Scheduler: + def filter(self, already_exists: RunFilter) -> Scheduler: filtered = Scheduler(self.exp_name, self.seeds, self.entry, self.version, self.base_path) for r in self.all_runs: - if not already_exists(*r): + if not already_exists(r): filtered.all_runs.add(r) return filtered