|
7 | 7 |
|
8 | 8 | from mypy import errorcodes, message_registry
|
9 | 9 | from mypy.expandtype import expand_type, expand_type_by_instance
|
| 10 | +from mypy.meet import meet_types |
| 11 | +from mypy.messages import format_type_bare |
10 | 12 | from mypy.nodes import (
|
11 | 13 | ARG_NAMED,
|
12 | 14 | ARG_NAMED_OPT,
|
|
38 | 40 | TypeVarExpr,
|
39 | 41 | Var,
|
40 | 42 | )
|
41 |
| -from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface |
| 43 | +from mypy.plugin import ClassDefContext, FunctionSigContext, SemanticAnalyzerPluginInterface |
42 | 44 | from mypy.plugins.common import (
|
43 | 45 | _get_callee_type,
|
44 | 46 | _get_decorator_bool_argument,
|
|
56 | 58 | Instance,
|
57 | 59 | LiteralType,
|
58 | 60 | NoneType,
|
| 61 | + ProperType, |
59 | 62 | TupleType,
|
60 | 63 | Type,
|
61 | 64 | TypeOfAny,
|
62 | 65 | TypeVarType,
|
| 66 | + UninhabitedType, |
| 67 | + UnionType, |
63 | 68 | get_proper_type,
|
64 | 69 | )
|
65 | 70 | from mypy.typevars import fill_typevars
|
|
76 | 81 | frozen_default=False,
|
77 | 82 | field_specifiers=("dataclasses.Field", "dataclasses.field"),
|
78 | 83 | )
|
| 84 | +_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace" |
79 | 85 |
|
80 | 86 |
|
81 | 87 | class DataclassAttribute:
|
@@ -344,13 +350,47 @@ def transform(self) -> bool:
|
344 | 350 |
|
345 | 351 | self._add_dataclass_fields_magic_attribute()
|
346 | 352 |
|
| 353 | + if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES: |
| 354 | + self._add_internal_replace_method(attributes) |
| 355 | + |
347 | 356 | info.metadata["dataclass"] = {
|
348 | 357 | "attributes": [attr.serialize() for attr in attributes],
|
349 | 358 | "frozen": decorator_arguments["frozen"],
|
350 | 359 | }
|
351 | 360 |
|
352 | 361 | return True
|
353 | 362 |
|
| 363 | + def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) -> None: |
| 364 | + """ |
| 365 | + Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass |
| 366 | + to be used later whenever 'dataclasses.replace' is called for this dataclass. |
| 367 | + """ |
| 368 | + arg_types: list[Type] = [] |
| 369 | + arg_kinds = [] |
| 370 | + arg_names: list[str | None] = [] |
| 371 | + |
| 372 | + info = self._cls.info |
| 373 | + for attr in attributes: |
| 374 | + attr_type = attr.expand_type(info) |
| 375 | + assert attr_type is not None |
| 376 | + arg_types.append(attr_type) |
| 377 | + arg_kinds.append( |
| 378 | + ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT |
| 379 | + ) |
| 380 | + arg_names.append(attr.name) |
| 381 | + |
| 382 | + signature = CallableType( |
| 383 | + arg_types=arg_types, |
| 384 | + arg_kinds=arg_kinds, |
| 385 | + arg_names=arg_names, |
| 386 | + ret_type=NoneType(), |
| 387 | + fallback=self._api.named_type("builtins.function"), |
| 388 | + ) |
| 389 | + |
| 390 | + self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode( |
| 391 | + kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True |
| 392 | + ) |
| 393 | + |
354 | 394 | def add_slots(
|
355 | 395 | self, info: TypeInfo, attributes: list[DataclassAttribute], *, correct_version: bool
|
356 | 396 | ) -> None:
|
@@ -893,3 +933,124 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool:
|
893 | 933 | info.declared_metaclass is not None
|
894 | 934 | and info.declared_metaclass.type.dataclass_transform_spec is not None
|
895 | 935 | )
|
| 936 | + |
| 937 | + |
| 938 | +def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None: |
| 939 | + t_name = format_type_bare(t, ctx.api.options) |
| 940 | + if parent_t is t: |
| 941 | + msg = ( |
| 942 | + f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass' |
| 943 | + if isinstance(t, TypeVarType) |
| 944 | + else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass' |
| 945 | + ) |
| 946 | + else: |
| 947 | + pt_name = format_type_bare(parent_t, ctx.api.options) |
| 948 | + msg = ( |
| 949 | + f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass' |
| 950 | + if isinstance(t, TypeVarType) |
| 951 | + else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass' |
| 952 | + ) |
| 953 | + |
| 954 | + ctx.api.fail(msg, ctx.context) |
| 955 | + |
| 956 | + |
| 957 | +def _get_expanded_dataclasses_fields( |
| 958 | + ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType |
| 959 | +) -> list[CallableType] | None: |
| 960 | + """ |
| 961 | + For a given type, determine what dataclasses it can be: for each class, return the field types. |
| 962 | + For generic classes, the field types are expanded. |
| 963 | + If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error. |
| 964 | + """ |
| 965 | + if isinstance(typ, AnyType): |
| 966 | + return None |
| 967 | + elif isinstance(typ, UnionType): |
| 968 | + ret: list[CallableType] | None = [] |
| 969 | + for item in typ.relevant_items(): |
| 970 | + item = get_proper_type(item) |
| 971 | + item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ) |
| 972 | + if ret is not None and item_types is not None: |
| 973 | + ret += item_types |
| 974 | + else: |
| 975 | + ret = None # but keep iterating to emit all errors |
| 976 | + return ret |
| 977 | + elif isinstance(typ, TypeVarType): |
| 978 | + return _get_expanded_dataclasses_fields( |
| 979 | + ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ |
| 980 | + ) |
| 981 | + elif isinstance(typ, Instance): |
| 982 | + replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME) |
| 983 | + if replace_sym is None: |
| 984 | + _fail_not_dataclass(ctx, display_typ, parent_typ) |
| 985 | + return None |
| 986 | + replace_sig = replace_sym.type |
| 987 | + assert isinstance(replace_sig, ProperType) |
| 988 | + assert isinstance(replace_sig, CallableType) |
| 989 | + return [expand_type_by_instance(replace_sig, typ)] |
| 990 | + else: |
| 991 | + _fail_not_dataclass(ctx, display_typ, parent_typ) |
| 992 | + return None |
| 993 | + |
| 994 | + |
| 995 | +# TODO: we can potentially get the function signature hook to allow returning a union |
| 996 | +# and leave this to the regular machinery of resolving a union of callables |
| 997 | +# (https://github.com/python/mypy/issues/15457) |
| 998 | +def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType: |
| 999 | + """ |
| 1000 | + Produces the lowest bound of the 'replace' signatures of multiple dataclasses. |
| 1001 | + """ |
| 1002 | + args = { |
| 1003 | + name: (typ, kind) |
| 1004 | + for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds) |
| 1005 | + } |
| 1006 | + |
| 1007 | + for sig in sigs[1:]: |
| 1008 | + sig_args = { |
| 1009 | + name: (typ, kind) |
| 1010 | + for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds) |
| 1011 | + } |
| 1012 | + for name in (*args.keys(), *sig_args.keys()): |
| 1013 | + sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT)) |
| 1014 | + sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT)) |
| 1015 | + args[name] = ( |
| 1016 | + meet_types(sig_typ, sig2_typ), |
| 1017 | + ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED, |
| 1018 | + ) |
| 1019 | + |
| 1020 | + return sigs[0].copy_modified( |
| 1021 | + arg_names=list(args.keys()), |
| 1022 | + arg_types=[typ for typ, _ in args.values()], |
| 1023 | + arg_kinds=[kind for _, kind in args.values()], |
| 1024 | + ) |
| 1025 | + |
| 1026 | + |
| 1027 | +def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType: |
| 1028 | + """ |
| 1029 | + Returns a signature for the 'dataclasses.replace' function that's dependent on the type |
| 1030 | + of the first positional argument. |
| 1031 | + """ |
| 1032 | + if len(ctx.args) != 2: |
| 1033 | + # Ideally the name and context should be callee's, but we don't have it in FunctionSigContext. |
| 1034 | + ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context) |
| 1035 | + return ctx.default_signature |
| 1036 | + |
| 1037 | + if len(ctx.args[0]) != 1: |
| 1038 | + return ctx.default_signature # leave it to the type checker to complain |
| 1039 | + |
| 1040 | + obj_arg = ctx.args[0][0] |
| 1041 | + obj_type = get_proper_type(ctx.api.get_expression_type(obj_arg)) |
| 1042 | + inst_type_str = format_type_bare(obj_type, ctx.api.options) |
| 1043 | + |
| 1044 | + replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type) |
| 1045 | + if replace_sigs is None: |
| 1046 | + return ctx.default_signature |
| 1047 | + replace_sig = _meet_replace_sigs(replace_sigs) |
| 1048 | + |
| 1049 | + return replace_sig.copy_modified( |
| 1050 | + arg_names=[None, *replace_sig.arg_names], |
| 1051 | + arg_kinds=[ARG_POS, *replace_sig.arg_kinds], |
| 1052 | + arg_types=[obj_type, *replace_sig.arg_types], |
| 1053 | + ret_type=obj_type, |
| 1054 | + fallback=ctx.default_signature.fallback, |
| 1055 | + name=f"{ctx.default_signature.name} of {inst_type_str}", |
| 1056 | + ) |
0 commit comments