Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recursive and batch yield for ls_iterate API #256

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,9 @@ def ls_iterate(
detail: bool = False,
versions: bool = False,
batch_size: int = LS_OPERATION_DEFAULT_MAX_ITEMS,
recursive: bool = False,
**kwargs: Union[str, bool, float, None],
) -> Generator[Union[dict, str], None, None]:
) -> Generator[Union[List[dict], List[str]], None, None]:
"""List objects under the given path in batches then returns an iterator.

Parameters
Expand All @@ -415,6 +416,8 @@ def ls_iterate(
Whether to list object versions (default is False).
batch_size : int, optional
The number of items to fetch in each batch (default is 1000).
recursive : bool, optional
Whether to list objects recursively (default is False).
**kwargs : dict, optional
Additional arguments.

Expand Down Expand Up @@ -450,7 +453,7 @@ def _call_list_objects_type2(
bucket,
prefix,
start_after=prefix,
delimiter="/",
delimiter=None if recursive else "/",
max_keys=batch_size,
continuation_token=continuation_token,
)
Expand All @@ -464,6 +467,7 @@ def _call_list_objects_type2(
continuation_token = resp.next_continuation_token
results = resp.contents + resp.common_prefixes

batch = []
for obj in results:
if isinstance(obj, CommonPrefixInfo):
info = self._fill_dir_info(bucket, obj)
Expand All @@ -472,7 +476,9 @@ def _call_list_objects_type2(
else:
info = self._fill_file_info(obj, bucket, versions)

yield info if detail else info["name"]
batch.append(info if detail else info["name"])

yield batch

def info(
self,
Expand Down
45 changes: 35 additions & 10 deletions tosfs/tests/test_tosfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,55 @@ def test_ls_iterate(
)

# Test listing without detail
result = list(tosfs.ls_iterate(f"{bucket}/{temporary_workspace}"))
result = [
item
for batch in tosfs.ls_iterate(f"{bucket}/{temporary_workspace}")
for item in batch
]
assert f"{bucket}/{temporary_workspace}/{dir_name}" in result

# Test listing with detail
result = list(tosfs.ls_iterate(f"{bucket}/{temporary_workspace}", detail=True))
result = [
item
for batch in tosfs.ls_iterate(f"{bucket}/{temporary_workspace}", detail=True)
for item in batch
]
assert any(
item["name"] == f"{bucket}/{temporary_workspace}/{dir_name}" for item in result
)

# Test list with iterate
for item in tosfs.ls_iterate(f"{bucket}/{temporary_workspace}", detail=True):
assert item["name"] in sorted(
[
f"{bucket}/{temporary_workspace}/{dir_name}",
f"{bucket}/{temporary_workspace}/{another_dir_name}",
]
)
for batch in tosfs.ls_iterate(f"{bucket}/{temporary_workspace}", detail=True):
for item in batch:
assert item["name"] in sorted(
[
f"{bucket}/{temporary_workspace}/{dir_name}",
f"{bucket}/{temporary_workspace}/{another_dir_name}",
]
)

# Test listing with batch size and while loop more than one time
result = []
for batch in tosfs.ls_iterate(f"{bucket}/{temporary_workspace}", batch_size=1):
result.append(batch)
for item in batch:
result.append(item)
assert len(result) == len([dir_name, another_dir_name])

# Test list recursively
expected = [
f"{bucket}/{temporary_workspace}/{dir_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}",
f"{bucket}/{temporary_workspace}/{another_dir_name}",
]
result = [
item
for batch in tosfs.ls_iterate(f"{bucket}/{temporary_workspace}", recursive=True)
for item in batch
]
assert sorted(result) == sorted(expected)


def test_inner_rm(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None:
file_name = random_str()
Expand Down