Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 23, 2024
1 parent 3aaebbe commit 8155b3c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
3 changes: 3 additions & 0 deletions dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def add_parser(subparsers, parent_parser):
cmd.add_parser(experiments_subparsers, parent_parser)
hide_subparsers_from_help(experiments_subparsers)


def add_keep_selection_flag(experiments_subcmd_parser):
experiments_subcmd_parser.add_argument(
"--keep",
action="store_true",
default=False,
help="Keep the selected experiments instead of removing them (use it with `--rev` and `--num` or with experiment names).",
)


def add_rev_selection_flags(
experiments_subcmd_parser, command: str, default: bool = True
):
Expand Down
4 changes: 2 additions & 2 deletions dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run(self):
num=self.args.num,
queue=self.args.queue,
git_remote=self.args.git_remote,
keep_selected=self.args.keep
keep_selected=self.args.keep,
)
if removed:
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
Expand All @@ -45,7 +45,7 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags, add_keep_selection_flag
from . import add_keep_selection_flag, add_rev_selection_flags

EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
Expand Down
14 changes: 10 additions & 4 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def remove( # noqa: C901, PLR0912
num: int = 1,
queue: bool = False,
git_remote: Optional[str] = None,
keep_selected: bool = False, # keep the experiments instead of removing them
keep_selected: bool = False, # keep the experiments instead of removing them
) -> list[str]:
removed: list[str] = []
if not any([exp_names, queue, all_commits, rev]):
Expand All @@ -50,16 +50,22 @@ def remove( # noqa: C901, PLR0912
all_exp_refs = exp_refs(repo.scm, git_remote)

if exp_names:
selected_exp_names = set(exp_names) if isinstance(exp_names, list) else {exp_names}
selected_exp_names = (
set(exp_names) if isinstance(exp_names, list) else {exp_names}
)
elif rev:
selected_exp_names = set(
_resolve_exp_by_baseline(repo, [rev] if isinstance(rev, str) else rev, num, git_remote).keys()
_resolve_exp_by_baseline(
repo, [rev] if isinstance(rev, str) else rev, num, git_remote
).keys()
)
else:
selected_exp_names = set()

# Identify experiments to remove: all experiments - selected experiments
unselected_exp_refs = [ref for ref in all_exp_refs if ref.name not in selected_exp_names]
unselected_exp_refs = [
ref for ref in all_exp_refs if ref.name not in selected_exp_names
]
removed = [ref.name for ref in unselected_exp_refs]

# Remove the unselected experiments
Expand Down
20 changes: 14 additions & 6 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):
assert scm.get_ref(str(baseline_exp_ref)) is None
assert scm.get_ref(str(new_exp_ref)) is None

def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage):

def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
Expand All @@ -206,8 +206,8 @@ def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage):
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is None

def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage):

def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
Expand All @@ -224,14 +224,15 @@ def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage):
assert scm.get_ref(str(exp3_ref)) is not None

# Keep "exp1" and "exp2" and remove "exp3"
removed = dvc.experiments.remove(exp_names=["exp1","exp2"], keep_selected=True)
removed = dvc.experiments.remove(exp_names=["exp1", "exp2"], keep_selected=True)
assert removed == ["exp3"]

# Check remaining experiments
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is None


def test_keep_selected_all_by_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
Expand All @@ -249,14 +250,17 @@ def test_keep_selected_all_by_name(tmp_dir, scm, dvc, exp_stage):
assert scm.get_ref(str(exp3_ref)) is not None

# Keep "exp1" and "exp2" and remove "exp3"
removed = dvc.experiments.remove(exp_names=["exp1","exp2", "exp3"], keep_selected=True)
removed = dvc.experiments.remove(
exp_names=["exp1", "exp2", "exp3"], keep_selected=True
)
assert removed == []

# Check remaining experiments
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None


def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
Expand All @@ -282,14 +286,17 @@ def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):
assert scm.get_ref(str(exp2_ref)) is None
assert scm.get_ref(str(exp3_ref)) is None


def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments and commit
baseline = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit1")

new_results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
new_results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], name="exp2"
)
exp2_ref = first(exp_refs_by_rev(scm, first(new_results)))
new_rev = scm.get_rev()

Expand All @@ -305,6 +312,7 @@ def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage):
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp1_ref)) is None


def test_keep_selected_by_rev_multiple(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments and commit
baseline = scm.get_rev()
Expand Down Expand Up @@ -335,4 +343,4 @@ def test_keep_selected_by_rev_multiple(tmp_dir, scm, dvc, exp_stage):
# Check remaining experiments
assert scm.get_ref(str(exp3_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp1_ref)) is None
assert scm.get_ref(str(exp1_ref)) is None

0 comments on commit 8155b3c

Please sign in to comment.