Skip to content

Commit

Permalink
experiments: auto push experiments (#10323)
Browse files Browse the repository at this point in the history
* add auto_push for experiment on run and on save.
add config.exp.auto_push and config.exp.git_remote

* add func tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* separate auto_push behaviour out of commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Saugat Pachhai (सौगात) <[email protected]>
  • Loading branch information
3 people authored Mar 4, 2024
1 parent 3b2e031 commit cd1e228
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 28 deletions.
2 changes: 2 additions & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def __call__(self, data):
"params": str,
"plots": str,
"live": str,
"auto_push": Bool,
"git_remote": str,
},
"parsing": {
"bool": All(Lower, Choices("store_true", "boolean_optional")),
Expand Down
65 changes: 37 additions & 28 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,16 @@ def save(
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
cls.commit(
dvc.scm, # type: ignore[arg-type]
exp_hash,
exp_name=info.name,
force=force,
message=message,
)

with cls.auto_push(dvc):
cls.commit(
dvc.scm, # type: ignore[arg-type]
exp_hash,
exp_name=info.name,
force=force,
message=message,
)

ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
untracked = dvc.scm.untracked_files()
Expand Down Expand Up @@ -460,9 +463,6 @@ def reproduce(
from dvc.repo.checkout import checkout as dvc_checkout
from dvc.ui import ui

auto_push = env2bool(DVC_EXP_AUTO_PUSH)
git_remote = os.getenv(DVC_EXP_GIT_REMOTE, None)

if queue is not None:
queue.put((rev, os.getpid()))
if log_errors and log_level is not None:
Expand All @@ -483,9 +483,6 @@ def reproduce(
message=message,
**kwargs,
) as dvc:
if auto_push:
cls._validate_remotes(dvc, git_remote)

args, kwargs = cls._repro_args(dvc)
if args:
targets: Optional[Union[list, str]] = args[0]
Expand Down Expand Up @@ -519,8 +516,6 @@ def reproduce(
dvc,
info,
exp_hash,
auto_push,
git_remote,
repro_force,
message=message,
)
Expand Down Expand Up @@ -550,20 +545,18 @@ def _repro_commit(
dvc,
info,
exp_hash,
auto_push,
git_remote,
repro_force,
message: Optional[str] = None,
) -> tuple[Optional[str], Optional["ExpRefInfo"], bool]:
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=repro_force,
message=message,
)
if auto_push:
cls._auto_push(dvc, dvc.scm, git_remote)
with cls.auto_push(dvc):
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=repro_force,
message=message,
)

ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref: Optional["ExpRefInfo"] = ExpRefInfo.from_ref(ref) if ref else None
if cls.WARN_UNTRACKED:
Expand Down Expand Up @@ -672,15 +665,30 @@ def _repro_args(cls, dvc):
kwargs = {}
return args, kwargs

@classmethod
@contextmanager
def auto_push(cls, dvc: "Repo") -> Iterator[None]:
exp_config = dvc.config.get("exp", {})
auto_push = env2bool(DVC_EXP_AUTO_PUSH, exp_config.get("auto_push", False))
if not auto_push:
yield
return

git_remote = os.getenv(
DVC_EXP_GIT_REMOTE, exp_config.get("git_remote", "origin")
)
cls._validate_remotes(dvc, git_remote)
yield
cls._auto_push(dvc, git_remote)

@staticmethod
def _auto_push(
dvc: "Repo",
scm: "Git",
git_remote: Optional[str],
push_cache=True,
run_cache=True,
):
branch = scm.get_ref(EXEC_BRANCH, follow=False)
branch = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
try:
dvc.experiments.push(
git_remote,
Expand Down Expand Up @@ -708,6 +716,7 @@ def commit(
message: Optional[str] = None,
):
"""Commit stages as an experiment and return the commit SHA."""

rev = scm.get_rev()
if not scm.is_dirty(untracked_files=False):
logger.debug("No changes to commit")
Expand Down
37 changes: 37 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,40 @@ def test_push_pull_invalid_workspace(
dvc.experiments.push(git_upstream.remote, push_cache=True)
dvc.experiments.pull(git_upstream.remote, pull_cache=True)
assert "failed to collect" not in caplog.text


@pytest.mark.parametrize(
"auto_push, expected_key", [(True, "up_to_date"), (False, "success")]
)
def test_auto_push_on_run(
tmp_dir, scm, dvc, git_upstream, local_remote, exp_stage, auto_push, expected_key
):
remote = git_upstream.remote

with dvc.config.edit() as conf:
conf["exp"]["auto_push"] = auto_push
conf["exp"]["git_remote"] = remote

exp_name = "foo"
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name=exp_name)

assert first(dvc.experiments.push(name=exp_name, git_remote=remote)) == expected_key


@pytest.mark.parametrize(
"auto_push, expected_key", [(True, "up_to_date"), (False, "success")]
)
def test_auto_push_on_save(
tmp_dir, scm, dvc, git_upstream, local_remote, exp_stage, auto_push, expected_key
):
remote = git_upstream.remote
exp_name = "foo"
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name=exp_name)

with dvc.config.edit() as conf:
conf["exp"]["auto_push"] = auto_push
conf["exp"]["git_remote"] = remote

dvc.experiments.save(name=exp_name, force=True)

assert first(dvc.experiments.push(name=exp_name, git_remote=remote)) == expected_key

0 comments on commit cd1e228

Please sign in to comment.