Skip to content

Commit

Permalink
checkout: use index checkout
Browse files Browse the repository at this point in the history
Needed for iterative#9424
  • Loading branch information
efiop committed May 12, 2023
1 parent 2342099 commit d484a14
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 76 deletions.
122 changes: 73 additions & 49 deletions dvc/repo/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,6 @@ def _remove_unused_links(repo):
return ret


def get_all_files_numbers(pairs):
return sum(stage.get_all_files_number(filter_info) for stage, filter_info in pairs)


def _collect_pairs(
self: "Repo", targets, with_deps: bool, recursive: bool
) -> Set["StageInfo"]:
from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError

pairs: Set["StageInfo"] = set()
for target in targets:
try:
pairs.update(
self.stage.collect_granular(
target, with_deps=with_deps, recursive=recursive
)
)
except (
StageFileDoesNotExistError,
StageFileBadNameError,
NoOutputOrStageError,
) as exc:
if not target:
raise
raise CheckoutErrorSuggestGit(target) from exc

return pairs


@locked
def checkout(
self,
Expand All @@ -70,13 +41,18 @@ def checkout(
allow_missing=False,
**kwargs,
):
from dvc import prompt
from dvc.fs.callbacks import Callback
from dvc.repo.index import build_data_index
from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError
from dvc_data.index.checkout import ADD, DELETE, MODIFY
from dvc_data.index.checkout import checkout as icheckout
from dvc_data.index.diff import diff as idiff

stats: Dict[str, List[str]] = {
"added": [],
"deleted": [],
"modified": [],
"failed": [],
}
if not targets:
targets = [None]
Expand All @@ -85,28 +61,76 @@ def checkout(
if isinstance(targets, str):
targets = [targets]

pairs = _collect_pairs(self, targets, with_deps, recursive)
total = get_all_files_numbers(pairs)
def onerror(target, exc):
if target and isinstance(
exc,
(
StageFileDoesNotExistError,
StageFileBadNameError,
NoOutputOrStageError,
),
):
raise CheckoutErrorSuggestGit(target) from exc
raise

view = self.index.targets_view(
targets,
recursive=recursive,
with_deps=with_deps,
onerror=onerror,
)

old = build_data_index(view, self.root_dir, self.fs, compute_hash=True)
new = view.data["repo"]

total = len(new)
with Callback.as_tqdm_callback(
unit="file",
desc="Checkout",
disable=total == 0,
) as cb:
cb.set_size(total)
for stage, filter_info in pairs:
result = stage.checkout(
force=force,
progress_callback=cb,
relink=relink,
filter_info=filter_info,
allow_missing=allow_missing,
**kwargs,
)
for key, items in result.items():
stats[key].extend(_fspath_dir(path) for path in items)

if stats.get("failed"):
raise CheckoutError(stats["failed"], stats)

del stats["failed"]
changes = icheckout(
new,
self.root_dir,
self.fs,
old=old,
callback=cb,
hash_only=True,
delete=True,
prompt=prompt.confirm,
update_meta=False,
relink=relink,
force=force,
**kwargs,
)

def _adapt_path(entry):
parts = list(entry.key)
if entry.meta and entry.meta.isdir:
parts.append("")
return self.fs.path.join(*parts)

stats["added"].extend(_adapt_path(change.new) for change in changes[ADD])
stats["deleted"].extend(_adapt_path(change.old) for change in changes[DELETE])
stats["modified"].extend(_adapt_path(change.new) for change in changes[MODIFY])

from itertools import chain

top_keys = {key for key, _ in new.iteritems(shallow=True)}
changed_keys = {
change.key for change in chain(changes[ADD], changes[DELETE], changes[MODIFY])
}
for key in top_keys:
for changed_key in changed_keys:
if len(changed_key) >= len(key) and changed_key[: len(key)] == key:
self.state.save_link(self.fs.path.join(self.root_dir, *key), self.fs)

failed = [
entry
for _, entry in new.iteritems()
if not entry.hash_info and not (entry.meta and entry.meta.isdir)
]
if failed:
raise CheckoutError([_adapt_path(entry) for entry in failed], stats)

return stats
5 changes: 5 additions & 0 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]:
lambda: _DataPrefixes(set(), set())
)
for out, filter_info in self._filtered_outs:
if not out.use_cache:
continue
workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
Expand All @@ -550,6 +552,9 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]:
ret: Dict[str, Set["DataIndexKey"]] = defaultdict(set)

for out, filter_info in self._filtered_outs:
if not out.use_cache:
continue

workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
Expand Down
5 changes: 0 additions & 5 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,11 +728,6 @@ def outs_cached(self) -> bool:
for out in self.outs
)

def get_all_files_number(self, filter_info=None) -> int:
return sum(
out.get_files_number(filter_info) for out in self.filter_outs(filter_info)
)

def get_used_objs(
self, *args, **kwargs
) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
Expand Down
41 changes: 19 additions & 22 deletions tests/func/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,6 @@ def test_checkout_directory(tmp_dir, dvc):
assert os.path.exists("data")


def test_checkout_hook(mocker, tmp_dir, dvc):
"""Test that dvc checkout handles EOFError gracefully, which is what
it will experience when running in a git hook.
"""
tmp_dir.dvc_gen({"data": {"foo": "foo"}})
mocker.patch("sys.stdout.isatty", return_value=True)
mocker.patch("dvc.prompt.input", side_effect=EOFError)

(tmp_dir / "data").gen("test", "test")
with pytest.raises(ConfirmRemoveError):
dvc.checkout()


def test_checkout_suggest_git(tmp_dir, dvc, scm):
with pytest.raises(CheckoutErrorSuggestGit) as e:
dvc.checkout(targets="gitbranch")
Expand Down Expand Up @@ -432,7 +419,7 @@ def test_partial_checkout(tmp_dir, dvc, target):
tmp_dir.dvc_gen({"dir": {"subdir": {"file": "file"}, "other": "other"}})
shutil.rmtree("dir")
stats = dvc.checkout([target])
assert stats["added"] == ["dir" + os.sep]
assert set(stats["added"]) == {"dir" + os.sep, os.path.join("dir", "subdir", "file")}
assert list(walk_files("dir")) == [os.path.join("dir", "subdir", "file")]


Expand Down Expand Up @@ -463,7 +450,13 @@ def test_stats_on_checkout(tmp_dir, dvc, scm):

scm.checkout("-")
stats = dvc.checkout()
assert set(stats["added"]) == {"bar", "dir" + os.sep, "foo"}
assert set(stats["added"]) == {
"bar",
"dir" + os.sep,
os.path.join("dir", "subdir", "file"),
os.path.join("dir", "other"),
"foo",
}

tmp_dir.gen({"lorem": "lorem", "bar": "new bar", "dir2": {"file": "file"}})
(tmp_dir / "foo").unlink()
Expand All @@ -479,7 +472,11 @@ def test_stats_on_checkout(tmp_dir, dvc, scm):
scm.checkout("-")
stats = dvc.checkout()
assert set(stats["modified"]) == {"bar"}
assert set(stats["added"]) == {"dir2" + os.sep, "lorem"}
assert set(stats["added"]) == {
"dir2" + os.sep,
os.path.join("dir2", "file"),
"lorem",
}
assert set(stats["deleted"]) == {"foo"}


Expand Down Expand Up @@ -518,11 +515,11 @@ def test_stats_on_added_file_from_tracked_dir(tmp_dir, dvc, scm):
tmp_dir.gen("dir/subdir/newfile", "newfile")
tmp_dir.dvc_add("dir", commit="add newfile")
scm.checkout("HEAD~")
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]}
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "deleted": [os.path.join("dir", "subdir", "newfile")]}
assert dvc.checkout() == empty_checkout

scm.checkout("-")
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]}
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "added": [os.path.join("dir", "subdir", "newfile")]}
assert dvc.checkout() == empty_checkout


Expand All @@ -535,11 +532,11 @@ def test_stats_on_updated_file_from_tracked_dir(tmp_dir, dvc, scm):
tmp_dir.gen("dir/subdir/file", "what file?")
tmp_dir.dvc_add("dir", commit="update file")
scm.checkout("HEAD~")
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]}
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep, os.path.join("dir", "subdir", "file")]}
assert dvc.checkout() == empty_checkout

scm.checkout("-")
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]}
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep, os.path.join("dir", "subdir", "file")]}
assert dvc.checkout() == empty_checkout


Expand All @@ -552,11 +549,11 @@ def test_stats_on_removed_file_from_tracked_dir(tmp_dir, dvc, scm):
(tmp_dir / "dir" / "subdir" / "file").unlink()
tmp_dir.dvc_add("dir", commit="removed file from subdir")
scm.checkout("HEAD~")
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]}
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "added": [os.path.join("dir", "subdir", "file")]}
assert dvc.checkout() == empty_checkout

scm.checkout("-")
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep]}
assert dvc.checkout() == {**empty_checkout, "modified": ["dir" + os.sep], "deleted": [os.path.join("dir", "subdir", "file")]}
assert dvc.checkout() == empty_checkout


Expand Down

0 comments on commit d484a14

Please sign in to comment.