Skip to content

Commit 6f2bfff

Browse files
ikonstAlexWaygood
andauthored
Add signature for dataclasses.replace (#14849)
Validate `dataclassses.replace` actual arguments to match the fields: - Unlike `__init__`, the arguments are always named. - All arguments are optional except for `InitVar`s without a default value. The tricks: - We're looking up type of the first positional argument ("obj") through private API. See #10216, #14845. - We're preparing the signature of "replace" (for that specific dataclass) during the dataclass transformation and storing it in a "private" class attribute `__mypy-replace` (obviously not part of PEP-557 but contains a hyphen so should not conflict with any future valid identifier). Stashing the signature into the symbol table allows it to be passed across phases and cached across invocations. The stashed signature lacks the first argument, which we prepend at function signature hook time, since it depends on the type that `replace` is called on. Based on #14526 but actually simpler. Partially addresses #5152. # Remaining tasks - [x] handle generic dataclasses - [x] avoid data class transforms - [x] fine-grained mode tests --------- Co-authored-by: Alex Waygood <[email protected]>
1 parent 21cc1c7 commit 6f2bfff

8 files changed

+389
-3
lines changed

mypy/plugins/dataclasses.py

+162-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from mypy import errorcodes, message_registry
99
from mypy.expandtype import expand_type, expand_type_by_instance
10+
from mypy.meet import meet_types
11+
from mypy.messages import format_type_bare
1012
from mypy.nodes import (
1113
ARG_NAMED,
1214
ARG_NAMED_OPT,
@@ -38,7 +40,7 @@
3840
TypeVarExpr,
3941
Var,
4042
)
41-
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
43+
from mypy.plugin import ClassDefContext, FunctionSigContext, SemanticAnalyzerPluginInterface
4244
from mypy.plugins.common import (
4345
_get_callee_type,
4446
_get_decorator_bool_argument,
@@ -56,10 +58,13 @@
5658
Instance,
5759
LiteralType,
5860
NoneType,
61+
ProperType,
5962
TupleType,
6063
Type,
6164
TypeOfAny,
6265
TypeVarType,
66+
UninhabitedType,
67+
UnionType,
6368
get_proper_type,
6469
)
6570
from mypy.typevars import fill_typevars
@@ -76,6 +81,7 @@
7681
frozen_default=False,
7782
field_specifiers=("dataclasses.Field", "dataclasses.field"),
7883
)
84+
_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace"
7985

8086

8187
class DataclassAttribute:
@@ -344,13 +350,47 @@ def transform(self) -> bool:
344350

345351
self._add_dataclass_fields_magic_attribute()
346352

353+
if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
354+
self._add_internal_replace_method(attributes)
355+
347356
info.metadata["dataclass"] = {
348357
"attributes": [attr.serialize() for attr in attributes],
349358
"frozen": decorator_arguments["frozen"],
350359
}
351360

352361
return True
353362

363+
def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) -> None:
364+
"""
365+
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass
366+
to be used later whenever 'dataclasses.replace' is called for this dataclass.
367+
"""
368+
arg_types: list[Type] = []
369+
arg_kinds = []
370+
arg_names: list[str | None] = []
371+
372+
info = self._cls.info
373+
for attr in attributes:
374+
attr_type = attr.expand_type(info)
375+
assert attr_type is not None
376+
arg_types.append(attr_type)
377+
arg_kinds.append(
378+
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT
379+
)
380+
arg_names.append(attr.name)
381+
382+
signature = CallableType(
383+
arg_types=arg_types,
384+
arg_kinds=arg_kinds,
385+
arg_names=arg_names,
386+
ret_type=NoneType(),
387+
fallback=self._api.named_type("builtins.function"),
388+
)
389+
390+
self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
391+
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
392+
)
393+
354394
def add_slots(
355395
self, info: TypeInfo, attributes: list[DataclassAttribute], *, correct_version: bool
356396
) -> None:
@@ -893,3 +933,124 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool:
893933
info.declared_metaclass is not None
894934
and info.declared_metaclass.type.dataclass_transform_spec is not None
895935
)
936+
937+
938+
def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None:
939+
t_name = format_type_bare(t, ctx.api.options)
940+
if parent_t is t:
941+
msg = (
942+
f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass'
943+
if isinstance(t, TypeVarType)
944+
else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass'
945+
)
946+
else:
947+
pt_name = format_type_bare(parent_t, ctx.api.options)
948+
msg = (
949+
f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass'
950+
if isinstance(t, TypeVarType)
951+
else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass'
952+
)
953+
954+
ctx.api.fail(msg, ctx.context)
955+
956+
957+
def _get_expanded_dataclasses_fields(
958+
ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType
959+
) -> list[CallableType] | None:
960+
"""
961+
For a given type, determine what dataclasses it can be: for each class, return the field types.
962+
For generic classes, the field types are expanded.
963+
If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error.
964+
"""
965+
if isinstance(typ, AnyType):
966+
return None
967+
elif isinstance(typ, UnionType):
968+
ret: list[CallableType] | None = []
969+
for item in typ.relevant_items():
970+
item = get_proper_type(item)
971+
item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ)
972+
if ret is not None and item_types is not None:
973+
ret += item_types
974+
else:
975+
ret = None # but keep iterating to emit all errors
976+
return ret
977+
elif isinstance(typ, TypeVarType):
978+
return _get_expanded_dataclasses_fields(
979+
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
980+
)
981+
elif isinstance(typ, Instance):
982+
replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
983+
if replace_sym is None:
984+
_fail_not_dataclass(ctx, display_typ, parent_typ)
985+
return None
986+
replace_sig = replace_sym.type
987+
assert isinstance(replace_sig, ProperType)
988+
assert isinstance(replace_sig, CallableType)
989+
return [expand_type_by_instance(replace_sig, typ)]
990+
else:
991+
_fail_not_dataclass(ctx, display_typ, parent_typ)
992+
return None
993+
994+
995+
# TODO: we can potentially get the function signature hook to allow returning a union
996+
# and leave this to the regular machinery of resolving a union of callables
997+
# (https://github.com/python/mypy/issues/15457)
998+
def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType:
999+
"""
1000+
Produces the lowest bound of the 'replace' signatures of multiple dataclasses.
1001+
"""
1002+
args = {
1003+
name: (typ, kind)
1004+
for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds)
1005+
}
1006+
1007+
for sig in sigs[1:]:
1008+
sig_args = {
1009+
name: (typ, kind)
1010+
for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds)
1011+
}
1012+
for name in (*args.keys(), *sig_args.keys()):
1013+
sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
1014+
sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
1015+
args[name] = (
1016+
meet_types(sig_typ, sig2_typ),
1017+
ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED,
1018+
)
1019+
1020+
return sigs[0].copy_modified(
1021+
arg_names=list(args.keys()),
1022+
arg_types=[typ for typ, _ in args.values()],
1023+
arg_kinds=[kind for _, kind in args.values()],
1024+
)
1025+
1026+
1027+
def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
1028+
"""
1029+
Returns a signature for the 'dataclasses.replace' function that's dependent on the type
1030+
of the first positional argument.
1031+
"""
1032+
if len(ctx.args) != 2:
1033+
# Ideally the name and context should be callee's, but we don't have it in FunctionSigContext.
1034+
ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context)
1035+
return ctx.default_signature
1036+
1037+
if len(ctx.args[0]) != 1:
1038+
return ctx.default_signature # leave it to the type checker to complain
1039+
1040+
obj_arg = ctx.args[0][0]
1041+
obj_type = get_proper_type(ctx.api.get_expression_type(obj_arg))
1042+
inst_type_str = format_type_bare(obj_type, ctx.api.options)
1043+
1044+
replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type)
1045+
if replace_sigs is None:
1046+
return ctx.default_signature
1047+
replace_sig = _meet_replace_sigs(replace_sigs)
1048+
1049+
return replace_sig.copy_modified(
1050+
arg_names=[None, *replace_sig.arg_names],
1051+
arg_kinds=[ARG_POS, *replace_sig.arg_kinds],
1052+
arg_types=[obj_type, *replace_sig.arg_types],
1053+
ret_type=obj_type,
1054+
fallback=ctx.default_signature.fallback,
1055+
name=f"{ctx.default_signature.name} of {inst_type_str}",
1056+
)

mypy/plugins/default.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
5151
def get_function_signature_hook(
5252
self, fullname: str
5353
) -> Callable[[FunctionSigContext], FunctionLike] | None:
54-
from mypy.plugins import attrs
54+
from mypy.plugins import attrs, dataclasses
5555

5656
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
5757
return attrs.evolve_function_sig_callback
5858
elif fullname in ("attr.fields", "attrs.fields"):
5959
return attrs.fields_function_sig_callback
60+
elif fullname == "dataclasses.replace":
61+
return dataclasses.replace_function_sig_callback
6062
return None
6163

6264
def get_method_signature_hook(

test-data/unit/check-dataclass-transform.test

+19
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,24 @@ reveal_type(bar.base) # N: Revealed type is "builtins.int"
840840
[typing fixtures/typing-full.pyi]
841841
[builtins fixtures/dataclasses.pyi]
842842

843+
[case testDataclassTransformReplace]
844+
from dataclasses import replace
845+
from typing import dataclass_transform, Type
846+
847+
@dataclass_transform()
848+
def my_dataclass(cls: Type) -> Type:
849+
return cls
850+
851+
@my_dataclass
852+
class Person:
853+
name: str
854+
855+
p = Person('John')
856+
y = replace(p, name='Bob') # E: Argument 1 to "replace" has incompatible type "Person"; expected a dataclass
857+
858+
[typing fixtures/typing-full.pyi]
859+
[builtins fixtures/dataclasses.pyi]
860+
843861
[case testDataclassTransformSimpleDescriptor]
844862
# flags: --python-version 3.11
845863

@@ -1051,5 +1069,6 @@ class Desc2:
10511069
class C:
10521070
x: Desc # E: Unsupported signature for "__set__" in "Desc"
10531071
y: Desc2 # E: Unsupported "__set__" in "Desc2"
1072+
10541073
[typing fixtures/typing-full.pyi]
10551074
[builtins fixtures/dataclasses.pyi]

0 commit comments

Comments
 (0)