Skip to content

Commit

Permalink
checkout: use index checkout
Browse files Browse the repository at this point in the history
Needed for iterative#9424
  • Loading branch information
efiop committed May 11, 2023
1 parent 06d8e3c commit 911438d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 51 deletions.
109 changes: 60 additions & 49 deletions dvc/repo/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,6 @@ def _remove_unused_links(repo):
return ret


def get_all_files_numbers(pairs):
return sum(stage.get_all_files_number(filter_info) for stage, filter_info in pairs)


def _collect_pairs(
self: "Repo", targets, with_deps: bool, recursive: bool
) -> Set["StageInfo"]:
from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError

pairs: Set["StageInfo"] = set()
for target in targets:
try:
pairs.update(
self.stage.collect_granular(
target, with_deps=with_deps, recursive=recursive
)
)
except (
StageFileDoesNotExistError,
StageFileBadNameError,
NoOutputOrStageError,
) as exc:
if not target:
raise
raise CheckoutErrorSuggestGit(target) from exc

return pairs


@locked
def checkout(
self,
Expand All @@ -70,13 +41,17 @@ def checkout(
allow_missing=False,
**kwargs,
):
from dvc import prompt
from dvc.fs.callbacks import Callback
from dvc.repo.index import build_data_index
from dvc_data.index.checkout import ADD, DELETE, MODIFY
from dvc_data.index.checkout import checkout as icheckout
from dvc_data.index.diff import diff as idiff

stats: Dict[str, List[str]] = {
"added": [],
"deleted": [],
"modified": [],
"failed": [],
}
if not targets:
targets = [None]
Expand All @@ -85,28 +60,64 @@ def checkout(
if isinstance(targets, str):
targets = [targets]

pairs = _collect_pairs(self, targets, with_deps, recursive)
total = get_all_files_numbers(pairs)
view = self.index.targets_view(
targets,
recursive=recursive,
with_deps=True,
)

old = build_data_index(view, self.root_dir, self.fs, compute_hash=True)

new = view.data["repo"]
# failed = [entry for _, entry in new.iteritems() if not entry.hash_info and not (entry.meta and entry.meta.isdir)]

total = len(new)
with Callback.as_tqdm_callback(
unit="file",
desc="Checkout",
disable=total == 0,
) as cb:
cb.set_size(total)
for stage, filter_info in pairs:
result = stage.checkout(
force=force,
progress_callback=cb,
relink=relink,
filter_info=filter_info,
allow_missing=allow_missing,
**kwargs,
)
for key, items in result.items():
stats[key].extend(_fspath_dir(path) for path in items)

if stats.get("failed"):
raise CheckoutError(stats["failed"], stats)

del stats["failed"]
changes = icheckout(
new,
self.root_dir,
self.fs,
old=old,
callback=cb,
hash_only=True,
delete=True,
prompt=prompt.confirm,
update_meta=False, # FIXME should be off by defeault
relink=relink,
**kwargs,
)

def _adapt_path(entry):
parts = list(entry.key)
if entry.meta and entry.meta.isdir:
parts.append("")
return self.fs.path.join(*parts)

stats["added"].extend(_adapt_path(change.new) for change in changes[ADD])
stats["deleted"].extend(_adapt_path(change.old) for change in changes[DELETE])
stats["modified"].extend(_adapt_path(change.new) for change in changes[MODIFY])

from itertools import chain

top_keys = {key for key, _ in new.iteritems(shallow=True)}
changed_keys = {
change.key for change in chain(changes[ADD], changes[DELETE], changes[MODIFY])
}
for key in top_keys:
for changed_key in changed_keys:
if len(changed_key) >= len(key) and changed_key[: len(key)] == key:
self.state.save_link(self.fs.path.join(self.root_dir, *key), self.fs)

failed = [
entry
for _, entry in new.iteritems()
if not entry.hash_info and not (entry.meta and entry.meta.isdir)
]
if failed:
raise CheckoutError([_adapt_path(entry) for entry in failed], stats)

return stats
5 changes: 5 additions & 0 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]:
lambda: _DataPrefixes(set(), set())
)
for out, filter_info in self._filtered_outs:
if not out.use_cache:
continue
workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
Expand All @@ -550,6 +552,9 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]:
ret: Dict[str, Set["DataIndexKey"]] = defaultdict(set)

for out, filter_info in self._filtered_outs:
if not out.use_cache:
continue

workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
Expand Down
14 changes: 12 additions & 2 deletions tests/func/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,13 @@ def test_stats_on_checkout(tmp_dir, dvc, scm):

scm.checkout("-")
stats = dvc.checkout()
assert set(stats["added"]) == {"bar", "dir" + os.sep, "foo"}
assert set(stats["added"]) == {
"bar",
"dir" + os.sep,
os.path.join("dir", "subdir", "file"),
os.path.join("dir", "other"),
"foo",
}

tmp_dir.gen({"lorem": "lorem", "bar": "new bar", "dir2": {"file": "file"}})
(tmp_dir / "foo").unlink()
Expand All @@ -479,7 +485,11 @@ def test_stats_on_checkout(tmp_dir, dvc, scm):
scm.checkout("-")
stats = dvc.checkout()
assert set(stats["modified"]) == {"bar"}
assert set(stats["added"]) == {"dir2" + os.sep, "lorem"}
assert set(stats["added"]) == {
"dir2" + os.sep,
os.path.join("dir2", "file"),
"lorem",
}
assert set(stats["deleted"]) == {"foo"}


Expand Down

0 comments on commit 911438d

Please sign in to comment.