Skip to content

Commit

Permalink
[stubgen] Improve dataclass init signatures (#18430)
Browse files Browse the repository at this point in the history
Remove generated incomplete `__init__` signatures for dataclasses. Keep
the field specifiers instead.
  • Loading branch information
cdce8p authored Jan 19, 2025
1 parent c4e2eb7 commit 68cffa7
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 34 deletions.
4 changes: 3 additions & 1 deletion mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
# Default field specifiers for dataclasses
DATACLASS_FIELD_SPECIFIERS: Final = ("dataclasses.Field", "dataclasses.field")


SELF_TVAR_NAME: Final = "_DT"
Expand All @@ -87,7 +89,7 @@
order_default=False,
kw_only_default=False,
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
field_specifiers=DATACLASS_FIELD_SPECIFIERS,
)
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init"
Expand Down
33 changes: 25 additions & 8 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
ImportFrom,
IndexExpr,
IntExpr,
LambdaExpr,
ListExpr,
MemberExpr,
MypyFile,
Expand All @@ -113,6 +114,7 @@
Var,
)
from mypy.options import Options as MypyOptions
from mypy.plugins.dataclasses import DATACLASS_FIELD_SPECIFIERS
from mypy.semanal_shared import find_dataclass_transform_spec
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
from mypy.stubdoc import ArgSig, FunctionSig
Expand Down Expand Up @@ -342,11 +344,12 @@ def visit_index_expr(self, node: IndexExpr) -> str:
base = node.base.accept(self)
index = node.index.accept(self)
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
index = index[1:-1]
index = index[1:-1].rstrip(",")
return f"{base}[{index}]"

def visit_tuple_expr(self, node: TupleExpr) -> str:
return f"({', '.join(n.accept(self) for n in node.items)})"
suffix = "," if len(node.items) == 1 else ""
return f"({', '.join(n.accept(self) for n in node.items)}{suffix})"

def visit_list_expr(self, node: ListExpr) -> str:
return f"[{', '.join(n.accept(self) for n in node.items)}]"
Expand All @@ -368,6 +371,10 @@ def visit_op_expr(self, o: OpExpr) -> str:
def visit_star_expr(self, o: StarExpr) -> str:
return f"*{o.expr.accept(self)}"

def visit_lambda_expr(self, o: LambdaExpr) -> str:
# TODO: Required for among other things dataclass.field default_factory
return self.stubgen.add_name("_typeshed.Incomplete")


def find_defined_names(file: MypyFile) -> set[str]:
finder = DefinitionFinder()
Expand Down Expand Up @@ -482,6 +489,7 @@ def __init__(
self.method_names: set[str] = set()
self.processing_enum = False
self.processing_dataclass = False
self.dataclass_field_specifier: tuple[str, ...] = ()

@property
def _current_class(self) -> ClassDef | None:
Expand Down Expand Up @@ -636,8 +644,8 @@ def visit_func_def(self, o: FuncDef) -> None:
is_dataclass_generated = (
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
)
if is_dataclass_generated and o.name != "__init__":
# Skip methods generated by the @dataclass decorator (except for __init__)
if is_dataclass_generated:
# Skip methods generated by the @dataclass decorator
return
if (
self.is_private_name(o.name, o.fullname)
Expand Down Expand Up @@ -793,8 +801,9 @@ def visit_class_def(self, o: ClassDef) -> None:
self.add(f"{self._indent}{docstring}\n")
n = len(self._output)
self._vars.append([])
if self.analyzed and find_dataclass_transform_spec(o):
if self.analyzed and (spec := find_dataclass_transform_spec(o)):
self.processing_dataclass = True
self.dataclass_field_specifier = spec.field_specifiers
super().visit_class_def(o)
self.dedent()
self._vars.pop()
Expand All @@ -809,6 +818,7 @@ def visit_class_def(self, o: ClassDef) -> None:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False
self.dataclass_field_specifier = ()
self._class_stack.pop(-1)
self.processing_enum = False

Expand Down Expand Up @@ -879,8 +889,9 @@ def is_dataclass_transform(self, expr: Expression) -> bool:
expr = expr.callee
if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES:
return True
if find_dataclass_transform_spec(expr) is not None:
if (spec := find_dataclass_transform_spec(expr)) is not None:
self.processing_dataclass = True
self.dataclass_field_specifier = spec.field_specifiers
return True
return False

Expand Down Expand Up @@ -1259,8 +1270,14 @@ def get_assign_initializer(self, rvalue: Expression) -> str:
and not isinstance(rvalue, TempNode)
):
return " = ..."
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
if self.processing_dataclass:
if isinstance(rvalue, CallExpr):
fullname = self.get_fullname(rvalue.callee)
if fullname in (self.dataclass_field_specifier or DATACLASS_FIELD_SPECIFIERS):
p = AliasPrinter(self)
return f" = {rvalue.accept(p)}"
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
# TODO: support other possible cases, where initializer is important

# By default, no initializer is required:
Expand Down
78 changes: 53 additions & 25 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3101,15 +3101,14 @@ import attrs

@attrs.define
class C:
x = attrs.field()
x: int = attrs.field()

[out]
import attrs

@attrs.define
class C:
x = ...
def __init__(self, x) -> None: ...
x: int = attrs.field()

[case testNamedTupleInClass]
from collections import namedtuple
Expand Down Expand Up @@ -4050,8 +4049,9 @@ def i(x=..., y=..., z=...) -> None: ...
[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, InitVar, KW_ONLY
from dataclasses import dataclass, field, Field, InitVar, KW_ONLY
from dataclasses import dataclass as dc
from datetime import datetime
from typing import ClassVar

@dataclasses.dataclass
Expand All @@ -4066,6 +4066,10 @@ class X:
h: int = 1
i: InitVar[str]
j: InitVar = 100
# Lambda not supported yet -> marked as Incomplete instead
k: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds")
)
non_field = None

@dcs.dataclass
Expand All @@ -4083,7 +4087,8 @@ class V: ...
[out]
import dataclasses
import dataclasses as dcs
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
from _typeshed import Incomplete
from dataclasses import Field, InitVar, KW_ONLY, dataclass, dataclass as dc, field
from typing import ClassVar

@dataclasses.dataclass
Expand All @@ -4092,12 +4097,13 @@ class X:
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
k: str = Field(default_factory=Incomplete)
non_field = ...

@dcs.dataclass
Expand All @@ -4110,8 +4116,9 @@ class W: ...
class V: ...

[case testDataclass_semanal]
from dataclasses import InitVar, dataclass, field
from dataclasses import Field, InitVar, dataclass, field
from typing import ClassVar
from datetime import datetime

@dataclass
class X:
Expand All @@ -4125,13 +4132,18 @@ class X:
h: int = 1
i: InitVar = 100
j: list[int] = field(default_factory=list)
# Lambda not supported yet -> marked as Incomplete instead
k: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(" ", timespec="seconds")
)
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import InitVar, dataclass
from _typeshed import Incomplete
from dataclasses import Field, InitVar, dataclass, field
from typing import ClassVar

@dataclass
Expand All @@ -4141,13 +4153,13 @@ class X:
c: str = ...
d: ClassVar
e: ClassVar = ...
f: list[int] = ...
g: int = ...
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
h: int = ...
i: InitVar = ...
j: list[int] = ...
j: list[int] = field(default_factory=list)
k: str = Field(default_factory=Incomplete)
non_field = ...
def __init__(self, a, b, c=..., *, g=..., h=..., i=..., j=...) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...
Expand Down Expand Up @@ -4175,7 +4187,7 @@ class X:
class Y: ...

[out]
from dataclasses import InitVar, KW_ONLY, dataclass
from dataclasses import InitVar, KW_ONLY, dataclass, field
from typing import ClassVar

@dataclass
Expand All @@ -4184,14 +4196,13 @@ class X:
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b=..., *, g=..., h=..., i, j=...) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...
Expand Down Expand Up @@ -4236,15 +4247,13 @@ from dataclasses import dataclass
@dataclass
class X(missing.Base):
a: int
def __init__(self, *generated_args, a, **generated_kwargs) -> None: ...

@dataclass
class Y(missing.Base):
generated_args: str
generated_args_: str
generated_kwargs: float
generated_kwargs_: float
def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ...

[case testDataclassTransform]
# dataclass_transform detection only works with sementic analysis.
Expand Down Expand Up @@ -4298,6 +4307,7 @@ class Z(metaclass=DCMeta):

[case testDataclassTransformDecorator_semanal]
import typing_extensions
from dataclasses import field

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls):
Expand All @@ -4307,9 +4317,11 @@ def create_model(cls):
class X:
a: int
b: str = "hello"
c: bool = field(default=True)

[out]
import typing_extensions
from dataclasses import field

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls): ...
Expand All @@ -4318,9 +4330,10 @@ def create_model(cls): ...
class X:
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...
c: bool = field(default=True)

[case testDataclassTransformClass_semanal]
from dataclasses import field
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
Expand All @@ -4329,8 +4342,10 @@ class ModelBase: ...
class X(ModelBase):
a: int
b: str = "hello"
c: bool = field(default=True)

[out]
from dataclasses import field
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
Expand All @@ -4339,28 +4354,42 @@ class ModelBase: ...
class X(ModelBase):
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...
c: bool = field(default=True)

[case testDataclassTransformMetaclass_semanal]
from dataclasses import field
from typing import Any
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
def custom_field(*, default: bool, kw_only: bool) -> Any: ...

@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,))
class DCMeta(type): ...

class X(metaclass=DCMeta):
a: int
b: str = "hello"
c: bool = field(default=True) # should be ignored, not field_specifier here

class Y(X):
d: str = custom_field(default="Hello")

[out]
from typing import Any
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
def custom_field(*, default: bool, kw_only: bool) -> Any: ...

@dataclass_transform(kw_only_default=True, field_specifiers=(custom_field,))
class DCMeta(type): ...

class X(metaclass=DCMeta):
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...
c: bool = ...

class Y(X):
d: str = custom_field(default='Hello')

[case testAlwaysUsePEP604Union]
import typing
Expand Down Expand Up @@ -4662,4 +4691,3 @@ class DCMeta(type): ...

class DC(metaclass=DCMeta):
x: str
def __init__(self, x) -> None: ...

0 comments on commit 68cffa7

Please sign in to comment.