diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index 51c3dc5f7a5..36a89f945c1 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -30,35 +30,6 @@ def _remove_unused_links(repo): return ret -def get_all_files_numbers(pairs): - return sum(stage.get_all_files_number(filter_info) for stage, filter_info in pairs) - - -def _collect_pairs( - self: "Repo", targets, with_deps: bool, recursive: bool -) -> Set["StageInfo"]: - from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError - - pairs: Set["StageInfo"] = set() - for target in targets: - try: - pairs.update( - self.stage.collect_granular( - target, with_deps=with_deps, recursive=recursive - ) - ) - except ( - StageFileDoesNotExistError, - StageFileBadNameError, - NoOutputOrStageError, - ) as exc: - if not target: - raise - raise CheckoutErrorSuggestGit(target) from exc - - return pairs - - @locked def checkout( self, @@ -70,13 +41,17 @@ def checkout( allow_missing=False, **kwargs, ): + from dvc import prompt from dvc.fs.callbacks import Callback + from dvc.repo.index import build_data_index + from dvc_data.index.checkout import ADD, DELETE, MODIFY + from dvc_data.index.checkout import checkout as icheckout + from dvc_data.index.diff import diff as idiff stats: Dict[str, List[str]] = { "added": [], "deleted": [], "modified": [], - "failed": [], } if not targets: targets = [None] @@ -85,28 +60,64 @@ def checkout( if isinstance(targets, str): targets = [targets] - pairs = _collect_pairs(self, targets, with_deps, recursive) - total = get_all_files_numbers(pairs) + view = self.index.targets_view( + targets, + recursive=recursive, + with_deps=True, + ) + + old = build_data_index(view, self.root_dir, self.fs, compute_hash=True) + + new = view.data["repo"] + # failed = [entry for _, entry in new.iteritems() if not entry.hash_info and not (entry.meta and entry.meta.isdir)] + + total = len(new) with Callback.as_tqdm_callback( unit="file", desc="Checkout", disable=total == 0, ) as cb: - cb.set_size(total) - for stage, filter_info in pairs: - result = stage.checkout( - force=force, - progress_callback=cb, - relink=relink, - filter_info=filter_info, - allow_missing=allow_missing, - **kwargs, - ) - for key, items in result.items(): - stats[key].extend(_fspath_dir(path) for path in items) - - if stats.get("failed"): - raise CheckoutError(stats["failed"], stats) - - del stats["failed"] + changes = icheckout( + new, + self.root_dir, + self.fs, + old=old, + callback=cb, + hash_only=True, + delete=True, + prompt=prompt.confirm, + update_meta=False, # FIXME should be off by defeault + relink=relink, + **kwargs, + ) + + def _adapt_path(entry): + parts = list(entry.key) + if entry.meta and entry.meta.isdir: + parts.append("") + return self.fs.path.join(*parts) + + stats["added"].extend(_adapt_path(change.new) for change in changes[ADD]) + stats["deleted"].extend(_adapt_path(change.old) for change in changes[DELETE]) + stats["modified"].extend(_adapt_path(change.new) for change in changes[MODIFY]) + + from itertools import chain + + top_keys = {key for key, _ in new.iteritems(shallow=True)} + changed_keys = { + change.key for change in chain(changes[ADD], changes[DELETE], changes[MODIFY]) + } + for key in top_keys: + for changed_key in changed_keys: + if len(changed_key) >= len(key) and changed_key[: len(key)] == key: + self.state.save_link(self.fs.path.join(self.root_dir, *key), self.fs) + + failed = [ + entry + for _, entry in new.iteritems() + if not entry.hash_info and not (entry.meta and entry.meta.isdir) + ] + if failed: + raise CheckoutError([_adapt_path(entry) for entry in failed], stats) + return stats diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 1a3227c991c..dbf89d0bd33 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -534,6 +534,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]: lambda: _DataPrefixes(set(), set()) ) for out, filter_info in self._filtered_outs: + if not out.use_cache: + continue workspace, key = out.index_key if filter_info and out.fs.path.isin(filter_info, out.fs_path): key = key + out.fs.path.relparts(filter_info, out.fs_path) @@ -550,6 +552,9 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]: ret: Dict[str, Set["DataIndexKey"]] = defaultdict(set) for out, filter_info in self._filtered_outs: + if not out.use_cache: + continue + workspace, key = out.index_key if filter_info and out.fs.path.isin(filter_info, out.fs_path): key = key + out.fs.path.relparts(filter_info, out.fs_path) diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index d308b633fc3..012caa99577 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -463,7 +463,13 @@ def test_stats_on_checkout(tmp_dir, dvc, scm): scm.checkout("-") stats = dvc.checkout() - assert set(stats["added"]) == {"bar", "dir" + os.sep, "foo"} + assert set(stats["added"]) == { + "bar", + "dir" + os.sep, + os.path.join("dir", "subdir", "file"), + os.path.join("dir", "other"), + "foo", + } tmp_dir.gen({"lorem": "lorem", "bar": "new bar", "dir2": {"file": "file"}}) (tmp_dir / "foo").unlink() @@ -479,7 +485,11 @@ def test_stats_on_checkout(tmp_dir, dvc, scm): scm.checkout("-") stats = dvc.checkout() assert set(stats["modified"]) == {"bar"} - assert set(stats["added"]) == {"dir2" + os.sep, "lorem"} + assert set(stats["added"]) == { + "dir2" + os.sep, + os.path.join("dir2", "file"), + "lorem", + } assert set(stats["deleted"]) == {"foo"}