From 39559cc7bee16e1d700374167ad67159450d0b8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Nov 2024 22:27:32 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dvc/commands/experiments/__init__.py | 5 ++++- dvc/commands/experiments/remove.py | 4 ++-- dvc/repo/experiments/remove.py | 18 +++++++++++++----- tests/func/experiments/test_remove.py | 20 ++++++++++++++------ 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/dvc/commands/experiments/__init__.py b/dvc/commands/experiments/__init__.py index 656546c513..780e6e18a8 100644 --- a/dvc/commands/experiments/__init__.py +++ b/dvc/commands/experiments/__init__.py @@ -57,14 +57,17 @@ def add_parser(subparsers, parent_parser): cmd.add_parser(experiments_subparsers, parent_parser) hide_subparsers_from_help(experiments_subparsers) + def add_keep_selection_flag(experiments_subcmd_parser): experiments_subcmd_parser.add_argument( "--keep", action="store_true", default=False, help="Keep the selected experiments instead of removing them " - "(use it with `--rev`' and `--num` or with experiment names)." + "(use it with `--rev`' and `--num` or with experiment names).", ) + + def add_rev_selection_flags( experiments_subcmd_parser, command: str, default: bool = True ): diff --git a/dvc/commands/experiments/remove.py b/dvc/commands/experiments/remove.py index c5525388ba..7f63ece2b0 100644 --- a/dvc/commands/experiments/remove.py +++ b/dvc/commands/experiments/remove.py @@ -34,7 +34,7 @@ def run(self): num=self.args.num, queue=self.args.queue, git_remote=self.args.git_remote, - keep_selected=self.args.keep + keep_selected=self.args.keep, ) if removed: ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}") @@ -45,7 +45,7 @@ def run(self): def add_parser(experiments_subparsers, parent_parser): - from . import add_rev_selection_flags, add_keep_selection_flag + from . import add_keep_selection_flag, add_rev_selection_flags EXPERIMENTS_REMOVE_HELP = "Remove experiments." experiments_remove_parser = experiments_subparsers.add_parser( diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 8ae86c12cc..1042b37dd8 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -30,7 +30,7 @@ def remove( # noqa: C901, PLR0912 num: int = 1, queue: bool = False, git_remote: Optional[str] = None, - keep_selected: bool = False, # keep the experiments instead of removing them + keep_selected: bool = False, # keep the experiments instead of removing them ) -> list[str]: removed: list[str] = [] if not any([exp_names, queue, all_commits, rev]): @@ -48,10 +48,14 @@ def remove( # noqa: C901, PLR0912 if keep_selected: # In keep_selected mode, identify all experiments and remove the unselected ones all_exp_refs = exp_refs(repo.scm, git_remote) - selected_exp_names = resolve_selected_exp_names(exp_names, git_remote, num, repo, rev) + selected_exp_names = resolve_selected_exp_names( + exp_names, git_remote, num, repo, rev + ) # Identify experiments to remove: all experiments - selected experiments - unselected_exp_refs = [ref for ref in all_exp_refs if ref.name not in selected_exp_names] + unselected_exp_refs = [ + ref for ref in all_exp_refs if ref.name not in selected_exp_names + ] removed = [ref.name for ref in unselected_exp_refs] # Remove the unselected experiments @@ -107,10 +111,14 @@ def remove( # noqa: C901, PLR0912 def resolve_selected_exp_names(exp_names, git_remote, num, repo, rev): selected_exp_names = set() if exp_names: - selected_exp_names = set(exp_names) if isinstance(exp_names, list) else {exp_names} + selected_exp_names = ( + set(exp_names) if isinstance(exp_names, list) else {exp_names} + ) elif rev: selected_exp_names = set( - _resolve_exp_by_baseline(repo, [rev] if isinstance(rev, str) else rev, num, git_remote).keys() + _resolve_exp_by_baseline( + repo, [rev] if isinstance(rev, str) else rev, num, git_remote + ).keys() ) return selected_exp_names diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 109ccf5830..06f0aa84f4 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -180,8 +180,8 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(baseline_exp_ref)) is None assert scm.get_ref(str(new_exp_ref)) is None -def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage): +def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage): # Setup: Run experiments results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1") exp1_ref = first(exp_refs_by_rev(scm, first(results))) @@ -206,8 +206,8 @@ def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp2_ref)) is not None assert scm.get_ref(str(exp3_ref)) is None -def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage): +def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage): # Setup: Run experiments results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1") exp1_ref = first(exp_refs_by_rev(scm, first(results))) @@ -224,7 +224,7 @@ def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp3_ref)) is not None # Keep "exp1" and "exp2" and remove "exp3" - removed = dvc.experiments.remove(exp_names=["exp1","exp2"], keep_selected=True) + removed = dvc.experiments.remove(exp_names=["exp1", "exp2"], keep_selected=True) assert removed == ["exp3"] # Check remaining experiments @@ -232,6 +232,7 @@ def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp2_ref)) is not None assert scm.get_ref(str(exp3_ref)) is None + def test_keep_selected_all_by_name(tmp_dir, scm, dvc, exp_stage): # Setup: Run experiments results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1") @@ -249,7 +250,9 @@ def test_keep_selected_all_by_name(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp3_ref)) is not None # Keep "exp1" and "exp2" and remove "exp3" - removed = dvc.experiments.remove(exp_names=["exp1","exp2", "exp3"], keep_selected=True) + removed = dvc.experiments.remove( + exp_names=["exp1", "exp2", "exp3"], keep_selected=True + ) assert removed == [] # Check remaining experiments @@ -257,6 +260,7 @@ def test_keep_selected_all_by_name(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp2_ref)) is not None assert scm.get_ref(str(exp3_ref)) is not None + def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage): # Setup: Run experiments results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1") @@ -282,13 +286,16 @@ def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp2_ref)) is None assert scm.get_ref(str(exp3_ref)) is None + def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage): # Setup: Run experiments and commit results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1") exp1_ref = first(exp_refs_by_rev(scm, first(results))) scm.commit("commit1") - new_results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2") + new_results = dvc.experiments.run( + exp_stage.addressing, params=["foo=2"], name="exp2" + ) exp2_ref = first(exp_refs_by_rev(scm, first(new_results))) new_rev = scm.get_rev() @@ -304,6 +311,7 @@ def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(exp2_ref)) is not None assert scm.get_ref(str(exp1_ref)) is None + def test_keep_selected_by_rev_multiple(tmp_dir, scm, dvc, exp_stage): # Setup: Run experiments and commit results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1") @@ -331,4 +339,4 @@ def test_keep_selected_by_rev_multiple(tmp_dir, scm, dvc, exp_stage): # Check remaining experiments assert scm.get_ref(str(exp3_ref)) is not None assert scm.get_ref(str(exp2_ref)) is not None - assert scm.get_ref(str(exp1_ref)) is None \ No newline at end of file + assert scm.get_ref(str(exp1_ref)) is None