Skip to content

Commit

Permalink
checkout: use index checkout
Browse files Browse the repository at this point in the history
Needed for iterative#9424
  • Loading branch information
efiop committed May 15, 2023
1 parent d08c85f commit 4ac8a5a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 93 deletions.
122 changes: 73 additions & 49 deletions dvc/repo/checkout.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import logging
import os
from typing import TYPE_CHECKING, Dict, List, Set
from itertools import chain
from typing import Dict, List

from dvc.exceptions import CheckoutError, CheckoutErrorSuggestGit, NoOutputOrStageError
from dvc.utils import relpath

from . import locked

if TYPE_CHECKING:
from . import Repo
from .stage import StageInfo

logger = logging.getLogger(__name__)


Expand All @@ -30,37 +27,8 @@ def _remove_unused_links(repo):
return ret


def get_all_files_numbers(pairs):
return sum(stage.get_all_files_number(filter_info) for stage, filter_info in pairs)


def _collect_pairs(
self: "Repo", targets, with_deps: bool, recursive: bool
) -> Set["StageInfo"]:
from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError

pairs: Set["StageInfo"] = set()
for target in targets:
try:
pairs.update(
self.stage.collect_granular(
target, with_deps=with_deps, recursive=recursive
)
)
except (
StageFileDoesNotExistError,
StageFileBadNameError,
NoOutputOrStageError,
) as exc:
if not target:
raise
raise CheckoutErrorSuggestGit(target) from exc

return pairs


@locked
def checkout(
def checkout( # noqa: C901
self,
targets=None,
with_deps=False,
Expand All @@ -70,13 +38,18 @@ def checkout(
allow_missing=False,
**kwargs,
):
from dvc import prompt
from dvc.fs.callbacks import Callback
from dvc.repo.index import build_data_index
from dvc.stage.exceptions import StageFileBadNameError, StageFileDoesNotExistError
from dvc_data.index.checkout import ADD, DELETE, MODIFY
from dvc_data.index.checkout import CheckoutError as IndexCheckoutError
from dvc_data.index.checkout import checkout as icheckout

stats: Dict[str, List[str]] = {
"added": [],
"deleted": [],
"modified": [],
"failed": [],
}
if not targets:
targets = [None]
Expand All @@ -85,28 +58,79 @@ def checkout(
if isinstance(targets, str):
targets = [targets]

pairs = _collect_pairs(self, targets, with_deps, recursive)
total = get_all_files_numbers(pairs)
def onerror(target, exc):
if target and isinstance(
exc,
(
StageFileDoesNotExistError,
StageFileBadNameError,
NoOutputOrStageError,
),
):
raise CheckoutErrorSuggestGit(target) from exc
raise # pylint: disable=misplaced-bare-raise

view = self.index.targets_view(
targets,
recursive=recursive,
with_deps=with_deps,
onerror=onerror,
)

old = build_data_index(view, self.root_dir, self.fs, compute_hash=True)
new = view.data["repo"]

total = len(new)
with Callback.as_tqdm_callback(
unit="file",
desc="Checkout",
disable=total == 0,
) as cb:
cb.set_size(total)
for stage, filter_info in pairs:
result = stage.checkout(
force=force,
progress_callback=cb,
try:
changes = icheckout(
new,
self.root_dir,
self.fs,
old=old,
callback=cb,
delete=True,
prompt=prompt.confirm,
update_meta=False,
relink=relink,
filter_info=filter_info,
force=force,
allow_missing=allow_missing,
**kwargs,
)
for key, items in result.items():
stats[key].extend(_fspath_dir(path) for path in items)
except IndexCheckoutError as exc:
raise CheckoutError([], {}) from exc

def _adapt_path(entry):
ret = _fspath_dir(self.fs.path.join(self.root_dir, *entry.key))
ret = relpath(ret)
if entry.meta and entry.meta.isdir:
return self.fs.path.join(ret, "")
return ret

top_keys = {key for key, _ in new.iteritems(shallow=True)}

if stats.get("failed"):
raise CheckoutError(stats["failed"], stats)
stats["added"].extend(_adapt_path(change.new) for change in changes[ADD])
stats["deleted"].extend(_adapt_path(change.old) for change in changes[DELETE])
stats["modified"].extend(_adapt_path(change.new) for change in changes[MODIFY])

changed_keys = {
change.key for change in chain(changes[ADD], changes[DELETE], changes[MODIFY])
}
for key in top_keys:
for changed_key in changed_keys:
if len(changed_key) >= len(key) and changed_key[: len(key)] == key:
self.state.save_link(self.fs.path.join(self.root_dir, *key), self.fs)

failed = [
entry
for _, entry in new.iteritems()
if not entry.hash_info and not (entry.meta and entry.meta.isdir)
]
if not allow_missing and failed:
raise CheckoutError([_adapt_path(entry) for entry in failed], stats)

del stats["failed"]
return stats
2 changes: 1 addition & 1 deletion dvc/repo/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def save_imports(
desc="Downloading imports from source",
unit="files",
) as cb:
checkout(data_view, tmpdir, cache.fs, callback=cb, storage="data")
checkout(data_view, tmpdir, cache.fs, callback=cb, storage="remote")
md5(data_view)
save(data_view, odb=cache, hardlink=True)

Expand Down
36 changes: 31 additions & 5 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,18 @@ def is_out_or_ignored(root, directory):


def _load_data_from_outs(index, prefix, outs):
from dvc_data.index import DataIndexEntry
from dvc_data.index import DataIndexEntry, Meta

parents = set()
for out in outs:
if not out.use_cache:
continue

ws, key = out.index_key

for key_len in range(1, len(key)):
parents.add((ws, key[:key_len]))

entry = DataIndexEntry(
key=key,
meta=out.meta,
Expand All @@ -137,6 +141,8 @@ def _load_data_from_outs(index, prefix, outs):
# index.view() work correctly.
index[(*prefix, ws, *key)] = entry

for ws, key in parents:
index[(*prefix, ws, *key)] = DataIndexEntry(key=key, meta=Meta(isdir=True), loaded=True)

def _load_storage_from_out(storage_map, key, out):
from dvc.config import NoRemoteError
Expand All @@ -163,7 +169,7 @@ def _load_storage_from_out(storage_map, key, out):

if out.stage.is_import:
dep = out.stage.deps[0]
storage_map.add_data(FileStorage(key, dep.fs, dep.fs_path))
storage_map.add_remote(FileStorage(key, dep.fs, dep.fs_path))


class Index:
Expand Down Expand Up @@ -534,6 +540,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]:
lambda: _DataPrefixes(set(), set())
)
for out, filter_info in self._filtered_outs:
if not out.use_cache:
continue
workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
Expand All @@ -550,6 +558,9 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]:
ret: Dict[str, Set["DataIndexKey"]] = defaultdict(set)

for out, filter_info in self._filtered_outs:
if not out.use_cache:
continue

workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
Expand Down Expand Up @@ -579,14 +590,14 @@ def key_filter(workspace: str, key: "DataIndexKey"):
return data


def build_data_index(
def build_data_index( # noqa: C901
index: Union["Index", "IndexView"],
path: str,
fs: "FileSystem",
workspace: str = "repo",
compute_hash: Optional[bool] = False,
) -> "DataIndex":
from dvc_data.index import DataIndex, DataIndexEntry
from dvc_data.index import DataIndex, DataIndexEntry, Meta
from dvc_data.index.build import build_entries, build_entry
from dvc_data.index.save import build_tree

Expand All @@ -595,9 +606,16 @@ def build_data_index(
ignore = index.repo.dvcignore

data = DataIndex()
parents = set()
for key in index.data_keys.get(workspace, set()):
out_path = fs.path.join(path, *key)

for key_len in range(1, len(key)):
parents.add(key[:key_len])

if not fs.exists(out_path):
continue

try:
out_entry = build_entry(
out_path,
Expand All @@ -606,7 +624,8 @@ def build_data_index(
state=index.repo.state,
)
except FileNotFoundError:
out_entry = DataIndexEntry()
continue
# out_entry = DataIndexEntry()

out_entry.key = key
data.add(out_entry)
Expand Down Expand Up @@ -636,4 +655,11 @@ def build_data_index(
out_entry.loaded = True
data.add(out_entry)

for key in parents:
parent_path = fs.path.join(path, *key)
if not fs.exists(parent_path):
continue
direntry = DataIndexEntry(key=key, meta=Meta(isdir=True), loaded=True)
data.add(direntry)

return data
5 changes: 0 additions & 5 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,11 +728,6 @@ def outs_cached(self) -> bool:
for out in self.outs
)

def get_all_files_number(self, filter_info=None) -> int:
return sum(
out.get_files_number(filter_info) for out in self.filter_outs(filter_info)
)

def get_used_objs(
self, *args, **kwargs
) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"configobj>=5.0.6",
"distro>=1.3",
"dpath<3,>=2.1.0",
"dvc-data>=0.49.1,<0.50",
"dvc-data>=0.50.0,<0.51",
"dvc-http>=2.29.0",
"dvc-render>=0.3.1,<1",
"dvc-studio-client>=0.9.0,<1",
Expand Down
Loading

0 comments on commit 4ac8a5a

Please sign in to comment.