diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index e4328d764be6..4c96003ca326 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -69,6 +69,7 @@ TupleType, Type, TypeOfAny, + TypeType, TypeVarType, UninhabitedType, UnionType, @@ -935,7 +936,7 @@ def add_method( def _get_attrs_init_type(typ: Instance) -> CallableType | None: """ - If `typ` refers to an attrs class, gets the type of its initializer method. + If `typ` refers to an attrs class, get the type of its initializer method. """ magic_attr = typ.type.get(MAGIC_ATTR_NAME) if magic_attr is None or not magic_attr.plugin_generated: @@ -1009,7 +1010,7 @@ def _get_expanded_attr_types( def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]: """ - "Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound. + "Meet" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound. """ field_to_types = defaultdict(list) for fields in types: @@ -1026,7 +1027,7 @@ def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]: def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType: """ - Generates a signature for the 'attr.evolve' function that's specific to the call site + Generate a signature for the 'attr.evolve' function that's specific to the call site and dependent on the type of the first argument. """ if len(ctx.args) != 2: @@ -1060,3 +1061,48 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl fallback=ctx.default_signature.fallback, name=f"{ctx.default_signature.name} of {inst_type_str}", ) + + +def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType: + """Provide the signature for `attrs.fields`.""" + if not ctx.args or len(ctx.args) != 1 or not ctx.args[0] or not ctx.args[0][0]: + return ctx.default_signature + + # + assert isinstance(ctx.api, TypeChecker) + inst_type = ctx.api.expr_checker.accept(ctx.args[0][0]) + # + proper_type = get_proper_type(inst_type) + + # fields(Any) -> Any, fields(type[Any]) -> Any + if ( + isinstance(proper_type, AnyType) + or isinstance(proper_type, TypeType) + and isinstance(proper_type.item, AnyType) + ): + return ctx.default_signature + + cls = None + arg_types = ctx.default_signature.arg_types + + if isinstance(proper_type, TypeVarType): + inner = get_proper_type(proper_type.upper_bound) + if isinstance(inner, Instance): + # We need to work arg_types to compensate for the attrs stubs. + arg_types = [inst_type] + cls = inner.type + elif isinstance(proper_type, CallableType): + cls = proper_type.type_object() + + if cls is not None and MAGIC_ATTR_NAME in cls.names: + # This is a proper attrs class. + ret_type = cls.names[MAGIC_ATTR_NAME].type + assert ret_type is not None + return ctx.default_signature.copy_modified(arg_types=arg_types, ret_type=ret_type) + + ctx.api.fail( + f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type, ctx.api.options)}"; expected an attrs class', + ctx.context, + ) + + return ctx.default_signature diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 1edc91a1183c..500eef76a9d9 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -45,6 +45,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] return ctypes.array_constructor_callback elif fullname == "functools.singledispatch": return singledispatch.create_singledispatch_function_callback + return None def get_function_signature_hook( @@ -54,6 +55,8 @@ def get_function_signature_hook( if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"): return attrs.evolve_function_sig_callback + elif fullname in ("attr.fields", "attrs.fields"): + return attrs.fields_function_sig_callback return None def get_method_signature_hook( diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index ce1d670431c7..e34408454a83 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -1548,6 +1548,59 @@ takes_attrs_cls(A(1, "")) # E: Argument 1 to "takes_attrs_cls" has incompatible takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompatible type "Type[A]"; expected "AttrsInstance" # N: ClassVar protocol member AttrsInstance.__attrs_attrs__ can never be matched by a class object [builtins fixtures/plugin_attrs.pyi] +[case testAttrsFields] +import attr +from attrs import fields as f # Common usage. + +@attr.define +class A: + b: int + c: str + +reveal_type(f(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]" +reveal_type(f(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]" +reveal_type(f(A).b) # N: Revealed type is "attr.Attribute[builtins.int]" +f(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x" + +[builtins fixtures/plugin_attrs.pyi] + +[case testAttrsGenericFields] +from typing import TypeVar + +import attr +from attrs import fields + +@attr.define +class A: + b: int + c: str + +TA = TypeVar('TA', bound=A) + +def f(t: TA) -> None: + reveal_type(fields(t)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]" + reveal_type(fields(t)[0]) # N: Revealed type is "attr.Attribute[builtins.int]" + reveal_type(fields(t).b) # N: Revealed type is "attr.Attribute[builtins.int]" + fields(t).x # E: "____main___A_AttrsAttributes__" has no attribute "x" + + +[builtins fixtures/plugin_attrs.pyi] + +[case testNonattrsFields] +from typing import Any, cast, Type +from attrs import fields + +class A: + b: int + c: str + +fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class +fields(None) # E: Argument 1 to "fields" has incompatible type "None"; expected an attrs class +fields(cast(Any, 42)) +fields(cast(Type[Any], 43)) + +[builtins fixtures/plugin_attrs.pyi] + [case testAttrsInitMethodAlwaysGenerates] from typing import Tuple import attr diff --git a/test-data/unit/lib-stub/attr/__init__.pyi b/test-data/unit/lib-stub/attr/__init__.pyi index 1a3838aa3ab1..24ffc0f3f275 100644 --- a/test-data/unit/lib-stub/attr/__init__.pyi +++ b/test-data/unit/lib-stub/attr/__init__.pyi @@ -247,3 +247,5 @@ def field( def evolve(inst: _T, **changes: Any) -> _T: ... def assoc(inst: _T, **changes: Any) -> _T: ... + +def fields(cls: type) -> Any: ... diff --git a/test-data/unit/lib-stub/attrs/__init__.pyi b/test-data/unit/lib-stub/attrs/__init__.pyi index 8e9aa1fdced5..cc09ce9b0b49 100644 --- a/test-data/unit/lib-stub/attrs/__init__.pyi +++ b/test-data/unit/lib-stub/attrs/__init__.pyi @@ -129,3 +129,5 @@ def field( def evolve(inst: _T, **changes: Any) -> _T: ... def assoc(inst: _T, **changes: Any) -> _T: ... + +def fields(cls: type) -> Any: ...