From ced1f378bfd141f1ac39e039c98bf5e1fd8d3dc2 Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 3 Apr 2024 17:07:09 +0200 Subject: [PATCH] fix(frontend-python): REVIEW TO SQUASH --- .../concrete/fhe/mlir/context.py | 77 +++++++++---------- 1 file changed, 36 insertions(+), 41 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 28781918db..71fea10458 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -40,8 +40,9 @@ # pylint: enable=import-error,no-name-in-module - -LUT_COSTS = { +# See https://raw.githubusercontent.com/zama-ai/concrete/main/compilers/concrete-optimizer/v0-parameters/ref/v0_last_128 +# Provide a very coarse way to compare 2 alternative code generation +LUT_COSTS_V0_NORM2_0 = { 1: 29, 2: 33, 3: 45, @@ -140,6 +141,15 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType: return result if value.is_scalar else self.tensor(result, value.shape) + def fork_type(self, type_, bit_width): + return self.typeof( + ValueDescription( + dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width), + shape=type_.shape, + is_encrypted=type_.is_encrypted, + ) + ) + # utilities def location(self) -> MlirLocation: @@ -2188,17 +2198,17 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion: def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion: return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL}) - def shift_left(self, x: Conversion, rank): + def shift_left(self, x: Conversion, rank: int) -> Conversion: assert rank >= 0 assert rank < x.bit_width shifter = 2**rank shifter = self.constant(self.i(x.bit_width + 1), shifter) return self.mul(x.type, x, shifter) - def reduce_precision(self, x: Conversion, bit_width): + def reduce_precision(self, x: Conversion, bit_width: int) -> Conversion: assert bit_width > 0 - assert bit_width <= x.type.bit_width - if bit_width == x.type.bit_width: + assert bit_width <= x.bit_width + if bit_width == x.bit_width: return x scaled_x = self.shift_left(x, x.type.bit_width - bit_width) x = self.reinterpret(scaled_x, bit_width=bit_width) @@ -2241,44 +2251,29 @@ def extract_bits( ) # we optimize bulk extract in low precision, used for identity - if ( - len(bits) > 1 - and LUT_COSTS.get(x.type.bit_width, float("inf")) <= (max(bits) + 1) * LUT_COSTS[1] - ): + cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf")) + cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1) + if cost_one_tlu < cost_many_lsbs: - def is_positive(v): - return x.type.is_unsigned or v < 2 ** (x.type.bit_width - 1) + def tlu_cell_with_positive_value(i): + return x.type.is_unsigned or i < 2 ** (x.bit_width - 1) - def to_signed(v): - if is_positive(v): - return v - return -(2 ** (x.type.bit_width) - v) + def tlu_cell_input_value(i): + if tlu_cell_with_positive_value(i): + return i + return -(2 ** (x.bit_width) - i) table = [ sum( - ((to_signed(v) >> bit) & 1) << position + ((tlu_cell_input_value(i) >> bit) & 1) << position for bit, position in bits_and_their_positions ) - + (0 if is_positive(v) else 2 ** (resulting_type.bit_width + 1)) - for v in range(2**x.type.bit_width) + + (0 if tlu_cell_with_positive_value(i) else 2 ** (resulting_type.bit_width + 1)) + for i in range(2**x.bit_width) ] tlu_result = self.tlu(resulting_type, x, table) return self.to_signedness(tlu_result, of=resulting_type) - def same_type_as(type_, bit_width): - return self.typeof( - ValueDescription( - dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width), - shape=type_.shape, - is_encrypted=type_.is_encrypted, - ) - ) - - def reduce_precision(x, delta): - assert delta < x.type.bit_width - new_bit_witdh = x.type.bit_width - delta - return self.reduce_precision(x, new_bit_witdh) - current_bit = 0 max_bit = x.original_bit_width @@ -2297,14 +2292,16 @@ def reduce_precision(x, delta): lsb_bit_witdh = 1 lsb = x else: - lsb_bit_witdh = max(resulting_type.bit_width - position, x.type.bit_width) - lsb_type = same_type_as(x.type, lsb_bit_witdh) + lsb_bit_witdh = max(resulting_type.bit_width - position, x.bit_width) + lsb_type = self.fork_type(x.type, lsb_bit_witdh) lsb = self.lsb(lsb_type, x) # check that we only need to shift to emulate the initial and final position # position are expressed for the final bit_width - initial_position = resulting_type.bit_width - x.type.bit_width + initial_position = resulting_type.bit_width - x.bit_width actual_position = resulting_type.bit_width - lsb.type.bit_width + delta_precision = initial_position - actual_position + assert 0 <= delta_precision < resulting_type.bit_width assert ( actual_position <= initial_position ), "extract_bits: Cannot get back to initial precision" @@ -2317,9 +2314,7 @@ def reduce_precision(x, delta): if current_bit >= max_bit: break - delta_precision = initial_position - actual_position - assert 0 <= delta_precision < resulting_type.bit_width - clearing_bit = reduce_precision(lsb, delta_precision) + clearing_bit = self.reduce_precision(lsb, x.bit_width) cleared = self.sub(x.type, x, clearing_bit) x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) @@ -3290,8 +3285,8 @@ def round_bit_pattern( unskewed = x if approx_conf.symetrize_deltas: highest_supported_precision = 62 - delta_precision = highest_supported_precision - x.type.bit_width - full_precision = x.type.bit_width + delta_precision + delta_precision = highest_supported_precision - x.bit_width + full_precision = x.bit_width + delta_precision half_in_extra_precision = ( 1 << (delta_precision - 1) ) - 1 # slightly smaller then half