diff --git a/frontends/concrete-python/tests/dtypes/test_tfhers.py b/frontends/concrete-python/tests/dtypes/test_tfhers.py index 9ca506705c..082a5c6706 100644 --- a/frontends/concrete-python/tests/dtypes/test_tfhers.py +++ b/frontends/concrete-python/tests/dtypes/test_tfhers.py @@ -32,7 +32,10 @@ def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType: def test_tfhers_encode_bad_type(): """Test encoding of unsupported type""" dtype = parameterize_partial_dtype(tfhers.uint16_2_2) - with pytest.raises(TypeError, match=r"can only encode int or ndarray, but got "): + with pytest.raises( + TypeError, + match=r"can only encode int, np.integer, list or ndarray, but got ", + ): dtype.encode("bad type") @@ -54,7 +57,7 @@ def test_tfhers_bad_decode(): bad_value = np.random.randint(0, 2**10, size=shape) with pytest.raises( ValueError, - match=r"bad encoding", + match=r"expected the last dimension of encoded value to be 4 but it's 10", ): dtype.decode(bad_value) diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index 906de50790..ed019c6320 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -288,7 +288,7 @@ def test_tfhers_integer_eq(lhs, rhs, is_equal): """ Test TFHERSIntegerType equality. """ - assert is_equal == (lhs == rhs) + assert is_equal == (parameterize_partial_dtype(lhs) == parameterize_partial_dtype(rhs)) @pytest.mark.parametrize( @@ -309,7 +309,7 @@ def test_tfhers_integer_encode(dtype, value, encoded): """ Test TFHERSIntegerType encode. """ - + dtype = parameterize_partial_dtype(dtype) assert np.array_equal(dtype.encode(value), encoded) @@ -329,6 +329,7 @@ def test_tfhers_integer_bad_encode(dtype, value, expected_error, expected_messag Test TFHERSIntegerType encode. """ + dtype = parameterize_partial_dtype(dtype) with pytest.raises(expected_error) as excinfo: dtype.encode(value) @@ -354,6 +355,7 @@ def test_tfhers_integer_decode(dtype, encoded, decoded): Test TFHERSIntegerType decode. """ + dtype = parameterize_partial_dtype(dtype) assert np.array_equal(dtype.decode(encoded), decoded) @@ -379,6 +381,7 @@ def test_tfhers_integer_bad_decode(dtype, value, expected_error, expected_messag Test TFHERSIntegerType decode. """ + dtype = parameterize_partial_dtype(dtype) with pytest.raises(expected_error) as excinfo: dtype.decode(value) @@ -484,6 +487,7 @@ def test_tfhers_integer_bad_init(dtype, value, expected_error, expected_message) Test __init__ of TFHERSInteger with bad arguments. """ + dtype = parameterize_partial_dtype(dtype) with pytest.raises(expected_error) as excinfo: tfhers.TFHERSInteger(dtype, value)