diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 538f689f5e07..6e0e22272356 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -79,6 +79,8 @@ # The set of decorators that generate dataclasses. dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"} +# Default field specifiers for dataclasses +DATACLASS_FIELD_SPECIFIERS: Final = ("dataclasses.Field", "dataclasses.field") SELF_TVAR_NAME: Final = "_DT" @@ -87,7 +89,7 @@ order_default=False, kw_only_default=False, frozen_default=False, - field_specifiers=("dataclasses.Field", "dataclasses.field"), + field_specifiers=DATACLASS_FIELD_SPECIFIERS, ) _INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace" _INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init" diff --git a/mypy/stubgen.py b/mypy/stubgen.py index c74e9f700861..1f8a1a4740f1 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -95,6 +95,7 @@ ImportFrom, IndexExpr, IntExpr, + LambdaExpr, ListExpr, MemberExpr, MypyFile, @@ -113,6 +114,7 @@ Var, ) from mypy.options import Options as MypyOptions +from mypy.plugins.dataclasses import DATACLASS_FIELD_SPECIFIERS from mypy.semanal_shared import find_dataclass_transform_spec from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY from mypy.stubdoc import ArgSig, FunctionSig @@ -342,11 +344,12 @@ def visit_index_expr(self, node: IndexExpr) -> str: base = node.base.accept(self) index = node.index.accept(self) if len(index) > 2 and index.startswith("(") and index.endswith(")"): - index = index[1:-1] + index = index[1:-1].rstrip(",") return f"{base}[{index}]" def visit_tuple_expr(self, node: TupleExpr) -> str: - return f"({', '.join(n.accept(self) for n in node.items)})" + suffix = "," if len(node.items) == 1 else "" + return f"({', '.join(n.accept(self) for n in node.items)}{suffix})" def visit_list_expr(self, node: ListExpr) -> str: return f"[{', '.join(n.accept(self) for n in node.items)}]" @@ -368,6 +371,10 @@ def visit_op_expr(self, o: OpExpr) -> str: def visit_star_expr(self, o: StarExpr) -> str: return f"*{o.expr.accept(self)}" + def visit_lambda_expr(self, o: LambdaExpr) -> str: + # TODO: Required for among other things dataclass.field default_factory + return self.stubgen.add_name("_typeshed.Incomplete") + def find_defined_names(file: MypyFile) -> set[str]: finder = DefinitionFinder() @@ -482,6 +489,7 @@ def __init__( self.method_names: set[str] = set() self.processing_enum = False self.processing_dataclass = False + self.dataclass_field_specifier: tuple[str, ...] = () @property def _current_class(self) -> ClassDef | None: @@ -636,8 +644,8 @@ def visit_func_def(self, o: FuncDef) -> None: is_dataclass_generated = ( self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated ) - if is_dataclass_generated and o.name != "__init__": - # Skip methods generated by the @dataclass decorator (except for __init__) + if is_dataclass_generated: + # Skip methods generated by the @dataclass decorator return if ( self.is_private_name(o.name, o.fullname) @@ -793,8 +801,9 @@ def visit_class_def(self, o: ClassDef) -> None: self.add(f"{self._indent}{docstring}\n") n = len(self._output) self._vars.append([]) - if self.analyzed and find_dataclass_transform_spec(o): + if self.analyzed and (spec := find_dataclass_transform_spec(o)): self.processing_dataclass = True + self.dataclass_field_specifier = spec.field_specifiers super().visit_class_def(o) self.dedent() self._vars.pop() @@ -809,6 +818,7 @@ def visit_class_def(self, o: ClassDef) -> None: self._state = CLASS self.method_names = set() self.processing_dataclass = False + self.dataclass_field_specifier = () self._class_stack.pop(-1) self.processing_enum = False @@ -879,8 +889,9 @@ def is_dataclass_transform(self, expr: Expression) -> bool: expr = expr.callee if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES: return True - if find_dataclass_transform_spec(expr) is not None: + if (spec := find_dataclass_transform_spec(expr)) is not None: self.processing_dataclass = True + self.dataclass_field_specifier = spec.field_specifiers return True return False @@ -1259,8 +1270,14 @@ def get_assign_initializer(self, rvalue: Expression) -> str: and not isinstance(rvalue, TempNode) ): return " = ..." - if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs): - return " = ..." + if self.processing_dataclass: + if isinstance(rvalue, CallExpr): + fullname = self.get_fullname(rvalue.callee) + if fullname in (self.dataclass_field_specifier or DATACLASS_FIELD_SPECIFIERS): + p = AliasPrinter(self) + return f" = {rvalue.accept(p)}" + if not (isinstance(rvalue, TempNode) and rvalue.no_rhs): + return " = ..." # TODO: support other possible cases, where initializer is important # By default, no initializer is required: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index fa462dc23a9a..7700f04c6797 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -3101,15 +3101,14 @@ import attrs @attrs.define class C: - x = attrs.field() + x: int = attrs.field() [out] import attrs @attrs.define class C: - x = ... - def __init__(self, x) -> None: ... + x: int = attrs.field() [case testNamedTupleInClass] from collections import namedtuple @@ -4050,8 +4049,9 @@ def i(x=..., y=..., z=...) -> None: ... [case testDataclass] import dataclasses import dataclasses as dcs -from dataclasses import dataclass, InitVar, KW_ONLY +from dataclasses import dataclass, field, Field, InitVar, KW_ONLY from dataclasses import dataclass as dc +from datetime import datetime from typing import ClassVar @dataclasses.dataclass @@ -4066,6 +4066,10 @@ class X: h: int = 1 i: InitVar[str] j: InitVar = 100 + # Lambda not supported yet -> marked as Incomplete instead + k: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds") + ) non_field = None @dcs.dataclass @@ -4083,7 +4087,8 @@ class V: ... [out] import dataclasses import dataclasses as dcs -from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc +from _typeshed import Incomplete +from dataclasses import Field, InitVar, KW_ONLY, dataclass, dataclass as dc, field from typing import ClassVar @dataclasses.dataclass @@ -4092,12 +4097,13 @@ class X: b: str = ... c: ClassVar d: ClassVar = ... - f: list[int] = ... - g: int = ... + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) _: KW_ONLY h: int = ... i: InitVar[str] j: InitVar = ... + k: str = Field(default_factory=Incomplete) non_field = ... @dcs.dataclass @@ -4110,8 +4116,9 @@ class W: ... class V: ... [case testDataclass_semanal] -from dataclasses import InitVar, dataclass, field +from dataclasses import Field, InitVar, dataclass, field from typing import ClassVar +from datetime import datetime @dataclass class X: @@ -4125,13 +4132,18 @@ class X: h: int = 1 i: InitVar = 100 j: list[int] = field(default_factory=list) + # Lambda not supported yet -> marked as Incomplete instead + k: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds") + ) non_field = None @dataclass(init=False, repr=False, frozen=True) class Y: ... [out] -from dataclasses import InitVar, dataclass +from _typeshed import Incomplete +from dataclasses import Field, InitVar, dataclass, field from typing import ClassVar @dataclass @@ -4141,13 +4153,13 @@ class X: c: str = ... d: ClassVar e: ClassVar = ... - f: list[int] = ... - g: int = ... + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) h: int = ... i: InitVar = ... - j: list[int] = ... + j: list[int] = field(default_factory=list) + k: str = Field(default_factory=Incomplete) non_field = ... - def __init__(self, a, b, c=..., *, g=..., h=..., i=..., j=...) -> None: ... @dataclass(init=False, repr=False, frozen=True) class Y: ... @@ -4175,7 +4187,7 @@ class X: class Y: ... [out] -from dataclasses import InitVar, KW_ONLY, dataclass +from dataclasses import InitVar, KW_ONLY, dataclass, field from typing import ClassVar @dataclass @@ -4184,14 +4196,13 @@ class X: b: str = ... c: ClassVar d: ClassVar = ... - f: list[int] = ... - g: int = ... + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) _: KW_ONLY h: int = ... i: InitVar[str] j: InitVar = ... non_field = ... - def __init__(self, a, b=..., *, g=..., h=..., i, j=...) -> None: ... @dataclass(init=False, repr=False, frozen=True) class Y: ... @@ -4236,7 +4247,6 @@ from dataclasses import dataclass @dataclass class X(missing.Base): a: int - def __init__(self, *generated_args, a, **generated_kwargs) -> None: ... @dataclass class Y(missing.Base): @@ -4244,7 +4254,6 @@ class Y(missing.Base): generated_args_: str generated_kwargs: float generated_kwargs_: float - def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ... [case testDataclassTransform] # dataclass_transform detection only works with sementic analysis. @@ -4298,6 +4307,7 @@ class Z(metaclass=DCMeta): [case testDataclassTransformDecorator_semanal] import typing_extensions +from dataclasses import field @typing_extensions.dataclass_transform(kw_only_default=True) def create_model(cls): @@ -4307,9 +4317,11 @@ def create_model(cls): class X: a: int b: str = "hello" + c: bool = field(default=True) [out] import typing_extensions +from dataclasses import field @typing_extensions.dataclass_transform(kw_only_default=True) def create_model(cls): ... @@ -4318,9 +4330,10 @@ def create_model(cls): ... class X: a: int b: str = ... - def __init__(self, *, a, b=...) -> None: ... + c: bool = field(default=True) [case testDataclassTransformClass_semanal] +from dataclasses import field from typing_extensions import dataclass_transform @dataclass_transform(kw_only_default=True) @@ -4329,8 +4342,10 @@ class ModelBase: ... class X(ModelBase): a: int b: str = "hello" + c: bool = field(default=True) [out] +from dataclasses import field from typing_extensions import dataclass_transform @dataclass_transform(kw_only_default=True) @@ -4339,28 +4354,42 @@ class ModelBase: ... class X(ModelBase): a: int b: str = ... - def __init__(self, *, a, b=...) -> None: ... + c: bool = field(default=True) [case testDataclassTransformMetaclass_semanal] +from dataclasses import field +from typing import Any from typing_extensions import dataclass_transform -@dataclass_transform(kw_only_default=True) +def custom_field(*, default: bool, kw_only: bool) -> Any: ... + +@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,)) class DCMeta(type): ... class X(metaclass=DCMeta): a: int b: str = "hello" + c: bool = field(default=True) # should be ignored, not field_specifier here + +class Y(X): + d: str = custom_field(default="Hello") [out] +from typing import Any from typing_extensions import dataclass_transform -@dataclass_transform(kw_only_default=True) +def custom_field(*, default: bool, kw_only: bool) -> Any: ... + +@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,)) class DCMeta(type): ... class X(metaclass=DCMeta): a: int b: str = ... - def __init__(self, *, a, b=...) -> None: ... + c: bool = ... + +class Y(X): + d: str = custom_field(default='Hello') [case testAlwaysUsePEP604Union] import typing @@ -4662,4 +4691,3 @@ class DCMeta(type): ... class DC(metaclass=DCMeta): x: str - def __init__(self, x) -> None: ...