diff --git a/api/data_handler.py b/api/data_handler.py index f3c744b..f30d9dc 100644 --- a/api/data_handler.py +++ b/api/data_handler.py @@ -10,6 +10,7 @@ import hashlib from urllib.request import urlopen, urlretrieve, HTTPError from urllib.parse import urlparse, urlunparse +import zipfile from fastapi import HTTPException, status import neo.io import quantities as pq @@ -58,20 +59,21 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None): root_path, ext = os.path.splitext(main_file) io_mode = getattr(io_cls, "rawmode", None) if io_mode == "one-dir": - # In general, we don't know the names of the individual files - # and have no way to get a directory listing from a URL - # so we raise an exception - if io_cls.__name__ in ("PhyIO"): - # for the exceptions, resolved_url must represent a directory - raise NotImplementedError # todo: for these ios, the file names are known - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=( - "Cannot download files from a URL representing a directory. " - "Please provide the URL of a zip or tar archive of the directory." + if not resolved_url.endswith(".zip"): + # In general, we don't know the names of the individual files + # and have no way to get a directory listing from a URL + # so we raise an exception + if io_cls.__name__ in ("PhyIO"): + # for the exceptions, resolved_url must represent a directory + raise NotImplementedError # todo: for these ios, the file names are known + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + "Cannot download files from a URL representing a directory. " + "Please provide the URL of a zip or tar archive of the directory." + ) ) - ) elif io_mode == "multi-file": # Here the resolved_url represents a single file, with or without the file extension. # By taking the base/root path and adding various extensions we get a list of files to download @@ -153,9 +155,23 @@ def download_neo_data(url, io_cls=None): main_path = files_to_download[0][1] else: main_path = os.path.join(cache_dir, main_file) + if main_path.endswith(".zip"): + main_path = get_archive_dir(main_path, cache_dir) return main_path +def get_archive_dir(archive_path, cache_dir): + with zipfile.ZipFile(archive_path) as zf: + contents = zf.infolist() + dir_name = contents[0].filename.strip("/") + main_path = os.path.join(cache_dir, dir_name) + if not os.path.exists(main_path): + zf.extractall(path=cache_dir) + # we are assuming the zipfile unpacks to a single directory + # todo: check this is the case, and if not either raise an Exception + # or create our own directory to unpack in to + return main_path + extra_kwargs = { "NestIO": { diff --git a/api/test/test_data_handler.py b/api/test/test_data_handler.py index c04dbb9..9f185bb 100644 --- a/api/test/test_data_handler.py +++ b/api/test/test_data_handler.py @@ -3,8 +3,16 @@ """ import os.path +import shutil +import tempfile +from urllib.request import urlretrieve from neo.io import BrainVisionIO -from ..data_handler import get_base_url_and_path, get_cache_path, list_files_to_download +from ..data_handler import ( + get_base_url_and_path, + get_cache_path, + list_files_to_download, + get_archive_dir, +) def test_get_base_url_and_path(): @@ -37,12 +45,22 @@ def test_get_cache_path(): ) assert filename == "File_brainvision_1.vhdr" + def test_list_files_to_download(): url = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/brainvision/File_brainvision_1.vhdr" files_to_download = list_files_to_download(url, "the_cache_dir", BrainVisionIO) expected = [ (url, "the_cache_dir/File_brainvision_1.vhdr", True), (url.replace(".vhdr", ".eeg"), "the_cache_dir/File_brainvision_1.eeg", True), - (url.replace(".vhdr", ".vmrk"), "the_cache_dir/File_brainvision_1.vmrk", True) + (url.replace(".vhdr", ".vmrk"), "the_cache_dir/File_brainvision_1.vmrk", True), ] assert files_to_download == expected + + +def test_download_neo_data_zip(): + cache_dir = tempfile.mkdtemp() + file_url = "https://data-proxy.ebrains.eu/api/v1/buckets/myspace/neo-viewer-test-data/ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip" + archive_path, headers = urlretrieve(file_url, os.path.join(cache_dir, "ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip")) + main_path = get_archive_dir(archive_path, cache_dir) + assert main_path == os.path.join(cache_dir, "original_data") + shutil.rmtree(cache_dir) diff --git a/api/test/test_example_data.py b/api/test/test_example_data.py index 3898095..4facc52 100644 --- a/api/test/test_example_data.py +++ b/api/test/test_example_data.py @@ -14,9 +14,9 @@ test_client = TestClient(app) -base_data_url = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/" +gin_data_url = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/" -test_data = { +test_data_gin = { 200: { "AsciiSpikeTrainIO": ["asciispiketrain/File_ascii_spiketrain_1.txt"], "AxographIO": [ @@ -230,7 +230,8 @@ "neuralynx/Cheetah_v4.0.2/original_data", "neuralynx/Cheetah_v5.4.0/original_data", "neuralynx/Cheetah_v5.5.1/original_data", - "neuralynx/Cheetah_v5.6.3/original_data", + # "neuralynx/Cheetah_v5.6.3/original_data", + "https://data-proxy.ebrains.eu/api/v1/buckets/myspace/neo-viewer-test-data/ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip", "neuralynx/Cheetah_v5.7.4/original_data", "neuralynx/Cheetah_v6.3.2/incomplete_blocks", ], @@ -263,47 +264,59 @@ }, } +test_data_other = { + 200: { + "NeuralynxIO": [ + "https://data-proxy.ebrains.eu/api/v1/buckets/myspace/neo-viewer-test-data/ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip", + ] + } +} + expected_success = [ - (io_cls, test_file) - for io_cls, test_files in test_data[200].items() + (io_cls, f"{gin_data_url}{test_file}") + for io_cls, test_files in test_data_gin[200].items() for test_file in test_files +] + [ + (io_cls, test_file_url) + for io_cls, test_files in test_data_other[200].items() + for test_file_url in test_files ] expected_400_failure_block = [ - (io_cls, test_file) - for io_cls, test_files in test_data[400]["block"].items() + (io_cls, f"{gin_data_url}{test_file}") + for io_cls, test_files in test_data_gin[400]["block"].items() for test_file in test_files ] expected_400_failure_segment = [ - (io_cls, test_file) - for io_cls, test_files in test_data[400]["segment"].items() + (io_cls, f"{gin_data_url}{test_file}") + for io_cls, test_files in test_data_gin[400]["segment"].items() for test_file in test_files ] expected_400_failure_signal = [ - (io_cls, test_file) - for io_cls, test_files in test_data[400]["signal"].items() + (io_cls, f"{gin_data_url}{test_file}") + for io_cls, test_files in test_data_gin[400]["signal"].items() for test_file in test_files ] expected_415_failure = [ - (io_cls, test_file) - for io_cls, test_files in test_data[415].items() + (io_cls, f"{gin_data_url}{test_file}") + for io_cls, test_files in test_data_gin[415].items() for test_file in test_files ] expected_500_failure = [ - (io_cls, test_file) - for io_cls, test_files in test_data[500].items() + (io_cls, f"{gin_data_url}{test_file}") + for io_cls, test_files in test_data_gin[500].items() for test_file in test_files ] -@pytest.mark.parametrize("io_cls,test_file", expected_success) -def test_datasets_expected_success(io_cls, test_file): +@pytest.mark.parametrize("io_cls,test_file_url", expected_success) +def test_datasets_expected_success(io_cls, test_file_url): encode = urllib.parse.urlencode - params = {"url": f"{base_data_url}{test_file}", "type": io_cls} + params = {"url": test_file_url, "type": io_cls} response = test_client.get(f"/api/blockdata/?{encode(params)}") assert response.status_code == 200 @@ -323,10 +336,10 @@ def test_datasets_expected_success(io_cls, test_file): # todo: test irregularlysampledsignals - do we have any cases in the example data? -@pytest.mark.parametrize("io_cls,test_file", expected_400_failure_block) -def test_datasets_expected_400_failure_blockdata(io_cls, test_file): +@pytest.mark.parametrize("io_cls,test_file_url", expected_400_failure_block) +def test_datasets_expected_400_failure_blockdata(io_cls, test_file_url): encode = urllib.parse.urlencode - params = {"url": f"{base_data_url}{test_file}", "type": io_cls} + params = {"url": test_file_url, "type": io_cls} response = test_client.get(f"/api/blockdata/?{encode(params)}") if response.status_code != 400: @@ -336,10 +349,10 @@ def test_datasets_expected_400_failure_blockdata(io_cls, test_file): pytest.xfail(response.json()["detail"]) -@pytest.mark.parametrize("io_cls,test_file", expected_400_failure_segment) -def test_datasets_expected_400_failure_segmentdata(io_cls, test_file): +@pytest.mark.parametrize("io_cls,test_file_url", expected_400_failure_segment) +def test_datasets_expected_400_failure_segmentdata(io_cls, test_file_url): encode = urllib.parse.urlencode - params = {"url": f"{base_data_url}{test_file}", "type": io_cls} + params = {"url": test_file_url, "type": io_cls} response = test_client.get(f"/api/blockdata/?{encode(params)}") assert response.status_code == 200 @@ -351,10 +364,10 @@ def test_datasets_expected_400_failure_segmentdata(io_cls, test_file): pytest.xfail(response2.json()["detail"]) -@pytest.mark.parametrize("io_cls,test_file", expected_400_failure_signal) -def test_datasets_expected_400_failure_analogsignaldata(io_cls, test_file): +@pytest.mark.parametrize("io_cls,test_file_url", expected_400_failure_signal) +def test_datasets_expected_400_failure_analogsignaldata(io_cls, test_file_url): encode = urllib.parse.urlencode - params = {"url": f"{base_data_url}{test_file}", "type": io_cls} + params = {"url": test_file_url, "type": io_cls} response = test_client.get(f"/api/blockdata/?{encode(params)}") assert response.status_code == 200 @@ -372,10 +385,10 @@ def test_datasets_expected_400_failure_analogsignaldata(io_cls, test_file): pytest.xfail(response3.json()["detail"]) -@pytest.mark.parametrize("io_cls,test_file", expected_415_failure) -def test_datasets_expected_415_failure(io_cls, test_file): +@pytest.mark.parametrize("io_cls,test_file_url", expected_415_failure) +def test_datasets_expected_415_failure(io_cls, test_file_url): encode = urllib.parse.urlencode - params = {"url": f"{base_data_url}{test_file}", "type": io_cls} + params = {"url": test_file_url, "type": io_cls} response = test_client.get(f"/api/blockdata/?{encode(params)}") if response.status_code != 415: raise Exception("error") @@ -383,10 +396,10 @@ def test_datasets_expected_415_failure(io_cls, test_file): pytest.xfail(response.json()["detail"]) -@pytest.mark.parametrize("io_cls,test_file", expected_500_failure) -def test_datasets_expected_500_failure(io_cls, test_file): +@pytest.mark.parametrize("io_cls,test_file_url", expected_500_failure) +def test_datasets_expected_500_failure(io_cls, test_file_url): encode = urllib.parse.urlencode - params = {"url": f"{base_data_url}{test_file}", "type": io_cls} + params = {"url": test_file_url, "type": io_cls} response = test_client.get(f"/api/blockdata/?{encode(params)}") assert response.status_code == 200