diff --git a/tests/test_type_view.py b/tests/test_type_view.py index 6e95059..a2e845d 100644 --- a/tests/test_type_view.py +++ b/tests/test_type_view.py @@ -278,3 +278,12 @@ def test_parsed_type_equality() -> None: assert TypeView(list[int]) != TypeView(list[str]) assert TypeView(list[str]) != TypeView(tuple[str]) assert TypeView(Optional[str]) == TypeView(Union[str, None]) + + +def test_unwrap_optional() -> None: + # Non-optionals should return the original input + assert TypeView(int).unwrap_optional() == TypeView(int) + + assert TypeView(Optional[int]).unwrap_optional() == TypeView(int) + assert TypeView(Optional[Union[str, int]]).unwrap_optional() == TypeView(Union[str, int]) + assert TypeView(Union[str, int, None]).unwrap_optional() == TypeView(Union[str, int]) diff --git a/type_lens/type_view.py b/type_lens/type_view.py index 868a93b..60ff6ec 100644 --- a/type_lens/type_view.py +++ b/type_lens/type_view.py @@ -2,7 +2,7 @@ from collections import abc from collections.abc import Collection, Mapping -from typing import Annotated, Any, AnyStr, Final, ForwardRef, TypeVar +from typing import Annotated, Any, AnyStr, Final, ForwardRef, TypeVar, Union from typing_extensions import NotRequired, Required, get_args, get_origin @@ -137,3 +137,14 @@ def has_inner_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: Whether any of the type's generic args are a subclass of the given type. """ return any(t.is_subclass_of(cl) for t in self.inner_types) + + def unwrap_optional(self) -> TypeView: + if not self.is_optional: + return self + + if len(self.args) == 2: + return TypeView(self.args[0]) + + args = tuple([a for a in self.args if a is not NoneType]) + non_optional = Union[args] + return TypeView(non_optional)