Skip to content

Commit

Permalink
test(frontend): more tests and coverage for tfhers
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 22, 2024
1 parent 24493a3 commit 8051a14
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 3 deletions.
6 changes: 3 additions & 3 deletions frontends/concrete-python/concrete/fhe/tfhers/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def __init__(
self.pbs_base_log = pbs_base_log
self.pbs_level = pbs_level

def __str__(self) -> str:
def __str__(self) -> str: # pragma: no cover
return (
f"tfhers_params<lwe_dim={self.lwe_dimension}, glwe_dim={self.glwe_dimension}, "
f"poly_size={self.polynomial_size}, pbs_base_log={self.pbs_base_log}, "
f"pbs_level={self.pbs_level}>"
)

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: Any) -> bool: # pragma: no cover
return (
isinstance(other, self.__class__)
and self.lwe_dimension == other.lwe_dimension
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.msg_width = msg_width
self.params = params

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: Any) -> bool: # pragma: no cover
return (
isinstance(other, self.__class__)
and super().__eq__(other)
Expand Down
3 changes: 3 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def to_native(value: Union[Tracer, TFHERSInteger]) -> Union[Tracer, int, np.ndar

if isinstance(value, Tracer) and isinstance(value.output.dtype, TFHERSIntegerType):
dtype = value.output.dtype
if not isinstance(dtype, TFHERSIntegerType): # pragma: no cover
msg = f"tracer didn't contain an output of TFHEInteger type. Type is: {dtype}"
raise TypeError(msg)
return _trace_to_native(value, dtype)

if isinstance(value, TFHERSInteger):
Expand Down
3 changes: 3 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(
except Exception as e: # pylint: disable=broad-except
msg = f"got error while trying to convert list value into a numpy array: {e}"
raise ValueError(msg) from e
if value.dtype == np.dtype("O"):
msg = "malformed value array"
raise ValueError(msg)

if isinstance(value, (int, np.integer)):
self._shape = ()
Expand Down
106 changes: 106 additions & 0 deletions frontends/concrete-python/tests/dtypes/test_tfhers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Tests of `TFHERSIntegerType` data type.
"""

import numpy as np
import pytest

from concrete.fhe import tfhers

DEFAULT_TFHERS_PARAM = tfhers.TFHERSParams(
909,
1,
4096,
15,
2,
)


def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType:
"""Create a tfhers type from a partial func missing tfhers params.
Args:
partial_dtype (callable): partial function to create dtype (missing params)
Returns:
tfhers.TFHERSIntegerType: tfhers type
"""

return partial_dtype(DEFAULT_TFHERS_PARAM)


def test_tfhers_encode_bad_type():
"""Test encoding of unsupported type"""
dtype = parameterize_partial_dtype(tfhers.uint16_2_2)
with pytest.raises(TypeError, match=r"can only encode int or ndarray, but got <class 'str'>"):
dtype.encode("bad type")


def test_tfhers_encode_ndarray():
"""Test ndarray encoding"""
dtype = parameterize_partial_dtype(tfhers.uint16_2_2)
shape = (4, 5)
value = np.random.randint(0, 2**10, size=shape)
encoded = dtype.encode(value)
decoded = dtype.decode(encoded)
assert (decoded == value).all()
assert encoded.shape == shape + (8,)


def test_tfhers_bad_decode():
"""Test decoding of bad values"""
dtype = parameterize_partial_dtype(tfhers.uint8_2_2)
shape = (2, 10)
bad_value = np.random.randint(0, 2**10, size=shape)
with pytest.raises(
ValueError,
match=r"bad encoding",
):
dtype.decode(bad_value)


def test_tfhers_integer_bad_values():
"""Test new integer with bad values"""
dtype = parameterize_partial_dtype(tfhers.uint8_2_2)
with pytest.raises(
ValueError,
):
tfhers.TFHERSInteger(
dtype,
[
[1, 2],
[
2,
],
],
)

with pytest.raises(
ValueError,
match=r"ndarray value has bigger elements than what the dtype can support",
):
tfhers.TFHERSInteger(
dtype,
[
[1, 2],
[2, 2**10],
],
)

with pytest.raises(
ValueError,
match=r"ndarray value has smaller elements than what the dtype can support",
):
tfhers.TFHERSInteger(
dtype,
[
[1, -2],
[2, 2],
],
)

with pytest.raises(
TypeError,
match=r"value can either be an int or ndarray, not a <class 'str'>",
):
tfhers.TFHERSInteger(dtype, "bad value")
10 changes: 10 additions & 0 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def is_input_and_output_tfhers(
tfhers_ins: List[int],
tfhers_outs: List[int],
) -> bool:
"""Check if inputs and outputs description match tfhers parameters"""
params = json.loads(circuit.client.specs.client_parameters.serialize())
main_circuit = params["circuits"][0]
# check all encrypted input/output have the correct lwe_dim
Expand Down Expand Up @@ -534,3 +535,12 @@ def test_tfhers_conversion_without_multi(function, parameters, parameter_strateg
]
with pytest.raises(RuntimeError, match=f"Can't use tfhers integers with {parameter_strategy}"):
compiler.compile(inputset, configuration)


def test_tfhers_circuit_eval():
"""Test evaluation of tfhers function."""
dtype = parameterize_partial_dtype(tfhers.uint16_2_2)
x = tfhers.TFHERSInteger(dtype, 1)
y = tfhers.TFHERSInteger(dtype, 2)
result = binary_tfhers(x, y, lambda x, y: x + y, dtype)
assert result == 3

0 comments on commit 8051a14

Please sign in to comment.