diff --git a/dvc/repo/brancher.py b/dvc/repo/brancher.py index 3499cf0f08..38609677de 100644 --- a/dvc/repo/brancher.py +++ b/dvc/repo/brancher.py @@ -20,6 +20,7 @@ def brancher( all_tags=False, all_commits=False, all_experiments=False, + workspace=True, commit_date: Optional[str] = None, sha_only=False, num=1, @@ -31,6 +32,7 @@ def brancher( all_branches (bool): iterate over all available branches. all_commits (bool): iterate over all commits. all_tags (bool): iterate over all available tags. + workspace (bool): include workspace. commit_date (str): Keep experiments from the commits after(include) a certain date. Date must match the extended ISO 8601 format (YYYY-MM-DD). @@ -73,7 +75,8 @@ def brancher( logger.trace("switching fs to workspace") self.fs = LocalFileSystem(url=self.root_dir) - yield "workspace" + if workspace: + yield "workspace" revs = revs.copy() if revs else [] if "workspace" in revs: diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index f643e38cbf..699a54a422 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -112,4 +112,6 @@ def _pull_cache( refs = [refs] revs = list(exp_commits(repo.scm, refs)) logger.debug("dvc fetch experiment '%s'", refs) - repo.fetch(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) + repo.fetch( + jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, workspace=False + ) diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index f62b1b2634..184bf21406 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -188,4 +188,6 @@ def _push_cache( assert isinstance(repo.scm, Git) revs = list(exp_commits(repo.scm, refs)) logger.debug("dvc push experiment '%s'", refs) - return repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) + return repo.push( + jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, workspace=False + ) diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 5bdfc9acdb..3a3c795091 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -32,6 +32,7 @@ def _collect_indexes( # noqa: PLR0913 recursive=False, all_commits=False, revs=None, + workspace=True, max_size=None, types=None, config=None, @@ -62,6 +63,7 @@ def outs_filter(out: "Output") -> bool: all_branches=all_branches, all_tags=all_tags, all_commits=all_commits, + workspace=workspace, ): try: repo.config.merge(config) @@ -79,11 +81,11 @@ def outs_filter(out: "Output") -> bool: idx.data["repo"].onerror = _make_index_onerror(onerror, rev) indexes[rev or "workspace"] = idx - except Exception as exc: + except Exception as exc: # noqa: BLE001 if onerror: onerror(rev, None, exc) collection_exc = exc - logger.exception("failed to collect '%s'", rev or "workspace") + logger.warning("failed to collect '%s', skipping", rev or "workspace") if not indexes and collection_exc: raise collection_exc @@ -104,6 +106,7 @@ def fetch( # noqa: PLR0913 all_commits=False, run_cache=False, revs=None, + workspace=True, max_size=None, types=None, config=None, @@ -148,6 +151,7 @@ def fetch( # noqa: PLR0913 recursive=recursive, all_commits=all_commits, revs=revs, + workspace=workspace, max_size=max_size, types=types, config=config, diff --git a/dvc/repo/push.py b/dvc/repo/push.py index 5a14181577..ae42610d46 100644 --- a/dvc/repo/push.py +++ b/dvc/repo/push.py @@ -46,6 +46,7 @@ def push( # noqa: PLR0913 all_commits=False, run_cache=False, revs=None, + workspace=True, glob=False, ): from fsspec.utils import tokenize @@ -87,6 +88,7 @@ def push( # noqa: PLR0913 recursive=recursive, all_commits=all_commits, revs=revs, + workspace=workspace, push=True, ) diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 1fe9d80933..150279797d 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -1,3 +1,5 @@ +import logging + import pytest from funcy import first @@ -360,3 +362,17 @@ def test_get(tmp_dir, scm, dvc, exp_stage, erepo_dir, use_ref): rev=exp_ref.name if use_ref else exp_rev, ) assert (erepo_dir / "params.yaml").read_text().strip() == "foo: 2" + + +def test_push_pull_invalid_workspace( + tmp_dir, scm, dvc, git_upstream, exp_stage, local_remote, caplog +): + dvc.experiments.run() + + with open("dvc.yaml", mode="a") as f: + f.write("\ninvalid") + + with caplog.at_level(logging.WARNING, logger="dvc"): + dvc.experiments.push(git_upstream.remote, push_cache=True) + dvc.experiments.pull(git_upstream.remote, pull_cache=True) + assert "failed to collect" not in caplog.text