Skip to content

Commit

Permalink
add enum to ls_files to simplify
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Sep 9, 2024
1 parent 6daa218 commit 5d45412
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 44 deletions.
8 changes: 1 addition & 7 deletions flytekit/image_spec/default_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def create_docker_context(image_spec: ImageSpec, tmp_dir: Path):
# what about deref_symlink?
ignore = IgnoreGroup(image_spec.source_root, [GitIgnore, DockerIgnore, StandardIgnore])

if image_spec.copy == CopyFileDetection.LOADED_MODULES:
# This is the 'auto' semantic by default used for pyflyte run, it only copies loaded .py files.
sys_modules = list(sys.modules.values())
ls, _ = ls_files(str(image_spec.source_root), sys_modules, deref_symlinks=False, ignore_group=ignore)
else:
# This triggers listing of all files
ls, _ = ls_files(str(image_spec.source_root), [], deref_symlinks=False, ignore_group=ignore)
ls, _ = ls_files(str(image_spec.source_root), image_spec.copy, deref_symlinks=False, ignore_group=ignore)

for file_to_copy in ls:
rel_path = os.path.relpath(file_to_copy, start=str(image_spec.source_root))
Expand Down
10 changes: 2 additions & 8 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import pathlib
import re
import sys
import typing
from abc import abstractmethod
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -164,13 +163,8 @@ def tag(self) -> str:
# todo: we should pipe through ignores from the command line here at some point.
# what about deref_symlink?
ignore = IgnoreGroup(self.source_root, [GitIgnore, DockerIgnore, StandardIgnore])
if self.copy == CopyFileDetection.LOADED_MODULES:
# This is the 'auto' semantic by default used for pyflyte run, it only copies loaded .py files.
sys_modules = list(sys.modules.values())
_, ls_digest = ls_files(str(self.source_root), sys_modules, deref_symlinks=False, ignore_group=ignore)
else:
# This triggers listing of all files, mimicking the old way of creating the tar file.
_, ls_digest = ls_files(str(self.source_root), [], deref_symlinks=False, ignore_group=ignore)

_, ls_digest = ls_files(str(self.source_root), self.copy, deref_symlinks=False, ignore_group=ignore)

# Since the source root is supposed to represent the files, store the digest into the source root as a
# shortcut to represent all the files.
Expand Down
10 changes: 1 addition & 9 deletions flytekit/tools/fast_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pathlib
import posixpath
import subprocess
import sys
import tarfile
import tempfile
import typing
Expand Down Expand Up @@ -97,14 +96,7 @@ def fast_package(
if options and (
options.copy_style == CopyFileDetection.LOADED_MODULES or options.copy_style == CopyFileDetection.ALL
):
if options.copy_style == CopyFileDetection.LOADED_MODULES:
# This is the 'auto' semantic by default used for pyflyte run, it only copies loaded .py files.
sys_modules = list(sys.modules.values())
ls, ls_digest = ls_files(str(source), sys_modules, deref_symlinks, ignore)
else:
# This triggers listing of all files, mimicking the old way of creating the tar file.
ls, ls_digest = ls_files(str(source), [], deref_symlinks, ignore)

ls, ls_digest = ls_files(str(source), options.copy_style, deref_symlinks, ignore)
logger.debug(f"Hash digest: {ls_digest}", fg="green")

if options.show_files:
Expand Down
15 changes: 7 additions & 8 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from types import ModuleType
from typing import List, Optional, Tuple, Union

from flytekit.constants import CopyFileDetection
from flytekit.loggers import logger
from flytekit.tools.ignore import IgnoreGroup

Expand Down Expand Up @@ -86,7 +87,7 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo:

def ls_files(
source_path: str,
modules: List[ModuleType],
copy_file_detection: CopyFileDetection,
deref_symlinks: bool = False,
ignore_group: Optional[IgnoreGroup] = None,
) -> Tuple[List[str], str]:
Expand All @@ -101,19 +102,17 @@ def ls_files(
Then the common root is just the folder a/. The modules list is filtered against this root. Only files
representing modules under this root are included
If the modules list should be a list of all the
needs to compute digest as well.
If the copy enum is set to loaded_modules, then the loaded sys modules will be used.
"""

# Unlike the below, the value error here is useful and should be returned to the user, like if absolute and
# relative paths are mixed.

# This is --copy auto
if modules:
all_files = list_imported_modules_as_files(source_path, modules)
# this is --copy all
if copy_file_detection == CopyFileDetection.LOADED_MODULES:
sys_modules = list(sys.modules.values())
all_files = list_imported_modules_as_files(source_path, sys_modules)
# this is --copy all (--copy none should never invoke this function)
else:
all_files = list_all_files(source_path, deref_symlinks, ignore_group)

Expand Down
9 changes: 1 addition & 8 deletions plugins/flytekit-envd/flytekitplugins/envd/image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pathlib
import shutil
import subprocess
import sys
from dataclasses import asdict
from importlib import metadata

Expand Down Expand Up @@ -161,13 +160,7 @@ def build():

dst = pathlib.Path(cfg_path).parent

if image_spec.copy == CopyFileDetection.LOADED_MODULES:
# This is the 'auto' semantic by default used for pyflyte run, it only copies loaded .py files.
sys_modules = list(sys.modules.values())
ls, _ = ls_files(str(image_spec.source_root), sys_modules, deref_symlinks=False, ignore_group=ignore)
else:
# This triggers listing of all files
ls, _ = ls_files(str(image_spec.source_root), [], deref_symlinks=False, ignore_group=ignore)
ls, _ = ls_files(str(image_spec.source_root), image_spec.copy, deref_symlinks=False, ignore_group=ignore)

for file_to_copy in ls:
rel_path = os.path.relpath(file_to_copy, start=str(image_spec.source_root))
Expand Down
9 changes: 5 additions & 4 deletions tests/flytekit/unit/cli/pyflyte/test_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile

from flytekit.tools.script_mode import ls_files

from flytekit.constants import CopyFileDetection

# a pytest fixture that creates a tmp directory and creates
# a small file structure in it
Expand Down Expand Up @@ -36,15 +36,16 @@ def dummy_dir_structure():


def test_list_dir(dummy_dir_structure):
files, d = ls_files(str(dummy_dir_structure), [])
files, d = ls_files(str(dummy_dir_structure), CopyFileDetection.ALL)
assert len(files) == 5
if os.name != "nt":
assert d == "c092f1b85f7c6b2a71881a946c00a855"


def test_list_filtered_on_modules(dummy_dir_structure):
import sys # any module will do
files, d = ls_files(str(dummy_dir_structure), [sys])
# any module will do
import sys # noqa
files, d = ls_files(str(dummy_dir_structure), CopyFileDetection.LOADED_MODULES)
# because none of the files are python modules, nothing should be returned
assert len(files) == 0
if os.name != "nt":
Expand Down

0 comments on commit 5d45412

Please sign in to comment.