From 2b368edb209dab10d3e4ccae92fcceb76bbd1b37 Mon Sep 17 00:00:00 2001 From: rudy Date: Thu, 21 Mar 2024 15:56:35 +0100 Subject: [PATCH] fix(frontend-python): optimize extract bits lsb and tlu calls were not minimized --- .../concrete/fhe/extensions/bits.py | 4 +- .../concrete/fhe/mlir/context.py | 156 +++++++++++++----- .../tests/execution/test_bit_extraction.py | 139 +++++++++++++++- 3 files changed, 259 insertions(+), 40 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/extensions/bits.py b/frontends/concrete-python/concrete/fhe/extensions/bits.py index 28c042b6a3..b48cb7a87b 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/bits.py +++ b/frontends/concrete-python/concrete/fhe/extensions/bits.py @@ -105,7 +105,7 @@ def __getitem__(self, index: Union[int, np.integer, slice]) -> Tracer: def evaluator(x, bits): # pylint: disable=redefined-outer-name if isinstance(bits, (int, np.integer)): - return (x & (1 << bits)) >> bits + return (x >> bits) & 1 assert isinstance(bits, slice) @@ -126,7 +126,7 @@ def evaluator(x, bits): # pylint: disable=redefined-outer-name result = 0 for i, bit in enumerate(range(start, stop, step)): - value = (x & (1 << bit)) >> bit + value = (x >> bit) & 1 result += value << i return result diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index e98250afcc..28781918db 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -41,6 +41,21 @@ # pylint: enable=import-error,no-name-in-module +LUT_COSTS = { + 1: 29, + 2: 33, + 3: 45, + 4: 74, + 5: 101, + 6: 231, + 7: 535, + 8: 1721, + 9: 3864, + 10: 8697, + 11: 19522, +} + + class Context: """ Context class, to perform operations on conversions. @@ -2173,6 +2188,22 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion: def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion: return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL}) + def shift_left(self, x: Conversion, rank): + assert rank >= 0 + assert rank < x.bit_width + shifter = 2**rank + shifter = self.constant(self.i(x.bit_width + 1), shifter) + return self.mul(x.type, x, shifter) + + def reduce_precision(self, x: Conversion, bit_width): + assert bit_width > 0 + assert bit_width <= x.type.bit_width + if bit_width == x.type.bit_width: + return x + scaled_x = self.shift_left(x, x.type.bit_width - bit_width) + x = self.reinterpret(scaled_x, bit_width=bit_width) + return x + def extract_bits( self, resulting_type: ConversionType, @@ -2200,66 +2231,116 @@ def extract_bits( start = bits.start or MIN_EXTRACTABLE_BIT stop = bits.stop or (MAX_EXTRACTABLE_BIT if step > 0 else (MIN_EXTRACTABLE_BIT - 1)) - bits_and_their_positions = [] - for position, bit in enumerate(range(start, stop, step)): - bits_and_their_positions.append((bit, position)) + bits = list(range(start, stop, step)) + + bits_and_their_positions = ((bit, position) for position, bit in enumerate(bits)) bits_and_their_positions = sorted( bits_and_their_positions, key=lambda bit_and_its_position: bit_and_its_position[0], ) + # we optimize bulk extract in low precision, used for identity + if ( + len(bits) > 1 + and LUT_COSTS.get(x.type.bit_width, float("inf")) <= (max(bits) + 1) * LUT_COSTS[1] + ): + + def is_positive(v): + return x.type.is_unsigned or v < 2 ** (x.type.bit_width - 1) + + def to_signed(v): + if is_positive(v): + return v + return -(2 ** (x.type.bit_width) - v) + + table = [ + sum( + ((to_signed(v) >> bit) & 1) << position + for bit, position in bits_and_their_positions + ) + + (0 if is_positive(v) else 2 ** (resulting_type.bit_width + 1)) + for v in range(2**x.type.bit_width) + ] + tlu_result = self.tlu(resulting_type, x, table) + return self.to_signedness(tlu_result, of=resulting_type) + + def same_type_as(type_, bit_width): + return self.typeof( + ValueDescription( + dtype=Integer(is_signed=type_.is_signed, bit_width=bit_width), + shape=type_.shape, + is_encrypted=type_.is_encrypted, + ) + ) + + def reduce_precision(x, delta): + assert delta < x.type.bit_width + new_bit_witdh = x.type.bit_width - delta + return self.reduce_precision(x, new_bit_witdh) + current_bit = 0 max_bit = x.original_bit_width lsb: Optional[Conversion] = None result: Optional[Conversion] = None - - for index, (bit, position) in enumerate(bits_and_their_positions): + for bit, position in bits_and_their_positions: if bit >= max_bit and x.is_unsigned: break - last = index == len(bits_and_their_positions) - 1 while bit != (current_bit - 1): - if bit == (max_bit - 1) and x.bit_width == 1 and x.is_unsigned: + if ( + bit == (max_bit - 1) + and x.bit_width == resulting_type.bit_width == 1 + and x.is_unsigned + ): + lsb_bit_witdh = 1 lsb = x - elif last and bit == current_bit: - lsb = self.lsb(resulting_type, x) else: - lsb = self.lsb(x.type, x) + lsb_bit_witdh = max(resulting_type.bit_width - position, x.type.bit_width) + lsb_type = same_type_as(x.type, lsb_bit_witdh) + lsb = self.lsb(lsb_type, x) + + # check that we only need to shift to emulate the initial and final position + # position are expressed for the final bit_width + initial_position = resulting_type.bit_width - x.type.bit_width + actual_position = resulting_type.bit_width - lsb.type.bit_width + assert ( + actual_position <= initial_position + ), "extract_bits: Cannot get back to initial precision" + assert ( + actual_position <= position + ), "extract_bits: Cannot get back to final precision" current_bit += 1 if current_bit >= max_bit: break - if not last or bit != (current_bit - 1): - cleared = self.sub(x.type, x, lsb) - x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) + delta_precision = initial_position - actual_position + assert 0 <= delta_precision < resulting_type.bit_width + clearing_bit = reduce_precision(lsb, delta_precision) + cleared = self.sub(x.type, x, clearing_bit) + x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) assert lsb is not None - lsb = self.to_signedness(lsb, of=resulting_type) - - if lsb.bit_width > resulting_type.bit_width: - difference = (lsb.bit_width - resulting_type.bit_width) + position - shifter = self.constant(self.i(lsb.bit_width + 1), 2**difference) - shifted = self.mul(lsb.type, lsb, shifter) - lsb = self.reinterpret(shifted, bit_width=resulting_type.bit_width) - - elif lsb.bit_width < resulting_type.bit_width: - shift = 2 ** (lsb.bit_width - 1) - if shift != 1: - shifter = self.constant(self.i(lsb.bit_width + 1), shift) - shifted = self.mul(lsb.type, lsb, shifter) - lsb = self.reinterpret(shifted, bit_width=1) - lsb = self.tlu(resulting_type, lsb, [0 << position, 1 << position]) - - elif position != 0: - shifter = self.constant(self.i(lsb.bit_width + 1), 2**position) - lsb = self.mul(lsb.type, lsb, shifter) + bit_value = self.to_signedness(lsb, of=resulting_type) + bit_value = self.reinterpret( + bit_value, bit_width=max(resulting_type.bit_width, max_bit) + ) - assert lsb is not None - result = lsb if result is None else self.add(resulting_type, result, lsb) + delta_precision = position - actual_position + assert actual_position < 0 or 0 <= delta_precision < resulting_type.bit_width, ( + position, + actual_position, + resulting_type.bit_width, + ) + if delta_precision: + bit_value = self.shift_left(bit_value, delta_precision) + + bit_value = self.reinterpret(bit_value, bit_width=resulting_type.bit_width) + + result = bit_value if result is None else self.add(resulting_type, result, bit_value) return result if result is not None else self.zeros(resulting_type) @@ -3763,11 +3844,12 @@ def reinterpret( ) -> Conversion: assert x.is_encrypted - if x.bit_width == bit_width: + result_unsigned = x.is_unsigned if signed is None else not signed + + if x.bit_width == bit_width and x.is_unsigned == result_unsigned: return x - result_signed = x.is_unsigned if signed is None else signed - resulting_element_type = (self.eint if result_signed else self.esint)(bit_width) + resulting_element_type = (self.eint if result_unsigned else self.esint)(bit_width) resulting_type = self.tensor(resulting_element_type, shape=x.shape) operation = ( diff --git a/frontends/concrete-python/tests/execution/test_bit_extraction.py b/frontends/concrete-python/tests/execution/test_bit_extraction.py index 3285170643..61bf004e3b 100644 --- a/frontends/concrete-python/tests/execution/test_bit_extraction.py +++ b/frontends/concrete-python/tests/execution/test_bit_extraction.py @@ -203,7 +203,144 @@ def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers): compiler = fhe.Compiler(operation, {"x": "encrypted"}) circuit = compiler.compile(inputset, helpers.configuration()) - + print(circuit.mlir) values = inputset if len(inputset) <= 8 else random.sample(inputset, 8) for value in values: helpers.check_execution(circuit, operation, value, retries=3) + + +def mlir_count_ops(mlir, operation): + """ + Count op in mlir. + """ + return sum(operation in line for line in mlir.splitlines()) + + +def test_highest_bit_extraction_mlir(helpers): + """ + Test bit extraction of the highest bit. Saves one lsb. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return fhe.bits(x)[precision - 1] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bits_extraction_to_same_bitwidth_mlir(helpers): + """ + Test bit extraction to same. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bits_extraction_to_bigger_bitwidth_mlir(helpers): + """ + Test bit extraction to bigger bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + (2**precision + 1) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + print(circuit.mlir) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_same_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to smaller bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + (2**precision - 2) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_smaller_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to smaller bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_bigger_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to bigger bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + 2 ** (precision + 1) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bit_extract_to_1_tlu(helpers): + """ + Test bit extract as 1 tlu for small precision. + """ + precision = 3 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return fhe.bits(x)[0:2] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == 0 + assert mlir_count_ops(circuit.mlir, "lookup") == 1 + + precision = 4 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): # pylint: disable=function-redefined + return fhe.bits(x)[0:2] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == 2 + assert mlir_count_ops(circuit.mlir, "lookup") == 0