Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SiQube committed Aug 24, 2024
1 parent a146910 commit 071e0bc
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/pymovements/utils/archives.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def _extract_tar(
"""
with tarfile.open(source_path, f'r:{compression[1:]}' if compression else 'r') as archive:
for member in archive.getnames():
if os.path.exists(os.path.join(destination_path, member)):
if (
os.path.exists(os.path.join(destination_path, member)) and
member[-4:] not in _ARCHIVE_EXTRACTORS
):
continue
if sys.version_info < (3, 12): # pragma: <3.12 cover
archive.extract(member, destination_path)
Expand All @@ -179,7 +182,10 @@ def _extract_zip(
compression_id = _ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
with zipfile.ZipFile(source_path, 'r', compression=compression_id) as archive:
for member in archive.namelist():
if os.path.exists(os.path.join(destination_path, member)):
if (
os.path.exists(os.path.join(destination_path, member)) and
member[-4:] not in _ARCHIVE_EXTRACTORS
):
continue
archive.extract(member, destination_path)

Expand Down
113 changes: 113 additions & 0 deletions tests/unit/utils/archives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def fixture_archive(request, tmp_path):
archive_path = rootpath / f'test.{compression}'
elif compression is not None and extension is not None:
archive_path = rootpath / f'test.{extension}.{compression}'
else:
raise ValueError(f'{request.param} not supported for archive fixture')

if compression is None and extension == 'zip':
with zipfile.ZipFile(archive_path, 'w') as zip_open:
Expand Down Expand Up @@ -495,3 +497,114 @@ def test_decompress_unknown_compression_suffix():
_decompress(pathlib.Path('test.zip.zip'))
msg, = excinfo.value.args
assert msg == "Couldn't detect a compression from suffix .zip."


# @pytest.mark.parametrize(
# ('recursive', 'remove_top_level', 'expected_files'),
# [
# pytest.param(
# False, False,
# (
# 'toplevel',
# os.path.join('toplevel', 'recursive.zip'),
# ),
# id='recursive_false_remove_finished_false',
# ),
# pytest.param(
# True, False,
# (
# 'toplevel',
# os.path.join('toplevel', 'recursive'),
# os.path.join('toplevel', 'recursive.zip'),
# os.path.join('toplevel', 'recursive', 'singlechild'),
# os.path.join('toplevel', 'recursive', 'singlechild', 'test.file'),
# ),
# id='recursive_true_remove_finished_false',
# ),
# pytest.param(
# False, True,
# (
# 'toplevel',
# os.path.join('toplevel', 'recursive.zip'),
# ),
# id='recursive_false_remove_top_level_true',
# ),
# pytest.param(
# True, True,
# (
# 'toplevel',
# os.path.join('toplevel', 'recursive'),
# os.path.join('toplevel', 'recursive.zip'),
# os.path.join('toplevel', 'recursive', 'test.file'),
# ),
# id='recursive_true_remove_top_level_true',
# ),
# ],
# )
@pytest.mark.parametrize(
('recursive', 'remove_top_level', 'expected_files'),
[
pytest.param(
False, False,
(
'toplevel',
os.path.join('toplevel', 'recursive.zip'),
),
id='recursive_false_remove_finished_false',
),
pytest.param(
True, False,
(
'toplevel',
os.path.join('toplevel', 'recursive.zip'),
os.path.join('toplevel', 'recursive'),
os.path.join('toplevel', 'recursive', 'singlechild'),
os.path.join('toplevel', 'recursive', 'singlechild', 'test.file'),
),
id='recursive_true_remove_finished_false',
),
# pytest.param(
# False, True,
# (
# 'toplevel',
# os.path.join('toplevel', 'recursive.zip'),
# ),
# id='recursive_false_remove_top_level_true',
# ),
# pytest.param(
# True, True,
# (
# 'toplevel',
# os.path.join('toplevel', 'recursive.zip'),
# os.path.join('toplevel', 'recursive'),
# os.path.join('toplevel', 'recursive', 'test.file'),
# ),
# id='recursive_true_remove_top_level_true',
# ),
],
)
def test_extract_archive_destination_path_not_None_no_remove_top_level_no_remove_finished_twice(
recursive, remove_top_level, archive, tmp_path, expected_files,
):
destination_path = tmp_path / pathlib.Path('tmpfoo')
extract_archive(
source_path=archive,
destination_path=destination_path,
recursive=recursive,
remove_finished=False,
remove_top_level=remove_top_level,
)
extract_archive(
source_path=archive,
destination_path=destination_path,
recursive=recursive,
remove_finished=False,
remove_top_level=remove_top_level,
)

if destination_path.is_file():
destination_path = destination_path.parent

result_files = {str(file.relative_to(destination_path)) for file in destination_path.rglob('*')}

assert result_files == set(expected_files)

0 comments on commit 071e0bc

Please sign in to comment.