Skip to content

Commit

Permalink
fix(frontend-python): round_bit_pattern, prevent exactness argument m…
Browse files Browse the repository at this point in the history
…isuse
  • Loading branch information
rudy-6-4 committed Apr 2, 2024
1 parent a98feed commit 9717ec5
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import platform
from dataclasses import dataclass
from enum import Enum, IntEnum
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple, Union, get_type_hints

Expand Down Expand Up @@ -74,13 +74,13 @@ def parse(cls, string: str) -> "MultiParameterStrategy":
raise ValueError(message)


class Exactness(IntEnum):
class Exactness(Enum):
"""
Exactness, to specify for specific operator the implementation preference (default and local).
"""

EXACT = 0
APPROXIMATE = 1
APPROXIMATE = "approximate"
EXACT = "exact"


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def round_bit_pattern(

lsbs_to_remove = lsbs_to_remove.lsbs_to_remove

if not isinstance(exactness, (Exactness, None.__class__)):
msg = "exactness should be of type fhe.Exactness"
raise TypeError(msg)

assert isinstance(lsbs_to_remove, int)

def evaluator(
Expand Down
2 changes: 1 addition & 1 deletion frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
variable_input.origin.properties["exactness"]
or ctx.configuration.rounding_exactness
)
if exactness == Exactness.APPROXIMATE:
if exactness is Exactness.APPROXIMATE:
# we clip values to enforce input precision exactly as queried
original_bit_width = variable_input.origin.properties["original_bit_width"]
lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def apply(self, graph: Graph):
exactness = self.rounding_exactness

if original_lsbs_to_remove != 0 and final_lsbs_to_remove == 0:
if exactness != Exactness.APPROXIMATE:
if exactness is not Exactness.APPROXIMATE:
self.replace_with_tlu(graph, node)
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Tests of 'round_bit_pattern' extension.
"""

import pytest

from concrete import fhe


Expand Down Expand Up @@ -35,3 +37,16 @@ def test_dump_load_auto_rounder():
assert loaded.input_max == 20
assert loaded.input_bit_width == 5
assert loaded.lsbs_to_remove == 2


def test_bad_exactness():
"""
Test for incorrect 'exactness' argument.
"""

@fhe.compiler({"a": "encrypted"})
def f(a):
return fhe.round_bit_pattern(a, lsbs_to_remove=1, exactness=True)

with pytest.raises(TypeError):
f.compile([0])

0 comments on commit 9717ec5

Please sign in to comment.