From 83dbe1984ba7582a19d6e44075050573bbbeb870 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 23 Oct 2024 03:05:02 +0200 Subject: [PATCH] More cases covered - nested concatenate - Unpack subscription --- src/test_typing_extensions.py | 22 ++++++++++++++++++++-- src/typing_extensions.py | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 09e987a1..5728bde5 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -5520,6 +5520,8 @@ def test_substitution(self): T = TypeVar('T') P = ParamSpec('P') Ts = TypeVarTuple("Ts") + U1 = Unpack[Tuple[int, str]] + U2 = Unpack[Ts] C1 = Concatenate[str, T, ...] self.assertEqual(C1[int], Concatenate[str, int, ...]) @@ -5527,10 +5529,11 @@ def test_substitution(self): C2 = Concatenate[str, P] self.assertEqual(C2[...], Concatenate[str, ...]) self.assertEqual(C2[int], (str, int)) - U1 = Unpack[Tuple[int, str]] - U2 = Unpack[Ts] + self.assertEqual(C2[int, ...], (str, int, ...)) + self.assertEqual(C2[U1], (str, int, str)) self.assertEqual(C2[U2], (str, Unpack[Ts])) + self.assertEqual(C2["U1"], (str, typing.ForwardRef("U1"))) if (3, 12, 0) <= sys.version_info < (3, 12, 4): with self.assertRaises(AssertionError): @@ -5541,7 +5544,22 @@ def test_substitution(self): C3 = Concatenate[str, T, P] self.assertEqual(C3[int, [bool]], (str, int, bool)) + self.assertEqual(C3[int, ...], Concatenate[str, int, ...]) + self.assertEqual(C3[int, Concatenate[str, P]], Concatenate[str, int, str, P]) + + @skipIf((3, 10) <= sys.version_info < (3, 12), reason="no backport yet") + def test_invalid_substitution(self): + T = TypeVar('T') + Ts = TypeVarTuple("Ts") + U1 = Unpack[Tuple[int, str]] + U2 = Unpack[Ts] + + C1 = Concatenate[str, T, ...] + with self.assertRaisesRegex(TypeError, "Too many arguments"): + C1[U1] + with self.assertRaisesRegex(TypeError, r"Unpack\[Ts\] is not valid as type argument"): + C1[U2] class TypeGuardTests(BaseTestCase): def test_basics(self): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index db4de049..13a4b3f1 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -1766,13 +1766,17 @@ def __call__(self, *args, **kwargs): if not hasattr(typing, 'Concatenate'): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. - #3.9.0-1 + # 3.9.0-1 if not hasattr(typing, '_type_convert'): def _type_convert(arg, module=None, *, allow_special_forms=False): """For converting None to type(None), and strings to ForwardRef.""" if arg is None: return type(None) if isinstance(arg, str): + if sys.version_info <= (3, 9, 6): + return ForwardRef(arg) + if sys.version_info <= (3, 9, 7): + return ForwardRef(arg, module=module) return ForwardRef(arg, module=module, is_class=allow_special_forms) return arg else: @@ -1812,10 +1816,10 @@ def __parameters__(self): # 3.8; needed for typing._subs_tvars # 3.9 used by __getitem__ below def copy_with(self, params): - if isinstance(params[-1], (list, tuple)): - return (*params[:-1], *params[-1]) if isinstance(params[-1], _ConcatenateGenericAlias): params = (*params[:-1], *params[-1].__args__) + elif isinstance(params[-1], (list, tuple)): + return (*params[:-1], *params[-1]) elif (not(params[-1] is ... or isinstance(params[-1], ParamSpec))): raise TypeError("The last parameter to Concatenate should be a " "ParamSpec variable or ellipsis.") @@ -1847,10 +1851,21 @@ def __getitem__(self, args): if len(params) == 1 and not _is_param_expr(args[0]): assert i == 0 args = (args,) - # Convert lists to tuples to help other libraries cache the results. - elif isinstance(args[i], list): + # This class inherits from list do not convert + elif ( + isinstance(args[i], list) + and not isinstance(args[i], _ConcatenateGenericAlias) + ): args = (*args[:i], tuple(args[i]), *args[i+1:]) + alen = len(args) + plen = len(params) + if alen != plen: + raise TypeError( + f"Too {'many' if alen > plen else 'few'} arguments for {self};" + f" actual {alen}, expected {plen}" + ) + subst = dict(zip(self.__parameters__, args)) # determine new args new_args = [] @@ -1860,6 +1875,16 @@ def __getitem__(self, args): continue if isinstance(arg, TypeVar): arg = subst[arg] + if ( + (isinstance(arg, typing._GenericAlias) and _is_unpack(arg)) + or ( + hasattr(_types, "GenericAlias") + and isinstance(arg, _types.GenericAlias) + and getattr(arg, "__unpacked__", False) + ) + ): + raise TypeError(f"{arg} is not valid as type argument") + elif isinstance(arg, typing._GenericAlias if not hasattr(_types, "GenericAlias") else