Skip to content

Commit

Permalink
new io funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Apr 25, 2023
1 parent 5cf01e6 commit b1162e9
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.19.4"
version = "0.19.5"
description = """\
SMASHED is a toolkit designed to apply transformations to samples in \
datasets, such as fields extraction, tokenization, prompting, batching, \
Expand Down
103 changes: 92 additions & 11 deletions src/smashed/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@

__all__ = [
"copy_directory",
"exists",
"is_dir",
"is_file",
"open_file_for_read",
"open_file_for_write",
"recursively_list_files",
Expand Down Expand Up @@ -274,6 +277,66 @@ def open_file_for_read(
remove_local_file(str(path))


def is_dir(
path: PathType,
client: Optional[ClientType] = None,
raise_if_not_exists: bool = False,
) -> bool:
"""Check if a path is a directory."""

path = MultiPath.parse(path)
client = client or get_client_if_needed(path)

if path.is_local:
if not (e := path.as_path.exists()) and raise_if_not_exists:
raise FileNotFoundError(f"Path does not exist: {path}")
elif not e:
return False
return path.as_path.is_dir()
elif path.is_s3:
assert client is not None, "Could not get S3 client"
resp = client.list_objects_v2(
Bucket=path.bucket, Prefix=path.key.lstrip("/"), Delimiter="/"
)
if "CommonPrefixes" in resp:
return True
elif "Contents" in resp:
return False
elif raise_if_not_exists:
raise FileNotFoundError(f"Path does not exist: {path}")
return False
else:
raise FileNotFoundError(f"Unsupported protocol: {path.prot}")


def is_file(
path: PathType,
client: Optional[ClientType] = None,
raise_if_not_exists: bool = False,
) -> bool:
"""Check if a path is a file."""

try:
return not is_dir(path=path, client=client, raise_if_not_exists=True)
except FileNotFoundError as e:
if raise_if_not_exists:
raise FileNotFoundError(f"Path does not exist: {path}") from e
return False


def exists(
path: PathType,
client: Optional[ClientType] = None,
) -> bool:
"""Check if a path exists"""

try:
is_dir(path=path, client=client, raise_if_not_exists=True)
return True
except FileNotFoundError:
return False


@contextmanager
def open_file_for_write(
path: PathType,
Expand Down Expand Up @@ -350,16 +413,24 @@ def open_file_for_write(

def recursively_list_files(
path: PathType,
ignore_hidden_files: bool = True,
ignore_hidden: bool = True,
include_dirs: bool = False,
include_files: bool = True,
client: Optional[ClientType] = None,
) -> Iterable[str]:
"""Recursively list all files in the given directory for a given
path, local or remote.
Args:
path (Union[str, Path, MultiPath]): The path to list content at.
ignore_hidden_files (bool, optional): Whether to ignore hidden files
(i.e. files that start with a dot) when listing. Defaults to True.
ignore_hidden (bool, optional): Whether to ignore hidden files and
directories when listing. Defaults to True.
include_dirs (bool, optional): Whether to include directories in the
listing. Defaults to False.
include_files (bool, optional): Whether to include files in the
listing. Defaults to True.
client (boto3.client, optional): The boto3 client to use. If not
provided, one will be created if necessary.
"""

path = MultiPath.parse(path)
Expand All @@ -377,17 +448,27 @@ def recursively_list_files(
for page in pages:
for obj in page["Contents"]:
key = obj["Key"]
if key[-1] == "/": # last char is a slash
path = MultiPath(prot="s3", root=path.root, path=key)
if key[-1] == "/" and key != prefix:
# last char is a slash, so it's a directory
# we don't want to re-include the prefix though, so we
# check that it's not the same
prefixes.append(key)
if include_dirs:
yield str(path)
else:
p = MultiPath(prot="s3", root=path.root, path=key)
yield str(p)
if include_files:
yield str(path)

if path.is_local:
for _root, _, files in local_walk(path.as_str):
for _root, dirnames, filenames in local_walk(path.as_str):
root = Path(_root)
for f in files:
if ignore_hidden_files and f.startswith("."):
to_list = [
*(dirnames if include_dirs else []),
*(filenames if include_files else []),
]
for f in to_list:
if ignore_hidden and f.startswith("."):
continue
yield str(MultiPath.parse(root / f))

Expand Down Expand Up @@ -423,7 +504,7 @@ def copy_directory(
client = client or get_client_if_needed(src) or get_client_if_needed(dst)

for sp in recursively_list_files(
path=src, ignore_hidden_files=ignore_hidden_files
path=src, ignore_hidden=ignore_hidden_files
):
# parse the source path
source_path = MultiPath.parse(sp)
Expand Down Expand Up @@ -475,7 +556,7 @@ def remove_directory(path: PathType, client: Optional[ClientType] = None):
assert client is not None, "Could not get S3 client"

for fn in recursively_list_files(
path=path, ignore_hidden_files=False, client=client
path=path, ignore_hidden=False, client=client
):
remove_file(fn, client=client)

Expand Down

0 comments on commit b1162e9

Please sign in to comment.