From b9661da9bc5968300c690104a050d688f0951413 Mon Sep 17 00:00:00 2001 From: Marcel Levstek <62072754+marcellevstek@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:26:29 +0200 Subject: [PATCH] fixup! Add functionality for setting custom names --- src/resdk/resolwe.py | 29 +++----------- src/resdk/resources/collection.py | 2 +- src/resdk/resources/data.py | 63 +++++++++++++++++++++++-------- tests/unit/test_data.py | 8 ++-- 4 files changed, 56 insertions(+), 46 deletions(-) diff --git a/src/resdk/resolwe.py b/src/resdk/resolwe.py index 7aaa50a3..830f1042 100644 --- a/src/resdk/resolwe.py +++ b/src/resdk/resolwe.py @@ -466,7 +466,6 @@ def _download_files( files: List[Union[str, Path]], download_dir=None, show_progress=True, - custom_file_names: Optional[List[str]] = None, ): """Download files. @@ -475,7 +474,6 @@ def _download_files( :param files: files to download :param download_dir: download directory - :param custom_file_names: list of file names to save the downloaded files as """ if not download_dir: @@ -486,14 +484,6 @@ def _download_files( "Download directory does not exist: {}".format(download_dir) ) - if not custom_file_names: - custom_file_names = len(files) * [None] - else: - if not len(files) == len(custom_file_names): - raise ValueError( - "Number of files and their corresponding custom file names must be equal." - ) - if not files: self.logger.info("No files to download.") else: @@ -503,7 +493,7 @@ def _download_files( sizes: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) checksums: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) - for file_uri, custom_file_name in zip(files, custom_file_names): + for file_uri in files: file_name = os.path.basename(file_uri) file_path = os.path.dirname(file_uri) file_url = urljoin(self.url, "data/{}".format(file_uri)) @@ -527,19 +517,12 @@ def _download_files( file_size = sizes[file_directory][file_name] - if custom_file_name: - desc = f"Downloading file {file_name} as {custom_file_name}" - actual_file_name = custom_file_name - else: - desc = f"Downloading file {file_name}" - actual_file_name = file_name - with tqdm.tqdm( total=file_size, disable=not show_progress, - desc=desc, + desc=f"Downloading file {file_name}", ) as progress_bar, open( - os.path.join(download_dir, file_path, actual_file_name), "wb" + os.path.join(download_dir, file_path, file_name), "wb" ) as file_handle: response = self.session.get(file_url, stream=True, auth=self.auth) @@ -556,12 +539,10 @@ def _download_files( # checksums that are difficult to reproduce here. return expected_md5 = checksums[file_directory][file_name] - computed_md5 = md5( - os.path.join(download_dir, file_path, actual_file_name) - ) + computed_md5 = md5(os.path.join(download_dir, file_path, file_name)) if expected_md5 != computed_md5: raise ValueError( - f"Checksum ({computed_md5}) of downloaded file {actual_file_name} does not match the expected value of {expected_md5}." + f"Checksum ({computed_md5}) of downloaded file {file_name} does not match the expected value of {expected_md5}." ) def data_usage(self, **query_params): diff --git a/src/resdk/resources/collection.py b/src/resdk/resources/collection.py index a8e19598..e6fa19f7 100644 --- a/src/resdk/resources/collection.py +++ b/src/resdk/resources/collection.py @@ -140,7 +140,7 @@ def download(self, file_name=None, field_name=None, download_dir=None): data_files = data.files(file_name, field_name) files.extend("{}/{}".format(data.id, file_name) for file_name in data_files) - self.resolwe._download_files(files=files, download_dir=download_dir) + self.resolwe._download_files(files, download_dir) class Collection(CollectionRelationsMixin, BaseCollection): diff --git a/src/resdk/resources/data.py b/src/resdk/resources/data.py index a24c1eb5..ec5a9786 100644 --- a/src/resdk/resources/data.py +++ b/src/resdk/resources/data.py @@ -2,6 +2,7 @@ import json import logging +import os from typing import Optional from urllib.parse import urljoin @@ -16,6 +17,12 @@ from .sample import Sample from .utils import flatten_field, parse_resolwe_datetime +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%d-%b-%y %H:%M:%S", +) + class Data(BaseResolweResource): """Resolwe Data resource. @@ -329,7 +336,6 @@ def download( file_name: Optional[str] = None, field_name: Optional[str] = None, download_dir: Optional[str] = None, - custom_file_name: Optional[str] = None, ): """Download Data object's files and directories. @@ -343,9 +349,6 @@ def download( * re.data.get(42).download(file_name='alignment7.bam') * re.data.get(42).download(field_name='bam') - If custom_file_name is provided, the file will be saved with that name, - provided that either field_name or file_name is also specified. - """ if file_name and field_name: raise ValueError("Only one of file_name or field_name may be given.") @@ -354,20 +357,48 @@ def download( "{}/{}".format(self.id, fname) for fname in self.files(file_name, field_name) ] + file_names = [fname for fname in self.files(file_name, field_name)] - # Only applies if downloading a single file - custom_file_names = None - if custom_file_name: - if file_name or field_name: - custom_file_names = [custom_file_name] * len(files) - else: - raise ValueError( - "Setting a custom file name is not supported " - "without specifying file name or field name." - ) + self.resolwe._download_files(files=files, download_dir=download_dir) + + return file_names + + def download_and_rename( + self, + custom_file_name: str, + field_name: Optional[str] = None, + file_name: Optional[str] = None, + download_dir: Optional[str] = None, + ): + """Download and rename a single file from data object.""" + + if download_dir is None: + download_dir = os.getcwd() + + new_file_path = os.path.join(download_dir, custom_file_name) + + if os.path.exists(new_file_path): + logging.warning( + f"File with path '{new_file_path}' already exists. Skipping download." + ) + return + + file_names = self.download( + file_name=file_name, + field_name=field_name, + download_dir=download_dir, + ) + if len(file_names) != 1: + raise ValueError( + f"Expected one file to be downloaded, but got {len(file_names)}" + ) + og_file_name = file_names[0] + og_file_path = os.path.join(download_dir, og_file_name) - self.resolwe._download_files( - files=files, download_dir=download_dir, custom_file_names=custom_file_names + logging.info(f"Renaming file '{og_file_name}' to '{custom_file_name}'.") + os.rename( + og_file_path, + new_file_path, ) def stdout(self): diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 2cf17faa..6a8dedb4 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -193,16 +193,14 @@ def test_download_ok(self, data_mock): data_mock.reset_mock() data_mock.files.return_value = ["file1.txt"] - Data.download( + Data.download_and_rename( data_mock, custom_file_name="text_file1.txt", field_name="txt", download_dir="/some/path/", ) - data_mock.resolwe._download_files.assert_called_once_with( - files=["123/file1.txt"], - download_dir="/some/path/", - custom_file_names=["text_file1.txt"], + data_mock.download.assert_called_once_with( + file_name=None, field_name="txt", download_dir="/some/path/" ) @patch("resdk.resolwe.Resolwe")