Skip to content

Commit

Permalink
More cases covered
Browse files Browse the repository at this point in the history
- nested concatenate
- Unpack subscription
  • Loading branch information
Daraan committed Oct 23, 2024
1 parent b89c272 commit 83dbe19
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
22 changes: 20 additions & 2 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5520,17 +5520,20 @@ def test_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
Ts = TypeVarTuple("Ts")
U1 = Unpack[Tuple[int, str]]
U2 = Unpack[Ts]

C1 = Concatenate[str, T, ...]
self.assertEqual(C1[int], Concatenate[str, int, ...])

C2 = Concatenate[str, P]
self.assertEqual(C2[...], Concatenate[str, ...])
self.assertEqual(C2[int], (str, int))
U1 = Unpack[Tuple[int, str]]
U2 = Unpack[Ts]
self.assertEqual(C2[int, ...], (str, int, ...))

self.assertEqual(C2[U1], (str, int, str))
self.assertEqual(C2[U2], (str, Unpack[Ts]))
self.assertEqual(C2["U1"], (str, typing.ForwardRef("U1")))

if (3, 12, 0) <= sys.version_info < (3, 12, 4):
with self.assertRaises(AssertionError):
Expand All @@ -5541,7 +5544,22 @@ def test_substitution(self):

C3 = Concatenate[str, T, P]
self.assertEqual(C3[int, [bool]], (str, int, bool))
self.assertEqual(C3[int, ...], Concatenate[str, int, ...])
self.assertEqual(C3[int, Concatenate[str, P]], Concatenate[str, int, str, P])

@skipIf((3, 10) <= sys.version_info < (3, 12), reason="no backport yet")
def test_invalid_substitution(self):
T = TypeVar('T')
Ts = TypeVarTuple("Ts")
U1 = Unpack[Tuple[int, str]]
U2 = Unpack[Ts]

C1 = Concatenate[str, T, ...]
with self.assertRaisesRegex(TypeError, "Too many arguments"):
C1[U1]

with self.assertRaisesRegex(TypeError, r"Unpack\[Ts\] is not valid as type argument"):
C1[U2]

class TypeGuardTests(BaseTestCase):
def test_basics(self):
Expand Down
35 changes: 30 additions & 5 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,13 +1766,17 @@ def __call__(self, *args, **kwargs):
if not hasattr(typing, 'Concatenate'):
# Inherits from list as a workaround for Callable checks in Python < 3.9.2.

#3.9.0-1
# 3.9.0-1
if not hasattr(typing, '_type_convert'):
def _type_convert(arg, module=None, *, allow_special_forms=False):
"""For converting None to type(None), and strings to ForwardRef."""
if arg is None:
return type(None)
if isinstance(arg, str):
if sys.version_info <= (3, 9, 6):
return ForwardRef(arg)
if sys.version_info <= (3, 9, 7):
return ForwardRef(arg, module=module)
return ForwardRef(arg, module=module, is_class=allow_special_forms)
return arg
else:
Expand Down Expand Up @@ -1812,10 +1816,10 @@ def __parameters__(self):
# 3.8; needed for typing._subs_tvars
# 3.9 used by __getitem__ below
def copy_with(self, params):
if isinstance(params[-1], (list, tuple)):
return (*params[:-1], *params[-1])
if isinstance(params[-1], _ConcatenateGenericAlias):
params = (*params[:-1], *params[-1].__args__)
elif isinstance(params[-1], (list, tuple)):
return (*params[:-1], *params[-1])
elif (not(params[-1] is ... or isinstance(params[-1], ParamSpec))):
raise TypeError("The last parameter to Concatenate should be a "
"ParamSpec variable or ellipsis.")
Expand Down Expand Up @@ -1847,10 +1851,21 @@ def __getitem__(self, args):
if len(params) == 1 and not _is_param_expr(args[0]):
assert i == 0
args = (args,)
# Convert lists to tuples to help other libraries cache the results.
elif isinstance(args[i], list):
# This class inherits from list do not convert
elif (
isinstance(args[i], list)
and not isinstance(args[i], _ConcatenateGenericAlias)
):
args = (*args[:i], tuple(args[i]), *args[i+1:])

alen = len(args)
plen = len(params)
if alen != plen:
raise TypeError(
f"Too {'many' if alen > plen else 'few'} arguments for {self};"
f" actual {alen}, expected {plen}"
)

subst = dict(zip(self.__parameters__, args))
# determine new args
new_args = []
Expand All @@ -1860,6 +1875,16 @@ def __getitem__(self, args):
continue
if isinstance(arg, TypeVar):
arg = subst[arg]
if (
(isinstance(arg, typing._GenericAlias) and _is_unpack(arg))
or (
hasattr(_types, "GenericAlias")
and isinstance(arg, _types.GenericAlias)
and getattr(arg, "__unpacked__", False)
)
):
raise TypeError(f"{arg} is not valid as type argument")

elif isinstance(arg,
typing._GenericAlias
if not hasattr(_types, "GenericAlias") else
Expand Down

0 comments on commit 83dbe19

Please sign in to comment.