diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 5d239c159..47ef5fe5b 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -1008,6 +1008,7 @@ class Configuration: keyset_restriction: Optional[KeysetRestriction] auto_schedule_run: bool security_level: SecurityLevel + optim_lsbs_with_lut: bool def __init__( self, @@ -1081,6 +1082,7 @@ def __init__( keyset_restriction: Optional[KeysetRestriction] = None, auto_schedule_run: bool = False, security_level: SecurityLevel = SecurityLevel.SECURITY_128_BITS, + optim_lsbs_with_lut: bool = True, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1194,6 +1196,8 @@ def __init__( self.security_level = security_level + self.optim_lsbs_with_lut = optim_lsbs_with_lut + self._validate() class Keep: @@ -1273,6 +1277,7 @@ def fork( keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP, auto_schedule_run: Union[Keep, bool] = KEEP, security_level: Union[Keep, SecurityLevel] = KEEP, + optim_lsbs_with_lut: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index fb02547d5..3d8d51855 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -2456,29 +2456,30 @@ def extract_bits( ) # we optimize bulk extract in low precision, used for identity - cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf")) - cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1) - if cost_one_tlu < cost_many_lsbs: - - def tlu_cell_input_value(i): - if x.type.is_unsigned or i < 2 ** (x.bit_width - 1): - return i - return -(2 ** (x.bit_width) - i) - - table = [ - sum( - ((tlu_cell_input_value(i) >> bit) & 1) << position - for bit, position in bits_and_their_positions - ) - + ( # padding bit - 0 - if resulting_type.is_unsigned or i < 2 ** (x.bit_width - 1) - else 2**resulting_type.bit_width - ) - for i in range(2**x.bit_width) - ] - tlu_result = self.tlu(resulting_type, x, table) - return self.to_signedness(tlu_result, of=resulting_type) + if self.configuration.optim_lsbs_with_lut: + cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf")) + cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1) + if cost_one_tlu < cost_many_lsbs: + + def tlu_cell_input_value(i): + if x.type.is_unsigned or i < 2 ** (x.bit_width - 1): + return i + return -(2 ** (x.bit_width) - i) + + table = [ + sum( + ((tlu_cell_input_value(i) >> bit) & 1) << position + for bit, position in bits_and_their_positions + ) + + ( # padding bit + 0 + if resulting_type.is_unsigned or i < 2 ** (x.bit_width - 1) + else 2**resulting_type.bit_width + ) + for i in range(2**x.bit_width) + ] + tlu_result = self.tlu(resulting_type, x, table) + return self.to_signedness(tlu_result, of=resulting_type) current_bit = 0 max_bit = x.original_bit_width diff --git a/frontends/concrete-python/tests/execution/test_bit_extraction.py b/frontends/concrete-python/tests/execution/test_bit_extraction.py index 1bec4c697..b1ded6926 100644 --- a/frontends/concrete-python/tests/execution/test_bit_extraction.py +++ b/frontends/concrete-python/tests/execution/test_bit_extraction.py @@ -192,7 +192,11 @@ def test_bad_plain_bit_extraction( pytest.param(10, False, lambda x: fhe.bits(x)[5:15], id="unsigned-10b[5:15]"), ], ) -def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers): +@pytest.mark.parametrize( + "optim_lsbs_with_lut", + [True, False], +) +def test_bit_extraction(input_bit_width, input_is_signed, operation, optim_lsbs_with_lut, helpers): """ Test bit extraction. """ @@ -208,7 +212,9 @@ def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers): ] compiler = fhe.Compiler(operation, {"x": "encrypted"}) - circuit = compiler.compile(inputset, helpers.configuration()) + circuit = compiler.compile( + inputset, helpers.configuration().fork(optim_lsbs_with_lut=optim_lsbs_with_lut) + ) values = inputset if len(inputset) <= 8 else random.sample(inputset, 8) for value in values: helpers.check_execution(circuit, operation, value, retries=3)