From 863b485552a3d0b04db419473776dde9576fdea6 Mon Sep 17 00:00:00 2001 From: rudy Date: Thu, 21 Mar 2024 15:56:35 +0100 Subject: [PATCH] fix(frontend-python): optimize extract bits lsb and tlu calls were not minimized --- .../concrete/fhe/extensions/bits.py | 4 +- .../concrete/fhe/mlir/context.py | 178 ++++++++++++++---- .../tests/execution/test_bit_extraction.py | 142 +++++++++++++- 3 files changed, 281 insertions(+), 43 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/extensions/bits.py b/frontends/concrete-python/concrete/fhe/extensions/bits.py index 28c042b6a3..b48cb7a87b 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/bits.py +++ b/frontends/concrete-python/concrete/fhe/extensions/bits.py @@ -105,7 +105,7 @@ def __getitem__(self, index: Union[int, np.integer, slice]) -> Tracer: def evaluator(x, bits): # pylint: disable=redefined-outer-name if isinstance(bits, (int, np.integer)): - return (x & (1 << bits)) >> bits + return (x >> bits) & 1 assert isinstance(bits, slice) @@ -126,7 +126,7 @@ def evaluator(x, bits): # pylint: disable=redefined-outer-name result = 0 for i, bit in enumerate(range(start, stop, step)): - value = (x & (1 << bit)) >> bit + value = (x >> bit) & 1 result += value << i return result diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index e98250afcc..7b0704ffd1 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -40,6 +40,27 @@ # pylint: enable=import-error,no-name-in-module +# See https://raw.githubusercontent.com/zama-ai/concrete/main/compilers/concrete-optimizer/v0-parameters/ref/v0_last_128 +# Provide a very coarse way to compare 2 alternative code generation +LUT_COSTS_V0_NORM2_0 = { + 1: 29, + 2: 33, + 3: 45, + 4: 74, + 5: 101, + 6: 231, + 7: 535, + 8: 1721, + 9: 3864, + 10: 8697, + 11: 19522, +} + + +def default(param, value): + """Handle optional parameter with default value.""" + return value if param is None else param + class Context: """ @@ -125,6 +146,27 @@ def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType: return result if value.is_scalar else self.tensor(result, value.shape) + def fork_type( + self, + type_: ConversionType, + bit_width: Optional[int] = None, + is_signed: Optional[int] = None, + shape: Optional[Tuple[int, ...]] = None, + ) -> ConversionType: + """ + Fork a type with some properties update. + """ + bit_width = default(param=bit_width, value=type_.bit_width) + is_signed = default(param=is_signed, value=type_.is_signed) + shape = default(param=shape, value=type_.shape) + return self.typeof( + ValueDescription( + dtype=Integer(is_signed=is_signed, bit_width=bit_width), + shape=shape, + is_encrypted=type_.is_encrypted, + ) + ) + # utilities def location(self) -> MlirLocation: @@ -2173,6 +2215,22 @@ def encrypt(self, resulting_type: ConversionType, x: Conversion) -> Conversion: def equal(self, resulting_type: ConversionType, x: Conversion, y: Conversion) -> Conversion: return self.comparison(resulting_type, x, y, accept={Comparison.EQUAL}) + def shift_left(self, x: Conversion, rank: int) -> Conversion: + assert rank >= 0 + assert rank < x.bit_width + shifter = 2**rank + shifter = self.constant(self.i(x.bit_width + 1), shifter) + return self.mul(x.type, x, shifter) + + def reduce_precision(self, x: Conversion, bit_width: int) -> Conversion: + assert bit_width > 0 + assert bit_width <= x.bit_width + if bit_width == x.bit_width: + return x + scaled_x = self.shift_left(x, x.type.bit_width - bit_width) + x = self.reinterpret(scaled_x, bit_width=bit_width) + return x + def extract_bits( self, resulting_type: ConversionType, @@ -2200,66 +2258,105 @@ def extract_bits( start = bits.start or MIN_EXTRACTABLE_BIT stop = bits.stop or (MAX_EXTRACTABLE_BIT if step > 0 else (MIN_EXTRACTABLE_BIT - 1)) - bits_and_their_positions = [] - for position, bit in enumerate(range(start, stop, step)): - bits_and_their_positions.append((bit, position)) + bits = list(range(start, stop, step)) + + bits_and_their_positions = ((bit, position) for position, bit in enumerate(bits)) bits_and_their_positions = sorted( bits_and_their_positions, key=lambda bit_and_its_position: bit_and_its_position[0], ) - current_bit = 0 + # we optimize bulk extract in low precision, used for identity max_bit = x.original_bit_width + max_real_bit = max( + (bit for bit in bits if bit <= max_bit and x.is_unsigned), + default=0 + ) + cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf")) + cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max_real_bit + 1) + if cost_one_tlu < cost_many_lsbs: + + def tlu_cell_with_positive_value(i): + return x.type.is_unsigned or i < 2 ** (x.bit_width - 1) + + def tlu_cell_input_value(i): + if tlu_cell_with_positive_value(i): + return i + return -(2 ** (x.bit_width) - i) + + table = [ + sum( + ((tlu_cell_input_value(i) >> bit) & 1) << position + for bit, position in bits_and_their_positions + ) + + (0 if tlu_cell_with_positive_value(i) else 2 ** (resulting_type.bit_width + 1)) + for i in range(2**x.bit_width) + ] + tlu_result = self.tlu(resulting_type, x, table) + return self.to_signedness(tlu_result, of=resulting_type) + + current_bit = 0 lsb: Optional[Conversion] = None result: Optional[Conversion] = None - - for index, (bit, position) in enumerate(bits_and_their_positions): + for bit, position in bits_and_their_positions: if bit >= max_bit and x.is_unsigned: break - last = index == len(bits_and_their_positions) - 1 while bit != (current_bit - 1): - if bit == (max_bit - 1) and x.bit_width == 1 and x.is_unsigned: + if ( + bit == (max_bit - 1) + and x.bit_width == resulting_type.bit_width == 1 + and x.is_unsigned + ): + lsb_bit_witdh = 1 lsb = x - elif last and bit == current_bit: - lsb = self.lsb(resulting_type, x) else: - lsb = self.lsb(x.type, x) + lsb_bit_witdh = max(resulting_type.bit_width - position, x.bit_width) + lsb_type = self.fork_type(x.type, lsb_bit_witdh) + lsb = self.lsb(lsb_type, x) + + # check that we only need to shift to emulate the initial and final position + # position are expressed for the final bit_width + initial_position = resulting_type.bit_width - x.bit_width + actual_position = resulting_type.bit_width - lsb.type.bit_width + delta_precision = initial_position - actual_position + assert 0 <= delta_precision < resulting_type.bit_width + assert ( + actual_position <= initial_position + ), "extract_bits: Cannot get back to initial precision" + assert ( + actual_position <= position + ), "extract_bits: Cannot get back to final precision" current_bit += 1 if current_bit >= max_bit: break - if not last or bit != (current_bit - 1): - cleared = self.sub(x.type, x, lsb) - x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) + clearing_bit = self.reduce_precision(lsb, x.bit_width) + cleared = self.sub(x.type, x, clearing_bit) + x = self.reinterpret(cleared, bit_width=(x.bit_width - 1)) assert lsb is not None - lsb = self.to_signedness(lsb, of=resulting_type) - - if lsb.bit_width > resulting_type.bit_width: - difference = (lsb.bit_width - resulting_type.bit_width) + position - shifter = self.constant(self.i(lsb.bit_width + 1), 2**difference) - shifted = self.mul(lsb.type, lsb, shifter) - lsb = self.reinterpret(shifted, bit_width=resulting_type.bit_width) - - elif lsb.bit_width < resulting_type.bit_width: - shift = 2 ** (lsb.bit_width - 1) - if shift != 1: - shifter = self.constant(self.i(lsb.bit_width + 1), shift) - shifted = self.mul(lsb.type, lsb, shifter) - lsb = self.reinterpret(shifted, bit_width=1) - lsb = self.tlu(resulting_type, lsb, [0 << position, 1 << position]) - - elif position != 0: - shifter = self.constant(self.i(lsb.bit_width + 1), 2**position) - lsb = self.mul(lsb.type, lsb, shifter) + bit_value = self.to_signedness(lsb, of=resulting_type) + bit_value = self.reinterpret( + bit_value, bit_width=max(resulting_type.bit_width, max_bit) + ) - assert lsb is not None - result = lsb if result is None else self.add(resulting_type, result, lsb) + delta_precision = position - actual_position + assert actual_position < 0 or 0 <= delta_precision < resulting_type.bit_width, ( + position, + actual_position, + resulting_type.bit_width, + ) + if delta_precision: + bit_value = self.shift_left(bit_value, delta_precision) + + bit_value = self.reinterpret(bit_value, bit_width=resulting_type.bit_width) + + result = bit_value if result is None else self.add(resulting_type, result, bit_value) return result if result is not None else self.zeros(resulting_type) @@ -3209,8 +3306,8 @@ def round_bit_pattern( unskewed = x if approx_conf.symetrize_deltas: highest_supported_precision = 62 - delta_precision = highest_supported_precision - x.type.bit_width - full_precision = x.type.bit_width + delta_precision + delta_precision = highest_supported_precision - x.bit_width + full_precision = x.bit_width + delta_precision half_in_extra_precision = ( 1 << (delta_precision - 1) ) - 1 # slightly smaller then half @@ -3763,11 +3860,12 @@ def reinterpret( ) -> Conversion: assert x.is_encrypted - if x.bit_width == bit_width: + result_unsigned = x.is_unsigned if signed is None else not signed + + if x.bit_width == bit_width and x.is_unsigned == result_unsigned: return x - result_signed = x.is_unsigned if signed is None else signed - resulting_element_type = (self.eint if result_signed else self.esint)(bit_width) + resulting_element_type = (self.eint if result_unsigned else self.esint)(bit_width) resulting_type = self.tensor(resulting_element_type, shape=x.shape) operation = ( diff --git a/frontends/concrete-python/tests/execution/test_bit_extraction.py b/frontends/concrete-python/tests/execution/test_bit_extraction.py index 3285170643..483fcd788f 100644 --- a/frontends/concrete-python/tests/execution/test_bit_extraction.py +++ b/frontends/concrete-python/tests/execution/test_bit_extraction.py @@ -152,6 +152,7 @@ def test_bad_plain_bit_extraction( "input_bit_width,input_is_signed,operation", [ # unsigned + pytest.param(3, False, lambda x: fhe.bits(x)[0:3], id="unsigned-3b[0:3]"), pytest.param(5, False, lambda x: fhe.bits(x)[0], id="unsigned-5b[0]"), pytest.param(5, False, lambda x: fhe.bits(x)[1], id="unsigned-5b[1]"), pytest.param(5, False, lambda x: fhe.bits(x)[2], id="unsigned-5b[2]"), @@ -166,6 +167,7 @@ def test_bad_plain_bit_extraction( pytest.param(5, False, lambda x: fhe.bits(x)[2::-1], id="unsigned-5b[2::-1]"), pytest.param(5, False, lambda x: fhe.bits(x)[1:30:10], id="unsigned-5b[1:30:10]"), # signed + pytest.param(3, True, lambda x: fhe.bits(x)[0:3], id="signed-3b[0:3]"), pytest.param(5, True, lambda x: fhe.bits(x)[0], id="signed-5b[0]"), pytest.param(5, True, lambda x: fhe.bits(x)[1], id="signed-5b[1]"), pytest.param(5, True, lambda x: fhe.bits(x)[2], id="signed-5b[2]"), @@ -179,9 +181,11 @@ def test_bad_plain_bit_extraction( pytest.param(5, True, lambda x: fhe.bits(x)[2::-1], id="signed-5b[2::-1]"), pytest.param(5, True, lambda x: fhe.bits(x)[1:30:10], id="signed-5b[1:30:10]"), # unsigned (result bit-width increased) + pytest.param(3, False, lambda x: fhe.bits(x)[0:3] + 100, id="unsigned-3b[0:3] + 100"), pytest.param(5, False, lambda x: fhe.bits(x)[0] + 100, id="unsigned-5b[0] + 100"), pytest.param(5, False, lambda x: fhe.bits(x)[1:3] + 100, id="unsigned-5b[1:3] + 100"), # signed (result bit-width increased) + pytest.param(3, True, lambda x: fhe.bits(x)[0:3], id="signed-3b[0:3] + 100"), pytest.param(5, True, lambda x: fhe.bits(x)[0] + 100, id="signed-5b[0] + 100"), pytest.param(5, True, lambda x: fhe.bits(x)[1:3] + 100, id="signed-5b[1:3] + 100"), ], @@ -203,7 +207,143 @@ def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers): compiler = fhe.Compiler(operation, {"x": "encrypted"}) circuit = compiler.compile(inputset, helpers.configuration()) - values = inputset if len(inputset) <= 8 else random.sample(inputset, 8) for value in values: helpers.check_execution(circuit, operation, value, retries=3) + + +def mlir_count_ops(mlir, operation): + """ + Count op in mlir. + """ + return sum(operation in line for line in mlir.splitlines()) + + +def test_highest_bit_extraction_mlir(helpers): + """ + Test bit extraction of the highest bit. Saves one lsb. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return fhe.bits(x)[precision - 1] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bits_extraction_to_same_bitwidth_mlir(helpers): + """ + Test bit extraction to same. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bits_extraction_to_bigger_bitwidth_mlir(helpers): + """ + Test bit extraction to bigger bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + (2**precision + 1) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + print(circuit.mlir) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_same_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to smaller bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + (2**precision - 2) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_smaller_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to smaller bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1 + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_seq_bits_extraction_to_bigger_bitwidth_mlir(helpers): + """ + Test sequential bit extraction to bigger bitwidth. + """ + + precision = 8 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return tuple(fhe.bits(x)[i] + 2 ** (precision + 1) for i in range(precision)) + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == precision + assert mlir_count_ops(circuit.mlir, "lookup") == 0 + + +def test_bit_extract_to_1_tlu(helpers): + """ + Test bit extract as 1 tlu for small precision. + """ + precision = 3 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): + return fhe.bits(x)[0:2] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == 0 + assert mlir_count_ops(circuit.mlir, "lookup") == 1 + + precision = 4 + inputset = list(range(2**precision)) + + @fhe.compiler({"x": "encrypted"}) + def operation(x): # pylint: disable=function-redefined + return fhe.bits(x)[0:2] + + circuit = operation.compile(inputset, helpers.configuration()) + assert mlir_count_ops(circuit.mlir, "lsb") == 2 + assert mlir_count_ops(circuit.mlir, "lookup") == 0