Skip to content

Commit

Permalink
test(frontend): make tfhers test pass again
Browse files Browse the repository at this point in the history
there was an issue while rebasing to main
  • Loading branch information
youben11 committed Aug 23, 2024
1 parent 96cff21 commit 3e42d69
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
7 changes: 5 additions & 2 deletions frontends/concrete-python/tests/dtypes/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType:
def test_tfhers_encode_bad_type():
"""Test encoding of unsupported type"""
dtype = parameterize_partial_dtype(tfhers.uint16_2_2)
with pytest.raises(TypeError, match=r"can only encode int or ndarray, but got <class 'str'>"):
with pytest.raises(
TypeError,
match=r"can only encode int, np.integer, list or ndarray, but got <class 'str'>",
):
dtype.encode("bad type")


Expand All @@ -54,7 +57,7 @@ def test_tfhers_bad_decode():
bad_value = np.random.randint(0, 2**10, size=shape)
with pytest.raises(
ValueError,
match=r"bad encoding",
match=r"expected the last dimension of encoded value to be 4 but it's 10",
):
dtype.decode(bad_value)

Expand Down
8 changes: 6 additions & 2 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_tfhers_integer_eq(lhs, rhs, is_equal):
"""
Test TFHERSIntegerType equality.
"""
assert is_equal == (lhs == rhs)
assert is_equal == (parameterize_partial_dtype(lhs) == parameterize_partial_dtype(rhs))


@pytest.mark.parametrize(
Expand All @@ -309,7 +309,7 @@ def test_tfhers_integer_encode(dtype, value, encoded):
"""
Test TFHERSIntegerType encode.
"""

dtype = parameterize_partial_dtype(dtype)
assert np.array_equal(dtype.encode(value), encoded)


Expand All @@ -329,6 +329,7 @@ def test_tfhers_integer_bad_encode(dtype, value, expected_error, expected_messag
Test TFHERSIntegerType encode.
"""

dtype = parameterize_partial_dtype(dtype)
with pytest.raises(expected_error) as excinfo:
dtype.encode(value)

Expand All @@ -354,6 +355,7 @@ def test_tfhers_integer_decode(dtype, encoded, decoded):
Test TFHERSIntegerType decode.
"""

dtype = parameterize_partial_dtype(dtype)
assert np.array_equal(dtype.decode(encoded), decoded)


Expand All @@ -379,6 +381,7 @@ def test_tfhers_integer_bad_decode(dtype, value, expected_error, expected_messag
Test TFHERSIntegerType decode.
"""

dtype = parameterize_partial_dtype(dtype)
with pytest.raises(expected_error) as excinfo:
dtype.decode(value)

Expand Down Expand Up @@ -484,6 +487,7 @@ def test_tfhers_integer_bad_init(dtype, value, expected_error, expected_message)
Test __init__ of TFHERSInteger with bad arguments.
"""

dtype = parameterize_partial_dtype(dtype)
with pytest.raises(expected_error) as excinfo:
tfhers.TFHERSInteger(dtype, value)

Expand Down

0 comments on commit 3e42d69

Please sign in to comment.