From b0e7c085ea109bb7935ae6366b1ef5784a441ed7 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 22 Oct 2024 11:41:39 +0100 Subject: [PATCH] fix(frontend): consider padding bit during to_native conversion --- .../concrete/fhe/mlir/converter.py | 26 ++++++++++++++++++- .../tests/execution/test_tfhers.py | 10 +++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index cbd174c2f9..f5cf6211bb 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -963,7 +963,31 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> ctx.esint(result_bit_width) if dtype.is_signed else ctx.eint(result_bit_width), result_shape, ) - return ctx.sum(result_type, mapped, axes=-1) + sum_result = ctx.sum(result_type, mapped, axes=-1) + + # we want to set the padding bit if the native type is signed + # and the ciphertext is negative (sign bit set to 1) + if dtype.is_signed: + # select MSBs of all tfhers ciphetexts + index = [slice(0, dim_size) for dim_size in tfhers_int.shape[:-1]] + [ + -1, + ] + msbs = ctx.index( + ctx.tensor(ctx.eint(msg_width + carry_width), tfhers_int.shape[:-1]), + tfhers_int, + index=index, + ) + # construct padding bits based on sign bits (carry would be considered negative) + padding_bit_table = [ + 0, + ] * 2 ** (msg_width - 1) + [ + 2**result_bit_width, + ] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1)) + padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table) + # set padding bits (where necessary) in the final result + return ctx.add(result_type, sum_result, padding_bits_inc) + + return sum_result def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index ce0eab473c..c747916084 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -363,12 +363,11 @@ def lut_add_lut(x, y): TFHERS_UINT_8_3_2_4096, id="x * y", ), - # FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative pytest.param( lambda x, y: x * y, { - "x": {"range": [-(2**3), 0], "status": "encrypted"}, - "y": {"range": [-(2**3), 0], "status": "encrypted"}, + "x": {"range": [-(2**3), 2**2], "status": "encrypted"}, + "y": {"range": [-(2**2), 2**3], "status": "encrypted"}, }, TFHERS_INT_8_3_2_4096, id="signed(x) * signed(y)", @@ -663,12 +662,11 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), - # FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative pytest.param( lambda x, y: x * y, { - "x": {"range": [-(2**3), -(2**3)], "status": "encrypted"}, - "y": {"range": [2**4, 2**4], "status": "encrypted"}, + "x": {"range": [-(2**3), 2], "status": "encrypted"}, + "y": {"range": [-2, 2**4], "status": "encrypted"}, }, TFHERS_INT_8_3_2_4096, id="signed(x) * signed(y)",