Skip to content

Commit

Permalink
fixup! Add functionality for setting custom names
Browse files Browse the repository at this point in the history
  • Loading branch information
marcellevstek committed Oct 17, 2024
1 parent 8c27660 commit b9661da
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 46 deletions.
29 changes: 5 additions & 24 deletions src/resdk/resolwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def _download_files(
files: List[Union[str, Path]],
download_dir=None,
show_progress=True,
custom_file_names: Optional[List[str]] = None,
):
"""Download files.
Expand All @@ -475,7 +474,6 @@ def _download_files(
:param files: files to download
:param download_dir: download directory
:param custom_file_names: list of file names to save the downloaded files as
"""
if not download_dir:
Expand All @@ -486,14 +484,6 @@ def _download_files(
"Download directory does not exist: {}".format(download_dir)
)

if not custom_file_names:
custom_file_names = len(files) * [None]
else:
if not len(files) == len(custom_file_names):
raise ValueError(
"Number of files and their corresponding custom file names must be equal."
)

if not files:
self.logger.info("No files to download.")
else:
Expand All @@ -503,7 +493,7 @@ def _download_files(
sizes: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
checksums: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))

for file_uri, custom_file_name in zip(files, custom_file_names):
for file_uri in files:
file_name = os.path.basename(file_uri)
file_path = os.path.dirname(file_uri)
file_url = urljoin(self.url, "data/{}".format(file_uri))
Expand All @@ -527,19 +517,12 @@ def _download_files(

file_size = sizes[file_directory][file_name]

if custom_file_name:
desc = f"Downloading file {file_name} as {custom_file_name}"
actual_file_name = custom_file_name
else:
desc = f"Downloading file {file_name}"
actual_file_name = file_name

with tqdm.tqdm(
total=file_size,
disable=not show_progress,
desc=desc,
desc=f"Downloading file {file_name}",
) as progress_bar, open(
os.path.join(download_dir, file_path, actual_file_name), "wb"
os.path.join(download_dir, file_path, file_name), "wb"
) as file_handle:
response = self.session.get(file_url, stream=True, auth=self.auth)

Expand All @@ -556,12 +539,10 @@ def _download_files(
# checksums that are difficult to reproduce here.
return
expected_md5 = checksums[file_directory][file_name]
computed_md5 = md5(
os.path.join(download_dir, file_path, actual_file_name)
)
computed_md5 = md5(os.path.join(download_dir, file_path, file_name))
if expected_md5 != computed_md5:
raise ValueError(
f"Checksum ({computed_md5}) of downloaded file {actual_file_name} does not match the expected value of {expected_md5}."
f"Checksum ({computed_md5}) of downloaded file {file_name} does not match the expected value of {expected_md5}."
)

def data_usage(self, **query_params):
Expand Down
2 changes: 1 addition & 1 deletion src/resdk/resources/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def download(self, file_name=None, field_name=None, download_dir=None):
data_files = data.files(file_name, field_name)
files.extend("{}/{}".format(data.id, file_name) for file_name in data_files)

self.resolwe._download_files(files=files, download_dir=download_dir)
self.resolwe._download_files(files, download_dir)


class Collection(CollectionRelationsMixin, BaseCollection):
Expand Down
63 changes: 47 additions & 16 deletions src/resdk/resources/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import os
from typing import Optional
from urllib.parse import urljoin

Expand All @@ -16,6 +17,12 @@
from .sample import Sample
from .utils import flatten_field, parse_resolwe_datetime

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
)


class Data(BaseResolweResource):
"""Resolwe Data resource.
Expand Down Expand Up @@ -329,7 +336,6 @@ def download(
file_name: Optional[str] = None,
field_name: Optional[str] = None,
download_dir: Optional[str] = None,
custom_file_name: Optional[str] = None,
):
"""Download Data object's files and directories.
Expand All @@ -343,9 +349,6 @@ def download(
* re.data.get(42).download(file_name='alignment7.bam')
* re.data.get(42).download(field_name='bam')
If custom_file_name is provided, the file will be saved with that name,
provided that either field_name or file_name is also specified.
"""
if file_name and field_name:
raise ValueError("Only one of file_name or field_name may be given.")
Expand All @@ -354,20 +357,48 @@ def download(
"{}/{}".format(self.id, fname)
for fname in self.files(file_name, field_name)
]
file_names = [fname for fname in self.files(file_name, field_name)]

# Only applies if downloading a single file
custom_file_names = None
if custom_file_name:
if file_name or field_name:
custom_file_names = [custom_file_name] * len(files)
else:
raise ValueError(
"Setting a custom file name is not supported "
"without specifying file name or field name."
)
self.resolwe._download_files(files=files, download_dir=download_dir)

return file_names

def download_and_rename(
self,
custom_file_name: str,
field_name: Optional[str] = None,
file_name: Optional[str] = None,
download_dir: Optional[str] = None,
):
"""Download and rename a single file from data object."""

if download_dir is None:
download_dir = os.getcwd()

new_file_path = os.path.join(download_dir, custom_file_name)

if os.path.exists(new_file_path):
logging.warning(
f"File with path '{new_file_path}' already exists. Skipping download."
)
return

file_names = self.download(
file_name=file_name,
field_name=field_name,
download_dir=download_dir,
)
if len(file_names) != 1:
raise ValueError(
f"Expected one file to be downloaded, but got {len(file_names)}"
)
og_file_name = file_names[0]
og_file_path = os.path.join(download_dir, og_file_name)

self.resolwe._download_files(
files=files, download_dir=download_dir, custom_file_names=custom_file_names
logging.info(f"Renaming file '{og_file_name}' to '{custom_file_name}'.")
os.rename(
og_file_path,
new_file_path,
)

def stdout(self):
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,14 @@ def test_download_ok(self, data_mock):

data_mock.reset_mock()
data_mock.files.return_value = ["file1.txt"]
Data.download(
Data.download_and_rename(
data_mock,
custom_file_name="text_file1.txt",
field_name="txt",
download_dir="/some/path/",
)
data_mock.resolwe._download_files.assert_called_once_with(
files=["123/file1.txt"],
download_dir="/some/path/",
custom_file_names=["text_file1.txt"],
data_mock.download.assert_called_once_with(
file_name=None, field_name="txt", download_dir="/some/path/"
)

@patch("resdk.resolwe.Resolwe")
Expand Down

0 comments on commit b9661da

Please sign in to comment.