From 310fcd312b7a28f0f9ae2f50567fe38e238558f6 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 11 Dec 2024 11:01:30 +0100 Subject: [PATCH 01/22] feat: add data.utils.check_dataset() to get or install junifer-data dataset --- junifer/data/utils.py | 55 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/junifer/data/utils.py b/junifer/data/utils.py index a1c5f5d3c5..46227301b9 100644 --- a/junifer/data/utils.py +++ b/junifer/data/utils.py @@ -5,14 +5,17 @@ # License: AGPL from collections.abc import MutableMapping +from pathlib import Path from typing import Optional, Union +import datalad.api as dl import numpy as np +from datalad.support.exceptions import IncompleteResultsError -from ..utils import logger, raise_error +from ..utils import config, logger, raise_error -__all__ = ["closest_resolution", "get_native_warper"] +__all__ = ["check_dataset", "closest_resolution", "get_native_warper"] def closest_resolution( @@ -114,3 +117,51 @@ def get_native_warper( ) return possible_warpers[0] + + +def check_dataset() -> dl.Dataset: + """Get or install junifer-data dataset. + + Returns + ------- + datalad.api.Dataset + The junifer-data dataset. + + Raises + ------ + RuntimeError + If there is a problem cloning the dataset. + + """ + # Check config and set default if not passed + data_dir = config.get("data.location") + if data_dir is not None: + data_dir = Path(data_dir) + else: + data_dir = Path().home() / "junifer-data" + + # Check if the dataset is installed at storage path; + # else clone a fresh copy + if dl.Dataset(data_dir).is_installed(): + logger.debug(f"Found existing junifer-data at: {data_dir.resolve()}") + return dl.Dataset(data_dir) + else: + logger.debug(f"Cloning junifer-data to: {data_dir.resolve()}") + # Clone dataset + try: + dataset = dl.clone( + "https://github.com/juaml/junifer-data.git", + path=data_dir, + result_renderer="disabled", + ) + except IncompleteResultsError as e: + raise_error( + msg=f"Failed to clone junifer-data: {e.failed}", + klass=RuntimeError, + ) + else: + logger.debug( + f"Successfully cloned junifer-data to: " + f"{data_dir.resolve()}" + ) + return dataset From 3dd8623798ad35fa2df1f2ea9b818f3dc25c90e5 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 11 Dec 2024 11:02:53 +0100 Subject: [PATCH 02/22] update: use data.utils.check_dataset() in get_xfm() --- junifer/data/template_spaces.py | 57 +++++---------------------------- 1 file changed, 8 insertions(+), 49 deletions(-) diff --git a/junifer/data/template_spaces.py b/junifer/data/template_spaces.py index 714d6e29ca..0dac737b18 100644 --- a/junifer/data/template_spaces.py +++ b/junifer/data/template_spaces.py @@ -6,22 +6,19 @@ from pathlib import Path from typing import Any, Optional, Union -import datalad.api as dl import nibabel as nib import numpy as np from datalad.support.exceptions import IncompleteResultsError from templateflow import api as tflow from ..utils import logger, raise_error -from .utils import closest_resolution +from .utils import check_dataset, closest_resolution __all__ = ["get_template", "get_xfm"] -def get_xfm( - src: str, dst: str, xfms_dir: Union[str, Path, None] = None -) -> Path: # pragma: no cover +def get_xfm(src: str, dst: str) -> Path: # pragma: no cover """Fetch warp files to convert from ``src`` to ``dst``. Parameters @@ -30,9 +27,6 @@ def get_xfm( The template space to transform from. dst : str The template space to transform to. - xfms_dir : str or pathlib.Path, optional - Path where the retrieved transformation files are stored. - The default location is "$HOME/junifer/data/xfms" (default None). Returns ------- @@ -42,51 +36,16 @@ def get_xfm( Raises ------ RuntimeError - If there is a problem cloning the xfm dataset or - if there is a problem fetching the xfm file. + If there is a problem fetching the xfm file. """ - # Set default path for storage - if xfms_dir is None: - xfms_dir = Path().home() / "junifer" / "data" / "xfms" - - # Convert str to Path - if not isinstance(xfms_dir, Path): - xfms_dir = Path(xfms_dir) - - # Check if the template xfms dataset is installed at storage path - is_installed = dl.Dataset(xfms_dir).is_installed() - # Use existing dataset - if is_installed: - logger.debug( - f"Found existing template xfms dataset at: {xfms_dir.resolve()}" - ) - # Set dataset - dataset = dl.Dataset(xfms_dir) - # Clone a fresh copy - else: - logger.debug(f"Cloning template xfms dataset to: {xfms_dir.resolve()}") - # Clone dataset - try: - dataset = dl.clone( - "https://github.com/juaml/human-template-xfms.git", - path=xfms_dir, - result_renderer="disabled", - ) - except IncompleteResultsError as e: - raise_error( - msg=f"Failed to clone dataset: {e.failed}", - klass=RuntimeError, - ) - else: - logger.debug( - f"Successfully cloned template xfms dataset to: " - f"{xfms_dir.resolve()}" - ) - + # Get dataset + dataset = check_dataset() + # Set xfms dir + xfms_dir = dataset.pathobj / "xfms" # Set file path to retrieve xfm_file_path = ( - xfms_dir / "xfms" / f"{src}_to_{dst}" / f"{src}_to_{dst}_Composite.h5" + xfms_dir / f"{src}_to_{dst}" / f"{src}_to_{dst}_Composite.h5" ) # Retrieve file From 6b7e31a905a61680f71cd93074ee4f1f2b59e0be Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 12 Dec 2024 16:59:18 +0100 Subject: [PATCH 03/22] chore: update junifer-data clone path --- junifer/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/utils.py b/junifer/data/utils.py index 46227301b9..f312ac57f3 100644 --- a/junifer/data/utils.py +++ b/junifer/data/utils.py @@ -138,7 +138,7 @@ def check_dataset() -> dl.Dataset: if data_dir is not None: data_dir = Path(data_dir) else: - data_dir = Path().home() / "junifer-data" + data_dir = Path().home() / "junifer_data" # Check if the dataset is installed at storage path; # else clone a fresh copy From 429c9627cac4ab2ea78c7ec1ceac870a15b7aa00 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 12 Dec 2024 17:00:07 +0100 Subject: [PATCH 04/22] feat: add data.utils.fetch_file_via_datalad() to get files via datalad --- junifer/data/utils.py | 52 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/junifer/data/utils.py b/junifer/data/utils.py index f312ac57f3..27c05f6400 100644 --- a/junifer/data/utils.py +++ b/junifer/data/utils.py @@ -15,7 +15,12 @@ from ..utils import config, logger, raise_error -__all__ = ["check_dataset", "closest_resolution", "get_native_warper"] +__all__ = [ + "check_dataset", + "closest_resolution", + "fetch_file_via_datalad", + "get_native_warper", +] def closest_resolution( @@ -165,3 +170,48 @@ def check_dataset() -> dl.Dataset: f"{data_dir.resolve()}" ) return dataset + + +def fetch_file_via_datalad(dataset: dl.Dataset, file_path: Path) -> Path: + """Fetch `file_path` from `dataset` via datalad. + + Parameters + ---------- + dataset : datalad.api.Dataset + The datalad dataset to fetch files from. + file_path : pathlib.Path + The file path to fetch. + + Returns + ------- + pathlib.Path + Resolved fetched file path. + + Raises + ------ + RuntimeError + If there is a problem fetching the file. + + """ + try: + got = dataset.get(file_path, result_renderer="disabled") + except IncompleteResultsError as e: + raise_error( + msg=f"Failed to get file from dataset: {e.failed}", + klass=RuntimeError, + ) + else: + got_path = Path(got[0]["path"]) + # Conditional logging based on file fetch + status = got[0]["status"] + if status == "ok": + logger.info(f"Successfully fetched file: {got_path.resolve()}") + return got_path + elif status == "notneeded": + logger.info(f"Found existing file: {got_path.resolve()}") + return got_path + else: + raise_error( + msg=f"Failed to fetch file: {got_path.resolve()}", + klass=RuntimeError, + ) From 24a1797912fe4ae3bcd2124970ca02a561fbdf5c Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 12 Dec 2024 17:06:56 +0100 Subject: [PATCH 05/22] update: use data.utils.fetch_file_via_datalad() in get_xfm() --- junifer/data/template_spaces.py | 45 +++++---------------------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/junifer/data/template_spaces.py b/junifer/data/template_spaces.py index 0dac737b18..12d45add22 100644 --- a/junifer/data/template_spaces.py +++ b/junifer/data/template_spaces.py @@ -8,11 +8,10 @@ import nibabel as nib import numpy as np -from datalad.support.exceptions import IncompleteResultsError from templateflow import api as tflow from ..utils import logger, raise_error -from .utils import check_dataset, closest_resolution +from .utils import check_dataset, closest_resolution, fetch_file_via_datalad __all__ = ["get_template", "get_xfm"] @@ -33,50 +32,18 @@ def get_xfm(src: str, dst: str) -> Path: # pragma: no cover pathlib.Path The path to the transformation file. - Raises - ------ - RuntimeError - If there is a problem fetching the xfm file. - """ # Get dataset dataset = check_dataset() - # Set xfms dir - xfms_dir = dataset.pathobj / "xfms" # Set file path to retrieve xfm_file_path = ( - xfms_dir / f"{src}_to_{dst}" / f"{src}_to_{dst}_Composite.h5" + dataset.pathobj + / "xfms" + / f"{src}_to_{dst}" + / f"{src}_to_{dst}_Composite.h5" ) - # Retrieve file - try: - got = dataset.get(xfm_file_path, result_renderer="disabled") - except IncompleteResultsError as e: - raise_error( - msg=f"Failed to get file from dataset: {e.failed}", - klass=RuntimeError, - ) - else: - file_path = Path(got[0]["path"]) - # Conditional logging based on file fetch - status = got[0]["status"] - if status == "ok": - logger.info( - f"Successfully fetched xfm file for {src} to {dst} at " - f"{file_path.resolve()}" - ) - return file_path - elif status == "notneeded": - logger.info( - f"Found existing xfm file for {src} to {dst} at " - f"{file_path.resolve()}" - ) - return file_path - else: - raise_error( - f"Failed to fetch xfm file for {src} to {dst} at " - f"{file_path.resolve()}" - ) + return fetch_file_via_datalad(dataset=dataset, file_path=xfm_file_path) def get_template( From 9b087ecbd42d8402a5cfc43c6979adadbc7d1f5e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 12 Dec 2024 17:09:18 +0100 Subject: [PATCH 06/22] update: use junifer-data in CoordinatesRegistry.load() --- junifer/data/coordinates/_coordinates.py | 207 +++++++++++++---------- 1 file changed, 114 insertions(+), 93 deletions(-) diff --git a/junifer/data/coordinates/_coordinates.py b/junifer/data/coordinates/_coordinates.py index 1b5a5e8a11..8fb4e1a49b 100644 --- a/junifer/data/coordinates/_coordinates.py +++ b/junifer/data/coordinates/_coordinates.py @@ -14,7 +14,7 @@ from ...utils import logger, raise_error from ...utils.singleton import Singleton from ..pipeline_data_registry_base import BasePipelineDataRegistry -from ..utils import get_native_warper +from ..utils import check_dataset, fetch_file_via_datalad, get_native_warper from ._ants_coordinates_warper import ANTsCoordinatesWarper from ._fsl_coordinates_warper import FSLCoordinatesWarper @@ -35,98 +35,97 @@ def __init__(self) -> None: # Each entry in registry is a dictionary that must contain at least # the following keys: # * 'space': the coordinates' space (e.g., 'MNI') - # The built-in coordinates are files that are shipped with the package - # in the data/VOIs directory. The user can also register their own + # The built-in coordinates are files that are shipped with the + # junifer-data dataset. The user can also register their own # coordinates, which will be stored as numpy arrays in the dictionary. # Make built-in and external dictionaries for validation later self._builtin = {} self._external = {} - # Path to the metadata of the VOIs - _vois_meta_path = Path(__file__).parent / "VOIs" / "meta" - - self._builtin = { - "CogAC": { - "path": _vois_meta_path / "CogAC_VOIs.txt", - "space": "MNI", - }, - "CogAR": { - "path": _vois_meta_path / "CogAR_VOIs.txt", - "space": "MNI", - }, - "DMNBuckner": { - "path": _vois_meta_path / "DMNBuckner_VOIs.txt", - "space": "MNI", - }, - "eMDN": { - "path": _vois_meta_path / "eMDN_VOIs.txt", - "space": "MNI", - }, - "Empathy": { - "path": _vois_meta_path / "Empathy_VOIs.txt", - "space": "MNI", - }, - "eSAD": { - "path": _vois_meta_path / "eSAD_VOIs.txt", - "space": "MNI", - }, - "extDMN": { - "path": _vois_meta_path / "extDMN_VOIs.txt", - "space": "MNI", - }, - "Motor": { - "path": _vois_meta_path / "Motor_VOIs.txt", - "space": "MNI", - }, - "MultiTask": { - "path": _vois_meta_path / "MultiTask_VOIs.txt", - "space": "MNI", - }, - "PhysioStress": { - "path": _vois_meta_path / "PhysioStress_VOIs.txt", - "space": "MNI", - }, - "Rew": { - "path": _vois_meta_path / "Rew_VOIs.txt", - "space": "MNI", - }, - "Somatosensory": { - "path": _vois_meta_path / "Somatosensory_VOIs.txt", - "space": "MNI", - }, - "ToM": { - "path": _vois_meta_path / "ToM_VOIs.txt", - "space": "MNI", - }, - "VigAtt": { - "path": _vois_meta_path / "VigAtt_VOIs.txt", - "space": "MNI", - }, - "WM": { - "path": _vois_meta_path / "WM_VOIs.txt", - "space": "MNI", - }, - "Power": { - "path": _vois_meta_path / "Power2011_MNI_VOIs.txt", - "space": "MNI", - }, - "Power2011": { - "path": _vois_meta_path / "Power2011_MNI_VOIs.txt", - "space": "MNI", - }, - "Dosenbach": { - "path": _vois_meta_path / "Dosenbach2010_MNI_VOIs.txt", - "space": "MNI", - }, - "Power2013": { - "path": _vois_meta_path / "Power2013_MNI_VOIs.tsv", - "space": "MNI", - }, - "AutobiographicalMemory": { - "path": _vois_meta_path / "AutobiographicalMemory_VOIs.txt", - "space": "MNI", - }, - } + self._builtin.update( + { + "CogAC": { + "file_path_suffix": "CogAC_VOIs.txt", + "space": "MNI", + }, + "CogAR": { + "file_path_suffix": "CogAR_VOIs.txt", + "space": "MNI", + }, + "DMNBuckner": { + "file_path_suffix": "DMNBuckner_VOIs.txt", + "space": "MNI", + }, + "eMDN": { + "file_path_suffix": "eMDN_VOIs.txt", + "space": "MNI", + }, + "Empathy": { + "file_path_suffix": "Empathy_VOIs.txt", + "space": "MNI", + }, + "eSAD": { + "file_path_suffix": "eSAD_VOIs.txt", + "space": "MNI", + }, + "extDMN": { + "file_path_suffix": "extDMN_VOIs.txt", + "space": "MNI", + }, + "Motor": { + "file_path_suffix": "Motor_VOIs.txt", + "space": "MNI", + }, + "MultiTask": { + "file_path_suffix": "MultiTask_VOIs.txt", + "space": "MNI", + }, + "PhysioStress": { + "file_path_suffix": "PhysioStress_VOIs.txt", + "space": "MNI", + }, + "Rew": { + "file_path_suffix": "Rew_VOIs.txt", + "space": "MNI", + }, + "Somatosensory": { + "file_path_suffix": "Somatosensory_VOIs.txt", + "space": "MNI", + }, + "ToM": { + "file_path_suffix": "ToM_VOIs.txt", + "space": "MNI", + }, + "VigAtt": { + "file_path_suffix": "VigAtt_VOIs.txt", + "space": "MNI", + }, + "WM": { + "file_path_suffix": "WM_VOIs.txt", + "space": "MNI", + }, + "Power": { + "file_path_suffix": "Power2011_MNI_VOIs.txt", + "space": "MNI", + }, + "Power2011": { + "file_path_suffix": "Power2011_MNI_VOIs.txt", + "space": "MNI", + }, + "Dosenbach": { + "file_path_suffix": "Dosenbach2010_MNI_VOIs.txt", + "space": "MNI", + }, + "Power2013": { + "file_path_suffix": "Power2013_MNI_VOIs.tsv", + "space": "MNI", + }, + "AutobiographicalMemory": { + "file_path_suffix": "AutobiographicalMemory_VOIs.txt", + "space": "MNI", + }, + } + ) # Set built-in to registry self._registry = self._builtin @@ -257,6 +256,8 @@ def load(self, name: str) -> tuple[ArrayLike, list[str], str]: ------ ValueError If ``name`` is invalid. + RuntimeError + If there is a problem fetching the coordinates file. """ # Check for valid coordinates name @@ -265,17 +266,37 @@ def load(self, name: str) -> tuple[ArrayLike, list[str], str]: f"Coordinates: {name} not found. " f"Valid options are: {self.list}" ) - # Load coordinates + # Load coordinates info t_coord = self._registry[name] - # Load data - if isinstance(t_coord.get("path"), Path): - logger.debug(f"Loading coordinates {t_coord['path'].absolute()!s}") + + # Load data for in-built ones + if t_coord.get("file_path_suffix") is not None: + # Get dataset + dataset = check_dataset() + # Set file path to retrieve + coords_file_path = ( + dataset.pathobj + / "coordinates" + / name + / t_coord["file_path_suffix"] + ) + logger.debug( + f"Loading coordinates `{name}` from: " + f"{coords_file_path.absolute()!s}" + ) # Load via pandas - df_coords = pd.read_csv(t_coord["path"], sep="\t", header=None) + df_coords = pd.read_csv( + fetch_file_via_datalad( + dataset=dataset, file_path=coords_file_path + ), + sep="\t", + header=None, + ) # Convert dataframe to numpy ndarray coords = df_coords.iloc[:, [0, 1, 2]].to_numpy() # Get label names names = list(df_coords.iloc[:, [3]].values[:, 0]) + # Load data for external ones else: coords = t_coord["coords"] names = t_coord["voi_names"] From c2a367d3ff095e696c895e05b94d169e3b2465f0 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 17:37:45 +0100 Subject: [PATCH 07/22] update: use junifer-data in MaskRegistry.load() --- junifer/data/masks/_masks.py | 51 ++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index b0064dfbda..44da273be2 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -26,12 +26,18 @@ from ...utils.singleton import Singleton from ..pipeline_data_registry_base import BasePipelineDataRegistry from ..template_spaces import get_template -from ..utils import closest_resolution, get_native_warper +from ..utils import ( + check_dataset, + closest_resolution, + fetch_file_via_datalad, + get_native_warper, +) from ._ants_mask_warper import ANTsMaskWarper from ._fsl_mask_warper import FSLMaskWarper if TYPE_CHECKING: + from datalad.api import Dataset from nibabel.nifti1 import Nifti1Image @@ -397,13 +403,21 @@ def load( mask_img = None if t_family == "CustomUserMask": mask_fname = Path(mask_definition["path"]) - elif t_family == "Vickery-Patil": - mask_fname = _load_vickery_patil_mask(name, resolution) elif t_family == "Callable": mask_img = mask_definition["func"] mask_fname = None - elif t_family == "UKB": - mask_fname = _load_ukb_mask(name) + elif t_family in ["Vickery-Patil", "UKB"]: + # Get dataset + dataset = check_dataset() + # Load mask + if t_family == "Vickery-Patil": + mask_fname = _load_vickery_patil_mask( + dataset=dataset, + name=name, + resolution=resolution, + ) + elif t_family == "UKB": + mask_fname = _load_ukb_mask(dataset=dataset, name=name) else: raise_error(f"Unknown mask family: {t_family}") @@ -685,6 +699,7 @@ def get( # noqa: C901 def _load_vickery_patil_mask( + dataset: "Dataset", name: str, resolution: Optional[float] = None, ) -> Path: @@ -692,6 +707,8 @@ def _load_vickery_patil_mask( Parameters ---------- + dataset : datalad.api.Dataset + The datalad dataset to fetch mask from. name : {"GM_prob0.2", "GM_prob0.2_cortex"} The name of the mask. resolution : float, optional @@ -712,6 +729,7 @@ def _load_vickery_patil_mask( ``name = "GM_prob0.2"``. """ + # Check name if name == "GM_prob0.2": available_resolutions = [1.5, 3.0] to_load = closest_resolution(resolution, available_resolutions) @@ -730,17 +748,20 @@ def _load_vickery_patil_mask( else: raise_error(f"Cannot find a Vickery-Patil mask called {name}") - # Set path for masks - mask_fname = _masks_path / "vickery-patil" / mask_fname + # Fetch file + return fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj / "masks" / "Vickery-Patil" / mask_fname, + ) - return mask_fname - -def _load_ukb_mask(name: str) -> Path: +def _load_ukb_mask(dataset: "Dataset", name: str) -> Path: """Load UKB mask. Parameters ---------- + dataset : datalad.api.Dataset + The datalad dataset to fetch mask from. name : {"UKB_15K_GM"} The name of the mask. @@ -755,15 +776,17 @@ def _load_ukb_mask(name: str) -> Path: If ``name`` is invalid. """ + # Check name if name == "UKB_15K_GM": mask_fname = "UKB_15K_GM_template.nii.gz" else: raise_error(f"Cannot find a UKB mask called {name}") - # Set path for masks - mask_fname = _masks_path / "ukb" / mask_fname - - return mask_fname + # Fetch file + return fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj / "masks" / "UKB" / mask_fname, + ) def _get_interpolation_method(img: "Nifti1Image") -> str: From d9a3717bbf9735fd2d01ed227b7bbcaa44ccbc8a Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:03:33 +0100 Subject: [PATCH 08/22] fix: update init and registration logic for CoordinatesRegistry --- junifer/data/coordinates/_coordinates.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/junifer/data/coordinates/_coordinates.py b/junifer/data/coordinates/_coordinates.py index 8fb4e1a49b..6499ebc47d 100644 --- a/junifer/data/coordinates/_coordinates.py +++ b/junifer/data/coordinates/_coordinates.py @@ -4,7 +4,6 @@ # Synchon Mandal # License: AGPL -from pathlib import Path from typing import Any, Optional import numpy as np @@ -32,6 +31,7 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton): def __init__(self) -> None: """Initialize the class.""" + super().__init__() # Each entry in registry is a dictionary that must contain at least # the following keys: # * 'space': the coordinates' space (e.g., 'MNI') @@ -127,8 +127,8 @@ def __init__(self) -> None: } ) - # Set built-in to registry - self._registry = self._builtin + # Update registry with built-in ones + self._registry.update(self._builtin) def register( self, @@ -160,9 +160,9 @@ def register( Raises ------ ValueError - If the coordinates ``name`` is already registered and + If the coordinates ``name`` is a built-in coordinates or + if the coordinates ``name`` is already registered and ``overwrite=False`` or - if the coordinates ``name`` is a built-in coordinates or if the ``coordinates`` is not a 2D array or if coordinate value does not have 3 components or if the ``voi_names`` shape does not match the @@ -173,11 +173,12 @@ def register( """ # Check for attempt of overwriting built-in coordinates if name in self._builtin: - if isinstance(self._registry[name].get("path"), Path): - raise_error( - f"Coordinates: {name} already registered as built-in " - "coordinates." - ) + raise_error( + f"Coordinates: {name} already registered as built-in " + "coordinates." + ) + # Check for attempt of overwriting external coordinates + if name in self._external: if overwrite: logger.info(f"Overwriting coordinates: {name}") else: @@ -185,7 +186,7 @@ def register( f"Coordinates: {name} already registered. " "Set `overwrite=True` to update its value." ) - + # Further checks if not isinstance(coordinates, np.ndarray): raise_error( "Coordinates must be a `numpy.ndarray`, " @@ -206,6 +207,7 @@ def register( f"Length of `voi_names` ({len(voi_names)}) does not match the " f"number of `coordinates` ({coordinates.shape[0]})." ) + # Registration logger.info(f"Registering coordinates: {name}") # Add coordinates info self._external[name] = { From feb0e133bb7f48018b7f5d88b08bd997675561a9 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:04:00 +0100 Subject: [PATCH 09/22] chore: update CoordinatesRegistry tests --- junifer/data/coordinates/tests/test_coordinates.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/junifer/data/coordinates/tests/test_coordinates.py b/junifer/data/coordinates/tests/test_coordinates.py index e981bac339..eefd8bffbc 100644 --- a/junifer/data/coordinates/tests/test_coordinates.py +++ b/junifer/data/coordinates/tests/test_coordinates.py @@ -21,7 +21,6 @@ def test_register_built_in_check() -> None: coordinates=np.zeros(2), voi_names=["1", "2"], space="MNI", - overwrite=True, ) @@ -32,7 +31,6 @@ def test_register_overwrite() -> None: coordinates=np.zeros((2, 3)), voi_names=["roi1", "roi2"], space="MNI", - overwrite=True, ) with pytest.raises(ValueError, match=r"already registered"): CoordinatesRegistry().register( @@ -40,6 +38,7 @@ def test_register_overwrite() -> None: coordinates=np.ones((2, 3)), voi_names=["roi2", "roi3"], space="MNI", + overwrite=False, ) CoordinatesRegistry().register( From 5f372465581f304d3b9fab6504f0654eb883e42d Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:13:06 +0100 Subject: [PATCH 10/22] fix: update init and registration logic for MaskRegistry --- junifer/data/masks/_masks.py | 79 +++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 38 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 44da273be2..917cb93bea 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -230,6 +230,7 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton): def __init__(self) -> None: """Initialize the class.""" + super().__init__() # Each entry in registry is a dictionary that must contain at least # the following keys: # * 'family': the mask's family name @@ -246,38 +247,40 @@ def __init__(self) -> None: self._builtin = {} self._external = {} - self._builtin = { - "GM_prob0.2": { - "family": "Vickery-Patil", - "space": "IXI549Space", - }, - "GM_prob0.2_cortex": { - "family": "Vickery-Patil", - "space": "IXI549Space", - }, - "compute_brain_mask": { - "family": "Callable", - "func": compute_brain_mask, - "space": "inherit", - }, - "compute_background_mask": { - "family": "Callable", - "func": compute_background_mask, - "space": "inherit", - }, - "compute_epi_mask": { - "family": "Callable", - "func": compute_epi_mask, - "space": "inherit", - }, - "UKB_15K_GM": { - "family": "UKB", - "space": "MNI152NLin6Asym", - }, - } + self._builtin.update( + { + "GM_prob0.2": { + "family": "Vickery-Patil", + "space": "IXI549Space", + }, + "GM_prob0.2_cortex": { + "family": "Vickery-Patil", + "space": "IXI549Space", + }, + "compute_brain_mask": { + "family": "Callable", + "func": compute_brain_mask, + "space": "inherit", + }, + "compute_background_mask": { + "family": "Callable", + "func": compute_background_mask, + "space": "inherit", + }, + "compute_epi_mask": { + "family": "Callable", + "func": compute_epi_mask, + "space": "inherit", + }, + "UKB_15K_GM": { + "family": "UKB", + "space": "MNI152NLin6Asym", + }, + } + ) - # Set built-in to registry - self._registry = self._builtin + # Update registry with built-in ones + self._registry.update(self._builtin) def register( self, @@ -303,19 +306,18 @@ def register( Raises ------ ValueError - If the mask ``name`` is already registered and - ``overwrite=False`` or - if the mask ``name`` is a built-in mask. + If the mask ``name`` is a built-in mask or + if the mask ``name`` is already registered and + ``overwrite=False``. """ # Check for attempt of overwriting built-in mask if name in self._builtin: + raise_error(f"Mask: {name} already registered as built-in mask.") + # Check for attempt of overwriting external coordinates + if name in self._external: if overwrite: logger.info(f"Overwriting mask: {name}") - if self._registry[name]["family"] != "CustomUserMask": - raise_error( - f"Mask: {name} already registered as built-in mask." - ) else: raise_error( f"Mask: {name} already registered. Set `overwrite=True` " @@ -324,6 +326,7 @@ def register( # Convert str to Path if not isinstance(mask_path, Path): mask_path = Path(mask_path) + # Registration logger.info(f"Registering mask: {name}") # Add mask info self._external[name] = { From a3f29713dd91585fef65becab3bfbde164d3310d Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:13:23 +0100 Subject: [PATCH 11/22] chore: update MaskRegistry tests --- junifer/data/masks/tests/test_masks.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/junifer/data/masks/tests/test_masks.py b/junifer/data/masks/tests/test_masks.py index a51380062a..0a8a2dc4ec 100644 --- a/junifer/data/masks/tests/test_masks.py +++ b/junifer/data/masks/tests/test_masks.py @@ -26,6 +26,7 @@ _load_ukb_mask, _load_vickery_patil_mask, ) +from junifer.data.utils import check_dataset from junifer.datagrabber import DMCC13Benchmark from junifer.datareader import DefaultDataReader from junifer.testing.datagrabbers import ( @@ -282,7 +283,9 @@ def test_vickery_patil( def test_vickery_patil_error() -> None: """Test error for Vickery-Patil mask.""" with pytest.raises(ValueError, match=r"find a Vickery-Patil mask "): - _load_vickery_patil_mask(name="wrong", resolution=2.0) + _load_vickery_patil_mask( + dataset=check_dataset(), name="wrong", resolution=2.0 + ) def test_ukb() -> None: @@ -297,7 +300,7 @@ def test_ukb() -> None: def test_ukb_error() -> None: """Test error for UKB mask.""" with pytest.raises(ValueError, match=r"find a UKB mask "): - _load_ukb_mask(name="wrong") + _load_ukb_mask(dataset=check_dataset(), name="wrong") def test_get() -> None: From 1344a81cf5e07a8a235e52674b3ed374b2e8871f Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:14:49 +0100 Subject: [PATCH 12/22] chore: lint --- junifer/data/masks/_masks.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 917cb93bea..0f9bb8dae0 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -44,10 +44,6 @@ __all__ = ["MaskRegistry", "compute_brain_mask"] -# Path to the masks -_masks_path = Path(__file__).parent - - def compute_brain_mask( target_data: dict[str, Any], warp_data: Optional[dict[str, Any]] = None, From 99e0e0d4061954d8ea1554c30d6fdcbbd15b83f3 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:20:38 +0100 Subject: [PATCH 13/22] chore: update commentary for MaskRegistry --- junifer/data/masks/_masks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 0f9bb8dae0..b2ad54ba1c 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -310,7 +310,7 @@ def register( # Check for attempt of overwriting built-in mask if name in self._builtin: raise_error(f"Mask: {name} already registered as built-in mask.") - # Check for attempt of overwriting external coordinates + # Check for attempt of overwriting external masks if name in self._external: if overwrite: logger.info(f"Overwriting mask: {name}") From 88b624b2ed0faef5f0adb5596e4338ff540f460f Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 16 Dec 2024 18:20:53 +0100 Subject: [PATCH 14/22] update: store Path instead of str for external entries in MaskRegistry --- junifer/data/masks/_masks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index b2ad54ba1c..97539005fb 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -326,13 +326,13 @@ def register( logger.info(f"Registering mask: {name}") # Add mask info self._external[name] = { - "path": str(mask_path.absolute()), + "path": mask_path, "family": "CustomUserMask", "space": space, } # Update registry self._registry[name] = { - "path": str(mask_path.absolute()), + "path": mask_path, "family": "CustomUserMask", "space": space, } From e5d6b00d8c9707347005ada331a1997d45baead2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Dec 2024 14:22:43 +0100 Subject: [PATCH 15/22] update: use junifer-data in ParcellationRegistry.load(); fix: update init and registration logic for ParcellationRegistry --- junifer/data/parcellations/_parcellations.py | 975 ++++++------------- 1 file changed, 297 insertions(+), 678 deletions(-) diff --git a/junifer/data/parcellations/_parcellations.py b/junifer/data/parcellations/_parcellations.py index 354c1ca8cd..1c4a8adffa 100644 --- a/junifer/data/parcellations/_parcellations.py +++ b/junifer/data/parcellations/_parcellations.py @@ -5,31 +5,30 @@ # Synchon Mandal # License: AGPL -import io -import shutil -import tarfile -import tempfile -import zipfile from itertools import product from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union -import httpx import nibabel as nib import nilearn.image as nimg import numpy as np import pandas as pd -from nilearn import datasets from ...utils import logger, raise_error, warn_with_log from ...utils.singleton import Singleton from ..pipeline_data_registry_base import BasePipelineDataRegistry -from ..utils import closest_resolution, get_native_warper +from ..utils import ( + check_dataset, + closest_resolution, + fetch_file_via_datalad, + get_native_warper, +) from ._ants_parcellation_warper import ANTsParcellationWarper from ._fsl_parcellation_warper import FSLParcellationWarper if TYPE_CHECKING: + from datalad.api import Dataset from nibabel.nifti1 import Nifti1Image @@ -49,6 +48,7 @@ class ParcellationRegistry(BasePipelineDataRegistry, metaclass=Singleton): def __init__(self) -> None: """Initialize the class.""" + super().__init__() # Each entry in registry is a dictionary that must contain at least # the following keys: # * 'family': the parcellation's family name (e.g., 'Schaefer', 'SUIT') @@ -56,6 +56,8 @@ def __init__(self) -> None: # and can also have optional key(s): # * 'valid_resolutions': a list of valid resolutions for the # parcellation (e.g., [1, 2]) + # The built-in coordinates are files that are shipped with the + # junifer-data dataset. # Make built-in and external dictionaries for validation later self._builtin = {} self._external = {} @@ -63,8 +65,14 @@ def __init__(self) -> None: # Add SUIT self._builtin.update( { - "SUITxSUIT": {"family": "SUIT", "space": "SUIT"}, - "SUITxMNI": {"family": "SUIT", "space": "MNI152NLin6Asym"}, + "SUITxSUIT": { + "family": "SUIT", + "space": "SUIT", + }, + "SUITxMNI": { + "family": "SUIT", + "space": "MNI152NLin6Asym", + }, } ) # Add Schaefer @@ -72,7 +80,7 @@ def __init__(self) -> None: self._builtin.update( { f"Schaefer{n_rois}x{t_net}": { - "family": "Schaefer", + "family": "Schaefer2018", "n_rois": n_rois, "yeo_networks": t_net, "space": "MNI152NLin6Asym", @@ -84,19 +92,19 @@ def __init__(self) -> None: self._builtin.update( { f"TianxS{scale}x7TxMNI6thgeneration": { - "family": "Tian", + "family": "Melbourne", "scale": scale, "magneticfield": "7T", "space": "MNI152NLin6Asym", }, f"TianxS{scale}x3TxMNI6thgeneration": { - "family": "Tian", + "family": "Melbourne", "scale": scale, "magneticfield": "3T", "space": "MNI152NLin6Asym", }, f"TianxS{scale}x3TxMNInonlinear2009cAsym": { - "family": "Tian", + "family": "Melbourne", "scale": scale, "magneticfield": "3T", "space": "MNI152NLin2009cAsym", @@ -155,7 +163,7 @@ def __init__(self) -> None: self._builtin.update( { f"Yan{n_rois}xYeo{yeo_network}": { - "family": "Yan", + "family": "Yan2023", "n_rois": n_rois, "yeo_networks": yeo_network, "space": "MNI152NLin6Asym", @@ -165,7 +173,7 @@ def __init__(self) -> None: self._builtin.update( { f"Yan{n_rois}xKong17": { - "family": "Yan", + "family": "Yan2023", "n_rois": n_rois, "kong_networks": 17, "space": "MNI152NLin6Asym", @@ -184,8 +192,8 @@ def __init__(self) -> None: } ) - # Set built-in to registry - self._registry = self._builtin + # Update registry with built-in ones + self._registry.update(self._builtin) def register( self, @@ -214,20 +222,21 @@ def register( Raises ------ ValueError - If the parcellation ``name`` is already registered and - ``overwrite=False`` or - if the parcellation ``name`` is a built-in parcellation. + If the parcellation ``name`` is a built-in parcellation or + if the parcellation ``name`` is already registered and + ``overwrite=False``. """ # Check for attempt of overwriting built-in parcellations if name in self._builtin: + raise_error( + f"Parcellation: {name} already registered as " + "built-in parcellation." + ) + # Check for attempt of overwriting external parcellations + if name in self._external: if overwrite: logger.info(f"Overwriting parcellation: {name}") - if self._registry[name]["family"] != "CustomUserParcellation": - raise_error( - f"Parcellation: {name} already registered as " - "built-in parcellation." - ) else: raise_error( f"Parcellation: {name} already registered. Set " @@ -236,6 +245,7 @@ def register( # Convert str to Path if not isinstance(parcellation_path, Path): parcellation_path = Path(parcellation_path) + # Registration logger.info(f"Registering parcellation: {name}") # Add user parcellation info self._external[name] = { @@ -271,24 +281,17 @@ def load( self, name: str, target_space: str, - parcellations_dir: Union[str, Path, None] = None, resolution: Optional[float] = None, path_only: bool = False, ) -> tuple[Optional["Nifti1Image"], list[str], Path, str]: """Load parcellation and labels. - If it is a built-in parcellation and the file is not present in the - ``parcellations_dir`` directory, it will be downloaded. - Parameters ---------- name : str The name of the parcellation. target_space : str The desired space of the parcellation. - parcellations_dir : str or pathlib.Path, optional - Path where the parcellations files are stored. The default location - is "$HOME/junifer/data/parcellations" (default None). resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -312,6 +315,7 @@ def load( ------ ValueError If ``name`` is invalid or + if the parcellation family is invalid or if the parcellation values and labels don't have equal dimension or if the value range is invalid. @@ -327,7 +331,7 @@ def load( parcellation_definition = self._registry[name].copy() t_family = parcellation_definition.pop("family") # Remove space conditionally - if t_family not in ["SUIT", "Tian"]: + if t_family not in ["SUIT", "Melbourne"]: space = parcellation_definition.pop("space") else: space = parcellation_definition["space"] @@ -342,15 +346,66 @@ def load( # Check if the parcellation family is custom or built-in if t_family == "CustomUserParcellation": - parcellation_fname = Path(parcellation_definition["path"]) + parcellation_fname = parcellation_definition["path"] parcellation_labels = parcellation_definition["labels"] + elif t_family in [ + "Schaefer2018", + "SUIT", + "Melbourne", + "AICHA", + "Shen", + "Yan2023", + "Brainnetome", + ]: + # Get dataset + dataset = check_dataset() + # Load parcellation and labels + if t_family == "Schaefer2018": + parcellation_fname, parcellation_labels = _retrieve_schaefer( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + elif t_family == "SUIT": + parcellation_fname, parcellation_labels = _retrieve_suit( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + elif t_family == "Melbourne": + parcellation_fname, parcellation_labels = _retrieve_tian( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + elif t_family == "AICHA": + parcellation_fname, parcellation_labels = _retrieve_aicha( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + elif t_family == "Shen": + parcellation_fname, parcellation_labels = _retrieve_shen( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + elif t_family == "Yan2023": + parcellation_fname, parcellation_labels = _retrieve_yan( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + elif t_family == "Brainnetome": + parcellation_fname, parcellation_labels = ( + _retrieve_brainnetome( + dataset=dataset, + resolution=resolution, + **parcellation_definition, + ) + ) else: - parcellation_fname, parcellation_labels = _retrieve_parcellation( - family=t_family, - parcellations_dir=parcellations_dir, - resolution=resolution, - **parcellation_definition, - ) + raise_error(f"Unknown parcellation family: {t_family}") # Load parcellation image and values logger.info(f"Loading parcellation: {parcellation_fname.absolute()!s}") @@ -529,152 +584,8 @@ def get( return resampled_parcellation_img, labels -def _retrieve_parcellation( - family: str, - parcellations_dir: Union[str, Path, None] = None, - resolution: Optional[float] = None, - **kwargs, -) -> tuple[Path, list[str]]: - """Retrieve a brain parcellation object from nilearn or online source. - - Only returns one parcellation per call. Call function multiple times for - different parameter specifications. Only retrieves parcellation if it is - not yet in parcellations_dir. - - Parameters - ---------- - family : {"Schaefer", "SUIT", "Tian", "AICHA", "Shen", "Yan"} - The name of the parcellation family. - parcellations_dir : str or pathlib.Path, optional - Path where the retrieved parcellations file are stored. The default - location is "$HOME/junifer/data/parcellations" (default None). - resolution : float, optional - The desired resolution of the parcellation to load. If it is not - available, the closest resolution will be loaded. Preferably, use a - resolution higher than the desired one. By default, will load the - highest one (default None). - **kwargs - Use to specify parcellation-specific keyword arguments found in the - following section. - - Other Parameters - ---------------- - * Schaefer : - ``n_rois`` : {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} - Granularity of parcellation to be used. - ``yeo_network`` : {7, 17}, optional - Number of Yeo networks to use (default 7). - * Tian : - ``scale`` : {1, 2, 3, 4} - Scale of parcellation (defines granularity). - ``space`` : {"MNI152NLin6Asym", "MNI152NLin2009cAsym"}, optional - Space of parcellation (default "MNI152NLin6Asym"). (For more - information see https://github.com/yetianmed/subcortex) - ``magneticfield`` : {"3T", "7T"}, optional - Magnetic field (default "3T"). - * SUIT : - ``space`` : {"MNI152NLin6Asym", "SUIT"}, optional - Space of parcellation (default "MNI"). (For more information - see http://www.diedrichsenlab.org/imaging/suit.htm). - * AICHA : - ``version`` : {1, 2}, optional - Version of parcellation (default 2). - * Shen : - ``year`` : {2013, 2015, 2019}, optional - Year of the parcellation to use (default 2015). - ``n_rois`` : int, optional - Number of ROIs to use. Can be ``50, 100, or 150`` for - ``year = 2013`` but is fixed at ``268`` for ``year = 2015`` and at - ``368`` for ``year = 2019``. - * Yan : - ``n_rois`` : {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} - Granularity of the parcellation to be used. - ``yeo_networks`` : {7, 17}, optional - Number of Yeo networks to use (default None). - ``kong_networks`` : {17}, optional - Number of Kong networks to use (default None). - * Brainnetome : - ``threshold`` : {0, 25, 50} - Threshold for the probabilistic maps of subregion. - - Returns - ------- - pathlib.Path - File path to the parcellation image. - list of str - Parcellation labels. - - Raises - ------ - ValueError - If the parcellation's name is invalid. - - """ - if parcellations_dir is None: - parcellations_dir = ( - Path().home() / "junifer" / "data" / "parcellations" - ) - # Create default junifer data directory if not present - parcellations_dir.mkdir(exist_ok=True, parents=True) - # Convert str to Path - elif not isinstance(parcellations_dir, Path): - parcellations_dir = Path(parcellations_dir) - - logger.info(f"Fetching one of {family} parcellations.") - - # Retrieval details per family - if family == "Schaefer": - parcellation_fname, parcellation_labels = _retrieve_schaefer( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - elif family == "SUIT": - parcellation_fname, parcellation_labels = _retrieve_suit( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - elif family == "Tian": - parcellation_fname, parcellation_labels = _retrieve_tian( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - elif family == "AICHA": - parcellation_fname, parcellation_labels = _retrieve_aicha( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - elif family == "Shen": - parcellation_fname, parcellation_labels = _retrieve_shen( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - elif family == "Yan": - parcellation_fname, parcellation_labels = _retrieve_yan( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - elif family == "Brainnetome": - parcellation_fname, parcellation_labels = _retrieve_brainnetome( - parcellations_dir=parcellations_dir, - resolution=resolution, - **kwargs, - ) - else: - raise_error( - f"The provided parcellation name {family} cannot be retrieved." - ) - - return parcellation_fname, parcellation_labels - - def _retrieve_schaefer( - parcellations_dir: Path, + dataset: "Dataset", resolution: Optional[float] = None, n_rois: Optional[int] = None, yeo_networks: int = 7, @@ -683,8 +594,8 @@ def _retrieve_schaefer( Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -706,8 +617,7 @@ def _retrieve_schaefer( Raises ------ ValueError - If invalid value is provided for ``n_rois`` or ``yeo_networks`` or if - there is a problem fetching the parcellation. + If invalid value is provided for ``n_rois`` or ``yeo_networks``. """ logger.info("Parcellation parameters:") @@ -735,47 +645,40 @@ def _retrieve_schaefer( _valid_resolutions = [1, 2] resolution = closest_resolution(resolution, _valid_resolutions) - # Define parcellation and label file names - parcellation_fname = ( - parcellations_dir - / "schaefer_2018" + # Fetch file paths + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Schaefer2018" + / "Yeo2011" / ( f"Schaefer2018_{n_rois}Parcels_{yeo_networks}Networks_order_" f"FSLMNI152_{resolution}mm.nii.gz" - ) + ), ) - parcellation_lname = ( - parcellations_dir - / "schaefer_2018" - / (f"Schaefer2018_{n_rois}Parcels_{yeo_networks}Networks_order.txt") + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Schaefer2018" + / "Yeo2011" + / (f"Schaefer2018_{n_rois}Parcels_{yeo_networks}Networks_order.txt"), ) - # Check existence of parcellation - if not (parcellation_fname.exists() and parcellation_lname.exists()): - logger.info( - "At least one of the parcellation files are missing. " - "Fetching using nilearn." - ) - datasets.fetch_atlas_schaefer_2018( - n_rois=n_rois, - yeo_networks=yeo_networks, - resolution_mm=resolution, # type: ignore we know it's 1 or 2 - data_dir=parcellations_dir.resolve(), - ) - # Load labels labels = [ "_".join(x.split("_")[1:]) - for x in pd.read_csv(parcellation_lname, sep="\t", header=None) + for x in pd.read_csv(parcellation_label_path, sep="\t", header=None) .iloc[:, 1] .to_list() ] - return parcellation_fname, labels + return parcellation_img_path, labels def _retrieve_tian( - parcellations_dir: Path, + dataset: "Dataset", resolution: Optional[float] = None, scale: Optional[int] = None, space: str = "MNI152NLin6Asym", @@ -785,8 +688,8 @@ def _retrieve_tian( Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -810,8 +713,6 @@ def _retrieve_tian( Raises ------ - RuntimeError - If there is a problem fetching files. ValueError If invalid value is provided for ``scale`` or ``magneticfield`` or ``space``. @@ -832,13 +733,10 @@ def _retrieve_tian( ) # Check resolution - _valid_resolutions = [] # avoid pylance error if magneticfield == "3T": _valid_spaces = ["MNI152NLin6Asym", "MNI152NLin2009cAsym"] - if space == "MNI152NLin6Asym": + if space in _valid_spaces: _valid_resolutions = [1, 2] - elif space == "MNI152NLin2009cAsym": - _valid_resolutions = [2] else: raise_error( f"The parameter `space` ({space}) for 3T needs to be one of " @@ -858,100 +756,76 @@ def _retrieve_tian( ) resolution = closest_resolution(resolution, _valid_resolutions) - # Define parcellation and label file names + # Fetch file paths if magneticfield == "3T": parcellation_fname_base_3T = ( - parcellations_dir / "Tian2020MSA_v1.1" / "3T" / "Subcortex-Only" - ) - parcellation_lname = parcellation_fname_base_3T / ( - f"Tian_Subcortex_S{scale}_3T_label.txt" + dataset.pathobj + / "parcellations" + / "Melbourne" + / "v1.4" + / "3T" + / "Subcortex-Only" ) if space == "MNI152NLin6Asym": - parcellation_fname = parcellation_fname_base_3T / ( - f"Tian_Subcortex_S{scale}_{magneticfield}.nii.gz" - ) if resolution == 1: parcellation_fname = ( parcellation_fname_base_3T / f"Tian_Subcortex_S{scale}_{magneticfield}_1mm.nii.gz" ) + else: + parcellation_fname = parcellation_fname_base_3T / ( + f"Tian_Subcortex_S{scale}_{magneticfield}.nii.gz" + ) elif space == "MNI152NLin2009cAsym": space = "2009cAsym" - parcellation_fname = parcellation_fname_base_3T / ( - f"Tian_Subcortex_S{scale}_{magneticfield}_{space}.nii.gz" - ) - elif magneticfield == "7T": - parcellation_fname_base_7T = ( - parcellations_dir / "Tian2020MSA_v1.1" / "7T" + if resolution == 1: + parcellation_fname = parcellation_fname_base_3T / ( + f"Tian_Subcortex_S{scale}_{magneticfield}_{space}_1mm.nii.gz" + ) + else: + parcellation_fname = parcellation_fname_base_3T / ( + f"Tian_Subcortex_S{scale}_{magneticfield}_{space}.nii.gz" + ) + + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=parcellation_fname, ) - parcellation_fname_base_7T.mkdir(exist_ok=True, parents=True) - parcellation_fname = ( - parcellations_dir - / "Tian2020MSA_v1.1" - / f"{magneticfield}" - / (f"Tian_Subcortex_S{scale}_{magneticfield}.nii.gz") + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=parcellation_fname_base_3T + / f"Tian_Subcortex_S{scale}_3T_label.txt", + ) + # Load labels + labels = pd.read_csv(parcellation_label_path, sep=" ", header=None)[ + 0 + ].to_list() + elif magneticfield == "7T": + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Melbourne" + / "v1.4" + / "7T" + / f"Tian_Subcortex_S{scale}_{magneticfield}.nii.gz", ) # define 7T labels (b/c currently no labels file available for 7T) scale7Trois = {1: 16, 2: 34, 3: 54, 4: 62} labels = [ ("parcel_" + str(x)) for x in np.arange(1, scale7Trois[scale] + 1) ] - parcellation_lname = parcellation_fname_base_7T / ( - f"Tian_Subcortex_S{scale}_7T_labelnumbering.txt" - ) - with open(parcellation_lname, "w") as filehandle: - for listitem in labels: - filehandle.write(f"{listitem}\n") logger.info( "Currently there are no labels provided for the 7T Tian " "parcellation. A simple numbering scheme for distinction was " "therefore used." ) - # Check existence of parcellation - if not (parcellation_fname.exists() and parcellation_lname.exists()): - logger.info( - "At least one of the parcellation files are missing, fetching." - ) - # Set URL - url = ( - "https://www.nitrc.org/frs/download.php/12012/Tian2020MSA_v1.1.zip" - ) - - logger.info(f"Downloading TIAN from {url}") - # Store initial download in a tempdir - with tempfile.TemporaryDirectory() as tmpdir: - # Make HTTP request - try: - resp = httpx.get(url) - resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - # Set tempfile for storing initial content and unzipping - zip_fname = Path(tmpdir) / "Tian2020MSA_v1.1.zip" - # Open tempfile and write content - with open(zip_fname, "wb") as f: - f.write(resp.content) - # Unzip tempfile - with zipfile.ZipFile(zip_fname, "r") as zip_ref: - zip_ref.extractall(parcellations_dir.as_posix()) - # Clean after unzipping - if (parcellations_dir / "__MACOSX").exists(): - shutil.rmtree((parcellations_dir / "__MACOSX").as_posix()) - - # Load labels - labels = pd.read_csv(parcellation_lname, sep=" ", header=None)[0].to_list() - - return parcellation_fname, labels + return parcellation_img_path, labels def _retrieve_suit( - parcellations_dir: Path, + dataset: "Dataset", resolution: Optional[float], space: str = "MNI152NLin6Asym", ) -> tuple[Path, list[str]]: @@ -959,8 +833,8 @@ def _retrieve_suit( Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -980,8 +854,6 @@ def _retrieve_suit( Raises ------ - RuntimeError - If there is a problem fetching files. ValueError If invalid value is provided for ``space``. @@ -1006,78 +878,32 @@ def _retrieve_suit( if space == "MNI152NLin6Asym": space = "MNI" - # Define parcellation and label file names - parcellation_fname = ( - parcellations_dir / "SUIT" / (f"SUIT_{space}Space_{resolution}mm.nii") + # Fetch file paths + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "SUIT" + / f"SUIT_{space}Space_{resolution}mm.nii", ) - parcellation_lname = ( - parcellations_dir / "SUIT" / (f"SUIT_{space}Space_{resolution}mm.tsv") + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "SUIT" + / f"SUIT_{space}Space_{resolution}mm.tsv", ) - # Check existence of parcellation - if not (parcellation_fname.exists() and parcellation_lname.exists()): - logger.info( - "At least one of the parcellation files is missing, fetching." - ) - # Create local directory if not present - parcellation_fname.parent.mkdir(exist_ok=True, parents=True) - # Set URL - url_basis = ( - "https://github.com/DiedrichsenLab/cerebellar_atlases/raw" - "/master/Diedrichsen_2009" - ) - if space == "MNI": - url = f"{url_basis}/atl-Anatom_space-MNI_dseg.nii" - else: # if not MNI, then SUIT - url = f"{url_basis}/atl-Anatom_space-SUIT_dseg.nii" - url_labels = f"{url_basis}/atl-Anatom.tsv" - - # Make HTTP requests - with httpx.Client(follow_redirects=True) as client: - # Download parcellation file - logger.info(f"Downloading SUIT parcellation from {url}") - try: - img_resp = client.get(url) - img_resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - with open(parcellation_fname, "wb") as f: - f.write(img_resp.content) - # Download label file - logger.info(f"Downloading SUIT labels from {url_labels}") - try: - label_resp = client.get(url_labels) - label_resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - # Load labels - labels = pd.read_csv( - io.StringIO(label_resp.content.decode("utf-8")), - sep="\t", - usecols=["name"], - ) - labels.to_csv(parcellation_lname, sep="\t", index=False) - # Load labels - labels = pd.read_csv(parcellation_lname, sep="\t", usecols=["name"])[ + labels = pd.read_csv(parcellation_label_path, sep="\t", usecols=["name"])[ "name" ].to_list() - return parcellation_fname, labels + return parcellation_img_path, labels def _retrieve_aicha( - parcellations_dir: Path, + dataset: "Dataset", resolution: Optional[float] = None, version: int = 2, ) -> tuple[Path, list[str]]: @@ -1085,8 +911,8 @@ def _retrieve_aicha( Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -1105,8 +931,6 @@ def _retrieve_aicha( Raises ------ - RuntimeError - If there is a problem fetching files. ValueError If invalid value is provided for ``version``. @@ -1143,99 +967,48 @@ def _retrieve_aicha( _valid_resolutions = [1] resolution = closest_resolution(resolution, _valid_resolutions) - # Define parcellation and label file names - parcellation_fname = ( - parcellations_dir / f"AICHA_v{version}" / "AICHA" / "AICHA.nii" + # Fetch file paths + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "AICHA" + / f"v{version}" + / "AICHA.nii", ) - parcellation_lname = Path() + # Conditional label file fetch if version == 1: - parcellation_lname = ( - parcellations_dir - / f"AICHA_v{version}" + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" / "AICHA" - / "AICHA_vol1.txt" + / f"v{version}" + / "AICHA_vol1.txt", ) elif version == 2: - parcellation_lname = ( - parcellations_dir - / f"AICHA_v{version}" + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" / "AICHA" - / "AICHA_vol3.txt" + / f"v{version}" + / "AICHA_vol3.txt", ) - # Check existence of parcellation - if not (parcellation_fname.exists() and parcellation_lname.exists()): - logger.info( - "At least one of the parcellation files are missing, fetching." - ) - # Set file name on server according to version - server_filename = "" - if version == 1: - server_filename = "aicha_v1.zip" - elif version == 2: - server_filename = "AICHA_v2.tar.zip" - # Set URL - url = f"http://www.gin.cnrs.fr/wp-content/uploads/{server_filename}" - - logger.info(f"Downloading AICHA v{version} from {url}") - # Store initial download in a tempdir - with tempfile.TemporaryDirectory() as tmpdir: - # Make HTTP request - try: - resp = httpx.get(url, follow_redirects=True) - resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - # Set tempfile for storing initial content and unzipping - parcellation_zip_path = Path(tmpdir) / server_filename - # Open tempfile and write content - with open(parcellation_zip_path, "wb") as f: - f.write(resp.content) - # Unzip tempfile - with zipfile.ZipFile(parcellation_zip_path, "r") as zip_ref: - if version == 1: - zip_ref.extractall( - (parcellations_dir / "AICHA_v1").as_posix() - ) - elif version == 2: - zip_ref.extractall(Path(tmpdir).as_posix()) - # Extract tarfile for v2 - with tarfile.TarFile( - Path(tmpdir) / "aicha_v2.tar", "r" - ) as tar_ref: - tar_ref.extractall( - (parcellations_dir / "AICHA_v2").as_posix() - ) - # Cleanup after unzipping - if ( - parcellations_dir / f"AICHA_v{version}" / "__MACOSX" - ).exists(): - shutil.rmtree( - ( - parcellations_dir - / f"AICHA_v{version}" - / "__MACOSX" - ).as_posix() - ) - # Load labels labels = pd.read_csv( - parcellation_lname, + parcellation_label_path, sep="\t", header=None, - skiprows=[0], # type: ignore + skiprows=[0], )[0].to_list() - return parcellation_fname, labels + return parcellation_img_path, labels -def _retrieve_shen( # noqa: C901 - parcellations_dir: Path, +def _retrieve_shen( + dataset: "Dataset", resolution: Optional[float] = None, year: int = 2015, n_rois: int = 268, @@ -1244,8 +1017,8 @@ def _retrieve_shen( # noqa: C901 Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -1269,8 +1042,6 @@ def _retrieve_shen( # noqa: C901 Raises ------ - RuntimeError - If there is a problem fetching files. ValueError If invalid value or combination is provided for ``year`` and ``n_rois``. @@ -1323,123 +1094,60 @@ def _retrieve_shen( # noqa: C901 f"`year = {year}` is invalid" ) - # Define parcellation and label file names + # Fetch file paths based on year if year == 2013: - parcellation_fname = ( - parcellations_dir - / "Shen_2013" - / "shenetal_neuroimage2013" - / f"fconn_atlas_{n_rois}_{resolution}mm.nii" - ) - parcellation_lname = ( - parcellations_dir - / "Shen_2013" - / "shenetal_neuroimage2013" - / f"Group_seg{n_rois}_BAindexing_setA.txt" - ) - elif year == 2015: - parcellation_fname = ( - parcellations_dir - / "Shen_2015" - / f"shen_{resolution}mm_268_parcellation.nii.gz" - ) - elif year == 2019: - parcellation_fname = ( - parcellations_dir - / "Shen_2019" - / "Shen_1mm_368_parcellation.nii.gz" + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Shen" + / "2013" + / f"fconn_atlas_{n_rois}_{resolution}mm.nii", ) - - # Check existence of parcellation - if not parcellation_fname.exists(): - logger.info( - "At least one of the parcellation files are missing, fetching." + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Shen" + / "2013" + / f"Group_seg{n_rois}_BAindexing_setA.txt", ) - - # Set URL based on year - url = "" - if year == 2013: - url = "https://www.nitrc.org/frs/download.php/5785/shenetal_neuroimage2013_funcatlas.zip" - elif year == 2015: - # Set URL based on resolution - if resolution == 1: - url = "https://www.nitrc.org/frs/download.php/7976/shen_1mm_268_parcellation.nii.gz" - elif resolution == 2: - url = "https://www.nitrc.org/frs/download.php/7977/shen_2mm_268_parcellation.nii.gz" - elif year == 2019: - url = "https://www.nitrc.org/frs/download.php/11629/shen_368.zip" - - logger.info(f"Downloading Shen {year} from {url}") - # Store initial download in a tempdir - with tempfile.TemporaryDirectory() as tmpdir: - # Make HTTP request - try: - resp = httpx.get(url) - resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - if year in (2013, 2019): - parcellation_zip_path = Path(tmpdir) / f"Shen{year}.zip" - # Open tempfile and write content - with open(parcellation_zip_path, "wb") as f: - f.write(resp.content) - # Unzip tempfile - with zipfile.ZipFile( - parcellation_zip_path, "r" - ) as zip_ref: - zip_ref.extractall( - (parcellations_dir / f"Shen_{year}").as_posix() - ) - # Cleanup after unzipping - if ( - parcellations_dir / f"Shen_{year}" / "__MACOSX" - ).exists(): - shutil.rmtree( - ( - parcellations_dir / f"Shen_{year}" / "__MACOSX" - ).as_posix() - ) - elif year == 2015: - img_dir_path = parcellations_dir / "Shen_2015" - # Create local directory if not present - img_dir_path.mkdir(parents=True, exist_ok=True) - img_path = ( - img_dir_path - / f"shen_{resolution}mm_268_parcellation.nii.gz" - ) - # Create local file if not present - img_path.touch(exist_ok=True) - # Open tempfile and write content - with open(img_path, "wb") as f: - f.write(resp.content) - - # Load labels based on year - if year == 2013: labels = ( pd.read_csv( - parcellation_lname, # type: ignore - sep=",", # type: ignore - header=None, # type: ignore - skiprows=[0], # type: ignore + parcellation_label_path, + sep=",", + header=None, + skiprows=[0], )[1] .map(lambda x: x.strip()) # fix formatting .to_list() ) elif year == 2015: + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Shen" + / "2015" + / f"shen_{resolution}mm_268_parcellation.nii.gz", + ) labels = list(range(1, 269)) elif year == 2019: + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Shen" + / "2019" + / "Shen_1mm_368_parcellation.nii.gz", + ) labels = list(range(1, 369)) - return parcellation_fname, labels + return parcellation_img_path, labels def _retrieve_yan( - parcellations_dir: Path, + dataset: "Dataset", resolution: Optional[float] = None, n_rois: Optional[int] = None, yeo_networks: Optional[int] = None, @@ -1449,8 +1157,8 @@ def _retrieve_yan( Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : float, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -1473,8 +1181,6 @@ def _retrieve_yan( Raises ------ - RuntimeError - If there is a problem fetching files. ValueError If invalid value is provided for ``n_rois``, ``yeo_networks`` or ``kong_networks``. @@ -1507,8 +1213,7 @@ def _retrieve_yan( f"following: {_valid_n_rois}" ) - parcellation_fname = Path() - parcellation_lname = Path() + # Fetch file paths based on networks if yeo_networks: # Check yeo_networks value _valid_yeo_networks = [7, 17] @@ -1517,19 +1222,25 @@ def _retrieve_yan( f"The parameter `yeo_networks` ({yeo_networks}) needs to be " f"one of the following: {_valid_yeo_networks}" ) - # Define image and label file according to network - parcellation_fname = ( - parcellations_dir - / "Yan_2023" + + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Yan2023" + / "Yeo2011" / ( f"{n_rois}Parcels_Yeo2011_{yeo_networks}Networks_FSLMNI152_" f"{resolution}mm.nii.gz" - ) + ), ) - parcellation_lname = ( - parcellations_dir - / "Yan_2023" - / f"{n_rois}Parcels_Yeo2011_{yeo_networks}Networks_LUT.txt" + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Yan2023" + / "Yeo2011" + / f"{n_rois}Parcels_Yeo2011_{yeo_networks}Networks_LUT.txt", ) elif kong_networks: # Check kong_networks value @@ -1539,106 +1250,37 @@ def _retrieve_yan( f"The parameter `kong_networks` ({kong_networks}) needs to be " f"one of the following: {_valid_kong_networks}" ) - # Define image and label file according to network - parcellation_fname = ( - parcellations_dir - / "Yan_2023" + + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Yan2023" + / "Kong2022" / ( f"{n_rois}Parcels_Kong2022_{kong_networks}Networks_FSLMNI152_" f"{resolution}mm.nii.gz" - ) - ) - parcellation_lname = ( - parcellations_dir - / "Yan_2023" - / f"{n_rois}Parcels_Kong2022_{kong_networks}Networks_LUT.txt" + ), ) - - # Check for existence of parcellation: - if not parcellation_fname.exists() and not parcellation_lname.exists(): - logger.info( - "At least one of the parcellation files are missing, fetching." + parcellation_label_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Yan2023" + / "Kong2022" + / f"{n_rois}Parcels_Kong2022_{kong_networks}Networks_LUT.txt", ) - # Set URL based on network - img_url = "" - label_url = "" - if yeo_networks: - img_url = ( - "https://raw.githubusercontent.com/ThomasYeoLab/CBIG/" - "master/stable_projects/brain_parcellation/Yan2023_homotopic/" - f"parcellations/MNI/yeo{yeo_networks}/{n_rois}Parcels_Yeo2011" - f"_{yeo_networks}Networks_FSLMNI152_{resolution}mm.nii.gz" - ) - label_url = ( - "https://raw.githubusercontent.com/ThomasYeoLab/CBIG/" - "master/stable_projects/brain_parcellation/Yan2023_homotopic/" - f"parcellations/MNI/yeo{yeo_networks}/freeview_lut/{n_rois}" - f"Parcels_Yeo2011_{yeo_networks}Networks_LUT.txt" - ) - elif kong_networks: - img_url = ( - "https://raw.githubusercontent.com/ThomasYeoLab/CBIG/" - "master/stable_projects/brain_parcellation/Yan2023_homotopic/" - f"parcellations/MNI/kong17/{n_rois}Parcels_Kong2022" - f"_17Networks_FSLMNI152_{resolution}mm.nii.gz" - ) - label_url = ( - "https://raw.githubusercontent.com/ThomasYeoLab/CBIG/" - "master/stable_projects/brain_parcellation/Yan2023_homotopic/" - f"parcellations/MNI/kong17/freeview_lut/{n_rois}Parcels_" - "Kong2022_17Networks_LUT.txt" - ) - - # Make HTTP requests - with httpx.Client() as client: - # Download parcellation file - logger.info(f"Downloading Yan 2023 parcellation from {img_url}") - try: - img_resp = client.get(img_url) - img_resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - parcellation_img_path = Path(parcellation_fname) - # Create local directory if not present - parcellation_img_path.parent.mkdir(parents=True, exist_ok=True) - # Create local file if not present - parcellation_img_path.touch(exist_ok=True) - # Open file and write content - with open(parcellation_img_path, "wb") as f: - f.write(img_resp.content) - # Download label file - logger.info(f"Downloading Yan 2023 labels from {label_url}") - try: - label_resp = client.get(label_url) - label_resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - parcellation_labels_path = Path(parcellation_lname) - # Create local file if not present - parcellation_labels_path.touch(exist_ok=True) - # Open file and write content - with open(parcellation_labels_path, "wb") as f: - f.write(label_resp.content) - # Load label file - labels = pd.read_csv(parcellation_lname, sep=" ", header=None)[1].to_list() + labels = pd.read_csv(parcellation_label_path, sep=" ", header=None)[ + 1 + ].to_list() - return parcellation_fname, labels + return parcellation_img_path, labels def _retrieve_brainnetome( - parcellations_dir: Path, + dataset: "Dataset", resolution: Optional[float] = None, threshold: Optional[int] = None, ) -> tuple[Path, list[str]]: @@ -1646,8 +1288,8 @@ def _retrieve_brainnetome( Parameters ---------- - parcellations_dir : pathlib.Path - The path to the parcellation data directory. + dataset : datalad.api.Dataset + The datalad dataset to fetch parcellation from. resolution : {1.0, 1.25, 2.0}, optional The desired resolution of the parcellation to load. If it is not available, the closest resolution will be loaded. Preferably, use a @@ -1666,8 +1308,6 @@ def _retrieve_brainnetome( Raises ------ - RuntimeError - If there is a problem fetching files. ValueError If invalid value is provided for ``threshold``. @@ -1691,36 +1331,15 @@ def _retrieve_brainnetome( if resolution in [1.0, 2.0]: resolution = int(resolution) - parcellation_fname = ( - parcellations_dir - / "BNA246" - / f"BNA-maxprob-thr{threshold}-{resolution}mm.nii.gz" + # Fetch file path + parcellation_img_path = fetch_file_via_datalad( + dataset=dataset, + file_path=dataset.pathobj + / "parcellations" + / "Brainnetome" + / f"BNA-maxprob-thr{threshold}-{resolution}mm.nii.gz", ) - # Check for existence of parcellation - if not parcellation_fname.exists(): - # Set URL - url = f"http://neurovault.org/media/images/1625/BNA-maxprob-thr{threshold}-{resolution}mm.nii.gz" - - logger.info(f"Downloading Brainnetome from {url}") - # Make HTTP request - try: - resp = httpx.get(url, follow_redirects=True) - resp.raise_for_status() - except httpx.HTTPError as exc: - raise_error( - f"Error response {exc.response.status_code} while " - f"requesting {exc.request.url!r}", - klass=RuntimeError, - ) - else: - # Create local directory if not present - parcellation_fname.parent.mkdir(parents=True, exist_ok=True) - # Create file if not present - parcellation_fname.touch(exist_ok=True) - # Open file and write bytes - parcellation_fname.write_bytes(resp.content) - # Load labels labels = ( sorted([f"SFG_L(R)_7_{i}" for i in range(1, 8)] * 2) @@ -1750,7 +1369,7 @@ def _retrieve_brainnetome( + sorted([f"Tha_L(R)_8_{i}" for i in range(1, 9)] * 2) ) - return parcellation_fname, labels + return parcellation_img_path, labels def merge_parcellations( From a986bc99b0dc270e29ebbc6bdf92712eadc4cbbb Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Dec 2024 14:24:39 +0100 Subject: [PATCH 16/22] chore: update ParcellationRegistry tests --- .../parcellations/tests/test_parcellations.py | 293 +++++------------- 1 file changed, 82 insertions(+), 211 deletions(-) diff --git a/junifer/data/parcellations/tests/test_parcellations.py b/junifer/data/parcellations/tests/test_parcellations.py index 23a5aceff2..247956622e 100644 --- a/junifer/data/parcellations/tests/test_parcellations.py +++ b/junifer/data/parcellations/tests/test_parcellations.py @@ -18,13 +18,13 @@ from junifer.data.parcellations._parcellations import ( _retrieve_aicha, _retrieve_brainnetome, - _retrieve_parcellation, _retrieve_schaefer, _retrieve_shen, _retrieve_suit, _retrieve_tian, _retrieve_yan, ) +from junifer.data.utils import check_dataset from junifer.datareader import DefaultDataReader from junifer.pipeline.utils import _check_ants from junifer.testing.datagrabbers import ( @@ -247,12 +247,6 @@ def test_load_incorrect() -> None: ParcellationRegistry().load("wrongparcellation", "MNI152NLin6Asym") -def test_retrieve_parcellation_incorrect() -> None: - """Test retrieval of invalid parcellations.""" - with pytest.raises(ValueError, match=r"provided parcellation name"): - _retrieve_parcellation("wrongparcellation") - - @pytest.mark.parametrize( "resolution, n_rois, yeo_networks", [ @@ -299,7 +293,6 @@ def test_retrieve_parcellation_incorrect() -> None: ], ) def test_schaefer( - tmp_path: Path, resolution: float, n_rois: int, yeo_networks: int, @@ -308,8 +301,6 @@ def test_schaefer( Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. resolution : float The parametrized resolution values. n_rois : int @@ -329,7 +320,6 @@ def test_schaefer( img, label, img_path, space = ParcellationRegistry().load( name=parcellation_name, target_space="MNI152NLin6Asym", - parcellations_dir=tmp_path, resolution=resolution, ) assert img is not None @@ -341,36 +331,22 @@ def test_schaefer( ) -def test_retrieve_schaefer_incorrect_n_rois(tmp_path: Path) -> None: - """Test retrieve Schaefer with incorrect ROIs. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_schaefer_incorrect_n_rois() -> None: + """Test retrieve Schaefer with incorrect ROIs.""" with pytest.raises(ValueError, match=r"The parameter `n_rois`"): _retrieve_schaefer( - parcellations_dir=tmp_path, + dataset=check_dataset(), resolution=1, n_rois=101, yeo_networks=7, ) -def test_retrieve_schaefer_incorrect_yeo_networks(tmp_path: Path) -> None: - """Test retrieve Schaefer with incorrect Yeo networks. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_schaefer_incorrect_yeo_networks() -> None: + """Test retrieve Schaefer with incorrect Yeo networks.""" with pytest.raises(ValueError, match=r"The parameter `yeo_networks`"): _retrieve_schaefer( - parcellations_dir=tmp_path, + dataset=check_dataset(), resolution=1, n_rois=100, yeo_networks=8, @@ -381,13 +357,11 @@ def test_retrieve_schaefer_incorrect_yeo_networks(tmp_path: Path) -> None: "space_key, space", [("SUIT", "SUIT"), ("MNI", "MNI152NLin6Asym")], ) -def test_suit(tmp_path: Path, space_key: str, space: str) -> None: +def test_suit(space_key: str, space: str) -> None: """Test SUIT parcellation. Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. space_key : str The parametrized space values for the key. space : str @@ -399,7 +373,6 @@ def test_suit(tmp_path: Path, space_key: str, space: str) -> None: img, label, img_path, parcellation_space = ParcellationRegistry().load( name=f"SUITx{space_key}", target_space=space, - parcellations_dir=tmp_path, ) assert img is not None assert img_path.name == f"SUIT_{space_key}Space_1mm.nii" @@ -408,33 +381,20 @@ def test_suit(tmp_path: Path, space_key: str, space: str) -> None: assert_array_equal(img.header["pixdim"][1:4], [1, 1, 1]) # type: ignore -def test_retrieve_suit_incorrect_space(tmp_path: Path) -> None: - """Test retrieve SUIT with incorrect space. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_suit_incorrect_space() -> None: + """Test retrieve SUIT with incorrect space.""" with pytest.raises(ValueError, match=r"The parameter `space`"): - _retrieve_suit( - parcellations_dir=tmp_path, resolution=1.0, space="wrong" - ) + _retrieve_suit(dataset=check_dataset(), resolution=1.0, space="wrong") @pytest.mark.parametrize( "scale, n_label", [(1, 16), (2, 32), (3, 50), (4, 54)] ) -def test_tian_3T_6thgeneration( - tmp_path: Path, scale: int, n_label: int -) -> None: +def test_tian_3T_6thgeneration(scale: int, n_label: int) -> None: """Test Tian parcellation. Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. scale : int The parametrized scale values. n_label : int @@ -447,44 +407,38 @@ def test_tian_3T_6thgeneration( assert "TianxS3x3TxMNI6thgeneration" in parcellations assert "TianxS4x3TxMNI6thgeneration" in parcellations # Load parcellation - img, lbl, fname, parcellation_space_1 = ParcellationRegistry().load( + img, lbl, fname, space = ParcellationRegistry().load( name=f"TianxS{scale}x3TxMNI6thgeneration", - parcellations_dir=tmp_path, - target_space="MNI152NLin2009cAsym", + target_space="MNI152NLin2009cAsym", # force highest resolution ) - fname1 = f"Tian_Subcortex_S{scale}_3T_1mm.nii.gz" + expected_fname = f"Tian_Subcortex_S{scale}_3T_1mm.nii.gz" assert img is not None - assert fname.name == fname1 - assert parcellation_space_1 == "MNI152NLin6Asym" + assert fname.name == expected_fname + assert space == "MNI152NLin6Asym" assert len(lbl) == n_label - assert_array_equal(img.header["pixdim"][1:4], [1, 1, 1]) # type: ignore + assert_array_equal(img.header["pixdim"][1:4], [1, 1, 1]) # Load parcellation - img, lbl, fname, parcellation_space_2 = ParcellationRegistry().load( + img, lbl, fname, space = ParcellationRegistry().load( name=f"TianxS{scale}x3TxMNI6thgeneration", target_space="MNI152NLin6Asym", - parcellations_dir=tmp_path, resolution=2, ) - fname1 = f"Tian_Subcortex_S{scale}_3T.nii.gz" + expected_fname = f"Tian_Subcortex_S{scale}_3T.nii.gz" assert img is not None - assert fname.name == fname1 - assert parcellation_space_2 == "MNI152NLin6Asym" + assert fname.name == expected_fname + assert space == "MNI152NLin6Asym" assert len(lbl) == n_label - assert_array_equal(img.header["pixdim"][1:4], [2, 2, 2]) # type: ignore + assert_array_equal(img.header["pixdim"][1:4], [2, 2, 2]) @pytest.mark.parametrize( "scale, n_label", [(1, 16), (2, 32), (3, 50), (4, 54)] ) -def test_tian_3T_nonlinear2009cAsym( - tmp_path: Path, scale: int, n_label: int -) -> None: +def test_tian_3T_nonlinear2009cAsym(scale: int, n_label: int) -> None: """Test Tian parcellation. Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. scale : int The parametrized scale values. n_label : int @@ -497,31 +451,38 @@ def test_tian_3T_nonlinear2009cAsym( assert "TianxS3x3TxMNInonlinear2009cAsym" in parcellations assert "TianxS4x3TxMNInonlinear2009cAsym" in parcellations # Load parcellation + img, lbl, fname, space = ParcellationRegistry().load( + name=f"TianxS{scale}x3TxMNInonlinear2009cAsym", + target_space="MNI152NLin6Asym", # force highest resolution + ) + expected_fname = f"Tian_Subcortex_S{scale}_3T_2009cAsym_1mm.nii.gz" + assert img is not None + assert fname.name == expected_fname + assert space == "MNI152NLin2009cAsym" + assert len(lbl) == n_label + assert_array_equal(img.header["pixdim"][1:4], [1, 1, 1]) + # Load parcellation img, lbl, fname, space = ParcellationRegistry().load( name=f"TianxS{scale}x3TxMNInonlinear2009cAsym", target_space="MNI152NLin2009cAsym", - parcellations_dir=tmp_path, + resolution=2, ) - fname1 = f"Tian_Subcortex_S{scale}_3T_2009cAsym.nii.gz" + expected_fname = f"Tian_Subcortex_S{scale}_3T_2009cAsym.nii.gz" assert img is not None - assert fname.name == fname1 + assert fname.name == expected_fname assert space == "MNI152NLin2009cAsym" assert len(lbl) == n_label - assert_array_equal(img.header["pixdim"][1:4], [2, 2, 2]) # type: ignore + assert_array_equal(img.header["pixdim"][1:4], [2, 2, 2]) @pytest.mark.parametrize( "scale, n_label", [(1, 16), (2, 34), (3, 54), (4, 62)] ) -def test_tian_7T_6thgeneration( - tmp_path: Path, scale: int, n_label: int -) -> None: +def test_tian_7T_6thgeneration(scale: int, n_label: int) -> None: """Test Tian parcellation. Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. scale : int The parametrized scale values. n_label : int @@ -537,7 +498,6 @@ def test_tian_7T_6thgeneration( img, lbl, fname, space = ParcellationRegistry().load( name=f"TianxS{scale}x7TxMNI6thgeneration", target_space="MNI152NLin6Asym", - parcellations_dir=tmp_path, ) fname1 = f"Tian_Subcortex_S{scale}_7T.nii.gz" assert img is not None @@ -549,23 +509,16 @@ def test_tian_7T_6thgeneration( ) -def test_retrieve_tian_incorrect_space(tmp_path: Path) -> None: - """Test retrieve tian with incorrect space. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_tian_incorrect_space() -> None: + """Test retrieve tian with incorrect space.""" with pytest.raises(ValueError, match=r"The parameter `space`"): _retrieve_tian( - parcellations_dir=tmp_path, resolution=1, scale=1, space="wrong" + dataset=check_dataset(), resolution=1, scale=1, space="wrong" ) with pytest.raises(ValueError, match=r"MNI152NLin6Asym"): _retrieve_tian( - parcellations_dir=tmp_path, + dataset=check_dataset(), resolution=1, scale=1, magneticfield="7T", @@ -573,18 +526,11 @@ def test_retrieve_tian_incorrect_space(tmp_path: Path) -> None: ) -def test_retrieve_tian_incorrect_magneticfield(tmp_path: Path) -> None: - """Test retrieve tian with incorrect magneticfield. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_tian_incorrect_magneticfield() -> None: + """Test retrieve tian with incorrect magneticfield.""" with pytest.raises(ValueError, match=r"The parameter `magneticfield`"): _retrieve_tian( - parcellations_dir=tmp_path, + dataset=check_dataset(), resolution=1, scale=1, magneticfield="wrong", @@ -592,17 +538,10 @@ def test_retrieve_tian_incorrect_magneticfield(tmp_path: Path) -> None: def test_retrieve_tian_incorrect_scale(tmp_path: Path) -> None: - """Test retrieve tian with incorrect scale. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ + """Test retrieve tian with incorrect scale.""" with pytest.raises(ValueError, match=r"The parameter `scale`"): _retrieve_tian( - parcellations_dir=tmp_path, + dataset=check_dataset(), resolution=1, scale=5, space="MNI152NLin6Asym", @@ -610,7 +549,7 @@ def test_retrieve_tian_incorrect_scale(tmp_path: Path) -> None: @pytest.mark.parametrize("version", [1, 2]) -def test_aicha(tmp_path: Path, version: int) -> None: +def test_aicha(version: int) -> None: """Test AICHA parcellation. Parameters @@ -626,7 +565,6 @@ def test_aicha(tmp_path: Path, version: int) -> None: img, label, img_path, space = ParcellationRegistry().load( name=f"AICHA_v{version}", target_space="IXI549Space", - parcellations_dir=tmp_path, ) assert img is not None assert img_path.name == "AICHA.nii" @@ -635,18 +573,11 @@ def test_aicha(tmp_path: Path, version: int) -> None: assert_array_equal(img.header["pixdim"][1:4], [2, 2, 2]) # type: ignore -def test_retrieve_aicha_incorrect_version(tmp_path: Path) -> None: - """Test retrieve AICHA with incorrect version. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_aicha_incorrect_version() -> None: + """Test retrieve AICHA with incorrect version.""" with pytest.raises(ValueError, match="The parameter `version`"): _retrieve_aicha( - parcellations_dir=tmp_path, + dataset=check_dataset(), version=100, ) @@ -666,7 +597,6 @@ def test_retrieve_aicha_incorrect_version(tmp_path: Path) -> None: ], ) def test_shen( - tmp_path: Path, resolution: float, year: int, n_rois: int, @@ -677,8 +607,6 @@ def test_shen( Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. resolution : float The parametrized resolution values. year : int @@ -696,7 +624,6 @@ def test_shen( img, label, img_path, space = ParcellationRegistry().load( name=f"Shen_{year}_{n_rois}", target_space="MNI152NLin2009cAsym", - parcellations_dir=tmp_path, resolution=resolution, ) assert img is not None @@ -708,34 +635,20 @@ def test_shen( ) -def test_retrieve_shen_incorrect_year(tmp_path: Path) -> None: - """Test retrieve Shen with incorrect year. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_shen_incorrect_year() -> None: + """Test retrieve Shen with incorrect year.""" with pytest.raises(ValueError, match="The parameter `year`"): _retrieve_shen( - parcellations_dir=tmp_path, + dataset=check_dataset(), year=1969, ) -def test_retrieve_shen_incorrect_n_rois(tmp_path: Path) -> None: - """Test retrieve Shen with incorrect ROIs. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_shen_incorrect_n_rois() -> None: + """Test retrieve Shen with incorrect ROIs.""" with pytest.raises(ValueError, match="The parameter `n_rois`"): _retrieve_shen( - parcellations_dir=tmp_path, + dataset=check_dataset(), year=2015, n_rois=10, ) @@ -758,7 +671,6 @@ def test_retrieve_shen_incorrect_n_rois(tmp_path: Path) -> None: ], ) def test_retrieve_shen_incorrect_param_combo( - tmp_path: Path, resolution: float, year: int, n_rois: int, @@ -779,7 +691,7 @@ def test_retrieve_shen_incorrect_param_combo( """ with pytest.raises(ValueError, match="The parameter combination"): _retrieve_shen( - parcellations_dir=tmp_path, + dataset=check_dataset(), resolution=resolution, year=year, n_rois=n_rois, @@ -852,7 +764,6 @@ def test_retrieve_shen_incorrect_param_combo( ], ) def test_yan( - tmp_path: Path, resolution: float, n_rois: int, yeo_networks: int, @@ -862,8 +773,6 @@ def test_yan( Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. resolution : float The parametrized resolution values. n_rois : int @@ -893,7 +802,6 @@ def test_yan( img, label, img_path, space = ParcellationRegistry().load( name=parcellation_name, target_space="MNI152NLin6Asym", - parcellations_dir=tmp_path, resolution=resolution, ) assert img is not None @@ -905,20 +813,13 @@ def test_yan( ) -def test_retrieve_yan_incorrect_networks(tmp_path: Path) -> None: - """Test retrieve Yan with incorrect networks. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_yan_incorrect_networks() -> None: + """Test retrieve Yan with incorrect networks.""" with pytest.raises( ValueError, match="Either one of `yeo_networks` or `kong_networks`" ): _retrieve_yan( - parcellations_dir=tmp_path, + dataset=check_dataset(), n_rois=31418, yeo_networks=100, kong_networks=100, @@ -928,59 +829,38 @@ def test_retrieve_yan_incorrect_networks(tmp_path: Path) -> None: ValueError, match="Either one of `yeo_networks` or `kong_networks`" ): _retrieve_yan( - parcellations_dir=tmp_path, + dataset=check_dataset(), n_rois=31418, yeo_networks=None, kong_networks=None, ) -def test_retrieve_yan_incorrect_n_rois(tmp_path: Path) -> None: - """Test retrieve Yan with incorrect ROIs. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_yan_incorrect_n_rois() -> None: + """Test retrieve Yan with incorrect ROIs.""" with pytest.raises(ValueError, match="The parameter `n_rois`"): _retrieve_yan( - parcellations_dir=tmp_path, + dataset=check_dataset(), n_rois=31418, yeo_networks=7, ) -def test_retrieve_yan_incorrect_yeo_networks(tmp_path: Path) -> None: - """Test retrieve Yan with incorrect Yeo networks. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_yan_incorrect_yeo_networks() -> None: + """Test retrieve Yan with incorrect Yeo networks.""" with pytest.raises(ValueError, match="The parameter `yeo_networks`"): _retrieve_yan( - parcellations_dir=tmp_path, + dataset=check_dataset(), n_rois=100, yeo_networks=27, ) -def test_retrieve_yan_incorrect_kong_networks(tmp_path: Path) -> None: - """Test retrieve Yan with incorrect Kong networks. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_yan_incorrect_kong_networks() -> None: + """Test retrieve Yan with incorrect Kong networks.""" with pytest.raises(ValueError, match="The parameter `kong_networks`"): _retrieve_yan( - parcellations_dir=tmp_path, + dataset=check_dataset(), n_rois=100, kong_networks=27, ) @@ -1001,7 +881,6 @@ def test_retrieve_yan_incorrect_kong_networks(tmp_path: Path) -> None: ], ) def test_brainnetome( - tmp_path: Path, resolution: float, threshold: int, ) -> None: @@ -1009,8 +888,6 @@ def test_brainnetome( Parameters ---------- - tmp_path : pathlib.Path - The path to the test directory. resolution : float The parametrized resolution values. threshold : int @@ -1030,7 +907,6 @@ def test_brainnetome( img, label, img_path, space = ParcellationRegistry().load( name=parcellation_name, target_space="MNI152NLin6Asym", - parcellations_dir=tmp_path, resolution=resolution, ) assert img is not None @@ -1042,18 +918,11 @@ def test_brainnetome( ) -def test_retrieve_brainnetome_incorrect_threshold(tmp_path: Path) -> None: - """Test retrieve Brainnetome with incorrect threshold. - - Parameters - ---------- - tmp_path : pathlib.Path - The path to the test directory. - - """ +def test_retrieve_brainnetome_incorrect_threshold() -> None: + """Test retrieve Brainnetome with incorrect threshold.""" with pytest.raises(ValueError, match="The parameter `threshold`"): _retrieve_brainnetome( - parcellations_dir=tmp_path, + dataset=check_dataset(), threshold=100, ) @@ -1214,7 +1083,7 @@ def test_get_single() -> None: bold_img = bold["data"] # Get tailored parcellation tailored_parcellation, tailored_labels = ParcellationRegistry().get( - parcellations=["TianxS1x3TxMNInonlinear2009cAsym"], + parcellations=["Shen_2015_268"], target_data=bold, ) # Check shape and affine with original element data @@ -1222,9 +1091,9 @@ def test_get_single() -> None: assert_array_equal(tailored_parcellation.affine, bold_img.affine) # Get raw parcellation raw_parcellation, raw_labels, _, _ = ParcellationRegistry().load( - "TianxS1x3TxMNInonlinear2009cAsym", + name="Shen_2015_268", target_space="MNI152NLin2009cAsym", - resolution=1.5, + resolution=4, ) resampled_raw_parcellation = resample_to_img( source_img=raw_parcellation, @@ -1266,7 +1135,9 @@ def test_get_multi_same_space() -> None: ] for name in parcellations_names: img, labels, _, _ = ParcellationRegistry().load( - name=name, target_space="MNI152NLin2009cAsym", resolution=1.5 + name=name, + target_space="MNI152NLin2009cAsym", + resolution=4, ) # Resample raw parcellations resampled_img = resample_to_img( From 91b59d7922a35a57a83f046650a07d4a85c975fb Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Dec 2024 14:25:14 +0100 Subject: [PATCH 17/22] chore: remove unnecessary Path conversion in MaskRegistry --- junifer/data/masks/_masks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 97539005fb..d68a76ac45 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -401,7 +401,7 @@ def load( # Check if the mask family is custom or built-in mask_img = None if t_family == "CustomUserMask": - mask_fname = Path(mask_definition["path"]) + mask_fname = mask_definition["path"] elif t_family == "Callable": mask_img = mask_definition["func"] mask_fname = None From a47b713e3f3e021556e13bd92e972f6c7aece876 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Dec 2024 14:25:54 +0100 Subject: [PATCH 18/22] chore: improve log messages in template_spaces.py --- junifer/data/template_spaces.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/junifer/data/template_spaces.py b/junifer/data/template_spaces.py index 12d45add22..c51e1719cc 100644 --- a/junifer/data/template_spaces.py +++ b/junifer/data/template_spaces.py @@ -116,7 +116,7 @@ def get_template( logger.info( f"Downloading template {space} ({template_type} in " - f"resolution {resolution}" + f"resolution {resolution})" ) # Retrieve template try: @@ -150,8 +150,10 @@ def get_template( ) except Exception: # noqa: BLE001 raise_error( - f"Template {space} ({template_type}) with resolution {resolution} " - "not found", + msg=( + f"Template {space} ({template_type}) with resolution " + f"{resolution}) not found" + ), klass=RuntimeError, ) else: From 8b98211eec79c5175d4a1e4bb566f13a88c864a2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Dec 2024 14:26:14 +0100 Subject: [PATCH 19/22] chore: remove unnecessary type check stop --- junifer/data/template_spaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/template_spaces.py b/junifer/data/template_spaces.py index c51e1719cc..f0705a090d 100644 --- a/junifer/data/template_spaces.py +++ b/junifer/data/template_spaces.py @@ -157,4 +157,4 @@ def get_template( klass=RuntimeError, ) else: - return nib.load(template_path) # type: ignore + return nib.load(template_path) From 42c0d17c1e698d3428647eac446ae9b0a493ceca Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Dec 2024 14:28:08 +0100 Subject: [PATCH 20/22] chore: remove httpx from deps --- junifer/cli/tests/test_cli_utils.py | 2 -- pyproject.toml | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/junifer/cli/tests/test_cli_utils.py b/junifer/cli/tests/test_cli_utils.py index e08033405a..d8a18da8f4 100644 --- a/junifer/cli/tests/test_cli_utils.py +++ b/junifer/cli/tests/test_cli_utils.py @@ -43,7 +43,6 @@ def test_get_dependency_information_short() -> None: "nilearn", "sqlalchemy", "ruamel.yaml", - "httpx", "tqdm", "templateflow", "lapy", @@ -73,7 +72,6 @@ def test_get_dependency_information_long() -> None: "nilearn", "sqlalchemy", "ruamel.yaml", - "httpx", "tqdm", "templateflow", "lapy", diff --git a/pyproject.toml b/pyproject.toml index 4e9b5fffc6..d0b4f6274e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ "sqlalchemy>=2.0.25,<=2.1.0", "ruamel.yaml>=0.17,<0.19", "h5py>=3.10", - "httpx[http2]>=0.26.0,<0.28.0", "tqdm>=4.66.1,<4.67.0", "templateflow>=23.0.0", "lapy>=1.0.0,<2.0.0", @@ -195,7 +194,7 @@ ignore = [ [tool.ruff.lint.isort] lines-after-imports = 2 known-first-party = ["junifer"] -known-third-party =[ +known-third-party = [ "click", "numpy", "scipy", @@ -206,7 +205,6 @@ known-third-party =[ "sqlalchemy", "yaml", "importlib_metadata", - "httpx", "tqdm", "templateflow", "bct", From f49524153695fd27a3a782c3e66c5f2dcaa9bcf1 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 8 Jan 2025 16:22:35 +0100 Subject: [PATCH 21/22] chore: add changelogs 418.{enh,misc} --- docs/changes/newsfragments/418.enh | 1 + docs/changes/newsfragments/418.misc | 1 + 2 files changed, 2 insertions(+) create mode 100644 docs/changes/newsfragments/418.enh create mode 100644 docs/changes/newsfragments/418.misc diff --git a/docs/changes/newsfragments/418.enh b/docs/changes/newsfragments/418.enh new file mode 100644 index 0000000000..ec0af2631f --- /dev/null +++ b/docs/changes/newsfragments/418.enh @@ -0,0 +1 @@ +Adapt usage of ``junifer-data`` DataLad dataset to fetch parcellations, masks, coordinates and xfms by `Synchon Mandal`_ diff --git a/docs/changes/newsfragments/418.misc b/docs/changes/newsfragments/418.misc new file mode 100644 index 0000000000..2e20d03513 --- /dev/null +++ b/docs/changes/newsfragments/418.misc @@ -0,0 +1 @@ +Remove ``httpx`` as a dependency by `Synchon Mandal`_ From ac76b08a48021134ce644b30028f87c6fb9bbd51 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Fri, 10 Jan 2025 14:04:26 +0100 Subject: [PATCH 22/22] update: adjust log level in data.utils.fetch_file_via_datalad --- junifer/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/utils.py b/junifer/data/utils.py index 27c05f6400..29f3d0dacd 100644 --- a/junifer/data/utils.py +++ b/junifer/data/utils.py @@ -208,7 +208,7 @@ def fetch_file_via_datalad(dataset: dl.Dataset, file_path: Path) -> Path: logger.info(f"Successfully fetched file: {got_path.resolve()}") return got_path elif status == "notneeded": - logger.info(f"Found existing file: {got_path.resolve()}") + logger.debug(f"Found existing file: {got_path.resolve()}") return got_path else: raise_error(