Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
rmic committed Nov 23, 2024
1 parent 5960b62 commit 7427d92
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
3 changes: 2 additions & 1 deletion dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def add_keep_selection_flag(experiments_subcmd_parser):
"--keep",
action="store_true",
default=False,
help="Keep the selected experiments instead of removing them (use it with `--rev` and `--num` or with experiment names).",
help="Keep the selected experiments instead of removing them (use it with `--rev`'"
" and `--num` or with experiment names)."
)
def add_rev_selection_flags(
experiments_subcmd_parser, command: str, default: bool = True
Expand Down
21 changes: 12 additions & 9 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,7 @@ def remove( # noqa: C901, PLR0912
if keep_selected:
# In keep_selected mode, identify all experiments and remove the unselected ones
all_exp_refs = exp_refs(repo.scm, git_remote)

if exp_names:
selected_exp_names = set(exp_names) if isinstance(exp_names, list) else {exp_names}
elif rev:
selected_exp_names = set(
_resolve_exp_by_baseline(repo, [rev] if isinstance(rev, str) else rev, num, git_remote).keys()
)
else:
selected_exp_names = set()
selected_exp_names = resolve_selected_exp_names(exp_names, git_remote, num, repo, rev)

# Identify experiments to remove: all experiments - selected experiments
unselected_exp_refs = [ref for ref in all_exp_refs if ref.name not in selected_exp_names]
Expand Down Expand Up @@ -112,6 +104,17 @@ def remove( # noqa: C901, PLR0912
return removed


def resolve_selected_exp_names(exp_names, git_remote, num, repo, rev):
selected_exp_names = set()
if exp_names:
selected_exp_names = set(exp_names) if isinstance(exp_names, list) else {exp_names}
elif rev:
selected_exp_names = set(
_resolve_exp_by_baseline(repo, [rev] if isinstance(rev, str) else rev, num, git_remote).keys()
)
return selected_exp_names


def _resolve_exp_by_baseline(
repo: "Repo",
rev: list[str],
Expand Down
4 changes: 0 additions & 4 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):

def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments and commit
baseline = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit1")
Expand All @@ -307,13 +306,10 @@ def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage):

def test_keep_selected_by_rev_multiple(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments and commit
baseline = scm.get_rev()
exp1_rev = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit1")

exp2_rev = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
exp2_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit2")
Expand Down

0 comments on commit 7427d92

Please sign in to comment.