diff --git a/mltb2/files.py b/mltb2/files.py index fbfa77f..22e628a 100644 --- a/mltb2/files.py +++ b/mltb2/files.py @@ -9,6 +9,7 @@ """ +import contextlib import os from typing import Optional @@ -45,5 +46,10 @@ def fetch_remote_file(dirname, filename, url, sha256_checksum) -> str: IOError: if the sha256 checksum is wrong """ remote = RemoteFileMetadata(filename=filename, url=url, checksum=sha256_checksum) - fetch_remote_file_path = _fetch_remote(remote, dirname=dirname) + try: + fetch_remote_file_path = _fetch_remote(remote, dirname=dirname) + except Exception: + with contextlib.suppress(FileNotFoundError): + os.remove(os.path.join(dirname, filename)) + raise return fetch_remote_file_path diff --git a/tests/test_files.py b/tests/test_files.py index b7bbe61..25fe578 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -18,6 +18,7 @@ def test_fetch_remote_file(tmpdir): sha256_checksum="8d834de97b095fbf4bf6075743827862be2c6c404594ae04606d9c56d8f1017b", ) assert remote_file == os.path.join(tmpdir, filename) + assert os.path.exists(os.path.join(tmpdir, filename)) def test_fetch_remote_file_wrong_checksum(tmpdir): @@ -29,6 +30,7 @@ def test_fetch_remote_file_wrong_checksum(tmpdir): url="https://raw.githubusercontent.com/telekom/mltb2/main/LICENSE", sha256_checksum="wrong", ) + assert not os.path.exists(os.path.join(tmpdir, filename)) def test_get_and_create_mltb2_data_dir(tmpdir):