Skip to content

Commit

Permalink
feat(frontend-python): add option to configure tlu on original bit width
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Apr 16, 2024
1 parent 34de883 commit 3d0727b
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 26 deletions.
10 changes: 8 additions & 2 deletions docs/guides/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* **print_tlu_fusing** : bool = False
* Enables printing TLU fusing to see which table lookups are fused.
* **compress\_evaluation\_keys**: bool = False,
- This specifies that serialization takes the compressed form of evaluation keys.
* This specifies that serialization takes the compressed form of evaluation keys.
* **compress\_input\_ciphertexts**: bool = False,
* This specifies that serialization takes the compressed form of input ciphertexts.
* This specifies that serialization takes the compressed form of input ciphertexts.
* **optimize\_tlu\_based\_on\_original\_bit\_width**: Union\[bool, int] = 8,
* Configures whether to convert values to their original precision before doing a table lookup on them.
* True enables it for all cases.
* False disables it for all cases.
* Integer value enables or disables it depending on the original bit width.
* With the default value of 8, only the values with original bit width <= 8 will be converted to their original precision.
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class Configuration:
optimize_tlu_based_on_measured_bounds: bool
enable_tlu_fusing: bool
print_tlu_fusing: bool
optimize_tlu_based_on_original_bit_width: Union[bool, int]

def __init__(
self,
Expand Down Expand Up @@ -1053,6 +1054,7 @@ def __init__(
optimize_tlu_based_on_measured_bounds: bool = False,
enable_tlu_fusing: bool = True,
print_tlu_fusing: bool = False,
optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1151,6 +1153,8 @@ def __init__(
self.enable_tlu_fusing = enable_tlu_fusing
self.print_tlu_fusing = print_tlu_fusing

self.optimize_tlu_based_on_original_bit_width = optimize_tlu_based_on_original_bit_width

self._validate()

class Keep:
Expand Down Expand Up @@ -1218,6 +1222,7 @@ def fork(
optimize_tlu_based_on_measured_bounds: Union[Keep, bool] = KEEP,
enable_tlu_fusing: Union[Keep, bool] = KEEP,
print_tlu_fusing: Union[Keep, bool] = KEEP,
optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
60 changes: 40 additions & 20 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2882,18 +2882,28 @@ def multi_tlu(
on = self.cast_to_original_bit_width(offsetted)

elif on.bit_width > on.original_bit_width:
if on.is_unsigned:
tables = [table[: 2**on.original_bit_width] for table in tables]
else:
tables = [
(
table[: 2 ** (on.original_bit_width - 1)]
+ table[-(2 ** (on.original_bit_width - 1)) :]
)
for table in tables
]
optimize = (
self.configuration.optimize_tlu_based_on_original_bit_width
if isinstance(self.configuration.optimize_tlu_based_on_original_bit_width, bool)
else (
on.original_bit_width
<= self.configuration.optimize_tlu_based_on_original_bit_width
)
)

on = self.cast_to_original_bit_width(on)
if optimize:
if on.is_unsigned:
tables = [table[: 2**on.original_bit_width] for table in tables]
else:
tables = [
(
table[: 2 ** (on.original_bit_width - 1)]
+ table[-(2 ** (on.original_bit_width - 1)) :]
)
for table in tables
]

on = self.cast_to_original_bit_width(on)

on = self.broadcast_to(on, mapping.shape)

Expand Down Expand Up @@ -3633,16 +3643,26 @@ def tlu(self, resulting_type: ConversionType, on: Conversion, table: Sequence[in
on = self.cast_to_original_bit_width(offsetted)

elif on.bit_width > on.original_bit_width:
if len(table) != 2**on.original_bit_width:
if on.is_unsigned:
table = table[: 2**on.original_bit_width]
else:
table = (
table[: 2 ** (on.original_bit_width - 1)]
+ table[-(2 ** (on.original_bit_width - 1)) :] # type: ignore
)
optimize = (
self.configuration.optimize_tlu_based_on_original_bit_width
if isinstance(self.configuration.optimize_tlu_based_on_original_bit_width, bool)
else (
on.original_bit_width
<= self.configuration.optimize_tlu_based_on_original_bit_width
)
)

if optimize:
if len(table) != 2**on.original_bit_width:
if on.is_unsigned:
table = table[: 2**on.original_bit_width]
else:
table = (
table[: 2 ** (on.original_bit_width - 1)]
+ table[-(2 ** (on.original_bit_width - 1)) :] # type: ignore
)

on = self.cast_to_original_bit_width(on)
on = self.cast_to_original_bit_width(on)

table = list(table)

Expand Down
8 changes: 8 additions & 0 deletions frontends/concrete-python/tests/execution/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,14 @@ def issue650(x):
{},
id="issue-651",
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 2**14 - 1], "status": "encrypted", "shape": ()},
},
{},
id="x + (x // 3)",
),
],
)
def test_others(function, parameters, configuration_overrides, helpers):
Expand Down
162 changes: 158 additions & 4 deletions frontends/concrete-python/tests/mlir/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
Tests of `Converter` class.
"""

# pylint: disable=import-error,no-name-in-module

import numpy as np
import pytest
from concrete.compiler import CompilationContext

from concrete import fhe
from concrete.fhe.compilation.configuration import ParameterSelectionStrategy
from concrete.fhe.mlir import GraphConverter

from ..conftest import USE_MULTI_PRECISION

# pylint: enable=import-error,no-name-in-module


def assign(x, y):
"""
Expand Down Expand Up @@ -1508,6 +1513,147 @@ def test_converter_bad_convert(
}
}
""", # noqa: E501
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_original_bit_width": True,
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%c2_i7 = arith.constant 2 : i7
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
%1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<6>) -> !FHE.eint<5>
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]> : tensor<32xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<6>
%3 = "FHE.add_eint"(%arg0, %2) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %3 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_original_bit_width": False,
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
%1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_original_bit_width": 6,
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%c2_i7 = arith.constant 2 : i7
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
%1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<6>) -> !FHE.eint<5>
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]> : tensor<32xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<6>
%3 = "FHE.add_eint"(%arg0, %2) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %3 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_original_bit_width": 5,
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%c2_i7 = arith.constant 2 : i7
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
%1 = "FHE.reinterpret_precision"(%0) : (!FHE.eint<6>) -> !FHE.eint<5>
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]> : tensor<32xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<6>
%3 = "FHE.add_eint"(%arg0, %2) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %3 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_original_bit_width": 4,
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
%1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
}
}
""", # noqa: E501
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_original_bit_width": 3,
},
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
%1 = "FHE.add_eint"(%arg0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
}
}
""", # noqa: E501
),
],
Expand All @@ -1529,9 +1675,13 @@ def test_converter_convert_multi_precision(
compiler = fhe.Compiler(function, parameter_encryption_statuses)

inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
graph = compiler.trace(inputset, configuration)

helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
compilation_context = CompilationContext.new()
mlir_context = compilation_context.mlir_context()

module = GraphConverter(configuration).convert(graph, mlir_context)
helpers.check_str(expected_mlir.strip(), str(module).strip())


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1643,9 +1793,13 @@ def test_converter_convert_single_precision(function, parameters, expected_mlir,
compiler = fhe.Compiler(function, parameter_encryption_statuses)

inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
graph = compiler.trace(inputset, configuration)

helpers.check_str(expected_mlir.strip(), circuit.mlir.strip())
compilation_context = CompilationContext.new()
mlir_context = compilation_context.mlir_context()

module = GraphConverter(configuration).convert(graph, mlir_context)
helpers.check_str(expected_mlir.strip(), str(module).strip())


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3d0727b

Please sign in to comment.