Skip to content

Commit

Permalink
generic function
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Jun 4, 2024
1 parent f3c2a9b commit 9a8cd5f
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
67 changes: 66 additions & 1 deletion basedtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import sys
from dataclasses import dataclass
from typing import ( # type: ignore[attr-defined]
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -53,6 +54,7 @@
"Untyped",
"Intersection",
"TypeForm",
"generic",
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -508,7 +510,9 @@ def __reduce__(self) -> (object, object):
if sys.version_info > (3, 9):

@_BasedSpecialForm
def Intersection(self: _BasedSpecialForm, parameters: object) -> object: # noqa: N802
def Intersection( # noqa: N802
self: _BasedSpecialForm, parameters: object
) -> object:
"""Intersection type; Intersection[X, Y] means both X and Y.
To define an intersection:
Expand Down Expand Up @@ -574,3 +578,64 @@ def f[T](t: TypeForm[T]) -> T: ...
reveal_type(f(int | str)) # int | str
"""
)


@dataclass
class _BaseGenericFunction(Generic[P, T]):
fn: Callable[P, T]


@dataclass
class _GenericFunction(_BaseGenericFunction[P, T]):
# TODO: make this an TypeVarTuple when mypy supports it
# https://github.com/python/mypy/issues/16696
__type_params__: tuple[object, ...] | None = None
"""Generic type parameters. Currently unused"""

def __getitem__(self, items: object) -> _ConcreteFunction[P, T]:
items = items if isinstance(items, tuple) else (items,)
return _ConcreteFunction(self.fn, items)


@dataclass
class _ConcreteFunction(_BaseGenericFunction[P, T]):
__type_args__: tuple[object, ...] | None = None
"""Concrete type parameters. Currently unused"""

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.fn(*args, **kwargs)


class _GenericFunctionFacilitator:
__type_params__: tuple[object, ...] | None = None
args: tuple[object, ...]

def __call__(self, fn: Callable[P, T]) -> _GenericFunction[P, T]:
return _GenericFunction(fn, self.args)


class _GenericFunctionDecorator:
"""Decorate a function to allow supplying type parameters on calls:
@generic[T]
def f1(t: T): ...
f1[int](1)
@generic
def f2[T](t: T): ...
f2[int](1)
"""

def __call__(self, fn: Callable[P, T]) -> _GenericFunction[P, T]:
params = cast(Union[Tuple[object, ...], None], getattr(fn, "__type_params__", None))
return _GenericFunction(fn, params)

def __getitem__(self, items: object) -> _GenericFunctionFacilitator:
result = _GenericFunctionFacilitator()
result.args = items if isinstance(items, tuple) else (items,)
return result


generic = _GenericFunctionDecorator()
35 changes: 35 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import sys
from typing import Callable, cast

import pytest

from basedtyping import T, generic


def test_generic_with_args():
deco = generic[T]

@deco # Python version 3.8 does not support arbitrary expressions as a decorator
def f(t: T) -> T:
return t

assert f.__type_params__ == (T,)
assert f[int].__type_args__ == (int,)
assert f[object](1) == 1


def test_generic_without_args():
# not using a decorator because of mypy
if sys.version_info < (3, 12):
pytest.skip(reason="Needs generic syntax support")
local: dict[str, object] = {}
# Can't use the actual function because then <3.12 wouldn't load
exec("def f[T](t: T) -> T: return t", None, local)
_f = cast(Callable[[object], object], local["f"])
f = generic(_f)

assert f.__type_params__ == _f.__type_params__ # type: ignore[attr-defined, unused-ignore]
assert f[int].__type_args__ == (int,)
assert f[int](1) == 1
Empty file added tests/test_generic_312.py
Empty file.

0 comments on commit 9a8cd5f

Please sign in to comment.