diff --git a/one/alf/path.py b/one/alf/path.py index 0317d683..8b13bc0d 100644 --- a/one/alf/path.py +++ b/one/alf/path.py @@ -379,14 +379,12 @@ def get_alf_path(path: Union[str, pathlib.Path]) -> str: path = path.strip('/') # Check if session path - match_session = spec.regex(SESSION_SPEC).search(path) - if match_session: + if match_session := spec.regex(SESSION_SPEC).search(path): return path[match_session.start():] # Check if filename / relative path (i.e. collection + filename) parts = path.rsplit('/', 1) - match_filename = spec.regex(FILE_SPEC).match(parts[-1]) - if match_filename: + if spec.regex(FILE_SPEC).match(parts[-1]): return path if spec.regex(f'{COLLECTION_SPEC}{FILE_SPEC}').match(path) else parts[-1] diff --git a/one/api.py b/one/api.py index 28a1c09a..c97196a2 100644 --- a/one/api.py +++ b/one/api.py @@ -14,6 +14,7 @@ import time import threading import os +import re import pandas as pd import numpy as np @@ -2422,8 +2423,13 @@ def _download_aws(self, dsets, update_exists=True, keep_uuid=None, **_) -> List[ self._cache['_meta']['modified_time'] = datetime.now() out_files.append(None) continue - assert record['relative_path'].endswith(dset['rel_path']), \ - f'Relative path for dataset {uuid} does not match Alyx record' + if 'relation' in dset: + # For non-session datasets the pandas record rel path is the full path + matches = dset['rel_path'].endswith(record['relative_path']) + else: + # For session datasets the pandas record rel path is relative to the session + matches = record['relative_path'].endswith(dset['rel_path']) + assert matches, f'Relative path for dataset {uuid} does not match Alyx record' source_path = PurePosixPath(record['data_repository_path'], record['relative_path']) local_path = self.cache_dir.joinpath(alfiles.get_alf_path(source_path)) # Add UUIDs to filenames, if required @@ -2532,6 +2538,9 @@ def _download_dataset( target_dir = [] for x in valid_urls: _path = urllib.parse.urlsplit(x, allow_fragments=False).path.strip('/') + # Since rel_path for public FI file records starts with 'public/aggregates' instead of + # 'aggregates', we should discard the file path parts before 'aggregates' (if present) + _path = re.sub(r'^[\w\/]+(?=aggregates\/)', '', _path, count=1) target_dir.append(str(Path(cache_dir, alfiles.get_alf_path(_path)).parent)) files = self._download_file(valid_urls, target_dir, **kwargs) # Return list of file paths or None if we failed to extract URL from dataset diff --git a/one/converters.py b/one/converters.py index 6879d0c3..3c040802 100644 --- a/one/converters.py +++ b/one/converters.py @@ -319,10 +319,15 @@ def record2url(self, record): else: raise TypeError( f'record must be pandas.DataFrame or pandas.Series, got {type(record)} instead') - assert isinstance(record.name, tuple) and len(record.name) == 2 - eid, uuid = record.name # must be (eid, did) - session_path = self.eid2path(eid) - url = PurePosixALFPath(get_alf_path(session_path), record['rel_path']) + if 'session_path' in record: + # Check for session_path field (aggregate datasets have no eid in name) + session_path = record['session_path'] + uuid = record.name if isinstance(record.name, str) else record.name[-1] + else: + assert isinstance(record.name, tuple) and len(record.name) == 2 + eid, uuid = record.name # must be (eid, did) + session_path = get_alf_path(self.eid2path(eid)) + url = PurePosixALFPath(session_path, record['rel_path']) return webclient.rel_path2url(url.with_uuid(uuid).as_posix()) def record2path(self, dataset) -> Optional[ALFPath]: diff --git a/one/tests/test_one.py b/one/tests/test_one.py index 661d8c01..c21862ed 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -29,7 +29,7 @@ import datetime import logging import time -from pathlib import Path +from pathlib import Path, PurePosixPath from itertools import permutations, combinations_with_replacement from functools import partial import unittest @@ -1423,6 +1423,30 @@ def test_load_aggregate(self): with self.assertRaises(alferr.ALFObjectNotFound): self.one.load_aggregate('subjects', 'ZM_1085', 'foo.bar') + # Test download file from HTTP dataserver + expected.unlink() + self.one.mode = 'remote' # Can't download in local mode + with mock.patch.object(self.one, '_download_file', return_value=[expected]) as m, \ + mock.patch.object(self.one, '_download_aws', side_effect=AssertionError): + file = self.one.load_aggregate('subjects', 'ZM_1085', dset, download_only=True) + # Check correct url passed to download_file + self.assertEqual(expected, file) + expected_src = (self.one.alyx._par.HTTP_DATA_SERVER + + '/aggregates/Subjects/mainenlab/ZM_1085/' + + '_ibl_subjectTraining.table.74dfb745-a7dc-4672-ace6-b556876c80cb.pqt') + expected_dst = str(file.parent) # should be without 'public' part + m.assert_called_once_with([expected_src], [expected_dst], keep_uuid=False) + # Test download file from AWS + with mock.patch('one.remote.aws.s3_download_file', return_value=expected) as m, \ + mock.patch.object(self.one, '_download_file', side_effect=AssertionError): + file = self.one.load_aggregate('subjects', 'ZM_1085', dset, download_only=True) + # Check correct url passed to download_file + self.assertEqual(expected, file) + expected_src = PurePosixPath( + expected_src[len(self.one.alyx._par.HTTP_DATA_SERVER) + 1:]) + m.assert_called_once_with( + expected_src, expected, s3=mock.ANY, bucket_name='s3_bucket', overwrite=True) + @classmethod def tearDownClass(cls) -> None: cls.tempdir.cleanup()