Skip to content

Commit

Permalink
fix(frontend): use esint when tfhers type is signed
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Oct 25, 2024
1 parent 6c7291c commit e5c643e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
5 changes: 4 additions & 1 deletion frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,10 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) ->
# sum will remove the last dim which is the dim of ciphertexts
result_shape = tfhers_int.shape[:-1]
# if result_shape is () then ctx.tensor would return a scalar type
result_type = ctx.tensor(ctx.eint(result_bit_width), result_shape)
result_type = ctx.tensor(
ctx.esint(result_bit_width) if dtype.is_signed else ctx.eint(result_bit_width),
result_shape,
)
return ctx.sum(result_type, mapped, axes=-1)

def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
Expand Down
29 changes: 29 additions & 0 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ def lut_add_lut(x, y):
TFHERS_UINT_8_3_2_4096,
id="x * y",
),
# FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [-(2**3), 0], "status": "encrypted"},
"y": {"range": [-(2**3), 0], "status": "encrypted"},
},
TFHERS_INT_8_3_2_4096,
id="signed(x) * signed(y)",
),
pytest.param(
lambda x, y: x * y,
{
Expand Down Expand Up @@ -653,6 +663,25 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen(
TFHERS_UINT_8_3_2_4096,
id="x * clear(y)",
),
# FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [-(2**3), -(2**3)], "status": "encrypted"},
"y": {"range": [2**4, 2**4], "status": "encrypted"},
},
TFHERS_INT_8_3_2_4096,
id="signed(x) * signed(y)",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"range": [0, 2**3], "status": "encrypted"},
"y": {"range": [-(2**3), 0], "status": "clear"},
},
TFHERS_INT_8_3_2_4096,
id="signed(x) * clear(-y)",
),
pytest.param(
lut_add_lut,
{
Expand Down

0 comments on commit e5c643e

Please sign in to comment.