diff --git a/tosfs/core.py b/tosfs/core.py index 7f9621e..2eb8273 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -18,7 +18,7 @@ import mimetypes import os import time -from typing import Any, BinaryIO, List, Optional, Tuple, Union +from typing import Any, BinaryIO, Generator, List, Optional, Tuple, Union import tos from fsspec import AbstractFileSystem @@ -65,6 +65,7 @@ class TosFileSystem(AbstractFileSystem): abstract super-class for pythonic file-systems. """ + protocol = ("tos", "tosfs") retries = 5 default_block_size = 5 * 2**20 @@ -676,6 +677,42 @@ def _read_chunks(body: BinaryIO, f: BinaryIO) -> None: e, ) + def walk( + self, + path: str, + maxdepth: Optional[int] = None, + topdown: bool = True, + on_error: str = "omit", + **kwargs: Any, + ) -> Generator[str, List[str], List[str]]: + """List objects under the given path. + + Parameters + ---------- + path : str + The path to list. + maxdepth : int, optional + The maximum depth to walk to (default is None). + topdown : bool, optional + Whether to walk top-down or bottom-up (default is True). + on_error : str, optional + How to handle errors (default is 'omit'). + **kwargs : Any, optional + Additional arguments. + + Raises + ------ + ValueError + If the path is an invalid path. + + """ + if path in ["", "*"] + ["{}://".format(p) for p in self.protocol]: + raise ValueError("Cannot access all of TOS via path {}.".format(path)) + + return super().walk( + path, maxdepth=maxdepth, topdown=topdown, on_error=on_error, **kwargs + ) + def _open_remote_file( self, bucket: str, diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index 3a87e05..5b1cfd7 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -343,6 +343,79 @@ def test_get_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) - tosfs.rm_file(rpath) +def test_walk(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None: + with pytest.raises(ValueError, match="Cannot access all of TOS via path ."): + tosfs.walk(path="") + + with pytest.raises(ValueError, match="Cannot access all of TOS via path *."): + tosfs.walk(path="*") + + with pytest.raises(ValueError, match="Cannot access all of TOS via path tos://."): + tosfs.walk("tos://") + + for root, dirs, files in list(tosfs.walk("/", maxdepth=1)): + assert root == "" + assert len(dirs) > 0 + assert files == [] + + for root, dirs, files in tosfs.walk(bucket, maxdepth=1): + assert root == bucket + assert len(dirs) > 0 + assert len(files) > 0 + + dir_name = random_str() + sub_dir_name = random_str() + file_name = random_str() + sub_file_name = random_str() + + tosfs.makedirs(f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}") + tosfs.touch(f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}") + tosfs.touch( + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}" + ) + + walk_results = list(tosfs.walk(f"{bucket}/{temporary_workspace}")) + + assert walk_results[0][0] == f"{bucket}/{temporary_workspace}" + assert dir_name in walk_results[0][1] + assert walk_results[0][2] == [] + + assert walk_results[1][0] == f"{bucket}/{temporary_workspace}/{dir_name}" + assert sub_dir_name in walk_results[1][1] + assert file_name in walk_results[1][2] + + assert ( + walk_results[2][0] + == f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}" + ) + assert walk_results[2][1] == [] + assert sub_file_name in walk_results[2][2] + + walk_results = list(tosfs.walk(f"{bucket}/{temporary_workspace}", topdown=False)) + assert ( + walk_results[0][0] + == f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}" + ) + assert walk_results[0][1] == [] + assert sub_file_name in walk_results[0][2] + + assert walk_results[1][0] == f"{bucket}/{temporary_workspace}/{dir_name}" + assert sub_dir_name in walk_results[1][1] + assert file_name in walk_results[1][2] + + assert walk_results[2][0] == f"{bucket}/{temporary_workspace}" + assert dir_name in walk_results[2][1] + assert walk_results[2][2] == [] + + tosfs.rm_file( + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}" + ) + tosfs.rm_file(f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}") + tosfs.rmdir(f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}") + tosfs.rmdir(f"{bucket}/{temporary_workspace}/{dir_name}") + tosfs.rmdir(f"{bucket}/{temporary_workspace}") + + ########################################################### # File operation tests # ###########################################################