Skip to content

Commit

Permalink
Simplify implementation of on-disk resources (#35)
Browse files Browse the repository at this point in the history
I had factored out the repetitive pattern that connected
on-disk and in-memory implementations, but mypy was very
difficult to please. Now I have removed the decorator
and allowed the repetition.
This makes it easier to read and maintain so meh
  • Loading branch information
ebrahimebrahim committed Aug 17, 2024
1 parent f78088a commit c038b90
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 53 deletions.
70 changes: 18 additions & 52 deletions src/abcdmicro/io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, TypeVar, get_type_hints
from typing import Any

import itk
import numpy as np
from dipy.io.gradients import read_bvals_bvecs
from numpy.typing import NDArray

from abcdmicro.resource import (
BvalResource,
Expand All @@ -17,54 +18,9 @@
VolumeResource,
)

T = TypeVar("T", bound="LoadableResource")


class LoadableResource(ABC):
"""Base class for on-disk resources that have a load method that converts them into in-memory resources"""

@abstractmethod
def load(self) -> Any:
"""Load this resource to get an in-memory version of it."""


def implement_via_loading(method_names: list[str]) -> Callable[[type[T]], type[T]]:
"""Decorator that implements the listed abstract methods of a LoadableResource class by calling the
load() method and then using the loaded object's method of the same name."""

def implement_via_loading_decorator(cls: type[T]) -> type[T]:
for method_name in method_names:

def method(self, method_name=method_name): # type: ignore[no-untyped-def]
return getattr(self.load(), method_name)()

method.__name__ = method_name
method.__doc__ = f"Automatically implemented method that returns `self.load().{method_name}()`."

for parent_class in cls.__bases__:
if hasattr(parent_class, method_name):
parent_method = getattr(parent_class, method_name)
return_type = get_type_hints(parent_method).get("return", Any)
method.__annotations__ = {"return": return_type}
break

setattr(cls, method_name, method)

# If the automatically implemented methods were abstract methods, then remove them from the set
# to indicate that they have been implemented.
if hasattr(cls, "__abstractmethods__"):
cls.__abstractmethods__ = frozenset(
name for name in cls.__abstractmethods__ if name not in method_names
)

return cls

return implement_via_loading_decorator


@implement_via_loading(["get_array", "get_metadata"])
@dataclass
class NiftiVolumeResrouce(VolumeResource): # type: ignore[type-var]
class NiftiVolumeResrouce(VolumeResource):
"""A volume or volume stack that is saved to disk in the nifti file format."""

path: Path
Expand All @@ -73,10 +29,15 @@ class NiftiVolumeResrouce(VolumeResource): # type: ignore[type-var]
def load(self) -> InMemoryVolumeResource:
return InMemoryVolumeResource(itk.imread(self.path))

def get_array(self) -> NDArray[Any]:
return self.load().get_array()

def get_metadata(self) -> dict[Any, Any]:
return self.load().get_metadata()


@implement_via_loading(["get"])
@dataclass
class FslBvalResource(BvalResource, LoadableResource):
class FslBvalResource(BvalResource):
"""A b-value list that is saved to disk in the FSL text file format."""

path: Path
Expand All @@ -86,10 +47,12 @@ def load(self) -> InMemoryBvalResource:
bvals_array, _ = read_bvals_bvecs(self.path, None)
return InMemoryBvalResource(bvals_array)

def get(self) -> NDArray[np.floating]:
return self.load().get()


@implement_via_loading(["get"])
@dataclass
class FslBvecResource(BvecResource, LoadableResource):
class FslBvecResource(BvecResource):
"""A b-vector list that is saved to disk in the FSL text file format."""

path: Path
Expand All @@ -98,3 +61,6 @@ class FslBvecResource(BvecResource, LoadableResource):
def load(self) -> InMemoryBvecResource:
_, bvecs_array = read_bvals_bvecs(None, self.path)
return InMemoryBvecResource(bvecs_array)

def get(self) -> NDArray[np.floating]:
return self.load().get()
2 changes: 1 addition & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def test_nifti_volume_resource(volume_array):
]
),
)
volume_resource = NiftiVolumeResrouce(path=volume_file) # type: ignore[abstract]
volume_resource = NiftiVolumeResrouce(path=volume_file)
assert np.allclose(volume_resource.get_array(), volume_array)

0 comments on commit c038b90

Please sign in to comment.