diff --git a/docs/guides/configure.md b/docs/guides/configure.md index b73f54ea73..8753aa86d8 100644 --- a/docs/guides/configure.md +++ b/docs/guides/configure.md @@ -135,6 +135,12 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * **print_tlu_fusing** : bool = False * Enables printing TLU fusing to see which table lookups are fused. * **compress\_evaluation\_keys**: bool = False, - - This specifies that serialization takes the compressed form of evaluation keys. + * This specifies that serialization takes the compressed form of evaluation keys. * **compress\_input\_ciphertexts**: bool = False, - * This specifies that serialization takes the compressed form of input ciphertexts. + * This specifies that serialization takes the compressed form of input ciphertexts. +* **optimize\_tlu\_based\_on\_original\_bit\_width**: Union\[bool, int] = 8, + * Configures whether to convert values to their original precision before doing a table lookup on them. + * True enables it for all cases. + * False disables it for all cases. + * Integer value enables or disables it depending on the original bit width. + * With the default value of 8, only the values with original bit width <= 8 will be converted to their original precision. diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index c4fd7e094d..55c0c796b9 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -989,6 +989,7 @@ class Configuration: optimize_tlu_based_on_measured_bounds: bool enable_tlu_fusing: bool print_tlu_fusing: bool + optimize_tlu_based_on_original_bit_width: Union[bool, int] def __init__( self, @@ -1053,6 +1054,7 @@ def __init__( optimize_tlu_based_on_measured_bounds: bool = False, enable_tlu_fusing: bool = True, print_tlu_fusing: bool = False, + optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1151,6 +1153,8 @@ def __init__( self.enable_tlu_fusing = enable_tlu_fusing self.print_tlu_fusing = print_tlu_fusing + self.optimize_tlu_based_on_original_bit_width = optimize_tlu_based_on_original_bit_width + self._validate() class Keep: @@ -1218,6 +1222,7 @@ def fork( optimize_tlu_based_on_measured_bounds: Union[Keep, bool] = KEEP, enable_tlu_fusing: Union[Keep, bool] = KEEP, print_tlu_fusing: Union[Keep, bool] = KEEP, + optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index ad553df280..982fd30a7c 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -2882,18 +2882,28 @@ def multi_tlu( on = self.cast_to_original_bit_width(offsetted) elif on.bit_width > on.original_bit_width: - if on.is_unsigned: - tables = [table[: 2**on.original_bit_width] for table in tables] - else: - tables = [ - ( - table[: 2 ** (on.original_bit_width - 1)] - + table[-(2 ** (on.original_bit_width - 1)) :] - ) - for table in tables - ] + optimize = ( + self.configuration.optimize_tlu_based_on_original_bit_width + if isinstance(self.configuration.optimize_tlu_based_on_original_bit_width, bool) + else ( + on.original_bit_width + <= self.configuration.optimize_tlu_based_on_original_bit_width + ) + ) - on = self.cast_to_original_bit_width(on) + if optimize: + if on.is_unsigned: + tables = [table[: 2**on.original_bit_width] for table in tables] + else: + tables = [ + ( + table[: 2 ** (on.original_bit_width - 1)] + + table[-(2 ** (on.original_bit_width - 1)) :] + ) + for table in tables + ] + + on = self.cast_to_original_bit_width(on) on = self.broadcast_to(on, mapping.shape) @@ -3633,16 +3643,26 @@ def tlu(self, resulting_type: ConversionType, on: Conversion, table: Sequence[in on = self.cast_to_original_bit_width(offsetted) elif on.bit_width > on.original_bit_width: - if len(table) != 2**on.original_bit_width: - if on.is_unsigned: - table = table[: 2**on.original_bit_width] - else: - table = ( - table[: 2 ** (on.original_bit_width - 1)] - + table[-(2 ** (on.original_bit_width - 1)) :] # type: ignore - ) + optimize = ( + self.configuration.optimize_tlu_based_on_original_bit_width + if isinstance(self.configuration.optimize_tlu_based_on_original_bit_width, bool) + else ( + on.original_bit_width + <= self.configuration.optimize_tlu_based_on_original_bit_width + ) + ) + + if optimize: + if len(table) != 2**on.original_bit_width: + if on.is_unsigned: + table = table[: 2**on.original_bit_width] + else: + table = ( + table[: 2 ** (on.original_bit_width - 1)] + + table[-(2 ** (on.original_bit_width - 1)) :] # type: ignore + ) - on = self.cast_to_original_bit_width(on) + on = self.cast_to_original_bit_width(on) table = list(table) diff --git a/frontends/concrete-python/tests/execution/test_others.py b/frontends/concrete-python/tests/execution/test_others.py index b048c84574..fa65750cb1 100644 --- a/frontends/concrete-python/tests/execution/test_others.py +++ b/frontends/concrete-python/tests/execution/test_others.py @@ -966,6 +966,14 @@ def issue650(x): {}, id="issue-651", ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 2**14 - 1], "status": "encrypted", "shape": ()}, + }, + {}, + id="x + (x // 3)", + ), ], ) def test_others(function, parameters, configuration_overrides, helpers): diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index f6d015976d..e6269968d6 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -2,8 +2,11 @@ Tests of `Converter` class. """ +# pylint: disable=import-error,no-name-in-module + import numpy as np import pytest +from concrete.compiler import CompilationContext from concrete import fhe from concrete.fhe.compilation.configuration import ParameterSelectionStrategy @@ -11,6 +14,8 @@ from ..conftest import USE_MULTI_PRECISION +# pylint: enable=import-error,no-name-in-module + def assign(x, y): """ @@ -1508,6 +1513,147 @@ def test_converter_bad_convert( } } + """, # noqa: E501 + ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 31], "status": "encrypted", "shape": ()}, + }, + { + "optimize_tlu_based_on_original_bit_width": True, + }, + """ + +module { + func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c3_i3 = arith.constant 3 : i3 + %c2_i7 = arith.constant 2 : i7 + %0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6> + %1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<6>) -> !FHE.eint<5> + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]> : tensor<32xi64> + %2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<6> + %3 = "FHE.add_eint"(%arg0, %2) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> + return %3 : !FHE.eint<6> + } +} + + """, # noqa: E501 + ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 31], "status": "encrypted", "shape": ()}, + }, + { + "optimize_tlu_based_on_original_bit_width": False, + }, + """ + +module { + func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c3_i3 = arith.constant 3 : i3 + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64> + %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6> + %1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> + return %1 : !FHE.eint<6> + } +} + + """, # noqa: E501 + ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 31], "status": "encrypted", "shape": ()}, + }, + { + "optimize_tlu_based_on_original_bit_width": 6, + }, + """ + +module { + func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c3_i3 = arith.constant 3 : i3 + %c2_i7 = arith.constant 2 : i7 + %0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6> + %1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<6>) -> !FHE.eint<5> + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]> : tensor<32xi64> + %2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<6> + %3 = "FHE.add_eint"(%arg0, %2) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> + return %3 : !FHE.eint<6> + } +} + + """, # noqa: E501 + ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 31], "status": "encrypted", "shape": ()}, + }, + { + "optimize_tlu_based_on_original_bit_width": 5, + }, + """ + +module { + func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c3_i3 = arith.constant 3 : i3 + %c2_i7 = arith.constant 2 : i7 + %0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6> + %1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<6>) -> !FHE.eint<5> + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]> : tensor<32xi64> + %2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<6> + %3 = "FHE.add_eint"(%arg0, %2) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> + return %3 : !FHE.eint<6> + } +} + + """, # noqa: E501 + ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 31], "status": "encrypted", "shape": ()}, + }, + { + "optimize_tlu_based_on_original_bit_width": 4, + }, + """ + +module { + func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c3_i3 = arith.constant 3 : i3 + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64> + %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6> + %1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> + return %1 : !FHE.eint<6> + } +} + + """, # noqa: E501 + ), + pytest.param( + lambda x: x + (x // 3), + { + "x": {"range": [0, 31], "status": "encrypted", "shape": ()}, + }, + { + "optimize_tlu_based_on_original_bit_width": 3, + }, + """ + +module { + func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c3_i3 = arith.constant 3 : i3 + %cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64> + %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6> + %1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6> + return %1 : !FHE.eint<6> + } +} + """, # noqa: E501 ), ], @@ -1529,9 +1675,13 @@ def test_converter_convert_multi_precision( compiler = fhe.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset, configuration) + graph = compiler.trace(inputset, configuration) - helpers.check_str(expected_mlir.strip(), circuit.mlir.strip()) + compilation_context = CompilationContext.new() + mlir_context = compilation_context.mlir_context() + + module = GraphConverter(configuration).convert(graph, mlir_context) + helpers.check_str(expected_mlir.strip(), str(module).strip()) @pytest.mark.parametrize( @@ -1643,9 +1793,13 @@ def test_converter_convert_single_precision(function, parameters, expected_mlir, compiler = fhe.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset, configuration) + graph = compiler.trace(inputset, configuration) - helpers.check_str(expected_mlir.strip(), circuit.mlir.strip()) + compilation_context = CompilationContext.new() + mlir_context = compilation_context.mlir_context() + + module = GraphConverter(configuration).convert(graph, mlir_context) + helpers.check_str(expected_mlir.strip(), str(module).strip()) @pytest.mark.parametrize(