diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 3471f0a3..18547a7e 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -3705,6 +3705,10 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... self.assertEqual(Y.__parameters__, ()) self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview)) + # Regression test; fixing #126 might cause an error here + with self.assertRaisesRegex(TypeError, "not a generic class"): + Y[int] + def test_protocol_generic_over_typevartuple(self): Ts = TypeVarTuple("Ts") T = TypeVar("T") @@ -5259,6 +5263,7 @@ class X(Generic[T, P]): class Y(Protocol[T, P]): pass + things = "arguments" if sys.version_info >= (3, 10) else "parameters" for klass in X, Y: with self.subTest(klass=klass.__name__): G1 = klass[int, P_2] @@ -5273,13 +5278,59 @@ class Y(Protocol[T, P]): self.assertEqual(G3.__args__, (int, Concatenate[int, ...])) self.assertEqual(G3.__parameters__, ()) + with self.assertRaisesRegex( + TypeError, + f"Too few {things} for {klass}" + ): + klass[int] + # The following are some valid uses cases in PEP 612 that don't work: # These do not work in 3.9, _type_check blocks the list and ellipsis. # G3 = X[int, [int, bool]] # G4 = X[int, ...] # G5 = Z[[int, str, bool]] - # Not working because this is special-cased in 3.10. - # G6 = Z[int, str, bool] + + def test_single_argument_generic(self): + P = ParamSpec("P") + T = TypeVar("T") + P_2 = ParamSpec("P_2") + + class Z(Generic[P]): + pass + + class ProtoZ(Protocol[P]): + pass + + for klass in Z, ProtoZ: + with self.subTest(klass=klass.__name__): + # Note: For 3.10+ __args__ are nested tuples here ((int, ),) instead of (int, ) + G6 = klass[int, str, T] + G6args = G6.__args__[0] if sys.version_info >= (3, 10) else G6.__args__ + self.assertEqual(G6args, (int, str, T)) + self.assertEqual(G6.__parameters__, (T,)) + + # P = [int] + G7 = klass[int] + G7args = G7.__args__[0] if sys.version_info >= (3, 10) else G7.__args__ + self.assertEqual(G7args, (int,)) + self.assertEqual(G7.__parameters__, ()) + + G8 = klass[Concatenate[T, ...]] + self.assertEqual(G8.__args__, (Concatenate[T, ...], )) + self.assertEqual(G8.__parameters__, (T,)) + + G9 = klass[Concatenate[T, P_2]] + self.assertEqual(G9.__args__, (Concatenate[T, P_2], )) + + # This is an invalid form but useful for testing correct subsitution + G10 = klass[int, Concatenate[str, P]] + G10args = G10.__args__[0] if sys.version_info >= (3, 10) else G10.__args__ + self.assertEqual(G10args, (int, Concatenate[str, P], )) + + def test_single_argument_generic_with_parameter_expressions(self): + P = ParamSpec("P") + T = TypeVar("T") + P_2 = ParamSpec("P_2") class Z(Generic[P]): pass @@ -5287,6 +5338,58 @@ class Z(Generic[P]): class ProtoZ(Protocol[P]): pass + things = "arguments" if sys.version_info >= (3, 10) else "parameters" + for klass in Z, ProtoZ: + with self.subTest(klass=klass.__name__): + G8 = klass[Concatenate[T, ...]] + + H8_1 = G8[int] + self.assertEqual(H8_1.__parameters__, ()) + with self.assertRaisesRegex(TypeError, "not a generic class"): + H8_1[str] + + H8_2 = G8[T][int] + self.assertEqual(H8_2.__parameters__, ()) + with self.assertRaisesRegex(TypeError, "not a generic class"): + H8_2[str] + + G9 = klass[Concatenate[T, P_2]] + self.assertEqual(G9.__parameters__, (T, P_2)) + + with self.assertRaisesRegex(TypeError, + "The last parameter to Concatenate should be a ParamSpec variable or ellipsis." + if sys.version_info < (3, 10) else + # from __typing_subst__ + "Expected a list of types, an ellipsis, ParamSpec, or Concatenate" + ): + G9[int, int] + + with self.assertRaisesRegex(TypeError, f"Too few {things}"): + G9[int] + + with self.subTest("Check list as parameter expression", klass=klass.__name__): + if sys.version_info < (3, 10): + self.skipTest("Cannot pass non-types") + G5 = klass[[int, str, T]] + self.assertEqual(G5.__parameters__, (T,)) + self.assertEqual(G5.__args__, ((int, str, T),)) + + H9 = G9[int, [T]] + self.assertEqual(H9.__parameters__, (T,)) + + # This is an invalid parameter expression but useful for testing correct subsitution + G10 = klass[int, Concatenate[str, P]] + with self.subTest("Check invalid form substitution"): + self.assertEqual(G10.__parameters__, (P, )) + if sys.version_info < (3, 9): + self.skipTest("3.8 typing._type_subst does not support this substitution process") + H10 = G10[int] + if (3, 10) <= sys.version_info < (3, 11, 3): + self.skipTest("3.10-3.11.2 does not substitute Concatenate here") + self.assertEqual(H10.__parameters__, ()) + H10args = H10.__args__[0] if sys.version_info >= (3, 10) else H10.__args__ + self.assertEqual(H10args, (int, (str, int))) + def test_pickle(self): global P, P_co, P_contra, P_default P = ParamSpec('P') @@ -5468,6 +5571,32 @@ def test_eq(self): self.assertEqual(hash(C4), hash(C5)) self.assertNotEqual(C4, C6) + def test_substitution(self): + T = TypeVar('T') + P = ParamSpec('P') + Ts = TypeVarTuple("Ts") + + C1 = Concatenate[str, T, ...] + self.assertEqual(C1[int], Concatenate[str, int, ...]) + + C2 = Concatenate[str, P] + self.assertEqual(C2[...], Concatenate[str, ...]) + self.assertEqual(C2[int], (str, int)) + U1 = Unpack[Tuple[int, str]] + U2 = Unpack[Ts] + self.assertEqual(C2[U1], (str, int, str)) + self.assertEqual(C2[U2], (str, Unpack[Ts])) + self.assertEqual(C2["U2"], (str, typing.ForwardRef("U2"))) + + if (3, 12, 0) <= sys.version_info < (3, 12, 4): + with self.assertRaises(AssertionError): + C2[Unpack[U2]] + else: + with self.assertRaisesRegex(TypeError, "must be used with a tuple type"): + C2[Unpack[U2]] + + C3 = Concatenate[str, T, P] + self.assertEqual(C3[int, [bool]], (str, int, bool)) class TypeGuardTests(BaseTestCase): def test_basics(self): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 5cdafb70..99795488 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -1765,6 +1765,23 @@ def __call__(self, *args, **kwargs): # 3.8-3.9 if not hasattr(typing, 'Concatenate'): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + + # 3.9.0-1 + if not hasattr(typing, '_type_convert'): + def _type_convert(arg, module=None, *, allow_special_forms=False): + """For converting None to type(None), and strings to ForwardRef.""" + if arg is None: + return type(None) + if isinstance(arg, str): + if sys.version_info <= (3, 9, 6): + return ForwardRef(arg) + if sys.version_info <= (3, 9, 7): + return ForwardRef(arg, module=module) + return ForwardRef(arg, module=module, is_class=allow_special_forms) + return arg + else: + _type_convert = typing._type_convert + class _ConcatenateGenericAlias(list): # Trick Generic into looking into this for __parameters__. @@ -1795,6 +1812,96 @@ def __parameters__(self): return tuple( tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) ) + + # 3.8; needed for typing._subs_tvars + # 3.9 used by __getitem__ below + def copy_with(self, params): + if isinstance(params[-1], _ConcatenateGenericAlias): + params = (*params[:-1], *params[-1].__args__) + elif isinstance(params[-1], (list, tuple)): + return (*params[:-1], *params[-1]) + elif (not (params[-1] is ... or isinstance(params[-1], ParamSpec))): + raise TypeError("The last parameter to Concatenate should be a " + "ParamSpec variable or ellipsis.") + return self.__class__(self.__origin__, params) + + # 3.9; accessed during GenericAlias.__getitem__ when substituting + def __getitem__(self, args): + if self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. + raise TypeError(f"Cannot subscript already-subscripted {self}") + if not self.__parameters__: + raise TypeError(f"{self} is not a generic class") + + if not isinstance(args, tuple): + args = (args,) + args = _unpack_args(*(_type_convert(p) for p in args)) + params = self.__parameters__ + for param in params: + prepare = getattr(param, "__typing_prepare_subst__", None) + if prepare is not None: + args = prepare(self, args) + # 3.8 - 3.9 & typing.ParamSpec + elif isinstance(param, ParamSpec): + i = params.index(param) + if ( + i == len(args) + and getattr(param, '__default__', NoDefault) is not NoDefault + ): + args = [*args, param.__default__] + if i >= len(args): + raise TypeError(f"Too few arguments for {self}") + # Special case for Z[[int, str, bool]] == Z[int, str, bool] + if len(params) == 1 and not _is_param_expr(args[0]): + assert i == 0 + args = (args,) + elif ( + isinstance(args[i], list) + # 3.8 - 3.9 + # This class inherits from list do not convert + and not isinstance(args[i], _ConcatenateGenericAlias) + ): + args = (*args[:i], tuple(args[i]), *args[i + 1:]) + + alen = len(args) + plen = len(params) + if alen != plen: + raise TypeError( + f"Too {'many' if alen > plen else 'few'} arguments for {self};" + f" actual {alen}, expected {plen}" + ) + + subst = dict(zip(self.__parameters__, args)) + # determine new args + new_args = [] + for arg in self.__args__: + if isinstance(arg, type): + new_args.append(arg) + continue + if isinstance(arg, TypeVar): + arg = subst[arg] + if ( + (isinstance(arg, typing._GenericAlias) and _is_unpack(arg)) + or ( + hasattr(_types, "GenericAlias") + and isinstance(arg, _types.GenericAlias) + and getattr(arg, "__unpacked__", False) + ) + ): + raise TypeError(f"{arg} is not valid as type argument") + + elif isinstance(arg, + typing._GenericAlias + if not hasattr(_types, "GenericAlias") else + (typing._GenericAlias, _types.GenericAlias) + ): + subparams = arg.__parameters__ + if subparams: + subargs = tuple(subst[x] for x in subparams) + arg = arg[subargs] + new_args.append(arg) + return self.copy_with(tuple(new_args)) + # 3.10+ else: _ConcatenateGenericAlias = typing._ConcatenateGenericAlias @@ -1817,6 +1924,12 @@ def copy_with(self, params): "ParamSpec variable or ellipsis.") return super(_typing_ConcatenateGenericAlias, self).copy_with(params) + def __getitem__(self, args): + value = super().__getitem__(args) + if isinstance(value, tuple) and any(_is_unpack(t) for t in value): + return tuple(_unpack_args(*(n for n in value))) + return value + # 3.8-3.9.2 class _EllipsisDummy: ... @@ -2496,6 +2609,21 @@ def _is_unpack(obj): class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar + def _typing_unpacked_tuple_args(self): + assert self.__origin__ is Unpack + assert len(self.__args__) == 1 + arg, = self.__args__ + if isinstance(arg, typing._GenericAlias): + if arg.__origin__ is not tuple: + raise TypeError("Unpack[...] must be used with a tuple type") + return arg.__args__ + return None + + def __getattr__(self, attr): + if attr == '__typing_unpacked_tuple_args__': + return self._typing_unpacked_tuple_args() + return super().__getattr__(attr) + @property def __typing_is_unpacked_typevartuple__(self): assert self.__origin__ is Unpack @@ -2519,21 +2647,22 @@ def _is_unpack(obj): return isinstance(obj, _UnpackAlias) +def _unpack_args(*args): + newargs = [] + for arg in args: + subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + if subargs is not None and (not (subargs and subargs[-1] is ...)): + newargs.extend(subargs) + else: + newargs.append(arg) + return newargs + + if _PEP_696_IMPLEMENTED: from typing import TypeVarTuple elif hasattr(typing, "TypeVarTuple"): # 3.11+ - def _unpack_args(*args): - newargs = [] - for arg in args: - subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) - if subargs is not None and not (subargs and subargs[-1] is ...): - newargs.extend(subargs) - else: - newargs.append(arg) - return newargs - # Add default parameter - PEP 696 class TypeVarTuple(metaclass=_TypeVarLikeMeta): """Type variable tuple.""" @@ -3007,6 +3136,12 @@ def wrapper(*args, **kwargs): ) +def _is_param_expr(arg): + return arg is ... or isinstance( + arg, (tuple, list, ParamSpec, _ConcatenateGenericAlias) + ) + + # We have to do some monkey patching to deal with the dual nature of # Unpack/TypeVarTuple: # - We want Unpack to be a kind of TypeVar so it gets accepted in @@ -3020,6 +3155,17 @@ def _check_generic(cls, parameters, elen=_marker): This gives a nice error message in case of count mismatch. """ + # If substituting a single ParamSpec with multiple arguments + # we do not check the count + if (inspect.isclass(cls) and issubclass(cls, typing.Generic) + and len(cls.__parameters__) == 1 + and isinstance(cls.__parameters__[0], ParamSpec) + and parameters + and not _is_param_expr(parameters[0]) + ): + # Generic modifies parameters variable, but here we cannot do this + return + if not elen: raise TypeError(f"{cls} is not a generic class") if elen is _marker: @@ -3171,6 +3317,13 @@ def _collect_type_vars(types, typevar_types=None): tvars.append(t) if _should_collect_from_parameters(t): tvars.extend([t for t in t.__parameters__ if t not in tvars]) + elif isinstance(t, tuple): + # Collect nested type_vars + # tuple wrapped by _prepare_paramspec_params(cls, params) + for x in t: + for collected in _collect_type_vars([x]): + if collected not in tvars: + tvars.append(collected) return tuple(tvars) typing._collect_type_vars = _collect_type_vars