Skip to content

Commit

Permalink
add verbosity and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SiQube committed Jan 2, 2025
1 parent c06ded8 commit 1477ee3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 18 deletions.
33 changes: 23 additions & 10 deletions src/pymovements/utils/archives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from pathlib import Path
from typing import IO

from tqdm import tqdm

from pymovements.utils.paths import get_filepaths


Expand Down Expand Up @@ -93,7 +95,7 @@ def extract_archive(
print(f'Extracting {source_path.name} to {destination_path}')

# Extract file and remove archive if desired.
extractor(source_path, destination_path, compression_type, resume)
extractor(source_path, destination_path, compression_type, resume, verbose)
if remove_finished:
source_path.unlink()

Expand Down Expand Up @@ -142,7 +144,8 @@ def _extract_tar(
source_path: Path,
destination_path: Path,
compression: str | None,
resume: bool = True,
resume: bool,
verbose: int,
) -> None:
"""Extract a tar archive.
Expand All @@ -155,19 +158,23 @@ def _extract_tar(
compression: str | None
Compression filename suffix.
resume: bool
Resume previous extraction. (default: True)
Resume if archive was already previous extracted.
verbose: int
Print messages for resuming each dataset resource.
"""
with tarfile.open(source_path, f'r:{compression[1:]}' if compression else 'r') as archive:
for member in archive.getmembers():
for member in tqdm(archive.getmembers()):
member_name = member.name
member_size = member.size
member_dest_path = os.path.join(destination_path, member_name)
if resume:
if (
os.path.exists(member_dest_path) and
member_name[-4:] not in _ARCHIVE_EXTRACTORS and
member_name[-4:] in _ARCHIVE_EXTRACTORS and
member_size == os.path.getsize(member_dest_path)
):
if verbose:
print(f'Skipping {member_name} due to previous extraction')
continue
if sys.version_info < (3, 12): # pragma: <3.12 cover
archive.extract(member_name, destination_path)
Expand All @@ -180,6 +187,7 @@ def _extract_zip(
destination_path: Path,
compression: str | None,
resume: bool,
verbose: int,
) -> None:
"""Extract a zip archive.
Expand All @@ -191,27 +199,32 @@ def _extract_zip(
Path to the directory the file will be extracted to.
compression: str | None
Compression filename suffix.
resume: bool
Resume if archive was already previous extracted.
verbose: int
Print messages for resuming each dataset resource.
"""
compression_id = _ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
with zipfile.ZipFile(source_path, 'r', compression=compression_id) as archive:
for member in archive.filelist:
for member in tqdm(archive.filelist):
member_filename = member.filename
member_dest_path = os.path.join(destination_path, member_filename)
if resume:
member_size = member.file_size
if (
os.path.exists(member_dest_path) and
member_filename[-4:] not in _ARCHIVE_EXTRACTORS and
member_filename[-4:] in _ARCHIVE_EXTRACTORS and
member_size == os.path.getsize(member_dest_path)
):
if verbose:
print(f'Skipping {member_filename} due to previous extraction')
continue
else:
archive.extract(member_filename, destination_path)
archive.extract(member_filename, destination_path)
else:
archive.extract(member_filename, destination_path)


_ARCHIVE_EXTRACTORS: dict[str, Callable[[Path, Path, str | None], None]] = {
_ARCHIVE_EXTRACTORS: dict[str, Callable[[Path, Path, str | None, bool, int], None]] = {
'.tar': _extract_tar,
'.zip': _extract_zip,
}
Expand Down
51 changes: 43 additions & 8 deletions tests/unit/utils/archives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ def fixture_unsupported_archive(request, tmp_path):
],
)
def test_extract_archive_destination_path_None(
recursive, remove_finished, remove_top_level, expected_files, archive,
recursive,
remove_finished,
remove_top_level,
expected_files,
archive,
):
extract_archive(
source_path=archive,
Expand Down Expand Up @@ -409,7 +413,12 @@ def test_extract_unsupported_archive_destination_path_None(
],
)
def test_extract_archive_destination_path_not_None(
recursive, remove_finished, remove_top_level, archive, tmp_path, expected_files,
recursive,
remove_finished,
remove_top_level,
archive,
tmp_path,
expected_files,
):
destination_path = tmp_path / pathlib.Path('tmpfoo')
extract_archive(
Expand Down Expand Up @@ -439,7 +448,10 @@ def test_extract_archive_destination_path_not_None(
],
)
def test_extract_compressed_file_destination_path_not_None(
recursive, remove_finished, compressed_file, tmp_path,
recursive,
remove_finished,
compressed_file,
tmp_path,
):
destination_filename = 'tmpfoo'
destination_path = tmp_path / pathlib.Path(destination_filename)
Expand Down Expand Up @@ -500,7 +512,14 @@ def test_decompress_unknown_compression_suffix():


@pytest.mark.parametrize(
('recursive', 'remove_top_level', 'expected_files', 'resume'),
('resume'),
[
pytest.param(True, id='resume_True'),
pytest.param(False, id='resume_False'),
],
)
@pytest.mark.parametrize(
('recursive', 'remove_top_level', 'expected_files'),
[
pytest.param(
False,
Expand All @@ -509,7 +528,6 @@ def test_decompress_unknown_compression_suffix():
'toplevel',
os.path.join('toplevel', 'recursive.zip'),
),
True,
id='recursive_false_remove_finished_false',
),
pytest.param(
Expand All @@ -522,22 +540,36 @@ def test_decompress_unknown_compression_suffix():
os.path.join('toplevel', 'recursive', 'singlechild'),
os.path.join('toplevel', 'recursive', 'singlechild', 'test.file'),
),
False,
id='recursive_true_remove_finished_false',
),
],
)
@pytest.mark.parametrize(
('verbose'),
[
pytest.param(True, id='verbose_True'),
pytest.param(False, id='verbose_False'),
],
)
def test_extract_archive_destination_path_not_None_no_remove_top_level_no_remove_finished_twice(
recursive, remove_top_level, archive, tmp_path, expected_files, resume,
verbose,
recursive,
remove_top_level,
archive,
tmp_path,
resume,
expected_files,
capsys,
):
destination_path = tmp_path / pathlib.Path('tmpfoo')
destination_path = tmp_path / pathlib.Path('tmp')
extract_archive(
source_path=archive,
destination_path=destination_path,
recursive=recursive,
remove_finished=False,
remove_top_level=remove_top_level,
resume=resume,
verbose=verbose,
)
extract_archive(
source_path=archive,
Expand All @@ -546,7 +578,10 @@ def test_extract_archive_destination_path_not_None_no_remove_top_level_no_remove
remove_finished=False,
remove_top_level=remove_top_level,
resume=resume,
verbose=verbose,
)
if resume and verbose:
assert 'Skipping' in capsys.readouterr().out

if destination_path.is_file():
destination_path = destination_path.parent
Expand Down

0 comments on commit 1477ee3

Please sign in to comment.