diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index 51c3dc5f7a5..65a3fb96c2f 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -30,35 +30,6 @@ def _remove_unused_links(repo): return ret -def get_all_files_numbers(pairs): - return sum(stage.get_all_files_number(filter_info) for stage, filter_info in pairs) - - -def _collect_pairs( - self: "Repo", targets, with_deps: bool, recursive: bool -) -> Set["StageInfo"]: - from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError - - pairs: Set["StageInfo"] = set() - for target in targets: - try: - pairs.update( - self.stage.collect_granular( - target, with_deps=with_deps, recursive=recursive - ) - ) - except ( - StageFileDoesNotExistError, - StageFileBadNameError, - NoOutputOrStageError, - ) as exc: - if not target: - raise - raise CheckoutErrorSuggestGit(target) from exc - - return pairs - - @locked def checkout( self, @@ -70,13 +41,18 @@ def checkout( allow_missing=False, **kwargs, ): + from dvc import prompt from dvc.fs.callbacks import Callback + from dvc.repo.index import build_data_index + from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError + from dvc_data.index.checkout import ADD, DELETE, MODIFY + from dvc_data.index.checkout import checkout as icheckout + from dvc_data.index.diff import diff as idiff stats: Dict[str, List[str]] = { "added": [], "deleted": [], "modified": [], - "failed": [], } if not targets: targets = [None] @@ -85,28 +61,76 @@ def checkout( if isinstance(targets, str): targets = [targets] - pairs = _collect_pairs(self, targets, with_deps, recursive) - total = get_all_files_numbers(pairs) + def onerror(target, exc): + if target and isinstance( + exc, + ( + StageFileDoesNotExistError, + StageFileBadNameError, + NoOutputOrStageError, + ), + ): + raise CheckoutErrorSuggestGit(target) from exc + raise + + view = self.index.targets_view( + targets, + recursive=recursive, + with_deps=with_deps, + onerror=onerror, + ) + + old = build_data_index(view, self.root_dir, self.fs, compute_hash=True) + new = view.data["repo"] + + total = len(new) with Callback.as_tqdm_callback( unit="file", desc="Checkout", disable=total == 0, ) as cb: - cb.set_size(total) - for stage, filter_info in pairs: - result = stage.checkout( - force=force, - progress_callback=cb, - relink=relink, - filter_info=filter_info, - allow_missing=allow_missing, - **kwargs, - ) - for key, items in result.items(): - stats[key].extend(_fspath_dir(path) for path in items) - - if stats.get("failed"): - raise CheckoutError(stats["failed"], stats) - - del stats["failed"] + changes = icheckout( + new, + self.root_dir, + self.fs, + old=old, + callback=cb, + hash_only=True, + delete=True, + prompt=prompt.confirm, + update_meta=False, + relink=relink, + force=force, + **kwargs, + ) + + def _adapt_path(entry): + parts = list(entry.key) + if entry.meta and entry.meta.isdir: + parts.append("") + return self.fs.path.join(*parts) + + stats["added"].extend(_adapt_path(change.new) for change in changes[ADD]) + stats["deleted"].extend(_adapt_path(change.old) for change in changes[DELETE]) + stats["modified"].extend(_adapt_path(change.new) for change in changes[MODIFY]) + + from itertools import chain + + top_keys = {key for key, _ in new.iteritems(shallow=True)} + changed_keys = { + change.key for change in chain(changes[ADD], changes[DELETE], changes[MODIFY]) + } + for key in top_keys: + for changed_key in changed_keys: + if len(changed_key) >= len(key) and changed_key[: len(key)] == key: + self.state.save_link(self.fs.path.join(self.root_dir, *key), self.fs) + + failed = [ + entry + for _, entry in new.iteritems() + if not entry.hash_info and not (entry.meta and entry.meta.isdir) + ] + if failed: + raise CheckoutError([_adapt_path(entry) for entry in failed], stats) + return stats diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 1a3227c991c..dbf89d0bd33 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -534,6 +534,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]: lambda: _DataPrefixes(set(), set()) ) for out, filter_info in self._filtered_outs: + if not out.use_cache: + continue workspace, key = out.index_key if filter_info and out.fs.path.isin(filter_info, out.fs_path): key = key + out.fs.path.relparts(filter_info, out.fs_path) @@ -550,6 +552,9 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]: ret: Dict[str, Set["DataIndexKey"]] = defaultdict(set) for out, filter_info in self._filtered_outs: + if not out.use_cache: + continue + workspace, key = out.index_key if filter_info and out.fs.path.isin(filter_info, out.fs_path): key = key + out.fs.path.relparts(filter_info, out.fs_path) diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 1f003678455..b7e8d6e1083 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -728,11 +728,6 @@ def outs_cached(self) -> bool: for out in self.outs ) - def get_all_files_number(self, filter_info=None) -> int: - return sum( - out.get_files_number(filter_info) for out in self.filter_outs(filter_info) - ) - def get_used_objs( self, *args, **kwargs ) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]: diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index d308b633fc3..c1957df2a2d 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -294,19 +294,6 @@ def test_checkout_directory(tmp_dir, dvc): assert os.path.exists("data") -def test_checkout_hook(mocker, tmp_dir, dvc): - """Test that dvc checkout handles EOFError gracefully, which is what - it will experience when running in a git hook. - """ - tmp_dir.dvc_gen({"data": {"foo": "foo"}}) - mocker.patch("sys.stdout.isatty", return_value=True) - mocker.patch("dvc.prompt.input", side_effect=EOFError) - - (tmp_dir / "data").gen("test", "test") - with pytest.raises(ConfirmRemoveError): - dvc.checkout() - - def test_checkout_suggest_git(tmp_dir, dvc, scm): with pytest.raises(CheckoutErrorSuggestGit) as e: dvc.checkout(targets="gitbranch") @@ -432,7 +419,7 @@ def test_partial_checkout(tmp_dir, dvc, target): tmp_dir.dvc_gen({"dir": {"subdir": {"file": "file"}, "other": "other"}}) shutil.rmtree("dir") stats = dvc.checkout([target]) - assert stats["added"] == ["dir" + os.sep] + assert set(stats["added"]) == {"dir" + os.sep, os.path.join("dir", "subdir", "file")} assert list(walk_files("dir")) == [os.path.join("dir", "subdir", "file")] @@ -463,7 +450,13 @@ def test_stats_on_checkout(tmp_dir, dvc, scm): scm.checkout("-") stats = dvc.checkout() - assert set(stats["added"]) == {"bar", "dir" + os.sep, "foo"} + assert set(stats["added"]) == { + "bar", + "dir" + os.sep, + os.path.join("dir", "subdir", "file"), + os.path.join("dir", "other"), + "foo", + } tmp_dir.gen({"lorem": "lorem", "bar": "new bar", "dir2": {"file": "file"}}) (tmp_dir / "foo").unlink() @@ -479,7 +472,11 @@ def test_stats_on_checkout(tmp_dir, dvc, scm): scm.checkout("-") stats = dvc.checkout() assert set(stats["modified"]) == {"bar"} - assert set(stats["added"]) == {"dir2" + os.sep, "lorem"} + assert set(stats["added"]) == { + "dir2" + os.sep, + os.path.join("dir2", "file"), + "lorem", + } assert set(stats["deleted"]) == {"foo"} @@ -518,11 +515,11 @@ def test_stats_on_added_file_from_tracked_dir(tmp_dir, dvc, scm): tmp_dir.gen("dir/subdir/newfile", "newfile") tmp_dir.dvc_add("dir", commit="add newfile") scm.checkout("HEAD~") - assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]} + assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "deleted": [os.path.join("dir", "subdir", "newfile")]} assert dvc.checkout() == empty_checkout scm.checkout("-") - assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]} + assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "added": [os.path.join("dir", "subdir", "newfile")]} assert dvc.checkout() == empty_checkout @@ -535,11 +532,11 @@ def test_stats_on_updated_file_from_tracked_dir(tmp_dir, dvc, scm): tmp_dir.gen("dir/subdir/file", "what file?") tmp_dir.dvc_add("dir", commit="update file") scm.checkout("HEAD~") - assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]} + assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep, os.path.join("dir", "subdir", "file")]} assert dvc.checkout() == empty_checkout scm.checkout("-") - assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]} + assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep, os.path.join("dir", "subdir", "file")]} assert dvc.checkout() == empty_checkout @@ -552,11 +549,11 @@ def test_stats_on_removed_file_from_tracked_dir(tmp_dir, dvc, scm): (tmp_dir / "dir" / "subdir" / "file").unlink() tmp_dir.dvc_add("dir", commit="removed file from subdir") scm.checkout("HEAD~") - assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]} + assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "added": [os.path.join("dir", "subdir", "file")]} assert dvc.checkout() == empty_checkout scm.checkout("-") - assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]} + assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "deleted": [os.path.join("dir", "subdir", "file")]} assert dvc.checkout() == empty_checkout