diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index 1ef238bdd5c..f5eb3e84a94 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -174,7 +174,7 @@ def __init__( if any(isinstance(k, SupportsIndex) for k in self.__mapping.keys()): raise ValueError( "values of `__keys` should not have type `int`, " - "or implements `__index__()`" + "or implement `__index__()`" ) @overload diff --git a/packages/syft/tests/syft/types/dicttuple_test.py b/packages/syft/tests/syft/types/dicttuple_test.py index 8317e608a62..c6adf40f328 100644 --- a/packages/syft/tests/syft/types/dicttuple_test.py +++ b/packages/syft/tests/syft/types/dicttuple_test.py @@ -2,6 +2,7 @@ from collections.abc import Collection from collections.abc import Iterable from collections.abc import Mapping +from functools import cached_property from itertools import chain from itertools import combinations from typing import Any @@ -89,7 +90,6 @@ class Case(Generic[_KT, _VT]): key_fn: Optional[Callable[[_VT], _KT]] value_generator: Callable[[], Generator[_VT, Any, None]] key_generator: Callable[[], Generator[_KT, Any, None]] - mapping: Mapping[_KT, _VT] def __init__( self, @@ -114,19 +114,27 @@ def key_generator() -> Generator[_KT, Any, None]: self.value_generator = value_generator self.key_generator = key_generator - self.mapping = dict(zip(self.keys, self.values)) - def kv(self) -> Iterable[tuple[_KT, _VT]]: return zip(self.keys, self.values) - def constructor_args(self) -> list[Callable[[], tuple]]: + @cached_property + def mapping(self) -> Mapping[_KT, _VT]: + return dict(self.kv()) + + def constructor_args(self, mapping: bool = True) -> list[Callable[[], tuple]]: return [ lambda: (self.values, self.keys), lambda: (self.value_generator(), self.key_generator()), lambda: (self.values, self.key_generator()), lambda: (self.value_generator(), self.keys), - lambda: (self.mapping,), - lambda: (self.kv(),), + *( + [ + lambda: (self.mapping,), + lambda: (self.kv(),), + ] + if mapping + else [] + ), *( [ lambda: (self.values, self.key_fn), @@ -207,3 +215,20 @@ def test_dicttuple_is_not_a_mapping( def test_keys_should_not_be_int(args: Callable[[], tuple]) -> None: with pytest.raises(ValueError, match="int"): DictTuple(*args()) + + +LENGTH_MISMACTH_TEST_CASES = [ + Case(values=[1, 2, 3], keys=["x", "y"]), + Case(values=[1, 2], keys=["x", "y", "z"]), +] + + +@pytest.mark.parametrize( + "args", + chain.from_iterable( + c.constructor_args(mapping=False) for c in LENGTH_MISMACTH_TEST_CASES + ), +) +def test_keys_and_values_should_have_same_length(args: Callable[[], tuple]) -> None: + with pytest.raises(ValueError, match="length"): + DictTuple(*args())