Skip to content

Commit

Permalink
Merge pull request #85 from ArcanaFramework/typed-collection
Browse files Browse the repository at this point in the history
Added `TypedCollection` base class
  • Loading branch information
tclose authored Sep 20, 2024
2 parents f5bf5e6 + 233d079 commit 542c75c
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 109 deletions.
3 changes: 2 additions & 1 deletion fileformats/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ._version import __version__
from .classifier import Classifier
from .datatype import DataType
from .fileset import FileSet, MockMixin
from .mock import MockMixin
from .fileset import FileSet
from .field import Field
from .identification import (
to_mime,
Expand Down
69 changes: 69 additions & 0 deletions fileformats/core/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import typing as ty
from pathlib import Path
from abc import ABCMeta, abstractproperty
from fileformats.core import FileSet, validated_property, mtime_cached_property
from fileformats.core.decorators import classproperty
from fileformats.core.exceptions import FormatMismatchError
from fileformats.core.utils import get_optional_type


class TypedCollection(FileSet, metaclass=ABCMeta):
"""Base class for collections of files-sets of specific types either in a directory
or a collection of file paths"""

content_types: ty.Tuple[
ty.Union[ty.Type[FileSet], ty.Type[ty.Optional[FileSet]]], ...
] = ()

@abstractproperty
def content_fspaths(self) -> ty.Iterable[Path]:
... # noqa: E704

@mtime_cached_property
def contents(self) -> ty.List[FileSet]:
contnts = []
for content_type in self.potential_content_types:
assert content_type
for p in self.content_fspaths:
try:
contnts.append(content_type([p], **self._load_kwargs))
except FormatMismatchError:
continue
return contnts

@validated_property
def _validate_required_content_types(self) -> None:
not_found = set(self.required_content_types)
if not not_found:
return
for fspath in self.content_fspaths:
for content_type in list(not_found):
if content_type.matches(fspath):
not_found.remove(content_type)
if not not_found:
return
assert not_found
raise FormatMismatchError(
f"Did not find the required content types, {not_found}, in {self}"
)

@classproperty
def potential_content_types(cls) -> ty.Tuple[ty.Type[FileSet], ...]:
content_types: ty.List[ty.Type[FileSet]] = []
for content_type in cls.content_types: # type: ignore[assignment]
content_types.append(get_optional_type(content_type)) # type: ignore[arg-type]
return tuple(content_types)

@classproperty
def required_content_types(cls) -> ty.Tuple[ty.Type[FileSet], ...]:
content_types: ty.List[ty.Type[FileSet]] = []
for content_type in cls.content_types: # type: ignore[assignment]
if ty.get_origin(content_type) is None:
content_types.append(content_type) # type: ignore[arg-type]
return tuple(content_types)

@classproperty
def unconstrained(cls) -> bool:
"""Whether the file-format is unconstrained by extension, magic number or another
constraint"""
return super().unconstrained and not cls.content_types
42 changes: 1 addition & 41 deletions fileformats/core/fileset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .datatype import DataType
from .extras import extra
from .fs_mount_identifier import FsMountIdentifier
from .mock import MockMixin

if ty.TYPE_CHECKING:
from pydra.engine.task import TaskBase
Expand Down Expand Up @@ -1742,44 +1743,3 @@ def _new_copy_path(
_formats_by_name: ty.Optional[ty.Dict[str, ty.Set[ty.Type["FileSet"]]]] = None
_required_props: ty.Optional[ty.Tuple[str, ...]] = None
_valid_class: ty.Optional[bool] = None


class MockMixin:
"""Strips out validation methods of a class, allowing it to be mocked in a way that
still satisfies type-checking"""

def __init__(
self,
fspaths: FspathsInputType,
metadata: ty.Union[ty.Dict[str, ty.Any], bool, None] = False,
):
self.fspaths = fspaths_converter(fspaths)
self._metadata = metadata

@classproperty
def type_name(cls) -> str:
return cls.mocked.type_name

def __bytes_repr__(self, cache: ty.Dict[str, ty.Any]) -> ty.Iterable[bytes]:
yield from (str(fspath).encode() for fspath in self.fspaths)

@classproperty
def mocked(cls) -> FileSet:
"""The "true" class that the mocked class is based on"""
return next(c for c in cls.__mro__ if not issubclass(c, MockMixin)) # type: ignore[no-any-return, attr-defined]

@classproperty
def namespace(cls) -> str:
"""The "namespace" the format belongs to under the "fileformats" umbrella
namespace"""
mro: ty.Tuple[ty.Type] = cls.__mro__ # type: ignore
for base in mro:
if issubclass(base, MockMixin):
continue
try:
return base.namespace # type: ignore
except FormatDefinitionError:
pass
raise FormatDefinitionError(
f"None of of the bases classes of {cls} ({mro}) have a valid namespace"
)
34 changes: 25 additions & 9 deletions fileformats/core/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from .datatype import DataType
import fileformats.core
from .utils import describe_task, matching_source
from .utils import describe_task, matching_source, get_optional_type
from .decorators import validated_property, classproperty
from .identification import to_mime_format_name
from .converter_helpers import SubtypeVar, ConverterSpec
Expand Down Expand Up @@ -292,6 +292,7 @@ def my_func(file: MyFormatWithClassifiers[Integer]):
# Default values for class attrs
multiple_classifiers = True
allowed_classifiers: ty.Optional[ty.Tuple[ty.Type[Classifier], ...]] = None
allow_optional_classifiers = False
exclusive_classifiers: ty.Tuple[ty.Type[Classifier], ...] = ()
ordered_classifiers = False
generically_classifiable = False
Expand Down Expand Up @@ -320,7 +321,9 @@ def wildcard_classifiers(
) -> ty.FrozenSet[ty.Type[SubtypeVar]]:
if classifiers is None:
classifiers = cls.classifiers if cls.is_classified else ()
return frozenset(t for t in classifiers if issubclass(t, SubtypeVar))
return frozenset(
t for t in classifiers if issubclass(get_optional_type(t), SubtypeVar) # type: ignore[misc]
)

@classmethod
def non_wildcard_classifiers(
Expand All @@ -329,7 +332,9 @@ def non_wildcard_classifiers(
if classifiers is None:
classifiers = cls.classifiers if cls.is_classified else ()
assert classifiers is not None
return frozenset(q for q in classifiers if not issubclass(q, SubtypeVar))
return frozenset(
q for q in classifiers if not issubclass(get_optional_type(q), SubtypeVar)
)

@classmethod
def __class_getitem__(
Expand All @@ -341,11 +346,15 @@ def __class_getitem__(
classifiers_tuple = tuple(classifiers)
else:
classifiers_tuple = (classifiers,)
classifiers_to_check = tuple(
get_optional_type(c, cls.allow_optional_classifiers)
for c in classifiers_tuple
)

if cls.allowed_classifiers:
not_allowed = [
q
for q in classifiers_tuple
for q in classifiers_to_check
if not any(issubclass(q, t) for t in cls.allowed_classifiers)
]
if not_allowed:
Expand All @@ -357,15 +366,17 @@ def __class_getitem__(
if cls.multiple_classifiers:
if not cls.ordered_classifiers:
# Check for duplicate classifiers in the multiple list
if len(classifiers_tuple) > 1:
if len(classifiers_to_check) > 1:
# Sort the classifiers into categories and ensure that there aren't more
# than one type for each category. Otherwise, if the classifier doesn't
# belong to a category, check to see that there aren't multiple sub-classes
# in the classifier set
repetitions: ty.Dict[
ty.Type[Classifier], ty.List[ty.Type[Classifier]]
] = {c: [] for c in cls.exclusive_classifiers + classifiers_tuple}
for classifier in classifiers_tuple:
] = {
c: [] for c in cls.exclusive_classifiers + classifiers_to_check
}
for classifier in classifiers_to_check:
for exc_classifier in repetitions:
if issubclass(classifier, exc_classifier):
repetitions[exc_classifier].append(classifier)
Expand All @@ -381,7 +392,10 @@ def __class_getitem__(
)
)
classifiers_tuple = tuple(
sorted(set(classifiers_tuple), key=lambda x: x.__name__)
sorted(
set(classifiers_tuple),
key=lambda x: get_optional_type(x).__name__,
)
)
else:
if len(classifiers_tuple) > 1:
Expand Down Expand Up @@ -428,7 +442,9 @@ def __class_getitem__(
class_attrs[cls.classifiers_attr_name] = (
classifiers_tuple if cls.multiple_classifiers else classifiers_tuple[0]
)
classifier_names = [t.__name__ for t in classifiers_tuple]
classifier_names = [
get_optional_type(t).__name__ for t in classifiers_tuple
]
if not cls.ordered_classifiers:
classifier_names.sort()
classified = type(
Expand Down
53 changes: 53 additions & 0 deletions fileformats/core/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import typing as ty
from .utils import (
fspaths_converter,
)
from .decorators import classproperty
from .typing import FspathsInputType
from .exceptions import (
FormatDefinitionError,
)

if ty.TYPE_CHECKING:
from .fileset import FileSet


class MockMixin:
"""Strips out validation methods of a class, allowing it to be mocked in a way that
still satisfies type-checking"""

def __init__(
self,
fspaths: FspathsInputType,
metadata: ty.Union[ty.Dict[str, ty.Any], bool, None] = False,
):
self.fspaths = fspaths_converter(fspaths)
self._metadata = metadata

@classproperty
def type_name(cls) -> str:
return cls.mocked.type_name

def __bytes_repr__(self, cache: ty.Dict[str, ty.Any]) -> ty.Iterable[bytes]:
yield from (str(fspath).encode() for fspath in self.fspaths)

@classproperty
def mocked(cls) -> "FileSet":
"""The "true" class that the mocked class is based on"""
return next(c for c in cls.__mro__ if not issubclass(c, MockMixin)) # type: ignore[no-any-return, attr-defined]

@classproperty
def namespace(cls) -> str:
"""The "namespace" the format belongs to under the "fileformats" umbrella
namespace"""
mro: ty.Tuple[ty.Type] = cls.__mro__ # type: ignore
for base in mro:
if issubclass(base, MockMixin):
continue
try:
return base.namespace # type: ignore
except FormatDefinitionError:
pass
raise FormatDefinitionError(
f"None of of the bases classes of {cls} ({mro}) have a valid namespace"
)
37 changes: 37 additions & 0 deletions fileformats/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from contextlib import contextmanager
from .typing import FspathsInputType
import fileformats.core
from fileformats.core.exceptions import FormatDefinitionError

if ty.TYPE_CHECKING:
import pydra.engine.core
Expand Down Expand Up @@ -228,3 +229,39 @@ def import_extras_module(klass: ty.Type["fileformats.core.DataType"]) -> ExtrasM
else:
extras_imported = True
return ExtrasModule(extras_imported, extras_pkg, extras_pypi)


TypeType = ty.TypeVar("TypeType", bound=ty.Type[ty.Any])


def get_optional_type(
type_: ty.Union[TypeType, ty.Type[ty.Optional[TypeType]]], allowed: bool = True
) -> TypeType:
"""Checks if a type is an Optional type
Parameters
----------
type_ : ty.Type
the type to check
allowed : bool
whether Optional types are allowed or not
Returns
-------
bool
whether the type is an Optional type or not
"""
if ty.get_origin(type_) is None:
return type_ # type: ignore[return-value]
if not allowed:
raise FormatDefinitionError(
f"Optional types are not allowed in content_type definitions ({type_}) "
"in this context"
)
args = ty.get_args(type_)
if len(args) != 2 and None in ty.get_args(type_):
raise FormatDefinitionError(
"Only Optional types are allowed in content_type definitions, "
f"not {type_}"
)
return args[0] if args[0] is not None else args[1] # type: ignore[no-any-return]
Loading

0 comments on commit 542c75c

Please sign in to comment.