Skip to content

Commit

Permalink
feat(frontend): add flag to optim LSBs with LUT
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Feb 27, 2025
1 parent e5dba27 commit 2625ffd
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ class Configuration:
keyset_restriction: Optional[KeysetRestriction]
auto_schedule_run: bool
security_level: SecurityLevel
optim_lsbs_with_lut: bool

def __init__(
self,
Expand Down Expand Up @@ -1081,6 +1082,7 @@ def __init__(
keyset_restriction: Optional[KeysetRestriction] = None,
auto_schedule_run: bool = False,
security_level: SecurityLevel = SecurityLevel.SECURITY_128_BITS,
optim_lsbs_with_lut: bool = True,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1194,6 +1196,8 @@ def __init__(

self.security_level = security_level

self.optim_lsbs_with_lut = optim_lsbs_with_lut

self._validate()

class Keep:
Expand Down Expand Up @@ -1273,6 +1277,7 @@ def fork(
keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP,
auto_schedule_run: Union[Keep, bool] = KEEP,
security_level: Union[Keep, SecurityLevel] = KEEP,
optim_lsbs_with_lut: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
47 changes: 24 additions & 23 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,29 +2456,30 @@ def extract_bits(
)

# we optimize bulk extract in low precision, used for identity
cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf"))
cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1)
if cost_one_tlu < cost_many_lsbs:

def tlu_cell_input_value(i):
if x.type.is_unsigned or i < 2 ** (x.bit_width - 1):
return i
return -(2 ** (x.bit_width) - i)

table = [
sum(
((tlu_cell_input_value(i) >> bit) & 1) << position
for bit, position in bits_and_their_positions
)
+ ( # padding bit
0
if resulting_type.is_unsigned or i < 2 ** (x.bit_width - 1)
else 2**resulting_type.bit_width
)
for i in range(2**x.bit_width)
]
tlu_result = self.tlu(resulting_type, x, table)
return self.to_signedness(tlu_result, of=resulting_type)
if self.configuration.optim_lsbs_with_lut:
cost_one_tlu = LUT_COSTS_V0_NORM2_0.get(x.bit_width, float("inf"))
cost_many_lsbs = LUT_COSTS_V0_NORM2_0[1] * (max(bits, default=0) + 1)
if cost_one_tlu < cost_many_lsbs:

def tlu_cell_input_value(i):
if x.type.is_unsigned or i < 2 ** (x.bit_width - 1):
return i
return -(2 ** (x.bit_width) - i)

table = [
sum(
((tlu_cell_input_value(i) >> bit) & 1) << position
for bit, position in bits_and_their_positions
)
+ ( # padding bit
0
if resulting_type.is_unsigned or i < 2 ** (x.bit_width - 1)
else 2**resulting_type.bit_width
)
for i in range(2**x.bit_width)
]
tlu_result = self.tlu(resulting_type, x, table)
return self.to_signedness(tlu_result, of=resulting_type)

current_bit = 0
max_bit = x.original_bit_width
Expand Down
10 changes: 8 additions & 2 deletions frontends/concrete-python/tests/execution/test_bit_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def test_bad_plain_bit_extraction(
pytest.param(10, False, lambda x: fhe.bits(x)[5:15], id="unsigned-10b[5:15]"),
],
)
def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers):
@pytest.mark.parametrize(
"optim_lsbs_with_lut",
[True, False],
)
def test_bit_extraction(input_bit_width, input_is_signed, operation, optim_lsbs_with_lut, helpers):
"""
Test bit extraction.
"""
Expand All @@ -208,7 +212,9 @@ def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers):
]

compiler = fhe.Compiler(operation, {"x": "encrypted"})
circuit = compiler.compile(inputset, helpers.configuration())
circuit = compiler.compile(
inputset, helpers.configuration().fork(optim_lsbs_with_lut=optim_lsbs_with_lut)
)
values = inputset if len(inputset) <= 8 else random.sample(inputset, 8)
for value in values:
helpers.check_execution(circuit, operation, value, retries=3)
Expand Down

0 comments on commit 2625ffd

Please sign in to comment.