Skip to content

Commit f13a463

Browse files
committed
Refactor to support typevars, and more tests
1 parent bd33c7a commit f13a463

File tree

5 files changed

+69
-28
lines changed

5 files changed

+69
-28
lines changed

mypy/plugins/attrs.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Iterable, List, Optional, cast
5+
from typing import Iterable, List, cast
66
from typing_extensions import Final, Literal
77

88
import mypy.plugin # To avoid circular imports.
@@ -43,7 +43,7 @@
4343
Var,
4444
is_class_var,
4545
)
46-
from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface
46+
from mypy.plugin import SemanticAnalyzerPluginInterface
4747
from mypy.plugins.common import (
4848
_get_argument,
4949
_get_bool_argument,
@@ -990,27 +990,42 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
990990
)
991991

992992

993-
def _get_cls_from_init(t: Type) -> Optional[TypeInfo]:
994-
proper_type = get_proper_type(t)
995-
if isinstance(proper_type, CallableType):
996-
return proper_type.type_object()
997-
return None
993+
def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
994+
"""Provide the proper signature for `attrs.fields`."""
995+
if ctx.args and len(ctx.args) == 1 and ctx.args[0] and ctx.args[0][0]:
998996

997+
# <hack>
998+
assert isinstance(ctx.api, TypeChecker)
999+
inst_type = ctx.api.expr_checker.accept(ctx.args[0][0])
1000+
# </hack>
1001+
proper_type = get_proper_type(inst_type)
1002+
1003+
if isinstance(proper_type, AnyType): # fields(Any) -> Any
1004+
return ctx.default_signature
1005+
1006+
cls = None
1007+
arg_types = ctx.default_signature.arg_types
1008+
1009+
if isinstance(proper_type, TypeVarType):
1010+
inner = get_proper_type(proper_type.upper_bound)
1011+
if isinstance(inner, Instance):
1012+
# We need to work arg_types to compensate for the attrs stubs.
1013+
arg_types = [inst_type]
1014+
cls = inner.type
1015+
elif isinstance(proper_type, CallableType):
1016+
cls = proper_type.type_object()
9991017

1000-
def fields_function_callback(ctx: FunctionContext) -> Type:
1001-
"""Provide the proper return value for `attrs.fields`."""
1002-
if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]:
1003-
first_arg_type = ctx.arg_types[0][0]
1004-
cls = _get_cls_from_init(first_arg_type)
10051018
if cls is not None:
10061019
if MAGIC_ATTR_NAME in cls.names:
10071020
# This is a proper attrs class.
10081021
ret_type = cls.names[MAGIC_ATTR_NAME].type
10091022
if ret_type is not None:
1010-
return ret_type
1011-
else:
1012-
ctx.api.fail(
1013-
f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class',
1014-
ctx.context,
1015-
)
1016-
return ctx.default_return_type
1023+
return ctx.default_signature.copy_modified(
1024+
arg_types=arg_types, ret_type=ret_type
1025+
)
1026+
1027+
ctx.api.fail(
1028+
f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type)}"; expected an attrs class',
1029+
ctx.context,
1030+
)
1031+
return ctx.default_signature

mypy/plugins/default.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,13 @@ class DefaultPlugin(Plugin):
3939
"""Type checker plugin that is enabled by default."""
4040

4141
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
42-
from mypy.plugins import attrs, ctypes, singledispatch
42+
from mypy.plugins import ctypes, singledispatch
4343

4444
if fullname == "ctypes.Array":
4545
return ctypes.array_constructor_callback
4646
elif fullname == "functools.singledispatch":
4747
return singledispatch.create_singledispatch_function_callback
48-
elif fullname in ("attr.fields", "attrs.fields"):
49-
return attrs.fields_function_callback
48+
5049
return None
5150

5251
def get_function_signature_hook(
@@ -56,6 +55,8 @@ def get_function_signature_hook(
5655

5756
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
5857
return attrs.evolve_function_sig_callback
58+
elif fullname in ("attr.fields", "attrs.fields"):
59+
return attrs.fields_function_sig_callback
5960
return None
6061

6162
def get_method_signature_hook(

test-data/unit/check-attr.test

+29-4
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,24 @@ takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompati
15491549
[builtins fixtures/attr.pyi]
15501550

15511551
[case testAttrsFields]
1552+
import attr
1553+
from attrs import fields as f # Common usage.
1554+
1555+
@attr.define
1556+
class A:
1557+
b: int
1558+
c: str
1559+
1560+
reveal_type(f(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1561+
reveal_type(f(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1562+
reveal_type(f(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1563+
f(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1564+
1565+
[builtins fixtures/attr.pyi]
1566+
1567+
[case testAttrsGenericFields]
1568+
from typing import TypeVar
1569+
15521570
import attr
15531571
from attrs import fields
15541572

@@ -1557,21 +1575,28 @@ class A:
15571575
b: int
15581576
c: str
15591577

1560-
reveal_type(fields(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1561-
reveal_type(fields(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1562-
reveal_type(fields(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1563-
fields(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1578+
TA = TypeVar('TA', bound=A)
1579+
1580+
def f(t: TA) -> None:
1581+
reveal_type(fields(t)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1582+
reveal_type(fields(t)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1583+
reveal_type(fields(t).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1584+
fields(t).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1585+
15641586

15651587
[builtins fixtures/attr.pyi]
15661588

15671589
[case testNonattrsFields]
1590+
from typing import Any, cast
15681591
from attrs import fields
15691592

15701593
class A:
15711594
b: int
15721595
c: str
15731596

15741597
fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class
1598+
fields(None) # E: Argument 1 to "fields" has incompatible type "None"; expected an attrs class
1599+
fields(cast(Any, 42))
15751600

15761601
[builtins fixtures/attr.pyi]
15771602

test-data/unit/lib-stub/attr/__init__.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,4 +248,4 @@ def field(
248248
def evolve(inst: _T, **changes: Any) -> _T: ...
249249
def assoc(inst: _T, **changes: Any) -> _T: ...
250250

251-
def fields(cls: _C) -> Any: ...
251+
def fields(cls: type) -> Any: ...

test-data/unit/lib-stub/attrs/__init__.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,4 @@ def field(
130130
def evolve(inst: _T, **changes: Any) -> _T: ...
131131
def assoc(inst: _T, **changes: Any) -> _T: ...
132132

133-
def fields(cls: _C) -> Any: ...
133+
def fields(cls: type) -> Any: ...

0 commit comments

Comments
 (0)