diff --git a/pyproject.toml b/pyproject.toml index 8a5be29..2344cd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.19.4" +version = "0.19.5" description = """\ SMASHED is a toolkit designed to apply transformations to samples in \ datasets, such as fields extraction, tokenization, prompting, batching, \ diff --git a/src/smashed/utils/io_utils.py b/src/smashed/utils/io_utils.py index 07c3190..6174e5e 100644 --- a/src/smashed/utils/io_utils.py +++ b/src/smashed/utils/io_utils.py @@ -34,6 +34,9 @@ __all__ = [ "copy_directory", + "exists", + "is_dir", + "is_file", "open_file_for_read", "open_file_for_write", "recursively_list_files", @@ -274,6 +277,66 @@ def open_file_for_read( remove_local_file(str(path)) +def is_dir( + path: PathType, + client: Optional[ClientType] = None, + raise_if_not_exists: bool = False, +) -> bool: + """Check if a path is a directory.""" + + path = MultiPath.parse(path) + client = client or get_client_if_needed(path) + + if path.is_local: + if not (e := path.as_path.exists()) and raise_if_not_exists: + raise FileNotFoundError(f"Path does not exist: {path}") + elif not e: + return False + return path.as_path.is_dir() + elif path.is_s3: + assert client is not None, "Could not get S3 client" + resp = client.list_objects_v2( + Bucket=path.bucket, Prefix=path.key.lstrip("/"), Delimiter="/" + ) + if "CommonPrefixes" in resp: + return True + elif "Contents" in resp: + return False + elif raise_if_not_exists: + raise FileNotFoundError(f"Path does not exist: {path}") + return False + else: + raise FileNotFoundError(f"Unsupported protocol: {path.prot}") + + +def is_file( + path: PathType, + client: Optional[ClientType] = None, + raise_if_not_exists: bool = False, +) -> bool: + """Check if a path is a file.""" + + try: + return not is_dir(path=path, client=client, raise_if_not_exists=True) + except FileNotFoundError as e: + if raise_if_not_exists: + raise FileNotFoundError(f"Path does not exist: {path}") from e + return False + + +def exists( + path: PathType, + client: Optional[ClientType] = None, +) -> bool: + """Check if a path exists""" + + try: + is_dir(path=path, client=client, raise_if_not_exists=True) + return True + except FileNotFoundError: + return False + + @contextmanager def open_file_for_write( path: PathType, @@ -350,7 +413,9 @@ def open_file_for_write( def recursively_list_files( path: PathType, - ignore_hidden_files: bool = True, + ignore_hidden: bool = True, + include_dirs: bool = False, + include_files: bool = True, client: Optional[ClientType] = None, ) -> Iterable[str]: """Recursively list all files in the given directory for a given @@ -358,8 +423,14 @@ def recursively_list_files( Args: path (Union[str, Path, MultiPath]): The path to list content at. - ignore_hidden_files (bool, optional): Whether to ignore hidden files - (i.e. files that start with a dot) when listing. Defaults to True. + ignore_hidden (bool, optional): Whether to ignore hidden files and + directories when listing. Defaults to True. + include_dirs (bool, optional): Whether to include directories in the + listing. Defaults to False. + include_files (bool, optional): Whether to include files in the + listing. Defaults to True. + client (boto3.client, optional): The boto3 client to use. If not + provided, one will be created if necessary. """ path = MultiPath.parse(path) @@ -377,17 +448,27 @@ def recursively_list_files( for page in pages: for obj in page["Contents"]: key = obj["Key"] - if key[-1] == "/": # last char is a slash + path = MultiPath(prot="s3", root=path.root, path=key) + if key[-1] == "/" and key != prefix: + # last char is a slash, so it's a directory + # we don't want to re-include the prefix though, so we + # check that it's not the same prefixes.append(key) + if include_dirs: + yield str(path) else: - p = MultiPath(prot="s3", root=path.root, path=key) - yield str(p) + if include_files: + yield str(path) if path.is_local: - for _root, _, files in local_walk(path.as_str): + for _root, dirnames, filenames in local_walk(path.as_str): root = Path(_root) - for f in files: - if ignore_hidden_files and f.startswith("."): + to_list = [ + *(dirnames if include_dirs else []), + *(filenames if include_files else []), + ] + for f in to_list: + if ignore_hidden and f.startswith("."): continue yield str(MultiPath.parse(root / f)) @@ -423,7 +504,7 @@ def copy_directory( client = client or get_client_if_needed(src) or get_client_if_needed(dst) for sp in recursively_list_files( - path=src, ignore_hidden_files=ignore_hidden_files + path=src, ignore_hidden=ignore_hidden_files ): # parse the source path source_path = MultiPath.parse(sp) @@ -475,7 +556,7 @@ def remove_directory(path: PathType, client: Optional[ClientType] = None): assert client is not None, "Could not get S3 client" for fn in recursively_list_files( - path=path, ignore_hidden_files=False, client=client + path=path, ignore_hidden=False, client=client ): remove_file(fn, client=client)