From 161d870410176d8b3a84f91f0fa7dd02f19c9267 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 12 Nov 2024 14:00:23 +0100 Subject: [PATCH] fix(frontend): support higher bitwidth computation when using TFHE-rs --- .../concrete/fhe/mlir/converter.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index f5cf6211bb..6c244f670b 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -985,9 +985,20 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> ] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1)) padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table) # set padding bits (where necessary) in the final result - return ctx.add(result_type, sum_result, padding_bits_inc) - - return sum_result + result = ctx.add(result_type, sum_result, padding_bits_inc) + else: + result = sum_result + + # even if TFHE-rs value are using non-variable bit-width, we want the output + # to be pluggable into the rest of the computation. For example, two 8bits TFHE-rs integers + # could be used in a 9bits addition. If we don't cast, it won't pass the bitwidth + # compatibility check. + output_bit_width = ctx.typeof(node).bit_width + casted_result_type = ctx.tensor( + ctx.esint(output_bit_width) if dtype.is_signed else ctx.eint(output_bit_width), + result_shape, + ) + return ctx.cast(casted_result_type, result) def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1