diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index f5eb3e84a94..77ed0b1f0a9 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -43,6 +43,28 @@ # once to extract both keys and values, then passing keys to __new__, values to __init__ # within the same function call. class _Meta(type): + @overload + def __call__(cls: type[_T]) -> _T: + ... + + @overload + def __call__(cls: type[_T], __value: Iterable[tuple[_KT, _VT]]) -> _T: + ... + + @overload + def __call__(cls: type[_T], __value: Mapping[_KT, _VT]) -> _T: + ... + + @overload + def __call__(cls: type[_T], __value: Iterable[_VT], __key: Collection[_KT]) -> _T: + ... + + @overload + def __call__( + cls: type[_T], __value: Iterable[_VT], __key: Callable[[_VT], _KT] + ) -> _T: + ... + def __call__( cls: type[_T], __value: Optional[Iterable] = None, @@ -130,6 +152,7 @@ class DictTuple(tuple[_VT, ...], Generic[_KT, _VT], metaclass=_Meta): __mapping: MappingProxyType[_KT, int] + # These overloads are copied from _Meta.__call__ just for IDE hints @overload def __init__(self) -> None: ... @@ -150,9 +173,7 @@ def __init__(self, __value: Iterable[_VT], __key: Collection[_KT]) -> None: def __init__(self, __value: Iterable[_VT], __key: Callable[[_VT], _KT]) -> None: ... - def __init__( - self, __value: Optional[Union[Mapping[_KT, int], Iterable[_KT]]] = None, / - ) -> None: + def __init__(self, __value=None, /): if isinstance(__value, MappingProxyType): self.__mapping = __value elif isinstance(__value, Mapping): @@ -189,7 +210,7 @@ def __getitem__(self, __key: slice) -> Self: def __getitem__(self, __key: SupportsIndex) -> _VT: ... - def __getitem__(self, __key): + def __getitem__(self, __key, /): if isinstance(__key, slice): return self.__class__( super().__getitem__(__key),