From 9a8cd5f50f7c84b5f3d144d0a03b084454a7cd1e Mon Sep 17 00:00:00 2001 From: KotlinIsland Date: Fri, 22 Dec 2023 04:25:27 +1000 Subject: [PATCH] generic function --- basedtyping/__init__.py | 67 ++++++++++++++++++++++++++++++++++++++- tests/test_generic.py | 35 ++++++++++++++++++++ tests/test_generic_312.py | 0 3 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 tests/test_generic.py create mode 100644 tests/test_generic_312.py diff --git a/basedtyping/__init__.py b/basedtyping/__init__.py index abc3146..600512c 100644 --- a/basedtyping/__init__.py +++ b/basedtyping/__init__.py @@ -5,6 +5,7 @@ from __future__ import annotations import sys +from dataclasses import dataclass from typing import ( # type: ignore[attr-defined] TYPE_CHECKING, Any, @@ -53,6 +54,7 @@ "Untyped", "Intersection", "TypeForm", + "generic", ) if TYPE_CHECKING: @@ -508,7 +510,9 @@ def __reduce__(self) -> (object, object): if sys.version_info > (3, 9): @_BasedSpecialForm - def Intersection(self: _BasedSpecialForm, parameters: object) -> object: # noqa: N802 + def Intersection( # noqa: N802 + self: _BasedSpecialForm, parameters: object + ) -> object: """Intersection type; Intersection[X, Y] means both X and Y. To define an intersection: @@ -574,3 +578,64 @@ def f[T](t: TypeForm[T]) -> T: ... reveal_type(f(int | str)) # int | str """ ) + + +@dataclass +class _BaseGenericFunction(Generic[P, T]): + fn: Callable[P, T] + + +@dataclass +class _GenericFunction(_BaseGenericFunction[P, T]): + # TODO: make this an TypeVarTuple when mypy supports it + # https://github.com/python/mypy/issues/16696 + __type_params__: tuple[object, ...] | None = None + """Generic type parameters. Currently unused""" + + def __getitem__(self, items: object) -> _ConcreteFunction[P, T]: + items = items if isinstance(items, tuple) else (items,) + return _ConcreteFunction(self.fn, items) + + +@dataclass +class _ConcreteFunction(_BaseGenericFunction[P, T]): + __type_args__: tuple[object, ...] | None = None + """Concrete type parameters. Currently unused""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self.fn(*args, **kwargs) + + +class _GenericFunctionFacilitator: + __type_params__: tuple[object, ...] | None = None + args: tuple[object, ...] + + def __call__(self, fn: Callable[P, T]) -> _GenericFunction[P, T]: + return _GenericFunction(fn, self.args) + + +class _GenericFunctionDecorator: + """Decorate a function to allow supplying type parameters on calls: + + @generic[T] + def f1(t: T): ... + + f1[int](1) + + @generic + def f2[T](t: T): ... + + f2[int](1) + """ + + def __call__(self, fn: Callable[P, T]) -> _GenericFunction[P, T]: + params = cast(Union[Tuple[object, ...], None], getattr(fn, "__type_params__", None)) + return _GenericFunction(fn, params) + + def __getitem__(self, items: object) -> _GenericFunctionFacilitator: + result = _GenericFunctionFacilitator() + result.args = items if isinstance(items, tuple) else (items,) + return result + + +generic = _GenericFunctionDecorator() diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 0000000..34333db --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import sys +from typing import Callable, cast + +import pytest + +from basedtyping import T, generic + + +def test_generic_with_args(): + deco = generic[T] + + @deco # Python version 3.8 does not support arbitrary expressions as a decorator + def f(t: T) -> T: + return t + + assert f.__type_params__ == (T,) + assert f[int].__type_args__ == (int,) + assert f[object](1) == 1 + + +def test_generic_without_args(): + # not using a decorator because of mypy + if sys.version_info < (3, 12): + pytest.skip(reason="Needs generic syntax support") + local: dict[str, object] = {} + # Can't use the actual function because then <3.12 wouldn't load + exec("def f[T](t: T) -> T: return t", None, local) + _f = cast(Callable[[object], object], local["f"]) + f = generic(_f) + + assert f.__type_params__ == _f.__type_params__ # type: ignore[attr-defined, unused-ignore] + assert f[int].__type_args__ == (int,) + assert f[int](1) == 1 diff --git a/tests/test_generic_312.py b/tests/test_generic_312.py new file mode 100644 index 0000000..e69de29