Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Concatenate and Generic with ParamSpec substitution #489

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 171 additions & 5 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3705,6 +3705,10 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
self.assertEqual(Y.__parameters__, ())
self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview))

# Regression test; fixing #126 might cause an error here
with self.assertRaisesRegex(TypeError, "not a generic class"):
Y[int]

def test_protocol_generic_over_typevartuple(self):
Ts = TypeVarTuple("Ts")
T = TypeVar("T")
Expand Down Expand Up @@ -5259,6 +5263,7 @@ class X(Generic[T, P]):
class Y(Protocol[T, P]):
pass

things = "arguments" if sys.version_info >= (3, 10) else "parameters"
for klass in X, Y:
with self.subTest(klass=klass.__name__):
G1 = klass[int, P_2]
Expand All @@ -5273,20 +5278,146 @@ class Y(Protocol[T, P]):
self.assertEqual(G3.__args__, (int, Concatenate[int, ...]))
self.assertEqual(G3.__parameters__, ())

with self.assertRaisesRegex(
TypeError,
f"Too few {things} for {klass}"
):
klass[int]

# The following are some valid uses cases in PEP 612 that don't work:
# These do not work in 3.9, _type_check blocks the list and ellipsis.
# G3 = X[int, [int, bool]]
# G4 = X[int, ...]
# G5 = Z[[int, str, bool]]
# Not working because this is special-cased in 3.10.
# G6 = Z[int, str, bool]

def test_single_argument_generic(self):
P = ParamSpec("P")
T = TypeVar("T")
P_2 = ParamSpec("P_2")

class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

for klass in Z, ProtoZ:
with self.subTest(klass=klass.__name__):
# Note: For 3.10+ __args__ are nested tuples here ((int, ),) instead of (int, )
G6 = klass[int, str, T]
G6args = G6.__args__[0] if sys.version_info >= (3, 10) else G6.__args__
self.assertEqual(G6args, (int, str, T))
self.assertEqual(G6.__parameters__, (T,))

# P = [int]
G7 = klass[int]
G7args = G7.__args__[0] if sys.version_info >= (3, 10) else G7.__args__
self.assertEqual(G7args, (int,))
self.assertEqual(G7.__parameters__, ())

G8 = klass[Concatenate[T, ...]]
self.assertEqual(G8.__args__, (Concatenate[T, ...], ))
self.assertEqual(G8.__parameters__, (T,))

G9 = klass[Concatenate[T, P_2]]
self.assertEqual(G9.__args__, (Concatenate[T, P_2], ))

# This is an invalid form but useful for testing correct subsitution
G10 = klass[int, Concatenate[str, P]]
G10args = G10.__args__[0] if sys.version_info >= (3, 10) else G10.__args__
self.assertEqual(G10args, (int, Concatenate[str, P], ))

@skipUnless(TYPING_3_10_0, "ParamSpec not present before 3.10")
def test_is_param_expr(self):
P = ParamSpec("P")
P_typing = typing.ParamSpec("P_typing")
self.assertTrue(typing_extensions._is_param_expr(P))
self.assertTrue(typing_extensions._is_param_expr(P_typing))
if hasattr(typing, "_is_param_expr"):
self.assertTrue(typing._is_param_expr(P))
self.assertTrue(typing._is_param_expr(P_typing))

def test_single_argument_generic_with_parameter_expressions(self):
P = ParamSpec("P")
T = TypeVar("T")
P_2 = ParamSpec("P_2")

class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

things = "arguments" if sys.version_info >= (3, 10) else "parameters"
for klass in Z, ProtoZ:
with self.subTest(klass=klass.__name__):
G8 = klass[Concatenate[T, ...]]

H8_1 = G8[int]
self.assertEqual(H8_1.__parameters__, ())
with self.assertRaisesRegex(TypeError, "not a generic class"):
H8_1[str]

H8_2 = G8[T][int]
self.assertEqual(H8_2.__parameters__, ())
with self.assertRaisesRegex(TypeError, "not a generic class"):
H8_2[str]

G9 = klass[Concatenate[T, P_2]]
self.assertEqual(G9.__parameters__, (T, P_2))

with self.assertRaisesRegex(TypeError,
"The last parameter to Concatenate should be a ParamSpec variable or ellipsis."
if sys.version_info < (3, 10) else
# from __typing_subst__
"Expected a list of types, an ellipsis, ParamSpec, or Concatenate"
):
G9[int, int]

with self.assertRaisesRegex(TypeError, f"Too few {things}"):
G9[int]

with self.subTest("Check list as parameter expression", klass=klass.__name__):
if sys.version_info < (3, 10):
self.skipTest("Cannot pass non-types")
G5 = klass[[int, str, T]]
self.assertEqual(G5.__parameters__, (T,))
self.assertEqual(G5.__args__, ((int, str, T),))

H9 = G9[int, [T]]
self.assertEqual(H9.__parameters__, (T,))

# This is an invalid parameter expression but useful for testing correct subsitution
G10 = klass[int, Concatenate[str, P]]
with self.subTest("Check invalid form substitution"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a valid parameter expression, it is nice for debugging though, should keep or remove it?

self.assertEqual(G10.__parameters__, (P, ))
if sys.version_info < (3, 9):
self.skipTest("3.8 typing._type_subst does not support this substitution process")
H10 = G10[int]
if (3, 10) <= sys.version_info < (3, 11, 3):
self.skipTest("3.10-3.11.2 does not substitute Concatenate here")
self.assertEqual(H10.__parameters__, ())
H10args = H10.__args__[0] if sys.version_info >= (3, 10) else H10.__args__
self.assertEqual(H10args, (int, (str, int)))

@skipUnless(TYPING_3_10_0, "ParamSpec not present before 3.10")
def test_substitution_with_typing_variants(self):
# verifies substitution and typing._check_generic working with typing variants
P = ParamSpec("P")
typing_P = typing.ParamSpec("typing_P")
typing_Concatenate = typing.Concatenate[int, P]

class Z(Generic[typing_P]):
pass

P1 = Z[typing_P]
self.assertEqual(P1.__parameters__, (typing_P,))
self.assertEqual(P1.__args__, (typing_P,))

C1 = Z[typing_Concatenate]
self.assertEqual(C1.__parameters__, (P,))
self.assertEqual(C1.__args__, (typing_Concatenate,))

def test_pickle(self):
global P, P_co, P_contra, P_default
P = ParamSpec('P')
Expand Down Expand Up @@ -5468,6 +5599,43 @@ def test_eq(self):
self.assertEqual(hash(C4), hash(C5))
self.assertNotEqual(C4, C6)

def test_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
Ts = TypeVarTuple("Ts")

C1 = Concatenate[str, T, ...]
self.assertEqual(C1[int], Concatenate[str, int, ...])

C2 = Concatenate[str, P]
self.assertEqual(C2[...], Concatenate[str, ...])
self.assertEqual(C2[int], (str, int))
U1 = Unpack[Tuple[int, str]]
U2 = Unpack[Ts]
self.assertEqual(C2[U1], (str, int, str))
self.assertEqual(C2[U2], (str, Unpack[Ts]))
self.assertEqual(C2["U2"], (str, typing.ForwardRef("U2")))

if (3, 12, 0) <= sys.version_info < (3, 12, 4):
with self.assertRaises(AssertionError):
C2[Unpack[U2]]
else:
with self.assertRaisesRegex(TypeError, "must be used with a tuple type"):
C2[Unpack[U2]]

C3 = Concatenate[str, T, P]
self.assertEqual(C3[int, [bool]], (str, int, bool))

@skipUnless(TYPING_3_10_0, "Concatenate not present before 3.10")
def test_is_param_expr(self):
P = ParamSpec('P')
concat = Concatenate[str, P]
typing_concat = typing.Concatenate[str, P]
self.assertTrue(typing_extensions._is_param_expr(concat))
self.assertTrue(typing_extensions._is_param_expr(typing_concat))
if hasattr(typing, "_is_param_expr"):
self.assertTrue(typing._is_param_expr(concat))
self.assertTrue(typing._is_param_expr(typing_concat))

class TypeGuardTests(BaseTestCase):
def test_basics(self):
Expand Down Expand Up @@ -7465,11 +7633,9 @@ def test_callable_with_concatenate(self):
self.assertEqual(callable_concat.__parameters__, (P2,))
concat_usage = callable_concat[str]
with self.subTest("get_args of Concatenate in TypeAliasType"):
if not TYPING_3_9_0:
if not TYPING_3_10_0:
# args are: ([<class 'int'>, ~P2],)
self.skipTest("Nested ParamSpec is not substituted")
if sys.version_info < (3, 10, 2):
self.skipTest("GenericAlias keeps Concatenate in __args__ prior to 3.10.2")
self.assertEqual(get_args(concat_usage), ((int, str),))
with self.subTest("Equality of parameter_expression without []"):
if not TYPING_3_10_0:
Expand Down
Loading
Loading