Skip to content

Commit

Permalink
test for actual size and not only > 0
Browse files Browse the repository at this point in the history
  • Loading branch information
SiQube committed Dec 28, 2024
1 parent 71e3104 commit 49cc52d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 25 deletions.
58 changes: 37 additions & 21 deletions src/pymovements/utils/archives.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def extract_archive(
remove_finished: bool = False,
remove_top_level: bool = True,
verbose: int = 1,
resume: bool = False,
) -> Path:
"""Extract an archive.
Expand All @@ -65,6 +66,8 @@ def extract_archive(
Verbosity levels: (1) Print messages for extracting each dataset resource without printing
messages for recursive archives. (2) Print additional messages for each recursive archive
extract. (default: 1)
resume: bool
Resume previous extraction. (default: True)
Returns
-------
Expand All @@ -90,7 +93,7 @@ def extract_archive(
print(f'Extracting {source_path.name} to {destination_path}')

# Extract file and remove archive if desired.
extractor(source_path, destination_path, compression_type)
extractor(source_path, destination_path, compression_type, resume)
if remove_finished:
source_path.unlink()

Expand Down Expand Up @@ -129,6 +132,7 @@ def extract_archive(
remove_finished=remove_finished,
remove_top_level=remove_top_level,
verbose=0 if verbose < 2 else 2,
resume=resume,
)

return destination_path
Expand All @@ -138,7 +142,7 @@ def _extract_tar(
source_path: Path,
destination_path: Path,
compression: str | None,
skip: bool = True,
resume: bool = True,
) -> None:
"""Extract a tar archive.
Expand All @@ -150,28 +154,32 @@ def _extract_tar(
Path to the directory the file will be extracted to.
compression: str | None
Compression filename suffix.
skip: bool
Skip already extracted files. (default: True)
resume: bool
Resume previous extraction. (default: True)
"""
with tarfile.open(source_path, f'r:{compression[1:]}' if compression else 'r') as archive:
for member in archive.getnames():
if (
os.path.exists(os.path.join(destination_path, member)) and
member[-4:] not in _ARCHIVE_EXTRACTORS and
tarfile.TarInfo(os.path.join(destination_path, member)).size > 0 and
skip
):
continue
for member in archive.getmembers():
member_name = member.name
member_size = member.size
member_dest_path = os.path.join(destination_path, member_name)
if resume:
if (
os.path.exists(member_dest_path) and
member_name[-4:] not in _ARCHIVE_EXTRACTORS and
member_size == os.path.getsize(member_dest_path)
):
continue
if sys.version_info < (3, 12): # pragma: <3.12 cover
archive.extract(member, destination_path)
archive.extract(member_name, destination_path)
else: # pragma: >=3.12 cover
archive.extract(member, destination_path, filter='tar')
archive.extract(member_name, destination_path, filter='tar')


def _extract_zip(
source_path: Path,
destination_path: Path,
compression: str | None,
resume: bool,
) -> None:
"""Extract a zip archive.
Expand All @@ -186,13 +194,21 @@ def _extract_zip(
"""
compression_id = _ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
with zipfile.ZipFile(source_path, 'r', compression=compression_id) as archive:
for member in archive.namelist():
if (
os.path.exists(os.path.join(destination_path, member)) and
member[-4:] not in _ARCHIVE_EXTRACTORS
):
continue
archive.extract(member, destination_path)
for member in archive.filelist:
member_filename = member.filename
member_dest_path = os.path.join(destination_path, member_filename)
if resume:
member_size = member.file_size
if (
os.path.exists(member_dest_path) and
member_filename[-4:] not in _ARCHIVE_EXTRACTORS and
member_size == os.path.getsize(member_dest_path)
):
continue
else:
archive.extract(member_filename, destination_path)
else:
archive.extract(member_filename, destination_path)


_ARCHIVE_EXTRACTORS: dict[str, Callable[[Path, Path, str | None], None]] = {
Expand Down
14 changes: 10 additions & 4 deletions tests/unit/utils/archives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,31 +500,35 @@ def test_decompress_unknown_compression_suffix():


@pytest.mark.parametrize(
('recursive', 'remove_top_level', 'expected_files'),
('recursive', 'remove_top_level', 'expected_files', 'resume'),
[
pytest.param(
False, False,
False,
False,
(
'toplevel',
os.path.join('toplevel', 'recursive.zip'),
),
True,
id='recursive_false_remove_finished_false',
),
pytest.param(
True, False,
True,
False,
(
'toplevel',
os.path.join('toplevel', 'recursive.zip'),
os.path.join('toplevel', 'recursive'),
os.path.join('toplevel', 'recursive', 'singlechild'),
os.path.join('toplevel', 'recursive', 'singlechild', 'test.file'),
),
False,
id='recursive_true_remove_finished_false',
),
],
)
def test_extract_archive_destination_path_not_None_no_remove_top_level_no_remove_finished_twice(
recursive, remove_top_level, archive, tmp_path, expected_files,
recursive, remove_top_level, archive, tmp_path, expected_files, resume,
):
destination_path = tmp_path / pathlib.Path('tmpfoo')
extract_archive(
Expand All @@ -533,13 +537,15 @@ def test_extract_archive_destination_path_not_None_no_remove_top_level_no_remove
recursive=recursive,
remove_finished=False,
remove_top_level=remove_top_level,
resume=resume,
)
extract_archive(
source_path=archive,
destination_path=destination_path,
recursive=recursive,
remove_finished=False,
remove_top_level=remove_top_level,
resume=resume,
)

if destination_path.is_file():
Expand Down

0 comments on commit 49cc52d

Please sign in to comment.