From b1352554ea0a1eaf0e7f97e75e62072d8b243d59 Mon Sep 17 00:00:00 2001
From: rudy <rudy.sicard@gmail.com>
Date: Wed, 3 Apr 2024 17:07:09 +0200
Subject: [PATCH] fix(frontend-python): REVIEW TO SQUASH

---
 .../concrete/fhe/mlir/context.py              | 77 +++++++++----------
 1 file changed, 36 insertions(+), 41 deletions(-)

diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py
index 28781918db..71fea10458 100644
--- a/frontends/concrete-python/concrete/fhe/mlir/context.py
+++ b/frontends/concrete-python/concrete/fhe/mlir/context.py
@@ -40,8 +40,9 @@
 
 # pylint: enable=import-error,no-name-in-module
 
-
-LUT_COSTS = {
+# See https://raw.githubusercontent.com/zama-ai/concrete/main/compilers/concrete-optimizer/v0-parameters/ref/v0_last_128
+# Provide a very coarse way to compare 2 alternative code generation
+LUT_COSTS_V0_NORM2_0 = {
     1: 29,
     2: 33,
     3: 45,
@@ -140,6 +141,15 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType:
 
         return result if value.is_scalar else self.tensor(result, value.shape)
 
+    def fork_type(self, type_, bit_width):
+        return self.typeof(
+            ValueDescription(
+                dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width),
+                shape=type_.shape,
+                is_encrypted=type_.is_encrypted,
+            )
+        )
+
     # utilities
 
     def location(self) -> MlirLocation:
@@ -2188,17 +2198,17 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
     def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion:
         return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL})
 
-    def shift_left(self, x: Conversion, rank):
+    def shift_left(self, x: Conversion, rank: int) -> Conversion:
         assert rank >= 0
         assert rank < x.bit_width
         shifter = 2**rank
         shifter = self.constant(self.i(x.bit_width + 1), shifter)
         return self.mul(x.type, x, shifter)
 
-    def reduce_precision(self, x: Conversion, bit_width):
+    def reduce_precision(self, x: Conversion, bit_width: int) -> Conversion:
         assert bit_width > 0
-        assert bit_width <= x.type.bit_width
-        if bit_width == x.type.bit_width:
+        assert bit_width <= x.bit_width
+        if bit_width == x.bit_width:
             return x
         scaled_x = self.shift_left(x, x.type.bit_width - bit_width)
         x = self.reinterpret(scaled_x, bit_width=bit_width)
@@ -2241,44 +2251,29 @@ def extract_bits(
         )
 
         # we optimize bulk extract in low precision, used for identity
-        if (
-            len(bits) > 1
-            and LUT_COSTS.get(x.type.bit_width, float("inf")) <= (max(bits) + 1) * LUT_COSTS[1]
-        ):
+        cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf"))
+        cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1)
+        if cost_one_tlu < cost_many_lsbs:
 
-            def is_positive(v):
-                return x.type.is_unsigned or v < 2 ** (x.type.bit_width - 1)
+            def tlu_cell_with_positive_value(i):
+                return x.type.is_unsigned or i < 2 ** (x.bit_width - 1)
 
-            def to_signed(v):
-                if is_positive(v):
-                    return v
-                return -(2 ** (x.type.bit_width) - v)
+            def tlu_cell_input_value(i):
+                if tlu_cell_with_positive_value(i):
+                    return i
+                return -(2 ** (x.bit_width) - i)
 
             table = [
                 sum(
-                    ((to_signed(v) >> bit) & 1) << position
+                    ((tlu_cell_input_value(i) >> bit) & 1) << position
                     for bit, position in bits_and_their_positions
                 )
-                + (0 if is_positive(v) else 2 ** (resulting_type.bit_width + 1))
-                for v in range(2**x.type.bit_width)
+                + (0 if tlu_cell_with_positive_value(i) else 2 ** (resulting_type.bit_width + 1))
+                for i in range(2**x.bit_width)
             ]
             tlu_result = self.tlu(resulting_type, x, table)
             return self.to_signedness(tlu_result, of=resulting_type)
 
-        def same_type_as(type_, bit_width):
-            return self.typeof(
-                ValueDescription(
-                    dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width),
-                    shape=type_.shape,
-                    is_encrypted=type_.is_encrypted,
-                )
-            )
-
-        def reduce_precision(x, delta):
-            assert delta < x.type.bit_width
-            new_bit_witdh = x.type.bit_width - delta
-            return self.reduce_precision(x, new_bit_witdh)
-
         current_bit = 0
         max_bit = x.original_bit_width
 
@@ -2297,14 +2292,16 @@ def reduce_precision(x, delta):
                     lsb_bit_witdh = 1
                     lsb = x
                 else:
-                    lsb_bit_witdh = max(resulting_type.bit_width - position, x.type.bit_width)
-                    lsb_type = same_type_as(x.type, lsb_bit_witdh)
+                    lsb_bit_witdh = max(resulting_type.bit_width - position, x.bit_width)
+                    lsb_type = self.fork_type(x.type, lsb_bit_witdh)
                     lsb = self.lsb(lsb_type, x)
 
                 # check that we only need to shift to emulate the initial and final position
                 # position are expressed for the final bit_width
-                initial_position = resulting_type.bit_width - x.type.bit_width
+                initial_position = resulting_type.bit_width - x.bit_width
                 actual_position = resulting_type.bit_width - lsb.type.bit_width
+                delta_precision = initial_position - actual_position
+                assert 0 <= delta_precision < resulting_type.bit_width
                 assert (
                     actual_position <= initial_position
                 ), "extract_bits: Cannot get back to initial precision"
@@ -2317,9 +2314,7 @@ def reduce_precision(x, delta):
                 if current_bit >= max_bit:
                     break
 
-                delta_precision = initial_position - actual_position
-                assert 0 <= delta_precision < resulting_type.bit_width
-                clearing_bit = reduce_precision(lsb, delta_precision)
+                clearing_bit = self.reduce_precision(lsb, x.bit_width)
                 cleared = self.sub(x.type, x, clearing_bit)
                 x = self.reinterpret(cleared, bit_width=(x.bit_width - 1))
 
@@ -3290,8 +3285,8 @@ def round_bit_pattern(
             unskewed = x
             if approx_conf.symetrize_deltas:
                 highest_supported_precision = 62
-                delta_precision = highest_supported_precision - x.type.bit_width
-                full_precision = x.type.bit_width + delta_precision
+                delta_precision = highest_supported_precision - x.bit_width
+                full_precision = x.bit_width + delta_precision
                 half_in_extra_precision = (
                     1 << (delta_precision - 1)
                 ) - 1  # slightly smaller then half