From f220b34af4b771636b022b1ecebfbd6ff809bf95 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 3 Dec 2024 15:44:28 +0100 Subject: [PATCH] fix | in sequences --- src/gt4py/eve/datamodels/core.py | 10 ++++------ src/gt4py/eve/type_validation.py | 9 +++++++++ tests/eve_tests/unit_tests/test_datamodels.py | 1 + 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index a45fbb821e..af80973fcd 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -1041,11 +1041,6 @@ def _make_datamodel( for key in annotations: type_hint = annotations[key] = resolved_annotations[key] - if isinstance( - type_hint, types.UnionType - ): # see https://github.com/python/cpython/issues/105499 - type_hint = typing.Union[type_hint.__args__] - # Skip members annotated as class variables if type_hint is ClassVar or xtyping.get_origin(type_hint) is ClassVar: continue @@ -1260,8 +1255,11 @@ def _make_concrete_with_cache( if not is_generic_datamodel_class(datamodel_cls): raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.") for t in type_args: + _accepted_types: tuple[type, ...] = (type, type(None), xtyping.StdGenericAliasType) + if sys.version_info >= (3, 10): + _accepted_types = (*_accepted_types, types.UnionType) if not ( - isinstance(t, (type, type(None), xtyping.StdGenericAliasType, types.UnionType)) + isinstance(t, _accepted_types) or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions")) ): raise TypeError( diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..7464db1b4e 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -14,6 +14,8 @@ import collections.abc import dataclasses import functools +import sys +import types import typing from . import exceptions, extended_typing as xtyping, utils @@ -193,6 +195,12 @@ def __call__( if type_annotation is None: type_annotation = type(None) + if sys.version_info >= (3, 10): + if isinstance( + type_annotation, types.UnionType + ): # see https://github.com/python/cpython/issues/105499 + type_annotation = typing.Union[type_annotation.__args__] + # Non-generic types if xtyping.is_actual_type(type_annotation): assert not xtyping.get_args(type_annotation) @@ -277,6 +285,7 @@ def __call__( if issubclass(origin_type, (collections.abc.Sequence, collections.abc.Set)): assert len(type_args) == 1 + make_recursive(type_args[0]) if (member_validator := make_recursive(type_args[0])) is None: raise exceptions.EveValueError( f"{type_args[0]} type annotation is not supported." diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 7f523df6cf..d826d7a02f 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -556,6 +556,7 @@ class WrongModel: ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), ("int | float | str", [1, 3.0, "one"], [[1], [], 1j]), + ("typing.List[int|float]", [[1, 2.0], []], [1, 2.0, [1, "2.0"]]), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]",