diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index 51da4fdb92..0bf02c076a 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -254,15 +254,6 @@ def _from_stash_entry( **kwargs, ) - @classmethod - def _get_stage_files(cls, stages: List["Stage"]) -> List[str]: - from dvc.stage.utils import _get_stage_files - - ret: List[str] = [] - for stage in stages: - ret.extend(_get_stage_files(stage)) - return ret - @classmethod def _get_top_level_paths(cls, repo: "Repo") -> List["str"]: return list( @@ -518,10 +509,8 @@ def reproduce( recursive=kwargs.get("recursive", False), ) + kwargs["repro_fn"] = cls._repro_and_track stages = dvc.reproduce(*args, **kwargs) - if paths := cls._get_stage_files(stages): - logger.debug("Staging stage-related files: %s", paths) - dvc.scm_context.add(paths) if paths := cls._get_top_level_paths(dvc): logger.debug("Staging top-level files: %s", paths) dvc.scm_context.add(paths) @@ -546,6 +535,17 @@ def reproduce( # multiprocessing calls return ExecutorResult(exp_hash, exp_ref, repro_force) + @staticmethod + def _repro_and_track(stage: "Stage", **kwargs) -> Optional["Stage"]: + from dvc.repo.reproduce import _reproduce_stage + from dvc.stage.utils import _get_stage_files + + ret = _reproduce_stage(stage, **kwargs) + if not kwargs.get("dry") and (paths := _get_stage_files(stage)): + logger.debug("Staging stage-related files: %s", paths) + stage.repo.scm_context.add(paths) + return ret + @classmethod def _repro_commit( cls,